edit : clip-vit 모델 추가
This commit is contained in:
2
custom_apps/faiss_imagenet/const.py
Normal file
2
custom_apps/faiss_imagenet/const.py
Normal file
@@ -0,0 +1,2 @@
|
||||
DATASET_BIN = "./datas/eyewear_all.fvecs.bin"
|
||||
DATASET_TEXT = "./datas/eyewear_all.fnames.txt"
|
||||
50
custom_apps/faiss_imagenet/main.py
Normal file
50
custom_apps/faiss_imagenet/main.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from custom_apps.faiss_imagenet.utils import preprocessing, preprocessing_quary, normalize, get_dataset_list
|
||||
from custom_apps.faiss_imagenet.const import *
|
||||
# from rest.app.utils.date_utils import D
|
||||
# from const import OUTPUT_FOLDER
|
||||
|
||||
dataset_list = get_dataset_list(DATASET_TEXT)
|
||||
|
||||
def search_idxs(image_path,dataset_bin=DATASET_BIN,index_type="hnsw",search_num=4):
|
||||
|
||||
if index_type not in ["hnsw", "l2"]:
|
||||
raise ValueError("index_type must be either 'hnsw' or 'l2'")
|
||||
|
||||
DIM = 1280
|
||||
|
||||
dataset_fvces, dataset_index = preprocessing(DIM,dataset_bin,index_type)
|
||||
|
||||
org_fvces, org_index = preprocessing_quary(DIM,image_path,index_type)
|
||||
|
||||
dists, idxs = dataset_index.search(normalize(org_fvces), search_num)
|
||||
|
||||
# print(dists[0])
|
||||
# print(idxs[0])
|
||||
|
||||
index_image_save(image_path, dists[0], idxs[0])
|
||||
|
||||
def index_image_save(query_image_path, dists, idxs):
|
||||
directory_path, file = os.path.split(query_image_path)
|
||||
_name, extension = os.path.splitext(file)
|
||||
|
||||
if not os.path.exists(directory_path):
|
||||
raise ValueError(f"Folder {directory_path} does not exist.")
|
||||
|
||||
for dist,index in zip(dists, idxs):
|
||||
|
||||
if dist > 1:
|
||||
dist = 0
|
||||
else:
|
||||
dist = 1-dist
|
||||
|
||||
origin_file_path = dataset_list[index]
|
||||
dest_file_path = os.path.join(directory_path, f"search_{round(float(dist),4)}{extension}")
|
||||
shutil.copy(origin_file_path, dest_file_path)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
116
custom_apps/faiss_imagenet/utils.py
Normal file
116
custom_apps/faiss_imagenet/utils.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
quary 이미지(이미지 경로) 입력받아 처리
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import math
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import normalize
|
||||
import faiss
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras.layers as layers
|
||||
from tensorflow.keras.models import Model
|
||||
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def populate(index, fvecs, batch_size=1000):
|
||||
nloop = math.ceil(fvecs.shape[0] / batch_size)
|
||||
for n in range(nloop):
|
||||
s = time.time()
|
||||
index.add(normalize(fvecs[n * batch_size : min((n + 1) * batch_size, fvecs.shape[0])]))
|
||||
# print(n * batch_size, time.time() - s)
|
||||
|
||||
return index
|
||||
|
||||
def get_index(index_type, dim):
|
||||
if index_type == 'hnsw':
|
||||
m = 48
|
||||
index = faiss.IndexHNSWFlat(dim, m)
|
||||
index.hnsw.efConstruction = 128
|
||||
return index
|
||||
elif index_type == 'l2':
|
||||
return faiss.IndexFlatL2(dim)
|
||||
raise
|
||||
|
||||
def preprocessing(dim,fvec_file,index_type):
|
||||
# index_type = 'hnsw'
|
||||
# index_type = 'l2'
|
||||
|
||||
# f-string 방식 (python3 이상에서 지원)
|
||||
index_file = f'{fvec_file}.{index_type}.index'
|
||||
|
||||
fvecs = np.memmap(fvec_file, dtype='float32', mode='r').view('float32').reshape(-1, dim)
|
||||
|
||||
if os.path.exists(index_file):
|
||||
index = faiss.read_index(index_file)
|
||||
if index_type == 'hnsw':
|
||||
index.hnsw.efSearch = 256
|
||||
else:
|
||||
index = get_index(index_type, dim)
|
||||
index = populate(index, fvecs)
|
||||
faiss.write_index(index, index_file)
|
||||
# print(index.ntotal)
|
||||
|
||||
return fvecs,index
|
||||
|
||||
def preprocessing_quary(dim,image_path,index_type):
|
||||
# index_type = 'hnsw'
|
||||
# index_type = 'l2'
|
||||
|
||||
fvecs = fvces_quary(image_path)
|
||||
index = get_index(index_type, dim)
|
||||
index = populate(index, fvecs)
|
||||
|
||||
return fvecs,index
|
||||
|
||||
def preprocess(img_path, input_shape):
|
||||
|
||||
img = tf.io.read_file(img_path)
|
||||
img = tf.image.decode_jpeg(img, channels=input_shape[2])
|
||||
img = tf.image.resize(img, input_shape[:2])
|
||||
img = preprocess_input(img)
|
||||
return img
|
||||
|
||||
def preprocess_pil(pil_img_data, input_shape):
|
||||
|
||||
pil_img = np.asarray(Image.open(pil_img_data))
|
||||
pil_img = tf.image.resize(pil_img, input_shape[:2])
|
||||
pil_img = preprocess_input(pil_img)
|
||||
return pil_img
|
||||
|
||||
def fvces_quary(image_path):
|
||||
|
||||
batch_size = 100
|
||||
input_shape = (224, 224, 3)
|
||||
base = tf.keras.applications.MobileNetV2(input_shape=input_shape,
|
||||
include_top=False,
|
||||
weights='imagenet')
|
||||
base.trainable = False
|
||||
model = Model(inputs=base.input, outputs=layers.GlobalAveragePooling2D()(base.output))
|
||||
|
||||
list_ds = tf.data.Dataset.from_tensor_slices([image_path])
|
||||
|
||||
ds = list_ds.map(lambda x: preprocess(x, input_shape), num_parallel_calls=-1)
|
||||
|
||||
dataset = ds.batch(batch_size).prefetch(-1)
|
||||
|
||||
for batch in dataset:
|
||||
fvecs = model.predict(batch)
|
||||
return fvecs
|
||||
|
||||
|
||||
# index_type = 'hnsw'
|
||||
# index_type = 'l2'
|
||||
|
||||
def get_dataset_list(filepath):
|
||||
content_list = []
|
||||
import os
|
||||
if os.path.isfile(filepath) and filepath.endswith(".txt"):
|
||||
with open(filepath, 'r', encoding='utf-8') as file:
|
||||
for line in file:
|
||||
content_list.append(line.strip())
|
||||
return content_list
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user