edit : clip-vit 모델 추가

This commit is contained in:
2025-07-08 17:20:32 +09:00
parent 4c2ea70289
commit 309a91bda6
24 changed files with 1395 additions and 40 deletions

View File

@@ -38,7 +38,7 @@ API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False)
@asynccontextmanager
async def lifespan(app: FastAPI):
# When service starts.
LOG.info("REST start")
LOG.info(f"REST start (port : {conf().REST_SERVER_PORT})")
yield

View File

@@ -570,6 +570,19 @@ class IndexType:
hnsw = "hnsw"
l2 = "l2"
class VitIndexType:
cos = "cos"
l2 = "l2"
class VitModelType:
b32 = "b32"
b16 = "b16"
l14 = "l14"
l14_336 = "l14_336"
class ImageGenerateReq(BaseModel):
"""
### [Request] image generate request
@@ -590,10 +603,29 @@ class VactorImageSearchReq(BaseModel):
### [Request] vactor image search request
"""
prompt : str = Field(description='프롬프트', example='검은색 안경')
index_type : str = Field(IndexType.hnsw, description='인덱스 타입', example=IndexType.hnsw)
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
indexType : str = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
searchNum : int = Field(4, description='검색결과 이미지 갯수', example=4)
class VactorImageSearchVitReq(BaseModel):
"""
### [Request] vactor image search vit
"""
prompt : str = Field(description='프롬프트', example='검은색 안경')
modelType : str = Field(VitModelType.l14, description='pretrained model 타입', example=VitModelType.l14)
indexType : str = Field(VitIndexType.l2, description='인덱스 타입', example=VitIndexType.l2)
searchNum : int = Field(4, description='검색결과 이미지 갯수', example=4)
class VactorImageSearchVitReportReq(BaseModel):
"""
### [Request] vactor image search vit request
"""
prompt : str = Field(description='프롬프트', example='검은색 안경')
modelType : str = Field(VitModelType.l14, description='pretrained model 타입', example=VitModelType.l14)
indexType : str = Field(VitIndexType.l2, description='인덱스 타입', example=VitIndexType.l2)
#===============================================================================
#===============================================================================
#===============================================================================

View File

@@ -21,11 +21,11 @@ from custom_logger.main_log import main_logger as LOG
from custom_apps.bingimagecreator.utils import DallEArgument,dalle3_generate_image
from custom_apps.bingart.bingart import BingArtGenerator
from custom_apps.imagen.custom_imagen import imagen_generate_image, imagen_generate_image_path
from custom_apps.imagen.custom_imagen import imagen_generate_image, imagen_generate_image_path, imagen_generate_temp_image_path
from main_rest.app.utils.parsing_utils import download_range
from custom_apps.utils import cookie_manager
from utils.custom_sftp import sftp_client
from const import REMOTE_FOLDER, TEMP_FOLDER
from config import rest_config
router = APIRouter(prefix="/services")
@@ -92,7 +92,6 @@ async def bing_img_generate(request: Request, request_body_info: M.ImageGenerate
LOG.error(traceback.format_exc())
return response.set_error(e)
@router.post("/imageGenerate/bingart", summary="이미지 생성(AI) - bing art (DALL-E 3)", response_model=M.ImageGenerateRes)
async def bing_art(request: Request, request_body_info: M.ImageGenerateReq):
"""
@@ -122,7 +121,6 @@ async def bing_art(request: Request, request_body_info: M.ImageGenerateReq):
LOG.error(traceback.format_exc())
return response.set_error(e)
@router.post("/imageGenerate/imagen", summary="이미지 생성(AI) - imagen", response_model=M.ImageGenerateRes)
async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
"""
@@ -150,9 +148,8 @@ async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
LOG.error(traceback.format_exc())
return response.set_error(error=e)
@router.post("/vactorImageSearch/imageGenerate/imagen", summary="벡터 이미지 검색 - imagen", response_model=M.ResponseBase)
async def vactor_image(request: Request, request_body_info: M.VactorImageSearchReq):
@router.post("/vactorImageSearch/imagenet/imageGenerate/imagen", summary="벡터 이미지 검색(imagenet) - imagen", response_model=M.ResponseBase)
async def vactor_imagenet(request: Request, request_body_info: M.VactorImageSearchReq):
"""
## 벡터 이미지 검색 - imagen
> imagen AI를 이용하여 이미지 생성 후 vactor 검색
@@ -164,13 +161,15 @@ async def vactor_image(request: Request, request_body_info: M.VactorImageSearchR
"""
response = M.ResponseBase()
try:
if request_body_info.index_type not in [M.IndexType.hnsw, M.IndexType.l2]:
raise Exception(f"index_type is hnsw or l2 (current value = {request_body_info.index_type})")
if request_body_info.indexType not in [M.IndexType.hnsw, M.IndexType.l2]:
raise Exception(f"indexType is hnsw or l2 (current value = {request_body_info.indexType})")
img_path = imagen_generate_image_path(image_prompt=request_body_info.prompt)
vactor_request_data = {'quary_image_path' : img_path,'index_type' : request_body_info.index_type, 'search_num' : request_body_info.search_num}
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search', data=json.dumps(vactor_request_data))
vactor_request_data = {'query_image_path' : img_path,
'index_type' : request_body_info.indexType,
'search_num' : request_body_info.searchNum}
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/imagenet', data=json.dumps(vactor_request_data))
if vactor_response.status_code != 200:
raise Exception(f"response error: {json.loads(vactor_response.text)['error']}")
@@ -183,14 +182,157 @@ async def vactor_image(request: Request, request_body_info: M.VactorImageSearchR
_base_bame = os.path.basename(_directory_path)
# remote 폴더 생성
sftp_client.remote_mkdir(os.path.join(REMOTE_FOLDER, _base_bame))
sftp_client.remote_mkdir(os.path.join(rest_config.remote_folder, _base_bame))
# remote 폴더에 이미지 저장
for i in os.listdir(_directory_path):
sftp_client.remote_copy_data(local_path=os.path.join(_directory_path, i), remote_path=os.path.join(REMOTE_FOLDER, _base_bame, i))
sftp_client.remote_copy_data(local_path=os.path.join(_directory_path, i), remote_path=os.path.join(rest_config.remote_folder, _base_bame, i))
return response.set_message()
except Exception as e:
LOG.error(traceback.format_exc())
return response.set_error(error=e)
@router.post("/vactorImageSearch/vit/imageGenerate/imagen", summary="벡터 이미지 검색(clip-vit) - imagen", response_model=M.ResponseBase)
async def vactor_vit_report(request: Request, request_body_info: M.VactorImageSearchVitReq):
"""
## 벡터 이미지 검색(clip-vit) - imagen
> imagen AI를 이용하여 이미지 생성 후 vactor 검색 그후 결과 이미지 생성
### Requriements
> - googlecli 설치(https://cloud.google.com/sdk/docs/install?hl=ko#linux)
### options
> - modelType -> b32,b16,l14,l14_336
> - indexType -> l2,cos
"""
response = M.ResponseBase()
try:
if not download_range(request_body_info.searchNum, max=10):
raise Exception(f"downloadCound is invalid (current value = {request_body_info.searchNum})")
if request_body_info.modelType not in [M.VitModelType.b32, M.VitModelType.b16, M.VitModelType.l14, M.VitModelType.l14_336]:
raise Exception(f"modelType is invalid (current value = {request_body_info.modelType})")
if request_body_info.indexType not in [M.VitIndexType.cos, M.VitIndexType.l2]:
raise Exception(f"indexType is invalid (current value = {request_body_info.indexType})")
query_image_path = imagen_generate_temp_image_path(image_prompt=request_body_info.prompt)
vector_request_data = {'query_image_path' : query_image_path,
'index_type' : request_body_info.indexType,
'model_type' : request_body_info.modelType,
'search_num' : request_body_info.searchNum}
vector_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/vit', data=json.dumps(vector_request_data))
vector_response_dict = json.loads(vector_response.text)
if vector_response.status_code != 200:
raise Exception(f"response error: {vector_response_dict['error']}")
if vector_response_dict["error"] != None:
raise Exception(f"vactor error: {vector_response_dict['error']}")
result_image_paths = vector_response_dict.get('img_list').get('result_image_paths')
result_percents = vector_response_dict.get('img_list').get('result_percents')
# 원격지 폴더 생성
remote_directory = os.path.join(rest_config.remote_folder, f"imagen_query_{request_body_info.modelType}_{request_body_info.indexType}_{request_body_info.prompt}_{D.date_file_name()}")
sftp_client.remote_mkdir(remote_directory)
# 원격지에 이미지 저장
sftp_client.remote_copy_data(local_path=query_image_path, remote_path=os.path.join(remote_directory,"query.png"))
for img_path, img_percent in zip(result_image_paths,result_percents):
sftp_client.remote_copy_data(local_path=img_path, remote_path=os.path.join(remote_directory,f"search_{img_percent}.png"))
# Clean up temporary files
if 'query_image_path' in locals():
if os.path.exists(query_image_path):
os.remove(query_image_path)
del query_image_path
return response.set_message()
except Exception as e:
LOG.error(traceback.format_exc())
# Clean up temporary files
if 'query_image_path' in locals():
if os.path.exists(query_image_path):
os.remove(query_image_path)
del query_image_path
return response.set_error(error=e)
@router.post("/vactorImageSearch/vit/imageGenerate/imagen/report", summary="벡터 이미지 검색(clip-vit) - imagen, report 생성", response_model=M.ResponseBase)
async def vactor_vit_report(request: Request, request_body_info: M.VactorImageSearchVitReportReq):
"""
## 벡터 이미지 검색(clip-vit) - imagen, report 생성
> imagen AI를 이용하여 이미지 생성 후 vactor 검색 그후 종합결과 이미지 생성
### Requriements
> - googlecli 설치(https://cloud.google.com/sdk/docs/install?hl=ko#linux)
### options
> - modelType -> b32,b16,l14,l14_336
> - indexType -> l2,cos
"""
response = M.ResponseBase()
try:
if request_body_info.modelType not in [M.VitModelType.b32, M.VitModelType.b16, M.VitModelType.l14, M.VitModelType.l14_336]:
raise Exception(f"modelType is invalid (current value = {request_body_info.modelType})")
if request_body_info.indexType not in [M.VitIndexType.cos, M.VitIndexType.l2]:
raise Exception(f"indexType is invalid (current value = {request_body_info.indexType})")
query_image_path = imagen_generate_temp_image_path(image_prompt=request_body_info.prompt)
report_image_path = f"{os.path.splitext(query_image_path)[0]}_report.png"
vactor_request_data = {'query_image_path' : query_image_path,
'index_type' : request_body_info.indexType,
'model_type' : request_body_info.modelType,
'report_path' : report_image_path}
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/vit/report', data=json.dumps(vactor_request_data))
if vactor_response.status_code != 200:
raise Exception(f"response error: {json.loads(vactor_response.text)['error']}")
if json.loads(vactor_response.text)["error"] != None:
raise Exception(f"vactor error: {json.loads(vactor_response.text)['error']}")
# remote 폴더에 이미지 저장
sftp_client.remote_copy_data(local_path=report_image_path,
remote_path=os.path.join(rest_config.remote_folder, f"imagen_report_vit_{request_body_info.prompt}_{D.date_file_name()}.png"))
# Clean up temporary files
if 'query_image_path' in locals():
if os.path.exists(query_image_path):
os.remove(query_image_path)
del query_image_path
if 'report_image_path' in locals():
if os.path.exists(report_image_path):
os.remove(report_image_path)
del report_image_path
return response.set_message()
except Exception as e:
LOG.error(traceback.format_exc())
# Clean up temporary files
if 'query_image_path' in locals():
if os.path.exists(query_image_path):
os.remove(query_image_path)
del query_image_path
if 'report_image_path' in locals():
if os.path.exists(report_image_path):
os.remove(report_image_path)
del report_image_path
return response.set_error(error=e)