edit: faiss 전용 rest 서버 추가
This commit is contained in:
@@ -4,8 +4,8 @@ from pathlib import Path
|
||||
from bingart import BingArt
|
||||
|
||||
from custom_apps.utils import cookie_manager
|
||||
from rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from rest.app.utils.date_utils import D
|
||||
from main_rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from main_rest.app.utils.date_utils import D
|
||||
from const import OUTPUT_FOLDER
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ import pkg_resources
|
||||
import regex
|
||||
import requests
|
||||
|
||||
from rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from rest.app.utils.date_utils import D
|
||||
from main_rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from main_rest.app.utils.date_utils import D
|
||||
|
||||
BING_URL = os.getenv("BING_URL", "https://www.bing.com")
|
||||
# Generate random IP between range 13.104.0.0/14
|
||||
|
||||
2
custom_apps/faiss/const.py
Normal file
2
custom_apps/faiss/const.py
Normal file
@@ -0,0 +1,2 @@
|
||||
DATASET_BIN = "./datas/eyewear_all.fvecs.bin"
|
||||
DATASET_TEXT = "./datas/eyewear_all.fnames.txt"
|
||||
50
custom_apps/faiss/main.py
Normal file
50
custom_apps/faiss/main.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from custom_apps.faiss.utils import preprocessing, preprocessing_quary, normalize, get_dataset_list
|
||||
from custom_apps.faiss.const import *
|
||||
# from rest.app.utils.date_utils import D
|
||||
# from const import OUTPUT_FOLDER
|
||||
|
||||
dataset_list = get_dataset_list(DATASET_TEXT)
|
||||
|
||||
def search_idxs(image_path,dataset_bin=DATASET_BIN,index_type="hnsw",search_num=4):
|
||||
|
||||
if index_type not in ["hnsw", "l2"]:
|
||||
raise ValueError("index_type must be either 'hnsw' or 'l2'")
|
||||
|
||||
DIM = 1280
|
||||
|
||||
dataset_fvces, dataset_index = preprocessing(DIM,dataset_bin,index_type)
|
||||
|
||||
org_fvces, org_index = preprocessing_quary(DIM,image_path,index_type)
|
||||
|
||||
dists, idxs = dataset_index.search(normalize(org_fvces), search_num)
|
||||
|
||||
# print(dists[0])
|
||||
# print(idxs[0])
|
||||
|
||||
index_image_save(image_path, dists[0], idxs[0])
|
||||
|
||||
def index_image_save(query_image_path, dists, idxs):
|
||||
directory_path, file = os.path.split(query_image_path)
|
||||
_name, extension = os.path.splitext(file)
|
||||
|
||||
if not os.path.exists(directory_path):
|
||||
raise ValueError(f"Folder {directory_path} does not exist.")
|
||||
|
||||
for dist,index in zip(dists, idxs):
|
||||
|
||||
if dist > 1:
|
||||
dist = 0
|
||||
else:
|
||||
dist = 1-dist
|
||||
|
||||
origin_file_path = dataset_list[index]
|
||||
dest_file_path = os.path.join(directory_path, f"search_{round(float(dist),4)}{extension}")
|
||||
shutil.copy(origin_file_path, dest_file_path)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
116
custom_apps/faiss/utils.py
Normal file
116
custom_apps/faiss/utils.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
quary 이미지(이미지 경로) 입력받아 처리
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import math
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import normalize
|
||||
import faiss
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras.layers as layers
|
||||
from tensorflow.keras.models import Model
|
||||
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def populate(index, fvecs, batch_size=1000):
|
||||
nloop = math.ceil(fvecs.shape[0] / batch_size)
|
||||
for n in range(nloop):
|
||||
s = time.time()
|
||||
index.add(normalize(fvecs[n * batch_size : min((n + 1) * batch_size, fvecs.shape[0])]))
|
||||
# print(n * batch_size, time.time() - s)
|
||||
|
||||
return index
|
||||
|
||||
def get_index(index_type, dim):
|
||||
if index_type == 'hnsw':
|
||||
m = 48
|
||||
index = faiss.IndexHNSWFlat(dim, m)
|
||||
index.hnsw.efConstruction = 128
|
||||
return index
|
||||
elif index_type == 'l2':
|
||||
return faiss.IndexFlatL2(dim)
|
||||
raise
|
||||
|
||||
def preprocessing(dim,fvec_file,index_type):
|
||||
# index_type = 'hnsw'
|
||||
# index_type = 'l2'
|
||||
|
||||
# f-string 방식 (python3 이상에서 지원)
|
||||
index_file = f'{fvec_file}.{index_type}.index'
|
||||
|
||||
fvecs = np.memmap(fvec_file, dtype='float32', mode='r').view('float32').reshape(-1, dim)
|
||||
|
||||
if os.path.exists(index_file):
|
||||
index = faiss.read_index(index_file)
|
||||
if index_type == 'hnsw':
|
||||
index.hnsw.efSearch = 256
|
||||
else:
|
||||
index = get_index(index_type, dim)
|
||||
index = populate(index, fvecs)
|
||||
faiss.write_index(index, index_file)
|
||||
# print(index.ntotal)
|
||||
|
||||
return fvecs,index
|
||||
|
||||
def preprocessing_quary(dim,image_path,index_type):
|
||||
# index_type = 'hnsw'
|
||||
# index_type = 'l2'
|
||||
|
||||
fvecs = fvces_quary(image_path)
|
||||
index = get_index(index_type, dim)
|
||||
index = populate(index, fvecs)
|
||||
|
||||
return fvecs,index
|
||||
|
||||
def preprocess(img_path, input_shape):
|
||||
|
||||
img = tf.io.read_file(img_path)
|
||||
img = tf.image.decode_jpeg(img, channels=input_shape[2])
|
||||
img = tf.image.resize(img, input_shape[:2])
|
||||
img = preprocess_input(img)
|
||||
return img
|
||||
|
||||
def preprocess_pil(pil_img_data, input_shape):
|
||||
|
||||
pil_img = np.asarray(Image.open(pil_img_data))
|
||||
pil_img = tf.image.resize(pil_img, input_shape[:2])
|
||||
pil_img = preprocess_input(pil_img)
|
||||
return pil_img
|
||||
|
||||
def fvces_quary(image_path):
|
||||
|
||||
batch_size = 100
|
||||
input_shape = (224, 224, 3)
|
||||
base = tf.keras.applications.MobileNetV2(input_shape=input_shape,
|
||||
include_top=False,
|
||||
weights='imagenet')
|
||||
base.trainable = False
|
||||
model = Model(inputs=base.input, outputs=layers.GlobalAveragePooling2D()(base.output))
|
||||
|
||||
list_ds = tf.data.Dataset.from_tensor_slices([image_path])
|
||||
|
||||
ds = list_ds.map(lambda x: preprocess(x, input_shape), num_parallel_calls=-1)
|
||||
|
||||
dataset = ds.batch(batch_size).prefetch(-1)
|
||||
|
||||
for batch in dataset:
|
||||
fvecs = model.predict(batch)
|
||||
return fvecs
|
||||
|
||||
|
||||
# index_type = 'hnsw'
|
||||
# index_type = 'l2'
|
||||
|
||||
def get_dataset_list(filepath):
|
||||
content_list = []
|
||||
import os
|
||||
if os.path.isfile(filepath) and filepath.endswith(".txt"):
|
||||
with open(filepath, 'r', encoding='utf-8') as file:
|
||||
for line in file:
|
||||
content_list.append(line.strip())
|
||||
return content_list
|
||||
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ import os
|
||||
from vertexai.preview.vision_models import ImageGenerationModel
|
||||
|
||||
from const import OUTPUT_FOLDER
|
||||
from rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from rest.app.utils.date_utils import D
|
||||
from main_rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from main_rest.app.utils.date_utils import D
|
||||
|
||||
|
||||
class ImagenConst:
|
||||
@@ -48,4 +48,44 @@ def imagen_generate_image(prompt,download_count=1):
|
||||
for i in range(len(images.images)):
|
||||
images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i+1}_{_datetime}.png"), include_generation_parameters=False)
|
||||
|
||||
return len(images.images)
|
||||
return len(images.images)
|
||||
|
||||
def imagen_generate_image_data(prompt,download_count=1):
|
||||
vertexai.init(project=ImagenConst.project_id, location=ImagenConst.location)
|
||||
|
||||
model = ImageGenerationModel.from_pretrained(ImagenConst.model)
|
||||
|
||||
images = model.generate_images(
|
||||
prompt=prompt,
|
||||
# Optional parameters
|
||||
number_of_images=download_count,
|
||||
language="ko",
|
||||
# You can't use a seed value and watermark at the same time.
|
||||
# add_watermark=False,
|
||||
# seed=100,
|
||||
aspect_ratio="1:1",
|
||||
safety_filter_level="block_some",
|
||||
person_generation="dont_allow",
|
||||
)
|
||||
|
||||
return images.images[0]._pil_image
|
||||
|
||||
def imagen_generate_image_path(image_prompt):
|
||||
|
||||
MODEL = "imagen"
|
||||
QUERY = "query"
|
||||
create_time = D.date_file_name()
|
||||
|
||||
folder_name = os.path.join(OUTPUT_FOLDER,f"{MODEL}_{QUERY}_{image_prompt}_{create_time}")
|
||||
if not os.path.exists(folder_name):
|
||||
os.makedirs(folder_name)
|
||||
|
||||
generate_img = imagen_generate_image_data(image_prompt)
|
||||
|
||||
generate_img.save(os.path.join(folder_name,f"query.png"))
|
||||
|
||||
return os.path.join(folder_name,f"query.png")
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
# imagen_generate_image_data("cat")
|
||||
Reference in New Issue
Block a user