edit: faiss 전용 rest 서버 추가

This commit is contained in:
2025-04-28 11:26:19 +09:00
parent 19e62f5724
commit 6b212125a4
76 changed files with 6014 additions and 92 deletions

View File

@@ -0,0 +1,142 @@
# -*- coding: utf-8 -*-
"""
@File: conn.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: DB Connections
"""
from fastapi import FastAPI
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
import logging
def _database_exist(engine, schema_name):
query = f'SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = "{schema_name}"'
with engine.connect() as conn:
result_proxy = conn.execute(query)
result = result_proxy.scalar()
return bool(result)
def _drop_database(engine, schema_name):
with engine.connect() as conn:
conn.execute(f'DROP DATABASE {schema_name};')
def _create_database(engine, schema_name):
with engine.connect() as conn:
conn.execute(f'CREATE DATABASE {schema_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_bin;')
class SQLAlchemy:
def __init__(self, app: FastAPI = None, **kwargs):
self._engine = None
self._session = None
if app is not None:
self.init_app(app=app, **kwargs)
def init_app(self, app: FastAPI, **kwargs):
"""
DB 초기화 함수
:param app: FastAPI 인스턴스
:param kwargs:
:return:
"""
database_url = kwargs.get('DB_URL')
pool_recycle = kwargs.setdefault('DB_POOL_RECYCLE', 900)
is_testing = kwargs.setdefault('TEST_MODE', False)
echo = kwargs.setdefault('DB_ECHO', True)
self._engine = create_engine(
database_url,
echo=echo,
pool_recycle=pool_recycle,
pool_pre_ping=True,
)
if is_testing: # create schema
db_url = self._engine.url
if db_url.host != 'localhost':
raise Exception('db host must be \'localhost\' in test environment')
except_schema_db_url = f'{db_url.drivername}://{db_url.username}:{db_url.password}@{db_url.host}:{db_url.port}'
schema_name = db_url.database
temp_engine = create_engine(except_schema_db_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True)
if _database_exist(temp_engine, schema_name):
_drop_database(temp_engine, schema_name)
_create_database(temp_engine, schema_name)
temp_engine.dispose()
else:
db_url = self._engine.url
except_schema_db_url = f'{db_url.drivername}://{db_url.username}:{db_url.password}@{db_url.host}:{db_url.port}'
schema_name = db_url.database
temp_engine = create_engine(except_schema_db_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True)
if not _database_exist(temp_engine, schema_name):
_create_database(temp_engine, schema_name)
Base.metadata.create_all(db.engine)
temp_engine.dispose()
self._session = sessionmaker(autocommit=False, autoflush=False, bind=self._engine)
# NOTE(hsj100): ADMINISTRATOR
create_admin(self._session)
@app.on_event('startup')
def startup():
self._engine.connect()
logging.info('DB connected.')
@app.on_event('shutdown')
def shutdown():
self._session.close_all()
self._engine.dispose()
logging.info('DB disconnected')
def get_db(self):
"""
요청마다 DB 세션 유지 함수
:return:
"""
if self._session is None:
raise Exception('must be called \'init_app\'')
db_session = None
try:
db_session = self._session()
yield db_session
finally:
db_session.close()
@property
def session(self):
return self.get_db
@property
def engine(self):
return self._engine
db = SQLAlchemy()
Base = declarative_base()
# NOTE(hsj100): ADMINISTRATOR
def create_admin(db_session):
import bcrypt
from main_rest.app.database.schema import Users
from main_rest.app.common.consts import ADMIN_INIT_ACCOUNT_INFO
session = db_session()
if not session:
raise Exception('cat`t create account of admin')
if Users.get(account=ADMIN_INIT_ACCOUNT_INFO.account):
return
admin = {**ADMIN_INIT_ACCOUNT_INFO.get_dict()}
Users.create(session=session, auto_commit=True, **admin)

View File

@@ -0,0 +1,300 @@
# -*- coding: utf-8 -*-
"""
@File: crud.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: CRUD
"""
import math
from datetime import datetime
from dateutil.relativedelta import relativedelta
from sqlalchemy import func, desc
from fastapi import APIRouter, Depends, Body
from sqlalchemy.orm import Session
from main_rest.app import models as M
from main_rest.app.database.conn import Base, db
from main_rest.app.database.schema import Users, UserLog
from main_rest.app.utils.extra import query_to_groupby, query_to_groupby_date
def get_month_info_list(start: datetime, end: datetime):
delta = relativedelta(end, start)
month_delta = 12 * delta.years + delta.months + (1 if delta.days > 0 else 0)
month_list = list()
for i in range(month_delta):
count = start + relativedelta(months=i)
month_info = dict()
month_info['period_title'] = count.strftime('%Y-%m')
month_info['year'] = count.year
month_info['month'] = count.month
month_info['start'] = (start + relativedelta(months=i)).replace(day=1)
month_info['end'] = (start + relativedelta(months=i + 1)).replace(day=1)
month_list.append(month_info)
return month_list
def request_parser(request_data: dict = None) -> dict:
"""
request information -> dict
:param request_data:
:return:
"""
result_dict = dict()
if not request_data:
return result_dict
for key, val in request_data.items():
if val is not None:
result_dict[key] = val if val != 'null' else None
return result_dict
def dict_to_filter_fmt(dict_data, get_attributes_callback=None):
"""
dict -> sqlalchemy filter (criterion: sql expression)
:param dict_data:
:param get_attributes_callback:
:return:
"""
if get_attributes_callback is None:
raise Exception('invalid get_attributes_callback')
criterion = list()
for key, val in dict_data.items():
key = key.split('__')
if len(key) > 2:
raise Exception('length of split(key) should be no more than 2.')
key_length = len(key)
col = get_attributes_callback(key[0])
if col is None:
continue
if key_length == 1:
criterion.append((col == val))
elif key_length == 2 and key[1] == 'gt':
criterion.append((col > val))
elif key_length == 2 and key[1] == 'gte':
criterion.append((col >= val))
elif key_length == 2 and key[1] == 'lt':
criterion.append((col < val))
elif key_length == 2 and key[1] == 'lte':
criterion.append((col <= val))
elif key_length == 2 and key[1] == 'in':
criterion.append((col.in_(val)))
elif key_length == 2 and key[1] == 'like':
criterion.append((col.like(val)))
return criterion
async def table_select(accessor_info, target_table, request_body_info, response_model, response_model_data):
"""
table_read
"""
try:
# parameter
if not accessor_info:
raise Exception('invalid accessor')
if not target_table:
raise Exception(f'invalid table_name:{target_table}')
# if not request_body_info:
# raise Exception('invalid request_body_info')
if not response_model:
raise Exception('invalid response_model')
if not response_model_data:
raise Exception('invalid response_model_data')
# paging - request
paging_request = None
if request_body_info:
if hasattr(request_body_info, 'paging'):
paging_request = request_body_info.paging
request_body_info = request_body_info.search
# request
request_info = request_parser(request_body_info.dict())
# search
criterion = None
if isinstance(request_body_info, M.UserLogDaySearchReq):
# UserLog search
# request
request_info = request_parser(request_body_info.dict())
# search UserLog
def get_attributes_callback(key: str):
return getattr(UserLog, key)
criterion = dict_to_filter_fmt(request_info, get_attributes_callback)
# search
session = next(db.session())
search_info = session.query(UserLog)\
.filter(UserLog.mac.isnot(None)
, UserLog.api == '/api/auth/login'
, UserLog.type == M.UserLogMessageType.info
, UserLog.message == 'ok'
, *criterion)\
.order_by(desc(UserLog.created_at), desc(UserLog.updated_at))\
.all()
if not search_info:
raise Exception('not found data')
group_by_day = query_to_groupby_date(search_info, 'updated_at')
# result
result_info = list()
for day, info_list in group_by_day.items():
info_by_mac = query_to_groupby(info_list, 'mac', first=True)
for log_info in info_by_mac.values():
result_info.append(response_model_data.from_orm(log_info))
else:
# basic search (single table)
# request
request_info = request_parser(request_body_info.dict())
# search
search_info = target_table.filter(**request_info).all()
if not search_info:
raise Exception('not found data')
# result
result_info = list()
for purchase_info in search_info:
result_info.append(response_model_data.from_orm(purchase_info))
# response - paging
paging_response = None
if paging_request:
total_contents_num = len(result_info)
total_page_num = math.ceil(total_contents_num / paging_request.page_contents_num)
start_contents_index = (paging_request.start_page - 1) * paging_request.page_contents_num
end_contents_index = start_contents_index + paging_request.page_contents_num
if end_contents_index < total_contents_num:
result_info = result_info[start_contents_index: end_contents_index]
else:
result_info = result_info[start_contents_index:]
paging_response = M.PagingRes()
paging_response.total_page_num = total_page_num
paging_response.total_contents_num = total_contents_num
paging_response.start_page = paging_request.start_page
paging_response.search_contents_num = len(result_info)
return response_model(data=result_info, paging=paging_response)
except Exception as e:
return response_model.set_error(str(e))
async def table_update(accessor_info, target_table, request_body_info, response_model, response_model_data=None):
search_info = None
try:
# parameter
if not accessor_info:
raise Exception('invalid accessor')
if not target_table:
raise Exception(f'invalid table_name:{target_table}')
if not request_body_info:
raise Exception('invalid request_body_info')
if not response_model:
raise Exception('invalid response_model')
# if not response_model_data:
# raise Exception('invalid response_model_data')
# request
if not request_body_info.search_info:
raise Exception('invalid request_body: search_info')
request_search_info = request_parser(request_body_info.search_info.dict())
if not request_search_info:
raise Exception('invalid request_body: search_info')
if not request_body_info.update_info:
raise Exception('invalid request_body: update_info')
request_update_info = request_parser(request_body_info.update_info.dict())
if not request_update_info:
raise Exception('invalid request_body: update_info')
# search
search_info = target_table.filter(**request_search_info)
# process
search_info.update(auto_commit=True, synchronize_session=False, **request_update_info)
# result
return response_model()
except Exception as e:
if search_info:
search_info.close()
return response_model.set_error(str(e))
async def table_delete(accessor_info, target_table, request_body_info, response_model, response_model_data=None):
search_info = None
try:
# request
if not accessor_info:
raise Exception('invalid accessor')
if not target_table:
raise Exception(f'invalid table_name:{target_table}')
if not request_body_info:
raise Exception('invalid request_body_info')
if not response_model:
raise Exception('invalid response_model')
# if not response_model_data:
# raise Exception('invalid response_model_data')
# request
request_search_info = request_parser(request_body_info.dict())
if not request_search_info:
raise Exception('invalid request_body')
# search
search_info = target_table.filter(**request_search_info)
temp_search = search_info.all()
# process
search_info.delete(auto_commit=True, synchronize_session=False)
# update license num
uuid_list = list()
for _license in temp_search:
if not hasattr(temp_search, 'uuid'):
# case: license
break
if _license.uuid not in uuid_list:
uuid_list.append(_license.uuid)
license_num = target_table.filter(uuid=_license.uuid).count()
target_table.filter(uuid=_license.uuid).update(auto_commit=True, synchronize_session=False, num=license_num)
# result
return response_model()
except Exception as e:
if search_info:
search_info.close()
return response_model.set_error(str(e))

View File

@@ -0,0 +1,304 @@
# -*- coding: utf-8 -*-
"""
@File: schema.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: database schema
"""
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
func,
Enum,
Boolean,
ForeignKey,
)
from sqlalchemy.orm import Session, relationship
from main_rest.app.database.conn import Base, db
from main_rest.app.utils.date_utils import D
from main_rest.app.models import (
SexType,
UserType,
MemberType,
AccountType,
UserStatusType,
UserLoginType,
UserLogMessageType
)
class BaseMixin:
id = Column(Integer, primary_key=True, index=True)
created_at = Column(DateTime, nullable=False, default=D.datetime())
updated_at = Column(DateTime, nullable=False, default=D.datetime(), onupdate=D.datetime())
def __init__(self):
self._q = None
self._session = None
self.served = None
def all_columns(self):
return [c for c in self.__table__.columns if c.primary_key is False and c.name != 'created_at']
def __hash__(self):
return hash(self.id)
@classmethod
def create(cls, session: Session, auto_commit=False, **kwargs):
"""
테이블 데이터 적재 전용 함수
:param session:
:param auto_commit: 자동 커밋 여부
:param kwargs: 적재 할 데이터
:return:
"""
obj = cls()
# NOTE(hsj100) : FIX_DATETIME
if 'created_at' not in kwargs:
obj.created_at = D.datetime()
if 'updated_at' not in kwargs:
obj.updated_at = D.datetime()
for col in obj.all_columns():
col_name = col.name
if col_name in kwargs:
setattr(obj, col_name, kwargs.get(col_name))
session.add(obj)
session.flush()
if auto_commit:
session.commit()
return obj
@classmethod
def get(cls, session: Session = None, **kwargs):
"""
Simply get a Row
:param session:
:param kwargs:
:return:
"""
sess = next(db.session()) if not session else session
query = sess.query(cls)
for key, val in kwargs.items():
col = getattr(cls, key)
query = query.filter(col == val)
if query.count() > 1:
raise Exception('Only one row is supposed to be returned, but got more than one.')
result = query.first()
if not session:
sess.close()
return result
@classmethod
def filter(cls, session: Session = None, **kwargs):
"""
Simply get a Row
:param session:
:param kwargs:
:return:
"""
cond = []
for key, val in kwargs.items():
key = key.split('__')
if len(key) > 2:
raise Exception('length of split(key) should be no more than 2.')
col = getattr(cls, key[0])
if len(key) == 1: cond.append((col == val))
elif len(key) == 2 and key[1] == 'gt': cond.append((col > val))
elif len(key) == 2 and key[1] == 'gte': cond.append((col >= val))
elif len(key) == 2 and key[1] == 'lt': cond.append((col < val))
elif len(key) == 2 and key[1] == 'lte': cond.append((col <= val))
elif len(key) == 2 and key[1] == 'in': cond.append((col.in_(val)))
elif len(key) == 2 and key[1] == 'like': cond.append((col.like(val)))
obj = cls()
if session:
obj._session = session
obj.served = True
else:
obj._session = next(db.session())
obj.served = False
query = obj._session.query(cls)
query = query.filter(*cond)
obj._q = query
return obj
@classmethod
def cls_attr(cls, col_name=None):
if col_name:
col = getattr(cls, col_name)
return col
else:
return cls
def order_by(self, *args: str):
for a in args:
if a.startswith('-'):
col_name = a[1:]
is_asc = False
else:
col_name = a
is_asc = True
col = self.cls_attr(col_name)
self._q = self._q.order_by(col.asc()) if is_asc else self._q.order_by(col.desc())
return self
def update(self, auto_commit: bool = False, synchronize_session='evaluate', **kwargs):
# NOTE(hsj100) : FIX_DATETIME
if 'updated_at' not in kwargs:
kwargs['updated_at'] = D.datetime()
qs = self._q.update(kwargs, synchronize_session=synchronize_session)
get_id = self.id
ret = None
self._session.flush()
if qs > 0 :
ret = self._q.first()
if auto_commit:
self._session.commit()
return ret
def first(self):
result = self._q.first()
self.close()
return result
def delete(self, auto_commit: bool = False, synchronize_session='evaluate'):
self._q.delete(synchronize_session=synchronize_session)
if auto_commit:
self._session.commit()
def all(self):
print(self.served)
result = self._q.all()
self.close()
return result
def count(self):
result = self._q.count()
self.close()
return result
def close(self):
if not self.served:
self._session.close()
else:
self._session.flush()
class ApiKeys(Base, BaseMixin):
__tablename__ = 'api_keys'
__table_args__ = {'extend_existing': True}
access_key = Column(String(length=64), nullable=False, index=True)
secret_key = Column(String(length=64), nullable=False)
user_memo = Column(String(length=40), nullable=True)
status = Column(Enum('active', 'stopped', 'deleted'), default='active')
is_whitelisted = Column(Boolean, default=False)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
whitelist = relationship('ApiWhiteLists', backref='api_keys')
users = relationship('Users', back_populates='keys')
class ApiWhiteLists(Base, BaseMixin):
__tablename__ = 'api_whitelists'
__table_args__ = {'extend_existing': True}
ip_addr = Column(String(length=64), nullable=False)
api_key_id = Column(Integer, ForeignKey('api_keys.id'), nullable=False)
class Users(Base, BaseMixin):
__tablename__ = 'users'
__table_args__ = {'extend_existing': True}
status = Column(Enum(UserStatusType), nullable=False, default=UserStatusType.active)
user_type = Column(Enum(UserType), nullable=False, default=UserType.user)
account_type = Column(Enum(AccountType), nullable=False, default=AccountType.email)
account = Column(String(length=255), nullable=False, unique=True)
pw = Column(String(length=2000), nullable=True)
email = Column(String(length=255), nullable=True)
name = Column(String(length=255), nullable=False)
sex = Column(Enum(SexType), nullable=False, default=SexType.male)
rrn = Column(String(length=255), nullable=True)
address = Column(String(length=2000), nullable=True)
phone_number = Column(String(length=20), nullable=True)
picture = Column(String(length=1000), nullable=True)
marketing_agree = Column(Boolean, nullable=False, default=False)
keys = relationship('ApiKeys', back_populates='users')
# extra
login = Column(Enum(UserLoginType), nullable=False, default=UserLoginType.logout) # TODO(hsj100): LOGIN_STATUS
member_type = Column(Enum(MemberType), nullable=False, default=MemberType.personal)
# uuid = Column(String(length=36), nullable=True, unique=True)
class UserLog(Base, BaseMixin):
__tablename__ = 'userlog'
__table_args__ = {'extend_existing': True}
account = Column(String(length=255), nullable=False)
mac = Column(String(length=255), nullable=True)
type = Column(Enum(UserLogMessageType), nullable=False, default=UserLogMessageType.info)
api = Column(String(length=511), nullable=False)
message = Column(String(length=5000), nullable=True)
# NOTE(hsj100): 다단계 자동 삭제
# user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'), nullable=False)
# user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
# users = relationship('Users', back_populates='userlog')
# users = relationship('Users', back_populates='userlog', lazy=False)
# ===============================================================================================
class CustomBaseMixin(BaseMixin):
id = Column(Integer, primary_key=True, index=True)
def all_columns(self):
return [c for c in self.__table__.columns if c.primary_key is False]
@classmethod
def create(cls, session: Session, auto_commit=False, **kwargs):
"""
테이블 데이터 적재 전용 함수
:param session:
:param auto_commit: 자동 커밋 여부
:param kwargs: 적재 할 데이터
:return:
"""
obj = cls()
for col in obj.all_columns():
col_name = col.name
if col_name in kwargs:
setattr(obj, col_name, kwargs.get(col_name))
session.add(obj)
session.flush()
if auto_commit:
session.commit()
return obj
def update(self, auto_commit: bool = False, synchronize_session='evaluate', **kwargs):
qs = self._q.update(kwargs, synchronize_session=synchronize_session)
get_id = self.id
ret = None
self._session.flush()
if qs > 0 :
ret = self._q.first()
if auto_commit:
self._session.commit()
return ret