50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
import os
|
|
import shutil
|
|
|
|
from custom_apps.faiss.utils import preprocessing, preprocessing_quary, normalize, get_dataset_list
|
|
from custom_apps.faiss.const import *
|
|
# from rest.app.utils.date_utils import D
|
|
# from const import OUTPUT_FOLDER
|
|
|
|
dataset_list = get_dataset_list(DATASET_TEXT)
|
|
|
|
def search_idxs(image_path,dataset_bin=DATASET_BIN,index_type="hnsw",search_num=4):
|
|
|
|
if index_type not in ["hnsw", "l2"]:
|
|
raise ValueError("index_type must be either 'hnsw' or 'l2'")
|
|
|
|
DIM = 1280
|
|
|
|
dataset_fvces, dataset_index = preprocessing(DIM,dataset_bin,index_type)
|
|
|
|
org_fvces, org_index = preprocessing_quary(DIM,image_path,index_type)
|
|
|
|
dists, idxs = dataset_index.search(normalize(org_fvces), search_num)
|
|
|
|
# print(dists[0])
|
|
# print(idxs[0])
|
|
|
|
index_image_save(image_path, dists[0], idxs[0])
|
|
|
|
def index_image_save(query_image_path, dists, idxs):
|
|
directory_path, file = os.path.split(query_image_path)
|
|
_name, extension = os.path.splitext(file)
|
|
|
|
if not os.path.exists(directory_path):
|
|
raise ValueError(f"Folder {directory_path} does not exist.")
|
|
|
|
for dist,index in zip(dists, idxs):
|
|
|
|
if dist > 1:
|
|
dist = 0
|
|
else:
|
|
dist = 1-dist
|
|
|
|
origin_file_path = dataset_list[index]
|
|
dest_file_path = os.path.join(directory_path, f"search_{round(float(dist),4)}{extension}")
|
|
shutil.copy(origin_file_path, dest_file_path)
|
|
|
|
|
|
|
|
|
|
|