Files
GLASSES_AI_SERVER/custom_apps/faiss/utils.py

116 lines
3.3 KiB
Python

"""
quary 이미지(이미지 경로) 입력받아 처리
"""
import os
import time
import math
import numpy as np
from sklearn.preprocessing import normalize
import faiss
import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras.models import Model
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from PIL import Image
def populate(index, fvecs, batch_size=1000):
nloop = math.ceil(fvecs.shape[0] / batch_size)
for n in range(nloop):
s = time.time()
index.add(normalize(fvecs[n * batch_size : min((n + 1) * batch_size, fvecs.shape[0])]))
# print(n * batch_size, time.time() - s)
return index
def get_index(index_type, dim):
if index_type == 'hnsw':
m = 48
index = faiss.IndexHNSWFlat(dim, m)
index.hnsw.efConstruction = 128
return index
elif index_type == 'l2':
return faiss.IndexFlatL2(dim)
raise
def preprocessing(dim,fvec_file,index_type):
# index_type = 'hnsw'
# index_type = 'l2'
# f-string 방식 (python3 이상에서 지원)
index_file = f'{fvec_file}.{index_type}.index'
fvecs = np.memmap(fvec_file, dtype='float32', mode='r').view('float32').reshape(-1, dim)
if os.path.exists(index_file):
index = faiss.read_index(index_file)
if index_type == 'hnsw':
index.hnsw.efSearch = 256
else:
index = get_index(index_type, dim)
index = populate(index, fvecs)
faiss.write_index(index, index_file)
# print(index.ntotal)
return fvecs,index
def preprocessing_quary(dim,image_path,index_type):
# index_type = 'hnsw'
# index_type = 'l2'
fvecs = fvces_quary(image_path)
index = get_index(index_type, dim)
index = populate(index, fvecs)
return fvecs,index
def preprocess(img_path, input_shape):
img = tf.io.read_file(img_path)
img = tf.image.decode_jpeg(img, channels=input_shape[2])
img = tf.image.resize(img, input_shape[:2])
img = preprocess_input(img)
return img
def preprocess_pil(pil_img_data, input_shape):
pil_img = np.asarray(Image.open(pil_img_data))
pil_img = tf.image.resize(pil_img, input_shape[:2])
pil_img = preprocess_input(pil_img)
return pil_img
def fvces_quary(image_path):
batch_size = 100
input_shape = (224, 224, 3)
base = tf.keras.applications.MobileNetV2(input_shape=input_shape,
include_top=False,
weights='imagenet')
base.trainable = False
model = Model(inputs=base.input, outputs=layers.GlobalAveragePooling2D()(base.output))
list_ds = tf.data.Dataset.from_tensor_slices([image_path])
ds = list_ds.map(lambda x: preprocess(x, input_shape), num_parallel_calls=-1)
dataset = ds.batch(batch_size).prefetch(-1)
for batch in dataset:
fvecs = model.predict(batch)
return fvecs
# index_type = 'hnsw'
# index_type = 'l2'
def get_dataset_list(filepath):
content_list = []
import os
if os.path.isfile(filepath) and filepath.endswith(".txt"):
with open(filepath, 'r', encoding='utf-8') as file:
for line in file:
content_list.append(line.strip())
return content_list