edit : 이미지 생성시 생성 갯수 지정 가능

This commit is contained in:
2025-01-17 13:26:19 +09:00
parent bb54fdd85e
commit c296204541
6 changed files with 36 additions and 11 deletions

1
.gitignore vendored
View File

@@ -144,3 +144,4 @@ cython_debug/
log log
sheet_counter.txt sheet_counter.txt
*output

View File

@@ -196,12 +196,13 @@ class ImageGenAsync:
os.mkdir(output_dir) os.mkdir(output_dir)
try: try:
model = "dalle3" model = "dalle3"
jpeg_index = 0 jpeg_index = 1
parsing_file_name = prompt_to_filenames(prompt) parsing_file_name = prompt_to_filenames(prompt)
_datetime = D.date_file_name()
for link in links[:download_count]: for link in links[:download_count]:
if download_count == 1: if download_count == 1:
_path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{D.date_file_name()}.jpg") _path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{_datetime}.jpg")
if os.path.exists(_path): if os.path.exists(_path):
raise Exception("파일 이미 존재함") raise Exception("파일 이미 존재함")
@@ -216,7 +217,7 @@ class ImageGenAsync:
jpeg_index += 1 jpeg_index += 1
else: else:
_path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{jpeg_index}_{D.date_file_name()}.jpg") _path = os.path.join(output_dir, f"{model}_{parsing_file_name}_{jpeg_index}_{_datetime}.jpg")
if os.path.exists(_path): if os.path.exists(_path):
raise Exception("파일 이미 존재함") raise Exception("파일 이미 존재함")

View File

@@ -22,6 +22,7 @@ def imagen_generate_image(prompt,download_count=1):
_file_name = prompt_to_filenames(prompt) _file_name = prompt_to_filenames(prompt)
_folder = os.path.join(OUTPUT_FOLDER,"imagen") _folder = os.path.join(OUTPUT_FOLDER,"imagen")
_datetime = D.date_file_name()
if not os.path.isdir(_folder): if not os.path.isdir(_folder):
raise FileExistsError(f"FileExistsError: {_folder}") raise FileExistsError(f"FileExistsError: {_folder}")
@@ -40,7 +41,7 @@ def imagen_generate_image(prompt,download_count=1):
) )
if len(images.images) <=1 : if len(images.images) <=1 :
images[0].save(location=os.path.join(_folder,f"imagen_{_file_name}_{D.date_file_name()}.png"), include_generation_parameters=False) images[0].save(location=os.path.join(_folder,f"imagen_{_file_name}_{_datetime}.png"), include_generation_parameters=False)
else: else:
for i in range(len(images.images)): for i in range(len(images.images)):
images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i}_{D.date_file_name()}.png"), include_generation_parameters=False) images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i+1}_{_datetime}.png"), include_generation_parameters=False)

View File

@@ -571,4 +571,4 @@ class ImageGenerateReq(BaseModel):
### [Request] image generate request ### [Request] image generate request
""" """
prompt : str = Field(description='프롬프트', example='검은색 안경') prompt : str = Field(description='프롬프트', example='검은색 안경')
# downloadCount : int = Field(1, description='이미지 생성 갯수', example=1) downloadCount : int = Field(1, description='이미지 생성 갯수', example=1)

View File

@@ -21,6 +21,7 @@ from custom_logger.custom_log import custom_logger as LOG
from custom_apps.dalle3.custom_dalle import DallEArgument,dalle3_generate_image from custom_apps.dalle3.custom_dalle import DallEArgument,dalle3_generate_image
from custom_apps.imagen.custom_imagen import imagen_generate_image from custom_apps.imagen.custom_imagen import imagen_generate_image
from rest.app.utils.parsing_utils import download_range
router = APIRouter(prefix="/services") router = APIRouter(prefix="/services")
@@ -39,8 +40,13 @@ async def dalle3(request: Request, request_body_info: M.ImageGenerateReq):
""" """
response = M.ResponseBase() response = M.ResponseBase()
try: try:
args = DallEArgument(prompt=request_body_info.prompt if not download_range(request_body_info.downloadCount):
# , download_count=request_body_info.downloadCount raise Exception(f"downloadCount is 1~4 (current value = {request_body_info.downloadCount})")
args = DallEArgument(
prompt=request_body_info.prompt,
download_count=request_body_info.downloadCount
) )
dalle3_generate_image(args) dalle3_generate_image(args)
@@ -66,7 +72,12 @@ async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
response = M.ResponseBase() response = M.ResponseBase()
try: try:
imagen_generate_image(request_body_info.prompt) if not download_range(request_body_info.downloadCount):
raise Exception(f"downloadCount is 1~4 (current value = {request_body_info.downloadCount})")
imagen_generate_image(prompt=request_body_info.prompt,
download_count=request_body_info.downloadCount
)
return response.set_message() return response.set_message()

View File

@@ -12,3 +12,14 @@ def prompt_to_filenames(prompt):
filename += i filename += i
return filename return filename
def download_range(download_count:int):
_min = 1
_max = 4
if _min <= download_count and download_count <= _max:
return True
return False