edit : 실제 이미지 생성한 갯수 response 에 추가
This commit is contained in:
@@ -181,6 +181,7 @@ class ImageGenAsync:
|
|||||||
links: list,
|
links: list,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
download_count: int,
|
download_count: int,
|
||||||
|
counter,
|
||||||
file_name: str = None,
|
file_name: str = None,
|
||||||
prompt: str = None
|
prompt: str = None
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -202,7 +203,8 @@ class ImageGenAsync:
|
|||||||
|
|
||||||
for link in links[:download_count]:
|
for link in links[:download_count]:
|
||||||
if download_count == 1:
|
if download_count == 1:
|
||||||
_path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{_datetime}.jpg")
|
_file_name = f"{model}_{parsing_file_name}_{_datetime}.png"
|
||||||
|
_path = os.path.join(output_dir, _file_name)
|
||||||
if os.path.exists(_path):
|
if os.path.exists(_path):
|
||||||
raise Exception("파일 이미 존재함")
|
raise Exception("파일 이미 존재함")
|
||||||
|
|
||||||
@@ -211,24 +213,41 @@ class ImageGenAsync:
|
|||||||
response = await self.session.get(link)
|
response = await self.session.get(link)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise Exception("Could not download image")
|
raise Exception("Could not download image")
|
||||||
# save response to file
|
|
||||||
with open(_path,"wb") as output_file:
|
# 이미지 파일 생성 x(error)
|
||||||
output_file.write(response.content)
|
if 'svg viewBox' in str(response.content):
|
||||||
|
counter.add_error_messages(_file_name)
|
||||||
|
|
||||||
|
# 이미지 파일 생성
|
||||||
|
else:
|
||||||
|
# save response to file
|
||||||
|
with open(_path,"wb") as output_file:
|
||||||
|
output_file.write(response.content)
|
||||||
|
counter.count()
|
||||||
|
|
||||||
jpeg_index += 1
|
jpeg_index += 1
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{jpeg_index}_{_datetime}.jpg")
|
_file_name = f"{model}_{parsing_file_name}_{jpeg_index}_{_datetime}.png"
|
||||||
|
_path = os.path.join(output_dir, _file_name)
|
||||||
if os.path.exists(_path):
|
if os.path.exists(_path):
|
||||||
raise Exception("파일 이미 존재함")
|
raise Exception("파일 이미 존재함")
|
||||||
|
|
||||||
while os.path.exists(_path):
|
|
||||||
jpeg_index += 1
|
|
||||||
response = await self.session.get(link)
|
response = await self.session.get(link)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise Exception("Could not download image")
|
raise Exception("Could not download image")
|
||||||
# save response to file
|
|
||||||
with open(_path,"wb") as output_file:
|
# 이미지 파일 생성 x(error)
|
||||||
output_file.write(response.content)
|
if 'svg viewBox' in str(response.content) or 'var' in str(response.content):
|
||||||
|
counter.add_error_messages(_file_name)
|
||||||
|
|
||||||
|
# 이미지 파일 생성
|
||||||
|
else:
|
||||||
|
# save response to file
|
||||||
|
with open(_path,"wb") as output_file:
|
||||||
|
output_file.write(response.content)
|
||||||
|
counter.count()
|
||||||
|
|
||||||
jpeg_index += 1
|
jpeg_index += 1
|
||||||
|
|
||||||
except httpx.InvalidURL as url_exception:
|
except httpx.InvalidURL as url_exception:
|
||||||
@@ -241,18 +260,12 @@ async def async_image_gen(
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
download_count: int,
|
download_count: int,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
|
counter,
|
||||||
u_cookie=None,
|
u_cookie=None,
|
||||||
debug_file=None,
|
debug_file=None,
|
||||||
quiet=False,
|
quiet=False,
|
||||||
all_cookies=None,
|
all_cookies=None,
|
||||||
):
|
):
|
||||||
async with ImageGenAsync(
|
async with ImageGenAsync(u_cookie,debug_file=debug_file,quiet=quiet,all_cookies=all_cookies,) as image_generator:
|
||||||
u_cookie,
|
|
||||||
debug_file=debug_file,
|
|
||||||
quiet=quiet,
|
|
||||||
all_cookies=all_cookies,
|
|
||||||
) as image_generator:
|
|
||||||
images = await image_generator.get_images(prompt)
|
images = await image_generator.get_images(prompt)
|
||||||
await image_generator.save_images(
|
await image_generator.save_images(links=images, counter=counter, output_dir=output_dir, download_count=download_count,prompt=prompt)
|
||||||
images, output_dir=output_dir, download_count=download_count,prompt=prompt
|
|
||||||
)
|
|
||||||
@@ -49,10 +49,38 @@ class DallEArgument:
|
|||||||
version: bool = False
|
version: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class Counter:
|
||||||
|
"""
|
||||||
|
생성된 이미지 카운트
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.counter = 0
|
||||||
|
self.error_messages = []
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.counter = 0
|
||||||
|
self.error_messages = []
|
||||||
|
|
||||||
|
def count(self):
|
||||||
|
self.counter += 1
|
||||||
|
|
||||||
|
def get_counter(self):
|
||||||
|
return self.counter
|
||||||
|
|
||||||
|
def add_error_messages(self,error):
|
||||||
|
self.error_messages.append(error)
|
||||||
|
|
||||||
|
def get_error_messages(self):
|
||||||
|
return self.error_messages
|
||||||
|
|
||||||
|
|
||||||
def dalle3_generate_image(args):
|
def dalle3_generate_image(args):
|
||||||
|
|
||||||
nest_asyncio.apply()
|
nest_asyncio.apply()
|
||||||
|
|
||||||
|
counter = Counter()
|
||||||
|
|
||||||
if not os.path.isdir(args.output_dir):
|
if not os.path.isdir(args.output_dir):
|
||||||
raise FileExistsError(f"FileExistsError: {args.output_dir}")
|
raise FileExistsError(f"FileExistsError: {args.output_dir}")
|
||||||
|
|
||||||
@@ -76,13 +104,16 @@ def dalle3_generate_image(args):
|
|||||||
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
async_image_gen(
|
async_image_gen(
|
||||||
args.prompt,
|
prompt=args.prompt,
|
||||||
args.download_count,
|
download_count=args.download_count,
|
||||||
args.output_dir,
|
output_dir=args.output_dir,
|
||||||
args.U,
|
counter=counter,
|
||||||
args.debug_file,
|
u_cookie=args.U,
|
||||||
args.quiet,
|
debug_file=args.debug_file,
|
||||||
|
quiet=args.quiet,
|
||||||
all_cookies=cookie_json,
|
all_cookies=cookie_json,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return counter
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ def imagen_generate_image(prompt,download_count=1):
|
|||||||
|
|
||||||
if len(images.images) <=1 :
|
if len(images.images) <=1 :
|
||||||
images[0].save(location=os.path.join(_folder,f"imagen_{_file_name}_{_datetime}.png"), include_generation_parameters=False)
|
images[0].save(location=os.path.join(_folder,f"imagen_{_file_name}_{_datetime}.png"), include_generation_parameters=False)
|
||||||
|
return 1
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for i in range(len(images.images)):
|
for i in range(len(images.images)):
|
||||||
images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i+1}_{_datetime}.png"), include_generation_parameters=False)
|
images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i+1}_{_datetime}.png"), include_generation_parameters=False)
|
||||||
|
|
||||||
|
return len(images.images)
|
||||||
@@ -572,3 +572,30 @@ class ImageGenerateReq(BaseModel):
|
|||||||
"""
|
"""
|
||||||
prompt : str = Field(description='프롬프트', example='검은색 안경')
|
prompt : str = Field(description='프롬프트', example='검은색 안경')
|
||||||
downloadCount : int = Field(1, description='이미지 생성 갯수', example=1)
|
downloadCount : int = Field(1, description='이미지 생성 갯수', example=1)
|
||||||
|
|
||||||
|
|
||||||
|
#===============================================================================
|
||||||
|
#===============================================================================
|
||||||
|
#===============================================================================
|
||||||
|
#===============================================================================
|
||||||
|
|
||||||
|
class ImageGenerateRes(ResponseBase):
|
||||||
|
"""
|
||||||
|
### image generate response
|
||||||
|
"""
|
||||||
|
imageLen : int = Field(0, description='실제 이미지 생성 갯수', example=1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_error(error,img_len=0):
|
||||||
|
ImageGenerateRes.result = False
|
||||||
|
ImageGenerateRes.error = str(error)
|
||||||
|
ImageGenerateRes.imageLen = img_len
|
||||||
|
|
||||||
|
return ImageGenerateRes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_message(img_len):
|
||||||
|
ImageGenerateRes.result = True
|
||||||
|
ImageGenerateRes.error = None
|
||||||
|
ImageGenerateRes.imageLen = img_len
|
||||||
|
return ImageGenerateRes
|
||||||
@@ -27,7 +27,7 @@ from rest.app.utils.parsing_utils import download_range
|
|||||||
router = APIRouter(prefix="/services")
|
router = APIRouter(prefix="/services")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/imageGenerate/dalle3", summary="이미지 생성(AI) - DALL-E 3", response_model=M.ResponseBase)
|
@router.post("/imageGenerate/dalle3", summary="이미지 생성(AI) - DALL-E 3", response_model=M.ImageGenerateRes)
|
||||||
async def dalle3(request: Request, request_body_info: M.ImageGenerateReq):
|
async def dalle3(request: Request, request_body_info: M.ImageGenerateReq):
|
||||||
"""
|
"""
|
||||||
## 이미지 생성(AI) - DALL-E 3
|
## 이미지 생성(AI) - DALL-E 3
|
||||||
@@ -37,8 +37,11 @@ async def dalle3(request: Request, request_body_info: M.ImageGenerateReq):
|
|||||||
> - 쿠키 정보 설정(https://github.com/acheong08/BingImageCreator) 추후 set api 추가 예정 -> 현재 고정값
|
> - 쿠키 정보 설정(https://github.com/acheong08/BingImageCreator) 추후 set api 추가 예정 -> 현재 고정값
|
||||||
> - const.py 에 지정한 OUTPUT_FOLDER 하위에 dalle 폴더가 있어야함.
|
> - const.py 에 지정한 OUTPUT_FOLDER 하위에 dalle 폴더가 있어야함.
|
||||||
|
|
||||||
|
## 정보
|
||||||
|
> 오류 발생시 오류 발생한 파일은 에러 메세지에만 남기고 저장은 안함
|
||||||
|
|
||||||
"""
|
"""
|
||||||
response = M.ResponseBase()
|
response = M.ImageGenerateRes()
|
||||||
try:
|
try:
|
||||||
if not download_range(request_body_info.downloadCount):
|
if not download_range(request_body_info.downloadCount):
|
||||||
raise Exception(f"downloadCount is 1~4 (current value = {request_body_info.downloadCount})")
|
raise Exception(f"downloadCount is 1~4 (current value = {request_body_info.downloadCount})")
|
||||||
@@ -49,16 +52,21 @@ async def dalle3(request: Request, request_body_info: M.ImageGenerateReq):
|
|||||||
download_count=request_body_info.downloadCount
|
download_count=request_body_info.downloadCount
|
||||||
)
|
)
|
||||||
|
|
||||||
dalle3_generate_image(args)
|
info = dalle3_generate_image(args)
|
||||||
|
|
||||||
return response.set_message()
|
if info.get_error_messages():
|
||||||
|
error_message = f"파일생성 error: {info.get_error_messages()}"
|
||||||
|
LOG.error(error_message)
|
||||||
|
return response.set_error(error=error_message, img_len=info.get_counter())
|
||||||
|
|
||||||
|
return response.set_message(img_len=info.get_counter())
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.error(traceback.format_exc())
|
LOG.error(traceback.format_exc())
|
||||||
return response.set_error(e)
|
return response.set_error(e)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/imageGenerate/imagen", summary="이미지 생성(AI) - imagen", response_model=M.ResponseBase)
|
@router.post("/imageGenerate/imagen", summary="이미지 생성(AI) - imagen", response_model=M.ImageGenerateRes)
|
||||||
async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
|
async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
|
||||||
"""
|
"""
|
||||||
## 이미지 생성(AI) - imagen
|
## 이미지 생성(AI) - imagen
|
||||||
@@ -69,18 +77,18 @@ async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
|
|||||||
> - const.py 에 지정한 OUTPUT_FOLDER 하위에 imagen 폴더가 있어야함.
|
> - const.py 에 지정한 OUTPUT_FOLDER 하위에 imagen 폴더가 있어야함.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
response = M.ResponseBase()
|
response = M.ImageGenerateRes()
|
||||||
try:
|
try:
|
||||||
|
|
||||||
if not download_range(request_body_info.downloadCount):
|
if not download_range(request_body_info.downloadCount):
|
||||||
raise Exception(f"downloadCount is 1~4 (current value = {request_body_info.downloadCount})")
|
raise Exception(f"downloadCount is 1~4 (current value = {request_body_info.downloadCount})")
|
||||||
|
|
||||||
imagen_generate_image(prompt=request_body_info.prompt,
|
img_length = imagen_generate_image(prompt=request_body_info.prompt,
|
||||||
download_count=request_body_info.downloadCount
|
download_count=request_body_info.downloadCount
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.set_message()
|
return response.set_message(img_len=img_length)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.error(traceback.format_exc())
|
LOG.error(traceback.format_exc())
|
||||||
return response.set_error(e)
|
return response.set_error(error=e)
|
||||||
Reference in New Issue
Block a user