From c296204541f8a3727b2b41ce1db0323b75d03757 Mon Sep 17 00:00:00 2001 From: jwkim Date: Fri, 17 Jan 2025 13:26:19 +0900 Subject: [PATCH] =?UTF-8?q?edit=20:=20=EC=9D=B4=EB=AF=B8=EC=A7=80=20?= =?UTF-8?q?=EC=83=9D=EC=84=B1=EC=8B=9C=20=EC=83=9D=EC=84=B1=20=EA=B0=AF?= =?UTF-8?q?=EC=88=98=20=EC=A7=80=EC=A0=95=20=EA=B0=80=EB=8A=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + custom_apps/dalle3/BingImageCreator.py | 7 ++++--- custom_apps/imagen/custom_imagen.py | 5 +++-- rest/app/models.py | 2 +- rest/app/routes/services.py | 19 +++++++++++++++---- rest/app/utils/parsing_utils.py | 13 ++++++++++++- 6 files changed, 36 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 90e61bd..30a79af 100644 --- a/.gitignore +++ b/.gitignore @@ -144,3 +144,4 @@ cython_debug/ log sheet_counter.txt +*output \ No newline at end of file diff --git a/custom_apps/dalle3/BingImageCreator.py b/custom_apps/dalle3/BingImageCreator.py index e75442c..8d96ed3 100644 --- a/custom_apps/dalle3/BingImageCreator.py +++ b/custom_apps/dalle3/BingImageCreator.py @@ -196,12 +196,13 @@ class ImageGenAsync: os.mkdir(output_dir) try: model = "dalle3" - jpeg_index = 0 + jpeg_index = 1 parsing_file_name = prompt_to_filenames(prompt) + _datetime = D.date_file_name() for link in links[:download_count]: if download_count == 1: - _path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{D.date_file_name()}.jpg") + _path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{_datetime}.jpg") if os.path.exists(_path): raise Exception("파일 이미 존재함") @@ -216,7 +217,7 @@ class ImageGenAsync: jpeg_index += 1 else: - _path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{jpeg_index}_{D.date_file_name()}.jpg") + _path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{jpeg_index}_{_datetime}.jpg") if os.path.exists(_path): raise Exception("파일 이미 존재함") diff --git a/custom_apps/imagen/custom_imagen.py b/custom_apps/imagen/custom_imagen.py index ddc27e8..84b2ca9 100644 --- a/custom_apps/imagen/custom_imagen.py +++ b/custom_apps/imagen/custom_imagen.py @@ -22,6 +22,7 @@ def imagen_generate_image(prompt,download_count=1): _file_name = prompt_to_filenames(prompt) _folder = os.path.join(OUTPUT_FOLDER,"imagen") + _datetime = D.date_file_name() if not os.path.isdir(_folder): raise FileExistsError(f"FileExistsError: {_folder}") @@ -40,7 +41,7 @@ def imagen_generate_image(prompt,download_count=1): ) if len(images.images) <=1 : - images[0].save(location=os.path.join(_folder,f"imagen_{_file_name}_{D.date_file_name()}.png"), include_generation_parameters=False) + images[0].save(location=os.path.join(_folder,f"imagen_{_file_name}_{_datetime}.png"), include_generation_parameters=False) else: for i in range(len(images.images)): - images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i}_{D.date_file_name()}.png"), include_generation_parameters=False) + images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i+1}_{_datetime}.png"), include_generation_parameters=False) diff --git a/rest/app/models.py b/rest/app/models.py index add73f4..54dae3c 100644 --- a/rest/app/models.py +++ b/rest/app/models.py @@ -571,4 +571,4 @@ class ImageGenerateReq(BaseModel): ### [Request] image generate request """ prompt : str = Field(description='프롬프트', example='검은색 안경') - # downloadCount : int = Field(1, description='이미지 생성 갯수', example=1) \ No newline at end of file + downloadCount : int = Field(1, description='이미지 생성 갯수', example=1) \ No newline at end of file diff --git a/rest/app/routes/services.py b/rest/app/routes/services.py index 6ed9e0a..1811f58 100644 --- a/rest/app/routes/services.py +++ b/rest/app/routes/services.py @@ -21,6 +21,7 @@ from custom_logger.custom_log import custom_logger as LOG from custom_apps.dalle3.custom_dalle import DallEArgument,dalle3_generate_image from custom_apps.imagen.custom_imagen import imagen_generate_image +from rest.app.utils.parsing_utils import download_range router = APIRouter(prefix="/services") @@ -39,9 +40,14 @@ async def dalle3(request: Request, request_body_info: M.ImageGenerateReq): """ response = M.ResponseBase() try: - args = DallEArgument(prompt=request_body_info.prompt - # , download_count=request_body_info.downloadCount - ) + if not download_range(request_body_info.downloadCount): + raise Exception(f"downloadCount is 1~4 (current value = {request_body_info.downloadCount})") + + + args = DallEArgument( + prompt=request_body_info.prompt, + download_count=request_body_info.downloadCount + ) dalle3_generate_image(args) @@ -66,7 +72,12 @@ async def imagen(request: Request, request_body_info: M.ImageGenerateReq): response = M.ResponseBase() try: - imagen_generate_image(request_body_info.prompt) + if not download_range(request_body_info.downloadCount): + raise Exception(f"downloadCount is 1~4 (current value = {request_body_info.downloadCount})") + + imagen_generate_image(prompt=request_body_info.prompt, + download_count=request_body_info.downloadCount + ) return response.set_message() diff --git a/rest/app/utils/parsing_utils.py b/rest/app/utils/parsing_utils.py index fa78c0b..cd2015e 100644 --- a/rest/app/utils/parsing_utils.py +++ b/rest/app/utils/parsing_utils.py @@ -11,4 +11,15 @@ def prompt_to_filenames(prompt): else: filename += i - return filename \ No newline at end of file + return filename + + +def download_range(download_count:int): + _min = 1 + _max = 4 + + if _min <= download_count and download_count <= _max: + return True + + return False + \ No newline at end of file