edit: faiss 전용 rest 서버 추가

This commit is contained in:
2025-04-28 11:26:19 +09:00
parent 19e62f5724
commit 6b212125a4
76 changed files with 6014 additions and 92 deletions

3
.gitignore vendored
View File

@@ -144,4 +144,5 @@ cython_debug/
log
sheet_counter.txt
*output
*output
datas/eyewear_all/*

View File

@@ -24,5 +24,10 @@ RESTful API Server
https://github.com/acheong08/BingImageCreator
```
### faiss
검색된 이미지는 search_'matching'.png 로 저장됨
matching : 매칭률 (클수록 좋음)
### notice
const.py 에 지정한 OUTPUT_FOLDER 하위폴더에 이미지 저장됨.
const.py 에 지정한 OUTPUT_FOLDER 하위폴더에 이미지 저장됨.
faiss api 사용시 rest_vactor 서버 구동해야함.

View File

@@ -1,3 +1,4 @@
# OUTPUT_FOLDER = "/home/fermat/STORAGE/01.Projects/A2TEC/K_EYEWEAR/02.ML_DATA/Image_generator_result"
OUTPUT_FOLDER = "./output"
ILLEGAL_FILE_NAME = ['<', '>', ':', '"', '/', '\ ', '|', '?', '*']

View File

@@ -4,8 +4,8 @@ from pathlib import Path
from bingart import BingArt
from custom_apps.utils import cookie_manager
from rest.app.utils.parsing_utils import prompt_to_filenames
from rest.app.utils.date_utils import D
from main_rest.app.utils.parsing_utils import prompt_to_filenames
from main_rest.app.utils.date_utils import D
from const import OUTPUT_FOLDER

View File

@@ -16,8 +16,8 @@ import pkg_resources
import regex
import requests
from rest.app.utils.parsing_utils import prompt_to_filenames
from rest.app.utils.date_utils import D
from main_rest.app.utils.parsing_utils import prompt_to_filenames
from main_rest.app.utils.date_utils import D
BING_URL = os.getenv("BING_URL", "https://www.bing.com")
# Generate random IP between range 13.104.0.0/14

View File

@@ -0,0 +1,2 @@
DATASET_BIN = "./datas/eyewear_all.fvecs.bin"
DATASET_TEXT = "./datas/eyewear_all.fnames.txt"

50
custom_apps/faiss/main.py Normal file
View File

@@ -0,0 +1,50 @@
import os
import shutil
from custom_apps.faiss.utils import preprocessing, preprocessing_quary, normalize, get_dataset_list
from custom_apps.faiss.const import *
# from rest.app.utils.date_utils import D
# from const import OUTPUT_FOLDER
dataset_list = get_dataset_list(DATASET_TEXT)
def search_idxs(image_path,dataset_bin=DATASET_BIN,index_type="hnsw",search_num=4):
if index_type not in ["hnsw", "l2"]:
raise ValueError("index_type must be either 'hnsw' or 'l2'")
DIM = 1280
dataset_fvces, dataset_index = preprocessing(DIM,dataset_bin,index_type)
org_fvces, org_index = preprocessing_quary(DIM,image_path,index_type)
dists, idxs = dataset_index.search(normalize(org_fvces), search_num)
# print(dists[0])
# print(idxs[0])
index_image_save(image_path, dists[0], idxs[0])
def index_image_save(query_image_path, dists, idxs):
directory_path, file = os.path.split(query_image_path)
_name, extension = os.path.splitext(file)
if not os.path.exists(directory_path):
raise ValueError(f"Folder {directory_path} does not exist.")
for dist,index in zip(dists, idxs):
if dist > 1:
dist = 0
else:
dist = 1-dist
origin_file_path = dataset_list[index]
dest_file_path = os.path.join(directory_path, f"search_{round(float(dist),4)}{extension}")
shutil.copy(origin_file_path, dest_file_path)

116
custom_apps/faiss/utils.py Normal file
View File

@@ -0,0 +1,116 @@
"""
quary 이미지(이미지 경로) 입력받아 처리
"""
import os
import time
import math
import numpy as np
from sklearn.preprocessing import normalize
import faiss
import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras.models import Model
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from PIL import Image
def populate(index, fvecs, batch_size=1000):
nloop = math.ceil(fvecs.shape[0] / batch_size)
for n in range(nloop):
s = time.time()
index.add(normalize(fvecs[n * batch_size : min((n + 1) * batch_size, fvecs.shape[0])]))
# print(n * batch_size, time.time() - s)
return index
def get_index(index_type, dim):
if index_type == 'hnsw':
m = 48
index = faiss.IndexHNSWFlat(dim, m)
index.hnsw.efConstruction = 128
return index
elif index_type == 'l2':
return faiss.IndexFlatL2(dim)
raise
def preprocessing(dim,fvec_file,index_type):
# index_type = 'hnsw'
# index_type = 'l2'
# f-string 방식 (python3 이상에서 지원)
index_file = f'{fvec_file}.{index_type}.index'
fvecs = np.memmap(fvec_file, dtype='float32', mode='r').view('float32').reshape(-1, dim)
if os.path.exists(index_file):
index = faiss.read_index(index_file)
if index_type == 'hnsw':
index.hnsw.efSearch = 256
else:
index = get_index(index_type, dim)
index = populate(index, fvecs)
faiss.write_index(index, index_file)
# print(index.ntotal)
return fvecs,index
def preprocessing_quary(dim,image_path,index_type):
# index_type = 'hnsw'
# index_type = 'l2'
fvecs = fvces_quary(image_path)
index = get_index(index_type, dim)
index = populate(index, fvecs)
return fvecs,index
def preprocess(img_path, input_shape):
img = tf.io.read_file(img_path)
img = tf.image.decode_jpeg(img, channels=input_shape[2])
img = tf.image.resize(img, input_shape[:2])
img = preprocess_input(img)
return img
def preprocess_pil(pil_img_data, input_shape):
pil_img = np.asarray(Image.open(pil_img_data))
pil_img = tf.image.resize(pil_img, input_shape[:2])
pil_img = preprocess_input(pil_img)
return pil_img
def fvces_quary(image_path):
batch_size = 100
input_shape = (224, 224, 3)
base = tf.keras.applications.MobileNetV2(input_shape=input_shape,
include_top=False,
weights='imagenet')
base.trainable = False
model = Model(inputs=base.input, outputs=layers.GlobalAveragePooling2D()(base.output))
list_ds = tf.data.Dataset.from_tensor_slices([image_path])
ds = list_ds.map(lambda x: preprocess(x, input_shape), num_parallel_calls=-1)
dataset = ds.batch(batch_size).prefetch(-1)
for batch in dataset:
fvecs = model.predict(batch)
return fvecs
# index_type = 'hnsw'
# index_type = 'l2'
def get_dataset_list(filepath):
content_list = []
import os
if os.path.isfile(filepath) and filepath.endswith(".txt"):
with open(filepath, 'r', encoding='utf-8') as file:
for line in file:
content_list.append(line.strip())
return content_list

View File

@@ -5,8 +5,8 @@ import os
from vertexai.preview.vision_models import ImageGenerationModel
from const import OUTPUT_FOLDER
from rest.app.utils.parsing_utils import prompt_to_filenames
from rest.app.utils.date_utils import D
from main_rest.app.utils.parsing_utils import prompt_to_filenames
from main_rest.app.utils.date_utils import D
class ImagenConst:
@@ -48,4 +48,44 @@ def imagen_generate_image(prompt,download_count=1):
for i in range(len(images.images)):
images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i+1}_{_datetime}.png"), include_generation_parameters=False)
return len(images.images)
return len(images.images)
def imagen_generate_image_data(prompt,download_count=1):
vertexai.init(project=ImagenConst.project_id, location=ImagenConst.location)
model = ImageGenerationModel.from_pretrained(ImagenConst.model)
images = model.generate_images(
prompt=prompt,
# Optional parameters
number_of_images=download_count,
language="ko",
# You can't use a seed value and watermark at the same time.
# add_watermark=False,
# seed=100,
aspect_ratio="1:1",
safety_filter_level="block_some",
person_generation="dont_allow",
)
return images.images[0]._pil_image
def imagen_generate_image_path(image_prompt):
MODEL = "imagen"
QUERY = "query"
create_time = D.date_file_name()
folder_name = os.path.join(OUTPUT_FOLDER,f"{MODEL}_{QUERY}_{image_prompt}_{create_time}")
if not os.path.exists(folder_name):
os.makedirs(folder_name)
generate_img = imagen_generate_image_data(image_prompt)
generate_img.save(os.path.join(folder_name,f"query.png"))
return os.path.join(folder_name,f"query.png")
if __name__ == '__main__':
pass
# imagen_generate_image_data("cat")

View File

@@ -95,18 +95,18 @@ def get_logger():
return custom_logger
if custom_logger is None:
custom_logger = logging.getLogger(LOGGER_NAME)
# if custom_logger is None:
# custom_logger = logging.getLogger(LOGGER_NAME)
if not os.path.exists(LOGGER_DIR):
os.makedirs(LOGGER_DIR)
# if not os.path.exists(LOGGER_DIR):
# os.makedirs(LOGGER_DIR)
if not __LOGGER_FILE_PATH:
logger_init(custom_logger, level=LOGGER_LEVEL)
else:
if not os.path.exists(LOGGER_DIR):
os.makedirs(LOGGER_DIR)
logger_init(custom_logger, level=LOGGER_LEVEL, file_log_path=__LOGGER_FILE_PATH)
# if not __LOGGER_FILE_PATH:
# logger_init(custom_logger, level=LOGGER_LEVEL)
# else:
# if not os.path.exists(LOGGER_DIR):
# os.makedirs(LOGGER_DIR)
# logger_init(custom_logger, level=LOGGER_LEVEL, file_log_path=__LOGGER_FILE_PATH)
def test():
custom_logger.info('Module: custom_log.py')

42
custom_logger/main_log.py Normal file
View File

@@ -0,0 +1,42 @@
import datetime
import os
import logging
from custom_logger.custom_log import logger_init
main_logger = None
_now = datetime.datetime.now()
LOGGER_NAME = 'main'
LOGGER_FILE_NAME = f'{_now.strftime("%Y-%m-%d %H_%M_%S")}_{LOGGER_NAME}.log'
LOGGER_LEVEL = logging.INFO
LOGGER_DIR = "log/"
__LOGGER_FILE_PATH = LOGGER_DIR + LOGGER_FILE_NAME
def get_main_logger():
"""
로거 객체 반환
로거 객체가 없을 경우에는 로그 초기화를 진행하고 생성된 로그 객체를 반환함
:return: 로그 객체
"""
if not main_logger.handlers:
logger_init(main_logger)
return main_logger
if main_logger is None:
main_logger = logging.getLogger(LOGGER_NAME)
if not os.path.exists(LOGGER_DIR):
os.makedirs(LOGGER_DIR)
if not __LOGGER_FILE_PATH:
logger_init(main_logger, level=LOGGER_LEVEL)
else:
if not os.path.exists(LOGGER_DIR):
os.makedirs(LOGGER_DIR)
logger_init(main_logger, level=LOGGER_LEVEL, file_log_path=__LOGGER_FILE_PATH)

View File

@@ -0,0 +1,42 @@
import datetime
import os
import logging
from custom_logger.custom_log import logger_init
vactor_logger = None
_now = datetime.datetime.now()
LOGGER_NAME = 'vactor'
LOGGER_FILE_NAME = f'{_now.strftime("%Y-%m-%d %H_%M_%S")}_{LOGGER_NAME}.log'
LOGGER_LEVEL = logging.INFO
LOGGER_DIR = "log/"
__LOGGER_FILE_PATH = LOGGER_DIR + LOGGER_FILE_NAME
def get_vactor_logger():
"""
로거 객체 반환
로거 객체가 없을 경우에는 로그 초기화를 진행하고 생성된 로그 객체를 반환함
:return: 로그 객체
"""
if not vactor_logger.handlers:
logger_init(vactor_logger)
return vactor_logger
if vactor_logger is None:
vactor_logger = logging.getLogger(LOGGER_NAME)
if not os.path.exists(LOGGER_DIR):
os.makedirs(LOGGER_DIR)
if not __LOGGER_FILE_PATH:
logger_init(vactor_logger, level=LOGGER_LEVEL)
else:
if not os.path.exists(LOGGER_DIR):
os.makedirs(LOGGER_DIR)
logger_init(vactor_logger, level=LOGGER_LEVEL, file_log_path=__LOGGER_FILE_PATH)

2000
datas/eyewear_all.fnames.txt Normal file

File diff suppressed because it is too large Load Diff

BIN
datas/eyewear_all.fvecs.bin Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

14
environment_vactor.yml Normal file
View File

@@ -0,0 +1,14 @@
name: fm_rest_vactor
channels:
- pytorch
- nvidia
- rapidsai
- conda-forge
- defaults
dependencies:
- python=3.10
- libnvjitlink
- faiss-gpu-cuvs=1.10.0
- tensorflow=2.11.0
- pip:
- -r requirements_vactor.txt

View File

@@ -1,6 +0,0 @@
import uvicorn
from rest.app.common.config import conf
if __name__ == '__main__':
uvicorn.run('rest.app.main:app', host='0.0.0.0', port=conf().REST_SERVER_PORT, reload=True)

View File

@@ -12,8 +12,8 @@
from dataclasses import dataclass
from os import path, environ
from rest.app.common import consts
from rest.app.models import UserInfo
from main_rest.app.common import consts
from main_rest.app.models import UserInfo
base_dir = path.dirname(path.dirname(path.dirname(path.abspath(__file__))))

View File

@@ -126,8 +126,8 @@ Base = declarative_base()
# NOTE(hsj100): ADMINISTRATOR
def create_admin(db_session):
import bcrypt
from rest.app.database.schema import Users
from rest.app.common.consts import ADMIN_INIT_ACCOUNT_INFO
from main_rest.app.database.schema import Users
from main_rest.app.common.consts import ADMIN_INIT_ACCOUNT_INFO
session = db_session()

View File

@@ -16,10 +16,10 @@ from sqlalchemy import func, desc
from fastapi import APIRouter, Depends, Body
from sqlalchemy.orm import Session
from rest.app import models as M
from rest.app.database.conn import Base, db
from rest.app.database.schema import Users, UserLog
from rest.app.utils.extra import query_to_groupby, query_to_groupby_date
from main_rest.app import models as M
from main_rest.app.database.conn import Base, db
from main_rest.app.database.schema import Users, UserLog
from main_rest.app.utils.extra import query_to_groupby, query_to_groupby_date
def get_month_info_list(start: datetime, end: datetime):

View File

@@ -21,9 +21,9 @@ from sqlalchemy import (
)
from sqlalchemy.orm import Session, relationship
from rest.app.database.conn import Base, db
from rest.app.utils.date_utils import D
from rest.app.models import (
from main_rest.app.database.conn import Base, db
from main_rest.app.utils.date_utils import D
from main_rest.app.models import (
SexType,
UserType,
MemberType,

View File

@@ -1,4 +1,4 @@
from rest.app.common.consts import MAX_API_KEY, MAX_API_WHITELIST
from main_rest.app.common.consts import MAX_API_KEY, MAX_API_WHITELIST
class StatusCode:

View File

@@ -20,16 +20,16 @@ from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from rest.app.common import consts
from main_rest.app.common import consts
from rest.app.database.conn import db
from rest.app.common.config import conf
from rest.app.middlewares.token_validator import access_control
from rest.app.middlewares.trusted_hosts import TrustedHostMiddleware
from rest.app.routes import dev, index, auth, users, services
from main_rest.app.database.conn import db
from main_rest.app.common.config import conf
from main_rest.app.middlewares.token_validator import access_control
from main_rest.app.middlewares.trusted_hosts import TrustedHostMiddleware
from main_rest.app.routes import dev, index, auth, users, services
from contextlib import asynccontextmanager
from custom_logger.custom_log import custom_logger as LOG
from custom_logger.main_log import main_logger as LOG
API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False)

View File

@@ -12,19 +12,19 @@ from jwt.exceptions import ExpiredSignatureError, DecodeError
from starlette.requests import Request
from starlette.responses import JSONResponse
from rest.app.common.consts import EXCEPT_PATH_LIST, EXCEPT_PATH_REGEX
from rest.app.database.conn import db
from rest.app.database.schema import Users, ApiKeys
from rest.app.errors import exceptions as ex
from main_rest.app.common.consts import EXCEPT_PATH_LIST, EXCEPT_PATH_REGEX
from main_rest.app.database.conn import db
from main_rest.app.database.schema import Users, ApiKeys
from main_rest.app.errors import exceptions as ex
from rest.app.common import consts
from rest.app.common.config import conf
from rest.app.errors.exceptions import APIException, SqlFailureEx, APIQueryStringEx
from rest.app.models import UserToken
from main_rest.app.common import consts
from main_rest.app.common.config import conf
from main_rest.app.errors.exceptions import APIException, SqlFailureEx, APIQueryStringEx
from main_rest.app.models import UserToken
from rest.app.utils.date_utils import D
from rest.app.utils.logger import api_logger
from rest.app.utils.query_utils import to_dict
from main_rest.app.utils.date_utils import D
from main_rest.app.utils.logger import api_logger
from main_rest.app.utils.query_utils import to_dict
from dataclasses import asdict

View File

@@ -19,7 +19,7 @@ from pydantic.main import BaseModel
from pydantic.networks import EmailStr, IPvAnyAddress
from typing import Optional
from rest.app.common.consts import (
from main_rest.app.common.consts import (
SW_TITLE,
SW_VERSION,
MAIL_REG_TITLE,
@@ -29,7 +29,7 @@ from rest.app.common.consts import (
ADMIN_INIT_ACCOUNT_INFO,
DEFAULT_USER_ACCOUNT_PW
)
from rest.app.utils.date_utils import D
from main_rest.app.utils.date_utils import D
class SWInfo(BaseModel):
@@ -566,6 +566,10 @@ class UserLogUpdateMultiReq(BaseModel):
#===============================================================================
#===============================================================================
class IndexType:
hnsw = "hnsw"
l2 = "l2"
class ImageGenerateReq(BaseModel):
"""
### [Request] image generate request
@@ -580,6 +584,15 @@ class BingCookieSetReq(BaseModel):
"""
cookie : str = Field('',description='쿠키 데이터', example='')
class VactorImageSearchReq(BaseModel):
"""
### [Request] vactor image search request
"""
prompt : str = Field(description='프롬프트', example='검은색 안경')
index_type : str = Field(IndexType.hnsw, description='인덱스 타입', example=IndexType.hnsw)
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
#===============================================================================
#===============================================================================

View File

@@ -17,13 +17,13 @@ import bcrypt
import jwt
from datetime import datetime, timedelta
from rest.app.common import consts
from rest.app import models as M
from rest.app.database.conn import db
from rest.app.common.config import conf
from rest.app.database.schema import Users, UserLog
from rest.app.utils.extra import query_to_groupby, AESCryptoCBC
from rest.app.utils.date_utils import D
from main_rest.app.common import consts
from main_rest.app import models as M
from main_rest.app.database.conn import db
from main_rest.app.common.config import conf
from main_rest.app.database.schema import Users, UserLog
from main_rest.app.utils.extra import query_to_groupby, AESCryptoCBC
from main_rest.app.utils.date_utils import D
router = APIRouter(prefix='/auth')

View File

@@ -15,13 +15,13 @@ from sqlalchemy.orm import Session
import bcrypt
from starlette.requests import Request
from rest.app.common import consts
from rest.app import models as M
from rest.app.database.conn import db, Base
from rest.app.database.schema import Users, UserLog
from main_rest.app.common import consts
from main_rest.app import models as M
from main_rest.app.database.conn import db, Base
from main_rest.app.database.schema import Users, UserLog
from rest.app.utils.extra import FernetCrypto, AESCryptoCBC, AESCipher
from custom_logger.custom_log import custom_logger as LOG
from main_rest.app.utils.extra import FernetCrypto, AESCryptoCBC, AESCipher
from custom_logger.main_log import main_logger as LOG
# mail test

View File

@@ -11,8 +11,8 @@
from fastapi import APIRouter
from rest.app.utils.date_utils import D
from rest.app.models import SWInfo
from main_rest.app.utils.date_utils import D
from main_rest.app.models import SWInfo
router = APIRouter()

View File

@@ -14,15 +14,15 @@ from fastapi import APIRouter, Depends, Body
from starlette.requests import Request
from typing import Annotated, List
from rest.app.common import consts
from rest.app import models as M
from rest.app.utils.date_utils import D
from custom_logger.custom_log import custom_logger as LOG
from main_rest.app.common import consts
from main_rest.app import models as M
from main_rest.app.utils.date_utils import D
from custom_logger.main_log import main_logger as LOG
from custom_apps.bingimagecreator.utils import DallEArgument,dalle3_generate_image
from custom_apps.bingart.bingart import BingArtGenerator
from custom_apps.imagen.custom_imagen import imagen_generate_image
from rest.app.utils.parsing_utils import download_range
from custom_apps.imagen.custom_imagen import imagen_generate_image, imagen_generate_image_path
from main_rest.app.utils.parsing_utils import download_range
from custom_apps.utils import cookie_manager
router = APIRouter(prefix="/services")
@@ -144,6 +144,40 @@ async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
return response.set_message(img_len=img_length)
except Exception as e:
LOG.error(traceback.format_exc())
return response.set_error(error=e)
@router.post("/vactorImageSearch/imageGenerate/imagen", summary="벡터 이미지 검색 - imagen", response_model=M.ResponseBase)
async def vactor_image(request: Request, request_body_info: M.VactorImageSearchReq):
"""
## 벡터 이미지 검색 - imagen
> imagen AI를 이용하여 이미지 생성 vactor 검색
### Requriements
> - googlecli 설치(https://cloud.google.com/sdk/docs/install?hl=ko#linux)
> - const.py 지정한 OUTPUT_FOLDER 하위에 imagen 폴더가 있어야함.
"""
response = M.ResponseBase()
try:
if request_body_info.index_type not in [M.IndexType.hnsw, M.IndexType.l2]:
raise Exception(f"index_type is hnsw or l2 (current value = {request_body_info.index_type})")
img_path = imagen_generate_image_path(image_prompt=request_body_info.prompt)
vactor_request_data = {'quary_image_path' : img_path,'index_type' : request_body_info.index_type, 'search_num' : request_body_info.search_num}
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search', data=json.dumps(vactor_request_data))
if vactor_response.status_code != 200:
raise Exception(f"response error: {json.loads(vactor_response.text)['error']}")
if json.loads(vactor_response.text)["error"] != None:
raise Exception(f"vactor error: {json.loads(vactor_response.text)['error']}")
return response.set_message()
except Exception as e:
LOG.error(traceback.format_exc())
return response.set_error(error=e)

View File

@@ -13,13 +13,13 @@ from fastapi import APIRouter
from starlette.requests import Request
import bcrypt
from rest.app.common import consts
from rest.app import models as M
from rest.app.common.config import conf
from rest.app.database.schema import Users
from rest.app.database.crud import table_select, table_update, table_delete
from main_rest.app.common import consts
from main_rest.app import models as M
from main_rest.app.common.config import conf
from main_rest.app.database.schema import Users
from main_rest.app.database.crud import table_select, table_update, table_delete
from rest.app.utils.extra import AESCryptoCBC
from main_rest.app.utils.extra import AESCryptoCBC
router = APIRouter(prefix='/user')

View File

@@ -26,10 +26,10 @@ from itertools import groupby
from operator import attrgetter
import uuid
from rest.app.common.consts import NUM_RETRY_UUID_GEN, SMTP_HOST, SMTP_PORT
from rest.app.utils.date_utils import D
from rest.app import models as M
from rest.app.common.consts import AES_CBC_PUBLIC_KEY, AES_CBC_IV, FERNET_SECRET_KEY
from main_rest.app.common.consts import NUM_RETRY_UUID_GEN, SMTP_HOST, SMTP_PORT
from main_rest.app.utils.date_utils import D
from main_rest.app import models as M
from main_rest.app.common.consts import AES_CBC_PUBLIC_KEY, AES_CBC_IV, FERNET_SECRET_KEY
async def send_mail(sender, sender_pw, title, recipient, contents_plain, contents_html, cc_list, smtp_host=SMTP_HOST, smtp_port=SMTP_PORT):

View File

@@ -12,20 +12,26 @@ cryptography
pycryptodomex
pycryptodome
email-validator
requests
#imagen
google-cloud-aiplatform
Pillow
#bing img
# #bing img
aiohttp
regex
requests
httpx
nest_asyncio
#bing art
bingart==1.10
# #bing art
bingart==1.1.0
#DALL-E 3
# openai
# openai
# faiss
# scikit-learn
# flax
# tensorflow

18
requirements_vactor.txt Normal file
View File

@@ -0,0 +1,18 @@
#rest
fastapi==0.115.0
uvicorn==0.16.0
pymysql==1.0.2
sqlalchemy==1.4.29
bcrypt==3.2.0
pyjwt==2.3.0
yagmail==0.14.261
boto3==1.20.32
pytest==6.2.5
cryptography
pycryptodomex
pycryptodome
email-validator
#faiss
scikit-learn
pillow

6
rest_main.py Normal file
View File

@@ -0,0 +1,6 @@
import uvicorn
from main_rest.app.common.config import conf
if __name__ == '__main__':
uvicorn.run('main_rest.app.main:app', host='0.0.0.0', port=conf().REST_SERVER_PORT, reload=True)

6
rest_vactor.py Normal file
View File

@@ -0,0 +1,6 @@
import uvicorn
from vactor_rest.app.common.config import conf
if __name__ == '__main__':
uvicorn.run('vactor_rest.app.main:app', host='0.0.0.0', port=conf().REST_SERVER_PORT, reload=True)

49
vactor_rest/.travis.yml Normal file
View File

@@ -0,0 +1,49 @@
os: linux
dist: bionic
language: python
services:
- docker
- mysql
python:
- 3.8
before_install:
- pip install awscli
- export PATH=$PATH:$HOME/.local/bin
script:
- pytest -v
before_deploy:
- if [ $TRAVIS_BRANCH == "master" ]; then export EB_ENV=notification-prd-api; fi
- if [ $TRAVIS_BRANCH == "develop" ]; then export EB_ENV=notification-dev-api; fi
- export REPO_NAME=$(echo $TRAVIS_REPO_SLUG | sed "s_^.*/__")
- export ELASTIC_BEANSTALK_LABEL=${REPO_NAME}-${TRAVIS_COMMIT::7}-$(date +%y%m%d%H%M%S)
deploy:
skip_cleanup: true
provider: elasticbeanstalk
access_key_id: $AWS_ACCESS
secret_access_key: $AWS_SECRET
region: ap-northeast-2
bucket: elasticbeanstalk-ap-northeast-2-884465654078
bucket_path: notification-api
app: notification-api
env: $EB_ENV
on:
all_branches: true
condition: $TRAVIS_BRANCH =~ ^develop|master
#notifications:
# slack:
# - rooms:
# - secure: ***********
# if: branch = master
# template:
# - "Repo `%{repository_slug}` *%{result}* build (<%{build_url}|#%{build_number}>) for commit (<%{compare_url}|%{commit}>) on branch `%{branch}`."
# - rooms:
# - secure: ***********
# if: branch = staging
# template:
# - "Repo `%{repository_slug}` *%{result}* build (<%{build_url}|#%{build_number}>) for commit (<%{compare_url}|%{commit}>) on branch `%{branch}`."
# - rooms:
# - secure: ***********
# if: branch = develop
# template:
# - "Repo `%{repository_slug}` *%{result}* build (<%{build_url}|#%{build_number}>) for commit (<%{compare_url}|%{commit}>) on branch `%{branch}`."

View File

@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
"""
@File: api_request_sample.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: test api key
"""
import base64
import hmac
from datetime import datetime, timedelta
import requests
def parse_params_to_str(params):
url = "?"
for key, value in params.items():
url = url + str(key) + '=' + str(value) + '&'
return url[1:-1]
def hash_string(qs, secret_key):
mac = hmac.new(bytes(secret_key, encoding='utf8'), bytes(qs, encoding='utf-8'), digestmod='sha256')
d = mac.digest()
validating_secret = str(base64.b64encode(d).decode('utf-8'))
return validating_secret
def sample_request():
access_key = 'c0883231-4aa9-4a1f-a77b-3ef250af-e449-42e9-856a-b3ada17c426b'
secret_key = 'QhOaeXTAAkW6yWt31jWDeERkBsZ3X4UmPds656YD'
cur_time = datetime.utcnow()+timedelta(hours=9)
cur_timestamp = int(cur_time.timestamp())
qs = dict(key=access_key, timestamp=cur_timestamp)
header_secret = hash_string(parse_params_to_str(qs), secret_key)
url = f'http://127.0.0.1:8080/api/services?{parse_params_to_str(qs)}'
res = requests.get(url, headers=dict(secret=header_secret))
return res
print(sample_request().json())

View File

@@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-
"""
@File: config.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: Configurations
"""
from dataclasses import dataclass
from os import path, environ
from vactor_rest.app.common import consts
from vactor_rest.app.models import UserInfo
base_dir = path.dirname(path.dirname(path.dirname(path.abspath(__file__))))
@dataclass
class Config:
"""
기본 Configuration
"""
BASE_DIR: str = base_dir
DB_POOL_RECYCLE: int = 900
DB_ECHO: bool = True
DEBUG: bool = False
TEST_MODE: bool = False
DEV_TEST_CONNECT_ACCOUNT: str | None = None
# NOTE(hsj100): SERVICE_AUTH_API_KEY (token 방식으로 진행)
SERVICE_AUTH_API_KEY: bool = False
DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}:{consts.DB_PORT}/{consts.DB_NAME}?charset={consts.DB_CHARSET}')
REST_SERVER_PORT = consts.REST_SERVER_PORT
SW_TITLE = consts.SW_TITLE
SW_VERSION = consts.SW_VERSION
SW_DESCRIPTION = consts.SW_DESCRIPTION
TERMS_OF_SERVICE = consts.TERMS_OF_SERVICE
CONTEACT = consts.CONTEACT
LICENSE_INFO = consts.LICENSE_INFO
GLOBAL_TOKEN = consts.ADMIN_INIT_ACCOUNT_INFO.connect_token
COOKIES_AUTH = 'Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6MTQsImVtYWlsIjoia29hbGFAZGluZ3JyLmNvbSIsIm5hbWUiOm51bGwsInBob25lX251bWJlciI6bnVsbCwicHJvZmlsZV9pbWciOm51bGwsInNuc190eXBlIjpudWxsfQ.4vgrFvxgH8odoXMvV70BBqyqXOFa2NDQtzYkGywhV48'
@dataclass
class LocalConfig(Config):
TRUSTED_HOSTS = ['*']
ALLOW_SITE = ['*']
DEBUG: bool = True
@dataclass
class ProdConfig(Config):
TRUSTED_HOSTS = ['*']
ALLOW_SITE = ['*']
@dataclass
class TestConfig(Config):
DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}/{consts.DB_NAME}_test?charset={consts.DB_CHARSET}')
TRUSTED_HOSTS = ['*']
ALLOW_SITE = ['*']
TEST_MODE: bool = True
@dataclass
class DevConfig(Config):
TRUSTED_HOSTS = ['*']
ALLOW_SITE = ['*']
DEBUG: bool = True
DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}/{consts.DB_NAME}_dev?charset={consts.DB_CHARSET}')
REST_SERVER_PORT = consts.REST_SERVER_PORT + 1
SW_TITLE = '[Dev] ' + consts.SW_TITLE
@dataclass
class MyConfig(Config):
TRUSTED_HOSTS = ['*']
ALLOW_SITE = ['*']
DEBUG: bool = True
DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}/{consts.DB_NAME}_my?charset={consts.DB_CHARSET}')
REST_SERVER_PORT = consts.REST_SERVER_PORT + 2
# NOTE(hsj100): DEV_TEST_CONNECT_ACCOUNT
# DEV_TEST_CONNECT_ACCOUNT: UserInfo = UserInfo(**consts.ADMIN_INIT_ACCOUNT_INFO)
# DEV_TEST_CONNECT_ACCOUNT: str = None
# NOTE(hsj100): SERVICE_AUTH_API_KEY (token 방식으로 진행)
SERVICE_AUTH_API_KEY: bool = False
SW_TITLE = '[My] ' + consts.SW_TITLE
def conf():
"""
환경 불러오기
:return:
"""
config = dict(prod=ProdConfig, local=LocalConfig, test=TestConfig, dev=DevConfig, my=MyConfig)
return config[environ.get('API_ENV', 'local')]()
return config[environ.get('API_ENV', 'dev')]()
return config[environ.get('API_ENV', 'my')]()
return config[environ.get('API_ENV', 'test')]()

View File

@@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
"""
@File: consts.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: consts
"""
# SUPPORT PROJECT
SUPPORT_PROJECT_BASIC = 'PROJECT_BASIC'
PROJECT_NAME = 'FERMAT-TEST(Vactor REST API)'
SW_TITLE= f'{PROJECT_NAME} - REST API'
SW_VERSION = '0.1.0'
SW_DESCRIPTION = f'''
### FERMAT-TEST(Vactor REST API) REST API
## API 이용법
- 개별 API 설명과 Request/Response schema 참조
'''
TERMS_OF_SERVICE = 'http://www.a2tec.co.kr'
CONTEACT={
'name': 'A2TEC (주)에이투텍',
'url': 'http://www.a2tec.co.kr',
'email': 'marketing@a2tec.co.kr'
}
LICENSE_INFO = {
'name': 'Copyright by A2TEC', 'url': 'http://www.a2tec.co.kr'
}
REST_SERVER_PORT = 51002
DEFAULT_USER_ACCOUNT_PW = '1234'
class AdminInfo:
def __init__(self):
self.id: int = 1
self.user_type: str = 'admin'
self.account: str = 'a2d2_lc_manager@naver.com' # !ekdnfeldpsdptm1 다울디엔에스
self.pw: str = '$2b$12$PklBvVXdLhOQnIiNanlnIu.DJh5MspRARVChJQfFu1qg35vBoIuX2'
self.name: str = 'administrator'
self.email: str = 'a2d2_lc_manager@naver.com' # daool1020
self.email_pw: str = 'gAAAAABioV5NucuS9nQugZJnz-KjVG_FGnaowB9KAfhOoWjjiQ4jGLuYJh4Qe94mT_lCm6m3HhuOJqUeOgjppwREDpIQYzrUXA=='
self.address: str = '대구광역시 동구 동촌로351 에이스빌딩 4F'
self.phone_number: str = '053-384-3010'
self.connect_token: str = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6MSwiYWNjb3VudCI6ImEyZDJfbGNfbWFuYWdlckBuYXZlci5jb20iLCJuYW1lIjoiYWRtaW5pc3RyYXRvciIsInBob25lX251bWJlciI6IjA1My0zODQtMzAxMCIsInByb2ZpbGVfaW1nIjpudWxsLCJhY2NvdW50X3R5cGUiOiJlbWFpbCJ9.SlQSCfAof1bv2YxmW2DO4dIBrbHLg1jPO3AJsX6xKbw'
def get_dict(self):
info = {}
for k, v in self.__dict__.items():
if type(v) is tuple:
info[k] = v[0]
else:
info[k] = v
return info
ADMIN_INIT_ACCOUNT_INFO = AdminInfo()
FERNET_SECRET_KEY = b'wQjpSYkmc4kX8MaAovk1NIHF02R2wZX760eeBTeIHW4='
AES_CBC_PUBLIC_KEY = b'daooldns12345678'
AES_CBC_IV = b'daooldns12345678'
COOKIES_AUTH = 'Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6MTQsImVtYWlsIjoia29hbGFAZGluZ3JyLmNvbSIsIm5hbWUiOm51bGwsInBob25lX251bWJlciI6bnVsbCwicHJvZmlsZV9pbWciOm51bGwsInNuc190eXBlIjpudWxsfQ.4vgrFvxgH8odoXMvV70BBqyqXOFa2NDQtzYkGywhV48'
JWT_SECRET = 'ABCD1234!'
JWT_ALGORITHM = 'HS256'
EXCEPT_PATH_LIST = ['/', '/openapi.json']
EXCEPT_PATH_REGEX = '^(/docs|/redoc|/api/auth' +\
'|/api/user/check_account_exist' +\
'|/api/services' + \
'|/api/temp' + \
'|/api/dev' + \
'|/static' + \
')'
MAX_API_KEY = 3
MAX_API_WHITELIST = 10
NUM_RETRY_UUID_GEN = 3
# DATABASE
DB_ADDRESS = "localhost"
DB_PORT = 53306
DB_USER_ID = 'root'
DB_USER_PW = '1234'
DB_NAME = 'FM_TEST'
DB_CHARSET = 'utf8mb4'
# MAIL
# SMTP_HOST = 'smtp.gmail.com'
# SMTP_PORT = 587
SMTP_HOST = 'smtp.naver.com'
SMTP_PORT = 587
MAIL_REG_TITLE = f'{PROJECT_NAME} - Registration'
MAIL_REG_CONTENTS = '''
안녕하세요.
감사합니다.
'''

View File

@@ -0,0 +1,142 @@
# -*- coding: utf-8 -*-
"""
@File: conn.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: DB Connections
"""
from fastapi import FastAPI
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
import logging
def _database_exist(engine, schema_name):
query = f'SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = "{schema_name}"'
with engine.connect() as conn:
result_proxy = conn.execute(query)
result = result_proxy.scalar()
return bool(result)
def _drop_database(engine, schema_name):
with engine.connect() as conn:
conn.execute(f'DROP DATABASE {schema_name};')
def _create_database(engine, schema_name):
with engine.connect() as conn:
conn.execute(f'CREATE DATABASE {schema_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_bin;')
class SQLAlchemy:
def __init__(self, app: FastAPI = None, **kwargs):
self._engine = None
self._session = None
if app is not None:
self.init_app(app=app, **kwargs)
def init_app(self, app: FastAPI, **kwargs):
"""
DB 초기화 함수
:param app: FastAPI 인스턴스
:param kwargs:
:return:
"""
database_url = kwargs.get('DB_URL')
pool_recycle = kwargs.setdefault('DB_POOL_RECYCLE', 900)
is_testing = kwargs.setdefault('TEST_MODE', False)
echo = kwargs.setdefault('DB_ECHO', True)
self._engine = create_engine(
database_url,
echo=echo,
pool_recycle=pool_recycle,
pool_pre_ping=True,
)
if is_testing: # create schema
db_url = self._engine.url
if db_url.host != 'localhost':
raise Exception('db host must be \'localhost\' in test environment')
except_schema_db_url = f'{db_url.drivername}://{db_url.username}:{db_url.password}@{db_url.host}:{db_url.port}'
schema_name = db_url.database
temp_engine = create_engine(except_schema_db_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True)
if _database_exist(temp_engine, schema_name):
_drop_database(temp_engine, schema_name)
_create_database(temp_engine, schema_name)
temp_engine.dispose()
else:
db_url = self._engine.url
except_schema_db_url = f'{db_url.drivername}://{db_url.username}:{db_url.password}@{db_url.host}:{db_url.port}'
schema_name = db_url.database
temp_engine = create_engine(except_schema_db_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True)
if not _database_exist(temp_engine, schema_name):
_create_database(temp_engine, schema_name)
Base.metadata.create_all(db.engine)
temp_engine.dispose()
self._session = sessionmaker(autocommit=False, autoflush=False, bind=self._engine)
# NOTE(hsj100): ADMINISTRATOR
create_admin(self._session)
@app.on_event('startup')
def startup():
self._engine.connect()
logging.info('DB connected.')
@app.on_event('shutdown')
def shutdown():
self._session.close_all()
self._engine.dispose()
logging.info('DB disconnected')
def get_db(self):
"""
요청마다 DB 세션 유지 함수
:return:
"""
if self._session is None:
raise Exception('must be called \'init_app\'')
db_session = None
try:
db_session = self._session()
yield db_session
finally:
db_session.close()
@property
def session(self):
return self.get_db
@property
def engine(self):
return self._engine
db = SQLAlchemy()
Base = declarative_base()
# NOTE(hsj100): ADMINISTRATOR
def create_admin(db_session):
import bcrypt
from vactor_rest.app.database.schema import Users
from vactor_rest.app.common.consts import ADMIN_INIT_ACCOUNT_INFO
session = db_session()
if not session:
raise Exception('cat`t create account of admin')
if Users.get(account=ADMIN_INIT_ACCOUNT_INFO.account):
return
admin = {**ADMIN_INIT_ACCOUNT_INFO.get_dict()}
Users.create(session=session, auto_commit=True, **admin)

View File

@@ -0,0 +1,300 @@
# -*- coding: utf-8 -*-
"""
@File: crud.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: CRUD
"""
import math
from datetime import datetime
from dateutil.relativedelta import relativedelta
from sqlalchemy import func, desc
from fastapi import APIRouter, Depends, Body
from sqlalchemy.orm import Session
from vactor_rest.app import models as M
from vactor_rest.app.database.conn import Base, db
from vactor_rest.app.database.schema import Users, UserLog
from vactor_rest.app.utils.extra import query_to_groupby, query_to_groupby_date
def get_month_info_list(start: datetime, end: datetime):
delta = relativedelta(end, start)
month_delta = 12 * delta.years + delta.months + (1 if delta.days > 0 else 0)
month_list = list()
for i in range(month_delta):
count = start + relativedelta(months=i)
month_info = dict()
month_info['period_title'] = count.strftime('%Y-%m')
month_info['year'] = count.year
month_info['month'] = count.month
month_info['start'] = (start + relativedelta(months=i)).replace(day=1)
month_info['end'] = (start + relativedelta(months=i + 1)).replace(day=1)
month_list.append(month_info)
return month_list
def request_parser(request_data: dict = None) -> dict:
"""
request information -> dict
:param request_data:
:return:
"""
result_dict = dict()
if not request_data:
return result_dict
for key, val in request_data.items():
if val is not None:
result_dict[key] = val if val != 'null' else None
return result_dict
def dict_to_filter_fmt(dict_data, get_attributes_callback=None):
"""
dict -> sqlalchemy filter (criterion: sql expression)
:param dict_data:
:param get_attributes_callback:
:return:
"""
if get_attributes_callback is None:
raise Exception('invalid get_attributes_callback')
criterion = list()
for key, val in dict_data.items():
key = key.split('__')
if len(key) > 2:
raise Exception('length of split(key) should be no more than 2.')
key_length = len(key)
col = get_attributes_callback(key[0])
if col is None:
continue
if key_length == 1:
criterion.append((col == val))
elif key_length == 2 and key[1] == 'gt':
criterion.append((col > val))
elif key_length == 2 and key[1] == 'gte':
criterion.append((col >= val))
elif key_length == 2 and key[1] == 'lt':
criterion.append((col < val))
elif key_length == 2 and key[1] == 'lte':
criterion.append((col <= val))
elif key_length == 2 and key[1] == 'in':
criterion.append((col.in_(val)))
elif key_length == 2 and key[1] == 'like':
criterion.append((col.like(val)))
return criterion
async def table_select(accessor_info, target_table, request_body_info, response_model, response_model_data):
"""
table_read
"""
try:
# parameter
if not accessor_info:
raise Exception('invalid accessor')
if not target_table:
raise Exception(f'invalid table_name:{target_table}')
# if not request_body_info:
# raise Exception('invalid request_body_info')
if not response_model:
raise Exception('invalid response_model')
if not response_model_data:
raise Exception('invalid response_model_data')
# paging - request
paging_request = None
if request_body_info:
if hasattr(request_body_info, 'paging'):
paging_request = request_body_info.paging
request_body_info = request_body_info.search
# request
request_info = request_parser(request_body_info.dict())
# search
criterion = None
if isinstance(request_body_info, M.UserLogDaySearchReq):
# UserLog search
# request
request_info = request_parser(request_body_info.dict())
# search UserLog
def get_attributes_callback(key: str):
return getattr(UserLog, key)
criterion = dict_to_filter_fmt(request_info, get_attributes_callback)
# search
session = next(db.session())
search_info = session.query(UserLog)\
.filter(UserLog.mac.isnot(None)
, UserLog.api == '/api/auth/login'
, UserLog.type == M.UserLogMessageType.info
, UserLog.message == 'ok'
, *criterion)\
.order_by(desc(UserLog.created_at), desc(UserLog.updated_at))\
.all()
if not search_info:
raise Exception('not found data')
group_by_day = query_to_groupby_date(search_info, 'updated_at')
# result
result_info = list()
for day, info_list in group_by_day.items():
info_by_mac = query_to_groupby(info_list, 'mac', first=True)
for log_info in info_by_mac.values():
result_info.append(response_model_data.from_orm(log_info))
else:
# basic search (single table)
# request
request_info = request_parser(request_body_info.dict())
# search
search_info = target_table.filter(**request_info).all()
if not search_info:
raise Exception('not found data')
# result
result_info = list()
for purchase_info in search_info:
result_info.append(response_model_data.from_orm(purchase_info))
# response - paging
paging_response = None
if paging_request:
total_contents_num = len(result_info)
total_page_num = math.ceil(total_contents_num / paging_request.page_contents_num)
start_contents_index = (paging_request.start_page - 1) * paging_request.page_contents_num
end_contents_index = start_contents_index + paging_request.page_contents_num
if end_contents_index < total_contents_num:
result_info = result_info[start_contents_index: end_contents_index]
else:
result_info = result_info[start_contents_index:]
paging_response = M.PagingRes()
paging_response.total_page_num = total_page_num
paging_response.total_contents_num = total_contents_num
paging_response.start_page = paging_request.start_page
paging_response.search_contents_num = len(result_info)
return response_model(data=result_info, paging=paging_response)
except Exception as e:
return response_model.set_error(str(e))
async def table_update(accessor_info, target_table, request_body_info, response_model, response_model_data=None):
search_info = None
try:
# parameter
if not accessor_info:
raise Exception('invalid accessor')
if not target_table:
raise Exception(f'invalid table_name:{target_table}')
if not request_body_info:
raise Exception('invalid request_body_info')
if not response_model:
raise Exception('invalid response_model')
# if not response_model_data:
# raise Exception('invalid response_model_data')
# request
if not request_body_info.search_info:
raise Exception('invalid request_body: search_info')
request_search_info = request_parser(request_body_info.search_info.dict())
if not request_search_info:
raise Exception('invalid request_body: search_info')
if not request_body_info.update_info:
raise Exception('invalid request_body: update_info')
request_update_info = request_parser(request_body_info.update_info.dict())
if not request_update_info:
raise Exception('invalid request_body: update_info')
# search
search_info = target_table.filter(**request_search_info)
# process
search_info.update(auto_commit=True, synchronize_session=False, **request_update_info)
# result
return response_model()
except Exception as e:
if search_info:
search_info.close()
return response_model.set_error(str(e))
async def table_delete(accessor_info, target_table, request_body_info, response_model, response_model_data=None):
search_info = None
try:
# request
if not accessor_info:
raise Exception('invalid accessor')
if not target_table:
raise Exception(f'invalid table_name:{target_table}')
if not request_body_info:
raise Exception('invalid request_body_info')
if not response_model:
raise Exception('invalid response_model')
# if not response_model_data:
# raise Exception('invalid response_model_data')
# request
request_search_info = request_parser(request_body_info.dict())
if not request_search_info:
raise Exception('invalid request_body')
# search
search_info = target_table.filter(**request_search_info)
temp_search = search_info.all()
# process
search_info.delete(auto_commit=True, synchronize_session=False)
# update license num
uuid_list = list()
for _license in temp_search:
if not hasattr(temp_search, 'uuid'):
# case: license
break
if _license.uuid not in uuid_list:
uuid_list.append(_license.uuid)
license_num = target_table.filter(uuid=_license.uuid).count()
target_table.filter(uuid=_license.uuid).update(auto_commit=True, synchronize_session=False, num=license_num)
# result
return response_model()
except Exception as e:
if search_info:
search_info.close()
return response_model.set_error(str(e))

View File

@@ -0,0 +1,304 @@
# -*- coding: utf-8 -*-
"""
@File: schema.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: database schema
"""
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
func,
Enum,
Boolean,
ForeignKey,
)
from sqlalchemy.orm import Session, relationship
from vactor_rest.app.database.conn import Base, db
from vactor_rest.app.utils.date_utils import D
from vactor_rest.app.models import (
SexType,
UserType,
MemberType,
AccountType,
UserStatusType,
UserLoginType,
UserLogMessageType
)
class BaseMixin:
id = Column(Integer, primary_key=True, index=True)
created_at = Column(DateTime, nullable=False, default=D.datetime())
updated_at = Column(DateTime, nullable=False, default=D.datetime(), onupdate=D.datetime())
def __init__(self):
self._q = None
self._session = None
self.served = None
def all_columns(self):
return [c for c in self.__table__.columns if c.primary_key is False and c.name != 'created_at']
def __hash__(self):
return hash(self.id)
@classmethod
def create(cls, session: Session, auto_commit=False, **kwargs):
"""
테이블 데이터 적재 전용 함수
:param session:
:param auto_commit: 자동 커밋 여부
:param kwargs: 적재 할 데이터
:return:
"""
obj = cls()
# NOTE(hsj100) : FIX_DATETIME
if 'created_at' not in kwargs:
obj.created_at = D.datetime()
if 'updated_at' not in kwargs:
obj.updated_at = D.datetime()
for col in obj.all_columns():
col_name = col.name
if col_name in kwargs:
setattr(obj, col_name, kwargs.get(col_name))
session.add(obj)
session.flush()
if auto_commit:
session.commit()
return obj
@classmethod
def get(cls, session: Session = None, **kwargs):
"""
Simply get a Row
:param session:
:param kwargs:
:return:
"""
sess = next(db.session()) if not session else session
query = sess.query(cls)
for key, val in kwargs.items():
col = getattr(cls, key)
query = query.filter(col == val)
if query.count() > 1:
raise Exception('Only one row is supposed to be returned, but got more than one.')
result = query.first()
if not session:
sess.close()
return result
@classmethod
def filter(cls, session: Session = None, **kwargs):
"""
Simply get a Row
:param session:
:param kwargs:
:return:
"""
cond = []
for key, val in kwargs.items():
key = key.split('__')
if len(key) > 2:
raise Exception('length of split(key) should be no more than 2.')
col = getattr(cls, key[0])
if len(key) == 1: cond.append((col == val))
elif len(key) == 2 and key[1] == 'gt': cond.append((col > val))
elif len(key) == 2 and key[1] == 'gte': cond.append((col >= val))
elif len(key) == 2 and key[1] == 'lt': cond.append((col < val))
elif len(key) == 2 and key[1] == 'lte': cond.append((col <= val))
elif len(key) == 2 and key[1] == 'in': cond.append((col.in_(val)))
elif len(key) == 2 and key[1] == 'like': cond.append((col.like(val)))
obj = cls()
if session:
obj._session = session
obj.served = True
else:
obj._session = next(db.session())
obj.served = False
query = obj._session.query(cls)
query = query.filter(*cond)
obj._q = query
return obj
@classmethod
def cls_attr(cls, col_name=None):
if col_name:
col = getattr(cls, col_name)
return col
else:
return cls
def order_by(self, *args: str):
for a in args:
if a.startswith('-'):
col_name = a[1:]
is_asc = False
else:
col_name = a
is_asc = True
col = self.cls_attr(col_name)
self._q = self._q.order_by(col.asc()) if is_asc else self._q.order_by(col.desc())
return self
def update(self, auto_commit: bool = False, synchronize_session='evaluate', **kwargs):
# NOTE(hsj100) : FIX_DATETIME
if 'updated_at' not in kwargs:
kwargs['updated_at'] = D.datetime()
qs = self._q.update(kwargs, synchronize_session=synchronize_session)
get_id = self.id
ret = None
self._session.flush()
if qs > 0 :
ret = self._q.first()
if auto_commit:
self._session.commit()
return ret
def first(self):
result = self._q.first()
self.close()
return result
def delete(self, auto_commit: bool = False, synchronize_session='evaluate'):
self._q.delete(synchronize_session=synchronize_session)
if auto_commit:
self._session.commit()
def all(self):
print(self.served)
result = self._q.all()
self.close()
return result
def count(self):
result = self._q.count()
self.close()
return result
def close(self):
if not self.served:
self._session.close()
else:
self._session.flush()
class ApiKeys(Base, BaseMixin):
__tablename__ = 'api_keys'
__table_args__ = {'extend_existing': True}
access_key = Column(String(length=64), nullable=False, index=True)
secret_key = Column(String(length=64), nullable=False)
user_memo = Column(String(length=40), nullable=True)
status = Column(Enum('active', 'stopped', 'deleted'), default='active')
is_whitelisted = Column(Boolean, default=False)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
whitelist = relationship('ApiWhiteLists', backref='api_keys')
users = relationship('Users', back_populates='keys')
class ApiWhiteLists(Base, BaseMixin):
__tablename__ = 'api_whitelists'
__table_args__ = {'extend_existing': True}
ip_addr = Column(String(length=64), nullable=False)
api_key_id = Column(Integer, ForeignKey('api_keys.id'), nullable=False)
class Users(Base, BaseMixin):
__tablename__ = 'users'
__table_args__ = {'extend_existing': True}
status = Column(Enum(UserStatusType), nullable=False, default=UserStatusType.active)
user_type = Column(Enum(UserType), nullable=False, default=UserType.user)
account_type = Column(Enum(AccountType), nullable=False, default=AccountType.email)
account = Column(String(length=255), nullable=False, unique=True)
pw = Column(String(length=2000), nullable=True)
email = Column(String(length=255), nullable=True)
name = Column(String(length=255), nullable=False)
sex = Column(Enum(SexType), nullable=False, default=SexType.male)
rrn = Column(String(length=255), nullable=True)
address = Column(String(length=2000), nullable=True)
phone_number = Column(String(length=20), nullable=True)
picture = Column(String(length=1000), nullable=True)
marketing_agree = Column(Boolean, nullable=False, default=False)
keys = relationship('ApiKeys', back_populates='users')
# extra
login = Column(Enum(UserLoginType), nullable=False, default=UserLoginType.logout) # TODO(hsj100): LOGIN_STATUS
member_type = Column(Enum(MemberType), nullable=False, default=MemberType.personal)
# uuid = Column(String(length=36), nullable=True, unique=True)
class UserLog(Base, BaseMixin):
__tablename__ = 'userlog'
__table_args__ = {'extend_existing': True}
account = Column(String(length=255), nullable=False)
mac = Column(String(length=255), nullable=True)
type = Column(Enum(UserLogMessageType), nullable=False, default=UserLogMessageType.info)
api = Column(String(length=511), nullable=False)
message = Column(String(length=5000), nullable=True)
# NOTE(hsj100): 다단계 자동 삭제
# user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'), nullable=False)
# user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
# users = relationship('Users', back_populates='userlog')
# users = relationship('Users', back_populates='userlog', lazy=False)
# ===============================================================================================
class CustomBaseMixin(BaseMixin):
id = Column(Integer, primary_key=True, index=True)
def all_columns(self):
return [c for c in self.__table__.columns if c.primary_key is False]
@classmethod
def create(cls, session: Session, auto_commit=False, **kwargs):
"""
테이블 데이터 적재 전용 함수
:param session:
:param auto_commit: 자동 커밋 여부
:param kwargs: 적재 할 데이터
:return:
"""
obj = cls()
for col in obj.all_columns():
col_name = col.name
if col_name in kwargs:
setattr(obj, col_name, kwargs.get(col_name))
session.add(obj)
session.flush()
if auto_commit:
session.commit()
return obj
def update(self, auto_commit: bool = False, synchronize_session='evaluate', **kwargs):
qs = self._q.update(kwargs, synchronize_session=synchronize_session)
get_id = self.id
ret = None
self._session.flush()
if qs > 0 :
ret = self._q.first()
if auto_commit:
self._session.commit()
return ret

View File

@@ -0,0 +1,188 @@
from vactor_rest.app.common.consts import MAX_API_KEY, MAX_API_WHITELIST
class StatusCode:
HTTP_500 = 500
HTTP_400 = 400
HTTP_401 = 401
HTTP_403 = 403
HTTP_404 = 404
HTTP_405 = 405
class APIException(Exception):
status_code: int
code: str
msg: str
detail: str
ex: Exception
def __init__(
self,
*,
status_code: int = StatusCode.HTTP_500,
code: str = '000000',
msg: str | None = None,
detail: str | None = None,
ex: Exception = None,
):
self.status_code = status_code
self.code = code
self.msg = msg
self.detail = detail
self.ex = ex
super().__init__(ex)
class NotFoundUserEx(APIException):
def __init__(self, user_id: int = None, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_404,
msg=f'해당 유저를 찾을 수 없습니다.',
detail=f'Not Found User ID : {user_id}',
code=f'{StatusCode.HTTP_400}{"1".zfill(4)}',
ex=ex,
)
class NotAuthorized(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_401,
msg=f'로그인이 필요한 서비스 입니다.',
detail='Authorization Required',
code=f'{StatusCode.HTTP_401}{"1".zfill(4)}',
ex=ex,
)
class TokenExpiredEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'세션이 만료되어 로그아웃 되었습니다.',
detail='Token Expired',
code=f'{StatusCode.HTTP_400}{"1".zfill(4)}',
ex=ex,
)
class TokenDecodeEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'비정상적인 접근입니다.',
detail='Token has been compromised.',
code=f'{StatusCode.HTTP_400}{"2".zfill(4)}',
ex=ex,
)
class NoKeyMatchEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_404,
msg=f'해당 키에 대한 권한이 없거나 해당 키가 없습니다.',
detail='No Keys Matched',
code=f'{StatusCode.HTTP_404}{"3".zfill(4)}',
ex=ex,
)
class MaxKeyCountEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'API 키 생성은 {MAX_API_KEY}개 까지 가능합니다.',
detail='Max Key Count Reached',
code=f'{StatusCode.HTTP_400}{"4".zfill(4)}',
ex=ex,
)
class MaxWLCountEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'화이트리스트 생성은 {MAX_API_WHITELIST}개 까지 가능합니다.',
detail='Max Whitelist Count Reached',
code=f'{StatusCode.HTTP_400}{"5".zfill(4)}',
ex=ex,
)
class InvalidIpEx(APIException):
def __init__(self, ip: str, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'{ip}는 올바른 IP 가 아닙니다.',
detail=f'invalid IP : {ip}',
code=f'{StatusCode.HTTP_400}{"6".zfill(4)}',
ex=ex,
)
class SqlFailureEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_500,
msg=f'이 에러는 서버측 에러 입니다. 자동으로 리포팅 되며, 빠르게 수정하겠습니다.',
detail='Internal Server Error',
code=f'{StatusCode.HTTP_500}{"2".zfill(4)}',
ex=ex,
)
class APIQueryStringEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'쿼리스트링은 key, timestamp 2개만 허용되며, 2개 모두 요청시 제출되어야 합니다.',
detail='Query String Only Accept key and timestamp.',
code=f'{StatusCode.HTTP_400}{"7".zfill(4)}',
ex=ex,
)
class APIHeaderInvalidEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'헤더에 키 해싱된 Secret 이 없거나, 유효하지 않습니다.',
detail='Invalid HMAC secret in Header',
code=f'{StatusCode.HTTP_400}{"8".zfill(4)}',
ex=ex,
)
class APITimestampEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'쿼리스트링에 포함된 타임스탬프는 KST 이며, 현재 시간보다 작아야 하고, 현재시간 - 10초 보다는 커야 합니다.',
detail='timestamp in Query String must be KST, Timestamp must be less than now, and greater than now - 10.',
code=f'{StatusCode.HTTP_400}{"9".zfill(4)}',
ex=ex,
)
class NotFoundAccessKeyEx(APIException):
def __init__(self, api_key: str, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_404,
msg=f'API 키를 찾을 수 없습니다.',
detail=f'Not found such API Access Key : {api_key}',
code=f'{StatusCode.HTTP_404}{"10".zfill(4)}',
ex=ex,
)
class KakaoSendFailureEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'카카오톡 전송에 실패했습니다.',
detail=f'Failed to send KAKAO MSG.',
code=f'{StatusCode.HTTP_400}{"11".zfill(4)}',
ex=ex,
)

116
vactor_rest/app/main.py Normal file
View File

@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
"""
@File: main.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: Main
"""
from dataclasses import asdict
import uvicorn
from fastapi import FastAPI, Depends
from fastapi.security import APIKeyHeader
from fastapi.staticfiles import StaticFiles
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from vactor_rest.app.common import consts
from vactor_rest.app.database.conn import db
from vactor_rest.app.common.config import conf
from vactor_rest.app.middlewares.token_validator import access_control
from vactor_rest.app.middlewares.trusted_hosts import TrustedHostMiddleware
from vactor_rest.app.routes import dev, index, auth, users, services
from contextlib import asynccontextmanager
from custom_logger.vactor_log import vactor_logger as LOG
API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False)
@asynccontextmanager
async def lifespan(app: FastAPI):
# When service starts.
LOG.info("REST start")
yield
# When service is stopped.
LOG.info("REST shutdown")
def create_app():
"""
fast_api_app 생성
:return: fast_api_app
"""
configurations = conf()
fast_api_app = FastAPI(
title=configurations.SW_TITLE,
version=configurations.SW_VERSION,
description=configurations.SW_DESCRIPTION,
terms_of_service=configurations.TERMS_OF_SERVICE,
contact=configurations.CONTEACT,
license_info={'name': 'Copyright by A2TEC', 'url': 'http://www.a2tec.co.kr'},
lifespan=lifespan
)
# 데이터 베이스 이니셜라이즈
conf_dict = asdict(configurations)
db.init_app(fast_api_app, **conf_dict)
# 레디스 이니셜라이즈
# 미들웨어 정의
fast_api_app.add_middleware(middleware_class=BaseHTTPMiddleware, dispatch=access_control)
fast_api_app.add_middleware(
CORSMiddleware,
allow_origins=conf().ALLOW_SITE,
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
fast_api_app.add_middleware(TrustedHostMiddleware, allowed_hosts=conf().TRUSTED_HOSTS, except_path=['/health'])
# 라우터 정의
fast_api_app.include_router(index.router, tags=['Defaults'])
# if conf().DEBUG:
# fast_api_app.include_router(auth.router, tags=['Authentication'], prefix='/api')
fast_api_app.include_router(services.router, tags=['Services'], prefix='/api', dependencies=[Depends(API_KEY_HEADER)])
# if conf().DEBUG:
# fast_api_app.include_router(users.router, tags=['Users'], prefix='/api', dependencies=[Depends(API_KEY_HEADER)])
# fast_api_app.include_router(dev.router, tags=['Developments'], prefix='/api', dependencies=[Depends(API_KEY_HEADER)])
import os
# fast_api_app.mount('/static', StaticFiles(directory=os.path.abspath('./rest/app/static')), name="static")
return fast_api_app
app = create_app()
#TODO(jwkim): 422 error handler
# @app.exception_handler(RequestValidationError)
# async def validation_exception_handler(request, exc):
# err = exc.errors()
# return JSONResponse(
# status_code=422,
# content={
# "result": False,
# "error": f"VALIDATION ERROR : {err}",
# "data": None},
# )
if __name__ == '__main__':
uvicorn.run('main:app', host='0.0.0.0', port=conf().REST_SERVER_PORT, reload=True)

View File

@@ -0,0 +1,166 @@
import base64
import hmac
import json
import time
import typing
import re
import jwt
import sqlalchemy.exc
from jwt.exceptions import ExpiredSignatureError, DecodeError
from starlette.requests import Request
from starlette.responses import JSONResponse
from vactor_rest.app.common.consts import EXCEPT_PATH_LIST, EXCEPT_PATH_REGEX
from vactor_rest.app.database.conn import db
from vactor_rest.app.database.schema import Users, ApiKeys
from vactor_rest.app.errors import exceptions as ex
from vactor_rest.app.common import consts
from vactor_rest.app.common.config import conf
from vactor_rest.app.errors.exceptions import APIException, SqlFailureEx, APIQueryStringEx
from vactor_rest.app.models import UserToken
from vactor_rest.app.utils.date_utils import D
from vactor_rest.app.utils.logger import api_logger
from vactor_rest.app.utils.query_utils import to_dict
from dataclasses import asdict
async def access_control(request: Request, call_next):
request.state.req_time = D.datetime()
request.state.start = time.time()
request.state.inspect = None
request.state.user = None
request.state.service = None
ip = request.headers['x-forwarded-for'] if 'x-forwarded-for' in request.headers.keys() else request.client.host
request.state.ip = ip.split(',')[0] if ',' in ip else ip
headers = request.headers
cookies = request.cookies
url = request.url.path
if await url_pattern_check(url, EXCEPT_PATH_REGEX) or url in EXCEPT_PATH_LIST:
response = await call_next(request)
if url != '/':
await api_logger(request=request, response=response)
return response
try:
if url.startswith('/api'):
# api 인경우 헤더로 토큰 검사
# NOTE(hsj100): SERVICE_AUTH_API_KEY (token 방식으로 진행)
if url.startswith('/api/services') and conf().SERVICE_AUTH_API_KEY:
qs = str(request.query_params)
qs_list = qs.split('&')
session = next(db.session())
if not conf().DEBUG:
try:
qs_dict = {qs_split.split('=')[0]: qs_split.split('=')[1] for qs_split in qs_list}
except Exception:
raise ex.APIQueryStringEx()
qs_keys = qs_dict.keys()
if 'key' not in qs_keys or 'timestamp' not in qs_keys:
raise ex.APIQueryStringEx()
if 'secret' not in headers.keys():
raise ex.APIHeaderInvalidEx()
api_key = ApiKeys.get(session=session, access_key=qs_dict['key'])
if not api_key:
raise ex.NotFoundAccessKeyEx(api_key=qs_dict['key'])
mac = hmac.new(bytes(api_key.secret_key, encoding='utf8'), bytes(qs, encoding='utf-8'), digestmod='sha256')
d = mac.digest()
validating_secret = str(base64.b64encode(d).decode('utf-8'))
if headers['secret'] != validating_secret:
raise ex.APIHeaderInvalidEx()
now_timestamp = int(D.datetime(diff=9).timestamp())
if now_timestamp - 10 > int(qs_dict['timestamp']) or now_timestamp < int(qs_dict['timestamp']):
raise ex.APITimestampEx()
user_info = to_dict(api_key.users)
request.state.user = UserToken(**user_info)
else:
# Request User 가 필요함
if 'authorization' in headers.keys():
key = headers.get('Authorization')
api_key_obj = ApiKeys.get(session=session, access_key=key)
user_info = to_dict(Users.get(session=session, id=api_key_obj.user_id))
request.state.user = UserToken(**user_info)
# 토큰 없음
else:
if 'Authorization' not in headers.keys():
raise ex.NotAuthorized()
session.close()
response = await call_next(request)
return response
else:
if 'authorization' in headers.keys():
# 토큰 존재
token_info = await token_decode(access_token=headers.get('Authorization'))
request.state.user = UserToken(**token_info)
elif conf().DEV_TEST_CONNECT_ACCOUNT:
# NOTE(hsj100): DEV_TEST_CONNECT_ACCOUNT
request.state.user = UserToken.from_orm(conf().DEV_TEST_CONNECT_ACCOUNT)
else:
# 토큰 없음
if 'Authorization' not in headers.keys():
raise ex.NotAuthorized()
else:
# 템플릿 렌더링인 경우 쿠키에서 토큰 검사
cookies['Authorization'] = conf().COOKIES_AUTH
if 'Authorization' not in cookies.keys():
raise ex.NotAuthorized()
token_info = await token_decode(access_token=cookies.get('Authorization'))
request.state.user = UserToken(**token_info)
response = await call_next(request)
await api_logger(request=request, response=response)
except Exception as e:
error = await exception_handler(e)
error_dict = dict(status=error.status_code, msg=error.msg, detail=error.detail, code=error.code)
response = JSONResponse(status_code=error.status_code, content=error_dict)
await api_logger(request=request, error=error)
return response
async def url_pattern_check(path, pattern):
result = re.match(pattern, path)
if result:
return True
return False
async def token_decode(access_token):
"""
:param access_token:
:return:
"""
try:
access_token = access_token.replace('Bearer ', "")
payload = jwt.decode(access_token, key=consts.JWT_SECRET, algorithms=[consts.JWT_ALGORITHM])
except ExpiredSignatureError:
raise ex.TokenExpiredEx()
except DecodeError:
raise ex.TokenDecodeEx()
return payload
async def exception_handler(error: Exception):
print(error)
if isinstance(error, sqlalchemy.exc.OperationalError):
error = SqlFailureEx(ex=error)
if not isinstance(error, APIException):
error = APIException(ex=error, detail=str(error))
return error

View File

@@ -0,0 +1,63 @@
import typing
from starlette.datastructures import URL, Headers
from starlette.responses import PlainTextResponse, RedirectResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send
ENFORCE_DOMAIN_WILDCARD = 'Domain wildcard patterns must be lNo module named \'app\'ike \'*.example.com\'.'
class TrustedHostMiddleware:
def __init__(
self,
app: ASGIApp,
allowed_hosts: typing.Sequence[str] = None,
except_path: typing.Sequence[str] = None,
www_redirect: bool = True,
) -> None:
if allowed_hosts is None:
allowed_hosts = ['*']
if except_path is None:
except_path = []
for pattern in allowed_hosts:
assert '*' not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
if pattern.startswith('*') and pattern != '*':
assert pattern.startswith('*.'), ENFORCE_DOMAIN_WILDCARD
self.app = app
self.allowed_hosts = list(allowed_hosts)
self.allow_any = '*' in allowed_hosts
self.www_redirect = www_redirect
self.except_path = list(except_path)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.allow_any or scope['type'] not in ('http', 'websocket',): # pragma: no cover
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
host = headers.get('host', "").split(':')[0]
is_valid_host = False
found_www_redirect = False
for pattern in self.allowed_hosts:
if (
host == pattern
or (pattern.startswith('*') and host.endswith(pattern[1:]))
or URL(scope=scope).path in self.except_path
):
is_valid_host = True
break
elif 'www.' + host == pattern:
found_www_redirect = True
if is_valid_host:
await self.app(scope, receive, send)
else:
if found_www_redirect and self.www_redirect:
url = URL(scope=scope)
redirect_url = url.replace(netloc='www.' + url.netloc)
response = RedirectResponse(url=str(redirect_url)) # type: Response
else:
response = PlainTextResponse('Invalid host header', status_code=400)
await response(scope, receive, send)

581
vactor_rest/app/models.py Normal file
View File

@@ -0,0 +1,581 @@
# -*- coding: utf-8 -*-
"""
@File: models.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: data models
"""
from datetime import datetime
from enum import Enum
from typing import List
from pydantic import TypeAdapter
from pydantic import Field, ConfigDict
from pydantic.main import BaseModel
from pydantic.networks import EmailStr, IPvAnyAddress
from typing import Optional
from vactor_rest.app.common.consts import (
SW_TITLE,
SW_VERSION,
MAIL_REG_TITLE,
MAIL_REG_CONTENTS,
SMTP_HOST,
SMTP_PORT,
ADMIN_INIT_ACCOUNT_INFO,
DEFAULT_USER_ACCOUNT_PW
)
from vactor_rest.app.utils.date_utils import D
class SWInfo(BaseModel):
"""
### 서비스 정보
"""
name: str = Field(SW_TITLE, description='SW 이름', example=SW_VERSION)
version: str = Field(SW_VERSION, description='SW 버전', example=SW_VERSION)
date: str = Field(D.date_str(), description='현재날짜', example='%Y.%m.%dT%H:%M:%S')
data1: str = Field(None, description='TestData1', example=None)
data2: str = Field(None, description='TestData2', example=None)
data3: str = Field(None, description='TestData3', example=None)
data4: str = Field(None, description='TestData4', example=None)
data5: str = Field(None, description='TestData1', example=None)
data6: str = Field(None, description='TestData2', example=None)
data7: str = Field(None, description='TestData3', example=None)
data8: str = Field(None, description='TestData4', example=None)
class CustomEnum(Enum):
@classmethod
def get_elements_str(cls, is_key=True):
if is_key:
result_str = cls.__members__.keys()
else:
result_str = cls.__members__.values()
return '[' + ', '.join(result_str) + ']'
class SexType(str, CustomEnum):
male: str = 'male'
female: str = 'female'
class AccountType(str, CustomEnum):
email: str = 'email'
# facebook: str = 'facebook'
# google: str = 'google'
# kakao: str = 'kakao'
class UserStatusType(str, CustomEnum):
active: str = 'active'
deleted: str = 'deleted'
blocked: str = 'blocked'
class UserLoginType(str, CustomEnum):
login: str = 'login'
logout: str = 'logout'
class MemberType(str, CustomEnum):
personal: str = 'personal'
company: str = 'company'
class UserType(str, CustomEnum):
admin: str = 'admin'
user: str = 'user'
class UserLogMessageType(str, CustomEnum):
info: str = 'info'
error: str = 'error'
class Token(BaseModel):
Authorization: str = Field(None, description='인증키', example='Bearer [token]')
class EmailRecipients(BaseModel):
name: str
email: str
class SendEmail(BaseModel):
email_to: List[EmailRecipients] = None
class KakaoMsgBody(BaseModel):
msg: str | None = None
class MessageOk(BaseModel):
message: str = Field(default='OK')
class UserToken(BaseModel):
id: int
account: str | None = None
name: str | None = None
phone_number: str | None = None
profile_img: str | None = None
account_type: str | None = None
# model_config = ConfigDict(from_attributes=True)
class AddApiKey(BaseModel):
user_memo: str | None = None
model_config = ConfigDict(from_attributes=True)
class GetApiKeyList(AddApiKey):
id: int = None
access_key: str | None = None
created_at: datetime = None
class GetApiKeys(GetApiKeyList):
secret_key: str | None = None
class CreateAPIWhiteLists(BaseModel):
ip_addr: str | None = None
class GetAPIWhiteLists(CreateAPIWhiteLists):
id: int
model_config = ConfigDict(from_attributes=True)
class ResponseBase(BaseModel):
"""
### [Response] API End-Point
**정상처리**\n
- result: true\n
- error: null\n
**오류발생**\n
- result: false\n
- error: 오류내용\n
"""
result: bool = Field(True, description='처리상태(성공: true, 실패: false)', example=True)
error: str | None = Field(None, description='오류내용(성공: null, 실패: 오류내용)', example=None)
@staticmethod
def set_error(error):
ResponseBase.result = False
ResponseBase.error = str(error)
return ResponseBase
@staticmethod
def set_message():
ResponseBase.result = True
ResponseBase.error = None
return ResponseBase
model_config = ConfigDict(from_attributes=True)
class PagingReq(BaseModel):
"""
### [Request] 페이징 정보
"""
start_page: int = Field(None, description='시작 페이지 번호(base: 1)', example=1)
page_contents_num: int = Field(None, description='페이지 내용 개수', example=2)
model_config = ConfigDict(from_attributes=True)
class PagingRes(BaseModel):
"""
### [Response] 페이징 정보
"""
total_page_num: int = Field(None, description='전체 페이지 개수', example=100)
total_contents_num: int = Field(None, description='전체 내용 개수', example=100)
start_page: int = Field(None, description='시작 페이지 번호(base: 1)', example=1)
search_contents_num: int = Field(None, description='검색된 내용 개수', example=100)
model_config = ConfigDict(from_attributes=True)
class TokenRes(ResponseBase):
"""
### [Response] 토큰 정보
"""
Authorization: str = Field(None, description='인증키', example='Bearer [token]')
model_config = ConfigDict(from_attributes=True)
class UserLogInfo(BaseModel):
"""
### 사용자 로그 정보
"""
id: int = Field(None, description='Table Index', example='1')
created_at: datetime = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
account: str = Field(None, description='계정', example='user1@test.com')
mac: str = Field(None, description='MAC(네트워크 인터페이스 식별자)', example='11:22:33:44:55:66')
type: str = Field(None, description='로그 타입' + UserLogMessageType.get_elements_str(), example=UserLogMessageType.info)
api: str = Field(None, description='API 이름', example='/api/auth/login')
message: str = Field(None, description='로그 내용', example='ok')
model_config = ConfigDict(from_attributes=True)
class UserInfo(BaseModel):
"""
### 유저 정보
"""
id: int = Field(None, description='Table Index', example='1')
created_at: datetime = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
status: UserStatusType = Field(None, description='계정상태' + UserStatusType.get_elements_str(), example=UserStatusType.active)
user_type: UserType = Field(None, description='유저 타입' + UserType.get_elements_str(), example=UserType.user)
account_type: AccountType = Field(None, description='계정종류' + AccountType.get_elements_str(), example=AccountType.email)
account: str = Field(None, description='계정', example='user1@test.com')
email: str = Field(None, description='전자메일', example='user1@test.com')
name: str = Field(None, description='이름', example='user1')
sex: str = Field(None, description='성별' + SexType.get_elements_str(), example=SexType.male)
rrn: str = Field(None, description='주민등록번호', example='123456-1234567')
address: str = Field(None, description='주소', example='대구1')
phone_number: str = Field(None, description='연락처', example='010-1234-1234')
picture: str = Field(None, description='프로필사진', example='profile1.png')
marketing_agree: bool = Field(False, description='마케팅동의 여부', example=False)
# extra
login: UserLoginType = Field(None, description='로그인 상태' + UserLoginType.get_elements_str(), example=UserLoginType.logout) # TODO(hsj100): LOGIN_STATUS
member_type: MemberType = Field(None, description='회원 타입' + MemberType.get_elements_str(), example=MemberType.personal)
model_config = ConfigDict(from_attributes=True)
class SendMailReq(BaseModel):
"""
### [Request] 메일 전송
"""
smtp_host: str = Field(None, description='SMTP 서버 주소', example=SMTP_HOST)
smtp_port: int = Field(None, description='SMTP 서버 포트', example=SMTP_PORT)
title: str = Field(None, description='제목', example=MAIL_REG_TITLE)
recipient: str = Field(None, description='수신자', example='user1@test.com')
cc_list: list = Field(None, description='참조 리스트', example='["user2@test.com", "user3@test.com"]')
# recipient: str = Field(None, description='수신자', example='hsj100@a2tec.co.kr')
# cc_list: list = Field(None, description='참조 리스트', example=None)
contents_plain: str = Field(None, description='내용', example=MAIL_REG_CONTENTS.format('user1@test.com', 10))
contents_html: str = Field(None, description='내용', example='<p>내 고양이는 <strong>아주 고약해.</p></strong>')
model_config = ConfigDict(from_attributes=True)
class UserInfoRes(ResponseBase):
"""
### [Response] 유저 정보
"""
data: UserInfo = None
model_config = ConfigDict(from_attributes=True)
class UserLoginReq(BaseModel):
"""
### 유저 로그인 정보
"""
account: str = Field(description='계정', example='user1@test.com')
pw: str = Field(None, description='비밀번호 [관리자 필수]', example='1234')
# SW-LICENSE
sw: str = Field(None, description='라이선스 SW 이름 [유저 필수]', example='저작도구')
mac: str = Field(None, description='MAC [유저 필수]', example='11:22:33:44:55:01')
class UserRegisterReq(BaseModel):
"""
### [Request] 유저 등록
"""
status: Optional[str] = Field(UserStatusType.active, description='계정상태' + UserStatusType.get_elements_str(), example=UserStatusType.active)
user_type: Optional[UserType] = Field(UserType.user, description='유저 타입' + UserType.get_elements_str(), example=UserType.user)
account: str = Field(description='계정', example='test@test.com')
pw: Optional[str] = Field(DEFAULT_USER_ACCOUNT_PW, description='비밀번호', example='1234')
email: Optional[str] = Field(None, description='전자메일', example='test@test.com')
name: str = Field(description='이름', example='test')
sex: Optional[SexType] = Field(SexType.male, description='성별' + SexType.get_elements_str(), example=SexType.male)
rrn: Optional[str] = Field(None, description='주민등록번호', example='19910101-1234567')
address: Optional[str] = Field(None, description='주소', example='대구광역시 동구 동촌로 351, 4층 (용계동, 에이스빌딩)')
phone_number: Optional[str] = Field(None, description='휴대전화', example='010-1234-1234')
picture: Optional[str] = Field(None, description='사진', example='profile1.png')
marketing_agree: Optional[bool] = Field(False, description='마케팅동의 여부', example=False)
# extra
member_type: Optional[MemberType] = Field(MemberType.personal, description='회원 타입' + MemberType.get_elements_str(), example=MemberType.personal)
# relationship:license
license_sw: str = Field(description='라이선스 SW 이름', example='저작도구')
license_start: datetime = Field(description='라이선스 시작날짜', example='2022-02-10T15:00:00')
license_end: datetime = Field(description='라이선스 시작날짜', example='2023-02-10T15:00:00')
license_num: Optional[int] = Field(1, description='계약된 라이선스 개수', example=1)
license_manager_name: Optional[str] = Field(ADMIN_INIT_ACCOUNT_INFO.name, description='담당자 이름', example=ADMIN_INIT_ACCOUNT_INFO.name)
license_manager_phone: Optional[str] = Field(ADMIN_INIT_ACCOUNT_INFO.phone_number, description='담당자 연락처', example=ADMIN_INIT_ACCOUNT_INFO.phone_number)
model_config = ConfigDict(from_attributes=True)
class UserSearchReq(BaseModel):
"""
### [Request] 유저 검색 (기본)
"""
# basic
id: Optional[int] = Field(None, description='등록번호', example='1')
id__in: Optional[list] = Field(None, description='등록번호 리스트', example=[1,])
created_at: Optional[str] = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
created_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
created_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
created_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
created_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
updated_at: Optional[str] = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
updated_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
updated_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
updated_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
updated_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
status: Optional[UserStatusType] = Field(None, description='계정상태' + UserStatusType.get_elements_str(), example=UserStatusType.active)
user_type: Optional[UserType] = Field(None, description='유저 타입' + UserType.get_elements_str(), example=UserType.user)
account_type: Optional[AccountType] = Field(None, description='계정종류' + AccountType.get_elements_str(), example=AccountType.email)
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
account__like: Optional[str] = Field(None, description='계정 부분검색', example='%user%')
email: Optional[str] = Field(None, description='전자메일', example='user1@test.com')
email__like: Optional[str] = Field(None, description='전자메일 부분검색', example='%user%')
name: Optional[str] = Field(None, description='이름', example='test')
name__like: Optional[str] = Field(None, description='이름 부분검색', example='%user%')
sex: Optional[SexType] = Field(None, description='성별' + SexType.get_elements_str(), example=SexType.male)
rrn: Optional[str] = Field(None, description='주민등록번호', example='123456-1234567')
address: Optional[str] = Field(None, description='주소', example='대구1')
address__like: Optional[str] = Field(None, description='주소 부분검색', example='%대구%')
phone_number: Optional[str] = Field(None, description='연락처', example='010-1234-1234')
picture: Optional[str] = Field(None, description='프로필사진', example='profile1.png')
marketing_agree: Optional[bool] = Field(None, description='마케팅동의 여부', example=False)
# extra
login: Optional[UserLoginType] = Field(None, description='로그인 상태' + UserLoginType.get_elements_str(), example=UserLoginType.logout) # TODO(hsj100): LOGIN_STATUS
member_type: Optional[MemberType] = Field(None, description='회원 타입' + MemberType.get_elements_str(), example=MemberType.personal)
model_config = ConfigDict(from_attributes=True)
class UserSearchRes(ResponseBase):
"""
### [Response] 유저 검색
"""
data: List[UserInfo] = []
model_config = ConfigDict(from_attributes=True)
class UserSearchPagingReq(BaseModel):
"""
### [Request] 유저 페이징 검색
"""
# paging
paging: Optional[PagingReq] = None
search: Optional[UserSearchReq] = None
class UserSearchPagingRes(ResponseBase):
"""
### [Response] 유저 페이징 검색
"""
paging: PagingRes = None
data: List[UserInfo] = []
model_config = ConfigDict(from_attributes=True)
class UserUpdateReq(BaseModel):
"""
### [Request] 유저 변경
"""
status: Optional[UserStatusType] = Field(None, description='계정상태' + UserStatusType.get_elements_str(), example=UserStatusType.active)
user_type: Optional[UserType] = Field(None, description='유저 타입' + UserType.get_elements_str(), example=UserType.user)
account_type: Optional[AccountType] = Field(None, description='계정종류' + AccountType.get_elements_str(), example=AccountType.email)
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
email: Optional[str] = Field(None, description='전자메일', example='user1@test.com')
name: Optional[str] = Field(None, description='이름', example='test')
sex: Optional[SexType] = Field(None, description='성별' + SexType.get_elements_str(), example=SexType.male)
rrn: Optional[str] = Field(None, description='주민등록번호', example='123456-1234567')
address: Optional[str] = Field(None, description='주소', example='대구1')
phone_number: Optional[str] = Field(None, description='연락처', example='010-1234-1234')
picture: Optional[str] = Field(None, description='프로필사진', example='profile1.png')
marketing_agree: Optional[bool] = Field(False, description='마케팅동의 여부', example=False)
# extra
login: Optional[UserLoginType] = Field(None, description='로그인 상태' + UserLoginType.get_elements_str(), example=UserLoginType.logout) # TODO(hsj100): LOGIN_STATUS
# uuid: Optional[str] = Field(None, description='UUID', example='12345678-1234-5678-1234-567800000001')
member_type: Optional[MemberType] = Field(None, description='회원 타입' + MemberType.get_elements_str(), example=MemberType.personal)
model_config = ConfigDict(from_attributes=True)
class UserUpdateMultiReq(BaseModel):
"""
### [Request] 유저 변경 (multi)
"""
search_info: UserSearchReq = None
update_info: UserUpdateReq = None
model_config = ConfigDict(from_attributes=True)
class UserUpdatePWReq(BaseModel):
"""
### [Request] 유저 비밀번호 변경
"""
account: str = Field(None, description='계정', example='user1@test.com')
current_pw: str = Field(None, description='현재 비밀번호', example='1234')
new_pw: str = Field(None, description='신규 비밀번호', example='5678')
class UserLogSearchReq(BaseModel):
"""
### [Request] 유저로그 검색
"""
# basic
id: Optional[int] = Field(None, description='등록번호', example='1')
id__in: Optional[list] = Field(None, description='등록번호 리스트', example=[1,])
created_at: Optional[str] = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
created_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
created_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
created_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
created_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
updated_at: Optional[str] = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
updated_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
updated_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
updated_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
updated_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
account__like: Optional[str] = Field(None, description='계정 부분검색', example='%user%')
mac: Optional[str] = Field(None, description='MAC', example='11:22:33:44:55:01')
mac__like: Optional[str] = Field(None, description='MAC 부분검색', example='%33%')
type: Optional[UserLogMessageType] = Field(None, description='유저로그 메시지 타입' + UserLogMessageType.get_elements_str(), example=UserLogMessageType.error)
api: Optional[str] = Field(None, description='API 이름', example='/api/auth/login')
api__like: Optional[str] = Field(None, description='API 이름 부분검색', example='%login%')
message: Optional[str] = Field(None, description='로그내용', example='invalid password')
message__like: Optional[str] = Field(None, description='로그내용 부분검색', example='%invalid%')
model_config = ConfigDict(from_attributes=True)
class UserLogDayInfo(BaseModel):
"""
### 유저로그 일별 접속 정보 (출석)
"""
account: str = Field(None, description='계정', example='user1@test.com')
mac: str = Field(None, description='MAC(네트워크 인터페이스 식별자)', example='11:22:33:44:55:66')
updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
message: str = Field(None, description='로그 내용', example='ok')
model_config = ConfigDict(from_attributes=True)
class UserLogDaySearchReq(BaseModel):
"""
### [Request] 유저로그 일별 마지막 접속 검색
"""
updated_at: Optional[str] = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
updated_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
updated_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
updated_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
updated_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
account__like: Optional[str] = Field(None, description='계정 부분검색', example='%user%')
mac: Optional[str] = Field(None, description='MAC', example='11:22:33:44:55:01')
mac__like: Optional[str] = Field(None, description='MAC 부분검색', example='%33%')
model_config = ConfigDict(from_attributes=True)
class UserLogDaySearchPagingReq(BaseModel):
"""
### [Request] 유저로그 일별 마지막 접속 페이징 검색
"""
# paging
paging: Optional[PagingReq] = None
search: Optional[UserLogDaySearchReq] = None
class UserLogDaySearchPagingRes(ResponseBase):
"""
### [Response] 유저로그 일별 마지막 접속 검색
"""
paging: PagingRes = None
data: List[UserLogDayInfo] = []
model_config = ConfigDict(from_attributes=True)
class UserLogPagingReq(BaseModel):
"""
### [Request] 유저로그 페이징 검색
"""
# paging
paging: Optional[PagingReq] = None
search: Optional[UserLogSearchReq] = None
class UserLogPagingRes(ResponseBase):
"""
### [Response] 유저로그 페이징 검색
"""
paging: PagingRes = None
data: List[UserLogInfo] = []
model_config = ConfigDict(from_attributes=True)
class UserLogUpdateReq(BaseModel):
"""
### [Request] 유저로그 변경
"""
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
mac: Optional[str] = Field(None, description='MAC', example='11:22:33:44:55:01')
type: Optional[UserLogMessageType] = Field(None, description='유저로그 메시지 타입' + UserLogMessageType.get_elements_str(), example=UserLogMessageType.error)
api: Optional[str] = Field(None, description='API 이름', example='/api/auth/login')
message: Optional[str] = Field(None, description='로그내용', example='invalid password')
model_config = ConfigDict(from_attributes=True)
class UserLogUpdateMultiReq(BaseModel):
"""
### [Request] 유저로그 변경 (multi)
"""
search_info: UserLogSearchReq = None
update_info: UserLogUpdateReq = None
model_config = ConfigDict(from_attributes=True)
#===============================================================================
#===============================================================================
#===============================================================================
#===============================================================================
class IndexType(str, Enum):
hnsw = "hnsw"
l2 = "l2"
class VactorSearchReq(BaseModel):
"""
### [Request] vactor 검색
"""
quary_image_path : str = Field(description='quary image', example='path')
index_type : IndexType = Field(IndexType.hnsw, description='인덱스 타입', example=IndexType.hnsw)
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)

View File

@@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
"""
@File: auth.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: authentication api
"""
from itertools import groupby
from operator import attrgetter
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
import bcrypt
import jwt
from datetime import datetime, timedelta
from vactor_rest.app.common import consts
from vactor_rest.app import models as M
from vactor_rest.app.database.conn import db
from vactor_rest.app.common.config import conf
from vactor_rest.app.database.schema import Users, UserLog
from vactor_rest.app.utils.extra import query_to_groupby, AESCryptoCBC
from vactor_rest.app.utils.date_utils import D
router = APIRouter(prefix='/auth')
@router.get('/find-account/{account}', response_model=M.ResponseBase, summary='계정유무 검사')
async def find_account(account: str):
"""
## 계정유무 검사
주어진 계정이 존재하면 true, 없으면 false 처리
**결과**
- ResponseBase
"""
try:
search_info = Users.get(account=account)
if not search_info:
raise Exception(f'not found data: {account}')
return M.ResponseBase()
except Exception as e:
return M.ResponseBase.set_error(str(e))
@router.post('/logout/{account}', status_code=200, response_model=M.TokenRes, summary='사용자 접속종료')
async def logout(account: str):
"""
## 사용자 접속종료
현재 버전에서는 로그인/로그아웃의 상태를 유지하지 않고 상태값만을 서버에서 사용하기 때문에,\n
***로그상태는 실제상황과 다를 수 있다.***
정상처리시 Authorization(null) 반환
**결과**
- TokenRes
"""
user_info = None
try:
# TODO(hsj100): LOGIN_STATUS
user_info = Users.filter(account=account)
if not user_info:
raise Exception('not found user')
user_info.update(auto_commit=True, login='logout')
return M.TokenRes()
except Exception as e:
if user_info:
user_info.close()
return M.ResponseBase.set_error(e)
async def is_account_exist(account: str):
get_account = Users.get(account=account)
return True if get_account else False
def create_access_token(*, data: dict = None, expires_delta: int = None):
if conf().GLOBAL_TOKEN:
return conf().GLOBAL_TOKEN
to_encode = data.copy()
if expires_delta:
to_encode.update({'exp': datetime.utcnow() + timedelta(hours=expires_delta)})
encoded_jwt = jwt.encode(to_encode, consts.JWT_SECRET, algorithm=consts.JWT_ALGORITHM)
return encoded_jwt

View File

@@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
"""
@File: dev.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: Developments Test
"""
import struct
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
import bcrypt
from starlette.requests import Request
from vactor_rest.app.common import consts
from vactor_rest.app import models as M
from vactor_rest.app.database.conn import db, Base
from vactor_rest.app.database.schema import Users, UserLog
from vactor_rest.app.utils.extra import FernetCrypto, AESCryptoCBC, AESCipher
from custom_logger.vactor_log import vactor_logger as LOG
# mail test
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
def send_mail():
"""
구글 계정사용시 : 보안 수준이 낮은 앱에서의 접근 활성화
:return:
"""
sender = 'jolimola@gmail.com'
sender_pw = '!ghkdtmdwns1'
# recipient = 'hsj100@a2tec.co.kr'
recipient = 'jwkim@daooldns.co.kr'
list_cc = ['cc1@gmail.com', 'cc2@naver.com']
str_cc = ','.join(list_cc)
title = 'Test mail'
contents = '''
This is test mail
using smtplib.
'''
smtp_server = smtplib.SMTP( # 1
host='smtp.gmail.com',
port=587
)
smtp_server.ehlo() # 2
smtp_server.starttls() # 2
smtp_server.ehlo() # 2
smtp_server.login(sender, sender_pw) # 3
msg = MIMEMultipart() # 4
msg['From'] = sender # 5
msg['To'] = recipient # 5
# msg['Cc'] = str_cc # 5
msg['Subject'] = contents # 5
msg.attach(MIMEText(contents, 'plain')) # 6
smtp_server.send_message(msg) # 7
smtp_server.quit() # 8
router = APIRouter(prefix='/dev')
@router.get('/test', summary='테스트', response_model=M.SWInfo)
async def test(request: Request):
"""
## ELB 상태 체크용 API
**결과**
- SWInfo
"""
a = M.SWInfo()
a.name = '!ekdnfeldpsdptm1'
# a.name = 'testtesttest123'
simpleEnDecrypt = FernetCrypto()
a.data1 = simpleEnDecrypt.encrypt(a.name)
a.data2 = bcrypt.hashpw(a.name.encode('utf-8'), bcrypt.gensalt())
t = bytes(a.name.encode('utf-8'))
enc = AESCryptoCBC().encrypt(t)
dec = AESCryptoCBC().decrypt(enc)
t = enc.decode('utf-8')
# enc = AESCipher('daooldns12345678').encrypt(a.name).decode('utf-8')
# enc = 'E563ZFt+yJL8YY5yYlYyk602MSscPP2SCCD8UtXXpMI='
# dec = AESCipher('daooldns12345678').decrypt(enc).decode('utf-8')
a.data3 = f'enc: {enc}, {t}'
a.data4 = f'dec: {dec}'
a.name = '!ekdnfeldpsdptm1'
# simpleEnDecrypt = SimpleEnDecrypt()
a.data5 = simpleEnDecrypt.encrypt(a.name)
a.data6 = bcrypt.hashpw(a.name.encode('utf-8'), bcrypt.gensalt())
key = consts.ADMIN_INIT_ACCOUNT_INFO['aes_cbc_key']
t = bytes(a.name.encode('utf-8'))
enc = AESCryptoCBC(key).encrypt(t)
dec = AESCryptoCBC(key).decrypt(enc)
t = enc.decode('utf-8')
# enc = AESCipher('daooldns12345678').encrypt(a.name).decode('utf-8')
# enc = 'E563ZFt+yJL8YY5yYlYyk602MSscPP2SCCD8UtXXpMI='
# dec = AESCipher('daooldns12345678').decrypt(enc).decode('utf-8')
print(f'key: {key}')
a.data7 = f'enc: {enc}, {t}'
a.data8 = f'dec: {dec}'
eee = "gAAAAABioV5NucuS9nQugZJnz-KjVG_FGnaowB9KAfhOoWjjiQ4jGLuYJh4Qe94mT_lCm6m3HhuOJqUeOgjppwREDpIQYzrUXA=="
a.data8 = simpleEnDecrypt.decrypt(eee)
return a

View File

@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""
@File: index.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: basic & test api
"""
from fastapi import APIRouter
from vactor_rest.app.utils.date_utils import D
from vactor_rest.app.models import SWInfo
router = APIRouter()
@router.get('/', summary='서비스 정보', response_model=SWInfo)
async def index():
"""
## 서비스 정보
소프트웨어 이름, 버전정보, 현재시간
**결과**
- SWInfo
"""
sw_info = SWInfo()
sw_info.date = D.date_str()
return sw_info

View File

@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
"""
@File: services.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: services api
"""
import requests, json, traceback, os
from fastapi import APIRouter, Depends, Body
from starlette.requests import Request
from typing import Annotated, List
from vactor_rest.app.common import consts
from vactor_rest.app import models as M
from vactor_rest.app.utils.date_utils import D
from custom_logger.vactor_log import vactor_logger as LOG
from custom_apps.faiss.main import search_idxs
router = APIRouter(prefix="/services")
@router.post("/faiss/vactor/search", summary="vactor search", response_model=M.ResponseBase)
async def vactor_search(request: Request, request_body_info: M.VactorSearchReq):
"""
## 벡터검색
## 정보
>
"""
response = M.ResponseBase()
try:
if os.path.exists(request_body_info.quary_image_path):
search_idxs(image_path=request_body_info.quary_image_path,
index_type=request_body_info.index_type,
search_num=request_body_info.search_num)
else:
raise Exception(f"File {request_body_info.quary_image_path} does not exist.")
return response.set_message()
except Exception as e:
LOG.error(traceback.format_exc())
return response.set_error(e)

View File

@@ -0,0 +1,298 @@
# -*- coding: utf-8 -*-
"""
@File: users.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: users api
"""
from fastapi import APIRouter
from starlette.requests import Request
import bcrypt
from vactor_rest.app.common import consts
from vactor_rest.app import models as M
from vactor_rest.app.common.config import conf
from vactor_rest.app.database.schema import Users
from vactor_rest.app.database.crud import table_select, table_update, table_delete
from vactor_rest.app.utils.extra import AESCryptoCBC
router = APIRouter(prefix='/user')
@router.get('/me', response_model=M.UserSearchRes, summary='접속자 정보')
async def get_me(request: Request):
"""
## 현재 접속된 자신정보 확인
***현 버전 미지원(추후 상세 처리)***
**결과**
- UserSearchRes
"""
target_table = Users
search_info = None
try:
# request
if conf().GLOBAL_TOKEN:
raise Exception('not supported: use search api!')
accessor_info = request.state.user
if not accessor_info:
raise Exception('invalid accessor')
# search
search_info = target_table.get(account=accessor_info.account)
if not search_info:
raise Exception('not found data')
# result
result_info = list()
result_info.append(M.UserInfo.from_orm(search_info))
return M.UserSearchRes(data=result_info)
except Exception as e:
if search_info:
search_info.close()
return M.ResponseBase.set_error(str(e))
@router.post('/search', response_model=M.UserSearchPagingRes, summary='유저정보 검색')
async def user_search(request: Request, request_body_info: M.UserSearchPagingReq):
"""
## 유저정보 검색 (기본)
검색 조건은 유저 테이블 항목만 가능 (Request body: Schema 참조)\n
관련 정보 연동 검색은 별도 API 사용
**세부항목**
- **paging**\n
항목 미사용시에는 페이징기능 없이 검색조건(search) 결과 모두 반환
- **search**\n
검색에 필요한 항목들을 search 에 포함시킨다.\n
검색에 사용된 각 항목들은 AND 조건으로 처리된다.
- **전체검색**\n
empty object 사용 ( {} )\n
예) "search": {}
- **검색항목**\n
- 부분검색 항목\n
SQL 문법( %, _ )을 사용한다.\n
__like: 시작포함(X%), 중간포함(%X%), 끝포함(%X)
- 구간검색 항목
* __lt: 주어진 값보다 작은값
* __lte: 주어진 값보다 같거나 작은 값
* __gt: 주어진 값보다 큰값
* __gte: 주어진 값보다 같거나 큰값
**결과**
- UserSearchPagingRes
"""
return await table_select(request.state.user, Users, request_body_info, M.UserSearchPagingRes, M.UserInfo)
@router.put('/update', response_model=M.ResponseBase, summary='유저정보 변경')
async def user_update(request: Request, request_body_info: M.UserUpdateMultiReq):
"""
## 유저정보 변경
**search_info**: 변경대상\n
**update_info**: 변경내용\n
- **비밀번호** 제외
**결과**
- ResponseBase
"""
return await table_update(request.state.user, Users, request_body_info, M.ResponseBase)
@router.put('/update_pw', response_model=M.ResponseBase, summary='유저 비밀번호 변경')
async def user_update_pw(request: Request, request_body_info: M.UserUpdatePWReq):
"""
## 유저정보 비밀번호 변경
**account**의 **비밀번호**를 변경한다.
**결과**
- ResponseBase
"""
target_table = Users
search_info = None
try:
# request
accessor_info = request.state.user
if not accessor_info:
raise Exception('invalid accessor')
if not request_body_info.account:
raise Exception('invalid account')
# decrypt pw
try:
decode_cur_pw = request_body_info.current_pw.encode('utf-8')
desc_cur_pw = AESCryptoCBC().decrypt(decode_cur_pw)
except Exception as e:
raise Exception(f'failed decryption [current_pw]: {e}')
try:
decode_new_pw = request_body_info.new_pw.encode('utf-8')
desc_new_pw = AESCryptoCBC().decrypt(decode_new_pw)
except Exception as e:
raise Exception(f'failed decryption [new_pw]: {e}')
# search
target_user = target_table.get(account=request_body_info.account)
is_verified = bcrypt.checkpw(desc_cur_pw, target_user.pw.encode('utf-8'))
if not is_verified:
raise Exception('invalid password')
search_info = target_table.filter(id=target_user.id)
if not search_info.first():
raise Exception('not found data')
# process
hash_pw = bcrypt.hashpw(desc_new_pw, bcrypt.gensalt())
result_info = search_info.update(auto_commit=True, pw=hash_pw)
if not result_info or not result_info.id:
raise Exception('failed update')
# result
return M.ResponseBase()
except Exception as e:
if search_info:
search_info.close()
return M.ResponseBase.set_error(str(e))
@router.delete('/delete', response_model=M.ResponseBase, summary='유저정보 삭제')
async def user_delete(request: Request, request_body_info: M.UserSearchReq):
"""
## 유저정보 삭제
조건에 해당하는 정보를 모두 삭제한다.\n
- **본 API는 DB에서 완적삭제를 하는 함수이며, 서버관리자가 사용하는 것을 권장한다.**
- **update API를 사용하여 상태 항목을 변경해서 사용하는 것을 권장.**
`유저삭제시 관계 테이블의 정보도 같이 삭제된다.`
**결과**
- ResponseBase
"""
return await table_delete(request.state.user, Users, request_body_info, M.ResponseBase)
# NOTE(hsj100): apikey
"""
"""
# @router.get('/apikeys', response_model=List[M.GetApiKeyList])
# async def get_api_keys(request: Request):
# """
# API KEY 조회
# :param request:
# :return:
# """
# user = request.state.user
# api_keys = ApiKeys.filter(user_id=user.id).all()
# return api_keys
#
#
# @router.post('/apikeys', response_model=M.GetApiKeys)
# async def create_api_keys(request: Request, key_info: M.AddApiKey, session: Session = Depends(db.session)):
# """
# API KEY 생성
# :param request:
# :param key_info:
# :param session:
# :return:
# """
# user = request.state.user
#
# api_keys = ApiKeys.filter(session, user_id=user.id, status='active').count()
# if api_keys == MAX_API_KEY:
# raise ex.MaxKeyCountEx()
#
# alphabet = string.ascii_letters + string.digits
# s_key = ''.join(secrets.choice(alphabet) for _ in range(40))
# uid = None
# while not uid:
# uid_candidate = f'{str(uuid4())[:-12]}{str(uuid4())}'
# uid_check = ApiKeys.get(access_key=uid_candidate)
# if not uid_check:
# uid = uid_candidate
#
# key_info = key_info.dict()
# new_key = ApiKeys.create(session, auto_commit=True, secret_key=s_key, user_id=user.id, access_key=uid, **key_info)
# return new_key
#
#
# @router.put('/apikeys/{key_id}', response_model=M.GetApiKeyList)
# async def update_api_keys(request: Request, key_id: int, key_info: M.AddApiKey):
# """
# API KEY User Memo Update
# :param request:
# :param key_id:
# :param key_info:
# :return:
# """
# user = request.state.user
# key_data = ApiKeys.filter(id=key_id)
# if key_data and key_data.first().user_id == user.id:
# return key_data.update(auto_commit=True, **key_info.dict())
# raise ex.NoKeyMatchEx()
#
#
# @router.delete('/apikeys/{key_id}')
# async def delete_api_keys(request: Request, key_id: int, access_key: str):
# user = request.state.user
# await check_api_owner(user.id, key_id)
# search_by_key = ApiKeys.filter(access_key=access_key)
# if not search_by_key.first():
# raise ex.NoKeyMatchEx()
# search_by_key.delete(auto_commit=True)
# return MessageOk()
#
#
# @router.get('/apikeys/{key_id}/whitelists', response_model=List[M.GetAPIWhiteLists])
# async def get_api_keys(request: Request, key_id: int):
# user = request.state.user
# await check_api_owner(user.id, key_id)
# whitelists = ApiWhiteLists.filter(api_key_id=key_id).all()
# return whitelists
#
#
# @router.post('/apikeys/{key_id}/whitelists', response_model=M.GetAPIWhiteLists)
# async def create_api_keys(request: Request, key_id: int, ip: M.CreateAPIWhiteLists, session: Session = Depends(db.session)):
# user = request.state.user
# await check_api_owner(user.id, key_id)
# import ipaddress
# try:
# _ip = ipaddress.ip_address(ip.ip_addr)
# except Exception as e:
# raise ex.InvalidIpEx(ip.ip_addr, e)
# if ApiWhiteLists.filter(api_key_id=key_id).count() == MAX_API_WHITELIST:
# raise ex.MaxWLCountEx()
# ip_dup = ApiWhiteLists.get(api_key_id=key_id, ip_addr=ip.ip_addr)
# if ip_dup:
# return ip_dup
# ip_reg = ApiWhiteLists.create(session=session, auto_commit=True, api_key_id=key_id, ip_addr=ip.ip_addr)
# return ip_reg
#
#
# @router.delete('/apikeys/{key_id}/whitelists/{list_id}')
# async def delete_api_keys(request: Request, key_id: int, list_id: int):
# user = request.state.user
# await check_api_owner(user.id, key_id)
# ApiWhiteLists.filter(id=list_id, api_key_id=key_id).delete()
#
# return MessageOk()
#
#
# async def check_api_owner(user_id, key_id):
# api_keys = ApiKeys.get(id=key_id, user_id=user_id)
# if not api_keys:
# raise ex.NoKeyMatchEx()

View File

@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
"""
@File: date_utils.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: date-utility functions
"""
from datetime import datetime, date, timedelta
_TIMEDELTA = 9
class D:
def __init__(self, *args):
self.utc_now = datetime.utcnow()
# NOTE(hsj100): utc->kst
self.timedelta = _TIMEDELTA
@classmethod
def datetime(cls, diff: int=_TIMEDELTA) -> datetime:
return datetime.utcnow() + timedelta(hours=diff) if diff > 0 else datetime.now()
@classmethod
def date(cls, diff: int=_TIMEDELTA) -> date:
return cls.datetime(diff=diff).date()
@classmethod
def date_num(cls, diff: int=_TIMEDELTA) -> int:
return int(cls.date(diff=diff).strftime('%Y%m%d'))
@classmethod
def validate(cls, date_text):
try:
datetime.strptime(date_text, '%Y-%m-%d')
except ValueError:
raise ValueError('Incorrect data format, should be YYYY-MM-DD')
@classmethod
def date_str(cls, diff: int = _TIMEDELTA) -> str:
return cls.datetime(diff=diff).strftime('%Y-%m-%dT%H:%M:%S')
@classmethod
def check_expire_date(cls, expire_date: datetime):
td = expire_date - datetime.now()
timestamp = td.total_seconds()
return timestamp
@classmethod
def date_file_name(cls):
date = datetime.now()
return date.strftime('%y%m%d_%H%M%S.%s')
if __name__ == "__main__":
_date = D()

View File

@@ -0,0 +1,205 @@
# -*- coding: utf-8 -*-
"""
@File: extra.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: extra functions
"""
from hashlib import md5
from base64 import b64decode
from base64 import b64encode
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from cryptography.fernet import Fernet # symmetric encryption
# mail test
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from itertools import groupby
from operator import attrgetter
import uuid
from vactor_rest.app.common.consts import NUM_RETRY_UUID_GEN, SMTP_HOST, SMTP_PORT
from vactor_rest.app.utils.date_utils import D
from vactor_rest.app import models as M
from vactor_rest.app.common.consts import AES_CBC_PUBLIC_KEY, AES_CBC_IV, FERNET_SECRET_KEY
async def send_mail(sender, sender_pw, title, recipient, contents_plain, contents_html, cc_list, smtp_host=SMTP_HOST, smtp_port=SMTP_PORT):
"""
구글 계정사용시 : 보안 수준이 낮은 앱에서의 접근 활성화
:return:
None: success
Str. Message: error
"""
try:
# check parameters
if not sender:
raise Exception('invalid sender')
if not title:
raise Exception('invalid title')
if not recipient:
raise Exception('invalid recipient')
# sender info.
# sender = consts.ADMIN_INIT_ACCOUNT_INFO.email
# sender_pw = consts.ADMIN_INIT_ACCOUNT_INFO.email_pw
# message
msg = MIMEMultipart()
msg['From'] = sender
msg['To'] = recipient
if cc_list:
list_cc = cc_list
str_cc = ','.join(list_cc)
msg['Cc'] = str_cc
msg['Subject'] = title
if contents_plain:
msg.attach(MIMEText(contents_plain, 'plain'))
if contents_html:
msg.attach(MIMEText(contents_html, 'html'))
# smtp server
smtp_server = smtplib.SMTP(host=smtp_host, port=smtp_port)
smtp_server.ehlo()
smtp_server.starttls()
smtp_server.ehlo()
smtp_server.login(sender, sender_pw)
smtp_server.send_message(msg)
smtp_server.quit()
return None
except Exception as e:
return str(e)
def query_to_groupby(query_result, key, first=False):
"""
쿼리 결과물(list)을 항목값(key)으로 그룹화한다.
:param query_result: 쿼리 결과 리스트
:param key: 그룹 항목값
:return: dict
"""
group_info = dict()
for k, g in groupby(query_result, attrgetter(key)):
if k not in group_info:
if not first:
group_info[k] = list(g)
else:
group_info[k] = list(g)[0]
else:
if not first:
group_info[k].extend(list(g))
return group_info
def query_to_groupby_date(query_result, key):
"""
쿼리 결과물(list)을 항목값(key)으로 그룹화한다.
:param query_result: 쿼리 결과 리스트
:param key: 그룹 항목값
:return: dict
"""
group_info = dict()
for k, g in groupby(query_result, attrgetter(key)):
day_str = k.strftime("%Y-%m-%d")
if day_str not in group_info:
group_info[day_str] = list(g)
else:
group_info[day_str].extend(list(g))
return group_info
class FernetCrypto:
def __init__(self, key=FERNET_SECRET_KEY):
self.key = key
self.f = Fernet(self.key)
def encrypt(self, data, is_out_string=True):
if isinstance(data, bytes):
ou = self.f.encrypt(data) # 바이트형태이면 바로 암호화
else:
ou = self.f.encrypt(data.encode('utf-8')) # 인코딩 후 암호화
if is_out_string is True:
return ou.decode('utf-8') # 출력이 문자열이면 디코딩 후 반환
else:
return ou
def decrypt(self, data, is_out_string=True):
if isinstance(data, bytes):
ou = self.f.decrypt(data) # 바이트형태이면 바로 복호화
else:
ou = self.f.decrypt(data.encode('utf-8')) # 인코딩 후 복호화
if is_out_string is True:
return ou.decode('utf-8') # 출력이 문자열이면 디코딩 후 반환
else:
return ou
class AESCryptoCBC:
def __init__(self, key=AES_CBC_PUBLIC_KEY, iv=AES_CBC_IV):
# Initial vector를 0으로 초기화하여 16바이트 할당함
# iv = chr(0) * 16 #pycrypto 기준
# iv = bytes([0x00] * 16) #pycryptodomex 기준
# aes cbc 생성
self.key = key
self.iv = iv
self.crypto = AES.new(self.key, AES.MODE_CBC, self.iv)
def encrypt(self, data):
# 암호화 message는 16의 배수여야 한다.
# enc = self.crypto.encrypt(data)
# return enc
enc = self.crypto.encrypt(pad(data, AES.block_size))
return b64encode(enc)
def decrypt(self, enc):
# 복호화 enc는 16의 배수여야 한다.
# dec = self.crypto.decrypt(enc)
# return dec
enc = b64decode(enc)
dec = self.crypto.decrypt(enc)
return unpad(dec, AES.block_size)
class AESCipher:
def __init__(self, key):
# self.key = md5(key.encode('utf8')).digest()
self.key = bytes(key.encode('utf-8'))
def encrypt(self, data):
# iv = get_random_bytes(AES.block_size)
iv = bytes('daooldns12345678'.encode('utf-8'))
self.cipher = AES.new(self.key, AES.MODE_CBC, iv)
t = b64encode(self.cipher.encrypt(pad(data.encode('utf-8'), AES.block_size)))
return b64encode(iv + self.cipher.encrypt(pad(data.encode('utf-8'), AES.block_size)))
def decrypt(self, data):
raw = b64decode(data)
self.cipher = AES.new(self.key, AES.MODE_CBC, raw[:AES.block_size])
return unpad(self.cipher.decrypt(raw[AES.block_size:]), AES.block_size)
def cls_list_to_dict_list(list):
"""
list 내부 element가 dict로 변환 가능한 class일경우
내부 element를 dict 로 변경
"""
_result = []
for i in list:
if isinstance(i, dict):
_result = list
break
_result.append(i.dict())
return _result

View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
"""
@File: logger.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: logger
"""
import json
import logging
from datetime import timedelta, datetime
from time import time
from fastapi.requests import Request
from fastapi import Body
from fastapi.logger import logger
logger.setLevel(logging.INFO)
async def api_logger(request: Request, response=None, error=None):
time_format = '%Y/%m/%d %H:%M:%S'
t = time() - request.state.start
status_code = error.status_code if error else response.status_code
error_log = None
user = request.state.user
if error:
if request.state.inspect:
frame = request.state.inspect
error_file = frame.f_code.co_filename
error_func = frame.f_code.co_name
error_line = frame.f_lineno
else:
error_func = error_file = error_line = 'UNKNOWN'
error_log = dict(
errorFunc=error_func,
location='{} line in {}'.format(str(error_line), error_file),
raised=str(error.__class__.__name__),
msg=str(error.ex),
)
account = user.account.split('@') if user and user.account else None
user_log = dict(
client=request.state.ip,
user=user.id if user and user.id else None,
account='**' + account[0][2:-1] + '*@' + account[1] if user and user.account else None,
)
log_dict = dict(
url=request.url.hostname + request.url.path,
method=str(request.method),
statusCode=status_code,
errorDetail=error_log,
client=user_log,
processedTime=str(round(t * 1000, 5)) + 'ms',
datetimeUTC=datetime.utcnow().strftime(time_format),
datetimeKST=(datetime.utcnow() + timedelta(hours=9)).strftime(time_format),
)
if error and error.status_code >= 500:
logger.error(json.dumps(log_dict))
else:
logger.info(json.dumps(log_dict))

View File

@@ -0,0 +1,25 @@
from const import ILLEGAL_FILE_NAME
def prompt_to_filenames(prompt):
"""
prompt 에 사용할 수 없는 문자가 있으면 '_' 로 치환
"""
filename = ''
for i in prompt:
if i in ILLEGAL_FILE_NAME:
filename += '_'
else:
filename += i
return filename
def download_range(download_count:int,max=4):
_min = 1
_max = max
if _min <= download_count and download_count <= _max:
return True
return False

View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
"""
@File: query_utils.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: query-utility functions
"""
from typing import List
def to_dict(model, *args, exclude: List = None):
q_dict = {}
for c in model.__table__.columns:
if not args or c.name in args:
if not exclude or c.name not in exclude:
q_dict[c.name] = getattr(model, c.name)
return q_dict

View File

@@ -0,0 +1,222 @@
# Gunicorn configuration file.
#
# Server socket
#
# bind - The socket to bind.
#
# A string of the form: 'HOST', 'HOST:PORT', 'unix:PATH'.
# An IP is a valid HOST.
#
# backlog - The number of pending connections. This refers
# to the number of clients that can be waiting to be
# served. Exceeding this number results in the client
# getting an error when attempting to connect. It should
# only affect servers under significant load.
#
# Must be a positive integer. Generally set in the 64-2048
# range.
#
bind = "0.0.0.0:5000"
backlog = 2048
#
# Worker processes
#
# workers - The number of worker processes that this server
# should keep alive for handling requests.
#
# A positive integer generally in the 2-4 x $(NUM_CORES)
# range. You'll want to vary this a bit to find the best
# for your particular application's work load.
#
# worker_class - The type of workers to use. The default
# sync class should handle most 'normal' types of work
# loads. You'll want to read
# http://docs.gunicorn.org/en/latest/design.html#choosing-a-worker-type
# for information on when you might want to choose one
# of the other worker classes.
#
# A string referring to a Python path to a subclass of
# gunicorn.workers.base.Worker. The default provided values
# can be seen at
# http://docs.gunicorn.org/en/latest/settings.html#worker-class
#
# worker_connections - For the eventlet and gevent worker classes
# this limits the maximum number of simultaneous clients that
# a single process can handle.
#
# A positive integer generally set to around 1000.
#
# timeout - If a worker does not notify the master process in this
# number of seconds it is killed and a new worker is spawned
# to replace it.
#
# Generally set to thirty seconds. Only set this noticeably
# higher if you're sure of the repercussions for sync workers.
# For the non sync workers it just means that the worker
# process is still communicating and is not tied to the length
# of time required to handle a single request.
#
# keepalive - The number of seconds to wait for the next request
# on a Keep-Alive HTTP connection.
#
# A positive integer. Generally set in the 1-5 seconds range.
#
# reload - Restart workers when code changes.
#
# This setting is intended for development. It will cause
# workers to be restarted whenever application code changes.
workers = 3
threads = 3
worker_class = "uvicorn.workers.UvicornWorker"
worker_connections = 1000
timeout = 60
keepalive = 2
reload = True
#
# spew - Install a trace function that spews every line of Python
# that is executed when running the server. This is the
# nuclear option.
#
# True or False
#
spew = False
#
# Server mechanics
#
# daemon - Detach the main Gunicorn process from the controlling
# terminal with a standard fork/fork sequence.
#
# True or False
#
# raw_env - Pass environment variables to the execution environment.
#
# pidfile - The path to a pid file to write
#
# A path string or None to not write a pid file.
#
# user - Switch worker processes to run as this user.
#
# A valid user id (as an integer) or the name of a user that
# can be retrieved with a call to pwd.getpwnam(value) or None
# to not change the worker process user.
#
# group - Switch worker process to run as this group.
#
# A valid group id (as an integer) or the name of a user that
# can be retrieved with a call to pwd.getgrnam(value) or None
# to change the worker processes group.
#
# umask - A mask for file permissions written by Gunicorn. Note that
# this affects unix socket permissions.
#
# A valid value for the os.umask(mode) call or a string
# compatible with int(value, 0) (0 means Python guesses
# the base, so values like "0", "0xFF", "0022" are valid
# for decimal, hex, and octal representations)
#
# tmp_upload_dir - A directory to store temporary request data when
# requests are read. This will most likely be disappearing soon.
#
# A path to a directory where the process owner can write. Or
# None to signal that Python should choose one on its own.
#
daemon = False
pidfile = None
umask = 0
user = None
group = None
tmp_upload_dir = None
#
# Logging
#
# logfile - The path to a log file to write to.
#
# A path string. "-" means log to stdout.
#
# loglevel - The granularity of log output
#
# A string of "debug", "info", "warning", "error", "critical"
#
errorlog = "-"
loglevel = "info"
accesslog = None
access_log_format = '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'
#
# Process naming
#
# proc_name - A base to use with setproctitle to change the way
# that Gunicorn processes are reported in the system process
# table. This affects things like 'ps' and 'top'. If you're
# going to be running more than one instance of Gunicorn you'll
# probably want to set a name to tell them apart. This requires
# that you install the setproctitle module.
#
# A string or None to choose a default of something like 'gunicorn'.
#
proc_name = "NotificationAPI"
#
# Server hooks
#
# post_fork - Called just after a worker has been forked.
#
# A callable that takes a server and worker instance
# as arguments.
#
# pre_fork - Called just prior to forking the worker subprocess.
#
# A callable that accepts the same arguments as after_fork
#
# pre_exec - Called just prior to forking off a secondary
# master process during things like config reloading.
#
# A callable that takes a server instance as the sole argument.
#
def post_fork(server, worker):
server.log.info("Worker spawned (pid: %s)", worker.pid)
def pre_fork(server, worker):
pass
def pre_exec(server):
server.log.info("Forked child, re-executing.")
def when_ready(server):
server.log.info("Server is ready. Spawning workers")
def worker_int(worker):
worker.log.info("worker received INT or QUIT signal")
# get traceback info
import threading, sys, traceback
id2name = {th.ident: th.name for th in threading.enumerate()}
code = []
for threadId, stack in sys._current_frames().items():
code.append("\n# Thread: %s(%d)" % (id2name.get(threadId, ""), threadId))
for filename, lineno, name, line in traceback.extract_stack(stack):
code.append('File: "%s", line %d, in %s' % (filename, lineno, name))
if line:
code.append(" %s" % (line.strip()))
worker.log.debug("\n".join(code))
def worker_abort(worker):
worker.log.info("worker received SIGABRT signal")

View File

View File

@@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-
"""
@File: conftest.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: test config
"""
import asyncio
import os
from os import path
from typing import List
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from app.database.schema import Users
from app.main import create_app
from app.database.conn import db, Base
from app.models import UserToken
from app.routes.auth import create_access_token
"""
1. DB 생성
2. 테이블 생성
3. 테스트 코드 작동
4. 테이블 레코드 삭제
"""
@pytest.fixture(scope='session')
def app():
os.environ['API_ENV'] = 'test'
return create_app()
@pytest.fixture(scope='session')
def client(app):
# Create tables
Base.metadata.create_all(db.engine)
return TestClient(app=app)
@pytest.fixture(scope='function', autouse=True)
def session():
sess = next(db.session())
yield sess
clear_all_table_data(
session=sess,
metadata=Base.metadata,
except_tables=[]
)
sess.rollback()
@pytest.fixture(scope='function')
def login(session):
"""
테스트전 사용자 미리 등록
:param session:
:return:
"""
db_user = Users.create(session=session, email='ryan_test@dingrr.com', pw='123')
session.commit()
access_token = create_access_token(data=UserToken.from_orm(db_user).dict(exclude={'pw', 'marketing_agree'}),)
return dict(Authorization=f'Bearer {access_token}')
def clear_all_table_data(session: Session, metadata, except_tables: List[str] = None):
session.execute('SET FOREIGN_KEY_CHECKS = 0;')
for table in metadata.sorted_tables:
if table.name not in except_tables:
session.execute(table.delete())
session.execute('SET FOREIGN_KEY_CHECKS = 1;')
session.commit()

View File

@@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
"""
@File: test_auth.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: test auth.
"""
from app.database.conn import db
from app.database.schema import Users
def test_registration(client, session):
"""
레버 로그인
:param client:
:param session:
:return:
"""
user = dict(email='ryan@dingrr.com', pw='123', name='라이언', phone='01099999999')
res = client.post('api/auth/register/email', json=user)
res_body = res.json()
print(res.json())
assert res.status_code == 201
assert 'Authorization' in res_body.keys()
def test_registration_exist_email(client, session):
"""
레버 로그인
:param client:
:param session:
:return:
"""
user = dict(email='Hello@dingrr.com', pw='123', name='라이언', phone='01099999999')
db_user = Users.create(session=session, **user)
session.commit()
res = client.post('api/auth/register/email', json=user)
res_body = res.json()
assert res.status_code == 400
assert 'EMAIL_EXISTS' == res_body['msg']

View File

@@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
"""
@File: test_user.py
@Date: 2020-09-14
@author: A2TEC
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
@brief: test user
"""
from app.database.conn import db
from app.database.schema import Users
def test_create_get_apikey(client, session, login):
"""
레버 로그인
:param client:
:param session:
:return:
"""
key = dict(user_memo='ryan__key')
res = client.post('api/user/apikeys', json=key, headers=login)
res_body = res.json()
assert res.status_code == 200
assert 'secret_key' in res_body
res = client.get('api/user/apikeys', headers=login)
res_body = res.json()
assert res.status_code == 200
assert 'ryan__key' in res_body[0]['user_memo']