edit: faiss 전용 rest 서버 추가

This commit is contained in:
2025-04-28 11:26:19 +09:00
parent 19e62f5724
commit 6b212125a4
76 changed files with 6014 additions and 92 deletions

View File

@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
"""
@File: date_utils.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: date-utility functions
"""
from datetime import datetime, date, timedelta
_TIMEDELTA = 9
class D:
def __init__(self, *args):
self.utc_now = datetime.utcnow()
# NOTE(hsj100): utc->kst
self.timedelta = _TIMEDELTA
@classmethod
def datetime(cls, diff: int=_TIMEDELTA) -> datetime:
return datetime.utcnow() + timedelta(hours=diff) if diff > 0 else datetime.now()
@classmethod
def date(cls, diff: int=_TIMEDELTA) -> date:
return cls.datetime(diff=diff).date()
@classmethod
def date_num(cls, diff: int=_TIMEDELTA) -> int:
return int(cls.date(diff=diff).strftime('%Y%m%d'))
@classmethod
def validate(cls, date_text):
try:
datetime.strptime(date_text, '%Y-%m-%d')
except ValueError:
raise ValueError('Incorrect data format, should be YYYY-MM-DD')
@classmethod
def date_str(cls, diff: int = _TIMEDELTA) -> str:
return cls.datetime(diff=diff).strftime('%Y-%m-%dT%H:%M:%S')
@classmethod
def check_expire_date(cls, expire_date: datetime):
td = expire_date - datetime.now()
timestamp = td.total_seconds()
return timestamp
@classmethod
def date_file_name(cls):
date = datetime.now()
return date.strftime('%y%m%d_%H%M%S.%s')
if __name__ == "__main__":
_date = D()

View File

@@ -0,0 +1,205 @@
# -*- coding: utf-8 -*-
"""
@File: extra.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: extra functions
"""
from hashlib import md5
from base64 import b64decode
from base64 import b64encode
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from cryptography.fernet import Fernet # symmetric encryption
# mail test
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from itertools import groupby
from operator import attrgetter
import uuid
from main_rest.app.common.consts import NUM_RETRY_UUID_GEN, SMTP_HOST, SMTP_PORT
from main_rest.app.utils.date_utils import D
from main_rest.app import models as M
from main_rest.app.common.consts import AES_CBC_PUBLIC_KEY, AES_CBC_IV, FERNET_SECRET_KEY
async def send_mail(sender, sender_pw, title, recipient, contents_plain, contents_html, cc_list, smtp_host=SMTP_HOST, smtp_port=SMTP_PORT):
"""
구글 계정사용시 : 보안 수준이 낮은 앱에서의 접근 활성화
:return:
None: success
Str. Message: error
"""
try:
# check parameters
if not sender:
raise Exception('invalid sender')
if not title:
raise Exception('invalid title')
if not recipient:
raise Exception('invalid recipient')
# sender info.
# sender = consts.ADMIN_INIT_ACCOUNT_INFO.email
# sender_pw = consts.ADMIN_INIT_ACCOUNT_INFO.email_pw
# message
msg = MIMEMultipart()
msg['From'] = sender
msg['To'] = recipient
if cc_list:
list_cc = cc_list
str_cc = ','.join(list_cc)
msg['Cc'] = str_cc
msg['Subject'] = title
if contents_plain:
msg.attach(MIMEText(contents_plain, 'plain'))
if contents_html:
msg.attach(MIMEText(contents_html, 'html'))
# smtp server
smtp_server = smtplib.SMTP(host=smtp_host, port=smtp_port)
smtp_server.ehlo()
smtp_server.starttls()
smtp_server.ehlo()
smtp_server.login(sender, sender_pw)
smtp_server.send_message(msg)
smtp_server.quit()
return None
except Exception as e:
return str(e)
def query_to_groupby(query_result, key, first=False):
"""
쿼리 결과물(list)을 항목값(key)으로 그룹화한다.
:param query_result: 쿼리 결과 리스트
:param key: 그룹 항목값
:return: dict
"""
group_info = dict()
for k, g in groupby(query_result, attrgetter(key)):
if k not in group_info:
if not first:
group_info[k] = list(g)
else:
group_info[k] = list(g)[0]
else:
if not first:
group_info[k].extend(list(g))
return group_info
def query_to_groupby_date(query_result, key):
"""
쿼리 결과물(list)을 항목값(key)으로 그룹화한다.
:param query_result: 쿼리 결과 리스트
:param key: 그룹 항목값
:return: dict
"""
group_info = dict()
for k, g in groupby(query_result, attrgetter(key)):
day_str = k.strftime("%Y-%m-%d")
if day_str not in group_info:
group_info[day_str] = list(g)
else:
group_info[day_str].extend(list(g))
return group_info
class FernetCrypto:
def __init__(self, key=FERNET_SECRET_KEY):
self.key = key
self.f = Fernet(self.key)
def encrypt(self, data, is_out_string=True):
if isinstance(data, bytes):
ou = self.f.encrypt(data) # 바이트형태이면 바로 암호화
else:
ou = self.f.encrypt(data.encode('utf-8')) # 인코딩 후 암호화
if is_out_string is True:
return ou.decode('utf-8') # 출력이 문자열이면 디코딩 후 반환
else:
return ou
def decrypt(self, data, is_out_string=True):
if isinstance(data, bytes):
ou = self.f.decrypt(data) # 바이트형태이면 바로 복호화
else:
ou = self.f.decrypt(data.encode('utf-8')) # 인코딩 후 복호화
if is_out_string is True:
return ou.decode('utf-8') # 출력이 문자열이면 디코딩 후 반환
else:
return ou
class AESCryptoCBC:
def __init__(self, key=AES_CBC_PUBLIC_KEY, iv=AES_CBC_IV):
# Initial vector를 0으로 초기화하여 16바이트 할당함
# iv = chr(0) * 16 #pycrypto 기준
# iv = bytes([0x00] * 16) #pycryptodomex 기준
# aes cbc 생성
self.key = key
self.iv = iv
self.crypto = AES.new(self.key, AES.MODE_CBC, self.iv)
def encrypt(self, data):
# 암호화 message는 16의 배수여야 한다.
# enc = self.crypto.encrypt(data)
# return enc
enc = self.crypto.encrypt(pad(data, AES.block_size))
return b64encode(enc)
def decrypt(self, enc):
# 복호화 enc는 16의 배수여야 한다.
# dec = self.crypto.decrypt(enc)
# return dec
enc = b64decode(enc)
dec = self.crypto.decrypt(enc)
return unpad(dec, AES.block_size)
class AESCipher:
def __init__(self, key):
# self.key = md5(key.encode('utf8')).digest()
self.key = bytes(key.encode('utf-8'))
def encrypt(self, data):
# iv = get_random_bytes(AES.block_size)
iv = bytes('daooldns12345678'.encode('utf-8'))
self.cipher = AES.new(self.key, AES.MODE_CBC, iv)
t = b64encode(self.cipher.encrypt(pad(data.encode('utf-8'), AES.block_size)))
return b64encode(iv + self.cipher.encrypt(pad(data.encode('utf-8'), AES.block_size)))
def decrypt(self, data):
raw = b64decode(data)
self.cipher = AES.new(self.key, AES.MODE_CBC, raw[:AES.block_size])
return unpad(self.cipher.decrypt(raw[AES.block_size:]), AES.block_size)
def cls_list_to_dict_list(list):
"""
list 내부 element가 dict로 변환 가능한 class일경우
내부 element를 dict 로 변경
"""
_result = []
for i in list:
if isinstance(i, dict):
_result = list
break
_result.append(i.dict())
return _result

View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
"""
@File: logger.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: logger
"""
import json
import logging
from datetime import timedelta, datetime
from time import time
from fastapi.requests import Request
from fastapi import Body
from fastapi.logger import logger
logger.setLevel(logging.INFO)
async def api_logger(request: Request, response=None, error=None):
time_format = '%Y/%m/%d %H:%M:%S'
t = time() - request.state.start
status_code = error.status_code if error else response.status_code
error_log = None
user = request.state.user
if error:
if request.state.inspect:
frame = request.state.inspect
error_file = frame.f_code.co_filename
error_func = frame.f_code.co_name
error_line = frame.f_lineno
else:
error_func = error_file = error_line = 'UNKNOWN'
error_log = dict(
errorFunc=error_func,
location='{} line in {}'.format(str(error_line), error_file),
raised=str(error.__class__.__name__),
msg=str(error.ex),
)
account = user.account.split('@') if user and user.account else None
user_log = dict(
client=request.state.ip,
user=user.id if user and user.id else None,
account='**' + account[0][2:-1] + '*@' + account[1] if user and user.account else None,
)
log_dict = dict(
url=request.url.hostname + request.url.path,
method=str(request.method),
statusCode=status_code,
errorDetail=error_log,
client=user_log,
processedTime=str(round(t * 1000, 5)) + 'ms',
datetimeUTC=datetime.utcnow().strftime(time_format),
datetimeKST=(datetime.utcnow() + timedelta(hours=9)).strftime(time_format),
)
if error and error.status_code >= 500:
logger.error(json.dumps(log_dict))
else:
logger.info(json.dumps(log_dict))

View File

@@ -0,0 +1,25 @@
from const import ILLEGAL_FILE_NAME
def prompt_to_filenames(prompt):
"""
prompt 에 사용할 수 없는 문자가 있으면 '_' 로 치환
"""
filename = ''
for i in prompt:
if i in ILLEGAL_FILE_NAME:
filename += '_'
else:
filename += i
return filename
def download_range(download_count:int,max=4):
_min = 1
_max = max
if _min <= download_count and download_count <= _max:
return True
return False

View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
"""
@File: query_utils.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: query-utility functions
"""
from typing import List
def to_dict(model, *args, exclude: List = None):
q_dict = {}
for c in model.__table__.columns:
if not args or c.name in args:
if not exclude or c.name not in exclude:
q_dict[c.name] = getattr(model, c.name)
return q_dict