edit: faiss 전용 rest 서버 추가

This commit is contained in:
2025-04-28 11:26:19 +09:00
parent 19e62f5724
commit 6b212125a4
76 changed files with 6014 additions and 92 deletions

116
custom_apps/faiss/utils.py Normal file
View File

@@ -0,0 +1,116 @@
"""
quary 이미지(이미지 경로) 입력받아 처리
"""
import os
import time
import math
import numpy as np
from sklearn.preprocessing import normalize
import faiss
import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras.models import Model
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from PIL import Image
def populate(index, fvecs, batch_size=1000):
nloop = math.ceil(fvecs.shape[0] / batch_size)
for n in range(nloop):
s = time.time()
index.add(normalize(fvecs[n * batch_size : min((n + 1) * batch_size, fvecs.shape[0])]))
# print(n * batch_size, time.time() - s)
return index
def get_index(index_type, dim):
if index_type == 'hnsw':
m = 48
index = faiss.IndexHNSWFlat(dim, m)
index.hnsw.efConstruction = 128
return index
elif index_type == 'l2':
return faiss.IndexFlatL2(dim)
raise
def preprocessing(dim,fvec_file,index_type):
# index_type = 'hnsw'
# index_type = 'l2'
# f-string 방식 (python3 이상에서 지원)
index_file = f'{fvec_file}.{index_type}.index'
fvecs = np.memmap(fvec_file, dtype='float32', mode='r').view('float32').reshape(-1, dim)
if os.path.exists(index_file):
index = faiss.read_index(index_file)
if index_type == 'hnsw':
index.hnsw.efSearch = 256
else:
index = get_index(index_type, dim)
index = populate(index, fvecs)
faiss.write_index(index, index_file)
# print(index.ntotal)
return fvecs,index
def preprocessing_quary(dim,image_path,index_type):
# index_type = 'hnsw'
# index_type = 'l2'
fvecs = fvces_quary(image_path)
index = get_index(index_type, dim)
index = populate(index, fvecs)
return fvecs,index
def preprocess(img_path, input_shape):
img = tf.io.read_file(img_path)
img = tf.image.decode_jpeg(img, channels=input_shape[2])
img = tf.image.resize(img, input_shape[:2])
img = preprocess_input(img)
return img
def preprocess_pil(pil_img_data, input_shape):
pil_img = np.asarray(Image.open(pil_img_data))
pil_img = tf.image.resize(pil_img, input_shape[:2])
pil_img = preprocess_input(pil_img)
return pil_img
def fvces_quary(image_path):
batch_size = 100
input_shape = (224, 224, 3)
base = tf.keras.applications.MobileNetV2(input_shape=input_shape,
include_top=False,
weights='imagenet')
base.trainable = False
model = Model(inputs=base.input, outputs=layers.GlobalAveragePooling2D()(base.output))
list_ds = tf.data.Dataset.from_tensor_slices([image_path])
ds = list_ds.map(lambda x: preprocess(x, input_shape), num_parallel_calls=-1)
dataset = ds.batch(batch_size).prefetch(-1)
for batch in dataset:
fvecs = model.predict(batch)
return fvecs
# index_type = 'hnsw'
# index_type = 'l2'
def get_dataset_list(filepath):
content_list = []
import os
if os.path.isfile(filepath) and filepath.endswith(".txt"):
with open(filepath, 'r', encoding='utf-8') as file:
for line in file:
content_list.append(line.strip())
return content_list