# -*- coding: utf-8 -*- """ @File: schema.py @Date: 2020-09-14 @author: A2TEC @section MODIFYINFO 수정정보 - 수정자/수정일 : 수정내역 - 2022-01-14/hsj100@a2tec.co.kr : refactoring @brief: database schema """ from sqlalchemy import ( Column, Integer, String, DateTime, func, Enum, Boolean, ForeignKey, ) from sqlalchemy.orm import Session, relationship from rest.app.database.conn import Base, db from rest.app.utils.date_utils import D from rest.app.models import ( SexType, UserType, MemberType, AccountType, UserStatusType, UserLoginType, UserLogMessageType ) class BaseMixin: id = Column(Integer, primary_key=True, index=True) created_at = Column(DateTime, nullable=False, default=D.datetime()) updated_at = Column(DateTime, nullable=False, default=D.datetime(), onupdate=D.datetime()) def __init__(self): self._q = None self._session = None self.served = None def all_columns(self): return [c for c in self.__table__.columns if c.primary_key is False and c.name != 'created_at'] def __hash__(self): return hash(self.id) @classmethod def create(cls, session: Session, auto_commit=False, **kwargs): """ 테이블 데이터 적재 전용 함수 :param session: :param auto_commit: 자동 커밋 여부 :param kwargs: 적재 할 데이터 :return: """ obj = cls() # NOTE(hsj100) : FIX_DATETIME if 'created_at' not in kwargs: obj.created_at = D.datetime() if 'updated_at' not in kwargs: obj.updated_at = D.datetime() for col in obj.all_columns(): col_name = col.name if col_name in kwargs: setattr(obj, col_name, kwargs.get(col_name)) session.add(obj) session.flush() if auto_commit: session.commit() return obj @classmethod def get(cls, session: Session = None, **kwargs): """ Simply get a Row :param session: :param kwargs: :return: """ sess = next(db.session()) if not session else session query = sess.query(cls) for key, val in kwargs.items(): col = getattr(cls, key) query = query.filter(col == val) if query.count() > 1: raise Exception('Only one row is supposed to be returned, but got more than one.') result = query.first() if not session: sess.close() return result @classmethod def filter(cls, session: Session = None, **kwargs): """ Simply get a Row :param session: :param kwargs: :return: """ cond = [] for key, val in kwargs.items(): key = key.split('__') if len(key) > 2: raise Exception('length of split(key) should be no more than 2.') col = getattr(cls, key[0]) if len(key) == 1: cond.append((col == val)) elif len(key) == 2 and key[1] == 'gt': cond.append((col > val)) elif len(key) == 2 and key[1] == 'gte': cond.append((col >= val)) elif len(key) == 2 and key[1] == 'lt': cond.append((col < val)) elif len(key) == 2 and key[1] == 'lte': cond.append((col <= val)) elif len(key) == 2 and key[1] == 'in': cond.append((col.in_(val))) elif len(key) == 2 and key[1] == 'like': cond.append((col.like(val))) obj = cls() if session: obj._session = session obj.served = True else: obj._session = next(db.session()) obj.served = False query = obj._session.query(cls) query = query.filter(*cond) obj._q = query return obj @classmethod def cls_attr(cls, col_name=None): if col_name: col = getattr(cls, col_name) return col else: return cls def order_by(self, *args: str): for a in args: if a.startswith('-'): col_name = a[1:] is_asc = False else: col_name = a is_asc = True col = self.cls_attr(col_name) self._q = self._q.order_by(col.asc()) if is_asc else self._q.order_by(col.desc()) return self def update(self, auto_commit: bool = False, synchronize_session='evaluate', **kwargs): # NOTE(hsj100) : FIX_DATETIME if 'updated_at' not in kwargs: kwargs['updated_at'] = D.datetime() qs = self._q.update(kwargs, synchronize_session=synchronize_session) get_id = self.id ret = None self._session.flush() if qs > 0 : ret = self._q.first() if auto_commit: self._session.commit() return ret def first(self): result = self._q.first() self.close() return result def delete(self, auto_commit: bool = False, synchronize_session='evaluate'): self._q.delete(synchronize_session=synchronize_session) if auto_commit: self._session.commit() def all(self): print(self.served) result = self._q.all() self.close() return result def count(self): result = self._q.count() self.close() return result def close(self): if not self.served: self._session.close() else: self._session.flush() class ApiKeys(Base, BaseMixin): __tablename__ = 'api_keys' __table_args__ = {'extend_existing': True} access_key = Column(String(length=64), nullable=False, index=True) secret_key = Column(String(length=64), nullable=False) user_memo = Column(String(length=40), nullable=True) status = Column(Enum('active', 'stopped', 'deleted'), default='active') is_whitelisted = Column(Boolean, default=False) user_id = Column(Integer, ForeignKey('users.id'), nullable=False) whitelist = relationship('ApiWhiteLists', backref='api_keys') users = relationship('Users', back_populates='keys') class ApiWhiteLists(Base, BaseMixin): __tablename__ = 'api_whitelists' __table_args__ = {'extend_existing': True} ip_addr = Column(String(length=64), nullable=False) api_key_id = Column(Integer, ForeignKey('api_keys.id'), nullable=False) class Users(Base, BaseMixin): __tablename__ = 'users' __table_args__ = {'extend_existing': True} status = Column(Enum(UserStatusType), nullable=False, default=UserStatusType.active) user_type = Column(Enum(UserType), nullable=False, default=UserType.user) account_type = Column(Enum(AccountType), nullable=False, default=AccountType.email) account = Column(String(length=255), nullable=False, unique=True) pw = Column(String(length=2000), nullable=True) email = Column(String(length=255), nullable=True) name = Column(String(length=255), nullable=False) sex = Column(Enum(SexType), nullable=False, default=SexType.male) rrn = Column(String(length=255), nullable=True) address = Column(String(length=2000), nullable=True) phone_number = Column(String(length=20), nullable=True) picture = Column(String(length=1000), nullable=True) marketing_agree = Column(Boolean, nullable=False, default=False) keys = relationship('ApiKeys', back_populates='users') # extra login = Column(Enum(UserLoginType), nullable=False, default=UserLoginType.logout) # TODO(hsj100): LOGIN_STATUS member_type = Column(Enum(MemberType), nullable=False, default=MemberType.personal) # uuid = Column(String(length=36), nullable=True, unique=True) class UserLog(Base, BaseMixin): __tablename__ = 'userlog' __table_args__ = {'extend_existing': True} account = Column(String(length=255), nullable=False) mac = Column(String(length=255), nullable=True) type = Column(Enum(UserLogMessageType), nullable=False, default=UserLogMessageType.info) api = Column(String(length=511), nullable=False) message = Column(String(length=5000), nullable=True) # NOTE(hsj100): 다단계 자동 삭제 # user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'), nullable=False) # user_id = Column(Integer, ForeignKey('users.id'), nullable=False) # users = relationship('Users', back_populates='userlog') # users = relationship('Users', back_populates='userlog', lazy=False) # =============================================================================================== class CustomBaseMixin(BaseMixin): id = Column(Integer, primary_key=True, index=True) def all_columns(self): return [c for c in self.__table__.columns if c.primary_key is False] @classmethod def create(cls, session: Session, auto_commit=False, **kwargs): """ 테이블 데이터 적재 전용 함수 :param session: :param auto_commit: 자동 커밋 여부 :param kwargs: 적재 할 데이터 :return: """ obj = cls() for col in obj.all_columns(): col_name = col.name if col_name in kwargs: setattr(obj, col_name, kwargs.get(col_name)) session.add(obj) session.flush() if auto_commit: session.commit() return obj def update(self, auto_commit: bool = False, synchronize_session='evaluate', **kwargs): qs = self._q.update(kwargs, synchronize_session=synchronize_session) get_id = self.id ret = None self._session.flush() if qs > 0 : ret = self._q.first() if auto_commit: self._session.commit() return ret