edit: faiss 전용 rest 서버 추가
This commit is contained in:
116
custom_apps/faiss/utils.py
Normal file
116
custom_apps/faiss/utils.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
quary 이미지(이미지 경로) 입력받아 처리
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import math
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import normalize
|
||||
import faiss
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras.layers as layers
|
||||
from tensorflow.keras.models import Model
|
||||
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def populate(index, fvecs, batch_size=1000):
|
||||
nloop = math.ceil(fvecs.shape[0] / batch_size)
|
||||
for n in range(nloop):
|
||||
s = time.time()
|
||||
index.add(normalize(fvecs[n * batch_size : min((n + 1) * batch_size, fvecs.shape[0])]))
|
||||
# print(n * batch_size, time.time() - s)
|
||||
|
||||
return index
|
||||
|
||||
def get_index(index_type, dim):
|
||||
if index_type == 'hnsw':
|
||||
m = 48
|
||||
index = faiss.IndexHNSWFlat(dim, m)
|
||||
index.hnsw.efConstruction = 128
|
||||
return index
|
||||
elif index_type == 'l2':
|
||||
return faiss.IndexFlatL2(dim)
|
||||
raise
|
||||
|
||||
def preprocessing(dim,fvec_file,index_type):
|
||||
# index_type = 'hnsw'
|
||||
# index_type = 'l2'
|
||||
|
||||
# f-string 방식 (python3 이상에서 지원)
|
||||
index_file = f'{fvec_file}.{index_type}.index'
|
||||
|
||||
fvecs = np.memmap(fvec_file, dtype='float32', mode='r').view('float32').reshape(-1, dim)
|
||||
|
||||
if os.path.exists(index_file):
|
||||
index = faiss.read_index(index_file)
|
||||
if index_type == 'hnsw':
|
||||
index.hnsw.efSearch = 256
|
||||
else:
|
||||
index = get_index(index_type, dim)
|
||||
index = populate(index, fvecs)
|
||||
faiss.write_index(index, index_file)
|
||||
# print(index.ntotal)
|
||||
|
||||
return fvecs,index
|
||||
|
||||
def preprocessing_quary(dim,image_path,index_type):
|
||||
# index_type = 'hnsw'
|
||||
# index_type = 'l2'
|
||||
|
||||
fvecs = fvces_quary(image_path)
|
||||
index = get_index(index_type, dim)
|
||||
index = populate(index, fvecs)
|
||||
|
||||
return fvecs,index
|
||||
|
||||
def preprocess(img_path, input_shape):
|
||||
|
||||
img = tf.io.read_file(img_path)
|
||||
img = tf.image.decode_jpeg(img, channels=input_shape[2])
|
||||
img = tf.image.resize(img, input_shape[:2])
|
||||
img = preprocess_input(img)
|
||||
return img
|
||||
|
||||
def preprocess_pil(pil_img_data, input_shape):
|
||||
|
||||
pil_img = np.asarray(Image.open(pil_img_data))
|
||||
pil_img = tf.image.resize(pil_img, input_shape[:2])
|
||||
pil_img = preprocess_input(pil_img)
|
||||
return pil_img
|
||||
|
||||
def fvces_quary(image_path):
|
||||
|
||||
batch_size = 100
|
||||
input_shape = (224, 224, 3)
|
||||
base = tf.keras.applications.MobileNetV2(input_shape=input_shape,
|
||||
include_top=False,
|
||||
weights='imagenet')
|
||||
base.trainable = False
|
||||
model = Model(inputs=base.input, outputs=layers.GlobalAveragePooling2D()(base.output))
|
||||
|
||||
list_ds = tf.data.Dataset.from_tensor_slices([image_path])
|
||||
|
||||
ds = list_ds.map(lambda x: preprocess(x, input_shape), num_parallel_calls=-1)
|
||||
|
||||
dataset = ds.batch(batch_size).prefetch(-1)
|
||||
|
||||
for batch in dataset:
|
||||
fvecs = model.predict(batch)
|
||||
return fvecs
|
||||
|
||||
|
||||
# index_type = 'hnsw'
|
||||
# index_type = 'l2'
|
||||
|
||||
def get_dataset_list(filepath):
|
||||
content_list = []
|
||||
import os
|
||||
if os.path.isfile(filepath) and filepath.endswith(".txt"):
|
||||
with open(filepath, 'r', encoding='utf-8') as file:
|
||||
for line in file:
|
||||
content_list.append(line.strip())
|
||||
return content_list
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user