edit: faiss 전용 rest 서버 추가

This commit is contained in:
2025-04-28 11:26:19 +09:00
parent 19e62f5724
commit 6b212125a4
76 changed files with 6014 additions and 92 deletions

View File

@@ -4,8 +4,8 @@ from pathlib import Path
from bingart import BingArt
from custom_apps.utils import cookie_manager
from rest.app.utils.parsing_utils import prompt_to_filenames
from rest.app.utils.date_utils import D
from main_rest.app.utils.parsing_utils import prompt_to_filenames
from main_rest.app.utils.date_utils import D
from const import OUTPUT_FOLDER

View File

@@ -16,8 +16,8 @@ import pkg_resources
import regex
import requests
from rest.app.utils.parsing_utils import prompt_to_filenames
from rest.app.utils.date_utils import D
from main_rest.app.utils.parsing_utils import prompt_to_filenames
from main_rest.app.utils.date_utils import D
BING_URL = os.getenv("BING_URL", "https://www.bing.com")
# Generate random IP between range 13.104.0.0/14

View File

@@ -0,0 +1,2 @@
DATASET_BIN = "./datas/eyewear_all.fvecs.bin"
DATASET_TEXT = "./datas/eyewear_all.fnames.txt"

50
custom_apps/faiss/main.py Normal file
View File

@@ -0,0 +1,50 @@
import os
import shutil
from custom_apps.faiss.utils import preprocessing, preprocessing_quary, normalize, get_dataset_list
from custom_apps.faiss.const import *
# from rest.app.utils.date_utils import D
# from const import OUTPUT_FOLDER
dataset_list = get_dataset_list(DATASET_TEXT)
def search_idxs(image_path,dataset_bin=DATASET_BIN,index_type="hnsw",search_num=4):
if index_type not in ["hnsw", "l2"]:
raise ValueError("index_type must be either 'hnsw' or 'l2'")
DIM = 1280
dataset_fvces, dataset_index = preprocessing(DIM,dataset_bin,index_type)
org_fvces, org_index = preprocessing_quary(DIM,image_path,index_type)
dists, idxs = dataset_index.search(normalize(org_fvces), search_num)
# print(dists[0])
# print(idxs[0])
index_image_save(image_path, dists[0], idxs[0])
def index_image_save(query_image_path, dists, idxs):
directory_path, file = os.path.split(query_image_path)
_name, extension = os.path.splitext(file)
if not os.path.exists(directory_path):
raise ValueError(f"Folder {directory_path} does not exist.")
for dist,index in zip(dists, idxs):
if dist > 1:
dist = 0
else:
dist = 1-dist
origin_file_path = dataset_list[index]
dest_file_path = os.path.join(directory_path, f"search_{round(float(dist),4)}{extension}")
shutil.copy(origin_file_path, dest_file_path)

116
custom_apps/faiss/utils.py Normal file
View File

@@ -0,0 +1,116 @@
"""
quary 이미지(이미지 경로) 입력받아 처리
"""
import os
import time
import math
import numpy as np
from sklearn.preprocessing import normalize
import faiss
import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras.models import Model
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from PIL import Image
def populate(index, fvecs, batch_size=1000):
nloop = math.ceil(fvecs.shape[0] / batch_size)
for n in range(nloop):
s = time.time()
index.add(normalize(fvecs[n * batch_size : min((n + 1) * batch_size, fvecs.shape[0])]))
# print(n * batch_size, time.time() - s)
return index
def get_index(index_type, dim):
if index_type == 'hnsw':
m = 48
index = faiss.IndexHNSWFlat(dim, m)
index.hnsw.efConstruction = 128
return index
elif index_type == 'l2':
return faiss.IndexFlatL2(dim)
raise
def preprocessing(dim,fvec_file,index_type):
# index_type = 'hnsw'
# index_type = 'l2'
# f-string 방식 (python3 이상에서 지원)
index_file = f'{fvec_file}.{index_type}.index'
fvecs = np.memmap(fvec_file, dtype='float32', mode='r').view('float32').reshape(-1, dim)
if os.path.exists(index_file):
index = faiss.read_index(index_file)
if index_type == 'hnsw':
index.hnsw.efSearch = 256
else:
index = get_index(index_type, dim)
index = populate(index, fvecs)
faiss.write_index(index, index_file)
# print(index.ntotal)
return fvecs,index
def preprocessing_quary(dim,image_path,index_type):
# index_type = 'hnsw'
# index_type = 'l2'
fvecs = fvces_quary(image_path)
index = get_index(index_type, dim)
index = populate(index, fvecs)
return fvecs,index
def preprocess(img_path, input_shape):
img = tf.io.read_file(img_path)
img = tf.image.decode_jpeg(img, channels=input_shape[2])
img = tf.image.resize(img, input_shape[:2])
img = preprocess_input(img)
return img
def preprocess_pil(pil_img_data, input_shape):
pil_img = np.asarray(Image.open(pil_img_data))
pil_img = tf.image.resize(pil_img, input_shape[:2])
pil_img = preprocess_input(pil_img)
return pil_img
def fvces_quary(image_path):
batch_size = 100
input_shape = (224, 224, 3)
base = tf.keras.applications.MobileNetV2(input_shape=input_shape,
include_top=False,
weights='imagenet')
base.trainable = False
model = Model(inputs=base.input, outputs=layers.GlobalAveragePooling2D()(base.output))
list_ds = tf.data.Dataset.from_tensor_slices([image_path])
ds = list_ds.map(lambda x: preprocess(x, input_shape), num_parallel_calls=-1)
dataset = ds.batch(batch_size).prefetch(-1)
for batch in dataset:
fvecs = model.predict(batch)
return fvecs
# index_type = 'hnsw'
# index_type = 'l2'
def get_dataset_list(filepath):
content_list = []
import os
if os.path.isfile(filepath) and filepath.endswith(".txt"):
with open(filepath, 'r', encoding='utf-8') as file:
for line in file:
content_list.append(line.strip())
return content_list

View File

@@ -5,8 +5,8 @@ import os
from vertexai.preview.vision_models import ImageGenerationModel
from const import OUTPUT_FOLDER
from rest.app.utils.parsing_utils import prompt_to_filenames
from rest.app.utils.date_utils import D
from main_rest.app.utils.parsing_utils import prompt_to_filenames
from main_rest.app.utils.date_utils import D
class ImagenConst:
@@ -48,4 +48,44 @@ def imagen_generate_image(prompt,download_count=1):
for i in range(len(images.images)):
images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i+1}_{_datetime}.png"), include_generation_parameters=False)
return len(images.images)
return len(images.images)
def imagen_generate_image_data(prompt,download_count=1):
vertexai.init(project=ImagenConst.project_id, location=ImagenConst.location)
model = ImageGenerationModel.from_pretrained(ImagenConst.model)
images = model.generate_images(
prompt=prompt,
# Optional parameters
number_of_images=download_count,
language="ko",
# You can't use a seed value and watermark at the same time.
# add_watermark=False,
# seed=100,
aspect_ratio="1:1",
safety_filter_level="block_some",
person_generation="dont_allow",
)
return images.images[0]._pil_image
def imagen_generate_image_path(image_prompt):
MODEL = "imagen"
QUERY = "query"
create_time = D.date_file_name()
folder_name = os.path.join(OUTPUT_FOLDER,f"{MODEL}_{QUERY}_{image_prompt}_{create_time}")
if not os.path.exists(folder_name):
os.makedirs(folder_name)
generate_img = imagen_generate_image_data(image_prompt)
generate_img.save(os.path.join(folder_name,f"query.png"))
return os.path.join(folder_name,f"query.png")
if __name__ == '__main__':
pass
# imagen_generate_image_data("cat")