diff --git a/custom_apps/dalle3/BingImageCreator.py b/custom_apps/dalle3/BingImageCreator.py index 8d96ed3..1d7833e 100644 --- a/custom_apps/dalle3/BingImageCreator.py +++ b/custom_apps/dalle3/BingImageCreator.py @@ -181,6 +181,7 @@ class ImageGenAsync: links: list, output_dir: str, download_count: int, + counter, file_name: str = None, prompt: str = None ) -> None: @@ -202,7 +203,8 @@ class ImageGenAsync: for link in links[:download_count]: if download_count == 1: - _path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{_datetime}.jpg") + _file_name = f"{model}_{parsing_file_name}_{_datetime}.png" + _path = os.path.join(output_dir, _file_name) if os.path.exists(_path): raise Exception("파일 이미 존재함") @@ -211,26 +213,43 @@ class ImageGenAsync: response = await self.session.get(link) if response.status_code != 200: raise Exception("Could not download image") - # save response to file - with open(_path,"wb") as output_file: - output_file.write(response.content) + + # 이미지 파일 생성 x(error) + if 'svg viewBox' in str(response.content): + counter.add_error_messages(_file_name) + + # 이미지 파일 생성 + else: + # save response to file + with open(_path,"wb") as output_file: + output_file.write(response.content) + counter.count() + jpeg_index += 1 else: - _path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{jpeg_index}_{_datetime}.jpg") + _file_name = f"{model}_{parsing_file_name}_{jpeg_index}_{_datetime}.png" + _path = os.path.join(output_dir, _file_name) if os.path.exists(_path): raise Exception("파일 이미 존재함") - - while os.path.exists(_path): - jpeg_index += 1 + response = await self.session.get(link) if response.status_code != 200: raise Exception("Could not download image") - # save response to file - with open(_path,"wb") as output_file: - output_file.write(response.content) - jpeg_index += 1 + # 이미지 파일 생성 x(error) + if 'svg viewBox' in str(response.content) or 'var' in str(response.content): + counter.add_error_messages(_file_name) + + # 이미지 파일 생성 + else: + # save response to file + with open(_path,"wb") as output_file: + output_file.write(response.content) + counter.count() + + jpeg_index += 1 + except httpx.InvalidURL as url_exception: raise Exception( "Inappropriate contents found in the generated images. Please try again or try another prompt.", @@ -241,18 +260,12 @@ async def async_image_gen( prompt: str, download_count: int, output_dir: str, + counter, u_cookie=None, debug_file=None, quiet=False, all_cookies=None, ): - async with ImageGenAsync( - u_cookie, - debug_file=debug_file, - quiet=quiet, - all_cookies=all_cookies, - ) as image_generator: + async with ImageGenAsync(u_cookie,debug_file=debug_file,quiet=quiet,all_cookies=all_cookies,) as image_generator: images = await image_generator.get_images(prompt) - await image_generator.save_images( - images, output_dir=output_dir, download_count=download_count,prompt=prompt - ) \ No newline at end of file + await image_generator.save_images(links=images, counter=counter, output_dir=output_dir, download_count=download_count,prompt=prompt) \ No newline at end of file diff --git a/custom_apps/dalle3/custom_dalle.py b/custom_apps/dalle3/custom_dalle.py index 2df22f4..6b3b4e4 100644 --- a/custom_apps/dalle3/custom_dalle.py +++ b/custom_apps/dalle3/custom_dalle.py @@ -47,12 +47,40 @@ class DallEArgument: quiet: bool = False asyncio: bool = True version: bool = False + +class Counter: + """ + 생성된 이미지 카운트 + """ + + def __init__(self): + self.counter = 0 + self.error_messages = [] + + def reset(self): + self.counter = 0 + self.error_messages = [] + + def count(self): + self.counter += 1 + + def get_counter(self): + return self.counter + + def add_error_messages(self,error): + self.error_messages.append(error) + + def get_error_messages(self): + return self.error_messages + def dalle3_generate_image(args): nest_asyncio.apply() + counter = Counter() + if not os.path.isdir(args.output_dir): raise FileExistsError(f"FileExistsError: {args.output_dir}") @@ -76,13 +104,16 @@ def dalle3_generate_image(args): asyncio.run( async_image_gen( - args.prompt, - args.download_count, - args.output_dir, - args.U, - args.debug_file, - args.quiet, + prompt=args.prompt, + download_count=args.download_count, + output_dir=args.output_dir, + counter=counter, + u_cookie=args.U, + debug_file=args.debug_file, + quiet=args.quiet, all_cookies=cookie_json, ), ) + + return counter diff --git a/custom_apps/imagen/custom_imagen.py b/custom_apps/imagen/custom_imagen.py index 84b2ca9..d449769 100644 --- a/custom_apps/imagen/custom_imagen.py +++ b/custom_apps/imagen/custom_imagen.py @@ -42,6 +42,10 @@ def imagen_generate_image(prompt,download_count=1): if len(images.images) <=1 : images[0].save(location=os.path.join(_folder,f"imagen_{_file_name}_{_datetime}.png"), include_generation_parameters=False) + return 1 + else: for i in range(len(images.images)): images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i+1}_{_datetime}.png"), include_generation_parameters=False) + + return len(images.images) \ No newline at end of file diff --git a/rest/app/models.py b/rest/app/models.py index 54dae3c..a108a82 100644 --- a/rest/app/models.py +++ b/rest/app/models.py @@ -571,4 +571,31 @@ class ImageGenerateReq(BaseModel): ### [Request] image generate request """ prompt : str = Field(description='프롬프트', example='검은색 안경') - downloadCount : int = Field(1, description='이미지 생성 갯수', example=1) \ No newline at end of file + downloadCount : int = Field(1, description='이미지 생성 갯수', example=1) + + +#=============================================================================== +#=============================================================================== +#=============================================================================== +#=============================================================================== + +class ImageGenerateRes(ResponseBase): + """ + ### image generate response + """ + imageLen : int = Field(0, description='실제 이미지 생성 갯수', example=1) + + @staticmethod + def set_error(error,img_len=0): + ImageGenerateRes.result = False + ImageGenerateRes.error = str(error) + ImageGenerateRes.imageLen = img_len + + return ImageGenerateRes + + @staticmethod + def set_message(img_len): + ImageGenerateRes.result = True + ImageGenerateRes.error = None + ImageGenerateRes.imageLen = img_len + return ImageGenerateRes \ No newline at end of file diff --git a/rest/app/routes/services.py b/rest/app/routes/services.py index 1811f58..45e3285 100644 --- a/rest/app/routes/services.py +++ b/rest/app/routes/services.py @@ -27,7 +27,7 @@ from rest.app.utils.parsing_utils import download_range router = APIRouter(prefix="/services") -@router.post("/imageGenerate/dalle3", summary="이미지 생성(AI) - DALL-E 3", response_model=M.ResponseBase) +@router.post("/imageGenerate/dalle3", summary="이미지 생성(AI) - DALL-E 3", response_model=M.ImageGenerateRes) async def dalle3(request: Request, request_body_info: M.ImageGenerateReq): """ ## 이미지 생성(AI) - DALL-E 3 @@ -37,8 +37,11 @@ async def dalle3(request: Request, request_body_info: M.ImageGenerateReq): > - 쿠키 정보 설정(https://github.com/acheong08/BingImageCreator) 추후 set api 추가 예정 -> 현재 고정값 > - const.py 에 지정한 OUTPUT_FOLDER 하위에 dalle 폴더가 있어야함. + ## 정보 + > 오류 발생시 오류 발생한 파일은 에러 메세지에만 남기고 저장은 안함 + """ - response = M.ResponseBase() + response = M.ImageGenerateRes() try: if not download_range(request_body_info.downloadCount): raise Exception(f"downloadCount is 1~4 (current value = {request_body_info.downloadCount})") @@ -49,16 +52,21 @@ async def dalle3(request: Request, request_body_info: M.ImageGenerateReq): download_count=request_body_info.downloadCount ) - dalle3_generate_image(args) + info = dalle3_generate_image(args) - return response.set_message() - + if info.get_error_messages(): + error_message = f"파일생성 error: {info.get_error_messages()}" + LOG.error(error_message) + return response.set_error(error=error_message, img_len=info.get_counter()) + + return response.set_message(img_len=info.get_counter()) + except Exception as e: LOG.error(traceback.format_exc()) return response.set_error(e) -@router.post("/imageGenerate/imagen", summary="이미지 생성(AI) - imagen", response_model=M.ResponseBase) +@router.post("/imageGenerate/imagen", summary="이미지 생성(AI) - imagen", response_model=M.ImageGenerateRes) async def imagen(request: Request, request_body_info: M.ImageGenerateReq): """ ## 이미지 생성(AI) - imagen @@ -69,18 +77,18 @@ async def imagen(request: Request, request_body_info: M.ImageGenerateReq): > - const.py 에 지정한 OUTPUT_FOLDER 하위에 imagen 폴더가 있어야함. """ - response = M.ResponseBase() + response = M.ImageGenerateRes() try: if not download_range(request_body_info.downloadCount): raise Exception(f"downloadCount is 1~4 (current value = {request_body_info.downloadCount})") - imagen_generate_image(prompt=request_body_info.prompt, + img_length = imagen_generate_image(prompt=request_body_info.prompt, download_count=request_body_info.downloadCount ) - return response.set_message() + return response.set_message(img_len=img_length) except Exception as e: LOG.error(traceback.format_exc()) - return response.set_error(e) \ No newline at end of file + return response.set_error(error=e) \ No newline at end of file