# -*- coding: utf-8 -*- """ @File: conn.py @Date: 2020-09-14 @author: A2TEC @section MODIFYINFO 수정정보 - 수정자/수정일 : 수정내역 - 2022-01-14/hsj100@a2tec.co.kr : refactoring @brief: DB Connections """ from fastapi import FastAPI from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker import logging def _database_exist(engine, schema_name): query = f'SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = "{schema_name}"' with engine.connect() as conn: result_proxy = conn.execute(query) result = result_proxy.scalar() return bool(result) def _drop_database(engine, schema_name): with engine.connect() as conn: conn.execute(f'DROP DATABASE {schema_name};') def _create_database(engine, schema_name): with engine.connect() as conn: conn.execute(f'CREATE DATABASE {schema_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_bin;') class SQLAlchemy: def __init__(self, app: FastAPI = None, **kwargs): self._engine = None self._session = None if app is not None: self.init_app(app=app, **kwargs) def init_app(self, app: FastAPI, **kwargs): """ DB 초기화 함수 :param app: FastAPI 인스턴스 :param kwargs: :return: """ database_url = kwargs.get('DB_URL') pool_recycle = kwargs.setdefault('DB_POOL_RECYCLE', 900) is_testing = kwargs.setdefault('TEST_MODE', False) echo = kwargs.setdefault('DB_ECHO', True) self._engine = create_engine( database_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True, ) if is_testing: # create schema db_url = self._engine.url if db_url.host != 'localhost': raise Exception('db host must be \'localhost\' in test environment') except_schema_db_url = f'{db_url.drivername}://{db_url.username}:{db_url.password}@{db_url.host}:{db_url.port}' schema_name = db_url.database temp_engine = create_engine(except_schema_db_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True) if _database_exist(temp_engine, schema_name): _drop_database(temp_engine, schema_name) _create_database(temp_engine, schema_name) temp_engine.dispose() else: db_url = self._engine.url except_schema_db_url = f'{db_url.drivername}://{db_url.username}:{db_url.password}@{db_url.host}:{db_url.port}' schema_name = db_url.database temp_engine = create_engine(except_schema_db_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True) if not _database_exist(temp_engine, schema_name): _create_database(temp_engine, schema_name) Base.metadata.create_all(db.engine) temp_engine.dispose() self._session = sessionmaker(autocommit=False, autoflush=False, bind=self._engine) # NOTE(hsj100): ADMINISTRATOR create_admin(self._session) @app.on_event('startup') def startup(): self._engine.connect() logging.info('DB connected.') @app.on_event('shutdown') def shutdown(): self._session.close_all() self._engine.dispose() logging.info('DB disconnected') def get_db(self): """ 요청마다 DB 세션 유지 함수 :return: """ if self._session is None: raise Exception('must be called \'init_app\'') db_session = None try: db_session = self._session() yield db_session finally: db_session.close() @property def session(self): return self.get_db @property def engine(self): return self._engine db = SQLAlchemy() Base = declarative_base() # NOTE(hsj100): ADMINISTRATOR def create_admin(db_session): import bcrypt from rest.app.database.schema import Users from rest.app.common.consts import ADMIN_INIT_ACCOUNT_INFO session = db_session() if not session: raise Exception('cat`t create account of admin') if Users.get(account=ADMIN_INIT_ACCOUNT_INFO.account): return admin = {**ADMIN_INIT_ACCOUNT_INFO.get_dict()} Users.create(session=session, auto_commit=True, **admin)