edit : clip-vit 모델 추가
This commit is contained in:
6
.gitignore
vendored
6
.gitignore
vendored
@@ -145,4 +145,8 @@ log
|
|||||||
|
|
||||||
*output
|
*output
|
||||||
*temp
|
*temp
|
||||||
datas/eyewear_all/*
|
datas/eyewear_all/*
|
||||||
|
datas/clip*
|
||||||
|
datas/openai*
|
||||||
|
|
||||||
|
*result
|
||||||
27
config.py
Normal file
27
config.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
class Config:
|
||||||
|
def __init__(self):
|
||||||
|
self.set_rel()
|
||||||
|
|
||||||
|
def set_rel(self):
|
||||||
|
self.config = 'rel'
|
||||||
|
self.remote_folder = "/home/fermat/STORAGE/01.Projects/A2TEC/K_EYEWEAR/02.ML_DATA/Image_generator_result"
|
||||||
|
self.sftp_host = "192.168.200.230"
|
||||||
|
self.sftp_port = 22
|
||||||
|
self.sftp_id = "fermat"
|
||||||
|
self.sftp_pw = "1234"
|
||||||
|
|
||||||
|
def set_dev(self):
|
||||||
|
self.config = 'dev'
|
||||||
|
self.remote_folder = "/home/fermat/project/FM_TEST_REST_SERVER/result"
|
||||||
|
self.sftp_host = "192.168.200.231"
|
||||||
|
self.sftp_port = 22
|
||||||
|
self.sftp_id = "fermat"
|
||||||
|
self.sftp_pw = "fermat3514"
|
||||||
|
|
||||||
|
import os
|
||||||
|
if not os.path.exists(self.remote_folder):
|
||||||
|
os.makedirs(self.remote_folder)
|
||||||
|
|
||||||
|
|
||||||
|
rest_config = Config()
|
||||||
|
# rest_config.set_dev()
|
||||||
1
const.py
1
const.py
@@ -1,4 +1,3 @@
|
|||||||
REMOTE_FOLDER = "/home/fermat/STORAGE/01.Projects/A2TEC/K_EYEWEAR/02.ML_DATA/Image_generator_result"
|
|
||||||
TEMP_FOLDER = "./temp"
|
TEMP_FOLDER = "./temp"
|
||||||
|
|
||||||
ILLEGAL_FILE_NAME = ['<', '>', ':', '"', '/', '\ ', '|', '?', '*']
|
ILLEGAL_FILE_NAME = ['<', '>', ':', '"', '/', '\ ', '|', '?', '*']
|
||||||
11
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/__init__.py
Executable file
11
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/__init__.py
Executable file
@@ -0,0 +1,11 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
@file : __init__.py
|
||||||
|
@author: hsj100
|
||||||
|
@license: A2D2
|
||||||
|
@brief:
|
||||||
|
|
||||||
|
@section Modify History
|
||||||
|
- 2025-06-27 오후 5:35 hsj100 base
|
||||||
|
|
||||||
|
"""
|
||||||
40
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/const.py
Executable file
40
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/const.py
Executable file
@@ -0,0 +1,40 @@
|
|||||||
|
# Access token은 huggingface.co > Settings > Access Tokens 에서 생성
|
||||||
|
HUGGINGFACE_TOKEN = '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'
|
||||||
|
|
||||||
|
MODEL_DEVICE_CPU = 'cpu'
|
||||||
|
MODEL_DEVICE_GPU = 'cuda'
|
||||||
|
|
||||||
|
DEFAULT_INDEX_NAME_SUFFIX = 'index.faiss'
|
||||||
|
DEFAULT_SAVE_INDEX_SUFFIX = DEFAULT_INDEX_NAME_SUFFIX
|
||||||
|
|
||||||
|
INDEX_TYPE_L2 = 'l2'
|
||||||
|
INDEX_TYPE_COSINE = 'cosine'
|
||||||
|
|
||||||
|
#VIT MODELS
|
||||||
|
ViTB32 = "clip-vit-base-patch32"
|
||||||
|
ViTB16 = "clip-vit-base-patch16"
|
||||||
|
ViTL14 = "clip-vit-large-patch14"
|
||||||
|
ViTL14_336 = "clip-vit-large-patch14-336"
|
||||||
|
|
||||||
|
#download/save path
|
||||||
|
PRETRAINED_MODEL_PATH = "./datas"
|
||||||
|
FAISS_VECTOR_PATH = "./datas"
|
||||||
|
VECTOR_IMG_LIST_PATH = "./datas"
|
||||||
|
|
||||||
|
INDEX_IMAGES_PATH = "./datas/eyewear_all"
|
||||||
|
|
||||||
|
|
||||||
|
# report image consts
|
||||||
|
class ReportInfoConst:
|
||||||
|
feature_extraction_model = "Feature Extraction Model"
|
||||||
|
model_architecture = "Model Architecture"
|
||||||
|
train_model = "Train Model"
|
||||||
|
feature_extraction_time = "Feature Extraction Time"
|
||||||
|
|
||||||
|
similarity_type = "Similarity Type"
|
||||||
|
query = "Query"
|
||||||
|
result = "Result"
|
||||||
|
|
||||||
|
class similarity:
|
||||||
|
l2 = "L2(Euclidean Distance)"
|
||||||
|
cosine = "Cosine(Inner Product, L2)"
|
||||||
135
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/faiss_functions.py
Executable file
135
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/faiss_functions.py
Executable file
@@ -0,0 +1,135 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
@file : faiss_functions.py
|
||||||
|
@author: jwkim
|
||||||
|
@license: A2D2
|
||||||
|
|
||||||
|
@section Modify History
|
||||||
|
- 2025-06-27 오후 2:05 hsj100 base
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Basic logger: custom color logger
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from custom_logger.custom_log import custom_logger as log
|
||||||
|
except ImportError:
|
||||||
|
import logging as log; log.basicConfig(level=log.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: 3rd party
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: standard
|
||||||
|
"""
|
||||||
|
import random
|
||||||
|
import pathlib
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: custom
|
||||||
|
"""
|
||||||
|
# NOTE(hsj100): 코드 단독 실행을 위한 기본 경로 지정 (일반적으로는 본 코드 보다는 모듈 실행 방법을 권장)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 현재 스크립트의 디렉토리를 기준으로 부모 디렉토리의 경로를 가져옴
|
||||||
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
if parent_dir not in sys.path:
|
||||||
|
sys.path.append(parent_dir)
|
||||||
|
|
||||||
|
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.report_utils import *
|
||||||
|
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
|
||||||
|
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.faiss_similarity_search import VectorSimilarity
|
||||||
|
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.report_utils import *
|
||||||
|
"""
|
||||||
|
Definition
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Implementation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_models(index_type, model_type):
|
||||||
|
from vactor_rest.app import models as M
|
||||||
|
|
||||||
|
model = None
|
||||||
|
|
||||||
|
# model 선택
|
||||||
|
if index_type == M.VitIndexType.l2:
|
||||||
|
# l2
|
||||||
|
if model_type == M.VitModelType.l14:
|
||||||
|
model = FEMUsageInfo.openaiclip_vit_l14_l2
|
||||||
|
elif model_type == M.VitModelType.b32:
|
||||||
|
model = FEMUsageInfo.openaiclip_vit_b32_l2
|
||||||
|
elif model_type == M.VitModelType.b16:
|
||||||
|
model = FEMUsageInfo.openaiclip_vit_b16_l2
|
||||||
|
elif model_type == M.VitModelType.l14_336:
|
||||||
|
model = FEMUsageInfo.openaiclip_vit_l14_336_l2
|
||||||
|
else:
|
||||||
|
raise Exception(f"model_type {model_type} is invalid")
|
||||||
|
else:
|
||||||
|
#cosine
|
||||||
|
if model_type == M.VitModelType.l14:
|
||||||
|
model = FEMUsageInfo.openaiclip_vit_l14_cos
|
||||||
|
elif model_type == M.VitModelType.b32:
|
||||||
|
model = FEMUsageInfo.openaiclip_vit_b32_cos
|
||||||
|
elif model_type == M.VitModelType.b16:
|
||||||
|
model = FEMUsageInfo.openaiclip_vit_b16_cos
|
||||||
|
elif model_type == M.VitModelType.l14_336:
|
||||||
|
model = FEMUsageInfo.openaiclip_vit_l14_336_cos
|
||||||
|
else:
|
||||||
|
# NOTE(JWKIM) : 현재 l14만 지원 FEMUsageInfo 에 추가해야함
|
||||||
|
raise Exception(f"model_type {model_type} is invalid")
|
||||||
|
|
||||||
|
if model:
|
||||||
|
return model
|
||||||
|
else:
|
||||||
|
raise Exception(f"model is not selected")
|
||||||
|
|
||||||
|
|
||||||
|
def save_report_image(report_info, report_image_path):
|
||||||
|
"""
|
||||||
|
이미지 유사도 검색 결과를 리포트로 생성
|
||||||
|
"""
|
||||||
|
table_setting(report_info, report_image_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_clip_info(model, query_image_path, top_k=4):
|
||||||
|
"""
|
||||||
|
이미지 유사도 검색
|
||||||
|
"""
|
||||||
|
index_file_path = os.path.join(VECTOR_IMG_LIST_PATH,f'{model.name}_{os.path.basename(model.value[1].trained_model)}_{model.value[1].index_type}_{DEFAULT_INDEX_NAME_SUFFIX}')
|
||||||
|
txt_file_path = os.path.join(VECTOR_IMG_LIST_PATH,f'{model.name}_{os.path.basename(model.value[1].trained_model)}_{model.value[1].index_type}.txt')
|
||||||
|
vector_model = VectorSimilarity(model_info=model,
|
||||||
|
index_file_path=index_file_path,
|
||||||
|
txt_file_path=txt_file_path)
|
||||||
|
|
||||||
|
# vector_model.create_feature_extraction_model(FEMUsageInfo.openaiclipvit)
|
||||||
|
|
||||||
|
vector_model.save_index_from_files(images_path=INDEX_IMAGES_PATH,
|
||||||
|
save_index_dir=FAISS_VECTOR_PATH,
|
||||||
|
save_txt_dir=VECTOR_IMG_LIST_PATH,
|
||||||
|
index_type=model.value[1].index_type)
|
||||||
|
|
||||||
|
inference_times, result_img_paths, result_percents = vector_model.query_faiss(query_image_path, top_k=top_k)
|
||||||
|
|
||||||
|
report_info = ReportInfo(
|
||||||
|
feature_extraction_model=ReportInfoConst.feature_extraction_model,
|
||||||
|
model_architecture=ReportInfoConst.model_architecture,
|
||||||
|
train_model=os.path.basename(model.value[1].trained_model),
|
||||||
|
feature_extraction_time=inference_times,
|
||||||
|
query_image_path=query_image_path,
|
||||||
|
result_image_paths=result_img_paths,
|
||||||
|
result_percents=result_percents,
|
||||||
|
similarity_type=model.value[1].index_type
|
||||||
|
)
|
||||||
|
|
||||||
|
return report_info
|
||||||
|
|
||||||
|
|
||||||
410
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/faiss_similarity_search.py
Executable file
410
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/faiss_similarity_search.py
Executable file
@@ -0,0 +1,410 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
@file : faiss_similarity_search.py
|
||||||
|
@author: hsj100
|
||||||
|
@license: A2D2
|
||||||
|
@brief: 특징 벡터를 이용한 이미지 유사도 검색
|
||||||
|
|
||||||
|
@section Modify History
|
||||||
|
- 2025-06-19 오후 2:45 hsj100 base
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Basic logger: custom color logger
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from custom_logger.main_log import vactor_logger as log
|
||||||
|
except ImportError:
|
||||||
|
import logging as log; log.basicConfig(level=log.DEBUG)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: standard
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import CLIPProcessor, CLIPModel
|
||||||
|
from PIL import Image
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: 3rd party
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: custom
|
||||||
|
"""
|
||||||
|
# NOTE(hsj100): 코드 단독 실행을 위한 기본 경로 지정 (일반적으로는 본 코드 보다는 모듈 실행 방법을 권장)
|
||||||
|
import os, sys
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 현재 스크립트의 디렉토리를 기준으로 부모 디렉토리의 경로를 가져옴
|
||||||
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
if parent_dir not in sys.path:
|
||||||
|
sys.path.append(parent_dir)
|
||||||
|
|
||||||
|
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS import feature_extraction_model as FEM
|
||||||
|
|
||||||
|
# 사용할 이미지 임베딩 모델 클래스 추가
|
||||||
|
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.fem_openaiclipvit import FEOpenAIClipViT
|
||||||
|
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Definition
|
||||||
|
"""
|
||||||
|
FILE_BASE_PATH = ''
|
||||||
|
if __name__ == '__main__':
|
||||||
|
FILE_BASE_PATH = os.path.dirname(os.path.realpath(__file__)) + '/'
|
||||||
|
|
||||||
|
# TRAINED_MODEL_PATH = FILE_BASE_PATH + 'openai/clip-resnet50' # OLD VERSION
|
||||||
|
# TRAINED_MODEL_PATH = FILE_BASE_PATH + 'openai/clip-vit-base-patch32' # huggingface 저장소 위치
|
||||||
|
TRAINED_MODEL_PATH = "./data" # 로컬 다운로드 경로
|
||||||
|
|
||||||
|
DEFAULT_INDEX_IMAGE_PATH = FILE_BASE_PATH + 'index_images'
|
||||||
|
|
||||||
|
class FEMUsageInfo(Enum):
|
||||||
|
"""
|
||||||
|
특징 추출(이미지 임베딩) 모델 추가시, 생성 정보를 추가해서 사용
|
||||||
|
* 변수명 변경되면 인덱스파일 이름 새로 만듬
|
||||||
|
"""
|
||||||
|
openaiclip_vit_b32_l2 = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||||
|
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTB32),
|
||||||
|
token=HUGGINGFACE_TOKEN,
|
||||||
|
index_type=INDEX_TYPE_L2)]
|
||||||
|
|
||||||
|
openaiclip_vit_b32_cos = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||||
|
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTB32),
|
||||||
|
token=HUGGINGFACE_TOKEN,
|
||||||
|
index_type=INDEX_TYPE_COSINE)]
|
||||||
|
|
||||||
|
openaiclip_vit_b16_l2 = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||||
|
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTB16),
|
||||||
|
token=HUGGINGFACE_TOKEN,
|
||||||
|
index_type=INDEX_TYPE_L2)]
|
||||||
|
|
||||||
|
openaiclip_vit_b16_cos = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||||
|
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTB16),
|
||||||
|
token=HUGGINGFACE_TOKEN,
|
||||||
|
index_type=INDEX_TYPE_COSINE)]
|
||||||
|
|
||||||
|
openaiclip_vit_l14_l2 = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||||
|
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTL14),
|
||||||
|
token=HUGGINGFACE_TOKEN,
|
||||||
|
index_type=INDEX_TYPE_L2)]
|
||||||
|
|
||||||
|
openaiclip_vit_l14_cos = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||||
|
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTL14),
|
||||||
|
token=HUGGINGFACE_TOKEN,
|
||||||
|
index_type=INDEX_TYPE_COSINE)]
|
||||||
|
|
||||||
|
openaiclip_vit_b14_336_l2 = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||||
|
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTL14_336),
|
||||||
|
token=HUGGINGFACE_TOKEN,
|
||||||
|
index_type=INDEX_TYPE_L2)]
|
||||||
|
|
||||||
|
openaiclip_vit_b14_336_cos = [FEOpenAIClipViT, FEM.FEMArguments(
|
||||||
|
trained_model=os.path.join(PRETRAINED_MODEL_PATH,ViTL14_336),
|
||||||
|
token=HUGGINGFACE_TOKEN,
|
||||||
|
index_type=INDEX_TYPE_COSINE)]
|
||||||
|
|
||||||
|
openaicliprn = None
|
||||||
|
mobilenetv2 = None
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_WEIGHT_PATH = FILE_BASE_PATH + TRAINED_MODEL_PATH
|
||||||
|
|
||||||
|
TEST_IMAGES_PATH = FILE_BASE_PATH + 'testimages'
|
||||||
|
|
||||||
|
"""
|
||||||
|
Implementation
|
||||||
|
"""
|
||||||
|
# log off
|
||||||
|
import logging
|
||||||
|
logging.getLogger('PIL.PngImagePlugin').disabled = True
|
||||||
|
|
||||||
|
# Utility Method
|
||||||
|
def function_name():
|
||||||
|
return sys._getframe(1).f_code.co_name
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureExtractionModel(ABC):
|
||||||
|
"""
|
||||||
|
특징 벡터 추출 모델의 추상 인터페이스
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def init_model(self, pretrained_model):
|
||||||
|
"""
|
||||||
|
로드 모델 및 초기화
|
||||||
|
:param pretrained_model:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f'Subclasses must implement this method{function_name()}')
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def image_embedding(self, image_np):
|
||||||
|
"""
|
||||||
|
이미지 임베딩 - 이미지 데이터(numpy)
|
||||||
|
:param image_np: 이미지 데이터(numpy array)
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f'Subclasses must implement this method{function_name()}')
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSimilarity:
|
||||||
|
"""
|
||||||
|
벡터 기반 유사성 검색 모델 (FACEBOOK - FAISS)
|
||||||
|
"""
|
||||||
|
def __init__(self, model_info:FEMUsageInfo=None, index_file_path=None, txt_file_path=None):
|
||||||
|
self.fem_model:FEM.FeatureExtractionModel = None
|
||||||
|
|
||||||
|
self.model_info = model_info
|
||||||
|
self.fem_model_name = ''
|
||||||
|
self.fem_model_args = None
|
||||||
|
self.index_file_path = index_file_path
|
||||||
|
self.txt_file_path = txt_file_path
|
||||||
|
self.index_type = model_info.value[1].index_type
|
||||||
|
|
||||||
|
if model_info is not None:
|
||||||
|
self.create_feature_extraction_model(model_info)
|
||||||
|
|
||||||
|
def create_feature_extraction_model(self, model_info):
|
||||||
|
"""
|
||||||
|
특징 벡터 추출 모델 초기화
|
||||||
|
:return: True/False
|
||||||
|
"""
|
||||||
|
# check FEMUsageInfo
|
||||||
|
if isinstance(model_info, FEMUsageInfo):
|
||||||
|
if model_info not in FEMUsageInfo or model_info.value is None:
|
||||||
|
log.error('invalid FEMArguments')
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
log.error('invalid FEMArguments')
|
||||||
|
return False
|
||||||
|
|
||||||
|
if model_info.value[0] is None or model_info.value[1] is None:
|
||||||
|
log.error('invalid FEMArguments')
|
||||||
|
return False
|
||||||
|
|
||||||
|
# release model
|
||||||
|
if self.fem_model is not None:
|
||||||
|
del self.fem_model
|
||||||
|
|
||||||
|
# initialize ai-model
|
||||||
|
self.fem_model_name = model_info.name
|
||||||
|
self.fem_model_args = model_info.value[1]
|
||||||
|
self.fem_model = model_info.value[0](self.fem_model_args)
|
||||||
|
|
||||||
|
log.info(f'created fem-model: {model_info.value[0]}')
|
||||||
|
|
||||||
|
def image_embedding_from_file(self, image_file):
|
||||||
|
"""
|
||||||
|
이미지 파일에서 특징 벡터 추출
|
||||||
|
:param image_file: 이미지 파일
|
||||||
|
:return: 특징 벡터 or None
|
||||||
|
"""
|
||||||
|
feature_vectors = None
|
||||||
|
|
||||||
|
if not os.path.isfile(image_file):
|
||||||
|
log.error(f'not found file[{image_file}]')
|
||||||
|
return feature_vectors
|
||||||
|
|
||||||
|
image_data_np = Image.open(image_file).convert('RGB')
|
||||||
|
feature_vectors = self.fem_model.image_embedding(image_data_np)
|
||||||
|
return feature_vectors
|
||||||
|
|
||||||
|
def image_embedding_from_data(self, image_data_np=None):
|
||||||
|
"""
|
||||||
|
이미지 데이터(numpy)에서 특징 벡터 추출
|
||||||
|
:param image_data_np: 이미지 데이터(numpy)
|
||||||
|
:return: 특징 벡터 or None
|
||||||
|
"""
|
||||||
|
feature_vectors = None
|
||||||
|
|
||||||
|
if image_data_np is None:
|
||||||
|
log.error(f'invalid data[{image_data_np}]')
|
||||||
|
return feature_vectors
|
||||||
|
|
||||||
|
feature_vectors = self.fem_model.image_embedding(image_data_np)
|
||||||
|
return feature_vectors
|
||||||
|
|
||||||
|
def save_index_from_files(self, images_path=None, save_index_dir=None, save_txt_dir=None, index_type=INDEX_TYPE_L2):
|
||||||
|
"""
|
||||||
|
인덱스 정보 저장
|
||||||
|
:param images_path: 이미지 파일 경로
|
||||||
|
:param save_index_dir: 인덱스 저장 디렉토리
|
||||||
|
:param save_txt_dir: 이미지 경로 저장 디렉토리
|
||||||
|
:param index_type: 인덱스 타입 (L2, Cosine)
|
||||||
|
:return: 이미지 임베딩 차원 정보
|
||||||
|
"""
|
||||||
|
assert self.fem_model is not None, 'invalid fem_model'
|
||||||
|
assert images_path is not None, 'images_path is required'
|
||||||
|
|
||||||
|
if not os.path.exists(images_path):
|
||||||
|
log.error(f'image_path={images_path} does not exist')
|
||||||
|
|
||||||
|
image_folder = images_path
|
||||||
|
result = None
|
||||||
|
|
||||||
|
# 모든 이미지 임베딩 모으기
|
||||||
|
if os.path.exists(image_folder):
|
||||||
|
image_paths = [os.path.join(image_folder, image_file) for image_file in os.listdir(image_folder)]
|
||||||
|
image_vectors = np.vstack([self.image_embedding_from_file(image_file) for image_file in image_paths])
|
||||||
|
log.debug(f'image_vectors.shape={image_vectors.shape}')
|
||||||
|
result = image_vectors.shape
|
||||||
|
|
||||||
|
# 인덱스 생성 (L2)
|
||||||
|
dimension = image_vectors.shape[1]
|
||||||
|
|
||||||
|
if index_type == INDEX_TYPE_L2:
|
||||||
|
index = faiss.IndexFlatL2(dimension) # L2
|
||||||
|
elif index_type == INDEX_TYPE_COSINE:
|
||||||
|
index = faiss.IndexFlatIP(dimension) # cosine
|
||||||
|
faiss.normalize_L2(image_vectors)
|
||||||
|
else:
|
||||||
|
raise Exception(f'Invalid index_type={index_type}')
|
||||||
|
|
||||||
|
index.add(image_vectors)
|
||||||
|
|
||||||
|
# 인덱스 저장
|
||||||
|
if save_index_dir is None:
|
||||||
|
save_index_path = os.path.join(FILE_BASE_PATH,f'{self.fem_model_name}_{os.path.basename(self.fem_model_args.trained_model)}_{index_type}_{DEFAULT_SAVE_INDEX_SUFFIX}')
|
||||||
|
self.index_file_path = save_index_path
|
||||||
|
else:
|
||||||
|
save_index_path = os.path.join(save_index_dir,f'{self.fem_model_name}_{os.path.basename(self.fem_model_args.trained_model)}_{index_type}_{DEFAULT_SAVE_INDEX_SUFFIX}')
|
||||||
|
os.makedirs(save_index_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 이미지 경로 파일 저장
|
||||||
|
if save_txt_dir is None:
|
||||||
|
save_txt_path = os.path.join(FILE_BASE_PATH, f'{self.fem_model_name}_{os.path.basename(self.fem_model_args.trained_model)}_{index_type}.txt')
|
||||||
|
self.txt_file_path = save_txt_path
|
||||||
|
else:
|
||||||
|
save_txt_path = os.path.join(save_txt_dir, f'{self.fem_model_name}_{os.path.basename(self.fem_model_args.trained_model)}_{index_type}.txt')
|
||||||
|
os.makedirs(save_txt_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if not os.path.exists(save_txt_path):
|
||||||
|
faiss.write_index(index, save_index_path)
|
||||||
|
log.debug(f'save_index_path={save_index_path}')
|
||||||
|
|
||||||
|
if not os.path.exists(save_txt_path):
|
||||||
|
# 이미지 경로 저장
|
||||||
|
with open(save_txt_path, 'w') as f:
|
||||||
|
for path in image_paths:
|
||||||
|
f.write(f"{path}\n")
|
||||||
|
log.debug(f'save_txt_path={save_txt_path}')
|
||||||
|
|
||||||
|
else:
|
||||||
|
log.error(f'Image folder {image_folder} does not exist')
|
||||||
|
return result
|
||||||
|
|
||||||
|
def set_index_path(self, index_path):
|
||||||
|
assert index_path is not None, 'index_path is required'
|
||||||
|
self.index_file_path = index_path
|
||||||
|
|
||||||
|
def load_index(self, index_file_name=None):
|
||||||
|
"""
|
||||||
|
인덱스 정보 불러오기
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if index_file_name is not None:
|
||||||
|
self.index_file_path = index_file_name
|
||||||
|
|
||||||
|
return faiss.read_index(self.index_file_path)
|
||||||
|
|
||||||
|
def time_diff(self, start_time, end_time):
|
||||||
|
# 시간 차이 계산
|
||||||
|
time_difference = end_time - start_time
|
||||||
|
|
||||||
|
# timedelta 객체에서 분:초:마이크로초 추출
|
||||||
|
total_seconds_float = time_difference.total_seconds()
|
||||||
|
total_seconds_int = int(total_seconds_float) # 정수 부분의 초
|
||||||
|
|
||||||
|
minutes = total_seconds_int // 60
|
||||||
|
seconds = total_seconds_int % 60
|
||||||
|
microseconds = time_difference.microseconds # 남은 마이크로초
|
||||||
|
|
||||||
|
return f"{minutes:02d}:{seconds:02d}.{microseconds:06d}"
|
||||||
|
|
||||||
|
def query_faiss(self, query_image=None, top_k=4):
|
||||||
|
|
||||||
|
assert query_image is not None, 'query_image is required'
|
||||||
|
|
||||||
|
if os.path.exists(self.txt_file_path):
|
||||||
|
with open(self.txt_file_path, 'r') as f:
|
||||||
|
image_paths = [line.strip() for line in f.readlines()]
|
||||||
|
else:
|
||||||
|
logging.error("Image path list TXT file not found.")
|
||||||
|
image_paths = []
|
||||||
|
|
||||||
|
start_vector_time = datetime.now()
|
||||||
|
index = self.load_index(self.index_file_path)
|
||||||
|
query_vector = self.image_embedding_from_file(query_image)
|
||||||
|
end_vector_time = datetime.now()
|
||||||
|
|
||||||
|
diff_vector_time = self.time_diff(start_vector_time,end_vector_time)
|
||||||
|
|
||||||
|
if self.index_type == INDEX_TYPE_COSINE:
|
||||||
|
faiss.normalize_L2(query_vector)
|
||||||
|
|
||||||
|
start_search_time = datetime.now()
|
||||||
|
distances, indices = index.search(query_vector, top_k)
|
||||||
|
end_search_time = datetime.now()
|
||||||
|
|
||||||
|
diff_search_time = self.time_diff(start_search_time,end_search_time)
|
||||||
|
diff_total_time = self.time_diff(start_vector_time,end_search_time)
|
||||||
|
inference_times = f"Total time - {diff_total_time}, vector_time - {diff_vector_time}, search_time - {diff_search_time}"
|
||||||
|
|
||||||
|
result_img_paths = []
|
||||||
|
result_percents = []
|
||||||
|
|
||||||
|
# 결과
|
||||||
|
# for i in range(top_k):
|
||||||
|
# print(f"{i + 1}: {image_paths[indices[0][i]]}, Distance: {distances[0][i]}")
|
||||||
|
for idx, dist in zip(indices[0], distances[0]):
|
||||||
|
log.debug(f"{idx} (거리: {dist:.4f})")
|
||||||
|
|
||||||
|
result_img_paths.append(image_paths[idx])
|
||||||
|
|
||||||
|
if self.index_type == INDEX_TYPE_COSINE:
|
||||||
|
result_percents.append(f"{dist*100:.2f}")
|
||||||
|
else:
|
||||||
|
result_percents.append(f"{((1 - dist)*100):.2f}")
|
||||||
|
|
||||||
|
return inference_times, result_img_paths, result_percents
|
||||||
|
|
||||||
|
def test():
|
||||||
|
"""
|
||||||
|
module test function
|
||||||
|
"""
|
||||||
|
log.info('\nModule: faiss_similarity_search.py')
|
||||||
|
|
||||||
|
# index_file_path = f'{FILE_BASE_PATH}openaiclipvit_clip-vit-base-patch32_index.faiss'
|
||||||
|
model = FEMUsageInfo.openaiclip_vit_l14_cos
|
||||||
|
index_file_path = os.path.join(VECTOR_IMG_LIST_PATH,f'{model.name}_{os.path.basename(model.value[1].trained_model)}_{model.value[1].index_type}_{DEFAULT_INDEX_NAME_SUFFIX}')
|
||||||
|
|
||||||
|
cm = VectorSimilarity(model_info=model,
|
||||||
|
index_file_path=index_file_path,
|
||||||
|
txt_file_path=os.path.join(VECTOR_IMG_LIST_PATH,f'{model.name}_{os.path.basename(model.value[1].trained_model)}_{model.value[1].index_type}.txt'))
|
||||||
|
|
||||||
|
# cm.create_feature_extraction_model(FEMUsageInfo.openaiclipvit)
|
||||||
|
|
||||||
|
cm.save_index_from_files(images_path=INDEX_IMAGES_PATH,
|
||||||
|
save_index_dir=FAISS_VECTOR_PATH,
|
||||||
|
save_txt_dir=VECTOR_IMG_LIST_PATH,
|
||||||
|
index_type=model.value[1].index_type)
|
||||||
|
|
||||||
|
|
||||||
|
cm.query_faiss(os.path.join(INDEX_IMAGES_PATH,"img1.png"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
"""
|
||||||
|
test module
|
||||||
|
"""
|
||||||
|
test()
|
||||||
109
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/feature_extraction_model.py
Executable file
109
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/feature_extraction_model.py
Executable file
@@ -0,0 +1,109 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
@file : feature_extraction_model.py
|
||||||
|
@author: hsj100
|
||||||
|
@license: A2D2
|
||||||
|
@brief: 이미지 임베딩 모델 관리 정보
|
||||||
|
|
||||||
|
@section Modify History
|
||||||
|
- 2025-06-27 오후 2:07 hsj100 base
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Basic logger: custom color logger
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from custom_logger.main_log import vactor_logger as log
|
||||||
|
except ImportError:
|
||||||
|
import logging as log; log.basicConfig(level=log.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: standard
|
||||||
|
"""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
import sys
|
||||||
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: 3rd party
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: custom
|
||||||
|
"""
|
||||||
|
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Definition
|
||||||
|
"""
|
||||||
|
|
||||||
|
class FEMArguments(BaseModel):
|
||||||
|
"""
|
||||||
|
특징 추출(이미지 임베딩) 모델 인수 정보
|
||||||
|
NOTE: 모델별 추가내용 필요시, 항목 추가해서 사용
|
||||||
|
"""
|
||||||
|
device:str = Field(default=MODEL_DEVICE_GPU, description=f'FEM 동작 장치[{MODEL_DEVICE_CPU},{MODEL_DEVICE_GPU}]')
|
||||||
|
trained_model:str = Field(default=None, description='학습된 모델(경로 또는 파일명) FEM 모델별로 상이함')
|
||||||
|
token:str = Field(default=None, description='FEM 사용시 필요한 token 정보')
|
||||||
|
index_type:str = Field(default=INDEX_TYPE_L2, description=f'인덱스 타입[{INDEX_TYPE_L2}, {INDEX_TYPE_COSINE}]')
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
def function_name():
|
||||||
|
return sys._getframe(1).f_code.co_name
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Implementation
|
||||||
|
"""
|
||||||
|
class FeatureExtractionModel(ABC):
|
||||||
|
"""
|
||||||
|
특징 벡터 추출 모델의 추상 인터페이스
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model(self, fem_arguments:FEMArguments):
|
||||||
|
"""
|
||||||
|
모델 초기화 및 메모리 적재
|
||||||
|
:param fem_arguments: 모델 파라미터
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f'Subclasses must implement this method[{function_name()}]')
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def release_model(self):
|
||||||
|
"""
|
||||||
|
모델 메모리 해제
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f'Subclasses must implement this method[{function_name()}]')
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def image_embedding(self, image_np):
|
||||||
|
"""
|
||||||
|
이미지 임베딩 - 이미지 데이터(numpy)
|
||||||
|
:param image_np: 이미지 데이터(numpy array)
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f'Subclasses must implement this method[{function_name()}]')
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
"""
|
||||||
|
module test function
|
||||||
|
"""
|
||||||
|
log.info('\nModule: feature_extraction_model.py')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
"""
|
||||||
|
test module
|
||||||
|
"""
|
||||||
|
test()
|
||||||
200
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/fem_openaiclipvit.py
Executable file
200
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/fem_openaiclipvit.py
Executable file
@@ -0,0 +1,200 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
@file : fem_openaiclipvit.py
|
||||||
|
@author: hsj100
|
||||||
|
@license: A2D2
|
||||||
|
@brief: OpenAI Clip Model (Image & Text Embedding)
|
||||||
|
|
||||||
|
@section Modify History
|
||||||
|
- 2025-06-27 오후 2:05 hsj100 base
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Basic logger: custom color logger
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from custom_logger.main_log import vactor_logger as log
|
||||||
|
except ImportError:
|
||||||
|
import logging as log; log.basicConfig(level=log.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: 3rd party
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: standard
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import os, sys
|
||||||
|
|
||||||
|
from transformers import CLIPProcessor, CLIPModel
|
||||||
|
from huggingface_hub import login as huggingface_login
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: custom
|
||||||
|
"""
|
||||||
|
# NOTE(hsj100): 코드 단독 실행을 위한 기본 경로 지정 (일반적으로는 본 코드 보다는 모듈 실행 방법을 권장)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 현재 스크립트의 디렉토리를 기준으로 부모 디렉토리의 경로를 가져옴
|
||||||
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||||
|
if parent_dir not in sys.path:
|
||||||
|
sys.path.append(parent_dir)
|
||||||
|
|
||||||
|
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS import feature_extraction_model as FEM
|
||||||
|
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Definition
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Implementation
|
||||||
|
"""
|
||||||
|
# 특징 벡터 추출 모델 개별 정의
|
||||||
|
class FEOpenAIClipViT(FEM.FeatureExtractionModel):
|
||||||
|
def __init__(self, arguments=FEM.FEMArguments()):
|
||||||
|
"""
|
||||||
|
Constructor
|
||||||
|
:param arguments: parameter of this model
|
||||||
|
"""
|
||||||
|
# model interface
|
||||||
|
self.model = None
|
||||||
|
self.processor = None
|
||||||
|
|
||||||
|
# model parameters
|
||||||
|
self.device = None
|
||||||
|
self.trained_model = None # 학습된 모델 결과물이 있는 경로
|
||||||
|
self.huggingface_token = None # huggingface token
|
||||||
|
|
||||||
|
|
||||||
|
self.fem_arguments = None
|
||||||
|
|
||||||
|
self.load_model(fem_arguments=arguments)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""
|
||||||
|
Destructor
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.release_model()
|
||||||
|
|
||||||
|
def load_model(self, fem_arguments=None):
|
||||||
|
"""
|
||||||
|
추상 인터페이스 구현: 모델 초기화 및 메모리 적재
|
||||||
|
:param fem_arguments: 모델 인자 정보
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
assert fem_arguments is not None, 'failed init_model: fem_arguments is None'
|
||||||
|
|
||||||
|
# model parameters
|
||||||
|
self.device = fem_arguments.device
|
||||||
|
self.trained_model = fem_arguments.trained_model
|
||||||
|
self.huggingface_token = fem_arguments.token
|
||||||
|
|
||||||
|
# device (CPU/GPU)
|
||||||
|
if self.device == FEM.MODEL_DEVICE_GPU:
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
self.device = FEM.MODEL_DEVICE_CPU
|
||||||
|
log.warning('Not found GPU device, use CPU instead')
|
||||||
|
|
||||||
|
# huggingface token
|
||||||
|
if self.huggingface_token:
|
||||||
|
huggingface_login(fem_arguments.token)
|
||||||
|
|
||||||
|
# model path check
|
||||||
|
if not os.path.exists(fem_arguments.trained_model):
|
||||||
|
self._download_model(fem_arguments.trained_model)
|
||||||
|
|
||||||
|
# model interface
|
||||||
|
self.model = CLIPModel.from_pretrained(fem_arguments.trained_model)
|
||||||
|
self.processor = CLIPProcessor.from_pretrained(fem_arguments.trained_model)
|
||||||
|
|
||||||
|
if self.device == FEM.MODEL_DEVICE_GPU:
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
# inference mode
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
log.info(f'load complete:{self.__class__.__name__}')
|
||||||
|
|
||||||
|
def _download_model(self, pretrained_model_path):
|
||||||
|
"""
|
||||||
|
모델 다운로드 (로컬에 없는 경우)
|
||||||
|
:param pretrained_model: 모델 이름
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
MODEL_LIST = [ViTB32, ViTB16, ViTL14, ViTL14_336]
|
||||||
|
|
||||||
|
model_name = os.path.basename(pretrained_model_path)
|
||||||
|
|
||||||
|
if model_name not in MODEL_LIST:
|
||||||
|
raise ValueError(f"Unsupported model name: {model_name}. Supported models are: {MODEL_LIST}")
|
||||||
|
|
||||||
|
model = CLIPModel.from_pretrained(f"openai/{model_name}")
|
||||||
|
processor = CLIPProcessor.from_pretrained(f"openai/{model_name}")
|
||||||
|
|
||||||
|
os.makedirs(pretrained_model_path, exist_ok=True)
|
||||||
|
|
||||||
|
model.save_pretrained(pretrained_model_path)
|
||||||
|
processor.save_pretrained(pretrained_model_path)
|
||||||
|
|
||||||
|
log.info(f'Model downloaded and saved to {pretrained_model_path}')
|
||||||
|
|
||||||
|
def release_model(self):
|
||||||
|
"""
|
||||||
|
추상 인터페이스 구현: 모델 메모리 해제
|
||||||
|
모델 해제
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# GPU 메모리 캐시 비우기 (GPU 메모리 해제에 도움)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
log.info(f'release complete:{self.__class__.__name__}')
|
||||||
|
|
||||||
|
def image_embedding(self, image_np):
|
||||||
|
"""
|
||||||
|
추상 인터페이스 구현 : 이미지 임베딩 (특징 벡터 추출)
|
||||||
|
:param image_np: 이미지 데이터 (numpy)
|
||||||
|
:return: 특징 벡터 (numpy)
|
||||||
|
"""
|
||||||
|
result = None
|
||||||
|
# 이미지 로딩 및 전처리
|
||||||
|
# image_np = Image.open(CLIPModel.TEST_IMAGE_PATH if image_path is None else image_path).convert('RGB')
|
||||||
|
inputs_tensor = self.processor(images=image_np, return_tensors='pt')
|
||||||
|
|
||||||
|
if self.device == FEM.MODEL_DEVICE_GPU:
|
||||||
|
inputs_tensor = inputs_tensor.to(self.device)
|
||||||
|
|
||||||
|
# 특징 벡터 추출
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model.get_image_features(**inputs_tensor) # shape: (1, 512)
|
||||||
|
features = outputs / outputs.norm(p=2, dim=-1, keepdim=True) # 임베딩 정규화 L2
|
||||||
|
|
||||||
|
# log.debug(f'feature vector shape: {features.shape}')
|
||||||
|
|
||||||
|
if self.device == FEM.MODEL_DEVICE_GPU:
|
||||||
|
result = features.cpu().numpy()
|
||||||
|
else:
|
||||||
|
result = features.numpy()
|
||||||
|
|
||||||
|
# log.debug(f'embedding complete:{self.__class__.__name__}')
|
||||||
|
return result
|
||||||
|
|
||||||
|
def test():
|
||||||
|
"""
|
||||||
|
module test function
|
||||||
|
"""
|
||||||
|
log.info('\nModule: fem_openaiclipvit.py')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
"""
|
||||||
|
test module
|
||||||
|
"""
|
||||||
|
test()
|
||||||
130
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/report_utils.py
Executable file
130
custom_apps/FEATURE_VECTOR_SIMILARITY_FAISS/report_utils.py
Executable file
@@ -0,0 +1,130 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
Basic logger: custom color logger
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from custom_logger.main_log import vactor_logger as log
|
||||||
|
except ImportError:
|
||||||
|
import logging as log; log.basicConfig(level=log.DEBUG)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: standard
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
|
||||||
|
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
|
||||||
|
from PIL import Image
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: 3rd party
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Package: custom
|
||||||
|
"""
|
||||||
|
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.const import *
|
||||||
|
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.faiss_similarity_search import FEMUsageInfo
|
||||||
|
|
||||||
|
"""
|
||||||
|
Definition
|
||||||
|
"""
|
||||||
|
@dataclass
|
||||||
|
class ReportInfo:
|
||||||
|
feature_extraction_model : str
|
||||||
|
model_architecture : str
|
||||||
|
train_model : str
|
||||||
|
feature_extraction_time : str
|
||||||
|
|
||||||
|
query_image_path : str
|
||||||
|
result_image_paths: list
|
||||||
|
result_percents: list
|
||||||
|
similarity_type: str
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Implementation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def table_setting(report_info:ReportInfo, result_image_path):
|
||||||
|
"""
|
||||||
|
결과 정보를 테이블 형태로 시각화하여 이미지로 저장하는 함수
|
||||||
|
"""
|
||||||
|
|
||||||
|
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(1920/100, 1080/100), dpi=100) # 2행 1열 서브플롯
|
||||||
|
|
||||||
|
##====================info table (상단 테이블)
|
||||||
|
_info_table_row = [ReportInfoConst.feature_extraction_model,
|
||||||
|
ReportInfoConst.model_architecture,
|
||||||
|
ReportInfoConst.train_model,
|
||||||
|
ReportInfoConst.feature_extraction_time,
|
||||||
|
ReportInfoConst.similarity_type]
|
||||||
|
_info_table_data = [[report_info.feature_extraction_model],
|
||||||
|
[report_info.model_architecture],
|
||||||
|
[report_info.train_model],
|
||||||
|
[report_info.feature_extraction_time],
|
||||||
|
[report_info.similarity_type]]
|
||||||
|
ax1.set_axis_off() # 축 숨기기
|
||||||
|
info_table = ax1.table(cellText=_info_table_data,
|
||||||
|
rowLabels=_info_table_row,
|
||||||
|
loc='center',
|
||||||
|
cellLoc='left',
|
||||||
|
bbox=[0,0,0.8,1])
|
||||||
|
# info_table.auto_set_column_width(True) # 열 너비를 내용에 맞게 자동 조절
|
||||||
|
# info_table.scale(0.8, 1.0)
|
||||||
|
info_table.set_fontsize(20) # 폰트
|
||||||
|
|
||||||
|
for (row, col), cell in info_table.get_celld().items():
|
||||||
|
cell.set_height(0.15)
|
||||||
|
##================================
|
||||||
|
|
||||||
|
##====================result table (하단 테이블)
|
||||||
|
_result_table_col = [ReportInfoConst.query, ReportInfoConst.result, ReportInfoConst.result, ReportInfoConst.result, ReportInfoConst.result]
|
||||||
|
_result_table_data = [['','','','',''],
|
||||||
|
['Similarity Rate(%)',
|
||||||
|
report_info.result_percents[0],
|
||||||
|
report_info.result_percents[1],
|
||||||
|
report_info.result_percents[2],
|
||||||
|
report_info.result_percents[3]]]
|
||||||
|
|
||||||
|
ax2.set_axis_off() # 축 숨기기
|
||||||
|
result_table = ax2.table(cellText=_result_table_data,
|
||||||
|
colLabels=_result_table_col,
|
||||||
|
loc='center',
|
||||||
|
cellLoc='center',
|
||||||
|
bbox=[-0.23,0,1.2,0.8])
|
||||||
|
result_table.set_fontsize(20) # 폰트
|
||||||
|
|
||||||
|
for (row, col), cell in result_table.get_celld().items():
|
||||||
|
if row == 1:
|
||||||
|
cell.set_height(0.5)
|
||||||
|
else:
|
||||||
|
cell.set_height(0.15)
|
||||||
|
|
||||||
|
# 이미지 삽입 - 위치 보정하여 셀 가운데 정렬
|
||||||
|
image_paths = [report_info.query_image_path] + report_info.result_image_paths[:4]
|
||||||
|
target_size = (240, 180)
|
||||||
|
for col, img_path in enumerate(image_paths):
|
||||||
|
try:
|
||||||
|
img = Image.open(img_path).convert("RGB").resize(target_size)
|
||||||
|
img_array = np.asarray(img)
|
||||||
|
|
||||||
|
imagebox = OffsetImage(img_array, zoom=1.0)
|
||||||
|
position_x = -0.13 + 0.24 * col # 열 중심 (0.1, 0.3, 0.5, 0.7, 0.9)
|
||||||
|
position_y = 0.4 # 테이블 세로 가운데 살짝 위쪽
|
||||||
|
|
||||||
|
ab = AnnotationBbox(imagebox,
|
||||||
|
(position_x, position_y),
|
||||||
|
frameon=False,
|
||||||
|
box_alignment=(0.4, 0.5),
|
||||||
|
xycoords='axes fraction')
|
||||||
|
ax2.add_artist(ab)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[이미지 로딩 실패] {img_path}: {e}")
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(result_image_path)
|
||||||
@@ -6,8 +6,9 @@ from bingart import BingArt
|
|||||||
from custom_apps.utils import cookie_manager
|
from custom_apps.utils import cookie_manager
|
||||||
from main_rest.app.utils.parsing_utils import prompt_to_filenames
|
from main_rest.app.utils.parsing_utils import prompt_to_filenames
|
||||||
from main_rest.app.utils.date_utils import D
|
from main_rest.app.utils.date_utils import D
|
||||||
from const import REMOTE_FOLDER, TEMP_FOLDER
|
from const import TEMP_FOLDER
|
||||||
from utils.custom_sftp import sftp_client
|
from utils.custom_sftp import sftp_client
|
||||||
|
from config import rest_config
|
||||||
|
|
||||||
|
|
||||||
class BingArtGenerator:
|
class BingArtGenerator:
|
||||||
@@ -15,7 +16,7 @@ class BingArtGenerator:
|
|||||||
model = 'dalle3'
|
model = 'dalle3'
|
||||||
detail = 'art'
|
detail = 'art'
|
||||||
output_folder = os.path.join(TEMP_FOLDER,"dalle","art")
|
output_folder = os.path.join(TEMP_FOLDER,"dalle","art")
|
||||||
remote_folder = os.path.join(REMOTE_FOLDER,"dalle","art")
|
remote_folder = os.path.join(rest_config.remote_folder,"dalle","art")
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.bing_art = BingArt(auth_cookie_U=cookie_manager.get_cookie())
|
self.bing_art = BingArt(auth_cookie_U=cookie_manager.get_cookie())
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ from dataclasses import dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from custom_apps.bingimagecreator.BingImageCreator import *
|
from custom_apps.bingimagecreator.BingImageCreator import *
|
||||||
from const import REMOTE_FOLDER
|
|
||||||
from custom_apps.utils import cookie_manager
|
from custom_apps.utils import cookie_manager
|
||||||
|
from config import rest_config
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -24,7 +24,7 @@ class DallEArgument:
|
|||||||
U = cookie_manager.get_cookie()
|
U = cookie_manager.get_cookie()
|
||||||
prompt: str
|
prompt: str
|
||||||
cookie_file: str|None = None
|
cookie_file: str|None = None
|
||||||
output_dir: str = os.path.join(REMOTE_FOLDER,"dalle","img")
|
output_dir: str = os.path.join(rest_config.remote_folder,"dalle","img")
|
||||||
download_count: int = 1
|
download_count: int = 1
|
||||||
debug_file: str|None = None
|
debug_file: str|None = None
|
||||||
quiet: bool = False
|
quiet: bool = False
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from custom_apps.faiss.utils import preprocessing, preprocessing_quary, normalize, get_dataset_list
|
from custom_apps.faiss_imagenet.utils import preprocessing, preprocessing_quary, normalize, get_dataset_list
|
||||||
from custom_apps.faiss.const import *
|
from custom_apps.faiss_imagenet.const import *
|
||||||
# from rest.app.utils.date_utils import D
|
# from rest.app.utils.date_utils import D
|
||||||
# from const import OUTPUT_FOLDER
|
# from const import OUTPUT_FOLDER
|
||||||
|
|
||||||
@@ -4,10 +4,12 @@ import vertexai
|
|||||||
import os
|
import os
|
||||||
from vertexai.preview.vision_models import ImageGenerationModel
|
from vertexai.preview.vision_models import ImageGenerationModel
|
||||||
|
|
||||||
from const import REMOTE_FOLDER, TEMP_FOLDER
|
from const import TEMP_FOLDER
|
||||||
from main_rest.app.utils.parsing_utils import prompt_to_filenames
|
from main_rest.app.utils.parsing_utils import prompt_to_filenames
|
||||||
from main_rest.app.utils.date_utils import D
|
from main_rest.app.utils.date_utils import D
|
||||||
from utils.custom_sftp import sftp_client
|
from utils.custom_sftp import sftp_client
|
||||||
|
from config import rest_config
|
||||||
|
|
||||||
|
|
||||||
class ImagenConst:
|
class ImagenConst:
|
||||||
project_id = "glasses-imagen"
|
project_id = "glasses-imagen"
|
||||||
@@ -22,7 +24,7 @@ def imagen_generate_image(prompt,download_count=1):
|
|||||||
|
|
||||||
_file_name = prompt_to_filenames(prompt)
|
_file_name = prompt_to_filenames(prompt)
|
||||||
_folder = os.path.join(TEMP_FOLDER)
|
_folder = os.path.join(TEMP_FOLDER)
|
||||||
_remote_folder = os.path.join(REMOTE_FOLDER,"imagen")
|
_remote_folder = os.path.join(rest_config.remote_folder,"imagen")
|
||||||
_datetime = D.date_file_name()
|
_datetime = D.date_file_name()
|
||||||
|
|
||||||
# if not os.path.isdir(_folder):
|
# if not os.path.isdir(_folder):
|
||||||
@@ -88,6 +90,19 @@ def imagen_generate_image_path(image_prompt):
|
|||||||
|
|
||||||
return os.path.join(folder_name,f"query.png")
|
return os.path.join(folder_name,f"query.png")
|
||||||
|
|
||||||
|
def imagen_generate_temp_image_path(image_prompt):
|
||||||
|
|
||||||
|
create_time = D.date_file_name()
|
||||||
|
|
||||||
|
if not os.path.exists(TEMP_FOLDER):
|
||||||
|
os.makedirs(TEMP_FOLDER)
|
||||||
|
|
||||||
|
generate_img = imagen_generate_image_data(image_prompt)
|
||||||
|
|
||||||
|
generate_img.save(os.path.join(TEMP_FOLDER,f"query_{create_time}.png"))
|
||||||
|
|
||||||
|
return os.path.join(TEMP_FOLDER,f"query_{create_time}.png")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
pass
|
||||||
# imagen_generate_image_data("cat")
|
# imagen_generate_image_data("cat")
|
||||||
@@ -38,7 +38,7 @@ API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False)
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# When service starts.
|
# When service starts.
|
||||||
LOG.info("REST start")
|
LOG.info(f"REST start (port : {conf().REST_SERVER_PORT})")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|||||||
@@ -570,6 +570,19 @@ class IndexType:
|
|||||||
hnsw = "hnsw"
|
hnsw = "hnsw"
|
||||||
l2 = "l2"
|
l2 = "l2"
|
||||||
|
|
||||||
|
|
||||||
|
class VitIndexType:
|
||||||
|
cos = "cos"
|
||||||
|
l2 = "l2"
|
||||||
|
|
||||||
|
|
||||||
|
class VitModelType:
|
||||||
|
b32 = "b32"
|
||||||
|
b16 = "b16"
|
||||||
|
l14 = "l14"
|
||||||
|
l14_336 = "l14_336"
|
||||||
|
|
||||||
|
|
||||||
class ImageGenerateReq(BaseModel):
|
class ImageGenerateReq(BaseModel):
|
||||||
"""
|
"""
|
||||||
### [Request] image generate request
|
### [Request] image generate request
|
||||||
@@ -590,10 +603,29 @@ class VactorImageSearchReq(BaseModel):
|
|||||||
### [Request] vactor image search request
|
### [Request] vactor image search request
|
||||||
"""
|
"""
|
||||||
prompt : str = Field(description='프롬프트', example='검은색 안경')
|
prompt : str = Field(description='프롬프트', example='검은색 안경')
|
||||||
index_type : str = Field(IndexType.hnsw, description='인덱스 타입', example=IndexType.hnsw)
|
indexType : str = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
|
||||||
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
searchNum : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||||
|
|
||||||
|
|
||||||
|
class VactorImageSearchVitReq(BaseModel):
|
||||||
|
"""
|
||||||
|
### [Request] vactor image search vit
|
||||||
|
"""
|
||||||
|
prompt : str = Field(description='프롬프트', example='검은색 안경')
|
||||||
|
modelType : str = Field(VitModelType.l14, description='pretrained model 타입', example=VitModelType.l14)
|
||||||
|
indexType : str = Field(VitIndexType.l2, description='인덱스 타입', example=VitIndexType.l2)
|
||||||
|
searchNum : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||||
|
|
||||||
|
|
||||||
|
class VactorImageSearchVitReportReq(BaseModel):
|
||||||
|
"""
|
||||||
|
### [Request] vactor image search vit request
|
||||||
|
"""
|
||||||
|
prompt : str = Field(description='프롬프트', example='검은색 안경')
|
||||||
|
modelType : str = Field(VitModelType.l14, description='pretrained model 타입', example=VitModelType.l14)
|
||||||
|
indexType : str = Field(VitIndexType.l2, description='인덱스 타입', example=VitIndexType.l2)
|
||||||
|
|
||||||
|
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
#===============================================================================
|
#===============================================================================
|
||||||
|
|||||||
@@ -21,11 +21,11 @@ from custom_logger.main_log import main_logger as LOG
|
|||||||
|
|
||||||
from custom_apps.bingimagecreator.utils import DallEArgument,dalle3_generate_image
|
from custom_apps.bingimagecreator.utils import DallEArgument,dalle3_generate_image
|
||||||
from custom_apps.bingart.bingart import BingArtGenerator
|
from custom_apps.bingart.bingart import BingArtGenerator
|
||||||
from custom_apps.imagen.custom_imagen import imagen_generate_image, imagen_generate_image_path
|
from custom_apps.imagen.custom_imagen import imagen_generate_image, imagen_generate_image_path, imagen_generate_temp_image_path
|
||||||
from main_rest.app.utils.parsing_utils import download_range
|
from main_rest.app.utils.parsing_utils import download_range
|
||||||
from custom_apps.utils import cookie_manager
|
from custom_apps.utils import cookie_manager
|
||||||
from utils.custom_sftp import sftp_client
|
from utils.custom_sftp import sftp_client
|
||||||
from const import REMOTE_FOLDER, TEMP_FOLDER
|
from config import rest_config
|
||||||
|
|
||||||
router = APIRouter(prefix="/services")
|
router = APIRouter(prefix="/services")
|
||||||
|
|
||||||
@@ -92,7 +92,6 @@ async def bing_img_generate(request: Request, request_body_info: M.ImageGenerate
|
|||||||
LOG.error(traceback.format_exc())
|
LOG.error(traceback.format_exc())
|
||||||
return response.set_error(e)
|
return response.set_error(e)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/imageGenerate/bingart", summary="이미지 생성(AI) - bing art (DALL-E 3)", response_model=M.ImageGenerateRes)
|
@router.post("/imageGenerate/bingart", summary="이미지 생성(AI) - bing art (DALL-E 3)", response_model=M.ImageGenerateRes)
|
||||||
async def bing_art(request: Request, request_body_info: M.ImageGenerateReq):
|
async def bing_art(request: Request, request_body_info: M.ImageGenerateReq):
|
||||||
"""
|
"""
|
||||||
@@ -122,7 +121,6 @@ async def bing_art(request: Request, request_body_info: M.ImageGenerateReq):
|
|||||||
LOG.error(traceback.format_exc())
|
LOG.error(traceback.format_exc())
|
||||||
return response.set_error(e)
|
return response.set_error(e)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/imageGenerate/imagen", summary="이미지 생성(AI) - imagen", response_model=M.ImageGenerateRes)
|
@router.post("/imageGenerate/imagen", summary="이미지 생성(AI) - imagen", response_model=M.ImageGenerateRes)
|
||||||
async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
|
async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
|
||||||
"""
|
"""
|
||||||
@@ -150,9 +148,8 @@ async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
|
|||||||
LOG.error(traceback.format_exc())
|
LOG.error(traceback.format_exc())
|
||||||
return response.set_error(error=e)
|
return response.set_error(error=e)
|
||||||
|
|
||||||
|
@router.post("/vactorImageSearch/imagenet/imageGenerate/imagen", summary="벡터 이미지 검색(imagenet) - imagen", response_model=M.ResponseBase)
|
||||||
@router.post("/vactorImageSearch/imageGenerate/imagen", summary="벡터 이미지 검색 - imagen", response_model=M.ResponseBase)
|
async def vactor_imagenet(request: Request, request_body_info: M.VactorImageSearchReq):
|
||||||
async def vactor_image(request: Request, request_body_info: M.VactorImageSearchReq):
|
|
||||||
"""
|
"""
|
||||||
## 벡터 이미지 검색 - imagen
|
## 벡터 이미지 검색 - imagen
|
||||||
> imagen AI를 이용하여 이미지 생성 후 vactor 검색
|
> imagen AI를 이용하여 이미지 생성 후 vactor 검색
|
||||||
@@ -164,13 +161,15 @@ async def vactor_image(request: Request, request_body_info: M.VactorImageSearchR
|
|||||||
"""
|
"""
|
||||||
response = M.ResponseBase()
|
response = M.ResponseBase()
|
||||||
try:
|
try:
|
||||||
if request_body_info.index_type not in [M.IndexType.hnsw, M.IndexType.l2]:
|
if request_body_info.indexType not in [M.IndexType.hnsw, M.IndexType.l2]:
|
||||||
raise Exception(f"index_type is hnsw or l2 (current value = {request_body_info.index_type})")
|
raise Exception(f"indexType is hnsw or l2 (current value = {request_body_info.indexType})")
|
||||||
|
|
||||||
img_path = imagen_generate_image_path(image_prompt=request_body_info.prompt)
|
img_path = imagen_generate_image_path(image_prompt=request_body_info.prompt)
|
||||||
|
|
||||||
vactor_request_data = {'quary_image_path' : img_path,'index_type' : request_body_info.index_type, 'search_num' : request_body_info.search_num}
|
vactor_request_data = {'query_image_path' : img_path,
|
||||||
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search', data=json.dumps(vactor_request_data))
|
'index_type' : request_body_info.indexType,
|
||||||
|
'search_num' : request_body_info.searchNum}
|
||||||
|
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/imagenet', data=json.dumps(vactor_request_data))
|
||||||
|
|
||||||
if vactor_response.status_code != 200:
|
if vactor_response.status_code != 200:
|
||||||
raise Exception(f"response error: {json.loads(vactor_response.text)['error']}")
|
raise Exception(f"response error: {json.loads(vactor_response.text)['error']}")
|
||||||
@@ -183,14 +182,157 @@ async def vactor_image(request: Request, request_body_info: M.VactorImageSearchR
|
|||||||
_base_bame = os.path.basename(_directory_path)
|
_base_bame = os.path.basename(_directory_path)
|
||||||
|
|
||||||
# remote 폴더 생성
|
# remote 폴더 생성
|
||||||
sftp_client.remote_mkdir(os.path.join(REMOTE_FOLDER, _base_bame))
|
sftp_client.remote_mkdir(os.path.join(rest_config.remote_folder, _base_bame))
|
||||||
|
|
||||||
# remote 폴더에 이미지 저장
|
# remote 폴더에 이미지 저장
|
||||||
for i in os.listdir(_directory_path):
|
for i in os.listdir(_directory_path):
|
||||||
sftp_client.remote_copy_data(local_path=os.path.join(_directory_path, i), remote_path=os.path.join(REMOTE_FOLDER, _base_bame, i))
|
sftp_client.remote_copy_data(local_path=os.path.join(_directory_path, i), remote_path=os.path.join(rest_config.remote_folder, _base_bame, i))
|
||||||
|
|
||||||
return response.set_message()
|
return response.set_message()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.error(traceback.format_exc())
|
LOG.error(traceback.format_exc())
|
||||||
|
return response.set_error(error=e)
|
||||||
|
|
||||||
|
@router.post("/vactorImageSearch/vit/imageGenerate/imagen", summary="벡터 이미지 검색(clip-vit) - imagen", response_model=M.ResponseBase)
|
||||||
|
async def vactor_vit_report(request: Request, request_body_info: M.VactorImageSearchVitReq):
|
||||||
|
"""
|
||||||
|
## 벡터 이미지 검색(clip-vit) - imagen
|
||||||
|
> imagen AI를 이용하여 이미지 생성 후 vactor 검색 그후 결과 이미지 생성
|
||||||
|
|
||||||
|
### Requriements
|
||||||
|
> - googlecli 설치(https://cloud.google.com/sdk/docs/install?hl=ko#linux)
|
||||||
|
|
||||||
|
### options
|
||||||
|
> - modelType -> b32,b16,l14,l14_336
|
||||||
|
> - indexType -> l2,cos
|
||||||
|
|
||||||
|
"""
|
||||||
|
response = M.ResponseBase()
|
||||||
|
try:
|
||||||
|
if not download_range(request_body_info.searchNum, max=10):
|
||||||
|
raise Exception(f"downloadCound is invalid (current value = {request_body_info.searchNum})")
|
||||||
|
|
||||||
|
if request_body_info.modelType not in [M.VitModelType.b32, M.VitModelType.b16, M.VitModelType.l14, M.VitModelType.l14_336]:
|
||||||
|
raise Exception(f"modelType is invalid (current value = {request_body_info.modelType})")
|
||||||
|
|
||||||
|
if request_body_info.indexType not in [M.VitIndexType.cos, M.VitIndexType.l2]:
|
||||||
|
raise Exception(f"indexType is invalid (current value = {request_body_info.indexType})")
|
||||||
|
|
||||||
|
query_image_path = imagen_generate_temp_image_path(image_prompt=request_body_info.prompt)
|
||||||
|
|
||||||
|
vector_request_data = {'query_image_path' : query_image_path,
|
||||||
|
'index_type' : request_body_info.indexType,
|
||||||
|
'model_type' : request_body_info.modelType,
|
||||||
|
'search_num' : request_body_info.searchNum}
|
||||||
|
|
||||||
|
vector_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/vit', data=json.dumps(vector_request_data))
|
||||||
|
|
||||||
|
vector_response_dict = json.loads(vector_response.text)
|
||||||
|
|
||||||
|
if vector_response.status_code != 200:
|
||||||
|
raise Exception(f"response error: {vector_response_dict['error']}")
|
||||||
|
|
||||||
|
if vector_response_dict["error"] != None:
|
||||||
|
raise Exception(f"vactor error: {vector_response_dict['error']}")
|
||||||
|
|
||||||
|
result_image_paths = vector_response_dict.get('img_list').get('result_image_paths')
|
||||||
|
result_percents = vector_response_dict.get('img_list').get('result_percents')
|
||||||
|
|
||||||
|
# 원격지 폴더 생성
|
||||||
|
remote_directory = os.path.join(rest_config.remote_folder, f"imagen_query_{request_body_info.modelType}_{request_body_info.indexType}_{request_body_info.prompt}_{D.date_file_name()}")
|
||||||
|
sftp_client.remote_mkdir(remote_directory)
|
||||||
|
|
||||||
|
# 원격지에 이미지 저장
|
||||||
|
sftp_client.remote_copy_data(local_path=query_image_path, remote_path=os.path.join(remote_directory,"query.png"))
|
||||||
|
|
||||||
|
for img_path, img_percent in zip(result_image_paths,result_percents):
|
||||||
|
sftp_client.remote_copy_data(local_path=img_path, remote_path=os.path.join(remote_directory,f"search_{img_percent}.png"))
|
||||||
|
|
||||||
|
# Clean up temporary files
|
||||||
|
if 'query_image_path' in locals():
|
||||||
|
if os.path.exists(query_image_path):
|
||||||
|
os.remove(query_image_path)
|
||||||
|
del query_image_path
|
||||||
|
|
||||||
|
return response.set_message()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error(traceback.format_exc())
|
||||||
|
|
||||||
|
# Clean up temporary files
|
||||||
|
if 'query_image_path' in locals():
|
||||||
|
if os.path.exists(query_image_path):
|
||||||
|
os.remove(query_image_path)
|
||||||
|
del query_image_path
|
||||||
|
|
||||||
|
return response.set_error(error=e)
|
||||||
|
|
||||||
|
@router.post("/vactorImageSearch/vit/imageGenerate/imagen/report", summary="벡터 이미지 검색(clip-vit) - imagen, report 생성", response_model=M.ResponseBase)
|
||||||
|
async def vactor_vit_report(request: Request, request_body_info: M.VactorImageSearchVitReportReq):
|
||||||
|
"""
|
||||||
|
## 벡터 이미지 검색(clip-vit) - imagen, report 생성
|
||||||
|
> imagen AI를 이용하여 이미지 생성 후 vactor 검색 그후 종합결과 이미지 생성
|
||||||
|
|
||||||
|
### Requriements
|
||||||
|
> - googlecli 설치(https://cloud.google.com/sdk/docs/install?hl=ko#linux)
|
||||||
|
|
||||||
|
### options
|
||||||
|
> - modelType -> b32,b16,l14,l14_336
|
||||||
|
> - indexType -> l2,cos
|
||||||
|
|
||||||
|
"""
|
||||||
|
response = M.ResponseBase()
|
||||||
|
try:
|
||||||
|
if request_body_info.modelType not in [M.VitModelType.b32, M.VitModelType.b16, M.VitModelType.l14, M.VitModelType.l14_336]:
|
||||||
|
raise Exception(f"modelType is invalid (current value = {request_body_info.modelType})")
|
||||||
|
|
||||||
|
if request_body_info.indexType not in [M.VitIndexType.cos, M.VitIndexType.l2]:
|
||||||
|
raise Exception(f"indexType is invalid (current value = {request_body_info.indexType})")
|
||||||
|
|
||||||
|
query_image_path = imagen_generate_temp_image_path(image_prompt=request_body_info.prompt)
|
||||||
|
report_image_path = f"{os.path.splitext(query_image_path)[0]}_report.png"
|
||||||
|
|
||||||
|
vactor_request_data = {'query_image_path' : query_image_path,
|
||||||
|
'index_type' : request_body_info.indexType,
|
||||||
|
'model_type' : request_body_info.modelType,
|
||||||
|
'report_path' : report_image_path}
|
||||||
|
|
||||||
|
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search/vit/report', data=json.dumps(vactor_request_data))
|
||||||
|
|
||||||
|
if vactor_response.status_code != 200:
|
||||||
|
raise Exception(f"response error: {json.loads(vactor_response.text)['error']}")
|
||||||
|
|
||||||
|
if json.loads(vactor_response.text)["error"] != None:
|
||||||
|
raise Exception(f"vactor error: {json.loads(vactor_response.text)['error']}")
|
||||||
|
|
||||||
|
# remote 폴더에 이미지 저장
|
||||||
|
sftp_client.remote_copy_data(local_path=report_image_path,
|
||||||
|
remote_path=os.path.join(rest_config.remote_folder, f"imagen_report_vit_{request_body_info.prompt}_{D.date_file_name()}.png"))
|
||||||
|
|
||||||
|
# Clean up temporary files
|
||||||
|
if 'query_image_path' in locals():
|
||||||
|
if os.path.exists(query_image_path):
|
||||||
|
os.remove(query_image_path)
|
||||||
|
del query_image_path
|
||||||
|
if 'report_image_path' in locals():
|
||||||
|
if os.path.exists(report_image_path):
|
||||||
|
os.remove(report_image_path)
|
||||||
|
del report_image_path
|
||||||
|
|
||||||
|
return response.set_message()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error(traceback.format_exc())
|
||||||
|
|
||||||
|
# Clean up temporary files
|
||||||
|
if 'query_image_path' in locals():
|
||||||
|
if os.path.exists(query_image_path):
|
||||||
|
os.remove(query_image_path)
|
||||||
|
del query_image_path
|
||||||
|
if 'report_image_path' in locals():
|
||||||
|
if os.path.exists(report_image_path):
|
||||||
|
os.remove(report_image_path)
|
||||||
|
del report_image_path
|
||||||
|
|
||||||
return response.set_error(error=e)
|
return response.set_error(error=e)
|
||||||
@@ -15,4 +15,10 @@ email-validator
|
|||||||
|
|
||||||
#faiss
|
#faiss
|
||||||
scikit-learn
|
scikit-learn
|
||||||
pillow
|
pillow
|
||||||
|
flax
|
||||||
|
transformers
|
||||||
|
torch
|
||||||
|
tensorflow
|
||||||
|
matplotlib
|
||||||
|
keras==3.10.0
|
||||||
@@ -2,10 +2,11 @@ import paramiko
|
|||||||
|
|
||||||
class CustomSFTPClient():
|
class CustomSFTPClient():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
host = "192.168.200.230"
|
from config import rest_config
|
||||||
port = 22
|
host = rest_config.sftp_host
|
||||||
id = "fermat"
|
port = rest_config.sftp_port
|
||||||
pw = "1234"
|
id = rest_config.sftp_id
|
||||||
|
pw = rest_config.sftp_pw
|
||||||
|
|
||||||
self.ssh_client = paramiko.SSHClient()
|
self.ssh_client = paramiko.SSHClient()
|
||||||
self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False)
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# When service starts.
|
# When service starts.
|
||||||
LOG.info("REST start")
|
LOG.info(f"REST start (port : {conf().REST_SERVER_PORT})")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|||||||
@@ -570,12 +570,62 @@ class IndexType(str, Enum):
|
|||||||
hnsw = "hnsw"
|
hnsw = "hnsw"
|
||||||
l2 = "l2"
|
l2 = "l2"
|
||||||
|
|
||||||
|
|
||||||
|
class VitIndexType(str, Enum):
|
||||||
|
cos = "cos"
|
||||||
|
l2 = "l2"
|
||||||
|
|
||||||
|
|
||||||
|
class VitModelType(str, Enum):
|
||||||
|
b32 = "b32"
|
||||||
|
b16 = "b16"
|
||||||
|
l14 = "l14"
|
||||||
|
l14_336 = "l14_336"
|
||||||
|
|
||||||
|
|
||||||
class VactorSearchReq(BaseModel):
|
class VactorSearchReq(BaseModel):
|
||||||
"""
|
"""
|
||||||
### [Request] vactor 검색
|
### [Request] vactor 검색
|
||||||
"""
|
"""
|
||||||
quary_image_path : str = Field(description='quary image', example='path')
|
query_image_path : str = Field(description='quary image', example='path')
|
||||||
index_type : IndexType = Field(IndexType.hnsw, description='인덱스 타입', example=IndexType.hnsw)
|
index_type : IndexType = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
|
||||||
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||||
|
|
||||||
|
|
||||||
|
class VactorSearchVitReportReq(BaseModel):
|
||||||
|
"""
|
||||||
|
### [Request] vactor 검색(vit) 후 리포트 이미지 생성
|
||||||
|
"""
|
||||||
|
query_image_path : str = Field(description='quary image', example='path')
|
||||||
|
index_type : VitIndexType = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
|
||||||
|
model_type : VitModelType = Field(VitModelType.l14, description='pretrained 모델 정보', example=VitModelType.l14)
|
||||||
|
report_path : str = Field(description='리포트 이미지 저장 경로', example='path')
|
||||||
|
|
||||||
|
|
||||||
|
class VactorSearchVitReq(BaseModel):
|
||||||
|
"""
|
||||||
|
### [Request] vactor 검색(vit) 후 이미지 생성
|
||||||
|
"""
|
||||||
|
query_image_path : str = Field(description='quary image', example='path')
|
||||||
|
index_type : VitIndexType = Field(IndexType.l2, description='인덱스 타입', example=IndexType.l2)
|
||||||
|
model_type : VitModelType = Field(VitModelType.l14, description='pretrained 모델 정보', example=VitModelType.l14)
|
||||||
|
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||||
|
|
||||||
|
class VactorSearchVitRes(ResponseBase):
|
||||||
|
img_list : dict = Field({}, description='이미지 결과 리스트', example={})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_error(error):
|
||||||
|
VactorSearchVitRes.img_list = {}
|
||||||
|
VactorSearchVitRes.result = False
|
||||||
|
VactorSearchVitRes.error = str(error)
|
||||||
|
|
||||||
|
return VactorSearchVitRes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_message(msg):
|
||||||
|
VactorSearchVitRes.img_list = msg
|
||||||
|
VactorSearchVitRes.result = True
|
||||||
|
VactorSearchVitRes.error = None
|
||||||
|
|
||||||
|
return VactorSearchVitRes
|
||||||
|
|||||||
@@ -19,12 +19,15 @@ from vactor_rest.app import models as M
|
|||||||
from vactor_rest.app.utils.date_utils import D
|
from vactor_rest.app.utils.date_utils import D
|
||||||
from custom_logger.vactor_log import vactor_logger as LOG
|
from custom_logger.vactor_log import vactor_logger as LOG
|
||||||
|
|
||||||
from custom_apps.faiss.main import search_idxs
|
from custom_apps.faiss_imagenet.main import search_idxs
|
||||||
|
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.faiss_functions import get_clip_info, save_report_image, get_models
|
||||||
|
|
||||||
|
from custom_apps.FEATURE_VECTOR_SIMILARITY_FAISS.faiss_similarity_search import FEMUsageInfo
|
||||||
|
|
||||||
router = APIRouter(prefix="/services")
|
router = APIRouter(prefix="/services")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/faiss/vactor/search", summary="vactor search", response_model=M.ResponseBase)
|
@router.post("/faiss/vactor/search/imagenet", summary="imagenet search", response_model=M.ResponseBase)
|
||||||
async def vactor_search(request: Request, request_body_info: M.VactorSearchReq):
|
async def vactor_search(request: Request, request_body_info: M.VactorSearchReq):
|
||||||
"""
|
"""
|
||||||
## 벡터검색
|
## 벡터검색
|
||||||
@@ -35,15 +38,55 @@ async def vactor_search(request: Request, request_body_info: M.VactorSearchReq):
|
|||||||
"""
|
"""
|
||||||
response = M.ResponseBase()
|
response = M.ResponseBase()
|
||||||
try:
|
try:
|
||||||
if os.path.exists(request_body_info.quary_image_path):
|
if os.path.exists(request_body_info.query_image_path):
|
||||||
search_idxs(image_path=request_body_info.quary_image_path,
|
search_idxs(image_path=request_body_info.query_image_path,
|
||||||
index_type=request_body_info.index_type,
|
index_type=request_body_info.index_type,
|
||||||
search_num=request_body_info.search_num)
|
search_num=request_body_info.search_num)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"File {request_body_info.quary_image_path} does not exist.")
|
raise Exception(f"File {request_body_info.query_image_path} does not exist.")
|
||||||
|
|
||||||
return response.set_message()
|
return response.set_message()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error(traceback.format_exc())
|
||||||
|
return response.set_error(e)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/faiss/vactor/search/vit/report", summary="vit search report", response_model=M.ResponseBase)
|
||||||
|
async def vactor_report_vit(request: Request, request_body_info: M.VactorSearchVitReportReq):
|
||||||
|
response = M.ResponseBase()
|
||||||
|
try:
|
||||||
|
if not os.path.exists(request_body_info.query_image_path):
|
||||||
|
raise FileNotFoundError(f"File {request_body_info.query_image_path} does not exist.")
|
||||||
|
|
||||||
|
model = get_models(index_type=request_body_info.index_type, model_type=request_body_info.model_type)
|
||||||
|
|
||||||
|
report_info = get_clip_info(model,request_body_info.query_image_path)
|
||||||
|
save_report_image(report_info, request_body_info.report_path)
|
||||||
|
|
||||||
|
return response.set_message()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error(traceback.format_exc())
|
||||||
|
return response.set_error(e)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/faiss/vactor/search/vit", summary="vit search", response_model=M.VactorSearchVitRes)
|
||||||
|
async def vactor_report_vit(request: Request, request_body_info: M.VactorSearchVitReq):
|
||||||
|
response = M.VactorSearchVitRes()
|
||||||
|
try:
|
||||||
|
if not os.path.exists(request_body_info.query_image_path):
|
||||||
|
raise FileNotFoundError(f"File {request_body_info.query_image_path} does not exist.")
|
||||||
|
|
||||||
|
model = get_models(index_type=request_body_info.index_type, model_type=request_body_info.model_type)
|
||||||
|
|
||||||
|
report_info = get_clip_info(model,request_body_info.query_image_path,top_k=request_body_info.search_num)
|
||||||
|
|
||||||
|
return response.set_message({
|
||||||
|
'result_image_paths': report_info.result_image_paths,
|
||||||
|
'result_percents': report_info.result_percents
|
||||||
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOG.error(traceback.format_exc())
|
LOG.error(traceback.format_exc())
|
||||||
return response.set_error(e)
|
return response.set_error(e)
|
||||||
Reference in New Issue
Block a user