edit : 실제 이미지 생성한 갯수 response 에 추가

This commit is contained in:
2025-01-17 16:06:01 +09:00
parent c296204541
commit 6e5f990fc1
5 changed files with 121 additions and 38 deletions

View File

@@ -181,6 +181,7 @@ class ImageGenAsync:
links: list, links: list,
output_dir: str, output_dir: str,
download_count: int, download_count: int,
counter,
file_name: str = None, file_name: str = None,
prompt: str = None prompt: str = None
) -> None: ) -> None:
@@ -202,7 +203,8 @@ class ImageGenAsync:
for link in links[:download_count]: for link in links[:download_count]:
if download_count == 1: if download_count == 1:
_path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{_datetime}.jpg") _file_name = f"{model}_{parsing_file_name}_{_datetime}.png"
_path = os.path.join(output_dir, _file_name)
if os.path.exists(_path): if os.path.exists(_path):
raise Exception("파일 이미 존재함") raise Exception("파일 이미 존재함")
@@ -211,26 +213,43 @@ class ImageGenAsync:
response = await self.session.get(link) response = await self.session.get(link)
if response.status_code != 200: if response.status_code != 200:
raise Exception("Could not download image") raise Exception("Could not download image")
# save response to file
with open(_path,"wb") as output_file: # 이미지 파일 생성 x(error)
output_file.write(response.content) if 'svg viewBox' in str(response.content):
counter.add_error_messages(_file_name)
# 이미지 파일 생성
else:
# save response to file
with open(_path,"wb") as output_file:
output_file.write(response.content)
counter.count()
jpeg_index += 1 jpeg_index += 1
else: else:
_path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{jpeg_index}_{_datetime}.jpg") _file_name = f"{model}_{parsing_file_name}_{jpeg_index}_{_datetime}.png"
_path = os.path.join(output_dir, _file_name)
if os.path.exists(_path): if os.path.exists(_path):
raise Exception("파일 이미 존재함") raise Exception("파일 이미 존재함")
while os.path.exists(_path):
jpeg_index += 1
response = await self.session.get(link) response = await self.session.get(link)
if response.status_code != 200: if response.status_code != 200:
raise Exception("Could not download image") raise Exception("Could not download image")
# save response to file
with open(_path,"wb") as output_file:
output_file.write(response.content)
jpeg_index += 1
# 이미지 파일 생성 x(error)
if 'svg viewBox' in str(response.content) or 'var' in str(response.content):
counter.add_error_messages(_file_name)
# 이미지 파일 생성
else:
# save response to file
with open(_path,"wb") as output_file:
output_file.write(response.content)
counter.count()
jpeg_index += 1
except httpx.InvalidURL as url_exception: except httpx.InvalidURL as url_exception:
raise Exception( raise Exception(
"Inappropriate contents found in the generated images. Please try again or try another prompt.", "Inappropriate contents found in the generated images. Please try again or try another prompt.",
@@ -241,18 +260,12 @@ async def async_image_gen(
prompt: str, prompt: str,
download_count: int, download_count: int,
output_dir: str, output_dir: str,
counter,
u_cookie=None, u_cookie=None,
debug_file=None, debug_file=None,
quiet=False, quiet=False,
all_cookies=None, all_cookies=None,
): ):
async with ImageGenAsync( async with ImageGenAsync(u_cookie,debug_file=debug_file,quiet=quiet,all_cookies=all_cookies,) as image_generator:
u_cookie,
debug_file=debug_file,
quiet=quiet,
all_cookies=all_cookies,
) as image_generator:
images = await image_generator.get_images(prompt) images = await image_generator.get_images(prompt)
await image_generator.save_images( await image_generator.save_images(links=images, counter=counter, output_dir=output_dir, download_count=download_count,prompt=prompt)
images, output_dir=output_dir, download_count=download_count,prompt=prompt
)

View File

@@ -47,12 +47,40 @@ class DallEArgument:
quiet: bool = False quiet: bool = False
asyncio: bool = True asyncio: bool = True
version: bool = False version: bool = False
class Counter:
"""
생성된 이미지 카운트
"""
def __init__(self):
self.counter = 0
self.error_messages = []
def reset(self):
self.counter = 0
self.error_messages = []
def count(self):
self.counter += 1
def get_counter(self):
return self.counter
def add_error_messages(self,error):
self.error_messages.append(error)
def get_error_messages(self):
return self.error_messages
def dalle3_generate_image(args): def dalle3_generate_image(args):
nest_asyncio.apply() nest_asyncio.apply()
counter = Counter()
if not os.path.isdir(args.output_dir): if not os.path.isdir(args.output_dir):
raise FileExistsError(f"FileExistsError: {args.output_dir}") raise FileExistsError(f"FileExistsError: {args.output_dir}")
@@ -76,13 +104,16 @@ def dalle3_generate_image(args):
asyncio.run( asyncio.run(
async_image_gen( async_image_gen(
args.prompt, prompt=args.prompt,
args.download_count, download_count=args.download_count,
args.output_dir, output_dir=args.output_dir,
args.U, counter=counter,
args.debug_file, u_cookie=args.U,
args.quiet, debug_file=args.debug_file,
quiet=args.quiet,
all_cookies=cookie_json, all_cookies=cookie_json,
), ),
) )
return counter

View File

@@ -42,6 +42,10 @@ def imagen_generate_image(prompt,download_count=1):
if len(images.images) <=1 : if len(images.images) <=1 :
images[0].save(location=os.path.join(_folder,f"imagen_{_file_name}_{_datetime}.png"), include_generation_parameters=False) images[0].save(location=os.path.join(_folder,f"imagen_{_file_name}_{_datetime}.png"), include_generation_parameters=False)
return 1
else: else:
for i in range(len(images.images)): for i in range(len(images.images)):
images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i+1}_{_datetime}.png"), include_generation_parameters=False) images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i+1}_{_datetime}.png"), include_generation_parameters=False)
return len(images.images)

View File

@@ -571,4 +571,31 @@ class ImageGenerateReq(BaseModel):
### [Request] image generate request ### [Request] image generate request
""" """
prompt : str = Field(description='프롬프트', example='검은색 안경') prompt : str = Field(description='프롬프트', example='검은색 안경')
downloadCount : int = Field(1, description='이미지 생성 갯수', example=1) downloadCount : int = Field(1, description='이미지 생성 갯수', example=1)
#===============================================================================
#===============================================================================
#===============================================================================
#===============================================================================
class ImageGenerateRes(ResponseBase):
"""
### image generate response
"""
imageLen : int = Field(0, description='실제 이미지 생성 갯수', example=1)
@staticmethod
def set_error(error,img_len=0):
ImageGenerateRes.result = False
ImageGenerateRes.error = str(error)
ImageGenerateRes.imageLen = img_len
return ImageGenerateRes
@staticmethod
def set_message(img_len):
ImageGenerateRes.result = True
ImageGenerateRes.error = None
ImageGenerateRes.imageLen = img_len
return ImageGenerateRes

View File

@@ -27,7 +27,7 @@ from rest.app.utils.parsing_utils import download_range
router = APIRouter(prefix="/services") router = APIRouter(prefix="/services")
@router.post("/imageGenerate/dalle3", summary="이미지 생성(AI) - DALL-E 3", response_model=M.ResponseBase) @router.post("/imageGenerate/dalle3", summary="이미지 생성(AI) - DALL-E 3", response_model=M.ImageGenerateRes)
async def dalle3(request: Request, request_body_info: M.ImageGenerateReq): async def dalle3(request: Request, request_body_info: M.ImageGenerateReq):
""" """
## 이미지 생성(AI) - DALL-E 3 ## 이미지 생성(AI) - DALL-E 3
@@ -37,8 +37,11 @@ async def dalle3(request: Request, request_body_info: M.ImageGenerateReq):
> - 쿠키 정보 설정(https://github.com/acheong08/BingImageCreator) 추후 set api 추가 예정 -> 현재 고정값 > - 쿠키 정보 설정(https://github.com/acheong08/BingImageCreator) 추후 set api 추가 예정 -> 현재 고정값
> - const.py 에 지정한 OUTPUT_FOLDER 하위에 dalle 폴더가 있어야함. > - const.py 에 지정한 OUTPUT_FOLDER 하위에 dalle 폴더가 있어야함.
## 정보
> 오류 발생시 오류 발생한 파일은 에러 메세지에만 남기고 저장은 안함
""" """
response = M.ResponseBase() response = M.ImageGenerateRes()
try: try:
if not download_range(request_body_info.downloadCount): if not download_range(request_body_info.downloadCount):
raise Exception(f"downloadCount is 1~4 (current value = {request_body_info.downloadCount})") raise Exception(f"downloadCount is 1~4 (current value = {request_body_info.downloadCount})")
@@ -49,16 +52,21 @@ async def dalle3(request: Request, request_body_info: M.ImageGenerateReq):
download_count=request_body_info.downloadCount download_count=request_body_info.downloadCount
) )
dalle3_generate_image(args) info = dalle3_generate_image(args)
return response.set_message() if info.get_error_messages():
error_message = f"파일생성 error: {info.get_error_messages()}"
LOG.error(error_message)
return response.set_error(error=error_message, img_len=info.get_counter())
return response.set_message(img_len=info.get_counter())
except Exception as e: except Exception as e:
LOG.error(traceback.format_exc()) LOG.error(traceback.format_exc())
return response.set_error(e) return response.set_error(e)
@router.post("/imageGenerate/imagen", summary="이미지 생성(AI) - imagen", response_model=M.ResponseBase) @router.post("/imageGenerate/imagen", summary="이미지 생성(AI) - imagen", response_model=M.ImageGenerateRes)
async def imagen(request: Request, request_body_info: M.ImageGenerateReq): async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
""" """
## 이미지 생성(AI) - imagen ## 이미지 생성(AI) - imagen
@@ -69,18 +77,18 @@ async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
> - const.py 에 지정한 OUTPUT_FOLDER 하위에 imagen 폴더가 있어야함. > - const.py 에 지정한 OUTPUT_FOLDER 하위에 imagen 폴더가 있어야함.
""" """
response = M.ResponseBase() response = M.ImageGenerateRes()
try: try:
if not download_range(request_body_info.downloadCount): if not download_range(request_body_info.downloadCount):
raise Exception(f"downloadCount is 1~4 (current value = {request_body_info.downloadCount})") raise Exception(f"downloadCount is 1~4 (current value = {request_body_info.downloadCount})")
imagen_generate_image(prompt=request_body_info.prompt, img_length = imagen_generate_image(prompt=request_body_info.prompt,
download_count=request_body_info.downloadCount download_count=request_body_info.downloadCount
) )
return response.set_message() return response.set_message(img_len=img_length)
except Exception as e: except Exception as e:
LOG.error(traceback.format_exc()) LOG.error(traceback.format_exc())
return response.set_error(e) return response.set_error(error=e)