edit : 기본 rest서버구조 추가

This commit is contained in:
2025-01-14 17:26:27 +09:00
parent e53282d7e7
commit bd28fcdca4
37 changed files with 4194 additions and 1 deletions

142
rest/app/database/conn.py Normal file
View File

@@ -0,0 +1,142 @@
# -*- coding: utf-8 -*-
"""
@File: conn.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: DB Connections
"""
from fastapi import FastAPI
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
import logging
def _database_exist(engine, schema_name):
query = f'SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = "{schema_name}"'
with engine.connect() as conn:
result_proxy = conn.execute(query)
result = result_proxy.scalar()
return bool(result)
def _drop_database(engine, schema_name):
with engine.connect() as conn:
conn.execute(f'DROP DATABASE {schema_name};')
def _create_database(engine, schema_name):
with engine.connect() as conn:
conn.execute(f'CREATE DATABASE {schema_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_bin;')
class SQLAlchemy:
def __init__(self, app: FastAPI = None, **kwargs):
self._engine = None
self._session = None
if app is not None:
self.init_app(app=app, **kwargs)
def init_app(self, app: FastAPI, **kwargs):
"""
DB 초기화 함수
:param app: FastAPI 인스턴스
:param kwargs:
:return:
"""
database_url = kwargs.get('DB_URL')
pool_recycle = kwargs.setdefault('DB_POOL_RECYCLE', 900)
is_testing = kwargs.setdefault('TEST_MODE', False)
echo = kwargs.setdefault('DB_ECHO', True)
self._engine = create_engine(
database_url,
echo=echo,
pool_recycle=pool_recycle,
pool_pre_ping=True,
)
if is_testing: # create schema
db_url = self._engine.url
if db_url.host != 'localhost':
raise Exception('db host must be \'localhost\' in test environment')
except_schema_db_url = f'{db_url.drivername}://{db_url.username}:{db_url.password}@{db_url.host}:{db_url.port}'
schema_name = db_url.database
temp_engine = create_engine(except_schema_db_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True)
if _database_exist(temp_engine, schema_name):
_drop_database(temp_engine, schema_name)
_create_database(temp_engine, schema_name)
temp_engine.dispose()
else:
db_url = self._engine.url
except_schema_db_url = f'{db_url.drivername}://{db_url.username}:{db_url.password}@{db_url.host}:{db_url.port}'
schema_name = db_url.database
temp_engine = create_engine(except_schema_db_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True)
if not _database_exist(temp_engine, schema_name):
_create_database(temp_engine, schema_name)
Base.metadata.create_all(db.engine)
temp_engine.dispose()
self._session = sessionmaker(autocommit=False, autoflush=False, bind=self._engine)
# NOTE(hsj100): ADMINISTRATOR
create_admin(self._session)
@app.on_event('startup')
def startup():
self._engine.connect()
logging.info('DB connected.')
@app.on_event('shutdown')
def shutdown():
self._session.close_all()
self._engine.dispose()
logging.info('DB disconnected')
def get_db(self):
"""
요청마다 DB 세션 유지 함수
:return:
"""
if self._session is None:
raise Exception('must be called \'init_app\'')
db_session = None
try:
db_session = self._session()
yield db_session
finally:
db_session.close()
@property
def session(self):
return self.get_db
@property
def engine(self):
return self._engine
db = SQLAlchemy()
Base = declarative_base()
# NOTE(hsj100): ADMINISTRATOR
def create_admin(db_session):
import bcrypt
from rest.app.database.schema import Users
from rest.app.common.consts import ADMIN_INIT_ACCOUNT_INFO
session = db_session()
if not session:
raise Exception('cat`t create account of admin')
if Users.get(account=ADMIN_INIT_ACCOUNT_INFO.account):
return
admin = {**ADMIN_INIT_ACCOUNT_INFO.get_dict()}
Users.create(session=session, auto_commit=True, **admin)

300
rest/app/database/crud.py Normal file
View File

@@ -0,0 +1,300 @@
# -*- coding: utf-8 -*-
"""
@File: crud.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: CRUD
"""
import math
from datetime import datetime
from dateutil.relativedelta import relativedelta
from sqlalchemy import func, desc
from fastapi import APIRouter, Depends, Body
from sqlalchemy.orm import Session
from rest.app import models as M
from rest.app.database.conn import Base, db
from rest.app.database.schema import Users, UserLog
from rest.app.utils.extra import query_to_groupby, query_to_groupby_date
def get_month_info_list(start: datetime, end: datetime):
delta = relativedelta(end, start)
month_delta = 12 * delta.years + delta.months + (1 if delta.days > 0 else 0)
month_list = list()
for i in range(month_delta):
count = start + relativedelta(months=i)
month_info = dict()
month_info['period_title'] = count.strftime('%Y-%m')
month_info['year'] = count.year
month_info['month'] = count.month
month_info['start'] = (start + relativedelta(months=i)).replace(day=1)
month_info['end'] = (start + relativedelta(months=i + 1)).replace(day=1)
month_list.append(month_info)
return month_list
def request_parser(request_data: dict = None) -> dict:
"""
request information -> dict
:param request_data:
:return:
"""
result_dict = dict()
if not request_data:
return result_dict
for key, val in request_data.items():
if val is not None:
result_dict[key] = val if val != 'null' else None
return result_dict
def dict_to_filter_fmt(dict_data, get_attributes_callback=None):
"""
dict -> sqlalchemy filter (criterion: sql expression)
:param dict_data:
:param get_attributes_callback:
:return:
"""
if get_attributes_callback is None:
raise Exception('invalid get_attributes_callback')
criterion = list()
for key, val in dict_data.items():
key = key.split('__')
if len(key) > 2:
raise Exception('length of split(key) should be no more than 2.')
key_length = len(key)
col = get_attributes_callback(key[0])
if col is None:
continue
if key_length == 1:
criterion.append((col == val))
elif key_length == 2 and key[1] == 'gt':
criterion.append((col > val))
elif key_length == 2 and key[1] == 'gte':
criterion.append((col >= val))
elif key_length == 2 and key[1] == 'lt':
criterion.append((col < val))
elif key_length == 2 and key[1] == 'lte':
criterion.append((col <= val))
elif key_length == 2 and key[1] == 'in':
criterion.append((col.in_(val)))
elif key_length == 2 and key[1] == 'like':
criterion.append((col.like(val)))
return criterion
async def table_select(accessor_info, target_table, request_body_info, response_model, response_model_data):
"""
table_read
"""
try:
# parameter
if not accessor_info:
raise Exception('invalid accessor')
if not target_table:
raise Exception(f'invalid table_name:{target_table}')
# if not request_body_info:
# raise Exception('invalid request_body_info')
if not response_model:
raise Exception('invalid response_model')
if not response_model_data:
raise Exception('invalid response_model_data')
# paging - request
paging_request = None
if request_body_info:
if hasattr(request_body_info, 'paging'):
paging_request = request_body_info.paging
request_body_info = request_body_info.search
# request
request_info = request_parser(request_body_info.dict())
# search
criterion = None
if isinstance(request_body_info, M.UserLogDaySearchReq):
# UserLog search
# request
request_info = request_parser(request_body_info.dict())
# search UserLog
def get_attributes_callback(key: str):
return getattr(UserLog, key)
criterion = dict_to_filter_fmt(request_info, get_attributes_callback)
# search
session = next(db.session())
search_info = session.query(UserLog)\
.filter(UserLog.mac.isnot(None)
, UserLog.api == '/api/auth/login'
, UserLog.type == M.UserLogMessageType.info
, UserLog.message == 'ok'
, *criterion)\
.order_by(desc(UserLog.created_at), desc(UserLog.updated_at))\
.all()
if not search_info:
raise Exception('not found data')
group_by_day = query_to_groupby_date(search_info, 'updated_at')
# result
result_info = list()
for day, info_list in group_by_day.items():
info_by_mac = query_to_groupby(info_list, 'mac', first=True)
for log_info in info_by_mac.values():
result_info.append(response_model_data.from_orm(log_info))
else:
# basic search (single table)
# request
request_info = request_parser(request_body_info.dict())
# search
search_info = target_table.filter(**request_info).all()
if not search_info:
raise Exception('not found data')
# result
result_info = list()
for purchase_info in search_info:
result_info.append(response_model_data.from_orm(purchase_info))
# response - paging
paging_response = None
if paging_request:
total_contents_num = len(result_info)
total_page_num = math.ceil(total_contents_num / paging_request.page_contents_num)
start_contents_index = (paging_request.start_page - 1) * paging_request.page_contents_num
end_contents_index = start_contents_index + paging_request.page_contents_num
if end_contents_index < total_contents_num:
result_info = result_info[start_contents_index: end_contents_index]
else:
result_info = result_info[start_contents_index:]
paging_response = M.PagingRes()
paging_response.total_page_num = total_page_num
paging_response.total_contents_num = total_contents_num
paging_response.start_page = paging_request.start_page
paging_response.search_contents_num = len(result_info)
return response_model(data=result_info, paging=paging_response)
except Exception as e:
return response_model.set_error(str(e))
async def table_update(accessor_info, target_table, request_body_info, response_model, response_model_data=None):
search_info = None
try:
# parameter
if not accessor_info:
raise Exception('invalid accessor')
if not target_table:
raise Exception(f'invalid table_name:{target_table}')
if not request_body_info:
raise Exception('invalid request_body_info')
if not response_model:
raise Exception('invalid response_model')
# if not response_model_data:
# raise Exception('invalid response_model_data')
# request
if not request_body_info.search_info:
raise Exception('invalid request_body: search_info')
request_search_info = request_parser(request_body_info.search_info.dict())
if not request_search_info:
raise Exception('invalid request_body: search_info')
if not request_body_info.update_info:
raise Exception('invalid request_body: update_info')
request_update_info = request_parser(request_body_info.update_info.dict())
if not request_update_info:
raise Exception('invalid request_body: update_info')
# search
search_info = target_table.filter(**request_search_info)
# process
search_info.update(auto_commit=True, synchronize_session=False, **request_update_info)
# result
return response_model()
except Exception as e:
if search_info:
search_info.close()
return response_model.set_error(str(e))
async def table_delete(accessor_info, target_table, request_body_info, response_model, response_model_data=None):
search_info = None
try:
# request
if not accessor_info:
raise Exception('invalid accessor')
if not target_table:
raise Exception(f'invalid table_name:{target_table}')
if not request_body_info:
raise Exception('invalid request_body_info')
if not response_model:
raise Exception('invalid response_model')
# if not response_model_data:
# raise Exception('invalid response_model_data')
# request
request_search_info = request_parser(request_body_info.dict())
if not request_search_info:
raise Exception('invalid request_body')
# search
search_info = target_table.filter(**request_search_info)
temp_search = search_info.all()
# process
search_info.delete(auto_commit=True, synchronize_session=False)
# update license num
uuid_list = list()
for _license in temp_search:
if not hasattr(temp_search, 'uuid'):
# case: license
break
if _license.uuid not in uuid_list:
uuid_list.append(_license.uuid)
license_num = target_table.filter(uuid=_license.uuid).count()
target_table.filter(uuid=_license.uuid).update(auto_commit=True, synchronize_session=False, num=license_num)
# result
return response_model()
except Exception as e:
if search_info:
search_info.close()
return response_model.set_error(str(e))

304
rest/app/database/schema.py Normal file
View File

@@ -0,0 +1,304 @@
# -*- coding: utf-8 -*-
"""
@File: schema.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: database schema
"""
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
func,
Enum,
Boolean,
ForeignKey,
)
from sqlalchemy.orm import Session, relationship
from rest.app.database.conn import Base, db
from rest.app.utils.date_utils import D
from rest.app.models import (
SexType,
UserType,
MemberType,
AccountType,
UserStatusType,
UserLoginType,
UserLogMessageType
)
class BaseMixin:
id = Column(Integer, primary_key=True, index=True)
created_at = Column(DateTime, nullable=False, default=D.datetime())
updated_at = Column(DateTime, nullable=False, default=D.datetime(), onupdate=D.datetime())
def __init__(self):
self._q = None
self._session = None
self.served = None
def all_columns(self):
return [c for c in self.__table__.columns if c.primary_key is False and c.name != 'created_at']
def __hash__(self):
return hash(self.id)
@classmethod
def create(cls, session: Session, auto_commit=False, **kwargs):
"""
테이블 데이터 적재 전용 함수
:param session:
:param auto_commit: 자동 커밋 여부
:param kwargs: 적재 할 데이터
:return:
"""
obj = cls()
# NOTE(hsj100) : FIX_DATETIME
if 'created_at' not in kwargs:
obj.created_at = D.datetime()
if 'updated_at' not in kwargs:
obj.updated_at = D.datetime()
for col in obj.all_columns():
col_name = col.name
if col_name in kwargs:
setattr(obj, col_name, kwargs.get(col_name))
session.add(obj)
session.flush()
if auto_commit:
session.commit()
return obj
@classmethod
def get(cls, session: Session = None, **kwargs):
"""
Simply get a Row
:param session:
:param kwargs:
:return:
"""
sess = next(db.session()) if not session else session
query = sess.query(cls)
for key, val in kwargs.items():
col = getattr(cls, key)
query = query.filter(col == val)
if query.count() > 1:
raise Exception('Only one row is supposed to be returned, but got more than one.')
result = query.first()
if not session:
sess.close()
return result
@classmethod
def filter(cls, session: Session = None, **kwargs):
"""
Simply get a Row
:param session:
:param kwargs:
:return:
"""
cond = []
for key, val in kwargs.items():
key = key.split('__')
if len(key) > 2:
raise Exception('length of split(key) should be no more than 2.')
col = getattr(cls, key[0])
if len(key) == 1: cond.append((col == val))
elif len(key) == 2 and key[1] == 'gt': cond.append((col > val))
elif len(key) == 2 and key[1] == 'gte': cond.append((col >= val))
elif len(key) == 2 and key[1] == 'lt': cond.append((col < val))
elif len(key) == 2 and key[1] == 'lte': cond.append((col <= val))
elif len(key) == 2 and key[1] == 'in': cond.append((col.in_(val)))
elif len(key) == 2 and key[1] == 'like': cond.append((col.like(val)))
obj = cls()
if session:
obj._session = session
obj.served = True
else:
obj._session = next(db.session())
obj.served = False
query = obj._session.query(cls)
query = query.filter(*cond)
obj._q = query
return obj
@classmethod
def cls_attr(cls, col_name=None):
if col_name:
col = getattr(cls, col_name)
return col
else:
return cls
def order_by(self, *args: str):
for a in args:
if a.startswith('-'):
col_name = a[1:]
is_asc = False
else:
col_name = a
is_asc = True
col = self.cls_attr(col_name)
self._q = self._q.order_by(col.asc()) if is_asc else self._q.order_by(col.desc())
return self
def update(self, auto_commit: bool = False, synchronize_session='evaluate', **kwargs):
# NOTE(hsj100) : FIX_DATETIME
if 'updated_at' not in kwargs:
kwargs['updated_at'] = D.datetime()
qs = self._q.update(kwargs, synchronize_session=synchronize_session)
get_id = self.id
ret = None
self._session.flush()
if qs > 0 :
ret = self._q.first()
if auto_commit:
self._session.commit()
return ret
def first(self):
result = self._q.first()
self.close()
return result
def delete(self, auto_commit: bool = False, synchronize_session='evaluate'):
self._q.delete(synchronize_session=synchronize_session)
if auto_commit:
self._session.commit()
def all(self):
print(self.served)
result = self._q.all()
self.close()
return result
def count(self):
result = self._q.count()
self.close()
return result
def close(self):
if not self.served:
self._session.close()
else:
self._session.flush()
class ApiKeys(Base, BaseMixin):
__tablename__ = 'api_keys'
__table_args__ = {'extend_existing': True}
access_key = Column(String(length=64), nullable=False, index=True)
secret_key = Column(String(length=64), nullable=False)
user_memo = Column(String(length=40), nullable=True)
status = Column(Enum('active', 'stopped', 'deleted'), default='active')
is_whitelisted = Column(Boolean, default=False)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
whitelist = relationship('ApiWhiteLists', backref='api_keys')
users = relationship('Users', back_populates='keys')
class ApiWhiteLists(Base, BaseMixin):
__tablename__ = 'api_whitelists'
__table_args__ = {'extend_existing': True}
ip_addr = Column(String(length=64), nullable=False)
api_key_id = Column(Integer, ForeignKey('api_keys.id'), nullable=False)
class Users(Base, BaseMixin):
__tablename__ = 'users'
__table_args__ = {'extend_existing': True}
status = Column(Enum(UserStatusType), nullable=False, default=UserStatusType.active)
user_type = Column(Enum(UserType), nullable=False, default=UserType.user)
account_type = Column(Enum(AccountType), nullable=False, default=AccountType.email)
account = Column(String(length=255), nullable=False, unique=True)
pw = Column(String(length=2000), nullable=True)
email = Column(String(length=255), nullable=True)
name = Column(String(length=255), nullable=False)
sex = Column(Enum(SexType), nullable=False, default=SexType.male)
rrn = Column(String(length=255), nullable=True)
address = Column(String(length=2000), nullable=True)
phone_number = Column(String(length=20), nullable=True)
picture = Column(String(length=1000), nullable=True)
marketing_agree = Column(Boolean, nullable=False, default=False)
keys = relationship('ApiKeys', back_populates='users')
# extra
login = Column(Enum(UserLoginType), nullable=False, default=UserLoginType.logout) # TODO(hsj100): LOGIN_STATUS
member_type = Column(Enum(MemberType), nullable=False, default=MemberType.personal)
# uuid = Column(String(length=36), nullable=True, unique=True)
class UserLog(Base, BaseMixin):
__tablename__ = 'userlog'
__table_args__ = {'extend_existing': True}
account = Column(String(length=255), nullable=False)
mac = Column(String(length=255), nullable=True)
type = Column(Enum(UserLogMessageType), nullable=False, default=UserLogMessageType.info)
api = Column(String(length=511), nullable=False)
message = Column(String(length=5000), nullable=True)
# NOTE(hsj100): 다단계 자동 삭제
# user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'), nullable=False)
# user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
# users = relationship('Users', back_populates='userlog')
# users = relationship('Users', back_populates='userlog', lazy=False)
# ===============================================================================================
class CustomBaseMixin(BaseMixin):
id = Column(Integer, primary_key=True, index=True)
def all_columns(self):
return [c for c in self.__table__.columns if c.primary_key is False]
@classmethod
def create(cls, session: Session, auto_commit=False, **kwargs):
"""
테이블 데이터 적재 전용 함수
:param session:
:param auto_commit: 자동 커밋 여부
:param kwargs: 적재 할 데이터
:return:
"""
obj = cls()
for col in obj.all_columns():
col_name = col.name
if col_name in kwargs:
setattr(obj, col_name, kwargs.get(col_name))
session.add(obj)
session.flush()
if auto_commit:
session.commit()
return obj
def update(self, auto_commit: bool = False, synchronize_session='evaluate', **kwargs):
qs = self._q.update(kwargs, synchronize_session=synchronize_session)
get_id = self.id
ret = None
self._session.flush()
if qs > 0 :
ret = self._q.first()
if auto_commit:
self._session.commit()
return ret