edit : clip-vit 모델 추가
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -145,4 +145,8 @@ log
|
||||
|
||||
*output
|
||||
*temp
|
||||
datas/eyewear_all/*
|
||||
datas/eyewear_all/*
|
||||
datas/clip*
|
||||
datas/openai*
|
||||
|
||||
*result
|
||||
27
config.py
Normal file
27
config.py
Normal file
@@ -0,0 +1,27 @@
|
||||
class Config:
|
||||
def __init__(self):
|
||||
self.set_rel()
|
||||
|
||||
def set_rel(self):
|
||||
self.config = 'rel'
|
||||
self.remote_folder = "/home/fermat/STORAGE/01.Projects/A2TEC/K_EYEWEAR/02.ML_DATA/Image_generator_result"
|
||||
self.sftp_host = "192.168.200.230"
|
||||
self.sftp_port = 22
|
||||
self.sftp_id = "fermat"
|
||||
self.sftp_pw = "1234"
|
||||
|
||||
def set_dev(self):
|
||||
self.config = 'dev'
|
||||
self.remote_folder = "/home/fermat/project/FM_TEST_REST_SERVER/result"
|
||||
self.sftp_host = "192.168.200.231"
|
||||
self.sftp_port = 22
|
||||
self.sftp_id = "fermat"
|
||||
self.sftp_pw = "fermat3514"
|
||||
|
||||
import os
|
||||
if not os.path.exists(self.remote_folder):
|
||||
os.makedirs(self.remote_folder)
|
||||
|
||||
|
||||
rest_config = Config()
|
||||
# rest_config.set_dev()
|
||||
1
const.py
1
const.py
@@ -1,4 +1,3 @@
|
||||
REMOTE_FOLDER = "/home/fermat/STORAGE/01.Projects/A2TEC/K_EYEWEAR/02.ML_DATA/Image_generator_result"
|
||||
TEMP_FOLDER = "./temp"
|
||||
|
||||
ILLEGAL_FILE_NAME = ['<', '>', ':', '"', '/', '\ ', '|', '?', '*']
|
||||
11
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/__init__.py
Executable file
11
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/__init__.py
Executable file
@@ -0,0 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@file : __init__.py
|
||||
@author: hsj100
|
||||
@license: A2D2
|
||||
@brief:
|
||||
|
||||
@section Modify History
|
||||
- 2025-06-27 오후 5:35 hsj100 base
|
||||
|
||||
"""
|
||||
40
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/const.py
Executable file
40
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/const.py
Executable file
@@ -0,0 +1,40 @@
|
||||
# Access token은 huggingface.co > Settings > Access Tokens 에서 생성
|
||||
HUGGINGFACE_TOKEN = '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'
|
||||
|
||||
MODEL_DEVICE_CPU = 'cpu'
|
||||
MODEL_DEVICE_GPU = 'cuda'
|
||||
|
||||
DEFAULT_INDEX_NAME_SUFFIX = 'index.faiss'
|
||||
DEFAULT_SAVE_INDEX_SUFFIX = DEFAULT_INDEX_NAME_SUFFIX
|
||||
|
||||
INDEX_TYPE_L2 = 'l2'
|
||||
INDEX_TYPE_COSINE = 'cosine'
|
||||
|
||||
#VIT MODELS
|
||||
ViTB32 = "clip-vit-base-patch32"
|
||||
ViTB16 = "clip-vit-base-patch16"
|
||||
ViTL14 = "clip-vit-large-patch14"
|
||||
ViTL14_336 = "clip-vit-large-patch14-336"
|
||||
|
||||
#download/save path
|
||||
PRETRAINED_MODEL_PATH = "./datas"
|
||||
FAISS_VECTOR_PATH = "./datas"
|
||||
VECTOR_IMG_LIST_PATH = "./datas"
|
||||
|
||||
INDEX_IMAGES_PATH = "./datas/eyewear_all"
|
||||
|
||||
|
||||
# report image consts
|
||||
class ReportInfoConst:
|
||||
feature_extraction_model = "Feature Extraction Model"
|
||||
model_architecture = "Model Architecture"
|
||||
train_model = "Train Model"
|
||||
feature_extraction_time = "Feature Extraction Time"
|
||||
|
||||
similarity_type = "Similarity Type"
|
||||
query = "Query"
|
||||
result = "Result"
|
||||
|
||||
class similarity:
|
||||
l2 = "L2(Euclidean Distance)"
|
||||
cosine = "Cosine(Inner Product, L2)"
|
||||
135
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/faiss_functions.py
Executable file
135
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/faiss_functions.py
Executable file
@@ -0,0 +1,135 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@file : faiss_functions.py
|
||||
@author: jwkim
|
||||
@license: A2D2
|
||||
|
||||
@section Modify History
|
||||
- 2025-06-27 오후 2:05 hsj100 base
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
Basic logger: custom color logger
|
||||
"""
|
||||
try:
|
||||
from custom_logger.custom_log import custom_logger as log
|
||||
except ImportError:
|
||||
import logging as log; log.basicConfig(level=log.DEBUG)
|
||||
|
||||
|
||||
"""
|
||||
Package: 3rd party
|
||||
"""
|
||||
|
||||
"""
|
||||
Package: standard
|
||||
"""
|
||||
import random
|
||||
import pathlib
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
"""
|
||||
Package: custom
|
||||
"""
|
||||
# NOTE(hsj100): 코드 단독 실행을 위한 기본 경로 지정 (일반적으로는 본 코드 보다는 모듈 실행 방법을 권장)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 현재 스크립트의 디렉토리를 기준으로 부모 디렉토리의 경로를 가져옴
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.report_utils import *
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.faiss_similarity_search import VectorSimilarity
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.report_utils import *
|
||||
"""
|
||||
Definition
|
||||
"""
|
||||
|
||||
|
||||
"""
|
||||
Implementation
|
||||
"""
|
||||
|
||||
def get_models(index_type, model_type):
|
||||
from vactor_rest.app import models as M
|
||||
|
||||
model = None
|
||||
|
||||
# model 선택
|
||||
if index_type == M.VitIndexType.l2:
|
||||
# l2
|
||||
if model_type == M.VitModelType.l14:
|
||||
model = FEMUsageInfo.openaiclip_vit_l14_l2
|
||||
elif model_type == M.VitModelType.b32:
|
||||
model = FEMUsageInfo.openaiclip_vit_b32_l2
|
||||
elif model_type == M.VitModelType.b16:
|
||||
model = FEMUsageInfo.openaiclip_vit_b16_l2
|
||||
elif model_type == M.VitModelType.l14_336:
|
||||
model = FEMUsageInfo.openaiclip_vit_l14_336_l2
|
||||
else:
|
||||
raise Exception(f"model_type {model_type} is invalid")
|
||||
else:
|
||||
#cosine
|
||||
if model_type == M.VitModelType.l14:
|
||||
model = FEMUsageInfo.openaiclip_vit_l14_cos
|
||||
elif model_type == M.VitModelType.b32:
|
||||
model = FEMUsageInfo.openaiclip_vit_b32_cos
|
||||
elif model_type == M.VitModelType.b16:
|
||||
model = FEMUsageInfo.openaiclip_vit_b16_cos
|
||||
elif model_type == M.VitModelType.l14_336:
|
||||
model = FEMUsageInfo.openaiclip_vit_l14_336_cos
|
||||
else:
|
||||
# NOTE(JWKIM) : 현재 l14만 지원 FEMUsageInfo 에 추가해야함
|
||||
raise Exception(f"model_type {model_type} is invalid")
|
||||
|
||||
if model:
|
||||
return model
|
||||
else:
|
||||
raise Exception(f"model is not selected")
|
||||
|
||||
|
||||
def save_report_image(report_info, report_image_path):
|
||||
"""
|
||||
이미지 유사도 검색 결과를 리포트로 생성
|
||||
"""
|
||||
table_setting(report_info, report_image_path)
|
||||
|
||||
|
||||
def get_clip_info(model, query_image_path, top_k=4):
|
||||
"""
|
||||
이미지 유사도 검색
|
||||
"""
|
||||
index_file_path = os.path.join(VECTOR_IMG_LIST_PATH,f'{model.name}_{os.path.basename(model.value[1].trained_model)}_{model.value[1].index_type}_{DEFAULT_INDEX_NAME_SUFFIX}')
|
||||
txt_file_path = os.path.join(VECTOR_IMG_LIST_PATH,f'{model.name}_{os.path.basename(model.value[1].trained_model)}_{model.value[1].index_type}.txt')
|
||||
vector_model = VectorSimilarity(model_info=model,
|
||||
index_file_path=index_file_path,
|
||||
txt_file_path=txt_file_path)
|
||||
|
||||
# vector_model.create_feature_extraction_model(FEMUsageInfo.openaiclipvit)
|
||||
|
||||
vector_model.save_index_from_files(images_path=INDEX_IMAGES_PATH,
|
||||
save_index_dir=FAISS_VECTOR_PATH,
|
||||
save_txt_dir=VECTOR_IMG_LIST_PATH,
|
||||
index_type=model.value[1].index_type)
|
||||
|
||||
inference_times, result_img_paths, result_percents = vector_model.query_faiss(query_image_path, top_k=top_k)
|
||||
|
||||
report_info = ReportInfo(
|
||||
feature_extraction_model=ReportInfoConst.feature_extraction_model,
|
||||
model_architecture=ReportInfoConst.model_architecture,
|
||||
train_model=os.path.basename(model.value[1].trained_model),
|
||||
feature_extraction_time=inference_times,
|
||||
query_image_path=query_image_path,
|
||||
result_image_paths=result_img_paths,
|
||||
result_percents=result_percents,
|
||||
similarity_type=model.value[1].index_type
|
||||
)
|
||||
|
||||
return report_info
|
||||
|
||||
|
||||
410
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/faiss_similarity_search.py
Executable file
410
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/faiss_similarity_search.py
Executable file
@@ -0,0 +1,410 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@file : faiss_similarity_search.py
|
||||
@author: hsj100
|
||||
@license: A2D2
|
||||
@brief: 특징 벡터를 이용한 이미지 유사도 검색
|
||||
|
||||
@section Modify History
|
||||
- 2025-06-19 오후 2:45 hsj100 base
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
Basic logger: custom color logger
|
||||
"""
|
||||
try:
|
||||
from custom_logger.main_log import vactor_logger as log
|
||||
except ImportError:
|
||||
import logging as log; log.basicConfig(level=log.DEBUG)
|
||||
|
||||
"""
|
||||
Package: standard
|
||||
"""
|
||||
import sys
|
||||
import torch
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
from PIL import Image
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
"""
|
||||
Package: 3rd party
|
||||
"""
|
||||
|
||||
"""
|
||||
Package: custom
|
||||
"""
|
||||
# NOTE(hsj100): 코드 단독 실행을 위한 기본 경로 지정 (일반적으로는 본 코드 보다는 모듈 실행 방법을 권장)
|
||||
import os, sys
|
||||
if __name__ == '__main__':
|
||||
# 현재 스크립트의 디렉토리를 기준으로 부모 디렉토리의 경로를 가져옴
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS import feature_extraction_model as FEM
|
||||
|
||||
# 사용할 이미지 임베딩 모델 클래스 추가
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.fem_openaiclipvit import FEOpenAIClipViT
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
|
||||
|
||||
|
||||
"""
|
||||
Definition
|
||||
"""
|
||||
FILE_BASE_PATH = ''
|
||||
if __name__ == '__main__':
|
||||
FILE_BASE_PATH = os.path.dirname(os.path.realpath(__file__)) + '/'
|
||||
|
||||
# TRAINED_MODEL_PATH = FILE_BASE_PATH + 'openai/clip-resnet50' # OLD VERSION
|
||||
# TRAINED_MODEL_PATH = FILE_BASE_PATH + 'openai/clip-vit-base-patch32' # huggingface 저장소 위치
|
||||
TRAINED_MODEL_PATH = "./data" # 로컬 다운로드 경로
|
||||
|
||||
DEFAULT_INDEX_IMAGE_PATH = FILE_BASE_PATH + 'index_images'
|
||||
|
||||
class FEMUsageInfo(Enum):
|
||||
"""
|
||||
특징 추출(이미지 임베딩) 모델 추가시, 생성 정보를 추가해서 사용
|
||||
* 변수명 변경되면 인덱스파일 이름 새로 만듬
|
||||
"""
|
||||
openaiclip_vit_b32_l2 = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTB32),
|
||||
token=HUGGINGFACE_TOKEN,
|
||||
index_type=INDEX_TYPE_L2)]
|
||||
|
||||
openaiclip_vit_b32_cos = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTB32),
|
||||
token=HUGGINGFACE_TOKEN,
|
||||
index_type=INDEX_TYPE_COSINE)]
|
||||
|
||||
openaiclip_vit_b16_l2 = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTB16),
|
||||
token=HUGGINGFACE_TOKEN,
|
||||
index_type=INDEX_TYPE_L2)]
|
||||
|
||||
openaiclip_vit_b16_cos = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTB16),
|
||||
token=HUGGINGFACE_TOKEN,
|
||||
index_type=INDEX_TYPE_COSINE)]
|
||||
|
||||
openaiclip_vit_l14_l2 = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTL14),
|
||||
token=HUGGINGFACE_TOKEN,
|
||||
index_type=INDEX_TYPE_L2)]
|
||||
|
||||
openaiclip_vit_l14_cos = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTL14),
|
||||
token=HUGGINGFACE_TOKEN,
|
||||
index_type=INDEX_TYPE_COSINE)]
|
||||
|
||||
openaiclip_vit_b14_336_l2 = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTL14_336),
|
||||
token=HUGGINGFACE_TOKEN,
|
||||
index_type=INDEX_TYPE_L2)]
|
||||
|
||||
openaiclip_vit_b14_336_cos = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTL14_336),
|
||||
token=HUGGINGFACE_TOKEN,
|
||||
index_type=INDEX_TYPE_COSINE)]
|
||||
|
||||
openaicliprn = None
|
||||
mobilenetv2 = None
|
||||
|
||||
|
||||
DEFAULT_WEIGHT_PATH = FILE_BASE_PATH + TRAINED_MODEL_PATH
|
||||
|
||||
TEST_IMAGES_PATH = FILE_BASE_PATH + 'testimages'
|
||||
|
||||
"""
|
||||
Implementation
|
||||
"""
|
||||
# log off
|
||||
import logging
|
||||
logging.getLogger('PIL.PngImagePlugin').disabled = True
|
||||
|
||||
# Utility Method
|
||||
def function_name():
|
||||
return sys._getframe(1).f_code.co_name
|
||||
|
||||
|
||||
class FeatureExtractionModel(ABC):
|
||||
"""
|
||||
특징 벡터 추출 모델의 추상 인터페이스
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def init_model(self, pretrained_model):
|
||||
"""
|
||||
로드 모델 및 초기화
|
||||
:param pretrained_model:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError(f'Subclasses must implement this method{function_name()}')
|
||||
|
||||
@abstractmethod
|
||||
def image_embedding(self, image_np):
|
||||
"""
|
||||
이미지 임베딩 - 이미지 데이터(numpy)
|
||||
:param image_np: 이미지 데이터(numpy array)
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError(f'Subclasses must implement this method{function_name()}')
|
||||
|
||||
|
||||
class VectorSimilarity:
|
||||
"""
|
||||
벡터 기반 유사성 검색 모델 (FACEBOOK - FAISS)
|
||||
"""
|
||||
def __init__(self, model_info:FEMUsageInfo=None, index_file_path=None, txt_file_path=None):
|
||||
self.fem_model:FEM.FeatureExtractionModel = None
|
||||
|
||||
self.model_info = model_info
|
||||
self.fem_model_name = ''
|
||||
self.fem_model_args = None
|
||||
self.index_file_path = index_file_path
|
||||
self.txt_file_path = txt_file_path
|
||||
self.index_type = model_info.value[1].index_type
|
||||
|
||||
if model_info is not None:
|
||||
self.create_feature_extraction_model(model_info)
|
||||
|
||||
def create_feature_extraction_model(self, model_info):
|
||||
"""
|
||||
특징 벡터 추출 모델 초기화
|
||||
:return: True/False
|
||||
"""
|
||||
# check FEMUsageInfo
|
||||
if isinstance(model_info, FEMUsageInfo):
|
||||
if model_info not in FEMUsageInfo or model_info.value is None:
|
||||
log.error('invalid FEMArguments')
|
||||
return False
|
||||
else:
|
||||
log.error('invalid FEMArguments')
|
||||
return False
|
||||
|
||||
if model_info.value[0] is None or model_info.value[1] is None:
|
||||
log.error('invalid FEMArguments')
|
||||
return False
|
||||
|
||||
# release model
|
||||
if self.fem_model is not None:
|
||||
del self.fem_model
|
||||
|
||||
# initialize ai-model
|
||||
self.fem_model_name = model_info.name
|
||||
self.fem_model_args = model_info.value[1]
|
||||
self.fem_model = model_info.value[0](self.fem_model_args)
|
||||
|
||||
log.info(f'created fem-model: {model_info.value[0]}')
|
||||
|
||||
def image_embedding_from_file(self, image_file):
|
||||
"""
|
||||
이미지 파일에서 특징 벡터 추출
|
||||
:param image_file: 이미지 파일
|
||||
:return: 특징 벡터 or None
|
||||
"""
|
||||
feature_vectors = None
|
||||
|
||||
if not os.path.isfile(image_file):
|
||||
log.error(f'not found file[{image_file}]')
|
||||
return feature_vectors
|
||||
|
||||
image_data_np = Image.open(image_file).convert('RGB')
|
||||
feature_vectors = self.fem_model.image_embedding(image_data_np)
|
||||
return feature_vectors
|
||||
|
||||
def image_embedding_from_data(self, image_data_np=None):
|
||||
"""
|
||||
이미지 데이터(numpy)에서 특징 벡터 추출
|
||||
:param image_data_np: 이미지 데이터(numpy)
|
||||
:return: 특징 벡터 or None
|
||||
"""
|
||||
feature_vectors = None
|
||||
|
||||
if image_data_np is None:
|
||||
log.error(f'invalid data[{image_data_np}]')
|
||||
return feature_vectors
|
||||
|
||||
feature_vectors = self.fem_model.image_embedding(image_data_np)
|
||||
return feature_vectors
|
||||
|
||||
def save_index_from_files(self, images_path=None, save_index_dir=None, save_txt_dir=None, index_type=INDEX_TYPE_L2):
|
||||
"""
|
||||
인덱스 정보 저장
|
||||
:param images_path: 이미지 파일 경로
|
||||
:param save_index_dir: 인덱스 저장 디렉토리
|
||||
:param save_txt_dir: 이미지 경로 저장 디렉토리
|
||||
:param index_type: 인덱스 타입 (L2, Cosine)
|
||||
:return: 이미지 임베딩 차원 정보
|
||||
"""
|
||||
assert self.fem_model is not None, 'invalid fem_model'
|
||||
assert images_path is not None, 'images_path is required'
|
||||
|
||||
if not os.path.exists(images_path):
|
||||
log.error(f'image_path={images_path} does not exist')
|
||||
|
||||
image_folder = images_path
|
||||
result = None
|
||||
|
||||
# 모든 이미지 임베딩 모으기
|
||||
if os.path.exists(image_folder):
|
||||
image_paths = [os.path.join(image_folder, image_file) for image_file in os.listdir(image_folder)]
|
||||
image_vectors = np.vstack([self.image_embedding_from_file(image_file) for image_file in image_paths])
|
||||
log.debug(f'image_vectors.shape={image_vectors.shape}')
|
||||
result = image_vectors.shape
|
||||
|
||||
# 인덱스 생성 (L2)
|
||||
dimension = image_vectors.shape[1]
|
||||
|
||||
if index_type == INDEX_TYPE_L2:
|
||||
index = faiss.IndexFlatL2(dimension) # L2
|
||||
elif index_type == INDEX_TYPE_COSINE:
|
||||
index = faiss.IndexFlatIP(dimension) # cosine
|
||||
faiss.normalize_L2(image_vectors)
|
||||
else:
|
||||
raise Exception(f'Invalid index_type={index_type}')
|
||||
|
||||
index.add(image_vectors)
|
||||
|
||||
# 인덱스 저장
|
||||
if save_index_dir is None:
|
||||
save_index_path = os.path.join(FILE_BASE_PATH,f'{self.fem_model_name}_{os.path.basename(self.fem_model_args.trained_model)}_{index_type}_{DEFAULT_SAVE_INDEX_SUFFIX}')
|
||||
self.index_file_path = save_index_path
|
||||
else:
|
||||
save_index_path = os.path.join(save_index_dir,f'{self.fem_model_name}_{os.path.basename(self.fem_model_args.trained_model)}_{index_type}_{DEFAULT_SAVE_INDEX_SUFFIX}')
|
||||
os.makedirs(save_index_dir, exist_ok=True)
|
||||
|
||||
# 이미지 경로 파일 저장
|
||||
if save_txt_dir is None:
|
||||
save_txt_path = os.path.join(FILE_BASE_PATH, f'{self.fem_model_name}_{os.path.basename(self.fem_model_args.trained_model)}_{index_type}.txt')
|
||||
self.txt_file_path = save_txt_path
|
||||
else:
|
||||
save_txt_path = os.path.join(save_txt_dir, f'{self.fem_model_name}_{os.path.basename(self.fem_model_args.trained_model)}_{index_type}.txt')
|
||||
os.makedirs(save_txt_dir, exist_ok=True)
|
||||
|
||||
if not os.path.exists(save_txt_path):
|
||||
faiss.write_index(index, save_index_path)
|
||||
log.debug(f'save_index_path={save_index_path}')
|
||||
|
||||
if not os.path.exists(save_txt_path):
|
||||
# 이미지 경로 저장
|
||||
with open(save_txt_path, 'w') as f:
|
||||
for path in image_paths:
|
||||
f.write(f"{path}\n")
|
||||
log.debug(f'save_txt_path={save_txt_path}')
|
||||
|
||||
else:
|
||||
log.error(f'Image folder {image_folder} does not exist')
|
||||
return result
|
||||
|
||||
def set_index_path(self, index_path):
|
||||
assert index_path is not None, 'index_path is required'
|
||||
self.index_file_path = index_path
|
||||
|
||||
def load_index(self, index_file_name=None):
|
||||
"""
|
||||
인덱스 정보 불러오기
|
||||
:return:
|
||||
"""
|
||||
if index_file_name is not None:
|
||||
self.index_file_path = index_file_name
|
||||
|
||||
return faiss.read_index(self.index_file_path)
|
||||
|
||||
def time_diff(self, start_time, end_time):
|
||||
# 시간 차이 계산
|
||||
time_difference = end_time - start_time
|
||||
|
||||
# timedelta 객체에서 분:초:마이크로초 추출
|
||||
total_seconds_float = time_difference.total_seconds()
|
||||
total_seconds_int = int(total_seconds_float) # 정수 부분의 초
|
||||
|
||||
minutes = total_seconds_int // 60
|
||||
seconds = total_seconds_int % 60
|
||||
microseconds = time_difference.microseconds # 남은 마이크로초
|
||||
|
||||
return f"{minutes:02d}:{seconds:02d}.{microseconds:06d}"
|
||||
|
||||
def query_faiss(self, query_image=None, top_k=4):
|
||||
|
||||
assert query_image is not None, 'query_image is required'
|
||||
|
||||
if os.path.exists(self.txt_file_path):
|
||||
with open(self.txt_file_path, 'r') as f:
|
||||
image_paths = [line.strip() for line in f.readlines()]
|
||||
else:
|
||||
logging.error("Image path list TXT file not found.")
|
||||
image_paths = []
|
||||
|
||||
start_vector_time = datetime.now()
|
||||
index = self.load_index(self.index_file_path)
|
||||
query_vector = self.image_embedding_from_file(query_image)
|
||||
end_vector_time = datetime.now()
|
||||
|
||||
diff_vector_time = self.time_diff(start_vector_time,end_vector_time)
|
||||
|
||||
if self.index_type == INDEX_TYPE_COSINE:
|
||||
faiss.normalize_L2(query_vector)
|
||||
|
||||
start_search_time = datetime.now()
|
||||
distances, indices = index.search(query_vector, top_k)
|
||||
end_search_time = datetime.now()
|
||||
|
||||
diff_search_time = self.time_diff(start_search_time,end_search_time)
|
||||
diff_total_time = self.time_diff(start_vector_time,end_search_time)
|
||||
inference_times = f"Total time - {diff_total_time}, vector_time - {diff_vector_time}, search_time - {diff_search_time}"
|
||||
|
||||
result_img_paths = []
|
||||
result_percents = []
|
||||
|
||||
# 결과
|
||||
# for i in range(top_k):
|
||||
# print(f"{i + 1}: {image_paths[indices[0][i]]}, Distance: {distances[0][i]}")
|
||||
for idx, dist in zip(indices[0], distances[0]):
|
||||
log.debug(f"{idx} (거리: {dist:.4f})")
|
||||
|
||||
result_img_paths.append(image_paths[idx])
|
||||
|
||||
if self.index_type == INDEX_TYPE_COSINE:
|
||||
result_percents.append(f"{dist*100:.2f}")
|
||||
else:
|
||||
result_percents.append(f"{((1 - dist)*100):.2f}")
|
||||
|
||||
return inference_times, result_img_paths, result_percents
|
||||
|
||||
def test():
|
||||
"""
|
||||
module test function
|
||||
"""
|
||||
log.info('\nModule: faiss_similarity_search.py')
|
||||
|
||||
# index_file_path = f'{FILE_BASE_PATH}openaiclipvit_clip-vit-base-patch32_index.faiss'
|
||||
model = FEMUsageInfo.openaiclip_vit_l14_cos
|
||||
index_file_path = os.path.join(VECTOR_IMG_LIST_PATH,f'{model.name}_{os.path.basename(model.value[1].trained_model)}_{model.value[1].index_type}_{DEFAULT_INDEX_NAME_SUFFIX}')
|
||||
|
||||
cm = VectorSimilarity(model_info=model,
|
||||
index_file_path=index_file_path,
|
||||
txt_file_path=os.path.join(VECTOR_IMG_LIST_PATH,f'{model.name}_{os.path.basename(model.value[1].trained_model)}_{model.value[1].index_type}.txt'))
|
||||
|
||||
# cm.create_feature_extraction_model(FEMUsageInfo.openaiclipvit)
|
||||
|
||||
cm.save_index_from_files(images_path=INDEX_IMAGES_PATH,
|
||||
save_index_dir=FAISS_VECTOR_PATH,
|
||||
save_txt_dir=VECTOR_IMG_LIST_PATH,
|
||||
index_type=model.value[1].index_type)
|
||||
|
||||
|
||||
cm.query_faiss(os.path.join(INDEX_IMAGES_PATH,"img1.png"))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
test module
|
||||
"""
|
||||
test()
|
||||
109
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/feature_extraction_model.py
Executable file
109
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/feature_extraction_model.py
Executable file
@@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@file : feature_extraction_model.py
|
||||
@author: hsj100
|
||||
@license: A2D2
|
||||
@brief: 이미지 임베딩 모델 관리 정보
|
||||
|
||||
@section Modify History
|
||||
- 2025-06-27 오후 2:07 hsj100 base
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
Basic logger: custom color logger
|
||||
"""
|
||||
try:
|
||||
from custom_logger.main_log import vactor_logger as log
|
||||
except ImportError:
|
||||
import logging as log; log.basicConfig(level=log.DEBUG)
|
||||
|
||||
|
||||
"""
|
||||
Package: standard
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import sys
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Optional
|
||||
|
||||
|
||||
"""
|
||||
Package: 3rd party
|
||||
"""
|
||||
|
||||
|
||||
"""
|
||||
Package: custom
|
||||
"""
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
|
||||
|
||||
|
||||
"""
|
||||
Definition
|
||||
"""
|
||||
|
||||
class FEMArguments(BaseModel):
|
||||
"""
|
||||
특징 추출(이미지 임베딩) 모델 인수 정보
|
||||
NOTE: 모델별 추가내용 필요시, 항목 추가해서 사용
|
||||
"""
|
||||
device:str = Field(default=MODEL_DEVICE_GPU, description=f'FEM 동작 장치[{MODEL_DEVICE_CPU},{MODEL_DEVICE_GPU}]')
|
||||
trained_model:str = Field(default=None, description='학습된 모델(경로 또는 파일명) FEM 모델별로 상이함')
|
||||
token:str = Field(default=None, description='FEM 사용시 필요한 token 정보')
|
||||
index_type:str = Field(default=INDEX_TYPE_L2, description=f'인덱스 타입[{INDEX_TYPE_L2}, {INDEX_TYPE_COSINE}]')
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
def function_name():
|
||||
return sys._getframe(1).f_code.co_name
|
||||
|
||||
|
||||
"""
|
||||
Implementation
|
||||
"""
|
||||
class FeatureExtractionModel(ABC):
|
||||
"""
|
||||
특징 벡터 추출 모델의 추상 인터페이스
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, fem_arguments:FEMArguments):
|
||||
"""
|
||||
모델 초기화 및 메모리 적재
|
||||
:param fem_arguments: 모델 파라미터
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError(f'Subclasses must implement this method[{function_name()}]')
|
||||
|
||||
@abstractmethod
|
||||
def release_model(self):
|
||||
"""
|
||||
모델 메모리 해제
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError(f'Subclasses must implement this method[{function_name()}]')
|
||||
|
||||
@abstractmethod
|
||||
def image_embedding(self, image_np):
|
||||
"""
|
||||
이미지 임베딩 - 이미지 데이터(numpy)
|
||||
:param image_np: 이미지 데이터(numpy array)
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError(f'Subclasses must implement this method[{function_name()}]')
|
||||
|
||||
|
||||
def test():
|
||||
"""
|
||||
module test function
|
||||
"""
|
||||
log.info('\nModule: feature_extraction_model.py')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
test module
|
||||
"""
|
||||
test()
|
||||
200
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/fem_openaiclipvit.py
Executable file
200
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/fem_openaiclipvit.py
Executable file
@@ -0,0 +1,200 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@file : fem_openaiclipvit.py
|
||||
@author: hsj100
|
||||
@license: A2D2
|
||||
@brief: OpenAI Clip Model (Image & Text Embedding)
|
||||
|
||||
@section Modify History
|
||||
- 2025-06-27 오후 2:05 hsj100 base
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
Basic logger: custom color logger
|
||||
"""
|
||||
try:
|
||||
from custom_logger.main_log import vactor_logger as log
|
||||
except ImportError:
|
||||
import logging as log; log.basicConfig(level=log.DEBUG)
|
||||
|
||||
|
||||
"""
|
||||
Package: 3rd party
|
||||
"""
|
||||
|
||||
"""
|
||||
Package: standard
|
||||
"""
|
||||
import torch
|
||||
import os, sys
|
||||
|
||||
from transformers import CLIPProcessor, CLIPModel
|
||||
from huggingface_hub import login as huggingface_login
|
||||
|
||||
"""
|
||||
Package: custom
|
||||
"""
|
||||
# NOTE(hsj100): 코드 단독 실행을 위한 기본 경로 지정 (일반적으로는 본 코드 보다는 모듈 실행 방법을 권장)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 현재 스크립트의 디렉토리를 기준으로 부모 디렉토리의 경로를 가져옴
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS import feature_extraction_model as FEM
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
|
||||
|
||||
|
||||
"""
|
||||
Definition
|
||||
"""
|
||||
|
||||
|
||||
"""
|
||||
Implementation
|
||||
"""
|
||||
# 특징 벡터 추출 모델 개별 정의
|
||||
class FEOpenAIClipViT(FEM.FeatureExtractionModel):
|
||||
def __init__(self, arguments=FEM.FEMArguments()):
|
||||
"""
|
||||
Constructor
|
||||
:param arguments: parameter of this model
|
||||
"""
|
||||
# model interface
|
||||
self.model = None
|
||||
self.processor = None
|
||||
|
||||
# model parameters
|
||||
self.device = None
|
||||
self.trained_model = None # 학습된 모델 결과물이 있는 경로
|
||||
self.huggingface_token = None # huggingface token
|
||||
|
||||
|
||||
self.fem_arguments = None
|
||||
|
||||
self.load_model(fem_arguments=arguments)
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Destructor
|
||||
:return:
|
||||
"""
|
||||
self.release_model()
|
||||
|
||||
def load_model(self, fem_arguments=None):
|
||||
"""
|
||||
추상 인터페이스 구현: 모델 초기화 및 메모리 적재
|
||||
:param fem_arguments: 모델 인자 정보
|
||||
:return:
|
||||
"""
|
||||
assert fem_arguments is not None, 'failed init_model: fem_arguments is None'
|
||||
|
||||
# model parameters
|
||||
self.device = fem_arguments.device
|
||||
self.trained_model = fem_arguments.trained_model
|
||||
self.huggingface_token = fem_arguments.token
|
||||
|
||||
# device (CPU/GPU)
|
||||
if self.device == FEM.MODEL_DEVICE_GPU:
|
||||
if not torch.cuda.is_available():
|
||||
self.device = FEM.MODEL_DEVICE_CPU
|
||||
log.warning('Not found GPU device, use CPU instead')
|
||||
|
||||
# huggingface token
|
||||
if self.huggingface_token:
|
||||
huggingface_login(fem_arguments.token)
|
||||
|
||||
# model path check
|
||||
if not os.path.exists(fem_arguments.trained_model):
|
||||
self._download_model(fem_arguments.trained_model)
|
||||
|
||||
# model interface
|
||||
self.model = CLIPModel.from_pretrained(fem_arguments.trained_model)
|
||||
self.processor = CLIPProcessor.from_pretrained(fem_arguments.trained_model)
|
||||
|
||||
if self.device == FEM.MODEL_DEVICE_GPU:
|
||||
self.model.to(self.device)
|
||||
|
||||
# inference mode
|
||||
self.model.eval()
|
||||
|
||||
log.info(f'load complete:{self.__class__.__name__}')
|
||||
|
||||
def _download_model(self, pretrained_model_path):
|
||||
"""
|
||||
모델 다운로드 (로컬에 없는 경우)
|
||||
:param pretrained_model: 모델 이름
|
||||
:return:
|
||||
"""
|
||||
MODEL_LIST = [ViTB32, ViTB16, ViTL14, ViTL14_336]
|
||||
|
||||
model_name = os.path.basename(pretrained_model_path)
|
||||
|
||||
if model_name not in MODEL_LIST:
|
||||
raise ValueError(f"Unsupported model name: {model_name}. Supported models are: {MODEL_LIST}")
|
||||
|
||||
model = CLIPModel.from_pretrained(f"openai/{model_name}")
|
||||
processor = CLIPProcessor.from_pretrained(f"openai/{model_name}")
|
||||
|
||||
os.makedirs(pretrained_model_path, exist_ok=True)
|
||||
|
||||
model.save_pretrained(pretrained_model_path)
|
||||
processor.save_pretrained(pretrained_model_path)
|
||||
|
||||
log.info(f'Model downloaded and saved to {pretrained_model_path}')
|
||||
|
||||
def release_model(self):
|
||||
"""
|
||||
추상 인터페이스 구현: 모델 메모리 해제
|
||||
모델 해제
|
||||
:return:
|
||||
"""
|
||||
# GPU 메모리 캐시 비우기 (GPU 메모리 해제에 도움)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
log.info(f'release complete:{self.__class__.__name__}')
|
||||
|
||||
def image_embedding(self, image_np):
|
||||
"""
|
||||
추상 인터페이스 구현 : 이미지 임베딩 (특징 벡터 추출)
|
||||
:param image_np: 이미지 데이터 (numpy)
|
||||
:return: 특징 벡터 (numpy)
|
||||
"""
|
||||
result = None
|
||||
# 이미지 로딩 및 전처리
|
||||
# image_np = Image.open(CLIPModel.TEST_IMAGE_PATH if image_path is None else image_path).convert('RGB')
|
||||
inputs_tensor = self.processor(images=image_np, return_tensors='pt')
|
||||
|
||||
if self.device == FEM.MODEL_DEVICE_GPU:
|
||||
inputs_tensor = inputs_tensor.to(self.device)
|
||||
|
||||
# 특징 벡터 추출
|
||||
with torch.no_grad():
|
||||
outputs = self.model.get_image_features(**inputs_tensor) # shape: (1, 512)
|
||||
features = outputs / outputs.norm(p=2, dim=-1, keepdim=True) # 임베딩 정규화 L2
|
||||
|
||||
# log.debug(f'feature vector shape: {features.shape}')
|
||||
|
||||
if self.device == FEM.MODEL_DEVICE_GPU:
|
||||
result = features.cpu().numpy()
|
||||
else:
|
||||
result = features.numpy()
|
||||
|
||||
# log.debug(f'embedding complete:{self.__class__.__name__}')
|
||||
return result
|
||||
|
||||
def test():
|
||||
"""
|
||||
module test function
|
||||
"""
|
||||
log.info('\nModule: fem_openaiclipvit.py')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
test module
|
||||
"""
|
||||
test()
|
||||
130
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/report_utils.py
Executable file
130
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/report_utils.py
Executable file
@@ -0,0 +1,130 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Basic logger: custom color logger
|
||||
"""
|
||||
try:
|
||||
from custom_logger.main_log import vactor_logger as log
|
||||
except ImportError:
|
||||
import logging as log; log.basicConfig(level=log.DEBUG)
|
||||
|
||||
"""
|
||||
Package: standard
|
||||
"""
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
|
||||
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
"""
|
||||
Package: 3rd party
|
||||
"""
|
||||
|
||||
"""
|
||||
Package: custom
|
||||
"""
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.faiss_similarity_search import FEMUsageInfo
|
||||
|
||||
"""
|
||||
Definition
|
||||
"""
|
||||
@dataclass
|
||||
class ReportInfo:
|
||||
feature_extraction_model : str
|
||||
model_architecture : str
|
||||
train_model : str
|
||||
feature_extraction_time : str
|
||||
|
||||
query_image_path : str
|
||||
result_image_paths: list
|
||||
result_percents: list
|
||||
similarity_type: str
|
||||
|
||||
|
||||
"""
|
||||
Implementation
|
||||
"""
|
||||
|
||||
def table_setting(report_info:ReportInfo, result_image_path):
|
||||
"""
|
||||
결과 정보를 테이블 형태로 시각화하여 이미지로 저장하는 함수
|
||||
"""
|
||||
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(1920/100, 1080/100), dpi=100) # 2행 1열 서브플롯
|
||||
|
||||
##====================info table (상단 테이블)
|
||||
_info_table_row = [ReportInfoConst.feature_extraction_model,
|
||||
ReportInfoConst.model_architecture,
|
||||
ReportInfoConst.train_model,
|
||||
ReportInfoConst.feature_extraction_time,
|
||||
ReportInfoConst.similarity_type]
|
||||
_info_table_data = [[report_info.feature_extraction_model],
|
||||
[report_info.model_architecture],
|
||||
[report_info.train_model],
|
||||
[report_info.feature_extraction_time],
|
||||
[report_info.similarity_type]]
|
||||
ax1.set_axis_off() # 축 숨기기
|
||||
info_table = ax1.table(cellText=_info_table_data,
|
||||
rowLabels=_info_table_row,
|
||||
loc='center',
|
||||
cellLoc='left',
|
||||
bbox=[0,0,0.8,1])
|
||||
# info_table.auto_set_column_width(True) # 열 너비를 내용에 맞게 자동 조절
|
||||
# info_table.scale(0.8, 1.0)
|
||||
info_table.set_fontsize(20) # 폰트
|
||||
|
||||
for (row, col), cell in info_table.get_celld().items():
|
||||
cell.set_height(0.15)
|
||||
##================================
|
||||
|
||||
##====================result table (하단 테이블)
|
||||
_result_table_col = [ReportInfoConst.query, ReportInfoConst.result, ReportInfoConst.result, ReportInfoConst.result, ReportInfoConst.result]
|
||||
_result_table_data = [['','','','',''],
|
||||
['Similarity Rate(%)',
|
||||
report_info.result_percents[0],
|
||||
report_info.result_percents[1],
|
||||
report_info.result_percents[2],
|
||||
report_info.result_percents[3]]]
|
||||
|
||||
ax2.set_axis_off() # 축 숨기기
|
||||
result_table = ax2.table(cellText=_result_table_data,
|
||||
colLabels=_result_table_col,
|
||||
loc='center',
|
||||
cellLoc='center',
|
||||
bbox=[-0.23,0,1.2,0.8])
|
||||
result_table.set_fontsize(20) # 폰트
|
||||
|
||||
for (row, col), cell in result_table.get_celld().items():
|
||||
if row == 1:
|
||||
cell.set_height(0.5)
|
||||
else:
|
||||
cell.set_height(0.15)
|
||||
|
||||
# 이미지 삽입 - 위치 보정하여 셀 가운데 정렬
|
||||
image_paths = [report_info.query_image_path] + report_info.result_image_paths[:4]
|
||||
target_size = (240, 180)
|
||||
for col, img_path in enumerate(image_paths):
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB").resize(target_size)
|
||||
img_array = np.asarray(img)
|
||||
|
||||
imagebox = OffsetImage(img_array, zoom=1.0)
|
||||
position_x = -0.13 + 0.24 * col # 열 중심 (0.1, 0.3, 0.5, 0.7, 0.9)
|
||||
position_y = 0.4 # 테이블 세로 가운데 살짝 위쪽
|
||||
|
||||
ab = AnnotationBbox(imagebox,
|
||||
(position_x, position_y),
|
||||
frameon=False,
|
||||
box_alignment=(0.4, 0.5),
|
||||
xycoords='axes fraction')
|
||||
ax2.add_artist(ab)
|
||||
except Exception as e:
|
||||
print(f"[이미지 로딩 실패] {img_path}: {e}")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(result_image_path)
|
||||
@@ -6,8 +6,9 @@ from bingart import BingArt
|
||||
from custom_apps.utils import cookie_manager
|
||||
from main_rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from main_rest.app.utils.date_utils import D
|
||||
from const import REMOTE_FOLDER, TEMP_FOLDER
|
||||
from const import TEMP_FOLDER
|
||||
from utils.custom_sftp import sftp_client
|
||||
from config import rest_config
|
||||
|
||||
|
||||
class BingArtGenerator:
|
||||
@@ -15,7 +16,7 @@ class BingArtGenerator:
|
||||
model = 'dalle3'
|
||||
detail = 'art'
|
||||
output_folder = os.path.join(TEMP_FOLDER,"dalle","art")
|
||||
remote_folder = os.path.join(REMOTE_FOLDER,"dalle","art")
|
||||
remote_folder = os.path.join(rest_config.remote_folder,"dalle","art")
|
||||
|
||||
def __init__(self):
|
||||
self.bing_art = BingArt(auth_cookie_U=cookie_manager.get_cookie())
|
||||
|
||||
@@ -4,8 +4,8 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from custom_apps.bingimagecreator.BingImageCreator import *
|
||||
from const import REMOTE_FOLDER
|
||||
from custom_apps.utils import cookie_manager
|
||||
from config import rest_config
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -24,7 +24,7 @@ class DallEArgument:
|
||||
U = cookie_manager.get_cookie()
|
||||
prompt: str
|
||||
cookie_file: str|None = None
|
||||
output_dir: str = os.path.join(REMOTE_FOLDER,"dalle","img")
|
||||
output_dir: str = os.path.join(rest_config.remote_folder,"dalle","img")
|
||||
download_count: int = 1
|
||||
debug_file: str|None = None
|
||||
quiet: bool = False
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from custom_apps.faiss.utils import preprocessing, preprocessing_quary, normalize, get_dataset_list
|
||||
from custom_apps.faiss.const import *
|
||||
from custom_apps.faiss_imagenet.utils import preprocessing, preprocessing_quary, normalize, get_dataset_list
|
||||
from custom_apps.faiss_imagenet.const import *
|
||||
# from rest.app.utils.date_utils import D
|
||||
# from const import OUTPUT_FOLDER
|
||||
|
||||
@@ -4,10 +4,12 @@ import vertexai
|
||||
import os
|
||||
from vertexai.preview.vision_models import ImageGenerationModel
|
||||
|
||||
from const import REMOTE_FOLDER, TEMP_FOLDER
|
||||
from const import TEMP_FOLDER
|
||||
from main_rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from main_rest.app.utils.date_utils import D
|
||||
from utils.custom_sftp import sftp_client
|
||||
from config import rest_config
|
||||
|
||||
|
||||
class ImagenConst:
|
||||
project_id = "glasses-imagen"
|
||||
@@ -22,7 +24,7 @@ def imagen_generate_image(prompt,download_count=1):
|
||||
|
||||
_file_name = prompt_to_filenames(prompt)
|
||||
_folder = os.path.join(TEMP_FOLDER)
|
||||
_remote_folder = os.path.join(REMOTE_FOLDER,"imagen")
|
||||
_remote_folder = os.path.join(rest_config.remote_folder,"imagen")
|
||||
_datetime = D.date_file_name()
|
||||
|
||||
# if not os.path.isdir(_folder):
|
||||
@@ -88,6 +90,19 @@ def imagen_generate_image_path(image_prompt):
|
||||
|
||||
return os.path.join(folder_name,f"query.png")
|
||||
|
||||
def imagen_generate_temp_image_path(image_prompt):
|
||||
|
||||
create_time = D.date_file_name()
|
||||
|
||||
if not os.path.exists(TEMP_FOLDER):
|
||||
os.makedirs(TEMP_FOLDER)
|
||||
|
||||
generate_img = imagen_generate_image_data(image_prompt)
|
||||
|
||||
generate_img.save(os.path.join(TEMP_FOLDER,f"query_{create_time}.png"))
|
||||
|
||||
return os.path.join(TEMP_FOLDER,f"query_{create_time}.png")
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
# imagen_generate_image_data("cat")
|
||||
@@ -38,7 +38,7 @@ API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False)
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# When service starts.
|
||||
LOG.info("REST start")
|
||||
LOG.info(f"REST start (port : {conf().REST_SERVER_PORT})")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -570,6 +570,19 @@ class IndexType:
|
||||
hnsw = "hnsw"
|
||||
l2 = "l2"
|
||||
|
||||
|
||||
class VitIndexType:
|
||||
cos = "cos"
|
||||
l2 = "l2"
|
||||
|
||||
|
||||
class VitModelType:
|
||||
b32 = "b32"
|
||||
b16 = "b16"
|
||||
l14 = "l14"
|
||||
l14_336 = "l14_336"
|
||||
|
||||
|
||||
class ImageGenerateReq(BaseModel):
|
||||
"""
|
||||
### [Request] image generate request
|
||||
@@ -590,10 +603,29 @@ class VactorImageSearchReq(BaseModel):
|
||||
### [Request] vactor image search request
|
||||
"""
|
||||
prompt : str = Field(description='프롬프트', example='검은색 안경')
|
||||
index_type : str = Field(IndexType.hnsw, description='인덱스 타입', example=IndexType.hnsw)
|
||||
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||
indexType : str = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
|
||||
searchNum : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||
|
||||
|
||||
class VactorImageSearchVitReq(BaseModel):
|
||||
"""
|
||||
### [Request] vactor image search vit
|
||||
"""
|
||||
prompt : str = Field(description='프롬프트', example='검은색 안경')
|
||||
modelType : str = Field(VitModelType.l14, description='pretrained model 타입', example=VitModelType.l14)
|
||||
indexType : str = Field(VitIndexType.l2, description='인덱스 타입', example=VitIndexType.l2)
|
||||
searchNum : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||
|
||||
|
||||
class VactorImageSearchVitReportReq(BaseModel):
|
||||
"""
|
||||
### [Request] vactor image search vit request
|
||||
"""
|
||||
prompt : str = Field(description='프롬프트', example='검은색 안경')
|
||||
modelType : str = Field(VitModelType.l14, description='pretrained model 타입', example=VitModelType.l14)
|
||||
indexType : str = Field(VitIndexType.l2, description='인덱스 타입', example=VitIndexType.l2)
|
||||
|
||||
|
||||
#===============================================================================
|
||||
#===============================================================================
|
||||
#===============================================================================
|
||||
|
||||
@@ -21,11 +21,11 @@ from custom_logger.main_log import main_logger as LOG
|
||||
|
||||
from custom_apps.bingimagecreator.utils import DallEArgument,dalle3_generate_image
|
||||
from custom_apps.bingart.bingart import BingArtGenerator
|
||||
from custom_apps.imagen.custom_imagen import imagen_generate_image, imagen_generate_image_path
|
||||
from custom_apps.imagen.custom_imagen import imagen_generate_image, imagen_generate_image_path, imagen_generate_temp_image_path
|
||||
from main_rest.app.utils.parsing_utils import download_range
|
||||
from custom_apps.utils import cookie_manager
|
||||
from utils.custom_sftp import sftp_client
|
||||
from const import REMOTE_FOLDER, TEMP_FOLDER
|
||||
from config import rest_config
|
||||
|
||||
router = APIRouter(prefix="/services")
|
||||
|
||||
@@ -92,7 +92,6 @@ async def bing_img_generate(request: Request, request_body_info: M.ImageGenerate
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(e)
|
||||
|
||||
|
||||
@router.post("/imageGenerate/bingart", summary="이미지 생성(AI) - bing art (DALL-E 3)", response_model=M.ImageGenerateRes)
|
||||
async def bing_art(request: Request, request_body_info: M.ImageGenerateReq):
|
||||
"""
|
||||
@@ -122,7 +121,6 @@ async def bing_art(request: Request, request_body_info: M.ImageGenerateReq):
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(e)
|
||||
|
||||
|
||||
@router.post("/imageGenerate/imagen", summary="이미지 생성(AI) - imagen", response_model=M.ImageGenerateRes)
|
||||
async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
|
||||
"""
|
||||
@@ -150,9 +148,8 @@ async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(error=e)
|
||||
|
||||
|
||||
@router.post("/vactorImageSearch/imageGenerate/imagen", summary="벡터 이미지 검색 - imagen", response_model=M.ResponseBase)
|
||||
async def vactor_image(request: Request, request_body_info: M.VactorImageSearchReq):
|
||||
@router.post("/vactorImageSearch/imagenet/imageGenerate/imagen", summary="벡터 이미지 검색(imagenet) - imagen", response_model=M.ResponseBase)
|
||||
async def vactor_imagenet(request: Request, request_body_info: M.VactorImageSearchReq):
|
||||
"""
|
||||
## 벡터 이미지 검색 - imagen
|
||||
> imagen AI를 이용하여 이미지 생성 후 vactor 검색
|
||||
@@ -164,13 +161,15 @@ async def vactor_image(request: Request, request_body_info: M.VactorImageSearchR
|
||||
"""
|
||||
response = M.ResponseBase()
|
||||
try:
|
||||
if request_body_info.index_type not in [M.IndexType.hnsw, M.IndexType.l2]:
|
||||
raise Exception(f"index_type is hnsw or l2 (current value = {request_body_info.index_type})")
|
||||
if request_body_info.indexType not in [M.IndexType.hnsw, M.IndexType.l2]:
|
||||
raise Exception(f"indexType is hnsw or l2 (current value = {request_body_info.indexType})")
|
||||
|
||||
img_path = imagen_generate_image_path(image_prompt=request_body_info.prompt)
|
||||
|
||||
vactor_request_data = {'quary_image_path' : img_path,'index_type' : request_body_info.index_type, 'search_num' : request_body_info.search_num}
|
||||
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search', data=json.dumps(vactor_request_data))
|
||||
vactor_request_data = {'query_image_path' : img_path,
|
||||
'index_type' : request_body_info.indexType,
|
||||
'search_num' : request_body_info.searchNum}
|
||||
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/imagenet', data=json.dumps(vactor_request_data))
|
||||
|
||||
if vactor_response.status_code != 200:
|
||||
raise Exception(f"response error: {json.loads(vactor_response.text)['error']}")
|
||||
@@ -183,14 +182,157 @@ async def vactor_image(request: Request, request_body_info: M.VactorImageSearchR
|
||||
_base_bame = os.path.basename(_directory_path)
|
||||
|
||||
# remote 폴더 생성
|
||||
sftp_client.remote_mkdir(os.path.join(REMOTE_FOLDER, _base_bame))
|
||||
sftp_client.remote_mkdir(os.path.join(rest_config.remote_folder, _base_bame))
|
||||
|
||||
# remote 폴더에 이미지 저장
|
||||
for i in os.listdir(_directory_path):
|
||||
sftp_client.remote_copy_data(local_path=os.path.join(_directory_path, i), remote_path=os.path.join(REMOTE_FOLDER, _base_bame, i))
|
||||
sftp_client.remote_copy_data(local_path=os.path.join(_directory_path, i), remote_path=os.path.join(rest_config.remote_folder, _base_bame, i))
|
||||
|
||||
return response.set_message()
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(error=e)
|
||||
|
||||
@router.post("/vactorImageSearch/vit/imageGenerate/imagen", summary="벡터 이미지 검색(clip-vit) - imagen", response_model=M.ResponseBase)
|
||||
async def vactor_vit_report(request: Request, request_body_info: M.VactorImageSearchVitReq):
|
||||
"""
|
||||
## 벡터 이미지 검색(clip-vit) - imagen
|
||||
> imagen AI를 이용하여 이미지 생성 후 vactor 검색 그후 결과 이미지 생성
|
||||
|
||||
### Requriements
|
||||
> - googlecli 설치(https://cloud.google.com/sdk/docs/install?hl=ko#linux)
|
||||
|
||||
### options
|
||||
> - modelType -> b32,b16,l14,l14_336
|
||||
> - indexType -> l2,cos
|
||||
|
||||
"""
|
||||
response = M.ResponseBase()
|
||||
try:
|
||||
if not download_range(request_body_info.searchNum, max=10):
|
||||
raise Exception(f"downloadCound is invalid (current value = {request_body_info.searchNum})")
|
||||
|
||||
if request_body_info.modelType not in [M.VitModelType.b32, M.VitModelType.b16, M.VitModelType.l14, M.VitModelType.l14_336]:
|
||||
raise Exception(f"modelType is invalid (current value = {request_body_info.modelType})")
|
||||
|
||||
if request_body_info.indexType not in [M.VitIndexType.cos, M.VitIndexType.l2]:
|
||||
raise Exception(f"indexType is invalid (current value = {request_body_info.indexType})")
|
||||
|
||||
query_image_path = imagen_generate_temp_image_path(image_prompt=request_body_info.prompt)
|
||||
|
||||
vector_request_data = {'query_image_path' : query_image_path,
|
||||
'index_type' : request_body_info.indexType,
|
||||
'model_type' : request_body_info.modelType,
|
||||
'search_num' : request_body_info.searchNum}
|
||||
|
||||
vector_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/vit', data=json.dumps(vector_request_data))
|
||||
|
||||
vector_response_dict = json.loads(vector_response.text)
|
||||
|
||||
if vector_response.status_code != 200:
|
||||
raise Exception(f"response error: {vector_response_dict['error']}")
|
||||
|
||||
if vector_response_dict["error"] != None:
|
||||
raise Exception(f"vactor error: {vector_response_dict['error']}")
|
||||
|
||||
result_image_paths = vector_response_dict.get('img_list').get('result_image_paths')
|
||||
result_percents = vector_response_dict.get('img_list').get('result_percents')
|
||||
|
||||
# 원격지 폴더 생성
|
||||
remote_directory = os.path.join(rest_config.remote_folder, f"imagen_query_{request_body_info.modelType}_{request_body_info.indexType}_{request_body_info.prompt}_{D.date_file_name()}")
|
||||
sftp_client.remote_mkdir(remote_directory)
|
||||
|
||||
# 원격지에 이미지 저장
|
||||
sftp_client.remote_copy_data(local_path=query_image_path, remote_path=os.path.join(remote_directory,"query.png"))
|
||||
|
||||
for img_path, img_percent in zip(result_image_paths,result_percents):
|
||||
sftp_client.remote_copy_data(local_path=img_path, remote_path=os.path.join(remote_directory,f"search_{img_percent}.png"))
|
||||
|
||||
# Clean up temporary files
|
||||
if 'query_image_path' in locals():
|
||||
if os.path.exists(query_image_path):
|
||||
os.remove(query_image_path)
|
||||
del query_image_path
|
||||
|
||||
return response.set_message()
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
|
||||
# Clean up temporary files
|
||||
if 'query_image_path' in locals():
|
||||
if os.path.exists(query_image_path):
|
||||
os.remove(query_image_path)
|
||||
del query_image_path
|
||||
|
||||
return response.set_error(error=e)
|
||||
|
||||
@router.post("/vactorImageSearch/vit/imageGenerate/imagen/report", summary="벡터 이미지 검색(clip-vit) - imagen, report 생성", response_model=M.ResponseBase)
|
||||
async def vactor_vit_report(request: Request, request_body_info: M.VactorImageSearchVitReportReq):
|
||||
"""
|
||||
## 벡터 이미지 검색(clip-vit) - imagen, report 생성
|
||||
> imagen AI를 이용하여 이미지 생성 후 vactor 검색 그후 종합결과 이미지 생성
|
||||
|
||||
### Requriements
|
||||
> - googlecli 설치(https://cloud.google.com/sdk/docs/install?hl=ko#linux)
|
||||
|
||||
### options
|
||||
> - modelType -> b32,b16,l14,l14_336
|
||||
> - indexType -> l2,cos
|
||||
|
||||
"""
|
||||
response = M.ResponseBase()
|
||||
try:
|
||||
if request_body_info.modelType not in [M.VitModelType.b32, M.VitModelType.b16, M.VitModelType.l14, M.VitModelType.l14_336]:
|
||||
raise Exception(f"modelType is invalid (current value = {request_body_info.modelType})")
|
||||
|
||||
if request_body_info.indexType not in [M.VitIndexType.cos, M.VitIndexType.l2]:
|
||||
raise Exception(f"indexType is invalid (current value = {request_body_info.indexType})")
|
||||
|
||||
query_image_path = imagen_generate_temp_image_path(image_prompt=request_body_info.prompt)
|
||||
report_image_path = f"{os.path.splitext(query_image_path)[0]}_report.png"
|
||||
|
||||
vactor_request_data = {'query_image_path' : query_image_path,
|
||||
'index_type' : request_body_info.indexType,
|
||||
'model_type' : request_body_info.modelType,
|
||||
'report_path' : report_image_path}
|
||||
|
||||
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/vit/report', data=json.dumps(vactor_request_data))
|
||||
|
||||
if vactor_response.status_code != 200:
|
||||
raise Exception(f"response error: {json.loads(vactor_response.text)['error']}")
|
||||
|
||||
if json.loads(vactor_response.text)["error"] != None:
|
||||
raise Exception(f"vactor error: {json.loads(vactor_response.text)['error']}")
|
||||
|
||||
# remote 폴더에 이미지 저장
|
||||
sftp_client.remote_copy_data(local_path=report_image_path,
|
||||
remote_path=os.path.join(rest_config.remote_folder, f"imagen_report_vit_{request_body_info.prompt}_{D.date_file_name()}.png"))
|
||||
|
||||
# Clean up temporary files
|
||||
if 'query_image_path' in locals():
|
||||
if os.path.exists(query_image_path):
|
||||
os.remove(query_image_path)
|
||||
del query_image_path
|
||||
if 'report_image_path' in locals():
|
||||
if os.path.exists(report_image_path):
|
||||
os.remove(report_image_path)
|
||||
del report_image_path
|
||||
|
||||
return response.set_message()
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
|
||||
# Clean up temporary files
|
||||
if 'query_image_path' in locals():
|
||||
if os.path.exists(query_image_path):
|
||||
os.remove(query_image_path)
|
||||
del query_image_path
|
||||
if 'report_image_path' in locals():
|
||||
if os.path.exists(report_image_path):
|
||||
os.remove(report_image_path)
|
||||
del report_image_path
|
||||
|
||||
return response.set_error(error=e)
|
||||
@@ -15,4 +15,10 @@ email-validator
|
||||
|
||||
#faiss
|
||||
scikit-learn
|
||||
pillow
|
||||
pillow
|
||||
flax
|
||||
transformers
|
||||
torch
|
||||
tensorflow
|
||||
matplotlib
|
||||
keras==3.10.0
|
||||
@@ -2,10 +2,11 @@ import paramiko
|
||||
|
||||
class CustomSFTPClient():
|
||||
def __init__(self):
|
||||
host = "192.168.200.230"
|
||||
port = 22
|
||||
id = "fermat"
|
||||
pw = "1234"
|
||||
from config import rest_config
|
||||
host = rest_config.sftp_host
|
||||
port = rest_config.sftp_port
|
||||
id = rest_config.sftp_id
|
||||
pw = rest_config.sftp_pw
|
||||
|
||||
self.ssh_client = paramiko.SSHClient()
|
||||
self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
|
||||
@@ -38,7 +38,7 @@ API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False)
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# When service starts.
|
||||
LOG.info("REST start")
|
||||
LOG.info(f"REST start (port : {conf().REST_SERVER_PORT})")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -570,12 +570,62 @@ class IndexType(str, Enum):
|
||||
hnsw = "hnsw"
|
||||
l2 = "l2"
|
||||
|
||||
|
||||
class VitIndexType(str, Enum):
|
||||
cos = "cos"
|
||||
l2 = "l2"
|
||||
|
||||
|
||||
class VitModelType(str, Enum):
|
||||
b32 = "b32"
|
||||
b16 = "b16"
|
||||
l14 = "l14"
|
||||
l14_336 = "l14_336"
|
||||
|
||||
|
||||
class VactorSearchReq(BaseModel):
|
||||
"""
|
||||
### [Request] vactor 검색
|
||||
"""
|
||||
quary_image_path : str = Field(description='quary image', example='path')
|
||||
index_type : IndexType = Field(IndexType.hnsw, description='인덱스 타입', example=IndexType.hnsw)
|
||||
query_image_path : str = Field(description='quary image', example='path')
|
||||
index_type : IndexType = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
|
||||
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||
|
||||
|
||||
|
||||
class VactorSearchVitReportReq(BaseModel):
|
||||
"""
|
||||
### [Request] vactor 검색(vit) 후 리포트 이미지 생성
|
||||
"""
|
||||
query_image_path : str = Field(description='quary image', example='path')
|
||||
index_type : VitIndexType = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
|
||||
model_type : VitModelType = Field(VitModelType.l14, description='pretrained 모델 정보', example=VitModelType.l14)
|
||||
report_path : str = Field(description='리포트 이미지 저장 경로', example='path')
|
||||
|
||||
|
||||
class VactorSearchVitReq(BaseModel):
|
||||
"""
|
||||
### [Request] vactor 검색(vit) 후 이미지 생성
|
||||
"""
|
||||
query_image_path : str = Field(description='quary image', example='path')
|
||||
index_type : VitIndexType = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
|
||||
model_type : VitModelType = Field(VitModelType.l14, description='pretrained 모델 정보', example=VitModelType.l14)
|
||||
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||
|
||||
class VactorSearchVitRes(ResponseBase):
|
||||
img_list : dict = Field({}, description='이미지 결과 리스트', example={})
|
||||
|
||||
@staticmethod
|
||||
def set_error(error):
|
||||
VactorSearchVitRes.img_list = {}
|
||||
VactorSearchVitRes.result = False
|
||||
VactorSearchVitRes.error = str(error)
|
||||
|
||||
return VactorSearchVitRes
|
||||
|
||||
@staticmethod
|
||||
def set_message(msg):
|
||||
VactorSearchVitRes.img_list = msg
|
||||
VactorSearchVitRes.result = True
|
||||
VactorSearchVitRes.error = None
|
||||
|
||||
return VactorSearchVitRes
|
||||
|
||||
@@ -19,12 +19,15 @@ from vactor_rest.app import models as M
|
||||
from vactor_rest.app.utils.date_utils import D
|
||||
from custom_logger.vactor_log import vactor_logger as LOG
|
||||
|
||||
from custom_apps.faiss.main import search_idxs
|
||||
from custom_apps.faiss_imagenet.main import search_idxs
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.faiss_functions import get_clip_info, save_report_image, get_models
|
||||
|
||||
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.faiss_similarity_search import FEMUsageInfo
|
||||
|
||||
router = APIRouter(prefix="/services")
|
||||
|
||||
|
||||
@router.post("/faiss/vactor/search", summary="vactor search", response_model=M.ResponseBase)
|
||||
@router.post("/faiss/vactor/search/imagenet", summary="imagenet search", response_model=M.ResponseBase)
|
||||
async def vactor_search(request: Request, request_body_info: M.VactorSearchReq):
|
||||
"""
|
||||
## 벡터검색
|
||||
@@ -35,15 +38,55 @@ async def vactor_search(request: Request, request_body_info: M.VactorSearchReq):
|
||||
"""
|
||||
response = M.ResponseBase()
|
||||
try:
|
||||
if os.path.exists(request_body_info.quary_image_path):
|
||||
search_idxs(image_path=request_body_info.quary_image_path,
|
||||
if os.path.exists(request_body_info.query_image_path):
|
||||
search_idxs(image_path=request_body_info.query_image_path,
|
||||
index_type=request_body_info.index_type,
|
||||
search_num=request_body_info.search_num)
|
||||
else:
|
||||
raise Exception(f"File {request_body_info.quary_image_path} does not exist.")
|
||||
raise Exception(f"File {request_body_info.query_image_path} does not exist.")
|
||||
|
||||
return response.set_message()
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(e)
|
||||
|
||||
|
||||
@router.post("/faiss/vactor/search/vit/report", summary="vit search report", response_model=M.ResponseBase)
|
||||
async def vactor_report_vit(request: Request, request_body_info: M.VactorSearchVitReportReq):
|
||||
response = M.ResponseBase()
|
||||
try:
|
||||
if not os.path.exists(request_body_info.query_image_path):
|
||||
raise FileNotFoundError(f"File {request_body_info.query_image_path} does not exist.")
|
||||
|
||||
model = get_models(index_type=request_body_info.index_type, model_type=request_body_info.model_type)
|
||||
|
||||
report_info = get_clip_info(model,request_body_info.query_image_path)
|
||||
save_report_image(report_info, request_body_info.report_path)
|
||||
|
||||
return response.set_message()
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(e)
|
||||
|
||||
|
||||
@router.post("/faiss/vactor/search/vit", summary="vit search", response_model=M.VactorSearchVitRes)
|
||||
async def vactor_report_vit(request: Request, request_body_info: M.VactorSearchVitReq):
|
||||
response = M.VactorSearchVitRes()
|
||||
try:
|
||||
if not os.path.exists(request_body_info.query_image_path):
|
||||
raise FileNotFoundError(f"File {request_body_info.query_image_path} does not exist.")
|
||||
|
||||
model = get_models(index_type=request_body_info.index_type, model_type=request_body_info.model_type)
|
||||
|
||||
report_info = get_clip_info(model,request_body_info.query_image_path,top_k=request_body_info.search_num)
|
||||
|
||||
return response.set_message({
|
||||
'result_image_paths': report_info.result_image_paths,
|
||||
'result_percents': report_info.result_percents
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(e)
|
||||
Reference in New Issue
Block a user