edit : clip-vit 모델 추가

This commit is contained in:
2025-07-08 17:20:32 +09:00
parent 4c2ea70289
commit 309a91bda6
24 changed files with 1395 additions and 40 deletions

6
.gitignore vendored
View File

@@ -145,4 +145,8 @@ log
*output
*temp
datas/eyewear_all/*
datas/eyewear_all/*
datas/clip*
datas/openai*
*result

27
config.py Normal file
View File

@@ -0,0 +1,27 @@
class Config:
def __init__(self):
self.set_rel()
def set_rel(self):
self.config = 'rel'
self.remote_folder = "/home/fermat/STORAGE/01.Projects/A2TEC/K_EYEWEAR/02.ML_DATA/Image_generator_result"
self.sftp_host = "192.168.200.230"
self.sftp_port = 22
self.sftp_id = "fermat"
self.sftp_pw = "1234"
def set_dev(self):
self.config = 'dev'
self.remote_folder = "/home/fermat/project/FM_TEST_REST_SERVER/result"
self.sftp_host = "192.168.200.231"
self.sftp_port = 22
self.sftp_id = "fermat"
self.sftp_pw = "fermat3514"
import os
if not os.path.exists(self.remote_folder):
os.makedirs(self.remote_folder)
rest_config = Config()
# rest_config.set_dev()

View File

@@ -1,4 +1,3 @@
REMOTE_FOLDER = "/home/fermat/STORAGE/01.Projects/A2TEC/K_EYEWEAR/02.ML_DATA/Image_generator_result"
TEMP_FOLDER = "./temp"
ILLEGAL_FILE_NAME = ['<', '>', ':', '"', '/', '\ ', '|', '?', '*']

View File

@@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
"""
@file : __init__.py
@author: hsj100
@license: A2D2
@brief:
@section Modify History
- 2025-06-27 오후 5:35 hsj100 base
"""

View File

@@ -0,0 +1,40 @@
# Access token은 huggingface.co > Settings > Access Tokens 에서 생성
HUGGINGFACE_TOKEN = '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'
MODEL_DEVICE_CPU = 'cpu'
MODEL_DEVICE_GPU = 'cuda'
DEFAULT_INDEX_NAME_SUFFIX = 'index.faiss'
DEFAULT_SAVE_INDEX_SUFFIX = DEFAULT_INDEX_NAME_SUFFIX
INDEX_TYPE_L2 = 'l2'
INDEX_TYPE_COSINE = 'cosine'
#VIT MODELS
ViTB32 = "clip-vit-base-patch32"
ViTB16 = "clip-vit-base-patch16"
ViTL14 = "clip-vit-large-patch14"
ViTL14_336 = "clip-vit-large-patch14-336"
#download/save path
PRETRAINED_MODEL_PATH = "./datas"
FAISS_VECTOR_PATH = "./datas"
VECTOR_IMG_LIST_PATH = "./datas"
INDEX_IMAGES_PATH = "./datas/eyewear_all"
# report image consts
class ReportInfoConst:
feature_extraction_model = "Feature Extraction Model"
model_architecture = "Model Architecture"
train_model = "Train Model"
feature_extraction_time = "Feature Extraction Time"
similarity_type = "Similarity Type"
query = "Query"
result = "Result"
class similarity:
l2 = "L2(Euclidean Distance)"
cosine = "Cosine(Inner Product, L2)"

View File

@@ -0,0 +1,135 @@
# -*- coding: utf-8 -*-
"""
@file : faiss_functions.py
@author: jwkim
@license: A2D2
@section Modify History
- 2025-06-27 오후 2:05 hsj100 base
"""
"""
Basic logger: custom color logger
"""
try:
from custom_logger.custom_log import custom_logger as log
except ImportError:
import logging as log; log.basicConfig(level=log.DEBUG)
"""
Package: 3rd party
"""
"""
Package: standard
"""
import random
import pathlib
import os
import sys
import traceback
"""
Package: custom
"""
# NOTE(hsj100): 코드 단독 실행을 위한 기본 경로 지정 (일반적으로는 본 코드 보다는 모듈 실행 방법을 권장)
if __name__ == '__main__':
# 현재 스크립트의 디렉토리를 기준으로 부모 디렉토리의 경로를 가져옴
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
sys.path.append(parent_dir)
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.report_utils import *
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.faiss_similarity_search import VectorSimilarity
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.report_utils import *
"""
Definition
"""
"""
Implementation
"""
def get_models(index_type, model_type):
from vactor_rest.app import models as M
model = None
# model 선택
if index_type == M.VitIndexType.l2:
# l2
if model_type == M.VitModelType.l14:
model = FEMUsageInfo.openaiclip_vit_l14_l2
elif model_type == M.VitModelType.b32:
model = FEMUsageInfo.openaiclip_vit_b32_l2
elif model_type == M.VitModelType.b16:
model = FEMUsageInfo.openaiclip_vit_b16_l2
elif model_type == M.VitModelType.l14_336:
model = FEMUsageInfo.openaiclip_vit_l14_336_l2
else:
raise Exception(f"model_type {model_type} is invalid")
else:
#cosine
if model_type == M.VitModelType.l14:
model = FEMUsageInfo.openaiclip_vit_l14_cos
elif model_type == M.VitModelType.b32:
model = FEMUsageInfo.openaiclip_vit_b32_cos
elif model_type == M.VitModelType.b16:
model = FEMUsageInfo.openaiclip_vit_b16_cos
elif model_type == M.VitModelType.l14_336:
model = FEMUsageInfo.openaiclip_vit_l14_336_cos
else:
# NOTE(JWKIM) : 현재 l14만 지원 FEMUsageInfo 에 추가해야함
raise Exception(f"model_type {model_type} is invalid")
if model:
return model
else:
raise Exception(f"model is not selected")
def save_report_image(report_info, report_image_path):
"""
이미지 유사도 검색 결과를 리포트로 생성
"""
table_setting(report_info, report_image_path)
def get_clip_info(model, query_image_path, top_k=4):
"""
이미지 유사도 검색
"""
index_file_path = os.path.join(VECTOR_IMG_LIST_PATH,f'{model.name}_{os.path.basename(model.value[1].trained_model)}_{model.value[1].index_type}_{DEFAULT_INDEX_NAME_SUFFIX}')
txt_file_path = os.path.join(VECTOR_IMG_LIST_PATH,f'{model.name}_{os.path.basename(model.value[1].trained_model)}_{model.value[1].index_type}.txt')
vector_model = VectorSimilarity(model_info=model,
index_file_path=index_file_path,
txt_file_path=txt_file_path)
# vector_model.create_feature_extraction_model(FEMUsageInfo.openaiclipvit)
vector_model.save_index_from_files(images_path=INDEX_IMAGES_PATH,
save_index_dir=FAISS_VECTOR_PATH,
save_txt_dir=VECTOR_IMG_LIST_PATH,
index_type=model.value[1].index_type)
inference_times, result_img_paths, result_percents = vector_model.query_faiss(query_image_path, top_k=top_k)
report_info = ReportInfo(
feature_extraction_model=ReportInfoConst.feature_extraction_model,
model_architecture=ReportInfoConst.model_architecture,
train_model=os.path.basename(model.value[1].trained_model),
feature_extraction_time=inference_times,
query_image_path=query_image_path,
result_image_paths=result_img_paths,
result_percents=result_percents,
similarity_type=model.value[1].index_type
)
return report_info

View File

@@ -0,0 +1,410 @@
# -*- coding: utf-8 -*-
"""
@file : faiss_similarity_search.py
@author: hsj100
@license: A2D2
@brief: 특징 벡터를 이용한 이미지 유사도 검색
@section Modify History
- 2025-06-19 오후 2:45 hsj100 base
"""
"""
Basic logger: custom color logger
"""
try:
from custom_logger.main_log import vactor_logger as log
except ImportError:
import logging as log; log.basicConfig(level=log.DEBUG)
"""
Package: standard
"""
import sys
import torch
import faiss
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
from abc import ABC, abstractmethod
from enum import Enum
from datetime import datetime
"""
Package: 3rd party
"""
"""
Package: custom
"""
# NOTE(hsj100): 코드 단독 실행을 위한 기본 경로 지정 (일반적으로는 본 코드 보다는 모듈 실행 방법을 권장)
import os, sys
if __name__ == '__main__':
# 현재 스크립트의 디렉토리를 기준으로 부모 디렉토리의 경로를 가져옴
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
sys.path.append(parent_dir)
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS import feature_extraction_model as FEM
# 사용할 이미지 임베딩 모델 클래스 추가
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.fem_openaiclipvit import FEOpenAIClipViT
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
"""
Definition
"""
FILE_BASE_PATH = ''
if __name__ == '__main__':
FILE_BASE_PATH = os.path.dirname(os.path.realpath(__file__)) + '/'
# TRAINED_MODEL_PATH = FILE_BASE_PATH + 'openai/clip-resnet50' # OLD VERSION
# TRAINED_MODEL_PATH = FILE_BASE_PATH + 'openai/clip-vit-base-patch32' # huggingface 저장소 위치
TRAINED_MODEL_PATH = "./data" # 로컬 다운로드 경로
DEFAULT_INDEX_IMAGE_PATH = FILE_BASE_PATH + 'index_images'
class FEMUsageInfo(Enum):
"""
특징 추출(이미지 임베딩) 모델 추가시, 생성 정보를 추가해서 사용
* 변수명 변경되면 인덱스파일 이름 새로 만듬
"""
openaiclip_vit_b32_l2 = [FEOpenAIClipViT, FEM.FEMArguments(
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTB32),
token=HUGGINGFACE_TOKEN,
index_type=INDEX_TYPE_L2)]
openaiclip_vit_b32_cos = [FEOpenAIClipViT, FEM.FEMArguments(
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTB32),
token=HUGGINGFACE_TOKEN,
index_type=INDEX_TYPE_COSINE)]
openaiclip_vit_b16_l2 = [FEOpenAIClipViT, FEM.FEMArguments(
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTB16),
token=HUGGINGFACE_TOKEN,
index_type=INDEX_TYPE_L2)]
openaiclip_vit_b16_cos = [FEOpenAIClipViT, FEM.FEMArguments(
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTB16),
token=HUGGINGFACE_TOKEN,
index_type=INDEX_TYPE_COSINE)]
openaiclip_vit_l14_l2 = [FEOpenAIClipViT, FEM.FEMArguments(
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTL14),
token=HUGGINGFACE_TOKEN,
index_type=INDEX_TYPE_L2)]
openaiclip_vit_l14_cos = [FEOpenAIClipViT, FEM.FEMArguments(
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTL14),
token=HUGGINGFACE_TOKEN,
index_type=INDEX_TYPE_COSINE)]
openaiclip_vit_b14_336_l2 = [FEOpenAIClipViT, FEM.FEMArguments(
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTL14_336),
token=HUGGINGFACE_TOKEN,
index_type=INDEX_TYPE_L2)]
openaiclip_vit_b14_336_cos = [FEOpenAIClipViT, FEM.FEMArguments(
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTL14_336),
token=HUGGINGFACE_TOKEN,
index_type=INDEX_TYPE_COSINE)]
openaicliprn = None
mobilenetv2 = None
DEFAULT_WEIGHT_PATH = FILE_BASE_PATH + TRAINED_MODEL_PATH
TEST_IMAGES_PATH = FILE_BASE_PATH + 'testimages'
"""
Implementation
"""
# log off
import logging
logging.getLogger('PIL.PngImagePlugin').disabled = True
# Utility Method
def function_name():
return sys._getframe(1).f_code.co_name
class FeatureExtractionModel(ABC):
"""
특징 벡터 추출 모델의 추상 인터페이스
"""
@abstractmethod
def init_model(self, pretrained_model):
"""
로드 모델 및 초기화
:param pretrained_model:
:return:
"""
raise NotImplementedError(f'Subclasses must implement this method{function_name()}')
@abstractmethod
def image_embedding(self, image_np):
"""
이미지 임베딩 - 이미지 데이터(numpy)
:param image_np: 이미지 데이터(numpy array)
:return:
"""
raise NotImplementedError(f'Subclasses must implement this method{function_name()}')
class VectorSimilarity:
"""
벡터 기반 유사성 검색 모델 (FACEBOOK - FAISS)
"""
def __init__(self, model_info:FEMUsageInfo=None, index_file_path=None, txt_file_path=None):
self.fem_model:FEM.FeatureExtractionModel = None
self.model_info = model_info
self.fem_model_name = ''
self.fem_model_args = None
self.index_file_path = index_file_path
self.txt_file_path = txt_file_path
self.index_type = model_info.value[1].index_type
if model_info is not None:
self.create_feature_extraction_model(model_info)
def create_feature_extraction_model(self, model_info):
"""
특징 벡터 추출 모델 초기화
:return: True/False
"""
# check FEMUsageInfo
if isinstance(model_info, FEMUsageInfo):
if model_info not in FEMUsageInfo or model_info.value is None:
log.error('invalid FEMArguments')
return False
else:
log.error('invalid FEMArguments')
return False
if model_info.value[0] is None or model_info.value[1] is None:
log.error('invalid FEMArguments')
return False
# release model
if self.fem_model is not None:
del self.fem_model
# initialize ai-model
self.fem_model_name = model_info.name
self.fem_model_args = model_info.value[1]
self.fem_model = model_info.value[0](self.fem_model_args)
log.info(f'created fem-model: {model_info.value[0]}')
def image_embedding_from_file(self, image_file):
"""
이미지 파일에서 특징 벡터 추출
:param image_file: 이미지 파일
:return: 특징 벡터 or None
"""
feature_vectors = None
if not os.path.isfile(image_file):
log.error(f'not found file[{image_file}]')
return feature_vectors
image_data_np = Image.open(image_file).convert('RGB')
feature_vectors = self.fem_model.image_embedding(image_data_np)
return feature_vectors
def image_embedding_from_data(self, image_data_np=None):
"""
이미지 데이터(numpy)에서 특징 벡터 추출
:param image_data_np: 이미지 데이터(numpy)
:return: 특징 벡터 or None
"""
feature_vectors = None
if image_data_np is None:
log.error(f'invalid data[{image_data_np}]')
return feature_vectors
feature_vectors = self.fem_model.image_embedding(image_data_np)
return feature_vectors
def save_index_from_files(self, images_path=None, save_index_dir=None, save_txt_dir=None, index_type=INDEX_TYPE_L2):
"""
인덱스 정보 저장
:param images_path: 이미지 파일 경로
:param save_index_dir: 인덱스 저장 디렉토리
:param save_txt_dir: 이미지 경로 저장 디렉토리
:param index_type: 인덱스 타입 (L2, Cosine)
:return: 이미지 임베딩 차원 정보
"""
assert self.fem_model is not None, 'invalid fem_model'
assert images_path is not None, 'images_path is required'
if not os.path.exists(images_path):
log.error(f'image_path={images_path} does not exist')
image_folder = images_path
result = None
# 모든 이미지 임베딩 모으기
if os.path.exists(image_folder):
image_paths = [os.path.join(image_folder, image_file) for image_file in os.listdir(image_folder)]
image_vectors = np.vstack([self.image_embedding_from_file(image_file) for image_file in image_paths])
log.debug(f'image_vectors.shape={image_vectors.shape}')
result = image_vectors.shape
# 인덱스 생성 (L2)
dimension = image_vectors.shape[1]
if index_type == INDEX_TYPE_L2:
index = faiss.IndexFlatL2(dimension) # L2
elif index_type == INDEX_TYPE_COSINE:
index = faiss.IndexFlatIP(dimension) # cosine
faiss.normalize_L2(image_vectors)
else:
raise Exception(f'Invalid index_type={index_type}')
index.add(image_vectors)
# 인덱스 저장
if save_index_dir is None:
save_index_path = os.path.join(FILE_BASE_PATH,f'{self.fem_model_name}_{os.path.basename(self.fem_model_args.trained_model)}_{index_type}_{DEFAULT_SAVE_INDEX_SUFFIX}')
self.index_file_path = save_index_path
else:
save_index_path = os.path.join(save_index_dir,f'{self.fem_model_name}_{os.path.basename(self.fem_model_args.trained_model)}_{index_type}_{DEFAULT_SAVE_INDEX_SUFFIX}')
os.makedirs(save_index_dir, exist_ok=True)
# 이미지 경로 파일 저장
if save_txt_dir is None:
save_txt_path = os.path.join(FILE_BASE_PATH, f'{self.fem_model_name}_{os.path.basename(self.fem_model_args.trained_model)}_{index_type}.txt')
self.txt_file_path = save_txt_path
else:
save_txt_path = os.path.join(save_txt_dir, f'{self.fem_model_name}_{os.path.basename(self.fem_model_args.trained_model)}_{index_type}.txt')
os.makedirs(save_txt_dir, exist_ok=True)
if not os.path.exists(save_txt_path):
faiss.write_index(index, save_index_path)
log.debug(f'save_index_path={save_index_path}')
if not os.path.exists(save_txt_path):
# 이미지 경로 저장
with open(save_txt_path, 'w') as f:
for path in image_paths:
f.write(f"{path}\n")
log.debug(f'save_txt_path={save_txt_path}')
else:
log.error(f'Image folder {image_folder} does not exist')
return result
def set_index_path(self, index_path):
assert index_path is not None, 'index_path is required'
self.index_file_path = index_path
def load_index(self, index_file_name=None):
"""
인덱스 정보 불러오기
:return:
"""
if index_file_name is not None:
self.index_file_path = index_file_name
return faiss.read_index(self.index_file_path)
def time_diff(self, start_time, end_time):
# 시간 차이 계산
time_difference = end_time - start_time
# timedelta 객체에서 분:초:마이크로초 추출
total_seconds_float = time_difference.total_seconds()
total_seconds_int = int(total_seconds_float) # 정수 부분의 초
minutes = total_seconds_int // 60
seconds = total_seconds_int % 60
microseconds = time_difference.microseconds # 남은 마이크로초
return f"{minutes:02d}:{seconds:02d}.{microseconds:06d}"
def query_faiss(self, query_image=None, top_k=4):
assert query_image is not None, 'query_image is required'
if os.path.exists(self.txt_file_path):
with open(self.txt_file_path, 'r') as f:
image_paths = [line.strip() for line in f.readlines()]
else:
logging.error("Image path list TXT file not found.")
image_paths = []
start_vector_time = datetime.now()
index = self.load_index(self.index_file_path)
query_vector = self.image_embedding_from_file(query_image)
end_vector_time = datetime.now()
diff_vector_time = self.time_diff(start_vector_time,end_vector_time)
if self.index_type == INDEX_TYPE_COSINE:
faiss.normalize_L2(query_vector)
start_search_time = datetime.now()
distances, indices = index.search(query_vector, top_k)
end_search_time = datetime.now()
diff_search_time = self.time_diff(start_search_time,end_search_time)
diff_total_time = self.time_diff(start_vector_time,end_search_time)
inference_times = f"Total time - {diff_total_time}, vector_time - {diff_vector_time}, search_time - {diff_search_time}"
result_img_paths = []
result_percents = []
# 결과
# for i in range(top_k):
# print(f"{i + 1}: {image_paths[indices[0][i]]}, Distance: {distances[0][i]}")
for idx, dist in zip(indices[0], distances[0]):
log.debug(f"{idx} (거리: {dist:.4f})")
result_img_paths.append(image_paths[idx])
if self.index_type == INDEX_TYPE_COSINE:
result_percents.append(f"{dist*100:.2f}")
else:
result_percents.append(f"{((1 - dist)*100):.2f}")
return inference_times, result_img_paths, result_percents
def test():
"""
module test function
"""
log.info('\nModule: faiss_similarity_search.py')
# index_file_path = f'{FILE_BASE_PATH}openaiclipvit_clip-vit-base-patch32_index.faiss'
model = FEMUsageInfo.openaiclip_vit_l14_cos
index_file_path = os.path.join(VECTOR_IMG_LIST_PATH,f'{model.name}_{os.path.basename(model.value[1].trained_model)}_{model.value[1].index_type}_{DEFAULT_INDEX_NAME_SUFFIX}')
cm = VectorSimilarity(model_info=model,
index_file_path=index_file_path,
txt_file_path=os.path.join(VECTOR_IMG_LIST_PATH,f'{model.name}_{os.path.basename(model.value[1].trained_model)}_{model.value[1].index_type}.txt'))
# cm.create_feature_extraction_model(FEMUsageInfo.openaiclipvit)
cm.save_index_from_files(images_path=INDEX_IMAGES_PATH,
save_index_dir=FAISS_VECTOR_PATH,
save_txt_dir=VECTOR_IMG_LIST_PATH,
index_type=model.value[1].index_type)
cm.query_faiss(os.path.join(INDEX_IMAGES_PATH,"img1.png"))
if __name__ == '__main__':
"""
test module
"""
test()

View File

@@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
"""
@file : feature_extraction_model.py
@author: hsj100
@license: A2D2
@brief: 이미지 임베딩 모델 관리 정보
@section Modify History
- 2025-06-27 오후 2:07 hsj100 base
"""
"""
Basic logger: custom color logger
"""
try:
from custom_logger.main_log import vactor_logger as log
except ImportError:
import logging as log; log.basicConfig(level=log.DEBUG)
"""
Package: standard
"""
from abc import ABC, abstractmethod
from enum import Enum
import sys
from pydantic import BaseModel, Field, ConfigDict
from typing import Optional
"""
Package: 3rd party
"""
"""
Package: custom
"""
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
"""
Definition
"""
class FEMArguments(BaseModel):
"""
특징 추출(이미지 임베딩) 모델 인수 정보
NOTE: 모델별 추가내용 필요시, 항목 추가해서 사용
"""
device:str = Field(default=MODEL_DEVICE_GPU, description=f'FEM 동작 장치[{MODEL_DEVICE_CPU},{MODEL_DEVICE_GPU}]')
trained_model:str = Field(default=None, description='학습된 모델(경로 또는 파일명) FEM 모델별로 상이함')
token:str = Field(default=None, description='FEM 사용시 필요한 token 정보')
index_type:str = Field(default=INDEX_TYPE_L2, description=f'인덱스 타입[{INDEX_TYPE_L2}, {INDEX_TYPE_COSINE}]')
model_config = ConfigDict(from_attributes=True)
def function_name():
return sys._getframe(1).f_code.co_name
"""
Implementation
"""
class FeatureExtractionModel(ABC):
"""
특징 벡터 추출 모델의 추상 인터페이스
"""
@abstractmethod
def load_model(self, fem_arguments:FEMArguments):
"""
모델 초기화 및 메모리 적재
:param fem_arguments: 모델 파라미터
:return:
"""
raise NotImplementedError(f'Subclasses must implement this method[{function_name()}]')
@abstractmethod
def release_model(self):
"""
모델 메모리 해제
:return:
"""
raise NotImplementedError(f'Subclasses must implement this method[{function_name()}]')
@abstractmethod
def image_embedding(self, image_np):
"""
이미지 임베딩 - 이미지 데이터(numpy)
:param image_np: 이미지 데이터(numpy array)
:return:
"""
raise NotImplementedError(f'Subclasses must implement this method[{function_name()}]')
def test():
"""
module test function
"""
log.info('\nModule: feature_extraction_model.py')
if __name__ == '__main__':
"""
test module
"""
test()

View File

@@ -0,0 +1,200 @@
# -*- coding: utf-8 -*-
"""
@file : fem_openaiclipvit.py
@author: hsj100
@license: A2D2
@brief: OpenAI Clip Model (Image & Text Embedding)
@section Modify History
- 2025-06-27 오후 2:05 hsj100 base
"""
"""
Basic logger: custom color logger
"""
try:
from custom_logger.main_log import vactor_logger as log
except ImportError:
import logging as log; log.basicConfig(level=log.DEBUG)
"""
Package: 3rd party
"""
"""
Package: standard
"""
import torch
import os, sys
from transformers import CLIPProcessor, CLIPModel
from huggingface_hub import login as huggingface_login
"""
Package: custom
"""
# NOTE(hsj100): 코드 단독 실행을 위한 기본 경로 지정 (일반적으로는 본 코드 보다는 모듈 실행 방법을 권장)
if __name__ == '__main__':
# 현재 스크립트의 디렉토리를 기준으로 부모 디렉토리의 경로를 가져옴
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if parent_dir not in sys.path:
sys.path.append(parent_dir)
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS import feature_extraction_model as FEM
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
"""
Definition
"""
"""
Implementation
"""
# 특징 벡터 추출 모델 개별 정의
class FEOpenAIClipViT(FEM.FeatureExtractionModel):
def __init__(self, arguments=FEM.FEMArguments()):
"""
Constructor
:param arguments: parameter of this model
"""
# model interface
self.model = None
self.processor = None
# model parameters
self.device = None
self.trained_model = None # 학습된 모델 결과물이 있는 경로
self.huggingface_token = None # huggingface token
self.fem_arguments = None
self.load_model(fem_arguments=arguments)
def __del__(self):
"""
Destructor
:return:
"""
self.release_model()
def load_model(self, fem_arguments=None):
"""
추상 인터페이스 구현: 모델 초기화 및 메모리 적재
:param fem_arguments: 모델 인자 정보
:return:
"""
assert fem_arguments is not None, 'failed init_model: fem_arguments is None'
# model parameters
self.device = fem_arguments.device
self.trained_model = fem_arguments.trained_model
self.huggingface_token = fem_arguments.token
# device (CPU/GPU)
if self.device == FEM.MODEL_DEVICE_GPU:
if not torch.cuda.is_available():
self.device = FEM.MODEL_DEVICE_CPU
log.warning('Not found GPU device, use CPU instead')
# huggingface token
if self.huggingface_token:
huggingface_login(fem_arguments.token)
# model path check
if not os.path.exists(fem_arguments.trained_model):
self._download_model(fem_arguments.trained_model)
# model interface
self.model = CLIPModel.from_pretrained(fem_arguments.trained_model)
self.processor = CLIPProcessor.from_pretrained(fem_arguments.trained_model)
if self.device == FEM.MODEL_DEVICE_GPU:
self.model.to(self.device)
# inference mode
self.model.eval()
log.info(f'load complete:{self.__class__.__name__}')
def _download_model(self, pretrained_model_path):
"""
모델 다운로드 (로컬에 없는 경우)
:param pretrained_model: 모델 이름
:return:
"""
MODEL_LIST = [ViTB32, ViTB16, ViTL14, ViTL14_336]
model_name = os.path.basename(pretrained_model_path)
if model_name not in MODEL_LIST:
raise ValueError(f"Unsupported model name: {model_name}. Supported models are: {MODEL_LIST}")
model = CLIPModel.from_pretrained(f"openai/{model_name}")
processor = CLIPProcessor.from_pretrained(f"openai/{model_name}")
os.makedirs(pretrained_model_path, exist_ok=True)
model.save_pretrained(pretrained_model_path)
processor.save_pretrained(pretrained_model_path)
log.info(f'Model downloaded and saved to {pretrained_model_path}')
def release_model(self):
"""
추상 인터페이스 구현: 모델 메모리 해제
모델 해제
:return:
"""
# GPU 메모리 캐시 비우기 (GPU 메모리 해제에 도움)
if torch.cuda.is_available():
torch.cuda.empty_cache()
log.info(f'release complete:{self.__class__.__name__}')
def image_embedding(self, image_np):
"""
추상 인터페이스 구현 : 이미지 임베딩 (특징 벡터 추출)
:param image_np: 이미지 데이터 (numpy)
:return: 특징 벡터 (numpy)
"""
result = None
# 이미지 로딩 및 전처리
# image_np = Image.open(CLIPModel.TEST_IMAGE_PATH if image_path is None else image_path).convert('RGB')
inputs_tensor = self.processor(images=image_np, return_tensors='pt')
if self.device == FEM.MODEL_DEVICE_GPU:
inputs_tensor = inputs_tensor.to(self.device)
# 특징 벡터 추출
with torch.no_grad():
outputs = self.model.get_image_features(**inputs_tensor) # shape: (1, 512)
features = outputs / outputs.norm(p=2, dim=-1, keepdim=True) # 임베딩 정규화 L2
# log.debug(f'feature vector shape: {features.shape}')
if self.device == FEM.MODEL_DEVICE_GPU:
result = features.cpu().numpy()
else:
result = features.numpy()
# log.debug(f'embedding complete:{self.__class__.__name__}')
return result
def test():
"""
module test function
"""
log.info('\nModule: fem_openaiclipvit.py')
if __name__ == '__main__':
"""
test module
"""
test()

View File

@@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-
"""
Basic logger: custom color logger
"""
try:
from custom_logger.main_log import vactor_logger as log
except ImportError:
import logging as log; log.basicConfig(level=log.DEBUG)
"""
Package: standard
"""
import numpy as np
import matplotlib.pyplot as plt
import os
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from PIL import Image
from dataclasses import dataclass
"""
Package: 3rd party
"""
"""
Package: custom
"""
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.faiss_similarity_search import FEMUsageInfo
"""
Definition
"""
@dataclass
class ReportInfo:
feature_extraction_model : str
model_architecture : str
train_model : str
feature_extraction_time : str
query_image_path : str
result_image_paths: list
result_percents: list
similarity_type: str
"""
Implementation
"""
def table_setting(report_info:ReportInfo, result_image_path):
"""
결과 정보를 테이블 형태로 시각화하여 이미지로 저장하는 함수
"""
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(1920/100, 1080/100), dpi=100) # 2행 1열 서브플롯
##====================info table (상단 테이블)
_info_table_row = [ReportInfoConst.feature_extraction_model,
ReportInfoConst.model_architecture,
ReportInfoConst.train_model,
ReportInfoConst.feature_extraction_time,
ReportInfoConst.similarity_type]
_info_table_data = [[report_info.feature_extraction_model],
[report_info.model_architecture],
[report_info.train_model],
[report_info.feature_extraction_time],
[report_info.similarity_type]]
ax1.set_axis_off() # 축 숨기기
info_table = ax1.table(cellText=_info_table_data,
rowLabels=_info_table_row,
loc='center',
cellLoc='left',
bbox=[0,0,0.8,1])
# info_table.auto_set_column_width(True) # 열 너비를 내용에 맞게 자동 조절
# info_table.scale(0.8, 1.0)
info_table.set_fontsize(20) # 폰트
for (row, col), cell in info_table.get_celld().items():
cell.set_height(0.15)
##================================
##====================result table (하단 테이블)
_result_table_col = [ReportInfoConst.query, ReportInfoConst.result, ReportInfoConst.result, ReportInfoConst.result, ReportInfoConst.result]
_result_table_data = [['','','','',''],
['Similarity Rate(%)',
report_info.result_percents[0],
report_info.result_percents[1],
report_info.result_percents[2],
report_info.result_percents[3]]]
ax2.set_axis_off() # 축 숨기기
result_table = ax2.table(cellText=_result_table_data,
colLabels=_result_table_col,
loc='center',
cellLoc='center',
bbox=[-0.23,0,1.2,0.8])
result_table.set_fontsize(20) # 폰트
for (row, col), cell in result_table.get_celld().items():
if row == 1:
cell.set_height(0.5)
else:
cell.set_height(0.15)
# 이미지 삽입 - 위치 보정하여 셀 가운데 정렬
image_paths = [report_info.query_image_path] + report_info.result_image_paths[:4]
target_size = (240, 180)
for col, img_path in enumerate(image_paths):
try:
img = Image.open(img_path).convert("RGB").resize(target_size)
img_array = np.asarray(img)
imagebox = OffsetImage(img_array, zoom=1.0)
position_x = -0.13 + 0.24 * col # 열 중심 (0.1, 0.3, 0.5, 0.7, 0.9)
position_y = 0.4 # 테이블 세로 가운데 살짝 위쪽
ab = AnnotationBbox(imagebox,
(position_x, position_y),
frameon=False,
box_alignment=(0.4, 0.5),
xycoords='axes fraction')
ax2.add_artist(ab)
except Exception as e:
print(f"[이미지 로딩 실패] {img_path}: {e}")
plt.tight_layout()
plt.savefig(result_image_path)

View File

@@ -6,8 +6,9 @@ from bingart import BingArt
from custom_apps.utils import cookie_manager
from main_rest.app.utils.parsing_utils import prompt_to_filenames
from main_rest.app.utils.date_utils import D
from const import REMOTE_FOLDER, TEMP_FOLDER
from const import TEMP_FOLDER
from utils.custom_sftp import sftp_client
from config import rest_config
class BingArtGenerator:
@@ -15,7 +16,7 @@ class BingArtGenerator:
model = 'dalle3'
detail = 'art'
output_folder = os.path.join(TEMP_FOLDER,"dalle","art")
remote_folder = os.path.join(REMOTE_FOLDER,"dalle","art")
remote_folder = os.path.join(rest_config.remote_folder,"dalle","art")
def __init__(self):
self.bing_art = BingArt(auth_cookie_U=cookie_manager.get_cookie())

View File

@@ -4,8 +4,8 @@ from dataclasses import dataclass
from pathlib import Path
from custom_apps.bingimagecreator.BingImageCreator import *
from const import REMOTE_FOLDER
from custom_apps.utils import cookie_manager
from config import rest_config
@dataclass
@@ -24,7 +24,7 @@ class DallEArgument:
U = cookie_manager.get_cookie()
prompt: str
cookie_file: str|None = None
output_dir: str = os.path.join(REMOTE_FOLDER,"dalle","img")
output_dir: str = os.path.join(rest_config.remote_folder,"dalle","img")
download_count: int = 1
debug_file: str|None = None
quiet: bool = False

View File

@@ -1,8 +1,8 @@
import os
import shutil
from custom_apps.faiss.utils import preprocessing, preprocessing_quary, normalize, get_dataset_list
from custom_apps.faiss.const import *
from custom_apps.faiss_imagenet.utils import preprocessing, preprocessing_quary, normalize, get_dataset_list
from custom_apps.faiss_imagenet.const import *
# from rest.app.utils.date_utils import D
# from const import OUTPUT_FOLDER

View File

@@ -4,10 +4,12 @@ import vertexai
import os
from vertexai.preview.vision_models import ImageGenerationModel
from const import REMOTE_FOLDER, TEMP_FOLDER
from const import TEMP_FOLDER
from main_rest.app.utils.parsing_utils import prompt_to_filenames
from main_rest.app.utils.date_utils import D
from utils.custom_sftp import sftp_client
from config import rest_config
class ImagenConst:
project_id = "glasses-imagen"
@@ -22,7 +24,7 @@ def imagen_generate_image(prompt,download_count=1):
_file_name = prompt_to_filenames(prompt)
_folder = os.path.join(TEMP_FOLDER)
_remote_folder = os.path.join(REMOTE_FOLDER,"imagen")
_remote_folder = os.path.join(rest_config.remote_folder,"imagen")
_datetime = D.date_file_name()
# if not os.path.isdir(_folder):
@@ -88,6 +90,19 @@ def imagen_generate_image_path(image_prompt):
return os.path.join(folder_name,f"query.png")
def imagen_generate_temp_image_path(image_prompt):
create_time = D.date_file_name()
if not os.path.exists(TEMP_FOLDER):
os.makedirs(TEMP_FOLDER)
generate_img = imagen_generate_image_data(image_prompt)
generate_img.save(os.path.join(TEMP_FOLDER,f"query_{create_time}.png"))
return os.path.join(TEMP_FOLDER,f"query_{create_time}.png")
if __name__ == '__main__':
pass
# imagen_generate_image_data("cat")

View File

@@ -38,7 +38,7 @@ API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False)
@asynccontextmanager
async def lifespan(app: FastAPI):
# When service starts.
LOG.info("REST start")
LOG.info(f"REST start (port : {conf().REST_SERVER_PORT})")
yield

View File

@@ -570,6 +570,19 @@ class IndexType:
hnsw = "hnsw"
l2 = "l2"
class VitIndexType:
cos = "cos"
l2 = "l2"
class VitModelType:
b32 = "b32"
b16 = "b16"
l14 = "l14"
l14_336 = "l14_336"
class ImageGenerateReq(BaseModel):
"""
### [Request] image generate request
@@ -590,10 +603,29 @@ class VactorImageSearchReq(BaseModel):
### [Request] vactor image search request
"""
prompt : str = Field(description='프롬프트', example='검은색 안경')
index_type : str = Field(IndexType.hnsw, description='인덱스 타입', example=IndexType.hnsw)
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
indexType : str = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
searchNum : int = Field(4, description='검색결과 이미지 갯수', example=4)
class VactorImageSearchVitReq(BaseModel):
"""
### [Request] vactor image search vit
"""
prompt : str = Field(description='프롬프트', example='검은색 안경')
modelType : str = Field(VitModelType.l14, description='pretrained model 타입', example=VitModelType.l14)
indexType : str = Field(VitIndexType.l2, description='인덱스 타입', example=VitIndexType.l2)
searchNum : int = Field(4, description='검색결과 이미지 갯수', example=4)
class VactorImageSearchVitReportReq(BaseModel):
"""
### [Request] vactor image search vit request
"""
prompt : str = Field(description='프롬프트', example='검은색 안경')
modelType : str = Field(VitModelType.l14, description='pretrained model 타입', example=VitModelType.l14)
indexType : str = Field(VitIndexType.l2, description='인덱스 타입', example=VitIndexType.l2)
#===============================================================================
#===============================================================================
#===============================================================================

View File

@@ -21,11 +21,11 @@ from custom_logger.main_log import main_logger as LOG
from custom_apps.bingimagecreator.utils import DallEArgument,dalle3_generate_image
from custom_apps.bingart.bingart import BingArtGenerator
from custom_apps.imagen.custom_imagen import imagen_generate_image, imagen_generate_image_path
from custom_apps.imagen.custom_imagen import imagen_generate_image, imagen_generate_image_path, imagen_generate_temp_image_path
from main_rest.app.utils.parsing_utils import download_range
from custom_apps.utils import cookie_manager
from utils.custom_sftp import sftp_client
from const import REMOTE_FOLDER, TEMP_FOLDER
from config import rest_config
router = APIRouter(prefix="/services")
@@ -92,7 +92,6 @@ async def bing_img_generate(request: Request, request_body_info: M.ImageGenerate
LOG.error(traceback.format_exc())
return response.set_error(e)
@router.post("/imageGenerate/bingart", summary="이미지 생성(AI) - bing art (DALL-E 3)", response_model=M.ImageGenerateRes)
async def bing_art(request: Request, request_body_info: M.ImageGenerateReq):
"""
@@ -122,7 +121,6 @@ async def bing_art(request: Request, request_body_info: M.ImageGenerateReq):
LOG.error(traceback.format_exc())
return response.set_error(e)
@router.post("/imageGenerate/imagen", summary="이미지 생성(AI) - imagen", response_model=M.ImageGenerateRes)
async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
"""
@@ -150,9 +148,8 @@ async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
LOG.error(traceback.format_exc())
return response.set_error(error=e)
@router.post("/vactorImageSearch/imageGenerate/imagen", summary="벡터 이미지 검색 - imagen", response_model=M.ResponseBase)
async def vactor_image(request: Request, request_body_info: M.VactorImageSearchReq):
@router.post("/vactorImageSearch/imagenet/imageGenerate/imagen", summary="벡터 이미지 검색(imagenet) - imagen", response_model=M.ResponseBase)
async def vactor_imagenet(request: Request, request_body_info: M.VactorImageSearchReq):
"""
## 벡터 이미지 검색 - imagen
> imagen AI를 이용하여 이미지 생성 후 vactor 검색
@@ -164,13 +161,15 @@ async def vactor_image(request: Request, request_body_info: M.VactorImageSearchR
"""
response = M.ResponseBase()
try:
if request_body_info.index_type not in [M.IndexType.hnsw, M.IndexType.l2]:
raise Exception(f"index_type is hnsw or l2 (current value = {request_body_info.index_type})")
if request_body_info.indexType not in [M.IndexType.hnsw, M.IndexType.l2]:
raise Exception(f"indexType is hnsw or l2 (current value = {request_body_info.indexType})")
img_path = imagen_generate_image_path(image_prompt=request_body_info.prompt)
vactor_request_data = {'quary_image_path' : img_path,'index_type' : request_body_info.index_type, 'search_num' : request_body_info.search_num}
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search', data=json.dumps(vactor_request_data))
vactor_request_data = {'query_image_path' : img_path,
'index_type' : request_body_info.indexType,
'search_num' : request_body_info.searchNum}
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/imagenet', data=json.dumps(vactor_request_data))
if vactor_response.status_code != 200:
raise Exception(f"response error: {json.loads(vactor_response.text)['error']}")
@@ -183,14 +182,157 @@ async def vactor_image(request: Request, request_body_info: M.VactorImageSearchR
_base_bame = os.path.basename(_directory_path)
# remote 폴더 생성
sftp_client.remote_mkdir(os.path.join(REMOTE_FOLDER, _base_bame))
sftp_client.remote_mkdir(os.path.join(rest_config.remote_folder, _base_bame))
# remote 폴더에 이미지 저장
for i in os.listdir(_directory_path):
sftp_client.remote_copy_data(local_path=os.path.join(_directory_path, i), remote_path=os.path.join(REMOTE_FOLDER, _base_bame, i))
sftp_client.remote_copy_data(local_path=os.path.join(_directory_path, i), remote_path=os.path.join(rest_config.remote_folder, _base_bame, i))
return response.set_message()
except Exception as e:
LOG.error(traceback.format_exc())
return response.set_error(error=e)
@router.post("/vactorImageSearch/vit/imageGenerate/imagen", summary="벡터 이미지 검색(clip-vit) - imagen", response_model=M.ResponseBase)
async def vactor_vit_report(request: Request, request_body_info: M.VactorImageSearchVitReq):
"""
## 벡터 이미지 검색(clip-vit) - imagen
> imagen AI를 이용하여 이미지 생성 후 vactor 검색 그후 결과 이미지 생성
### Requriements
> - googlecli 설치(https://cloud.google.com/sdk/docs/install?hl=ko#linux)
### options
> - modelType -> b32,b16,l14,l14_336
> - indexType -> l2,cos
"""
response = M.ResponseBase()
try:
if not download_range(request_body_info.searchNum, max=10):
raise Exception(f"downloadCound is invalid (current value = {request_body_info.searchNum})")
if request_body_info.modelType not in [M.VitModelType.b32, M.VitModelType.b16, M.VitModelType.l14, M.VitModelType.l14_336]:
raise Exception(f"modelType is invalid (current value = {request_body_info.modelType})")
if request_body_info.indexType not in [M.VitIndexType.cos, M.VitIndexType.l2]:
raise Exception(f"indexType is invalid (current value = {request_body_info.indexType})")
query_image_path = imagen_generate_temp_image_path(image_prompt=request_body_info.prompt)
vector_request_data = {'query_image_path' : query_image_path,
'index_type' : request_body_info.indexType,
'model_type' : request_body_info.modelType,
'search_num' : request_body_info.searchNum}
vector_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/vit', data=json.dumps(vector_request_data))
vector_response_dict = json.loads(vector_response.text)
if vector_response.status_code != 200:
raise Exception(f"response error: {vector_response_dict['error']}")
if vector_response_dict["error"] != None:
raise Exception(f"vactor error: {vector_response_dict['error']}")
result_image_paths = vector_response_dict.get('img_list').get('result_image_paths')
result_percents = vector_response_dict.get('img_list').get('result_percents')
# 원격지 폴더 생성
remote_directory = os.path.join(rest_config.remote_folder, f"imagen_query_{request_body_info.modelType}_{request_body_info.indexType}_{request_body_info.prompt}_{D.date_file_name()}")
sftp_client.remote_mkdir(remote_directory)
# 원격지에 이미지 저장
sftp_client.remote_copy_data(local_path=query_image_path, remote_path=os.path.join(remote_directory,"query.png"))
for img_path, img_percent in zip(result_image_paths,result_percents):
sftp_client.remote_copy_data(local_path=img_path, remote_path=os.path.join(remote_directory,f"search_{img_percent}.png"))
# Clean up temporary files
if 'query_image_path' in locals():
if os.path.exists(query_image_path):
os.remove(query_image_path)
del query_image_path
return response.set_message()
except Exception as e:
LOG.error(traceback.format_exc())
# Clean up temporary files
if 'query_image_path' in locals():
if os.path.exists(query_image_path):
os.remove(query_image_path)
del query_image_path
return response.set_error(error=e)
@router.post("/vactorImageSearch/vit/imageGenerate/imagen/report", summary="벡터 이미지 검색(clip-vit) - imagen, report 생성", response_model=M.ResponseBase)
async def vactor_vit_report(request: Request, request_body_info: M.VactorImageSearchVitReportReq):
"""
## 벡터 이미지 검색(clip-vit) - imagen, report 생성
> imagen AI를 이용하여 이미지 생성 후 vactor 검색 그후 종합결과 이미지 생성
### Requriements
> - googlecli 설치(https://cloud.google.com/sdk/docs/install?hl=ko#linux)
### options
> - modelType -> b32,b16,l14,l14_336
> - indexType -> l2,cos
"""
response = M.ResponseBase()
try:
if request_body_info.modelType not in [M.VitModelType.b32, M.VitModelType.b16, M.VitModelType.l14, M.VitModelType.l14_336]:
raise Exception(f"modelType is invalid (current value = {request_body_info.modelType})")
if request_body_info.indexType not in [M.VitIndexType.cos, M.VitIndexType.l2]:
raise Exception(f"indexType is invalid (current value = {request_body_info.indexType})")
query_image_path = imagen_generate_temp_image_path(image_prompt=request_body_info.prompt)
report_image_path = f"{os.path.splitext(query_image_path)[0]}_report.png"
vactor_request_data = {'query_image_path' : query_image_path,
'index_type' : request_body_info.indexType,
'model_type' : request_body_info.modelType,
'report_path' : report_image_path}
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/vit/report', data=json.dumps(vactor_request_data))
if vactor_response.status_code != 200:
raise Exception(f"response error: {json.loads(vactor_response.text)['error']}")
if json.loads(vactor_response.text)["error"] != None:
raise Exception(f"vactor error: {json.loads(vactor_response.text)['error']}")
# remote 폴더에 이미지 저장
sftp_client.remote_copy_data(local_path=report_image_path,
remote_path=os.path.join(rest_config.remote_folder, f"imagen_report_vit_{request_body_info.prompt}_{D.date_file_name()}.png"))
# Clean up temporary files
if 'query_image_path' in locals():
if os.path.exists(query_image_path):
os.remove(query_image_path)
del query_image_path
if 'report_image_path' in locals():
if os.path.exists(report_image_path):
os.remove(report_image_path)
del report_image_path
return response.set_message()
except Exception as e:
LOG.error(traceback.format_exc())
# Clean up temporary files
if 'query_image_path' in locals():
if os.path.exists(query_image_path):
os.remove(query_image_path)
del query_image_path
if 'report_image_path' in locals():
if os.path.exists(report_image_path):
os.remove(report_image_path)
del report_image_path
return response.set_error(error=e)

View File

@@ -15,4 +15,10 @@ email-validator
#faiss
scikit-learn
pillow
pillow
flax
transformers
torch
tensorflow
matplotlib
keras==3.10.0

View File

@@ -2,10 +2,11 @@ import paramiko
class CustomSFTPClient():
def __init__(self):
host = "192.168.200.230"
port = 22
id = "fermat"
pw = "1234"
from config import rest_config
host = rest_config.sftp_host
port = rest_config.sftp_port
id = rest_config.sftp_id
pw = rest_config.sftp_pw
self.ssh_client = paramiko.SSHClient()
self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

View File

@@ -38,7 +38,7 @@ API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False)
@asynccontextmanager
async def lifespan(app: FastAPI):
# When service starts.
LOG.info("REST start")
LOG.info(f"REST start (port : {conf().REST_SERVER_PORT})")
yield

View File

@@ -570,12 +570,62 @@ class IndexType(str, Enum):
hnsw = "hnsw"
l2 = "l2"
class VitIndexType(str, Enum):
cos = "cos"
l2 = "l2"
class VitModelType(str, Enum):
b32 = "b32"
b16 = "b16"
l14 = "l14"
l14_336 = "l14_336"
class VactorSearchReq(BaseModel):
"""
### [Request] vactor 검색
"""
quary_image_path : str = Field(description='quary image', example='path')
index_type : IndexType = Field(IndexType.hnsw, description='인덱스 타입', example=IndexType.hnsw)
query_image_path : str = Field(description='quary image', example='path')
index_type : IndexType = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
class VactorSearchVitReportReq(BaseModel):
"""
### [Request] vactor 검색(vit) 후 리포트 이미지 생성
"""
query_image_path : str = Field(description='quary image', example='path')
index_type : VitIndexType = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
model_type : VitModelType = Field(VitModelType.l14, description='pretrained 모델 정보', example=VitModelType.l14)
report_path : str = Field(description='리포트 이미지 저장 경로', example='path')
class VactorSearchVitReq(BaseModel):
"""
### [Request] vactor 검색(vit) 후 이미지 생성
"""
query_image_path : str = Field(description='quary image', example='path')
index_type : VitIndexType = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
model_type : VitModelType = Field(VitModelType.l14, description='pretrained 모델 정보', example=VitModelType.l14)
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
class VactorSearchVitRes(ResponseBase):
img_list : dict = Field({}, description='이미지 결과 리스트', example={})
@staticmethod
def set_error(error):
VactorSearchVitRes.img_list = {}
VactorSearchVitRes.result = False
VactorSearchVitRes.error = str(error)
return VactorSearchVitRes
@staticmethod
def set_message(msg):
VactorSearchVitRes.img_list = msg
VactorSearchVitRes.result = True
VactorSearchVitRes.error = None
return VactorSearchVitRes

View File

@@ -19,12 +19,15 @@ from vactor_rest.app import models as M
from vactor_rest.app.utils.date_utils import D
from custom_logger.vactor_log import vactor_logger as LOG
from custom_apps.faiss.main import search_idxs
from custom_apps.faiss_imagenet.main import search_idxs
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.faiss_functions import get_clip_info, save_report_image, get_models
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.faiss_similarity_search import FEMUsageInfo
router = APIRouter(prefix="/services")
@router.post("/faiss/vactor/search", summary="vactor search", response_model=M.ResponseBase)
@router.post("/faiss/vactor/search/imagenet", summary="imagenet search", response_model=M.ResponseBase)
async def vactor_search(request: Request, request_body_info: M.VactorSearchReq):
"""
## 벡터검색
@@ -35,15 +38,55 @@ async def vactor_search(request: Request, request_body_info: M.VactorSearchReq):
"""
response = M.ResponseBase()
try:
if os.path.exists(request_body_info.quary_image_path):
search_idxs(image_path=request_body_info.quary_image_path,
if os.path.exists(request_body_info.query_image_path):
search_idxs(image_path=request_body_info.query_image_path,
index_type=request_body_info.index_type,
search_num=request_body_info.search_num)
else:
raise Exception(f"File {request_body_info.quary_image_path} does not exist.")
raise Exception(f"File {request_body_info.query_image_path} does not exist.")
return response.set_message()
except Exception as e:
LOG.error(traceback.format_exc())
return response.set_error(e)
@router.post("/faiss/vactor/search/vit/report", summary="vit search report", response_model=M.ResponseBase)
async def vactor_report_vit(request: Request, request_body_info: M.VactorSearchVitReportReq):
response = M.ResponseBase()
try:
if not os.path.exists(request_body_info.query_image_path):
raise FileNotFoundError(f"File {request_body_info.query_image_path} does not exist.")
model = get_models(index_type=request_body_info.index_type, model_type=request_body_info.model_type)
report_info = get_clip_info(model,request_body_info.query_image_path)
save_report_image(report_info, request_body_info.report_path)
return response.set_message()
except Exception as e:
LOG.error(traceback.format_exc())
return response.set_error(e)
@router.post("/faiss/vactor/search/vit", summary="vit search", response_model=M.VactorSearchVitRes)
async def vactor_report_vit(request: Request, request_body_info: M.VactorSearchVitReq):
response = M.VactorSearchVitRes()
try:
if not os.path.exists(request_body_info.query_image_path):
raise FileNotFoundError(f"File {request_body_info.query_image_path} does not exist.")
model = get_models(index_type=request_body_info.index_type, model_type=request_body_info.model_type)
report_info = get_clip_info(model,request_body_info.query_image_path,top_k=request_body_info.search_num)
return response.set_message({
'result_image_paths': report_info.result_image_paths,
'result_percents': report_info.result_percents
})
except Exception as e:
LOG.error(traceback.format_exc())
return response.set_error(e)