116 lines
3.3 KiB
Python
116 lines
3.3 KiB
Python
|
|
"""
|
||
|
|
quary 이미지(이미지 경로) 입력받아 처리
|
||
|
|
"""
|
||
|
|
import os
|
||
|
|
import time
|
||
|
|
import math
|
||
|
|
import numpy as np
|
||
|
|
from sklearn.preprocessing import normalize
|
||
|
|
import faiss
|
||
|
|
import tensorflow as tf
|
||
|
|
import tensorflow.keras.layers as layers
|
||
|
|
from tensorflow.keras.models import Model
|
||
|
|
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
|
||
|
|
from PIL import Image
|
||
|
|
|
||
|
|
|
||
|
|
def populate(index, fvecs, batch_size=1000):
|
||
|
|
nloop = math.ceil(fvecs.shape[0] / batch_size)
|
||
|
|
for n in range(nloop):
|
||
|
|
s = time.time()
|
||
|
|
index.add(normalize(fvecs[n * batch_size : min((n + 1) * batch_size, fvecs.shape[0])]))
|
||
|
|
# print(n * batch_size, time.time() - s)
|
||
|
|
|
||
|
|
return index
|
||
|
|
|
||
|
|
def get_index(index_type, dim):
|
||
|
|
if index_type == 'hnsw':
|
||
|
|
m = 48
|
||
|
|
index = faiss.IndexHNSWFlat(dim, m)
|
||
|
|
index.hnsw.efConstruction = 128
|
||
|
|
return index
|
||
|
|
elif index_type == 'l2':
|
||
|
|
return faiss.IndexFlatL2(dim)
|
||
|
|
raise
|
||
|
|
|
||
|
|
def preprocessing(dim,fvec_file,index_type):
|
||
|
|
# index_type = 'hnsw'
|
||
|
|
# index_type = 'l2'
|
||
|
|
|
||
|
|
# f-string 방식 (python3 이상에서 지원)
|
||
|
|
index_file = f'{fvec_file}.{index_type}.index'
|
||
|
|
|
||
|
|
fvecs = np.memmap(fvec_file, dtype='float32', mode='r').view('float32').reshape(-1, dim)
|
||
|
|
|
||
|
|
if os.path.exists(index_file):
|
||
|
|
index = faiss.read_index(index_file)
|
||
|
|
if index_type == 'hnsw':
|
||
|
|
index.hnsw.efSearch = 256
|
||
|
|
else:
|
||
|
|
index = get_index(index_type, dim)
|
||
|
|
index = populate(index, fvecs)
|
||
|
|
faiss.write_index(index, index_file)
|
||
|
|
# print(index.ntotal)
|
||
|
|
|
||
|
|
return fvecs,index
|
||
|
|
|
||
|
|
def preprocessing_quary(dim,image_path,index_type):
|
||
|
|
# index_type = 'hnsw'
|
||
|
|
# index_type = 'l2'
|
||
|
|
|
||
|
|
fvecs = fvces_quary(image_path)
|
||
|
|
index = get_index(index_type, dim)
|
||
|
|
index = populate(index, fvecs)
|
||
|
|
|
||
|
|
return fvecs,index
|
||
|
|
|
||
|
|
def preprocess(img_path, input_shape):
|
||
|
|
|
||
|
|
img = tf.io.read_file(img_path)
|
||
|
|
img = tf.image.decode_jpeg(img, channels=input_shape[2])
|
||
|
|
img = tf.image.resize(img, input_shape[:2])
|
||
|
|
img = preprocess_input(img)
|
||
|
|
return img
|
||
|
|
|
||
|
|
def preprocess_pil(pil_img_data, input_shape):
|
||
|
|
|
||
|
|
pil_img = np.asarray(Image.open(pil_img_data))
|
||
|
|
pil_img = tf.image.resize(pil_img, input_shape[:2])
|
||
|
|
pil_img = preprocess_input(pil_img)
|
||
|
|
return pil_img
|
||
|
|
|
||
|
|
def fvces_quary(image_path):
|
||
|
|
|
||
|
|
batch_size = 100
|
||
|
|
input_shape = (224, 224, 3)
|
||
|
|
base = tf.keras.applications.MobileNetV2(input_shape=input_shape,
|
||
|
|
include_top=False,
|
||
|
|
weights='imagenet')
|
||
|
|
base.trainable = False
|
||
|
|
model = Model(inputs=base.input, outputs=layers.GlobalAveragePooling2D()(base.output))
|
||
|
|
|
||
|
|
list_ds = tf.data.Dataset.from_tensor_slices([image_path])
|
||
|
|
|
||
|
|
ds = list_ds.map(lambda x: preprocess(x, input_shape), num_parallel_calls=-1)
|
||
|
|
|
||
|
|
dataset = ds.batch(batch_size).prefetch(-1)
|
||
|
|
|
||
|
|
for batch in dataset:
|
||
|
|
fvecs = model.predict(batch)
|
||
|
|
return fvecs
|
||
|
|
|
||
|
|
|
||
|
|
# index_type = 'hnsw'
|
||
|
|
# index_type = 'l2'
|
||
|
|
|
||
|
|
def get_dataset_list(filepath):
|
||
|
|
content_list = []
|
||
|
|
import os
|
||
|
|
if os.path.isfile(filepath) and filepath.endswith(".txt"):
|
||
|
|
with open(filepath, 'r', encoding='utf-8') as file:
|
||
|
|
for line in file:
|
||
|
|
content_list.append(line.strip())
|
||
|
|
return content_list
|
||
|
|
|
||
|
|
|
||
|
|
|