edit : clip-vit 모델 추가
This commit is contained in:
@@ -38,7 +38,7 @@ API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False)
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# When service starts.
|
||||
LOG.info("REST start")
|
||||
LOG.info(f"REST start (port : {conf().REST_SERVER_PORT})")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -570,12 +570,62 @@ class IndexType(str, Enum):
|
||||
hnsw = "hnsw"
|
||||
l2 = "l2"
|
||||
|
||||
|
||||
class VitIndexType(str, Enum):
|
||||
cos = "cos"
|
||||
l2 = "l2"
|
||||
|
||||
|
||||
class VitModelType(str, Enum):
|
||||
b32 = "b32"
|
||||
b16 = "b16"
|
||||
l14 = "l14"
|
||||
l14_336 = "l14_336"
|
||||
|
||||
|
||||
class VactorSearchReq(BaseModel):
|
||||
"""
|
||||
### [Request] vactor 검색
|
||||
"""
|
||||
quary_image_path : str = Field(description='quary image', example='path')
|
||||
index_type : IndexType = Field(IndexType.hnsw, description='인덱스 타입', example=IndexType.hnsw)
|
||||
query_image_path : str = Field(description='quary image', example='path')
|
||||
index_type : IndexType = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
|
||||
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||
|
||||
|
||||
|
||||
class VactorSearchVitReportReq(BaseModel):
|
||||
"""
|
||||
### [Request] vactor 검색(vit) 후 리포트 이미지 생성
|
||||
"""
|
||||
query_image_path : str = Field(description='quary image', example='path')
|
||||
index_type : VitIndexType = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
|
||||
model_type : VitModelType = Field(VitModelType.l14, description='pretrained 모델 정보', example=VitModelType.l14)
|
||||
report_path : str = Field(description='리포트 이미지 저장 경로', example='path')
|
||||
|
||||
|
||||
class VactorSearchVitReq(BaseModel):
|
||||
"""
|
||||
### [Request] vactor 검색(vit) 후 이미지 생성
|
||||
"""
|
||||
query_image_path : str = Field(description='quary image', example='path')
|
||||
index_type : VitIndexType = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
|
||||
model_type : VitModelType = Field(VitModelType.l14, description='pretrained 모델 정보', example=VitModelType.l14)
|
||||
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||
|
||||
class VactorSearchVitRes(ResponseBase):
|
||||
img_list : dict = Field({}, description='이미지 결과 리스트', example={})
|
||||
|
||||
@staticmethod
|
||||
def set_error(error):
|
||||
VactorSearchVitRes.img_list = {}
|
||||
VactorSearchVitRes.result = False
|
||||
VactorSearchVitRes.error = str(error)
|
||||
|
||||
return VactorSearchVitRes
|
||||
|
||||
@staticmethod
|
||||
def set_message(msg):
|
||||
VactorSearchVitRes.img_list = msg
|
||||
VactorSearchVitRes.result = True
|
||||
VactorSearchVitRes.error = None
|
||||
|
||||
return VactorSearchVitRes
|
||||
|
||||
@@ -19,12 +19,15 @@ from vactor_rest.app import models as M
|
||||
from vactor_rest.app.utils.date_utils import D
|
||||
from custom_logger.vactor_log import vactor_logger as LOG
|
||||
|
||||
from custom_apps.faiss.main import search_idxs
|
||||
from custom_apps.faiss_imagenet.main import search_idxs
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.faiss_functions import get_clip_info, save_report_image, get_models
|
||||
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.faiss_similarity_search import FEMUsageInfo
|
||||
|
||||
router = APIRouter(prefix="/services")
|
||||
|
||||
|
||||
@router.post("/faiss/vactor/search", summary="vactor search", response_model=M.ResponseBase)
|
||||
@router.post("/faiss/vactor/search/imagenet", summary="imagenet search", response_model=M.ResponseBase)
|
||||
async def vactor_search(request: Request, request_body_info: M.VactorSearchReq):
|
||||
"""
|
||||
## 벡터검색
|
||||
@@ -35,15 +38,55 @@ async def vactor_search(request: Request, request_body_info: M.VactorSearchReq):
|
||||
"""
|
||||
response = M.ResponseBase()
|
||||
try:
|
||||
if os.path.exists(request_body_info.quary_image_path):
|
||||
search_idxs(image_path=request_body_info.quary_image_path,
|
||||
if os.path.exists(request_body_info.query_image_path):
|
||||
search_idxs(image_path=request_body_info.query_image_path,
|
||||
index_type=request_body_info.index_type,
|
||||
search_num=request_body_info.search_num)
|
||||
else:
|
||||
raise Exception(f"File {request_body_info.quary_image_path} does not exist.")
|
||||
raise Exception(f"File {request_body_info.query_image_path} does not exist.")
|
||||
|
||||
return response.set_message()
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(e)
|
||||
|
||||
|
||||
@router.post("/faiss/vactor/search/vit/report", summary="vit search report", response_model=M.ResponseBase)
|
||||
async def vactor_report_vit(request: Request, request_body_info: M.VactorSearchVitReportReq):
|
||||
response = M.ResponseBase()
|
||||
try:
|
||||
if not os.path.exists(request_body_info.query_image_path):
|
||||
raise FileNotFoundError(f"File {request_body_info.query_image_path} does not exist.")
|
||||
|
||||
model = get_models(index_type=request_body_info.index_type, model_type=request_body_info.model_type)
|
||||
|
||||
report_info = get_clip_info(model,request_body_info.query_image_path)
|
||||
save_report_image(report_info, request_body_info.report_path)
|
||||
|
||||
return response.set_message()
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(e)
|
||||
|
||||
|
||||
@router.post("/faiss/vactor/search/vit", summary="vit search", response_model=M.VactorSearchVitRes)
|
||||
async def vactor_report_vit(request: Request, request_body_info: M.VactorSearchVitReq):
|
||||
response = M.VactorSearchVitRes()
|
||||
try:
|
||||
if not os.path.exists(request_body_info.query_image_path):
|
||||
raise FileNotFoundError(f"File {request_body_info.query_image_path} does not exist.")
|
||||
|
||||
model = get_models(index_type=request_body_info.index_type, model_type=request_body_info.model_type)
|
||||
|
||||
report_info = get_clip_info(model,request_body_info.query_image_path,top_k=request_body_info.search_num)
|
||||
|
||||
return response.set_message({
|
||||
'result_image_paths': report_info.result_image_paths,
|
||||
'result_percents': report_info.result_percents
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(e)
|
||||
Reference in New Issue
Block a user