edit : clip-vit 모델 추가
This commit is contained in:
@@ -38,7 +38,7 @@ API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False)
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# When service starts.
|
||||
LOG.info("REST start")
|
||||
LOG.info(f"REST start (port : {conf().REST_SERVER_PORT})")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -570,6 +570,19 @@ class IndexType:
|
||||
hnsw = "hnsw"
|
||||
l2 = "l2"
|
||||
|
||||
|
||||
class VitIndexType:
|
||||
cos = "cos"
|
||||
l2 = "l2"
|
||||
|
||||
|
||||
class VitModelType:
|
||||
b32 = "b32"
|
||||
b16 = "b16"
|
||||
l14 = "l14"
|
||||
l14_336 = "l14_336"
|
||||
|
||||
|
||||
class ImageGenerateReq(BaseModel):
|
||||
"""
|
||||
### [Request] image generate request
|
||||
@@ -590,10 +603,29 @@ class VactorImageSearchReq(BaseModel):
|
||||
### [Request] vactor image search request
|
||||
"""
|
||||
prompt : str = Field(description='프롬프트', example='검은색 안경')
|
||||
index_type : str = Field(IndexType.hnsw, description='인덱스 타입', example=IndexType.hnsw)
|
||||
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||
indexType : str = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
|
||||
searchNum : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||
|
||||
|
||||
class VactorImageSearchVitReq(BaseModel):
|
||||
"""
|
||||
### [Request] vactor image search vit
|
||||
"""
|
||||
prompt : str = Field(description='프롬프트', example='검은색 안경')
|
||||
modelType : str = Field(VitModelType.l14, description='pretrained model 타입', example=VitModelType.l14)
|
||||
indexType : str = Field(VitIndexType.l2, description='인덱스 타입', example=VitIndexType.l2)
|
||||
searchNum : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||
|
||||
|
||||
class VactorImageSearchVitReportReq(BaseModel):
|
||||
"""
|
||||
### [Request] vactor image search vit request
|
||||
"""
|
||||
prompt : str = Field(description='프롬프트', example='검은색 안경')
|
||||
modelType : str = Field(VitModelType.l14, description='pretrained model 타입', example=VitModelType.l14)
|
||||
indexType : str = Field(VitIndexType.l2, description='인덱스 타입', example=VitIndexType.l2)
|
||||
|
||||
|
||||
#===============================================================================
|
||||
#===============================================================================
|
||||
#===============================================================================
|
||||
|
||||
@@ -21,11 +21,11 @@ from custom_logger.main_log import main_logger as LOG
|
||||
|
||||
from custom_apps.bingimagecreator.utils import DallEArgument,dalle3_generate_image
|
||||
from custom_apps.bingart.bingart import BingArtGenerator
|
||||
from custom_apps.imagen.custom_imagen import imagen_generate_image, imagen_generate_image_path
|
||||
from custom_apps.imagen.custom_imagen import imagen_generate_image, imagen_generate_image_path, imagen_generate_temp_image_path
|
||||
from main_rest.app.utils.parsing_utils import download_range
|
||||
from custom_apps.utils import cookie_manager
|
||||
from utils.custom_sftp import sftp_client
|
||||
from const import REMOTE_FOLDER, TEMP_FOLDER
|
||||
from config import rest_config
|
||||
|
||||
router = APIRouter(prefix="/services")
|
||||
|
||||
@@ -92,7 +92,6 @@ async def bing_img_generate(request: Request, request_body_info: M.ImageGenerate
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(e)
|
||||
|
||||
|
||||
@router.post("/imageGenerate/bingart", summary="이미지 생성(AI) - bing art (DALL-E 3)", response_model=M.ImageGenerateRes)
|
||||
async def bing_art(request: Request, request_body_info: M.ImageGenerateReq):
|
||||
"""
|
||||
@@ -122,7 +121,6 @@ async def bing_art(request: Request, request_body_info: M.ImageGenerateReq):
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(e)
|
||||
|
||||
|
||||
@router.post("/imageGenerate/imagen", summary="이미지 생성(AI) - imagen", response_model=M.ImageGenerateRes)
|
||||
async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
|
||||
"""
|
||||
@@ -150,9 +148,8 @@ async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(error=e)
|
||||
|
||||
|
||||
@router.post("/vactorImageSearch/imageGenerate/imagen", summary="벡터 이미지 검색 - imagen", response_model=M.ResponseBase)
|
||||
async def vactor_image(request: Request, request_body_info: M.VactorImageSearchReq):
|
||||
@router.post("/vactorImageSearch/imagenet/imageGenerate/imagen", summary="벡터 이미지 검색(imagenet) - imagen", response_model=M.ResponseBase)
|
||||
async def vactor_imagenet(request: Request, request_body_info: M.VactorImageSearchReq):
|
||||
"""
|
||||
## 벡터 이미지 검색 - imagen
|
||||
> imagen AI를 이용하여 이미지 생성 후 vactor 검색
|
||||
@@ -164,13 +161,15 @@ async def vactor_image(request: Request, request_body_info: M.VactorImageSearchR
|
||||
"""
|
||||
response = M.ResponseBase()
|
||||
try:
|
||||
if request_body_info.index_type not in [M.IndexType.hnsw, M.IndexType.l2]:
|
||||
raise Exception(f"index_type is hnsw or l2 (current value = {request_body_info.index_type})")
|
||||
if request_body_info.indexType not in [M.IndexType.hnsw, M.IndexType.l2]:
|
||||
raise Exception(f"indexType is hnsw or l2 (current value = {request_body_info.indexType})")
|
||||
|
||||
img_path = imagen_generate_image_path(image_prompt=request_body_info.prompt)
|
||||
|
||||
vactor_request_data = {'quary_image_path' : img_path,'index_type' : request_body_info.index_type, 'search_num' : request_body_info.search_num}
|
||||
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search', data=json.dumps(vactor_request_data))
|
||||
vactor_request_data = {'query_image_path' : img_path,
|
||||
'index_type' : request_body_info.indexType,
|
||||
'search_num' : request_body_info.searchNum}
|
||||
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/imagenet', data=json.dumps(vactor_request_data))
|
||||
|
||||
if vactor_response.status_code != 200:
|
||||
raise Exception(f"response error: {json.loads(vactor_response.text)['error']}")
|
||||
@@ -183,14 +182,157 @@ async def vactor_image(request: Request, request_body_info: M.VactorImageSearchR
|
||||
_base_bame = os.path.basename(_directory_path)
|
||||
|
||||
# remote 폴더 생성
|
||||
sftp_client.remote_mkdir(os.path.join(REMOTE_FOLDER, _base_bame))
|
||||
sftp_client.remote_mkdir(os.path.join(rest_config.remote_folder, _base_bame))
|
||||
|
||||
# remote 폴더에 이미지 저장
|
||||
for i in os.listdir(_directory_path):
|
||||
sftp_client.remote_copy_data(local_path=os.path.join(_directory_path, i), remote_path=os.path.join(REMOTE_FOLDER, _base_bame, i))
|
||||
sftp_client.remote_copy_data(local_path=os.path.join(_directory_path, i), remote_path=os.path.join(rest_config.remote_folder, _base_bame, i))
|
||||
|
||||
return response.set_message()
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(error=e)
|
||||
|
||||
@router.post("/vactorImageSearch/vit/imageGenerate/imagen", summary="벡터 이미지 검색(clip-vit) - imagen", response_model=M.ResponseBase)
|
||||
async def vactor_vit_report(request: Request, request_body_info: M.VactorImageSearchVitReq):
|
||||
"""
|
||||
## 벡터 이미지 검색(clip-vit) - imagen
|
||||
> imagen AI를 이용하여 이미지 생성 후 vactor 검색 그후 결과 이미지 생성
|
||||
|
||||
### Requriements
|
||||
> - googlecli 설치(https://cloud.google.com/sdk/docs/install?hl=ko#linux)
|
||||
|
||||
### options
|
||||
> - modelType -> b32,b16,l14,l14_336
|
||||
> - indexType -> l2,cos
|
||||
|
||||
"""
|
||||
response = M.ResponseBase()
|
||||
try:
|
||||
if not download_range(request_body_info.searchNum, max=10):
|
||||
raise Exception(f"downloadCound is invalid (current value = {request_body_info.searchNum})")
|
||||
|
||||
if request_body_info.modelType not in [M.VitModelType.b32, M.VitModelType.b16, M.VitModelType.l14, M.VitModelType.l14_336]:
|
||||
raise Exception(f"modelType is invalid (current value = {request_body_info.modelType})")
|
||||
|
||||
if request_body_info.indexType not in [M.VitIndexType.cos, M.VitIndexType.l2]:
|
||||
raise Exception(f"indexType is invalid (current value = {request_body_info.indexType})")
|
||||
|
||||
query_image_path = imagen_generate_temp_image_path(image_prompt=request_body_info.prompt)
|
||||
|
||||
vector_request_data = {'query_image_path' : query_image_path,
|
||||
'index_type' : request_body_info.indexType,
|
||||
'model_type' : request_body_info.modelType,
|
||||
'search_num' : request_body_info.searchNum}
|
||||
|
||||
vector_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/vit', data=json.dumps(vector_request_data))
|
||||
|
||||
vector_response_dict = json.loads(vector_response.text)
|
||||
|
||||
if vector_response.status_code != 200:
|
||||
raise Exception(f"response error: {vector_response_dict['error']}")
|
||||
|
||||
if vector_response_dict["error"] != None:
|
||||
raise Exception(f"vactor error: {vector_response_dict['error']}")
|
||||
|
||||
result_image_paths = vector_response_dict.get('img_list').get('result_image_paths')
|
||||
result_percents = vector_response_dict.get('img_list').get('result_percents')
|
||||
|
||||
# 원격지 폴더 생성
|
||||
remote_directory = os.path.join(rest_config.remote_folder, f"imagen_query_{request_body_info.modelType}_{request_body_info.indexType}_{request_body_info.prompt}_{D.date_file_name()}")
|
||||
sftp_client.remote_mkdir(remote_directory)
|
||||
|
||||
# 원격지에 이미지 저장
|
||||
sftp_client.remote_copy_data(local_path=query_image_path, remote_path=os.path.join(remote_directory,"query.png"))
|
||||
|
||||
for img_path, img_percent in zip(result_image_paths,result_percents):
|
||||
sftp_client.remote_copy_data(local_path=img_path, remote_path=os.path.join(remote_directory,f"search_{img_percent}.png"))
|
||||
|
||||
# Clean up temporary files
|
||||
if 'query_image_path' in locals():
|
||||
if os.path.exists(query_image_path):
|
||||
os.remove(query_image_path)
|
||||
del query_image_path
|
||||
|
||||
return response.set_message()
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
|
||||
# Clean up temporary files
|
||||
if 'query_image_path' in locals():
|
||||
if os.path.exists(query_image_path):
|
||||
os.remove(query_image_path)
|
||||
del query_image_path
|
||||
|
||||
return response.set_error(error=e)
|
||||
|
||||
@router.post("/vactorImageSearch/vit/imageGenerate/imagen/report", summary="벡터 이미지 검색(clip-vit) - imagen, report 생성", response_model=M.ResponseBase)
|
||||
async def vactor_vit_report(request: Request, request_body_info: M.VactorImageSearchVitReportReq):
|
||||
"""
|
||||
## 벡터 이미지 검색(clip-vit) - imagen, report 생성
|
||||
> imagen AI를 이용하여 이미지 생성 후 vactor 검색 그후 종합결과 이미지 생성
|
||||
|
||||
### Requriements
|
||||
> - googlecli 설치(https://cloud.google.com/sdk/docs/install?hl=ko#linux)
|
||||
|
||||
### options
|
||||
> - modelType -> b32,b16,l14,l14_336
|
||||
> - indexType -> l2,cos
|
||||
|
||||
"""
|
||||
response = M.ResponseBase()
|
||||
try:
|
||||
if request_body_info.modelType not in [M.VitModelType.b32, M.VitModelType.b16, M.VitModelType.l14, M.VitModelType.l14_336]:
|
||||
raise Exception(f"modelType is invalid (current value = {request_body_info.modelType})")
|
||||
|
||||
if request_body_info.indexType not in [M.VitIndexType.cos, M.VitIndexType.l2]:
|
||||
raise Exception(f"indexType is invalid (current value = {request_body_info.indexType})")
|
||||
|
||||
query_image_path = imagen_generate_temp_image_path(image_prompt=request_body_info.prompt)
|
||||
report_image_path = f"{os.path.splitext(query_image_path)[0]}_report.png"
|
||||
|
||||
vactor_request_data = {'query_image_path' : query_image_path,
|
||||
'index_type' : request_body_info.indexType,
|
||||
'model_type' : request_body_info.modelType,
|
||||
'report_path' : report_image_path}
|
||||
|
||||
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/vit/report', data=json.dumps(vactor_request_data))
|
||||
|
||||
if vactor_response.status_code != 200:
|
||||
raise Exception(f"response error: {json.loads(vactor_response.text)['error']}")
|
||||
|
||||
if json.loads(vactor_response.text)["error"] != None:
|
||||
raise Exception(f"vactor error: {json.loads(vactor_response.text)['error']}")
|
||||
|
||||
# remote 폴더에 이미지 저장
|
||||
sftp_client.remote_copy_data(local_path=report_image_path,
|
||||
remote_path=os.path.join(rest_config.remote_folder, f"imagen_report_vit_{request_body_info.prompt}_{D.date_file_name()}.png"))
|
||||
|
||||
# Clean up temporary files
|
||||
if 'query_image_path' in locals():
|
||||
if os.path.exists(query_image_path):
|
||||
os.remove(query_image_path)
|
||||
del query_image_path
|
||||
if 'report_image_path' in locals():
|
||||
if os.path.exists(report_image_path):
|
||||
os.remove(report_image_path)
|
||||
del report_image_path
|
||||
|
||||
return response.set_message()
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
|
||||
# Clean up temporary files
|
||||
if 'query_image_path' in locals():
|
||||
if os.path.exists(query_image_path):
|
||||
os.remove(query_image_path)
|
||||
del query_image_path
|
||||
if 'report_image_path' in locals():
|
||||
if os.path.exists(report_image_path):
|
||||
os.remove(report_image_path)
|
||||
del report_image_path
|
||||
|
||||
return response.set_error(error=e)
|
||||
Reference in New Issue
Block a user