edit: faiss 전용 rest 서버 추가
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -144,4 +144,5 @@ cython_debug/
|
||||
log
|
||||
sheet_counter.txt
|
||||
|
||||
*output
|
||||
*output
|
||||
datas/eyewear_all/*
|
||||
@@ -24,5 +24,10 @@ RESTful API Server
|
||||
https://github.com/acheong08/BingImageCreator
|
||||
```
|
||||
|
||||
### faiss
|
||||
검색된 이미지는 search_'matching'.png 로 저장됨
|
||||
matching : 매칭률 (클수록 좋음)
|
||||
|
||||
### notice
|
||||
const.py 에 지정한 OUTPUT_FOLDER 하위폴더에 이미지 저장됨.
|
||||
const.py 에 지정한 OUTPUT_FOLDER 하위폴더에 이미지 저장됨.
|
||||
faiss api 사용시 rest_vactor 서버 구동해야함.
|
||||
1
const.py
1
const.py
@@ -1,3 +1,4 @@
|
||||
# OUTPUT_FOLDER = "/home/fermat/STORAGE/01.Projects/A2TEC/K_EYEWEAR/02.ML_DATA/Image_generator_result"
|
||||
OUTPUT_FOLDER = "./output"
|
||||
|
||||
ILLEGAL_FILE_NAME = ['<', '>', ':', '"', '/', '\ ', '|', '?', '*']
|
||||
@@ -4,8 +4,8 @@ from pathlib import Path
|
||||
from bingart import BingArt
|
||||
|
||||
from custom_apps.utils import cookie_manager
|
||||
from rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from rest.app.utils.date_utils import D
|
||||
from main_rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from main_rest.app.utils.date_utils import D
|
||||
from const import OUTPUT_FOLDER
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ import pkg_resources
|
||||
import regex
|
||||
import requests
|
||||
|
||||
from rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from rest.app.utils.date_utils import D
|
||||
from main_rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from main_rest.app.utils.date_utils import D
|
||||
|
||||
BING_URL = os.getenv("BING_URL", "https://www.bing.com")
|
||||
# Generate random IP between range 13.104.0.0/14
|
||||
|
||||
2
custom_apps/faiss/const.py
Normal file
2
custom_apps/faiss/const.py
Normal file
@@ -0,0 +1,2 @@
|
||||
DATASET_BIN = "./datas/eyewear_all.fvecs.bin"
|
||||
DATASET_TEXT = "./datas/eyewear_all.fnames.txt"
|
||||
50
custom_apps/faiss/main.py
Normal file
50
custom_apps/faiss/main.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from custom_apps.faiss.utils import preprocessing, preprocessing_quary, normalize, get_dataset_list
|
||||
from custom_apps.faiss.const import *
|
||||
# from rest.app.utils.date_utils import D
|
||||
# from const import OUTPUT_FOLDER
|
||||
|
||||
dataset_list = get_dataset_list(DATASET_TEXT)
|
||||
|
||||
def search_idxs(image_path,dataset_bin=DATASET_BIN,index_type="hnsw",search_num=4):
|
||||
|
||||
if index_type not in ["hnsw", "l2"]:
|
||||
raise ValueError("index_type must be either 'hnsw' or 'l2'")
|
||||
|
||||
DIM = 1280
|
||||
|
||||
dataset_fvces, dataset_index = preprocessing(DIM,dataset_bin,index_type)
|
||||
|
||||
org_fvces, org_index = preprocessing_quary(DIM,image_path,index_type)
|
||||
|
||||
dists, idxs = dataset_index.search(normalize(org_fvces), search_num)
|
||||
|
||||
# print(dists[0])
|
||||
# print(idxs[0])
|
||||
|
||||
index_image_save(image_path, dists[0], idxs[0])
|
||||
|
||||
def index_image_save(query_image_path, dists, idxs):
|
||||
directory_path, file = os.path.split(query_image_path)
|
||||
_name, extension = os.path.splitext(file)
|
||||
|
||||
if not os.path.exists(directory_path):
|
||||
raise ValueError(f"Folder {directory_path} does not exist.")
|
||||
|
||||
for dist,index in zip(dists, idxs):
|
||||
|
||||
if dist > 1:
|
||||
dist = 0
|
||||
else:
|
||||
dist = 1-dist
|
||||
|
||||
origin_file_path = dataset_list[index]
|
||||
dest_file_path = os.path.join(directory_path, f"search_{round(float(dist),4)}{extension}")
|
||||
shutil.copy(origin_file_path, dest_file_path)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
116
custom_apps/faiss/utils.py
Normal file
116
custom_apps/faiss/utils.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
quary 이미지(이미지 경로) 입력받아 처리
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import math
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import normalize
|
||||
import faiss
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras.layers as layers
|
||||
from tensorflow.keras.models import Model
|
||||
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def populate(index, fvecs, batch_size=1000):
|
||||
nloop = math.ceil(fvecs.shape[0] / batch_size)
|
||||
for n in range(nloop):
|
||||
s = time.time()
|
||||
index.add(normalize(fvecs[n * batch_size : min((n + 1) * batch_size, fvecs.shape[0])]))
|
||||
# print(n * batch_size, time.time() - s)
|
||||
|
||||
return index
|
||||
|
||||
def get_index(index_type, dim):
|
||||
if index_type == 'hnsw':
|
||||
m = 48
|
||||
index = faiss.IndexHNSWFlat(dim, m)
|
||||
index.hnsw.efConstruction = 128
|
||||
return index
|
||||
elif index_type == 'l2':
|
||||
return faiss.IndexFlatL2(dim)
|
||||
raise
|
||||
|
||||
def preprocessing(dim,fvec_file,index_type):
|
||||
# index_type = 'hnsw'
|
||||
# index_type = 'l2'
|
||||
|
||||
# f-string 방식 (python3 이상에서 지원)
|
||||
index_file = f'{fvec_file}.{index_type}.index'
|
||||
|
||||
fvecs = np.memmap(fvec_file, dtype='float32', mode='r').view('float32').reshape(-1, dim)
|
||||
|
||||
if os.path.exists(index_file):
|
||||
index = faiss.read_index(index_file)
|
||||
if index_type == 'hnsw':
|
||||
index.hnsw.efSearch = 256
|
||||
else:
|
||||
index = get_index(index_type, dim)
|
||||
index = populate(index, fvecs)
|
||||
faiss.write_index(index, index_file)
|
||||
# print(index.ntotal)
|
||||
|
||||
return fvecs,index
|
||||
|
||||
def preprocessing_quary(dim,image_path,index_type):
|
||||
# index_type = 'hnsw'
|
||||
# index_type = 'l2'
|
||||
|
||||
fvecs = fvces_quary(image_path)
|
||||
index = get_index(index_type, dim)
|
||||
index = populate(index, fvecs)
|
||||
|
||||
return fvecs,index
|
||||
|
||||
def preprocess(img_path, input_shape):
|
||||
|
||||
img = tf.io.read_file(img_path)
|
||||
img = tf.image.decode_jpeg(img, channels=input_shape[2])
|
||||
img = tf.image.resize(img, input_shape[:2])
|
||||
img = preprocess_input(img)
|
||||
return img
|
||||
|
||||
def preprocess_pil(pil_img_data, input_shape):
|
||||
|
||||
pil_img = np.asarray(Image.open(pil_img_data))
|
||||
pil_img = tf.image.resize(pil_img, input_shape[:2])
|
||||
pil_img = preprocess_input(pil_img)
|
||||
return pil_img
|
||||
|
||||
def fvces_quary(image_path):
|
||||
|
||||
batch_size = 100
|
||||
input_shape = (224, 224, 3)
|
||||
base = tf.keras.applications.MobileNetV2(input_shape=input_shape,
|
||||
include_top=False,
|
||||
weights='imagenet')
|
||||
base.trainable = False
|
||||
model = Model(inputs=base.input, outputs=layers.GlobalAveragePooling2D()(base.output))
|
||||
|
||||
list_ds = tf.data.Dataset.from_tensor_slices([image_path])
|
||||
|
||||
ds = list_ds.map(lambda x: preprocess(x, input_shape), num_parallel_calls=-1)
|
||||
|
||||
dataset = ds.batch(batch_size).prefetch(-1)
|
||||
|
||||
for batch in dataset:
|
||||
fvecs = model.predict(batch)
|
||||
return fvecs
|
||||
|
||||
|
||||
# index_type = 'hnsw'
|
||||
# index_type = 'l2'
|
||||
|
||||
def get_dataset_list(filepath):
|
||||
content_list = []
|
||||
import os
|
||||
if os.path.isfile(filepath) and filepath.endswith(".txt"):
|
||||
with open(filepath, 'r', encoding='utf-8') as file:
|
||||
for line in file:
|
||||
content_list.append(line.strip())
|
||||
return content_list
|
||||
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ import os
|
||||
from vertexai.preview.vision_models import ImageGenerationModel
|
||||
|
||||
from const import OUTPUT_FOLDER
|
||||
from rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from rest.app.utils.date_utils import D
|
||||
from main_rest.app.utils.parsing_utils import prompt_to_filenames
|
||||
from main_rest.app.utils.date_utils import D
|
||||
|
||||
|
||||
class ImagenConst:
|
||||
@@ -48,4 +48,44 @@ def imagen_generate_image(prompt,download_count=1):
|
||||
for i in range(len(images.images)):
|
||||
images[i].save(location=os.path.join(_folder,f"imagen_{_file_name}_{i+1}_{_datetime}.png"), include_generation_parameters=False)
|
||||
|
||||
return len(images.images)
|
||||
return len(images.images)
|
||||
|
||||
def imagen_generate_image_data(prompt,download_count=1):
|
||||
vertexai.init(project=ImagenConst.project_id, location=ImagenConst.location)
|
||||
|
||||
model = ImageGenerationModel.from_pretrained(ImagenConst.model)
|
||||
|
||||
images = model.generate_images(
|
||||
prompt=prompt,
|
||||
# Optional parameters
|
||||
number_of_images=download_count,
|
||||
language="ko",
|
||||
# You can't use a seed value and watermark at the same time.
|
||||
# add_watermark=False,
|
||||
# seed=100,
|
||||
aspect_ratio="1:1",
|
||||
safety_filter_level="block_some",
|
||||
person_generation="dont_allow",
|
||||
)
|
||||
|
||||
return images.images[0]._pil_image
|
||||
|
||||
def imagen_generate_image_path(image_prompt):
|
||||
|
||||
MODEL = "imagen"
|
||||
QUERY = "query"
|
||||
create_time = D.date_file_name()
|
||||
|
||||
folder_name = os.path.join(OUTPUT_FOLDER,f"{MODEL}_{QUERY}_{image_prompt}_{create_time}")
|
||||
if not os.path.exists(folder_name):
|
||||
os.makedirs(folder_name)
|
||||
|
||||
generate_img = imagen_generate_image_data(image_prompt)
|
||||
|
||||
generate_img.save(os.path.join(folder_name,f"query.png"))
|
||||
|
||||
return os.path.join(folder_name,f"query.png")
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
# imagen_generate_image_data("cat")
|
||||
@@ -95,18 +95,18 @@ def get_logger():
|
||||
return custom_logger
|
||||
|
||||
|
||||
if custom_logger is None:
|
||||
custom_logger = logging.getLogger(LOGGER_NAME)
|
||||
# if custom_logger is None:
|
||||
# custom_logger = logging.getLogger(LOGGER_NAME)
|
||||
|
||||
if not os.path.exists(LOGGER_DIR):
|
||||
os.makedirs(LOGGER_DIR)
|
||||
# if not os.path.exists(LOGGER_DIR):
|
||||
# os.makedirs(LOGGER_DIR)
|
||||
|
||||
if not __LOGGER_FILE_PATH:
|
||||
logger_init(custom_logger, level=LOGGER_LEVEL)
|
||||
else:
|
||||
if not os.path.exists(LOGGER_DIR):
|
||||
os.makedirs(LOGGER_DIR)
|
||||
logger_init(custom_logger, level=LOGGER_LEVEL, file_log_path=__LOGGER_FILE_PATH)
|
||||
# if not __LOGGER_FILE_PATH:
|
||||
# logger_init(custom_logger, level=LOGGER_LEVEL)
|
||||
# else:
|
||||
# if not os.path.exists(LOGGER_DIR):
|
||||
# os.makedirs(LOGGER_DIR)
|
||||
# logger_init(custom_logger, level=LOGGER_LEVEL, file_log_path=__LOGGER_FILE_PATH)
|
||||
|
||||
def test():
|
||||
custom_logger.info('Module: custom_log.py')
|
||||
|
||||
42
custom_logger/main_log.py
Normal file
42
custom_logger/main_log.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import datetime
|
||||
import os
|
||||
import logging
|
||||
|
||||
from custom_logger.custom_log import logger_init
|
||||
|
||||
main_logger = None
|
||||
|
||||
_now = datetime.datetime.now()
|
||||
|
||||
LOGGER_NAME = 'main'
|
||||
LOGGER_FILE_NAME = f'{_now.strftime("%Y-%m-%d %H_%M_%S")}_{LOGGER_NAME}.log'
|
||||
|
||||
LOGGER_LEVEL = logging.INFO
|
||||
LOGGER_DIR = "log/"
|
||||
__LOGGER_FILE_PATH = LOGGER_DIR + LOGGER_FILE_NAME
|
||||
|
||||
def get_main_logger():
|
||||
"""
|
||||
로거 객체 반환
|
||||
로거 객체가 없을 경우에는 로그 초기화를 진행하고 생성된 로그 객체를 반환함
|
||||
:return: 로그 객체
|
||||
"""
|
||||
|
||||
if not main_logger.handlers:
|
||||
logger_init(main_logger)
|
||||
|
||||
return main_logger
|
||||
|
||||
if main_logger is None:
|
||||
|
||||
main_logger = logging.getLogger(LOGGER_NAME)
|
||||
|
||||
if not os.path.exists(LOGGER_DIR):
|
||||
os.makedirs(LOGGER_DIR)
|
||||
|
||||
if not __LOGGER_FILE_PATH:
|
||||
logger_init(main_logger, level=LOGGER_LEVEL)
|
||||
else:
|
||||
if not os.path.exists(LOGGER_DIR):
|
||||
os.makedirs(LOGGER_DIR)
|
||||
logger_init(main_logger, level=LOGGER_LEVEL, file_log_path=__LOGGER_FILE_PATH)
|
||||
42
custom_logger/vactor_log.py
Normal file
42
custom_logger/vactor_log.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import datetime
|
||||
import os
|
||||
import logging
|
||||
|
||||
from custom_logger.custom_log import logger_init
|
||||
|
||||
vactor_logger = None
|
||||
|
||||
_now = datetime.datetime.now()
|
||||
|
||||
LOGGER_NAME = 'vactor'
|
||||
LOGGER_FILE_NAME = f'{_now.strftime("%Y-%m-%d %H_%M_%S")}_{LOGGER_NAME}.log'
|
||||
|
||||
LOGGER_LEVEL = logging.INFO
|
||||
LOGGER_DIR = "log/"
|
||||
__LOGGER_FILE_PATH = LOGGER_DIR + LOGGER_FILE_NAME
|
||||
|
||||
def get_vactor_logger():
|
||||
"""
|
||||
로거 객체 반환
|
||||
로거 객체가 없을 경우에는 로그 초기화를 진행하고 생성된 로그 객체를 반환함
|
||||
:return: 로그 객체
|
||||
"""
|
||||
|
||||
if not vactor_logger.handlers:
|
||||
logger_init(vactor_logger)
|
||||
|
||||
return vactor_logger
|
||||
|
||||
if vactor_logger is None:
|
||||
|
||||
vactor_logger = logging.getLogger(LOGGER_NAME)
|
||||
|
||||
if not os.path.exists(LOGGER_DIR):
|
||||
os.makedirs(LOGGER_DIR)
|
||||
|
||||
if not __LOGGER_FILE_PATH:
|
||||
logger_init(vactor_logger, level=LOGGER_LEVEL)
|
||||
else:
|
||||
if not os.path.exists(LOGGER_DIR):
|
||||
os.makedirs(LOGGER_DIR)
|
||||
logger_init(vactor_logger, level=LOGGER_LEVEL, file_log_path=__LOGGER_FILE_PATH)
|
||||
2000
datas/eyewear_all.fnames.txt
Normal file
2000
datas/eyewear_all.fnames.txt
Normal file
File diff suppressed because it is too large
Load Diff
BIN
datas/eyewear_all.fvecs.bin
Normal file
BIN
datas/eyewear_all.fvecs.bin
Normal file
Binary file not shown.
BIN
datas/eyewear_all.fvecs.bin.hnsw.index
Normal file
BIN
datas/eyewear_all.fvecs.bin.hnsw.index
Normal file
Binary file not shown.
BIN
datas/eyewear_all.fvecs.bin.l2.index
Normal file
BIN
datas/eyewear_all.fvecs.bin.l2.index
Normal file
Binary file not shown.
14
environment_vactor.yml
Normal file
14
environment_vactor.yml
Normal file
@@ -0,0 +1,14 @@
|
||||
name: fm_rest_vactor
|
||||
channels:
|
||||
- pytorch
|
||||
- nvidia
|
||||
- rapidsai
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.10
|
||||
- libnvjitlink
|
||||
- faiss-gpu-cuvs=1.10.0
|
||||
- tensorflow=2.11.0
|
||||
- pip:
|
||||
- -r requirements_vactor.txt
|
||||
6
main.py
6
main.py
@@ -1,6 +0,0 @@
|
||||
import uvicorn
|
||||
from rest.app.common.config import conf
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run('rest.app.main:app', host='0.0.0.0', port=conf().REST_SERVER_PORT, reload=True)
|
||||
@@ -12,8 +12,8 @@
|
||||
from dataclasses import dataclass
|
||||
from os import path, environ
|
||||
|
||||
from rest.app.common import consts
|
||||
from rest.app.models import UserInfo
|
||||
from main_rest.app.common import consts
|
||||
from main_rest.app.models import UserInfo
|
||||
|
||||
base_dir = path.dirname(path.dirname(path.dirname(path.abspath(__file__))))
|
||||
|
||||
@@ -126,8 +126,8 @@ Base = declarative_base()
|
||||
# NOTE(hsj100): ADMINISTRATOR
|
||||
def create_admin(db_session):
|
||||
import bcrypt
|
||||
from rest.app.database.schema import Users
|
||||
from rest.app.common.consts import ADMIN_INIT_ACCOUNT_INFO
|
||||
from main_rest.app.database.schema import Users
|
||||
from main_rest.app.common.consts import ADMIN_INIT_ACCOUNT_INFO
|
||||
|
||||
session = db_session()
|
||||
|
||||
@@ -16,10 +16,10 @@ from sqlalchemy import func, desc
|
||||
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
from sqlalchemy.orm import Session
|
||||
from rest.app import models as M
|
||||
from rest.app.database.conn import Base, db
|
||||
from rest.app.database.schema import Users, UserLog
|
||||
from rest.app.utils.extra import query_to_groupby, query_to_groupby_date
|
||||
from main_rest.app import models as M
|
||||
from main_rest.app.database.conn import Base, db
|
||||
from main_rest.app.database.schema import Users, UserLog
|
||||
from main_rest.app.utils.extra import query_to_groupby, query_to_groupby_date
|
||||
|
||||
|
||||
def get_month_info_list(start: datetime, end: datetime):
|
||||
@@ -21,9 +21,9 @@ from sqlalchemy import (
|
||||
)
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
|
||||
from rest.app.database.conn import Base, db
|
||||
from rest.app.utils.date_utils import D
|
||||
from rest.app.models import (
|
||||
from main_rest.app.database.conn import Base, db
|
||||
from main_rest.app.utils.date_utils import D
|
||||
from main_rest.app.models import (
|
||||
SexType,
|
||||
UserType,
|
||||
MemberType,
|
||||
@@ -1,4 +1,4 @@
|
||||
from rest.app.common.consts import MAX_API_KEY, MAX_API_WHITELIST
|
||||
from main_rest.app.common.consts import MAX_API_KEY, MAX_API_WHITELIST
|
||||
|
||||
|
||||
class StatusCode:
|
||||
@@ -20,16 +20,16 @@ from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from rest.app.common import consts
|
||||
from main_rest.app.common import consts
|
||||
|
||||
from rest.app.database.conn import db
|
||||
from rest.app.common.config import conf
|
||||
from rest.app.middlewares.token_validator import access_control
|
||||
from rest.app.middlewares.trusted_hosts import TrustedHostMiddleware
|
||||
from rest.app.routes import dev, index, auth, users, services
|
||||
from main_rest.app.database.conn import db
|
||||
from main_rest.app.common.config import conf
|
||||
from main_rest.app.middlewares.token_validator import access_control
|
||||
from main_rest.app.middlewares.trusted_hosts import TrustedHostMiddleware
|
||||
from main_rest.app.routes import dev, index, auth, users, services
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from custom_logger.custom_log import custom_logger as LOG
|
||||
from custom_logger.main_log import main_logger as LOG
|
||||
|
||||
|
||||
API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False)
|
||||
@@ -12,19 +12,19 @@ from jwt.exceptions import ExpiredSignatureError, DecodeError
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from rest.app.common.consts import EXCEPT_PATH_LIST, EXCEPT_PATH_REGEX
|
||||
from rest.app.database.conn import db
|
||||
from rest.app.database.schema import Users, ApiKeys
|
||||
from rest.app.errors import exceptions as ex
|
||||
from main_rest.app.common.consts import EXCEPT_PATH_LIST, EXCEPT_PATH_REGEX
|
||||
from main_rest.app.database.conn import db
|
||||
from main_rest.app.database.schema import Users, ApiKeys
|
||||
from main_rest.app.errors import exceptions as ex
|
||||
|
||||
from rest.app.common import consts
|
||||
from rest.app.common.config import conf
|
||||
from rest.app.errors.exceptions import APIException, SqlFailureEx, APIQueryStringEx
|
||||
from rest.app.models import UserToken
|
||||
from main_rest.app.common import consts
|
||||
from main_rest.app.common.config import conf
|
||||
from main_rest.app.errors.exceptions import APIException, SqlFailureEx, APIQueryStringEx
|
||||
from main_rest.app.models import UserToken
|
||||
|
||||
from rest.app.utils.date_utils import D
|
||||
from rest.app.utils.logger import api_logger
|
||||
from rest.app.utils.query_utils import to_dict
|
||||
from main_rest.app.utils.date_utils import D
|
||||
from main_rest.app.utils.logger import api_logger
|
||||
from main_rest.app.utils.query_utils import to_dict
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
@@ -19,7 +19,7 @@ from pydantic.main import BaseModel
|
||||
from pydantic.networks import EmailStr, IPvAnyAddress
|
||||
from typing import Optional
|
||||
|
||||
from rest.app.common.consts import (
|
||||
from main_rest.app.common.consts import (
|
||||
SW_TITLE,
|
||||
SW_VERSION,
|
||||
MAIL_REG_TITLE,
|
||||
@@ -29,7 +29,7 @@ from rest.app.common.consts import (
|
||||
ADMIN_INIT_ACCOUNT_INFO,
|
||||
DEFAULT_USER_ACCOUNT_PW
|
||||
)
|
||||
from rest.app.utils.date_utils import D
|
||||
from main_rest.app.utils.date_utils import D
|
||||
|
||||
|
||||
class SWInfo(BaseModel):
|
||||
@@ -566,6 +566,10 @@ class UserLogUpdateMultiReq(BaseModel):
|
||||
#===============================================================================
|
||||
#===============================================================================
|
||||
|
||||
class IndexType:
|
||||
hnsw = "hnsw"
|
||||
l2 = "l2"
|
||||
|
||||
class ImageGenerateReq(BaseModel):
|
||||
"""
|
||||
### [Request] image generate request
|
||||
@@ -580,6 +584,15 @@ class BingCookieSetReq(BaseModel):
|
||||
"""
|
||||
cookie : str = Field('',description='쿠키 데이터', example='')
|
||||
|
||||
|
||||
class VactorImageSearchReq(BaseModel):
|
||||
"""
|
||||
### [Request] vactor image search request
|
||||
"""
|
||||
prompt : str = Field(description='프롬프트', example='검은색 안경')
|
||||
index_type : str = Field(IndexType.hnsw, description='인덱스 타입', example=IndexType.hnsw)
|
||||
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||
|
||||
|
||||
#===============================================================================
|
||||
#===============================================================================
|
||||
@@ -17,13 +17,13 @@ import bcrypt
|
||||
import jwt
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from rest.app.common import consts
|
||||
from rest.app import models as M
|
||||
from rest.app.database.conn import db
|
||||
from rest.app.common.config import conf
|
||||
from rest.app.database.schema import Users, UserLog
|
||||
from rest.app.utils.extra import query_to_groupby, AESCryptoCBC
|
||||
from rest.app.utils.date_utils import D
|
||||
from main_rest.app.common import consts
|
||||
from main_rest.app import models as M
|
||||
from main_rest.app.database.conn import db
|
||||
from main_rest.app.common.config import conf
|
||||
from main_rest.app.database.schema import Users, UserLog
|
||||
from main_rest.app.utils.extra import query_to_groupby, AESCryptoCBC
|
||||
from main_rest.app.utils.date_utils import D
|
||||
|
||||
router = APIRouter(prefix='/auth')
|
||||
|
||||
@@ -15,13 +15,13 @@ from sqlalchemy.orm import Session
|
||||
import bcrypt
|
||||
from starlette.requests import Request
|
||||
|
||||
from rest.app.common import consts
|
||||
from rest.app import models as M
|
||||
from rest.app.database.conn import db, Base
|
||||
from rest.app.database.schema import Users, UserLog
|
||||
from main_rest.app.common import consts
|
||||
from main_rest.app import models as M
|
||||
from main_rest.app.database.conn import db, Base
|
||||
from main_rest.app.database.schema import Users, UserLog
|
||||
|
||||
from rest.app.utils.extra import FernetCrypto, AESCryptoCBC, AESCipher
|
||||
from custom_logger.custom_log import custom_logger as LOG
|
||||
from main_rest.app.utils.extra import FernetCrypto, AESCryptoCBC, AESCipher
|
||||
from custom_logger.main_log import main_logger as LOG
|
||||
|
||||
|
||||
# mail test
|
||||
@@ -11,8 +11,8 @@
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from rest.app.utils.date_utils import D
|
||||
from rest.app.models import SWInfo
|
||||
from main_rest.app.utils.date_utils import D
|
||||
from main_rest.app.models import SWInfo
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
@@ -14,15 +14,15 @@ from fastapi import APIRouter, Depends, Body
|
||||
from starlette.requests import Request
|
||||
from typing import Annotated, List
|
||||
|
||||
from rest.app.common import consts
|
||||
from rest.app import models as M
|
||||
from rest.app.utils.date_utils import D
|
||||
from custom_logger.custom_log import custom_logger as LOG
|
||||
from main_rest.app.common import consts
|
||||
from main_rest.app import models as M
|
||||
from main_rest.app.utils.date_utils import D
|
||||
from custom_logger.main_log import main_logger as LOG
|
||||
|
||||
from custom_apps.bingimagecreator.utils import DallEArgument,dalle3_generate_image
|
||||
from custom_apps.bingart.bingart import BingArtGenerator
|
||||
from custom_apps.imagen.custom_imagen import imagen_generate_image
|
||||
from rest.app.utils.parsing_utils import download_range
|
||||
from custom_apps.imagen.custom_imagen import imagen_generate_image, imagen_generate_image_path
|
||||
from main_rest.app.utils.parsing_utils import download_range
|
||||
from custom_apps.utils import cookie_manager
|
||||
|
||||
router = APIRouter(prefix="/services")
|
||||
@@ -144,6 +144,40 @@ async def imagen(request: Request, request_body_info: M.ImageGenerateReq):
|
||||
|
||||
return response.set_message(img_len=img_length)
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(error=e)
|
||||
|
||||
|
||||
@router.post("/vactorImageSearch/imageGenerate/imagen", summary="벡터 이미지 검색 - imagen", response_model=M.ResponseBase)
|
||||
async def vactor_image(request: Request, request_body_info: M.VactorImageSearchReq):
|
||||
"""
|
||||
## 벡터 이미지 검색 - imagen
|
||||
> imagen AI를 이용하여 이미지 생성 후 vactor 검색
|
||||
|
||||
### Requriements
|
||||
> - googlecli 설치(https://cloud.google.com/sdk/docs/install?hl=ko#linux)
|
||||
> - const.py 에 지정한 OUTPUT_FOLDER 하위에 imagen 폴더가 있어야함.
|
||||
|
||||
"""
|
||||
response = M.ResponseBase()
|
||||
try:
|
||||
if request_body_info.index_type not in [M.IndexType.hnsw, M.IndexType.l2]:
|
||||
raise Exception(f"index_type is hnsw or l2 (current value = {request_body_info.index_type})")
|
||||
|
||||
img_path = imagen_generate_image_path(image_prompt=request_body_info.prompt)
|
||||
|
||||
vactor_request_data = {'quary_image_path' : img_path,'index_type' : request_body_info.index_type, 'search_num' : request_body_info.search_num}
|
||||
vactor_response = requests.post('http://localhost:51002/api/services/faiss/vactor/search', data=json.dumps(vactor_request_data))
|
||||
|
||||
if vactor_response.status_code != 200:
|
||||
raise Exception(f"response error: {json.loads(vactor_response.text)['error']}")
|
||||
|
||||
if json.loads(vactor_response.text)["error"] != None:
|
||||
raise Exception(f"vactor error: {json.loads(vactor_response.text)['error']}")
|
||||
|
||||
return response.set_message()
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(error=e)
|
||||
@@ -13,13 +13,13 @@ from fastapi import APIRouter
|
||||
from starlette.requests import Request
|
||||
import bcrypt
|
||||
|
||||
from rest.app.common import consts
|
||||
from rest.app import models as M
|
||||
from rest.app.common.config import conf
|
||||
from rest.app.database.schema import Users
|
||||
from rest.app.database.crud import table_select, table_update, table_delete
|
||||
from main_rest.app.common import consts
|
||||
from main_rest.app import models as M
|
||||
from main_rest.app.common.config import conf
|
||||
from main_rest.app.database.schema import Users
|
||||
from main_rest.app.database.crud import table_select, table_update, table_delete
|
||||
|
||||
from rest.app.utils.extra import AESCryptoCBC
|
||||
from main_rest.app.utils.extra import AESCryptoCBC
|
||||
|
||||
|
||||
router = APIRouter(prefix='/user')
|
||||
@@ -26,10 +26,10 @@ from itertools import groupby
|
||||
from operator import attrgetter
|
||||
import uuid
|
||||
|
||||
from rest.app.common.consts import NUM_RETRY_UUID_GEN, SMTP_HOST, SMTP_PORT
|
||||
from rest.app.utils.date_utils import D
|
||||
from rest.app import models as M
|
||||
from rest.app.common.consts import AES_CBC_PUBLIC_KEY, AES_CBC_IV, FERNET_SECRET_KEY
|
||||
from main_rest.app.common.consts import NUM_RETRY_UUID_GEN, SMTP_HOST, SMTP_PORT
|
||||
from main_rest.app.utils.date_utils import D
|
||||
from main_rest.app import models as M
|
||||
from main_rest.app.common.consts import AES_CBC_PUBLIC_KEY, AES_CBC_IV, FERNET_SECRET_KEY
|
||||
|
||||
|
||||
async def send_mail(sender, sender_pw, title, recipient, contents_plain, contents_html, cc_list, smtp_host=SMTP_HOST, smtp_port=SMTP_PORT):
|
||||
@@ -12,20 +12,26 @@ cryptography
|
||||
pycryptodomex
|
||||
pycryptodome
|
||||
email-validator
|
||||
requests
|
||||
|
||||
#imagen
|
||||
google-cloud-aiplatform
|
||||
Pillow
|
||||
|
||||
#bing img
|
||||
# #bing img
|
||||
aiohttp
|
||||
regex
|
||||
requests
|
||||
httpx
|
||||
nest_asyncio
|
||||
|
||||
#bing art
|
||||
bingart==1.10
|
||||
# #bing art
|
||||
bingart==1.1.0
|
||||
|
||||
#DALL-E 3
|
||||
# openai
|
||||
# openai
|
||||
|
||||
# faiss
|
||||
# scikit-learn
|
||||
# flax
|
||||
# tensorflow
|
||||
18
requirements_vactor.txt
Normal file
18
requirements_vactor.txt
Normal file
@@ -0,0 +1,18 @@
|
||||
#rest
|
||||
fastapi==0.115.0
|
||||
uvicorn==0.16.0
|
||||
pymysql==1.0.2
|
||||
sqlalchemy==1.4.29
|
||||
bcrypt==3.2.0
|
||||
pyjwt==2.3.0
|
||||
yagmail==0.14.261
|
||||
boto3==1.20.32
|
||||
pytest==6.2.5
|
||||
cryptography
|
||||
pycryptodomex
|
||||
pycryptodome
|
||||
email-validator
|
||||
|
||||
#faiss
|
||||
scikit-learn
|
||||
pillow
|
||||
6
rest_main.py
Normal file
6
rest_main.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import uvicorn
|
||||
from main_rest.app.common.config import conf
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run('main_rest.app.main:app', host='0.0.0.0', port=conf().REST_SERVER_PORT, reload=True)
|
||||
6
rest_vactor.py
Normal file
6
rest_vactor.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import uvicorn
|
||||
from vactor_rest.app.common.config import conf
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run('vactor_rest.app.main:app', host='0.0.0.0', port=conf().REST_SERVER_PORT, reload=True)
|
||||
49
vactor_rest/.travis.yml
Normal file
49
vactor_rest/.travis.yml
Normal file
@@ -0,0 +1,49 @@
|
||||
os: linux
|
||||
dist: bionic
|
||||
language: python
|
||||
services:
|
||||
- docker
|
||||
- mysql
|
||||
python:
|
||||
- 3.8
|
||||
before_install:
|
||||
- pip install awscli
|
||||
- export PATH=$PATH:$HOME/.local/bin
|
||||
script:
|
||||
- pytest -v
|
||||
before_deploy:
|
||||
- if [ $TRAVIS_BRANCH == "master" ]; then export EB_ENV=notification-prd-api; fi
|
||||
- if [ $TRAVIS_BRANCH == "develop" ]; then export EB_ENV=notification-dev-api; fi
|
||||
- export REPO_NAME=$(echo $TRAVIS_REPO_SLUG | sed "s_^.*/__")
|
||||
- export ELASTIC_BEANSTALK_LABEL=${REPO_NAME}-${TRAVIS_COMMIT::7}-$(date +%y%m%d%H%M%S)
|
||||
deploy:
|
||||
skip_cleanup: true
|
||||
provider: elasticbeanstalk
|
||||
access_key_id: $AWS_ACCESS
|
||||
secret_access_key: $AWS_SECRET
|
||||
region: ap-northeast-2
|
||||
bucket: elasticbeanstalk-ap-northeast-2-884465654078
|
||||
bucket_path: notification-api
|
||||
app: notification-api
|
||||
env: $EB_ENV
|
||||
on:
|
||||
all_branches: true
|
||||
condition: $TRAVIS_BRANCH =~ ^develop|master
|
||||
|
||||
#notifications:
|
||||
# slack:
|
||||
# - rooms:
|
||||
# - secure: ***********
|
||||
# if: branch = master
|
||||
# template:
|
||||
# - "Repo `%{repository_slug}` *%{result}* build (<%{build_url}|#%{build_number}>) for commit (<%{compare_url}|%{commit}>) on branch `%{branch}`."
|
||||
# - rooms:
|
||||
# - secure: ***********
|
||||
# if: branch = staging
|
||||
# template:
|
||||
# - "Repo `%{repository_slug}` *%{result}* build (<%{build_url}|#%{build_number}>) for commit (<%{compare_url}|%{commit}>) on branch `%{branch}`."
|
||||
# - rooms:
|
||||
# - secure: ***********
|
||||
# if: branch = develop
|
||||
# template:
|
||||
# - "Repo `%{repository_slug}` *%{result}* build (<%{build_url}|#%{build_number}>) for commit (<%{compare_url}|%{commit}>) on branch `%{branch}`."
|
||||
46
vactor_rest/app/api_request_sample.py
Normal file
46
vactor_rest/app/api_request_sample.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: api_request_sample.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: test api key
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hmac
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def parse_params_to_str(params):
|
||||
url = "?"
|
||||
for key, value in params.items():
|
||||
url = url + str(key) + '=' + str(value) + '&'
|
||||
return url[1:-1]
|
||||
|
||||
|
||||
def hash_string(qs, secret_key):
|
||||
mac = hmac.new(bytes(secret_key, encoding='utf8'), bytes(qs, encoding='utf-8'), digestmod='sha256')
|
||||
d = mac.digest()
|
||||
validating_secret = str(base64.b64encode(d).decode('utf-8'))
|
||||
return validating_secret
|
||||
|
||||
|
||||
def sample_request():
|
||||
access_key = 'c0883231-4aa9-4a1f-a77b-3ef250af-e449-42e9-856a-b3ada17c426b'
|
||||
secret_key = 'QhOaeXTAAkW6yWt31jWDeERkBsZ3X4UmPds656YD'
|
||||
cur_time = datetime.utcnow()+timedelta(hours=9)
|
||||
cur_timestamp = int(cur_time.timestamp())
|
||||
qs = dict(key=access_key, timestamp=cur_timestamp)
|
||||
header_secret = hash_string(parse_params_to_str(qs), secret_key)
|
||||
|
||||
url = f'http://127.0.0.1:8080/api/services?{parse_params_to_str(qs)}'
|
||||
res = requests.get(url, headers=dict(secret=header_secret))
|
||||
return res
|
||||
|
||||
|
||||
print(sample_request().json())
|
||||
111
vactor_rest/app/common/config.py
Normal file
111
vactor_rest/app/common/config.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: config.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: Configurations
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from os import path, environ
|
||||
|
||||
from vactor_rest.app.common import consts
|
||||
from vactor_rest.app.models import UserInfo
|
||||
|
||||
base_dir = path.dirname(path.dirname(path.dirname(path.abspath(__file__))))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""
|
||||
기본 Configuration
|
||||
"""
|
||||
BASE_DIR: str = base_dir
|
||||
DB_POOL_RECYCLE: int = 900
|
||||
DB_ECHO: bool = True
|
||||
DEBUG: bool = False
|
||||
TEST_MODE: bool = False
|
||||
DEV_TEST_CONNECT_ACCOUNT: str | None = None
|
||||
|
||||
# NOTE(hsj100): SERVICE_AUTH_API_KEY (token 방식으로 진행)
|
||||
SERVICE_AUTH_API_KEY: bool = False
|
||||
|
||||
DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}:{consts.DB_PORT}/{consts.DB_NAME}?charset={consts.DB_CHARSET}')
|
||||
REST_SERVER_PORT = consts.REST_SERVER_PORT
|
||||
|
||||
SW_TITLE = consts.SW_TITLE
|
||||
SW_VERSION = consts.SW_VERSION
|
||||
SW_DESCRIPTION = consts.SW_DESCRIPTION
|
||||
TERMS_OF_SERVICE = consts.TERMS_OF_SERVICE
|
||||
CONTEACT = consts.CONTEACT
|
||||
LICENSE_INFO = consts.LICENSE_INFO
|
||||
|
||||
GLOBAL_TOKEN = consts.ADMIN_INIT_ACCOUNT_INFO.connect_token
|
||||
|
||||
COOKIES_AUTH = 'Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6MTQsImVtYWlsIjoia29hbGFAZGluZ3JyLmNvbSIsIm5hbWUiOm51bGwsInBob25lX251bWJlciI6bnVsbCwicHJvZmlsZV9pbWciOm51bGwsInNuc190eXBlIjpudWxsfQ.4vgrFvxgH8odoXMvV70BBqyqXOFa2NDQtzYkGywhV48'
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalConfig(Config):
|
||||
TRUSTED_HOSTS = ['*']
|
||||
ALLOW_SITE = ['*']
|
||||
DEBUG: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProdConfig(Config):
|
||||
TRUSTED_HOSTS = ['*']
|
||||
ALLOW_SITE = ['*']
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig(Config):
|
||||
DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}/{consts.DB_NAME}_test?charset={consts.DB_CHARSET}')
|
||||
TRUSTED_HOSTS = ['*']
|
||||
ALLOW_SITE = ['*']
|
||||
TEST_MODE: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class DevConfig(Config):
|
||||
TRUSTED_HOSTS = ['*']
|
||||
ALLOW_SITE = ['*']
|
||||
DEBUG: bool = True
|
||||
DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}/{consts.DB_NAME}_dev?charset={consts.DB_CHARSET}')
|
||||
REST_SERVER_PORT = consts.REST_SERVER_PORT + 1
|
||||
|
||||
SW_TITLE = '[Dev] ' + consts.SW_TITLE
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyConfig(Config):
|
||||
TRUSTED_HOSTS = ['*']
|
||||
ALLOW_SITE = ['*']
|
||||
DEBUG: bool = True
|
||||
DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}/{consts.DB_NAME}_my?charset={consts.DB_CHARSET}')
|
||||
REST_SERVER_PORT = consts.REST_SERVER_PORT + 2
|
||||
|
||||
# NOTE(hsj100): DEV_TEST_CONNECT_ACCOUNT
|
||||
# DEV_TEST_CONNECT_ACCOUNT: UserInfo = UserInfo(**consts.ADMIN_INIT_ACCOUNT_INFO)
|
||||
# DEV_TEST_CONNECT_ACCOUNT: str = None
|
||||
|
||||
# NOTE(hsj100): SERVICE_AUTH_API_KEY (token 방식으로 진행)
|
||||
SERVICE_AUTH_API_KEY: bool = False
|
||||
|
||||
SW_TITLE = '[My] ' + consts.SW_TITLE
|
||||
|
||||
|
||||
def conf():
|
||||
"""
|
||||
환경 불러오기
|
||||
:return:
|
||||
"""
|
||||
config = dict(prod=ProdConfig, local=LocalConfig, test=TestConfig, dev=DevConfig, my=MyConfig)
|
||||
return config[environ.get('API_ENV', 'local')]()
|
||||
return config[environ.get('API_ENV', 'dev')]()
|
||||
return config[environ.get('API_ENV', 'my')]()
|
||||
return config[environ.get('API_ENV', 'test')]()
|
||||
|
||||
104
vactor_rest/app/common/consts.py
Normal file
104
vactor_rest/app/common/consts.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: consts.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: consts
|
||||
"""
|
||||
|
||||
# SUPPORT PROJECT
|
||||
SUPPORT_PROJECT_BASIC = 'PROJECT_BASIC'
|
||||
|
||||
PROJECT_NAME = 'FERMAT-TEST(Vactor REST API)'
|
||||
SW_TITLE= f'{PROJECT_NAME} - REST API'
|
||||
SW_VERSION = '0.1.0'
|
||||
SW_DESCRIPTION = f'''
|
||||
### FERMAT-TEST(Vactor REST API) REST API
|
||||
|
||||
## API 이용법
|
||||
- 개별 API 설명과 Request/Response schema 참조
|
||||
|
||||
|
||||
'''
|
||||
TERMS_OF_SERVICE = 'http://www.a2tec.co.kr'
|
||||
CONTEACT={
|
||||
'name': 'A2TEC (주)에이투텍',
|
||||
'url': 'http://www.a2tec.co.kr',
|
||||
'email': 'marketing@a2tec.co.kr'
|
||||
}
|
||||
LICENSE_INFO = {
|
||||
'name': 'Copyright by A2TEC', 'url': 'http://www.a2tec.co.kr'
|
||||
}
|
||||
|
||||
REST_SERVER_PORT = 51002
|
||||
DEFAULT_USER_ACCOUNT_PW = '1234'
|
||||
|
||||
|
||||
class AdminInfo:
|
||||
def __init__(self):
|
||||
self.id: int = 1
|
||||
self.user_type: str = 'admin'
|
||||
self.account: str = 'a2d2_lc_manager@naver.com' # !ekdnfeldpsdptm1 다울디엔에스
|
||||
self.pw: str = '$2b$12$PklBvVXdLhOQnIiNanlnIu.DJh5MspRARVChJQfFu1qg35vBoIuX2'
|
||||
self.name: str = 'administrator'
|
||||
self.email: str = 'a2d2_lc_manager@naver.com' # daool1020
|
||||
self.email_pw: str = 'gAAAAABioV5NucuS9nQugZJnz-KjVG_FGnaowB9KAfhOoWjjiQ4jGLuYJh4Qe94mT_lCm6m3HhuOJqUeOgjppwREDpIQYzrUXA=='
|
||||
self.address: str = '대구광역시 동구 동촌로351 에이스빌딩 4F'
|
||||
self.phone_number: str = '053-384-3010'
|
||||
self.connect_token: str = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6MSwiYWNjb3VudCI6ImEyZDJfbGNfbWFuYWdlckBuYXZlci5jb20iLCJuYW1lIjoiYWRtaW5pc3RyYXRvciIsInBob25lX251bWJlciI6IjA1My0zODQtMzAxMCIsInByb2ZpbGVfaW1nIjpudWxsLCJhY2NvdW50X3R5cGUiOiJlbWFpbCJ9.SlQSCfAof1bv2YxmW2DO4dIBrbHLg1jPO3AJsX6xKbw'
|
||||
|
||||
def get_dict(self):
|
||||
info = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if type(v) is tuple:
|
||||
info[k] = v[0]
|
||||
else:
|
||||
info[k] = v
|
||||
return info
|
||||
|
||||
|
||||
ADMIN_INIT_ACCOUNT_INFO = AdminInfo()
|
||||
|
||||
FERNET_SECRET_KEY = b'wQjpSYkmc4kX8MaAovk1NIHF02R2wZX760eeBTeIHW4='
|
||||
AES_CBC_PUBLIC_KEY = b'daooldns12345678'
|
||||
AES_CBC_IV = b'daooldns12345678'
|
||||
|
||||
COOKIES_AUTH = 'Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6MTQsImVtYWlsIjoia29hbGFAZGluZ3JyLmNvbSIsIm5hbWUiOm51bGwsInBob25lX251bWJlciI6bnVsbCwicHJvZmlsZV9pbWciOm51bGwsInNuc190eXBlIjpudWxsfQ.4vgrFvxgH8odoXMvV70BBqyqXOFa2NDQtzYkGywhV48'
|
||||
|
||||
JWT_SECRET = 'ABCD1234!'
|
||||
JWT_ALGORITHM = 'HS256'
|
||||
EXCEPT_PATH_LIST = ['/', '/openapi.json']
|
||||
EXCEPT_PATH_REGEX = '^(/docs|/redoc|/api/auth' +\
|
||||
'|/api/user/check_account_exist' +\
|
||||
'|/api/services' + \
|
||||
'|/api/temp' + \
|
||||
'|/api/dev' + \
|
||||
'|/static' + \
|
||||
')'
|
||||
MAX_API_KEY = 3
|
||||
MAX_API_WHITELIST = 10
|
||||
|
||||
NUM_RETRY_UUID_GEN = 3
|
||||
|
||||
# DATABASE
|
||||
DB_ADDRESS = "localhost"
|
||||
DB_PORT = 53306
|
||||
DB_USER_ID = 'root'
|
||||
DB_USER_PW = '1234'
|
||||
DB_NAME = 'FM_TEST'
|
||||
DB_CHARSET = 'utf8mb4'
|
||||
|
||||
# MAIL
|
||||
# SMTP_HOST = 'smtp.gmail.com'
|
||||
# SMTP_PORT = 587
|
||||
SMTP_HOST = 'smtp.naver.com'
|
||||
SMTP_PORT = 587
|
||||
MAIL_REG_TITLE = f'{PROJECT_NAME} - Registration'
|
||||
MAIL_REG_CONTENTS = '''
|
||||
안녕하세요.
|
||||
감사합니다.
|
||||
|
||||
'''
|
||||
142
vactor_rest/app/database/conn.py
Normal file
142
vactor_rest/app/database/conn.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: conn.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: DB Connections
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
import logging
|
||||
|
||||
|
||||
def _database_exist(engine, schema_name):
|
||||
query = f'SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = "{schema_name}"'
|
||||
with engine.connect() as conn:
|
||||
result_proxy = conn.execute(query)
|
||||
result = result_proxy.scalar()
|
||||
return bool(result)
|
||||
|
||||
|
||||
def _drop_database(engine, schema_name):
|
||||
with engine.connect() as conn:
|
||||
conn.execute(f'DROP DATABASE {schema_name};')
|
||||
|
||||
|
||||
def _create_database(engine, schema_name):
|
||||
with engine.connect() as conn:
|
||||
conn.execute(f'CREATE DATABASE {schema_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_bin;')
|
||||
|
||||
|
||||
class SQLAlchemy:
|
||||
def __init__(self, app: FastAPI = None, **kwargs):
|
||||
self._engine = None
|
||||
self._session = None
|
||||
if app is not None:
|
||||
self.init_app(app=app, **kwargs)
|
||||
|
||||
def init_app(self, app: FastAPI, **kwargs):
|
||||
"""
|
||||
DB 초기화 함수
|
||||
:param app: FastAPI 인스턴스
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
database_url = kwargs.get('DB_URL')
|
||||
pool_recycle = kwargs.setdefault('DB_POOL_RECYCLE', 900)
|
||||
is_testing = kwargs.setdefault('TEST_MODE', False)
|
||||
echo = kwargs.setdefault('DB_ECHO', True)
|
||||
|
||||
self._engine = create_engine(
|
||||
database_url,
|
||||
echo=echo,
|
||||
pool_recycle=pool_recycle,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
if is_testing: # create schema
|
||||
db_url = self._engine.url
|
||||
if db_url.host != 'localhost':
|
||||
raise Exception('db host must be \'localhost\' in test environment')
|
||||
except_schema_db_url = f'{db_url.drivername}://{db_url.username}:{db_url.password}@{db_url.host}:{db_url.port}'
|
||||
schema_name = db_url.database
|
||||
temp_engine = create_engine(except_schema_db_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True)
|
||||
if _database_exist(temp_engine, schema_name):
|
||||
_drop_database(temp_engine, schema_name)
|
||||
_create_database(temp_engine, schema_name)
|
||||
temp_engine.dispose()
|
||||
else:
|
||||
db_url = self._engine.url
|
||||
except_schema_db_url = f'{db_url.drivername}://{db_url.username}:{db_url.password}@{db_url.host}:{db_url.port}'
|
||||
schema_name = db_url.database
|
||||
temp_engine = create_engine(except_schema_db_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True)
|
||||
if not _database_exist(temp_engine, schema_name):
|
||||
_create_database(temp_engine, schema_name)
|
||||
Base.metadata.create_all(db.engine)
|
||||
temp_engine.dispose()
|
||||
|
||||
self._session = sessionmaker(autocommit=False, autoflush=False, bind=self._engine)
|
||||
|
||||
# NOTE(hsj100): ADMINISTRATOR
|
||||
create_admin(self._session)
|
||||
|
||||
@app.on_event('startup')
|
||||
def startup():
|
||||
self._engine.connect()
|
||||
logging.info('DB connected.')
|
||||
|
||||
@app.on_event('shutdown')
|
||||
def shutdown():
|
||||
self._session.close_all()
|
||||
self._engine.dispose()
|
||||
logging.info('DB disconnected')
|
||||
|
||||
def get_db(self):
|
||||
"""
|
||||
요청마다 DB 세션 유지 함수
|
||||
:return:
|
||||
"""
|
||||
if self._session is None:
|
||||
raise Exception('must be called \'init_app\'')
|
||||
db_session = None
|
||||
try:
|
||||
db_session = self._session()
|
||||
yield db_session
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self.get_db
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
return self._engine
|
||||
|
||||
|
||||
db = SQLAlchemy()
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# NOTE(hsj100): ADMINISTRATOR
|
||||
def create_admin(db_session):
|
||||
import bcrypt
|
||||
from vactor_rest.app.database.schema import Users
|
||||
from vactor_rest.app.common.consts import ADMIN_INIT_ACCOUNT_INFO
|
||||
|
||||
session = db_session()
|
||||
|
||||
if not session:
|
||||
raise Exception('cat`t create account of admin')
|
||||
|
||||
if Users.get(account=ADMIN_INIT_ACCOUNT_INFO.account):
|
||||
return
|
||||
|
||||
admin = {**ADMIN_INIT_ACCOUNT_INFO.get_dict()}
|
||||
|
||||
Users.create(session=session, auto_commit=True, **admin)
|
||||
300
vactor_rest/app/database/crud.py
Normal file
300
vactor_rest/app/database/crud.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: crud.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: CRUD
|
||||
"""
|
||||
|
||||
import math
|
||||
from datetime import datetime
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from sqlalchemy import func, desc
|
||||
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
from sqlalchemy.orm import Session
|
||||
from vactor_rest.app import models as M
|
||||
from vactor_rest.app.database.conn import Base, db
|
||||
from vactor_rest.app.database.schema import Users, UserLog
|
||||
from vactor_rest.app.utils.extra import query_to_groupby, query_to_groupby_date
|
||||
|
||||
|
||||
def get_month_info_list(start: datetime, end: datetime):
|
||||
delta = relativedelta(end, start)
|
||||
month_delta = 12 * delta.years + delta.months + (1 if delta.days > 0 else 0)
|
||||
|
||||
month_list = list()
|
||||
for i in range(month_delta):
|
||||
count = start + relativedelta(months=i)
|
||||
month_info = dict()
|
||||
month_info['period_title'] = count.strftime('%Y-%m')
|
||||
month_info['year'] = count.year
|
||||
month_info['month'] = count.month
|
||||
month_info['start'] = (start + relativedelta(months=i)).replace(day=1)
|
||||
month_info['end'] = (start + relativedelta(months=i + 1)).replace(day=1)
|
||||
month_list.append(month_info)
|
||||
return month_list
|
||||
|
||||
|
||||
def request_parser(request_data: dict = None) -> dict:
|
||||
"""
|
||||
request information -> dict
|
||||
|
||||
:param request_data:
|
||||
:return:
|
||||
"""
|
||||
result_dict = dict()
|
||||
if not request_data:
|
||||
return result_dict
|
||||
for key, val in request_data.items():
|
||||
if val is not None:
|
||||
result_dict[key] = val if val != 'null' else None
|
||||
return result_dict
|
||||
|
||||
|
||||
def dict_to_filter_fmt(dict_data, get_attributes_callback=None):
|
||||
"""
|
||||
dict -> sqlalchemy filter (criterion: sql expression)
|
||||
|
||||
:param dict_data:
|
||||
:param get_attributes_callback:
|
||||
:return:
|
||||
"""
|
||||
if get_attributes_callback is None:
|
||||
raise Exception('invalid get_attributes_callback')
|
||||
|
||||
criterion = list()
|
||||
|
||||
for key, val in dict_data.items():
|
||||
key = key.split('__')
|
||||
if len(key) > 2:
|
||||
raise Exception('length of split(key) should be no more than 2.')
|
||||
|
||||
key_length = len(key)
|
||||
col = get_attributes_callback(key[0])
|
||||
|
||||
if col is None:
|
||||
continue
|
||||
|
||||
if key_length == 1:
|
||||
criterion.append((col == val))
|
||||
elif key_length == 2 and key[1] == 'gt':
|
||||
criterion.append((col > val))
|
||||
elif key_length == 2 and key[1] == 'gte':
|
||||
criterion.append((col >= val))
|
||||
elif key_length == 2 and key[1] == 'lt':
|
||||
criterion.append((col < val))
|
||||
elif key_length == 2 and key[1] == 'lte':
|
||||
criterion.append((col <= val))
|
||||
elif key_length == 2 and key[1] == 'in':
|
||||
criterion.append((col.in_(val)))
|
||||
elif key_length == 2 and key[1] == 'like':
|
||||
criterion.append((col.like(val)))
|
||||
|
||||
return criterion
|
||||
|
||||
|
||||
async def table_select(accessor_info, target_table, request_body_info, response_model, response_model_data):
|
||||
"""
|
||||
table_read
|
||||
"""
|
||||
try:
|
||||
# parameter
|
||||
if not accessor_info:
|
||||
raise Exception('invalid accessor')
|
||||
|
||||
if not target_table:
|
||||
raise Exception(f'invalid table_name:{target_table}')
|
||||
|
||||
# if not request_body_info:
|
||||
# raise Exception('invalid request_body_info')
|
||||
|
||||
if not response_model:
|
||||
raise Exception('invalid response_model')
|
||||
|
||||
if not response_model_data:
|
||||
raise Exception('invalid response_model_data')
|
||||
|
||||
# paging - request
|
||||
paging_request = None
|
||||
if request_body_info:
|
||||
if hasattr(request_body_info, 'paging'):
|
||||
paging_request = request_body_info.paging
|
||||
request_body_info = request_body_info.search
|
||||
|
||||
# request
|
||||
request_info = request_parser(request_body_info.dict())
|
||||
|
||||
# search
|
||||
criterion = None
|
||||
if isinstance(request_body_info, M.UserLogDaySearchReq):
|
||||
# UserLog search
|
||||
# request
|
||||
request_info = request_parser(request_body_info.dict())
|
||||
|
||||
# search UserLog
|
||||
def get_attributes_callback(key: str):
|
||||
return getattr(UserLog, key)
|
||||
criterion = dict_to_filter_fmt(request_info, get_attributes_callback)
|
||||
|
||||
# search
|
||||
session = next(db.session())
|
||||
search_info = session.query(UserLog)\
|
||||
.filter(UserLog.mac.isnot(None)
|
||||
, UserLog.api == '/api/auth/login'
|
||||
, UserLog.type == M.UserLogMessageType.info
|
||||
, UserLog.message == 'ok'
|
||||
, *criterion)\
|
||||
.order_by(desc(UserLog.created_at), desc(UserLog.updated_at))\
|
||||
.all()
|
||||
if not search_info:
|
||||
raise Exception('not found data')
|
||||
|
||||
group_by_day = query_to_groupby_date(search_info, 'updated_at')
|
||||
|
||||
# result
|
||||
result_info = list()
|
||||
for day, info_list in group_by_day.items():
|
||||
info_by_mac = query_to_groupby(info_list, 'mac', first=True)
|
||||
for log_info in info_by_mac.values():
|
||||
result_info.append(response_model_data.from_orm(log_info))
|
||||
else:
|
||||
# basic search (single table)
|
||||
# request
|
||||
request_info = request_parser(request_body_info.dict())
|
||||
|
||||
# search
|
||||
search_info = target_table.filter(**request_info).all()
|
||||
if not search_info:
|
||||
raise Exception('not found data')
|
||||
|
||||
# result
|
||||
result_info = list()
|
||||
for purchase_info in search_info:
|
||||
result_info.append(response_model_data.from_orm(purchase_info))
|
||||
|
||||
# response - paging
|
||||
paging_response = None
|
||||
if paging_request:
|
||||
total_contents_num = len(result_info)
|
||||
total_page_num = math.ceil(total_contents_num / paging_request.page_contents_num)
|
||||
start_contents_index = (paging_request.start_page - 1) * paging_request.page_contents_num
|
||||
end_contents_index = start_contents_index + paging_request.page_contents_num
|
||||
if end_contents_index < total_contents_num:
|
||||
result_info = result_info[start_contents_index: end_contents_index]
|
||||
else:
|
||||
result_info = result_info[start_contents_index:]
|
||||
|
||||
paging_response = M.PagingRes()
|
||||
paging_response.total_page_num = total_page_num
|
||||
paging_response.total_contents_num = total_contents_num
|
||||
paging_response.start_page = paging_request.start_page
|
||||
paging_response.search_contents_num = len(result_info)
|
||||
|
||||
return response_model(data=result_info, paging=paging_response)
|
||||
except Exception as e:
|
||||
return response_model.set_error(str(e))
|
||||
|
||||
|
||||
async def table_update(accessor_info, target_table, request_body_info, response_model, response_model_data=None):
|
||||
search_info = None
|
||||
|
||||
try:
|
||||
|
||||
# parameter
|
||||
if not accessor_info:
|
||||
raise Exception('invalid accessor')
|
||||
|
||||
if not target_table:
|
||||
raise Exception(f'invalid table_name:{target_table}')
|
||||
|
||||
if not request_body_info:
|
||||
raise Exception('invalid request_body_info')
|
||||
|
||||
if not response_model:
|
||||
raise Exception('invalid response_model')
|
||||
|
||||
# if not response_model_data:
|
||||
# raise Exception('invalid response_model_data')
|
||||
|
||||
# request
|
||||
if not request_body_info.search_info:
|
||||
raise Exception('invalid request_body: search_info')
|
||||
|
||||
request_search_info = request_parser(request_body_info.search_info.dict())
|
||||
if not request_search_info:
|
||||
raise Exception('invalid request_body: search_info')
|
||||
|
||||
if not request_body_info.update_info:
|
||||
raise Exception('invalid request_body: update_info')
|
||||
|
||||
request_update_info = request_parser(request_body_info.update_info.dict())
|
||||
if not request_update_info:
|
||||
raise Exception('invalid request_body: update_info')
|
||||
|
||||
# search
|
||||
search_info = target_table.filter(**request_search_info)
|
||||
|
||||
# process
|
||||
search_info.update(auto_commit=True, synchronize_session=False, **request_update_info)
|
||||
|
||||
# result
|
||||
return response_model()
|
||||
except Exception as e:
|
||||
if search_info:
|
||||
search_info.close()
|
||||
return response_model.set_error(str(e))
|
||||
|
||||
|
||||
async def table_delete(accessor_info, target_table, request_body_info, response_model, response_model_data=None):
|
||||
search_info = None
|
||||
|
||||
try:
|
||||
# request
|
||||
if not accessor_info:
|
||||
raise Exception('invalid accessor')
|
||||
|
||||
if not target_table:
|
||||
raise Exception(f'invalid table_name:{target_table}')
|
||||
|
||||
if not request_body_info:
|
||||
raise Exception('invalid request_body_info')
|
||||
|
||||
if not response_model:
|
||||
raise Exception('invalid response_model')
|
||||
|
||||
# if not response_model_data:
|
||||
# raise Exception('invalid response_model_data')
|
||||
|
||||
# request
|
||||
request_search_info = request_parser(request_body_info.dict())
|
||||
if not request_search_info:
|
||||
raise Exception('invalid request_body')
|
||||
|
||||
# search
|
||||
search_info = target_table.filter(**request_search_info)
|
||||
temp_search = search_info.all()
|
||||
|
||||
# process
|
||||
search_info.delete(auto_commit=True, synchronize_session=False)
|
||||
|
||||
# update license num
|
||||
uuid_list = list()
|
||||
for _license in temp_search:
|
||||
if not hasattr(temp_search, 'uuid'):
|
||||
# case: license
|
||||
break
|
||||
if _license.uuid not in uuid_list:
|
||||
uuid_list.append(_license.uuid)
|
||||
license_num = target_table.filter(uuid=_license.uuid).count()
|
||||
target_table.filter(uuid=_license.uuid).update(auto_commit=True, synchronize_session=False, num=license_num)
|
||||
|
||||
# result
|
||||
return response_model()
|
||||
except Exception as e:
|
||||
if search_info:
|
||||
search_info.close()
|
||||
return response_model.set_error(str(e))
|
||||
304
vactor_rest/app/database/schema.py
Normal file
304
vactor_rest/app/database/schema.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: schema.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: database schema
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
String,
|
||||
DateTime,
|
||||
func,
|
||||
Enum,
|
||||
Boolean,
|
||||
ForeignKey,
|
||||
)
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
|
||||
from vactor_rest.app.database.conn import Base, db
|
||||
from vactor_rest.app.utils.date_utils import D
|
||||
from vactor_rest.app.models import (
|
||||
SexType,
|
||||
UserType,
|
||||
MemberType,
|
||||
AccountType,
|
||||
UserStatusType,
|
||||
UserLoginType,
|
||||
UserLogMessageType
|
||||
)
|
||||
|
||||
|
||||
class BaseMixin:
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
created_at = Column(DateTime, nullable=False, default=D.datetime())
|
||||
updated_at = Column(DateTime, nullable=False, default=D.datetime(), onupdate=D.datetime())
|
||||
|
||||
def __init__(self):
|
||||
self._q = None
|
||||
self._session = None
|
||||
self.served = None
|
||||
|
||||
def all_columns(self):
|
||||
return [c for c in self.__table__.columns if c.primary_key is False and c.name != 'created_at']
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.id)
|
||||
|
||||
@classmethod
|
||||
def create(cls, session: Session, auto_commit=False, **kwargs):
|
||||
"""
|
||||
테이블 데이터 적재 전용 함수
|
||||
:param session:
|
||||
:param auto_commit: 자동 커밋 여부
|
||||
:param kwargs: 적재 할 데이터
|
||||
:return:
|
||||
"""
|
||||
obj = cls()
|
||||
|
||||
# NOTE(hsj100) : FIX_DATETIME
|
||||
if 'created_at' not in kwargs:
|
||||
obj.created_at = D.datetime()
|
||||
if 'updated_at' not in kwargs:
|
||||
obj.updated_at = D.datetime()
|
||||
|
||||
for col in obj.all_columns():
|
||||
col_name = col.name
|
||||
if col_name in kwargs:
|
||||
setattr(obj, col_name, kwargs.get(col_name))
|
||||
session.add(obj)
|
||||
session.flush()
|
||||
if auto_commit:
|
||||
session.commit()
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def get(cls, session: Session = None, **kwargs):
|
||||
"""
|
||||
Simply get a Row
|
||||
:param session:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
sess = next(db.session()) if not session else session
|
||||
query = sess.query(cls)
|
||||
for key, val in kwargs.items():
|
||||
col = getattr(cls, key)
|
||||
query = query.filter(col == val)
|
||||
|
||||
if query.count() > 1:
|
||||
raise Exception('Only one row is supposed to be returned, but got more than one.')
|
||||
result = query.first()
|
||||
if not session:
|
||||
sess.close()
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def filter(cls, session: Session = None, **kwargs):
|
||||
"""
|
||||
Simply get a Row
|
||||
:param session:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
cond = []
|
||||
for key, val in kwargs.items():
|
||||
key = key.split('__')
|
||||
if len(key) > 2:
|
||||
raise Exception('length of split(key) should be no more than 2.')
|
||||
col = getattr(cls, key[0])
|
||||
if len(key) == 1: cond.append((col == val))
|
||||
elif len(key) == 2 and key[1] == 'gt': cond.append((col > val))
|
||||
elif len(key) == 2 and key[1] == 'gte': cond.append((col >= val))
|
||||
elif len(key) == 2 and key[1] == 'lt': cond.append((col < val))
|
||||
elif len(key) == 2 and key[1] == 'lte': cond.append((col <= val))
|
||||
elif len(key) == 2 and key[1] == 'in': cond.append((col.in_(val)))
|
||||
elif len(key) == 2 and key[1] == 'like': cond.append((col.like(val)))
|
||||
|
||||
obj = cls()
|
||||
if session:
|
||||
obj._session = session
|
||||
obj.served = True
|
||||
else:
|
||||
obj._session = next(db.session())
|
||||
obj.served = False
|
||||
query = obj._session.query(cls)
|
||||
query = query.filter(*cond)
|
||||
obj._q = query
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def cls_attr(cls, col_name=None):
|
||||
if col_name:
|
||||
col = getattr(cls, col_name)
|
||||
return col
|
||||
else:
|
||||
return cls
|
||||
|
||||
def order_by(self, *args: str):
|
||||
for a in args:
|
||||
if a.startswith('-'):
|
||||
col_name = a[1:]
|
||||
is_asc = False
|
||||
else:
|
||||
col_name = a
|
||||
is_asc = True
|
||||
col = self.cls_attr(col_name)
|
||||
self._q = self._q.order_by(col.asc()) if is_asc else self._q.order_by(col.desc())
|
||||
return self
|
||||
|
||||
def update(self, auto_commit: bool = False, synchronize_session='evaluate', **kwargs):
|
||||
# NOTE(hsj100) : FIX_DATETIME
|
||||
if 'updated_at' not in kwargs:
|
||||
kwargs['updated_at'] = D.datetime()
|
||||
|
||||
qs = self._q.update(kwargs, synchronize_session=synchronize_session)
|
||||
get_id = self.id
|
||||
ret = None
|
||||
|
||||
self._session.flush()
|
||||
if qs > 0 :
|
||||
ret = self._q.first()
|
||||
if auto_commit:
|
||||
self._session.commit()
|
||||
return ret
|
||||
|
||||
def first(self):
|
||||
result = self._q.first()
|
||||
self.close()
|
||||
return result
|
||||
|
||||
def delete(self, auto_commit: bool = False, synchronize_session='evaluate'):
|
||||
self._q.delete(synchronize_session=synchronize_session)
|
||||
if auto_commit:
|
||||
self._session.commit()
|
||||
|
||||
def all(self):
|
||||
print(self.served)
|
||||
result = self._q.all()
|
||||
self.close()
|
||||
return result
|
||||
|
||||
def count(self):
|
||||
result = self._q.count()
|
||||
self.close()
|
||||
return result
|
||||
|
||||
def close(self):
|
||||
if not self.served:
|
||||
self._session.close()
|
||||
else:
|
||||
self._session.flush()
|
||||
|
||||
|
||||
class ApiKeys(Base, BaseMixin):
|
||||
__tablename__ = 'api_keys'
|
||||
__table_args__ = {'extend_existing': True}
|
||||
|
||||
access_key = Column(String(length=64), nullable=False, index=True)
|
||||
secret_key = Column(String(length=64), nullable=False)
|
||||
user_memo = Column(String(length=40), nullable=True)
|
||||
status = Column(Enum('active', 'stopped', 'deleted'), default='active')
|
||||
is_whitelisted = Column(Boolean, default=False)
|
||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
||||
whitelist = relationship('ApiWhiteLists', backref='api_keys')
|
||||
users = relationship('Users', back_populates='keys')
|
||||
|
||||
|
||||
class ApiWhiteLists(Base, BaseMixin):
|
||||
__tablename__ = 'api_whitelists'
|
||||
__table_args__ = {'extend_existing': True}
|
||||
|
||||
ip_addr = Column(String(length=64), nullable=False)
|
||||
api_key_id = Column(Integer, ForeignKey('api_keys.id'), nullable=False)
|
||||
|
||||
|
||||
class Users(Base, BaseMixin):
|
||||
__tablename__ = 'users'
|
||||
__table_args__ = {'extend_existing': True}
|
||||
|
||||
status = Column(Enum(UserStatusType), nullable=False, default=UserStatusType.active)
|
||||
user_type = Column(Enum(UserType), nullable=False, default=UserType.user)
|
||||
account_type = Column(Enum(AccountType), nullable=False, default=AccountType.email)
|
||||
account = Column(String(length=255), nullable=False, unique=True)
|
||||
pw = Column(String(length=2000), nullable=True)
|
||||
email = Column(String(length=255), nullable=True)
|
||||
name = Column(String(length=255), nullable=False)
|
||||
sex = Column(Enum(SexType), nullable=False, default=SexType.male)
|
||||
rrn = Column(String(length=255), nullable=True)
|
||||
address = Column(String(length=2000), nullable=True)
|
||||
phone_number = Column(String(length=20), nullable=True)
|
||||
picture = Column(String(length=1000), nullable=True)
|
||||
marketing_agree = Column(Boolean, nullable=False, default=False)
|
||||
keys = relationship('ApiKeys', back_populates='users')
|
||||
|
||||
# extra
|
||||
login = Column(Enum(UserLoginType), nullable=False, default=UserLoginType.logout) # TODO(hsj100): LOGIN_STATUS
|
||||
member_type = Column(Enum(MemberType), nullable=False, default=MemberType.personal)
|
||||
# uuid = Column(String(length=36), nullable=True, unique=True)
|
||||
|
||||
|
||||
class UserLog(Base, BaseMixin):
|
||||
__tablename__ = 'userlog'
|
||||
__table_args__ = {'extend_existing': True}
|
||||
|
||||
account = Column(String(length=255), nullable=False)
|
||||
mac = Column(String(length=255), nullable=True)
|
||||
type = Column(Enum(UserLogMessageType), nullable=False, default=UserLogMessageType.info)
|
||||
api = Column(String(length=511), nullable=False)
|
||||
message = Column(String(length=5000), nullable=True)
|
||||
# NOTE(hsj100): 다단계 자동 삭제
|
||||
# user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'), nullable=False)
|
||||
# user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
||||
|
||||
# users = relationship('Users', back_populates='userlog')
|
||||
# users = relationship('Users', back_populates='userlog', lazy=False)
|
||||
|
||||
|
||||
# ===============================================================================================
|
||||
class CustomBaseMixin(BaseMixin):
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
def all_columns(self):
|
||||
return [c for c in self.__table__.columns if c.primary_key is False]
|
||||
|
||||
|
||||
@classmethod
|
||||
def create(cls, session: Session, auto_commit=False, **kwargs):
|
||||
"""
|
||||
테이블 데이터 적재 전용 함수
|
||||
:param session:
|
||||
:param auto_commit: 자동 커밋 여부
|
||||
:param kwargs: 적재 할 데이터
|
||||
:return:
|
||||
"""
|
||||
obj = cls()
|
||||
|
||||
for col in obj.all_columns():
|
||||
col_name = col.name
|
||||
if col_name in kwargs:
|
||||
setattr(obj, col_name, kwargs.get(col_name))
|
||||
session.add(obj)
|
||||
session.flush()
|
||||
if auto_commit:
|
||||
session.commit()
|
||||
return obj
|
||||
|
||||
def update(self, auto_commit: bool = False, synchronize_session='evaluate', **kwargs):
|
||||
|
||||
qs = self._q.update(kwargs, synchronize_session=synchronize_session)
|
||||
get_id = self.id
|
||||
ret = None
|
||||
|
||||
self._session.flush()
|
||||
if qs > 0 :
|
||||
ret = self._q.first()
|
||||
if auto_commit:
|
||||
self._session.commit()
|
||||
return ret
|
||||
|
||||
188
vactor_rest/app/errors/exceptions.py
Normal file
188
vactor_rest/app/errors/exceptions.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from vactor_rest.app.common.consts import MAX_API_KEY, MAX_API_WHITELIST
|
||||
|
||||
|
||||
class StatusCode:
|
||||
HTTP_500 = 500
|
||||
HTTP_400 = 400
|
||||
HTTP_401 = 401
|
||||
HTTP_403 = 403
|
||||
HTTP_404 = 404
|
||||
HTTP_405 = 405
|
||||
|
||||
|
||||
class APIException(Exception):
|
||||
status_code: int
|
||||
code: str
|
||||
msg: str
|
||||
detail: str
|
||||
ex: Exception
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
status_code: int = StatusCode.HTTP_500,
|
||||
code: str = '000000',
|
||||
msg: str | None = None,
|
||||
detail: str | None = None,
|
||||
ex: Exception = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.code = code
|
||||
self.msg = msg
|
||||
self.detail = detail
|
||||
self.ex = ex
|
||||
super().__init__(ex)
|
||||
|
||||
|
||||
class NotFoundUserEx(APIException):
|
||||
def __init__(self, user_id: int = None, ex: Exception = None):
|
||||
super().__init__(
|
||||
status_code=StatusCode.HTTP_404,
|
||||
msg=f'해당 유저를 찾을 수 없습니다.',
|
||||
detail=f'Not Found User ID : {user_id}',
|
||||
code=f'{StatusCode.HTTP_400}{"1".zfill(4)}',
|
||||
ex=ex,
|
||||
)
|
||||
|
||||
|
||||
class NotAuthorized(APIException):
|
||||
def __init__(self, ex: Exception = None):
|
||||
super().__init__(
|
||||
status_code=StatusCode.HTTP_401,
|
||||
msg=f'로그인이 필요한 서비스 입니다.',
|
||||
detail='Authorization Required',
|
||||
code=f'{StatusCode.HTTP_401}{"1".zfill(4)}',
|
||||
ex=ex,
|
||||
)
|
||||
|
||||
|
||||
class TokenExpiredEx(APIException):
|
||||
def __init__(self, ex: Exception = None):
|
||||
super().__init__(
|
||||
status_code=StatusCode.HTTP_400,
|
||||
msg=f'세션이 만료되어 로그아웃 되었습니다.',
|
||||
detail='Token Expired',
|
||||
code=f'{StatusCode.HTTP_400}{"1".zfill(4)}',
|
||||
ex=ex,
|
||||
)
|
||||
|
||||
|
||||
class TokenDecodeEx(APIException):
|
||||
def __init__(self, ex: Exception = None):
|
||||
super().__init__(
|
||||
status_code=StatusCode.HTTP_400,
|
||||
msg=f'비정상적인 접근입니다.',
|
||||
detail='Token has been compromised.',
|
||||
code=f'{StatusCode.HTTP_400}{"2".zfill(4)}',
|
||||
ex=ex,
|
||||
)
|
||||
|
||||
|
||||
class NoKeyMatchEx(APIException):
|
||||
def __init__(self, ex: Exception = None):
|
||||
super().__init__(
|
||||
status_code=StatusCode.HTTP_404,
|
||||
msg=f'해당 키에 대한 권한이 없거나 해당 키가 없습니다.',
|
||||
detail='No Keys Matched',
|
||||
code=f'{StatusCode.HTTP_404}{"3".zfill(4)}',
|
||||
ex=ex,
|
||||
)
|
||||
|
||||
|
||||
class MaxKeyCountEx(APIException):
|
||||
def __init__(self, ex: Exception = None):
|
||||
super().__init__(
|
||||
status_code=StatusCode.HTTP_400,
|
||||
msg=f'API 키 생성은 {MAX_API_KEY}개 까지 가능합니다.',
|
||||
detail='Max Key Count Reached',
|
||||
code=f'{StatusCode.HTTP_400}{"4".zfill(4)}',
|
||||
ex=ex,
|
||||
)
|
||||
|
||||
|
||||
class MaxWLCountEx(APIException):
|
||||
def __init__(self, ex: Exception = None):
|
||||
super().__init__(
|
||||
status_code=StatusCode.HTTP_400,
|
||||
msg=f'화이트리스트 생성은 {MAX_API_WHITELIST}개 까지 가능합니다.',
|
||||
detail='Max Whitelist Count Reached',
|
||||
code=f'{StatusCode.HTTP_400}{"5".zfill(4)}',
|
||||
ex=ex,
|
||||
)
|
||||
|
||||
|
||||
class InvalidIpEx(APIException):
|
||||
def __init__(self, ip: str, ex: Exception = None):
|
||||
super().__init__(
|
||||
status_code=StatusCode.HTTP_400,
|
||||
msg=f'{ip}는 올바른 IP 가 아닙니다.',
|
||||
detail=f'invalid IP : {ip}',
|
||||
code=f'{StatusCode.HTTP_400}{"6".zfill(4)}',
|
||||
ex=ex,
|
||||
)
|
||||
|
||||
|
||||
class SqlFailureEx(APIException):
|
||||
def __init__(self, ex: Exception = None):
|
||||
super().__init__(
|
||||
status_code=StatusCode.HTTP_500,
|
||||
msg=f'이 에러는 서버측 에러 입니다. 자동으로 리포팅 되며, 빠르게 수정하겠습니다.',
|
||||
detail='Internal Server Error',
|
||||
code=f'{StatusCode.HTTP_500}{"2".zfill(4)}',
|
||||
ex=ex,
|
||||
)
|
||||
|
||||
|
||||
class APIQueryStringEx(APIException):
|
||||
def __init__(self, ex: Exception = None):
|
||||
super().__init__(
|
||||
status_code=StatusCode.HTTP_400,
|
||||
msg=f'쿼리스트링은 key, timestamp 2개만 허용되며, 2개 모두 요청시 제출되어야 합니다.',
|
||||
detail='Query String Only Accept key and timestamp.',
|
||||
code=f'{StatusCode.HTTP_400}{"7".zfill(4)}',
|
||||
ex=ex,
|
||||
)
|
||||
|
||||
|
||||
class APIHeaderInvalidEx(APIException):
|
||||
def __init__(self, ex: Exception = None):
|
||||
super().__init__(
|
||||
status_code=StatusCode.HTTP_400,
|
||||
msg=f'헤더에 키 해싱된 Secret 이 없거나, 유효하지 않습니다.',
|
||||
detail='Invalid HMAC secret in Header',
|
||||
code=f'{StatusCode.HTTP_400}{"8".zfill(4)}',
|
||||
ex=ex,
|
||||
)
|
||||
|
||||
|
||||
class APITimestampEx(APIException):
|
||||
def __init__(self, ex: Exception = None):
|
||||
super().__init__(
|
||||
status_code=StatusCode.HTTP_400,
|
||||
msg=f'쿼리스트링에 포함된 타임스탬프는 KST 이며, 현재 시간보다 작아야 하고, 현재시간 - 10초 보다는 커야 합니다.',
|
||||
detail='timestamp in Query String must be KST, Timestamp must be less than now, and greater than now - 10.',
|
||||
code=f'{StatusCode.HTTP_400}{"9".zfill(4)}',
|
||||
ex=ex,
|
||||
)
|
||||
|
||||
|
||||
class NotFoundAccessKeyEx(APIException):
|
||||
def __init__(self, api_key: str, ex: Exception = None):
|
||||
super().__init__(
|
||||
status_code=StatusCode.HTTP_404,
|
||||
msg=f'API 키를 찾을 수 없습니다.',
|
||||
detail=f'Not found such API Access Key : {api_key}',
|
||||
code=f'{StatusCode.HTTP_404}{"10".zfill(4)}',
|
||||
ex=ex,
|
||||
)
|
||||
|
||||
|
||||
class KakaoSendFailureEx(APIException):
|
||||
def __init__(self, ex: Exception = None):
|
||||
super().__init__(
|
||||
status_code=StatusCode.HTTP_400,
|
||||
msg=f'카카오톡 전송에 실패했습니다.',
|
||||
detail=f'Failed to send KAKAO MSG.',
|
||||
code=f'{StatusCode.HTTP_400}{"11".zfill(4)}',
|
||||
ex=ex,
|
||||
)
|
||||
116
vactor_rest/app/main.py
Normal file
116
vactor_rest/app/main.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: main.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: Main
|
||||
"""
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi.security import APIKeyHeader
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from vactor_rest.app.common import consts
|
||||
|
||||
from vactor_rest.app.database.conn import db
|
||||
from vactor_rest.app.common.config import conf
|
||||
from vactor_rest.app.middlewares.token_validator import access_control
|
||||
from vactor_rest.app.middlewares.trusted_hosts import TrustedHostMiddleware
|
||||
from vactor_rest.app.routes import dev, index, auth, users, services
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from custom_logger.vactor_log import vactor_logger as LOG
|
||||
|
||||
|
||||
API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# When service starts.
|
||||
LOG.info("REST start")
|
||||
|
||||
yield
|
||||
|
||||
# When service is stopped.
|
||||
LOG.info("REST shutdown")
|
||||
|
||||
|
||||
def create_app():
|
||||
"""
|
||||
fast_api_app 생성
|
||||
|
||||
:return: fast_api_app
|
||||
"""
|
||||
configurations = conf()
|
||||
fast_api_app = FastAPI(
|
||||
title=configurations.SW_TITLE,
|
||||
version=configurations.SW_VERSION,
|
||||
description=configurations.SW_DESCRIPTION,
|
||||
terms_of_service=configurations.TERMS_OF_SERVICE,
|
||||
contact=configurations.CONTEACT,
|
||||
license_info={'name': 'Copyright by A2TEC', 'url': 'http://www.a2tec.co.kr'},
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 데이터 베이스 이니셜라이즈
|
||||
conf_dict = asdict(configurations)
|
||||
db.init_app(fast_api_app, **conf_dict)
|
||||
|
||||
# 레디스 이니셜라이즈
|
||||
|
||||
# 미들웨어 정의
|
||||
fast_api_app.add_middleware(middleware_class=BaseHTTPMiddleware, dispatch=access_control)
|
||||
fast_api_app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=conf().ALLOW_SITE,
|
||||
allow_credentials=True,
|
||||
allow_methods=['*'],
|
||||
allow_headers=['*'],
|
||||
)
|
||||
fast_api_app.add_middleware(TrustedHostMiddleware, allowed_hosts=conf().TRUSTED_HOSTS, except_path=['/health'])
|
||||
|
||||
# 라우터 정의
|
||||
fast_api_app.include_router(index.router, tags=['Defaults'])
|
||||
|
||||
# if conf().DEBUG:
|
||||
# fast_api_app.include_router(auth.router, tags=['Authentication'], prefix='/api')
|
||||
|
||||
fast_api_app.include_router(services.router, tags=['Services'], prefix='/api', dependencies=[Depends(API_KEY_HEADER)])
|
||||
|
||||
# if conf().DEBUG:
|
||||
# fast_api_app.include_router(users.router, tags=['Users'], prefix='/api', dependencies=[Depends(API_KEY_HEADER)])
|
||||
# fast_api_app.include_router(dev.router, tags=['Developments'], prefix='/api', dependencies=[Depends(API_KEY_HEADER)])
|
||||
|
||||
import os
|
||||
# fast_api_app.mount('/static', StaticFiles(directory=os.path.abspath('./rest/app/static')), name="static")
|
||||
|
||||
return fast_api_app
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
#TODO(jwkim): 422 error handler
|
||||
# @app.exception_handler(RequestValidationError)
|
||||
# async def validation_exception_handler(request, exc):
|
||||
# err = exc.errors()
|
||||
# return JSONResponse(
|
||||
# status_code=422,
|
||||
# content={
|
||||
# "result": False,
|
||||
# "error": f"VALIDATION ERROR : {err}",
|
||||
# "data": None},
|
||||
# )
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run('main:app', host='0.0.0.0', port=conf().REST_SERVER_PORT, reload=True)
|
||||
166
vactor_rest/app/middlewares/token_validator.py
Normal file
166
vactor_rest/app/middlewares/token_validator.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import base64
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
import typing
|
||||
import re
|
||||
|
||||
import jwt
|
||||
import sqlalchemy.exc
|
||||
|
||||
from jwt.exceptions import ExpiredSignatureError, DecodeError
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from vactor_rest.app.common.consts import EXCEPT_PATH_LIST, EXCEPT_PATH_REGEX
|
||||
from vactor_rest.app.database.conn import db
|
||||
from vactor_rest.app.database.schema import Users, ApiKeys
|
||||
from vactor_rest.app.errors import exceptions as ex
|
||||
|
||||
from vactor_rest.app.common import consts
|
||||
from vactor_rest.app.common.config import conf
|
||||
from vactor_rest.app.errors.exceptions import APIException, SqlFailureEx, APIQueryStringEx
|
||||
from vactor_rest.app.models import UserToken
|
||||
|
||||
from vactor_rest.app.utils.date_utils import D
|
||||
from vactor_rest.app.utils.logger import api_logger
|
||||
from vactor_rest.app.utils.query_utils import to_dict
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
|
||||
async def access_control(request: Request, call_next):
|
||||
request.state.req_time = D.datetime()
|
||||
request.state.start = time.time()
|
||||
request.state.inspect = None
|
||||
request.state.user = None
|
||||
request.state.service = None
|
||||
|
||||
ip = request.headers['x-forwarded-for'] if 'x-forwarded-for' in request.headers.keys() else request.client.host
|
||||
request.state.ip = ip.split(',')[0] if ',' in ip else ip
|
||||
headers = request.headers
|
||||
cookies = request.cookies
|
||||
|
||||
url = request.url.path
|
||||
if await url_pattern_check(url, EXCEPT_PATH_REGEX) or url in EXCEPT_PATH_LIST:
|
||||
response = await call_next(request)
|
||||
if url != '/':
|
||||
await api_logger(request=request, response=response)
|
||||
return response
|
||||
|
||||
try:
|
||||
if url.startswith('/api'):
|
||||
# api 인경우 헤더로 토큰 검사
|
||||
# NOTE(hsj100): SERVICE_AUTH_API_KEY (token 방식으로 진행)
|
||||
if url.startswith('/api/services') and conf().SERVICE_AUTH_API_KEY:
|
||||
qs = str(request.query_params)
|
||||
qs_list = qs.split('&')
|
||||
session = next(db.session())
|
||||
if not conf().DEBUG:
|
||||
try:
|
||||
qs_dict = {qs_split.split('=')[0]: qs_split.split('=')[1] for qs_split in qs_list}
|
||||
except Exception:
|
||||
raise ex.APIQueryStringEx()
|
||||
|
||||
qs_keys = qs_dict.keys()
|
||||
|
||||
if 'key' not in qs_keys or 'timestamp' not in qs_keys:
|
||||
raise ex.APIQueryStringEx()
|
||||
|
||||
if 'secret' not in headers.keys():
|
||||
raise ex.APIHeaderInvalidEx()
|
||||
|
||||
api_key = ApiKeys.get(session=session, access_key=qs_dict['key'])
|
||||
|
||||
if not api_key:
|
||||
raise ex.NotFoundAccessKeyEx(api_key=qs_dict['key'])
|
||||
mac = hmac.new(bytes(api_key.secret_key, encoding='utf8'), bytes(qs, encoding='utf-8'), digestmod='sha256')
|
||||
d = mac.digest()
|
||||
validating_secret = str(base64.b64encode(d).decode('utf-8'))
|
||||
|
||||
if headers['secret'] != validating_secret:
|
||||
raise ex.APIHeaderInvalidEx()
|
||||
|
||||
now_timestamp = int(D.datetime(diff=9).timestamp())
|
||||
if now_timestamp - 10 > int(qs_dict['timestamp']) or now_timestamp < int(qs_dict['timestamp']):
|
||||
raise ex.APITimestampEx()
|
||||
|
||||
user_info = to_dict(api_key.users)
|
||||
request.state.user = UserToken(**user_info)
|
||||
|
||||
else:
|
||||
# Request User 가 필요함
|
||||
if 'authorization' in headers.keys():
|
||||
key = headers.get('Authorization')
|
||||
api_key_obj = ApiKeys.get(session=session, access_key=key)
|
||||
user_info = to_dict(Users.get(session=session, id=api_key_obj.user_id))
|
||||
request.state.user = UserToken(**user_info)
|
||||
# 토큰 없음
|
||||
else:
|
||||
if 'Authorization' not in headers.keys():
|
||||
raise ex.NotAuthorized()
|
||||
session.close()
|
||||
response = await call_next(request)
|
||||
return response
|
||||
else:
|
||||
if 'authorization' in headers.keys():
|
||||
# 토큰 존재
|
||||
token_info = await token_decode(access_token=headers.get('Authorization'))
|
||||
request.state.user = UserToken(**token_info)
|
||||
elif conf().DEV_TEST_CONNECT_ACCOUNT:
|
||||
# NOTE(hsj100): DEV_TEST_CONNECT_ACCOUNT
|
||||
request.state.user = UserToken.from_orm(conf().DEV_TEST_CONNECT_ACCOUNT)
|
||||
else:
|
||||
# 토큰 없음
|
||||
if 'Authorization' not in headers.keys():
|
||||
raise ex.NotAuthorized()
|
||||
else:
|
||||
# 템플릿 렌더링인 경우 쿠키에서 토큰 검사
|
||||
cookies['Authorization'] = conf().COOKIES_AUTH
|
||||
|
||||
if 'Authorization' not in cookies.keys():
|
||||
raise ex.NotAuthorized()
|
||||
|
||||
token_info = await token_decode(access_token=cookies.get('Authorization'))
|
||||
request.state.user = UserToken(**token_info)
|
||||
response = await call_next(request)
|
||||
await api_logger(request=request, response=response)
|
||||
except Exception as e:
|
||||
|
||||
error = await exception_handler(e)
|
||||
error_dict = dict(status=error.status_code, msg=error.msg, detail=error.detail, code=error.code)
|
||||
response = JSONResponse(status_code=error.status_code, content=error_dict)
|
||||
await api_logger(request=request, error=error)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def url_pattern_check(path, pattern):
|
||||
result = re.match(pattern, path)
|
||||
if result:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def token_decode(access_token):
|
||||
"""
|
||||
:param access_token:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
access_token = access_token.replace('Bearer ', "")
|
||||
payload = jwt.decode(access_token, key=consts.JWT_SECRET, algorithms=[consts.JWT_ALGORITHM])
|
||||
except ExpiredSignatureError:
|
||||
raise ex.TokenExpiredEx()
|
||||
except DecodeError:
|
||||
raise ex.TokenDecodeEx()
|
||||
return payload
|
||||
|
||||
|
||||
async def exception_handler(error: Exception):
|
||||
print(error)
|
||||
if isinstance(error, sqlalchemy.exc.OperationalError):
|
||||
error = SqlFailureEx(ex=error)
|
||||
if not isinstance(error, APIException):
|
||||
error = APIException(ex=error, detail=str(error))
|
||||
return error
|
||||
63
vactor_rest/app/middlewares/trusted_hosts.py
Normal file
63
vactor_rest/app/middlewares/trusted_hosts.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import typing
|
||||
|
||||
|
||||
from starlette.datastructures import URL, Headers
|
||||
from starlette.responses import PlainTextResponse, RedirectResponse, Response
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
ENFORCE_DOMAIN_WILDCARD = 'Domain wildcard patterns must be lNo module named \'app\'ike \'*.example.com\'.'
|
||||
|
||||
|
||||
class TrustedHostMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
allowed_hosts: typing.Sequence[str] = None,
|
||||
except_path: typing.Sequence[str] = None,
|
||||
www_redirect: bool = True,
|
||||
) -> None:
|
||||
if allowed_hosts is None:
|
||||
allowed_hosts = ['*']
|
||||
if except_path is None:
|
||||
except_path = []
|
||||
for pattern in allowed_hosts:
|
||||
assert '*' not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
|
||||
if pattern.startswith('*') and pattern != '*':
|
||||
assert pattern.startswith('*.'), ENFORCE_DOMAIN_WILDCARD
|
||||
self.app = app
|
||||
self.allowed_hosts = list(allowed_hosts)
|
||||
self.allow_any = '*' in allowed_hosts
|
||||
self.www_redirect = www_redirect
|
||||
self.except_path = list(except_path)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if self.allow_any or scope['type'] not in ('http', 'websocket',): # pragma: no cover
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
headers = Headers(scope=scope)
|
||||
host = headers.get('host', "").split(':')[0]
|
||||
is_valid_host = False
|
||||
found_www_redirect = False
|
||||
for pattern in self.allowed_hosts:
|
||||
if (
|
||||
host == pattern
|
||||
or (pattern.startswith('*') and host.endswith(pattern[1:]))
|
||||
or URL(scope=scope).path in self.except_path
|
||||
):
|
||||
is_valid_host = True
|
||||
break
|
||||
elif 'www.' + host == pattern:
|
||||
found_www_redirect = True
|
||||
|
||||
if is_valid_host:
|
||||
await self.app(scope, receive, send)
|
||||
else:
|
||||
if found_www_redirect and self.www_redirect:
|
||||
url = URL(scope=scope)
|
||||
redirect_url = url.replace(netloc='www.' + url.netloc)
|
||||
response = RedirectResponse(url=str(redirect_url)) # type: Response
|
||||
else:
|
||||
response = PlainTextResponse('Invalid host header', status_code=400)
|
||||
|
||||
await response(scope, receive, send)
|
||||
581
vactor_rest/app/models.py
Normal file
581
vactor_rest/app/models.py
Normal file
@@ -0,0 +1,581 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: models.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: data models
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from pydantic import Field, ConfigDict
|
||||
from pydantic.main import BaseModel
|
||||
from pydantic.networks import EmailStr, IPvAnyAddress
|
||||
from typing import Optional
|
||||
|
||||
from vactor_rest.app.common.consts import (
|
||||
SW_TITLE,
|
||||
SW_VERSION,
|
||||
MAIL_REG_TITLE,
|
||||
MAIL_REG_CONTENTS,
|
||||
SMTP_HOST,
|
||||
SMTP_PORT,
|
||||
ADMIN_INIT_ACCOUNT_INFO,
|
||||
DEFAULT_USER_ACCOUNT_PW
|
||||
)
|
||||
from vactor_rest.app.utils.date_utils import D
|
||||
|
||||
|
||||
class SWInfo(BaseModel):
|
||||
"""
|
||||
### 서비스 정보
|
||||
"""
|
||||
name: str = Field(SW_TITLE, description='SW 이름', example=SW_VERSION)
|
||||
version: str = Field(SW_VERSION, description='SW 버전', example=SW_VERSION)
|
||||
date: str = Field(D.date_str(), description='현재날짜', example='%Y.%m.%dT%H:%M:%S')
|
||||
data1: str = Field(None, description='TestData1', example=None)
|
||||
data2: str = Field(None, description='TestData2', example=None)
|
||||
data3: str = Field(None, description='TestData3', example=None)
|
||||
data4: str = Field(None, description='TestData4', example=None)
|
||||
|
||||
data5: str = Field(None, description='TestData1', example=None)
|
||||
data6: str = Field(None, description='TestData2', example=None)
|
||||
data7: str = Field(None, description='TestData3', example=None)
|
||||
data8: str = Field(None, description='TestData4', example=None)
|
||||
|
||||
|
||||
class CustomEnum(Enum):
|
||||
@classmethod
|
||||
def get_elements_str(cls, is_key=True):
|
||||
if is_key:
|
||||
result_str = cls.__members__.keys()
|
||||
else:
|
||||
result_str = cls.__members__.values()
|
||||
return '[' + ', '.join(result_str) + ']'
|
||||
|
||||
|
||||
class SexType(str, CustomEnum):
|
||||
male: str = 'male'
|
||||
female: str = 'female'
|
||||
|
||||
|
||||
class AccountType(str, CustomEnum):
|
||||
email: str = 'email'
|
||||
# facebook: str = 'facebook'
|
||||
# google: str = 'google'
|
||||
# kakao: str = 'kakao'
|
||||
|
||||
|
||||
class UserStatusType(str, CustomEnum):
|
||||
active: str = 'active'
|
||||
deleted: str = 'deleted'
|
||||
blocked: str = 'blocked'
|
||||
|
||||
|
||||
class UserLoginType(str, CustomEnum):
|
||||
login: str = 'login'
|
||||
logout: str = 'logout'
|
||||
|
||||
|
||||
class MemberType(str, CustomEnum):
|
||||
personal: str = 'personal'
|
||||
company: str = 'company'
|
||||
|
||||
|
||||
class UserType(str, CustomEnum):
|
||||
admin: str = 'admin'
|
||||
user: str = 'user'
|
||||
|
||||
|
||||
class UserLogMessageType(str, CustomEnum):
|
||||
info: str = 'info'
|
||||
error: str = 'error'
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
Authorization: str = Field(None, description='인증키', example='Bearer [token]')
|
||||
|
||||
|
||||
class EmailRecipients(BaseModel):
|
||||
name: str
|
||||
email: str
|
||||
|
||||
|
||||
class SendEmail(BaseModel):
|
||||
email_to: List[EmailRecipients] = None
|
||||
|
||||
|
||||
class KakaoMsgBody(BaseModel):
|
||||
msg: str | None = None
|
||||
|
||||
|
||||
class MessageOk(BaseModel):
|
||||
message: str = Field(default='OK')
|
||||
|
||||
|
||||
class UserToken(BaseModel):
|
||||
id: int
|
||||
account: str | None = None
|
||||
name: str | None = None
|
||||
phone_number: str | None = None
|
||||
profile_img: str | None = None
|
||||
account_type: str | None = None
|
||||
|
||||
# model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AddApiKey(BaseModel):
|
||||
user_memo: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class GetApiKeyList(AddApiKey):
|
||||
id: int = None
|
||||
access_key: str | None = None
|
||||
created_at: datetime = None
|
||||
|
||||
|
||||
class GetApiKeys(GetApiKeyList):
|
||||
secret_key: str | None = None
|
||||
|
||||
|
||||
class CreateAPIWhiteLists(BaseModel):
|
||||
ip_addr: str | None = None
|
||||
|
||||
|
||||
class GetAPIWhiteLists(CreateAPIWhiteLists):
|
||||
id: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ResponseBase(BaseModel):
|
||||
"""
|
||||
### [Response] API End-Point
|
||||
|
||||
**정상처리**\n
|
||||
- result: true\n
|
||||
- error: null\n
|
||||
|
||||
**오류발생**\n
|
||||
- result: false\n
|
||||
- error: 오류내용\n
|
||||
"""
|
||||
result: bool = Field(True, description='처리상태(성공: true, 실패: false)', example=True)
|
||||
error: str | None = Field(None, description='오류내용(성공: null, 실패: 오류내용)', example=None)
|
||||
|
||||
@staticmethod
|
||||
def set_error(error):
|
||||
ResponseBase.result = False
|
||||
ResponseBase.error = str(error)
|
||||
|
||||
return ResponseBase
|
||||
|
||||
@staticmethod
|
||||
def set_message():
|
||||
ResponseBase.result = True
|
||||
ResponseBase.error = None
|
||||
return ResponseBase
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PagingReq(BaseModel):
|
||||
"""
|
||||
### [Request] 페이징 정보
|
||||
"""
|
||||
start_page: int = Field(None, description='시작 페이지 번호(base: 1)', example=1)
|
||||
page_contents_num: int = Field(None, description='페이지 내용 개수', example=2)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PagingRes(BaseModel):
|
||||
"""
|
||||
### [Response] 페이징 정보
|
||||
"""
|
||||
total_page_num: int = Field(None, description='전체 페이지 개수', example=100)
|
||||
total_contents_num: int = Field(None, description='전체 내용 개수', example=100)
|
||||
start_page: int = Field(None, description='시작 페이지 번호(base: 1)', example=1)
|
||||
search_contents_num: int = Field(None, description='검색된 내용 개수', example=100)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class TokenRes(ResponseBase):
|
||||
"""
|
||||
### [Response] 토큰 정보
|
||||
"""
|
||||
Authorization: str = Field(None, description='인증키', example='Bearer [token]')
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserLogInfo(BaseModel):
|
||||
"""
|
||||
### 사용자 로그 정보
|
||||
"""
|
||||
id: int = Field(None, description='Table Index', example='1')
|
||||
created_at: datetime = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
|
||||
updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
|
||||
|
||||
account: str = Field(None, description='계정', example='user1@test.com')
|
||||
mac: str = Field(None, description='MAC(네트워크 인터페이스 식별자)', example='11:22:33:44:55:66')
|
||||
type: str = Field(None, description='로그 타입' + UserLogMessageType.get_elements_str(), example=UserLogMessageType.info)
|
||||
api: str = Field(None, description='API 이름', example='/api/auth/login')
|
||||
message: str = Field(None, description='로그 내용', example='ok')
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
"""
|
||||
### 유저 정보
|
||||
"""
|
||||
id: int = Field(None, description='Table Index', example='1')
|
||||
created_at: datetime = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
|
||||
updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
|
||||
|
||||
status: UserStatusType = Field(None, description='계정상태' + UserStatusType.get_elements_str(), example=UserStatusType.active)
|
||||
user_type: UserType = Field(None, description='유저 타입' + UserType.get_elements_str(), example=UserType.user)
|
||||
account_type: AccountType = Field(None, description='계정종류' + AccountType.get_elements_str(), example=AccountType.email)
|
||||
account: str = Field(None, description='계정', example='user1@test.com')
|
||||
email: str = Field(None, description='전자메일', example='user1@test.com')
|
||||
name: str = Field(None, description='이름', example='user1')
|
||||
sex: str = Field(None, description='성별' + SexType.get_elements_str(), example=SexType.male)
|
||||
rrn: str = Field(None, description='주민등록번호', example='123456-1234567')
|
||||
address: str = Field(None, description='주소', example='대구1')
|
||||
phone_number: str = Field(None, description='연락처', example='010-1234-1234')
|
||||
picture: str = Field(None, description='프로필사진', example='profile1.png')
|
||||
marketing_agree: bool = Field(False, description='마케팅동의 여부', example=False)
|
||||
# extra
|
||||
login: UserLoginType = Field(None, description='로그인 상태' + UserLoginType.get_elements_str(), example=UserLoginType.logout) # TODO(hsj100): LOGIN_STATUS
|
||||
member_type: MemberType = Field(None, description='회원 타입' + MemberType.get_elements_str(), example=MemberType.personal)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class SendMailReq(BaseModel):
|
||||
"""
|
||||
### [Request] 메일 전송
|
||||
"""
|
||||
smtp_host: str = Field(None, description='SMTP 서버 주소', example=SMTP_HOST)
|
||||
smtp_port: int = Field(None, description='SMTP 서버 포트', example=SMTP_PORT)
|
||||
title: str = Field(None, description='제목', example=MAIL_REG_TITLE)
|
||||
recipient: str = Field(None, description='수신자', example='user1@test.com')
|
||||
cc_list: list = Field(None, description='참조 리스트', example='["user2@test.com", "user3@test.com"]')
|
||||
# recipient: str = Field(None, description='수신자', example='hsj100@a2tec.co.kr')
|
||||
# cc_list: list = Field(None, description='참조 리스트', example=None)
|
||||
contents_plain: str = Field(None, description='내용', example=MAIL_REG_CONTENTS.format('user1@test.com', 10))
|
||||
contents_html: str = Field(None, description='내용', example='<p>내 고양이는 <strong>아주 고약해.</p></strong>')
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserInfoRes(ResponseBase):
|
||||
"""
|
||||
### [Response] 유저 정보
|
||||
"""
|
||||
data: UserInfo = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserLoginReq(BaseModel):
|
||||
"""
|
||||
### 유저 로그인 정보
|
||||
"""
|
||||
account: str = Field(description='계정', example='user1@test.com')
|
||||
pw: str = Field(None, description='비밀번호 [관리자 필수]', example='1234')
|
||||
# SW-LICENSE
|
||||
sw: str = Field(None, description='라이선스 SW 이름 [유저 필수]', example='저작도구')
|
||||
mac: str = Field(None, description='MAC [유저 필수]', example='11:22:33:44:55:01')
|
||||
|
||||
|
||||
class UserRegisterReq(BaseModel):
|
||||
"""
|
||||
### [Request] 유저 등록
|
||||
"""
|
||||
status: Optional[str] = Field(UserStatusType.active, description='계정상태' + UserStatusType.get_elements_str(), example=UserStatusType.active)
|
||||
user_type: Optional[UserType] = Field(UserType.user, description='유저 타입' + UserType.get_elements_str(), example=UserType.user)
|
||||
account: str = Field(description='계정', example='test@test.com')
|
||||
pw: Optional[str] = Field(DEFAULT_USER_ACCOUNT_PW, description='비밀번호', example='1234')
|
||||
email: Optional[str] = Field(None, description='전자메일', example='test@test.com')
|
||||
name: str = Field(description='이름', example='test')
|
||||
sex: Optional[SexType] = Field(SexType.male, description='성별' + SexType.get_elements_str(), example=SexType.male)
|
||||
rrn: Optional[str] = Field(None, description='주민등록번호', example='19910101-1234567')
|
||||
address: Optional[str] = Field(None, description='주소', example='대구광역시 동구 동촌로 351, 4층 (용계동, 에이스빌딩)')
|
||||
phone_number: Optional[str] = Field(None, description='휴대전화', example='010-1234-1234')
|
||||
picture: Optional[str] = Field(None, description='사진', example='profile1.png')
|
||||
marketing_agree: Optional[bool] = Field(False, description='마케팅동의 여부', example=False)
|
||||
# extra
|
||||
member_type: Optional[MemberType] = Field(MemberType.personal, description='회원 타입' + MemberType.get_elements_str(), example=MemberType.personal)
|
||||
# relationship:license
|
||||
license_sw: str = Field(description='라이선스 SW 이름', example='저작도구')
|
||||
license_start: datetime = Field(description='라이선스 시작날짜', example='2022-02-10T15:00:00')
|
||||
license_end: datetime = Field(description='라이선스 시작날짜', example='2023-02-10T15:00:00')
|
||||
license_num: Optional[int] = Field(1, description='계약된 라이선스 개수', example=1)
|
||||
license_manager_name: Optional[str] = Field(ADMIN_INIT_ACCOUNT_INFO.name, description='담당자 이름', example=ADMIN_INIT_ACCOUNT_INFO.name)
|
||||
license_manager_phone: Optional[str] = Field(ADMIN_INIT_ACCOUNT_INFO.phone_number, description='담당자 연락처', example=ADMIN_INIT_ACCOUNT_INFO.phone_number)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserSearchReq(BaseModel):
|
||||
"""
|
||||
### [Request] 유저 검색 (기본)
|
||||
"""
|
||||
# basic
|
||||
id: Optional[int] = Field(None, description='등록번호', example='1')
|
||||
id__in: Optional[list] = Field(None, description='등록번호 리스트', example=[1,])
|
||||
created_at: Optional[str] = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
|
||||
created_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
|
||||
created_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
|
||||
created_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
|
||||
created_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
|
||||
updated_at: Optional[str] = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
|
||||
updated_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
|
||||
updated_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
|
||||
updated_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
|
||||
updated_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
|
||||
|
||||
status: Optional[UserStatusType] = Field(None, description='계정상태' + UserStatusType.get_elements_str(), example=UserStatusType.active)
|
||||
user_type: Optional[UserType] = Field(None, description='유저 타입' + UserType.get_elements_str(), example=UserType.user)
|
||||
account_type: Optional[AccountType] = Field(None, description='계정종류' + AccountType.get_elements_str(), example=AccountType.email)
|
||||
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
|
||||
account__like: Optional[str] = Field(None, description='계정 부분검색', example='%user%')
|
||||
email: Optional[str] = Field(None, description='전자메일', example='user1@test.com')
|
||||
email__like: Optional[str] = Field(None, description='전자메일 부분검색', example='%user%')
|
||||
name: Optional[str] = Field(None, description='이름', example='test')
|
||||
name__like: Optional[str] = Field(None, description='이름 부분검색', example='%user%')
|
||||
sex: Optional[SexType] = Field(None, description='성별' + SexType.get_elements_str(), example=SexType.male)
|
||||
rrn: Optional[str] = Field(None, description='주민등록번호', example='123456-1234567')
|
||||
address: Optional[str] = Field(None, description='주소', example='대구1')
|
||||
address__like: Optional[str] = Field(None, description='주소 부분검색', example='%대구%')
|
||||
phone_number: Optional[str] = Field(None, description='연락처', example='010-1234-1234')
|
||||
picture: Optional[str] = Field(None, description='프로필사진', example='profile1.png')
|
||||
marketing_agree: Optional[bool] = Field(None, description='마케팅동의 여부', example=False)
|
||||
# extra
|
||||
login: Optional[UserLoginType] = Field(None, description='로그인 상태' + UserLoginType.get_elements_str(), example=UserLoginType.logout) # TODO(hsj100): LOGIN_STATUS
|
||||
member_type: Optional[MemberType] = Field(None, description='회원 타입' + MemberType.get_elements_str(), example=MemberType.personal)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserSearchRes(ResponseBase):
|
||||
"""
|
||||
### [Response] 유저 검색
|
||||
"""
|
||||
data: List[UserInfo] = []
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserSearchPagingReq(BaseModel):
|
||||
"""
|
||||
### [Request] 유저 페이징 검색
|
||||
"""
|
||||
# paging
|
||||
paging: Optional[PagingReq] = None
|
||||
search: Optional[UserSearchReq] = None
|
||||
|
||||
|
||||
class UserSearchPagingRes(ResponseBase):
|
||||
"""
|
||||
### [Response] 유저 페이징 검색
|
||||
"""
|
||||
paging: PagingRes = None
|
||||
data: List[UserInfo] = []
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserUpdateReq(BaseModel):
|
||||
"""
|
||||
### [Request] 유저 변경
|
||||
"""
|
||||
status: Optional[UserStatusType] = Field(None, description='계정상태' + UserStatusType.get_elements_str(), example=UserStatusType.active)
|
||||
user_type: Optional[UserType] = Field(None, description='유저 타입' + UserType.get_elements_str(), example=UserType.user)
|
||||
account_type: Optional[AccountType] = Field(None, description='계정종류' + AccountType.get_elements_str(), example=AccountType.email)
|
||||
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
|
||||
email: Optional[str] = Field(None, description='전자메일', example='user1@test.com')
|
||||
name: Optional[str] = Field(None, description='이름', example='test')
|
||||
sex: Optional[SexType] = Field(None, description='성별' + SexType.get_elements_str(), example=SexType.male)
|
||||
rrn: Optional[str] = Field(None, description='주민등록번호', example='123456-1234567')
|
||||
address: Optional[str] = Field(None, description='주소', example='대구1')
|
||||
phone_number: Optional[str] = Field(None, description='연락처', example='010-1234-1234')
|
||||
picture: Optional[str] = Field(None, description='프로필사진', example='profile1.png')
|
||||
marketing_agree: Optional[bool] = Field(False, description='마케팅동의 여부', example=False)
|
||||
# extra
|
||||
login: Optional[UserLoginType] = Field(None, description='로그인 상태' + UserLoginType.get_elements_str(), example=UserLoginType.logout) # TODO(hsj100): LOGIN_STATUS
|
||||
# uuid: Optional[str] = Field(None, description='UUID', example='12345678-1234-5678-1234-567800000001')
|
||||
member_type: Optional[MemberType] = Field(None, description='회원 타입' + MemberType.get_elements_str(), example=MemberType.personal)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserUpdateMultiReq(BaseModel):
|
||||
"""
|
||||
### [Request] 유저 변경 (multi)
|
||||
"""
|
||||
search_info: UserSearchReq = None
|
||||
update_info: UserUpdateReq = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserUpdatePWReq(BaseModel):
|
||||
"""
|
||||
### [Request] 유저 비밀번호 변경
|
||||
"""
|
||||
account: str = Field(None, description='계정', example='user1@test.com')
|
||||
current_pw: str = Field(None, description='현재 비밀번호', example='1234')
|
||||
new_pw: str = Field(None, description='신규 비밀번호', example='5678')
|
||||
|
||||
|
||||
class UserLogSearchReq(BaseModel):
|
||||
"""
|
||||
### [Request] 유저로그 검색
|
||||
"""
|
||||
# basic
|
||||
id: Optional[int] = Field(None, description='등록번호', example='1')
|
||||
id__in: Optional[list] = Field(None, description='등록번호 리스트', example=[1,])
|
||||
created_at: Optional[str] = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
|
||||
created_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
|
||||
created_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
|
||||
created_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
|
||||
created_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
|
||||
updated_at: Optional[str] = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
|
||||
updated_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
|
||||
updated_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
|
||||
updated_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
|
||||
updated_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
|
||||
|
||||
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
|
||||
account__like: Optional[str] = Field(None, description='계정 부분검색', example='%user%')
|
||||
mac: Optional[str] = Field(None, description='MAC', example='11:22:33:44:55:01')
|
||||
mac__like: Optional[str] = Field(None, description='MAC 부분검색', example='%33%')
|
||||
type: Optional[UserLogMessageType] = Field(None, description='유저로그 메시지 타입' + UserLogMessageType.get_elements_str(), example=UserLogMessageType.error)
|
||||
api: Optional[str] = Field(None, description='API 이름', example='/api/auth/login')
|
||||
api__like: Optional[str] = Field(None, description='API 이름 부분검색', example='%login%')
|
||||
message: Optional[str] = Field(None, description='로그내용', example='invalid password')
|
||||
message__like: Optional[str] = Field(None, description='로그내용 부분검색', example='%invalid%')
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserLogDayInfo(BaseModel):
|
||||
"""
|
||||
### 유저로그 일별 접속 정보 (출석)
|
||||
"""
|
||||
account: str = Field(None, description='계정', example='user1@test.com')
|
||||
mac: str = Field(None, description='MAC(네트워크 인터페이스 식별자)', example='11:22:33:44:55:66')
|
||||
updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
|
||||
message: str = Field(None, description='로그 내용', example='ok')
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserLogDaySearchReq(BaseModel):
|
||||
"""
|
||||
### [Request] 유저로그 일별 마지막 접속 검색
|
||||
"""
|
||||
updated_at: Optional[str] = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
|
||||
updated_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
|
||||
updated_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
|
||||
updated_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
|
||||
updated_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
|
||||
|
||||
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
|
||||
account__like: Optional[str] = Field(None, description='계정 부분검색', example='%user%')
|
||||
mac: Optional[str] = Field(None, description='MAC', example='11:22:33:44:55:01')
|
||||
mac__like: Optional[str] = Field(None, description='MAC 부분검색', example='%33%')
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserLogDaySearchPagingReq(BaseModel):
|
||||
"""
|
||||
### [Request] 유저로그 일별 마지막 접속 페이징 검색
|
||||
"""
|
||||
# paging
|
||||
paging: Optional[PagingReq] = None
|
||||
search: Optional[UserLogDaySearchReq] = None
|
||||
|
||||
|
||||
class UserLogDaySearchPagingRes(ResponseBase):
|
||||
"""
|
||||
### [Response] 유저로그 일별 마지막 접속 검색
|
||||
"""
|
||||
paging: PagingRes = None
|
||||
data: List[UserLogDayInfo] = []
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserLogPagingReq(BaseModel):
|
||||
"""
|
||||
### [Request] 유저로그 페이징 검색
|
||||
"""
|
||||
# paging
|
||||
paging: Optional[PagingReq] = None
|
||||
search: Optional[UserLogSearchReq] = None
|
||||
|
||||
|
||||
class UserLogPagingRes(ResponseBase):
|
||||
"""
|
||||
### [Response] 유저로그 페이징 검색
|
||||
"""
|
||||
paging: PagingRes = None
|
||||
data: List[UserLogInfo] = []
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserLogUpdateReq(BaseModel):
|
||||
"""
|
||||
### [Request] 유저로그 변경
|
||||
"""
|
||||
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
|
||||
mac: Optional[str] = Field(None, description='MAC', example='11:22:33:44:55:01')
|
||||
type: Optional[UserLogMessageType] = Field(None, description='유저로그 메시지 타입' + UserLogMessageType.get_elements_str(), example=UserLogMessageType.error)
|
||||
api: Optional[str] = Field(None, description='API 이름', example='/api/auth/login')
|
||||
message: Optional[str] = Field(None, description='로그내용', example='invalid password')
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserLogUpdateMultiReq(BaseModel):
|
||||
"""
|
||||
### [Request] 유저로그 변경 (multi)
|
||||
"""
|
||||
search_info: UserLogSearchReq = None
|
||||
update_info: UserLogUpdateReq = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
#===============================================================================
|
||||
#===============================================================================
|
||||
#===============================================================================
|
||||
#===============================================================================
|
||||
|
||||
class IndexType(str, Enum):
|
||||
hnsw = "hnsw"
|
||||
l2 = "l2"
|
||||
|
||||
class VactorSearchReq(BaseModel):
|
||||
"""
|
||||
### [Request] vactor 검색
|
||||
"""
|
||||
quary_image_path : str = Field(description='quary image', example='path')
|
||||
index_type : IndexType = Field(IndexType.hnsw, description='인덱스 타입', example=IndexType.hnsw)
|
||||
search_num : int = Field(4, description='검색결과 이미지 갯수', example=4)
|
||||
|
||||
|
||||
93
vactor_rest/app/routes/auth.py
Normal file
93
vactor_rest/app/routes/auth.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: auth.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: authentication api
|
||||
"""
|
||||
|
||||
from itertools import groupby
|
||||
from operator import attrgetter
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import bcrypt
|
||||
import jwt
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from vactor_rest.app.common import consts
|
||||
from vactor_rest.app import models as M
|
||||
from vactor_rest.app.database.conn import db
|
||||
from vactor_rest.app.common.config import conf
|
||||
from vactor_rest.app.database.schema import Users, UserLog
|
||||
from vactor_rest.app.utils.extra import query_to_groupby, AESCryptoCBC
|
||||
from vactor_rest.app.utils.date_utils import D
|
||||
|
||||
router = APIRouter(prefix='/auth')
|
||||
|
||||
|
||||
@router.get('/find-account/{account}', response_model=M.ResponseBase, summary='계정유무 검사')
|
||||
async def find_account(account: str):
|
||||
"""
|
||||
## 계정유무 검사
|
||||
|
||||
주어진 계정이 존재하면 true, 없으면 false 처리
|
||||
|
||||
**결과**
|
||||
- ResponseBase
|
||||
"""
|
||||
try:
|
||||
search_info = Users.get(account=account)
|
||||
if not search_info:
|
||||
raise Exception(f'not found data: {account}')
|
||||
return M.ResponseBase()
|
||||
except Exception as e:
|
||||
return M.ResponseBase.set_error(str(e))
|
||||
|
||||
|
||||
@router.post('/logout/{account}', status_code=200, response_model=M.TokenRes, summary='사용자 접속종료')
|
||||
async def logout(account: str):
|
||||
"""
|
||||
## 사용자 접속종료
|
||||
|
||||
현재 버전에서는 로그인/로그아웃의 상태를 유지하지 않고 상태값만을 서버에서 사용하기 때문에,\n
|
||||
***로그상태는 실제상황과 다를 수 있다.***
|
||||
|
||||
정상처리시 Authorization(null) 반환
|
||||
|
||||
**결과**
|
||||
- TokenRes
|
||||
"""
|
||||
user_info = None
|
||||
|
||||
try:
|
||||
# TODO(hsj100): LOGIN_STATUS
|
||||
user_info = Users.filter(account=account)
|
||||
if not user_info:
|
||||
raise Exception('not found user')
|
||||
|
||||
user_info.update(auto_commit=True, login='logout')
|
||||
return M.TokenRes()
|
||||
except Exception as e:
|
||||
if user_info:
|
||||
user_info.close()
|
||||
return M.ResponseBase.set_error(e)
|
||||
|
||||
|
||||
async def is_account_exist(account: str):
|
||||
get_account = Users.get(account=account)
|
||||
return True if get_account else False
|
||||
|
||||
|
||||
def create_access_token(*, data: dict = None, expires_delta: int = None):
|
||||
|
||||
if conf().GLOBAL_TOKEN:
|
||||
return conf().GLOBAL_TOKEN
|
||||
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
to_encode.update({'exp': datetime.utcnow() + timedelta(hours=expires_delta)})
|
||||
encoded_jwt = jwt.encode(to_encode, consts.JWT_SECRET, algorithm=consts.JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
137
vactor_rest/app/routes/dev.py
Normal file
137
vactor_rest/app/routes/dev.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: dev.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: Developments Test
|
||||
"""
|
||||
import struct
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import bcrypt
|
||||
from starlette.requests import Request
|
||||
|
||||
from vactor_rest.app.common import consts
|
||||
from vactor_rest.app import models as M
|
||||
from vactor_rest.app.database.conn import db, Base
|
||||
from vactor_rest.app.database.schema import Users, UserLog
|
||||
|
||||
from vactor_rest.app.utils.extra import FernetCrypto, AESCryptoCBC, AESCipher
|
||||
from custom_logger.vactor_log import vactor_logger as LOG
|
||||
|
||||
|
||||
# mail test
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
|
||||
def send_mail():
|
||||
"""
|
||||
구글 계정사용시 : 보안 수준이 낮은 앱에서의 접근 활성화
|
||||
|
||||
:return:
|
||||
"""
|
||||
sender = 'jolimola@gmail.com'
|
||||
sender_pw = '!ghkdtmdwns1'
|
||||
# recipient = 'hsj100@a2tec.co.kr'
|
||||
recipient = 'jwkim@daooldns.co.kr'
|
||||
list_cc = ['cc1@gmail.com', 'cc2@naver.com']
|
||||
str_cc = ','.join(list_cc)
|
||||
|
||||
title = 'Test mail'
|
||||
contents = '''
|
||||
This is test mail
|
||||
using smtplib.
|
||||
'''
|
||||
|
||||
smtp_server = smtplib.SMTP( # 1
|
||||
host='smtp.gmail.com',
|
||||
port=587
|
||||
)
|
||||
|
||||
smtp_server.ehlo() # 2
|
||||
smtp_server.starttls() # 2
|
||||
smtp_server.ehlo() # 2
|
||||
smtp_server.login(sender, sender_pw) # 3
|
||||
|
||||
msg = MIMEMultipart() # 4
|
||||
msg['From'] = sender # 5
|
||||
msg['To'] = recipient # 5
|
||||
# msg['Cc'] = str_cc # 5
|
||||
msg['Subject'] = contents # 5
|
||||
msg.attach(MIMEText(contents, 'plain')) # 6
|
||||
|
||||
smtp_server.send_message(msg) # 7
|
||||
smtp_server.quit() # 8
|
||||
|
||||
|
||||
router = APIRouter(prefix='/dev')
|
||||
|
||||
|
||||
@router.get('/test', summary='테스트', response_model=M.SWInfo)
|
||||
async def test(request: Request):
|
||||
"""
|
||||
## ELB 상태 체크용 API
|
||||
|
||||
**결과**
|
||||
- SWInfo
|
||||
"""
|
||||
|
||||
a = M.SWInfo()
|
||||
|
||||
a.name = '!ekdnfeldpsdptm1'
|
||||
# a.name = 'testtesttest123'
|
||||
simpleEnDecrypt = FernetCrypto()
|
||||
a.data1 = simpleEnDecrypt.encrypt(a.name)
|
||||
a.data2 = bcrypt.hashpw(a.name.encode('utf-8'), bcrypt.gensalt())
|
||||
|
||||
t = bytes(a.name.encode('utf-8'))
|
||||
|
||||
enc = AESCryptoCBC().encrypt(t)
|
||||
dec = AESCryptoCBC().decrypt(enc)
|
||||
|
||||
t = enc.decode('utf-8')
|
||||
|
||||
# enc = AESCipher('daooldns12345678').encrypt(a.name).decode('utf-8')
|
||||
# enc = 'E563ZFt+yJL8YY5yYlYyk602MSscPP2SCCD8UtXXpMI='
|
||||
# dec = AESCipher('daooldns12345678').decrypt(enc).decode('utf-8')
|
||||
|
||||
a.data3 = f'enc: {enc}, {t}'
|
||||
a.data4 = f'dec: {dec}'
|
||||
|
||||
a.name = '!ekdnfeldpsdptm1'
|
||||
|
||||
|
||||
# simpleEnDecrypt = SimpleEnDecrypt()
|
||||
a.data5 = simpleEnDecrypt.encrypt(a.name)
|
||||
a.data6 = bcrypt.hashpw(a.name.encode('utf-8'), bcrypt.gensalt())
|
||||
|
||||
key = consts.ADMIN_INIT_ACCOUNT_INFO['aes_cbc_key']
|
||||
|
||||
t = bytes(a.name.encode('utf-8'))
|
||||
|
||||
enc = AESCryptoCBC(key).encrypt(t)
|
||||
dec = AESCryptoCBC(key).decrypt(enc)
|
||||
|
||||
t = enc.decode('utf-8')
|
||||
|
||||
# enc = AESCipher('daooldns12345678').encrypt(a.name).decode('utf-8')
|
||||
# enc = 'E563ZFt+yJL8YY5yYlYyk602MSscPP2SCCD8UtXXpMI='
|
||||
# dec = AESCipher('daooldns12345678').decrypt(enc).decode('utf-8')
|
||||
|
||||
print(f'key: {key}')
|
||||
a.data7 = f'enc: {enc}, {t}'
|
||||
a.data8 = f'dec: {dec}'
|
||||
|
||||
|
||||
eee = "gAAAAABioV5NucuS9nQugZJnz-KjVG_FGnaowB9KAfhOoWjjiQ4jGLuYJh4Qe94mT_lCm6m3HhuOJqUeOgjppwREDpIQYzrUXA=="
|
||||
a.data8 = simpleEnDecrypt.decrypt(eee)
|
||||
|
||||
return a
|
||||
|
||||
|
||||
32
vactor_rest/app/routes/index.py
Normal file
32
vactor_rest/app/routes/index.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: index.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: basic & test api
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from vactor_rest.app.utils.date_utils import D
|
||||
from vactor_rest.app.models import SWInfo
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get('/', summary='서비스 정보', response_model=SWInfo)
|
||||
async def index():
|
||||
"""
|
||||
## 서비스 정보
|
||||
소프트웨어 이름, 버전정보, 현재시간
|
||||
|
||||
**결과**
|
||||
- SWInfo
|
||||
"""
|
||||
sw_info = SWInfo()
|
||||
sw_info.date = D.date_str()
|
||||
return sw_info
|
||||
49
vactor_rest/app/routes/services.py
Normal file
49
vactor_rest/app/routes/services.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: services.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: services api
|
||||
"""
|
||||
|
||||
import requests, json, traceback, os
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
from starlette.requests import Request
|
||||
from typing import Annotated, List
|
||||
|
||||
from vactor_rest.app.common import consts
|
||||
from vactor_rest.app import models as M
|
||||
from vactor_rest.app.utils.date_utils import D
|
||||
from custom_logger.vactor_log import vactor_logger as LOG
|
||||
|
||||
from custom_apps.faiss.main import search_idxs
|
||||
|
||||
router = APIRouter(prefix="/services")
|
||||
|
||||
|
||||
@router.post("/faiss/vactor/search", summary="vactor search", response_model=M.ResponseBase)
|
||||
async def vactor_search(request: Request, request_body_info: M.VactorSearchReq):
|
||||
"""
|
||||
## 벡터검색
|
||||
|
||||
## 정보
|
||||
>
|
||||
|
||||
"""
|
||||
response = M.ResponseBase()
|
||||
try:
|
||||
if os.path.exists(request_body_info.quary_image_path):
|
||||
search_idxs(image_path=request_body_info.quary_image_path,
|
||||
index_type=request_body_info.index_type,
|
||||
search_num=request_body_info.search_num)
|
||||
else:
|
||||
raise Exception(f"File {request_body_info.quary_image_path} does not exist.")
|
||||
|
||||
return response.set_message()
|
||||
|
||||
except Exception as e:
|
||||
LOG.error(traceback.format_exc())
|
||||
return response.set_error(e)
|
||||
298
vactor_rest/app/routes/users.py
Normal file
298
vactor_rest/app/routes/users.py
Normal file
@@ -0,0 +1,298 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: users.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: users api
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from starlette.requests import Request
|
||||
import bcrypt
|
||||
|
||||
from vactor_rest.app.common import consts
|
||||
from vactor_rest.app import models as M
|
||||
from vactor_rest.app.common.config import conf
|
||||
from vactor_rest.app.database.schema import Users
|
||||
from vactor_rest.app.database.crud import table_select, table_update, table_delete
|
||||
|
||||
from vactor_rest.app.utils.extra import AESCryptoCBC
|
||||
|
||||
|
||||
router = APIRouter(prefix='/user')
|
||||
|
||||
|
||||
@router.get('/me', response_model=M.UserSearchRes, summary='접속자 정보')
|
||||
async def get_me(request: Request):
|
||||
"""
|
||||
## 현재 접속된 자신정보 확인
|
||||
|
||||
***현 버전 미지원(추후 상세 처리)***
|
||||
|
||||
**결과**
|
||||
- UserSearchRes
|
||||
"""
|
||||
target_table = Users
|
||||
search_info = None
|
||||
|
||||
try:
|
||||
# request
|
||||
if conf().GLOBAL_TOKEN:
|
||||
raise Exception('not supported: use search api!')
|
||||
|
||||
accessor_info = request.state.user
|
||||
if not accessor_info:
|
||||
raise Exception('invalid accessor')
|
||||
|
||||
# search
|
||||
search_info = target_table.get(account=accessor_info.account)
|
||||
if not search_info:
|
||||
raise Exception('not found data')
|
||||
|
||||
# result
|
||||
result_info = list()
|
||||
result_info.append(M.UserInfo.from_orm(search_info))
|
||||
return M.UserSearchRes(data=result_info)
|
||||
except Exception as e:
|
||||
if search_info:
|
||||
search_info.close()
|
||||
return M.ResponseBase.set_error(str(e))
|
||||
|
||||
|
||||
@router.post('/search', response_model=M.UserSearchPagingRes, summary='유저정보 검색')
|
||||
async def user_search(request: Request, request_body_info: M.UserSearchPagingReq):
|
||||
"""
|
||||
## 유저정보 검색 (기본)
|
||||
검색 조건은 유저 테이블 항목만 가능 (Request body: Schema 참조)\n
|
||||
관련 정보 연동 검색은 별도 API 사용
|
||||
|
||||
**세부항목**
|
||||
- **paging**\n
|
||||
항목 미사용시에는 페이징기능 없이 검색조건(search) 결과 모두 반환
|
||||
|
||||
- **search**\n
|
||||
검색에 필요한 항목들을 search 에 포함시킨다.\n
|
||||
검색에 사용된 각 항목들은 AND 조건으로 처리된다.
|
||||
|
||||
- **전체검색**\n
|
||||
empty object 사용 ( {} )\n
|
||||
예) "search": {}
|
||||
|
||||
- **검색항목**\n
|
||||
- 부분검색 항목\n
|
||||
SQL 문법( %, _ )을 사용한다.\n
|
||||
__like: 시작포함(X%), 중간포함(%X%), 끝포함(%X)
|
||||
- 구간검색 항목
|
||||
* __lt: 주어진 값보다 작은값
|
||||
* __lte: 주어진 값보다 같거나 작은 값
|
||||
* __gt: 주어진 값보다 큰값
|
||||
* __gte: 주어진 값보다 같거나 큰값
|
||||
|
||||
**결과**
|
||||
- UserSearchPagingRes
|
||||
"""
|
||||
return await table_select(request.state.user, Users, request_body_info, M.UserSearchPagingRes, M.UserInfo)
|
||||
|
||||
|
||||
@router.put('/update', response_model=M.ResponseBase, summary='유저정보 변경')
|
||||
async def user_update(request: Request, request_body_info: M.UserUpdateMultiReq):
|
||||
"""
|
||||
## 유저정보 변경
|
||||
|
||||
**search_info**: 변경대상\n
|
||||
|
||||
**update_info**: 변경내용\n
|
||||
- **비밀번호** 제외
|
||||
|
||||
**결과**
|
||||
- ResponseBase
|
||||
"""
|
||||
return await table_update(request.state.user, Users, request_body_info, M.ResponseBase)
|
||||
|
||||
|
||||
@router.put('/update_pw', response_model=M.ResponseBase, summary='유저 비밀번호 변경')
|
||||
async def user_update_pw(request: Request, request_body_info: M.UserUpdatePWReq):
|
||||
"""
|
||||
## 유저정보 비밀번호 변경
|
||||
|
||||
**account**의 **비밀번호**를 변경한다.
|
||||
|
||||
**결과**
|
||||
- ResponseBase
|
||||
"""
|
||||
target_table = Users
|
||||
search_info = None
|
||||
|
||||
try:
|
||||
# request
|
||||
accessor_info = request.state.user
|
||||
if not accessor_info:
|
||||
raise Exception('invalid accessor')
|
||||
if not request_body_info.account:
|
||||
raise Exception('invalid account')
|
||||
|
||||
# decrypt pw
|
||||
try:
|
||||
decode_cur_pw = request_body_info.current_pw.encode('utf-8')
|
||||
desc_cur_pw = AESCryptoCBC().decrypt(decode_cur_pw)
|
||||
except Exception as e:
|
||||
raise Exception(f'failed decryption [current_pw]: {e}')
|
||||
try:
|
||||
decode_new_pw = request_body_info.new_pw.encode('utf-8')
|
||||
desc_new_pw = AESCryptoCBC().decrypt(decode_new_pw)
|
||||
except Exception as e:
|
||||
raise Exception(f'failed decryption [new_pw]: {e}')
|
||||
|
||||
# search
|
||||
target_user = target_table.get(account=request_body_info.account)
|
||||
is_verified = bcrypt.checkpw(desc_cur_pw, target_user.pw.encode('utf-8'))
|
||||
if not is_verified:
|
||||
raise Exception('invalid password')
|
||||
|
||||
search_info = target_table.filter(id=target_user.id)
|
||||
if not search_info.first():
|
||||
raise Exception('not found data')
|
||||
|
||||
# process
|
||||
hash_pw = bcrypt.hashpw(desc_new_pw, bcrypt.gensalt())
|
||||
result_info = search_info.update(auto_commit=True, pw=hash_pw)
|
||||
if not result_info or not result_info.id:
|
||||
raise Exception('failed update')
|
||||
|
||||
# result
|
||||
return M.ResponseBase()
|
||||
except Exception as e:
|
||||
if search_info:
|
||||
search_info.close()
|
||||
return M.ResponseBase.set_error(str(e))
|
||||
|
||||
|
||||
@router.delete('/delete', response_model=M.ResponseBase, summary='유저정보 삭제')
|
||||
async def user_delete(request: Request, request_body_info: M.UserSearchReq):
|
||||
"""
|
||||
## 유저정보 삭제
|
||||
조건에 해당하는 정보를 모두 삭제한다.\n
|
||||
- **본 API는 DB에서 완적삭제를 하는 함수이며, 서버관리자가 사용하는 것을 권장한다.**
|
||||
- **update API를 사용하여 상태 항목을 변경해서 사용하는 것을 권장.**
|
||||
|
||||
`유저삭제시 관계 테이블의 정보도 같이 삭제된다.`
|
||||
|
||||
**결과**
|
||||
- ResponseBase
|
||||
"""
|
||||
return await table_delete(request.state.user, Users, request_body_info, M.ResponseBase)
|
||||
|
||||
|
||||
# NOTE(hsj100): apikey
|
||||
"""
|
||||
"""
|
||||
# @router.get('/apikeys', response_model=List[M.GetApiKeyList])
|
||||
# async def get_api_keys(request: Request):
|
||||
# """
|
||||
# API KEY 조회
|
||||
# :param request:
|
||||
# :return:
|
||||
# """
|
||||
# user = request.state.user
|
||||
# api_keys = ApiKeys.filter(user_id=user.id).all()
|
||||
# return api_keys
|
||||
#
|
||||
#
|
||||
# @router.post('/apikeys', response_model=M.GetApiKeys)
|
||||
# async def create_api_keys(request: Request, key_info: M.AddApiKey, session: Session = Depends(db.session)):
|
||||
# """
|
||||
# API KEY 생성
|
||||
# :param request:
|
||||
# :param key_info:
|
||||
# :param session:
|
||||
# :return:
|
||||
# """
|
||||
# user = request.state.user
|
||||
#
|
||||
# api_keys = ApiKeys.filter(session, user_id=user.id, status='active').count()
|
||||
# if api_keys == MAX_API_KEY:
|
||||
# raise ex.MaxKeyCountEx()
|
||||
#
|
||||
# alphabet = string.ascii_letters + string.digits
|
||||
# s_key = ''.join(secrets.choice(alphabet) for _ in range(40))
|
||||
# uid = None
|
||||
# while not uid:
|
||||
# uid_candidate = f'{str(uuid4())[:-12]}{str(uuid4())}'
|
||||
# uid_check = ApiKeys.get(access_key=uid_candidate)
|
||||
# if not uid_check:
|
||||
# uid = uid_candidate
|
||||
#
|
||||
# key_info = key_info.dict()
|
||||
# new_key = ApiKeys.create(session, auto_commit=True, secret_key=s_key, user_id=user.id, access_key=uid, **key_info)
|
||||
# return new_key
|
||||
#
|
||||
#
|
||||
# @router.put('/apikeys/{key_id}', response_model=M.GetApiKeyList)
|
||||
# async def update_api_keys(request: Request, key_id: int, key_info: M.AddApiKey):
|
||||
# """
|
||||
# API KEY User Memo Update
|
||||
# :param request:
|
||||
# :param key_id:
|
||||
# :param key_info:
|
||||
# :return:
|
||||
# """
|
||||
# user = request.state.user
|
||||
# key_data = ApiKeys.filter(id=key_id)
|
||||
# if key_data and key_data.first().user_id == user.id:
|
||||
# return key_data.update(auto_commit=True, **key_info.dict())
|
||||
# raise ex.NoKeyMatchEx()
|
||||
#
|
||||
#
|
||||
# @router.delete('/apikeys/{key_id}')
|
||||
# async def delete_api_keys(request: Request, key_id: int, access_key: str):
|
||||
# user = request.state.user
|
||||
# await check_api_owner(user.id, key_id)
|
||||
# search_by_key = ApiKeys.filter(access_key=access_key)
|
||||
# if not search_by_key.first():
|
||||
# raise ex.NoKeyMatchEx()
|
||||
# search_by_key.delete(auto_commit=True)
|
||||
# return MessageOk()
|
||||
#
|
||||
#
|
||||
# @router.get('/apikeys/{key_id}/whitelists', response_model=List[M.GetAPIWhiteLists])
|
||||
# async def get_api_keys(request: Request, key_id: int):
|
||||
# user = request.state.user
|
||||
# await check_api_owner(user.id, key_id)
|
||||
# whitelists = ApiWhiteLists.filter(api_key_id=key_id).all()
|
||||
# return whitelists
|
||||
#
|
||||
#
|
||||
# @router.post('/apikeys/{key_id}/whitelists', response_model=M.GetAPIWhiteLists)
|
||||
# async def create_api_keys(request: Request, key_id: int, ip: M.CreateAPIWhiteLists, session: Session = Depends(db.session)):
|
||||
# user = request.state.user
|
||||
# await check_api_owner(user.id, key_id)
|
||||
# import ipaddress
|
||||
# try:
|
||||
# _ip = ipaddress.ip_address(ip.ip_addr)
|
||||
# except Exception as e:
|
||||
# raise ex.InvalidIpEx(ip.ip_addr, e)
|
||||
# if ApiWhiteLists.filter(api_key_id=key_id).count() == MAX_API_WHITELIST:
|
||||
# raise ex.MaxWLCountEx()
|
||||
# ip_dup = ApiWhiteLists.get(api_key_id=key_id, ip_addr=ip.ip_addr)
|
||||
# if ip_dup:
|
||||
# return ip_dup
|
||||
# ip_reg = ApiWhiteLists.create(session=session, auto_commit=True, api_key_id=key_id, ip_addr=ip.ip_addr)
|
||||
# return ip_reg
|
||||
#
|
||||
#
|
||||
# @router.delete('/apikeys/{key_id}/whitelists/{list_id}')
|
||||
# async def delete_api_keys(request: Request, key_id: int, list_id: int):
|
||||
# user = request.state.user
|
||||
# await check_api_owner(user.id, key_id)
|
||||
# ApiWhiteLists.filter(id=list_id, api_key_id=key_id).delete()
|
||||
#
|
||||
# return MessageOk()
|
||||
#
|
||||
#
|
||||
# async def check_api_owner(user_id, key_id):
|
||||
# api_keys = ApiKeys.get(id=key_id, user_id=user_id)
|
||||
# if not api_keys:
|
||||
# raise ex.NoKeyMatchEx()
|
||||
58
vactor_rest/app/utils/date_utils.py
Normal file
58
vactor_rest/app/utils/date_utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: date_utils.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: date-utility functions
|
||||
"""
|
||||
|
||||
from datetime import datetime, date, timedelta
|
||||
|
||||
_TIMEDELTA = 9
|
||||
|
||||
|
||||
class D:
|
||||
def __init__(self, *args):
|
||||
self.utc_now = datetime.utcnow()
|
||||
# NOTE(hsj100): utc->kst
|
||||
self.timedelta = _TIMEDELTA
|
||||
|
||||
@classmethod
|
||||
def datetime(cls, diff: int=_TIMEDELTA) -> datetime:
|
||||
return datetime.utcnow() + timedelta(hours=diff) if diff > 0 else datetime.now()
|
||||
|
||||
@classmethod
|
||||
def date(cls, diff: int=_TIMEDELTA) -> date:
|
||||
return cls.datetime(diff=diff).date()
|
||||
|
||||
@classmethod
|
||||
def date_num(cls, diff: int=_TIMEDELTA) -> int:
|
||||
return int(cls.date(diff=diff).strftime('%Y%m%d'))
|
||||
|
||||
@classmethod
|
||||
def validate(cls, date_text):
|
||||
try:
|
||||
datetime.strptime(date_text, '%Y-%m-%d')
|
||||
except ValueError:
|
||||
raise ValueError('Incorrect data format, should be YYYY-MM-DD')
|
||||
|
||||
@classmethod
|
||||
def date_str(cls, diff: int = _TIMEDELTA) -> str:
|
||||
return cls.datetime(diff=diff).strftime('%Y-%m-%dT%H:%M:%S')
|
||||
|
||||
@classmethod
|
||||
def check_expire_date(cls, expire_date: datetime):
|
||||
td = expire_date - datetime.now()
|
||||
timestamp = td.total_seconds()
|
||||
return timestamp
|
||||
|
||||
@classmethod
|
||||
def date_file_name(cls):
|
||||
date = datetime.now()
|
||||
return date.strftime('%y%m%d_%H%M%S.%s')
|
||||
|
||||
if __name__ == "__main__":
|
||||
_date = D()
|
||||
205
vactor_rest/app/utils/extra.py
Normal file
205
vactor_rest/app/utils/extra.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: extra.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: extra functions
|
||||
"""
|
||||
|
||||
from hashlib import md5
|
||||
from base64 import b64decode
|
||||
from base64 import b64encode
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from cryptography.fernet import Fernet # symmetric encryption
|
||||
|
||||
# mail test
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
from itertools import groupby
|
||||
from operator import attrgetter
|
||||
import uuid
|
||||
|
||||
from vactor_rest.app.common.consts import NUM_RETRY_UUID_GEN, SMTP_HOST, SMTP_PORT
|
||||
from vactor_rest.app.utils.date_utils import D
|
||||
from vactor_rest.app import models as M
|
||||
from vactor_rest.app.common.consts import AES_CBC_PUBLIC_KEY, AES_CBC_IV, FERNET_SECRET_KEY
|
||||
|
||||
|
||||
async def send_mail(sender, sender_pw, title, recipient, contents_plain, contents_html, cc_list, smtp_host=SMTP_HOST, smtp_port=SMTP_PORT):
|
||||
"""
|
||||
구글 계정사용시 : 보안 수준이 낮은 앱에서의 접근 활성화
|
||||
|
||||
:return:
|
||||
None: success
|
||||
Str. Message: error
|
||||
"""
|
||||
try:
|
||||
# check parameters
|
||||
if not sender:
|
||||
raise Exception('invalid sender')
|
||||
if not title:
|
||||
raise Exception('invalid title')
|
||||
if not recipient:
|
||||
raise Exception('invalid recipient')
|
||||
|
||||
# sender info.
|
||||
# sender = consts.ADMIN_INIT_ACCOUNT_INFO.email
|
||||
# sender_pw = consts.ADMIN_INIT_ACCOUNT_INFO.email_pw
|
||||
|
||||
# message
|
||||
msg = MIMEMultipart()
|
||||
msg['From'] = sender
|
||||
msg['To'] = recipient
|
||||
if cc_list:
|
||||
list_cc = cc_list
|
||||
str_cc = ','.join(list_cc)
|
||||
msg['Cc'] = str_cc
|
||||
msg['Subject'] = title
|
||||
|
||||
if contents_plain:
|
||||
msg.attach(MIMEText(contents_plain, 'plain'))
|
||||
if contents_html:
|
||||
msg.attach(MIMEText(contents_html, 'html'))
|
||||
|
||||
# smtp server
|
||||
smtp_server = smtplib.SMTP(host=smtp_host, port=smtp_port)
|
||||
smtp_server.ehlo()
|
||||
smtp_server.starttls()
|
||||
smtp_server.ehlo()
|
||||
smtp_server.login(sender, sender_pw)
|
||||
smtp_server.send_message(msg)
|
||||
smtp_server.quit()
|
||||
return None
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
||||
|
||||
def query_to_groupby(query_result, key, first=False):
|
||||
"""
|
||||
쿼리 결과물(list)을 항목값(key)으로 그룹화한다.
|
||||
|
||||
:param query_result: 쿼리 결과 리스트
|
||||
:param key: 그룹 항목값
|
||||
:return: dict
|
||||
"""
|
||||
group_info = dict()
|
||||
for k, g in groupby(query_result, attrgetter(key)):
|
||||
if k not in group_info:
|
||||
if not first:
|
||||
group_info[k] = list(g)
|
||||
else:
|
||||
group_info[k] = list(g)[0]
|
||||
else:
|
||||
if not first:
|
||||
group_info[k].extend(list(g))
|
||||
return group_info
|
||||
|
||||
|
||||
def query_to_groupby_date(query_result, key):
|
||||
"""
|
||||
쿼리 결과물(list)을 항목값(key)으로 그룹화한다.
|
||||
|
||||
:param query_result: 쿼리 결과 리스트
|
||||
:param key: 그룹 항목값
|
||||
:return: dict
|
||||
"""
|
||||
group_info = dict()
|
||||
for k, g in groupby(query_result, attrgetter(key)):
|
||||
day_str = k.strftime("%Y-%m-%d")
|
||||
if day_str not in group_info:
|
||||
group_info[day_str] = list(g)
|
||||
else:
|
||||
group_info[day_str].extend(list(g))
|
||||
return group_info
|
||||
|
||||
|
||||
class FernetCrypto:
|
||||
def __init__(self, key=FERNET_SECRET_KEY):
|
||||
self.key = key
|
||||
self.f = Fernet(self.key)
|
||||
|
||||
def encrypt(self, data, is_out_string=True):
|
||||
if isinstance(data, bytes):
|
||||
ou = self.f.encrypt(data) # 바이트형태이면 바로 암호화
|
||||
else:
|
||||
ou = self.f.encrypt(data.encode('utf-8')) # 인코딩 후 암호화
|
||||
if is_out_string is True:
|
||||
return ou.decode('utf-8') # 출력이 문자열이면 디코딩 후 반환
|
||||
else:
|
||||
return ou
|
||||
|
||||
def decrypt(self, data, is_out_string=True):
|
||||
if isinstance(data, bytes):
|
||||
ou = self.f.decrypt(data) # 바이트형태이면 바로 복호화
|
||||
else:
|
||||
ou = self.f.decrypt(data.encode('utf-8')) # 인코딩 후 복호화
|
||||
if is_out_string is True:
|
||||
return ou.decode('utf-8') # 출력이 문자열이면 디코딩 후 반환
|
||||
else:
|
||||
return ou
|
||||
|
||||
|
||||
class AESCryptoCBC:
|
||||
def __init__(self, key=AES_CBC_PUBLIC_KEY, iv=AES_CBC_IV):
|
||||
# Initial vector를 0으로 초기화하여 16바이트 할당함
|
||||
# iv = chr(0) * 16 #pycrypto 기준
|
||||
# iv = bytes([0x00] * 16) #pycryptodomex 기준
|
||||
# aes cbc 생성
|
||||
self.key = key
|
||||
self.iv = iv
|
||||
self.crypto = AES.new(self.key, AES.MODE_CBC, self.iv)
|
||||
|
||||
def encrypt(self, data):
|
||||
# 암호화 message는 16의 배수여야 한다.
|
||||
# enc = self.crypto.encrypt(data)
|
||||
# return enc
|
||||
enc = self.crypto.encrypt(pad(data, AES.block_size))
|
||||
return b64encode(enc)
|
||||
|
||||
def decrypt(self, enc):
|
||||
# 복호화 enc는 16의 배수여야 한다.
|
||||
# dec = self.crypto.decrypt(enc)
|
||||
# return dec
|
||||
enc = b64decode(enc)
|
||||
dec = self.crypto.decrypt(enc)
|
||||
return unpad(dec, AES.block_size)
|
||||
|
||||
|
||||
class AESCipher:
|
||||
def __init__(self, key):
|
||||
# self.key = md5(key.encode('utf8')).digest()
|
||||
self.key = bytes(key.encode('utf-8'))
|
||||
|
||||
def encrypt(self, data):
|
||||
# iv = get_random_bytes(AES.block_size)
|
||||
iv = bytes('daooldns12345678'.encode('utf-8'))
|
||||
|
||||
self.cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
t = b64encode(self.cipher.encrypt(pad(data.encode('utf-8'), AES.block_size)))
|
||||
return b64encode(iv + self.cipher.encrypt(pad(data.encode('utf-8'), AES.block_size)))
|
||||
|
||||
def decrypt(self, data):
|
||||
raw = b64decode(data)
|
||||
self.cipher = AES.new(self.key, AES.MODE_CBC, raw[:AES.block_size])
|
||||
return unpad(self.cipher.decrypt(raw[AES.block_size:]), AES.block_size)
|
||||
|
||||
def cls_list_to_dict_list(list):
|
||||
"""
|
||||
list 내부 element가 dict로 변환 가능한 class일경우
|
||||
내부 element를 dict 로 변경
|
||||
"""
|
||||
_result = []
|
||||
for i in list:
|
||||
if isinstance(i, dict):
|
||||
_result = list
|
||||
break
|
||||
|
||||
_result.append(i.dict())
|
||||
return _result
|
||||
65
vactor_rest/app/utils/logger.py
Normal file
65
vactor_rest/app/utils/logger.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: logger.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: logger
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import timedelta, datetime
|
||||
from time import time
|
||||
from fastapi.requests import Request
|
||||
from fastapi import Body
|
||||
from fastapi.logger import logger
|
||||
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
async def api_logger(request: Request, response=None, error=None):
|
||||
time_format = '%Y/%m/%d %H:%M:%S'
|
||||
t = time() - request.state.start
|
||||
status_code = error.status_code if error else response.status_code
|
||||
error_log = None
|
||||
user = request.state.user
|
||||
if error:
|
||||
if request.state.inspect:
|
||||
frame = request.state.inspect
|
||||
error_file = frame.f_code.co_filename
|
||||
error_func = frame.f_code.co_name
|
||||
error_line = frame.f_lineno
|
||||
else:
|
||||
error_func = error_file = error_line = 'UNKNOWN'
|
||||
|
||||
error_log = dict(
|
||||
errorFunc=error_func,
|
||||
location='{} line in {}'.format(str(error_line), error_file),
|
||||
raised=str(error.__class__.__name__),
|
||||
msg=str(error.ex),
|
||||
)
|
||||
|
||||
account = user.account.split('@') if user and user.account else None
|
||||
user_log = dict(
|
||||
client=request.state.ip,
|
||||
user=user.id if user and user.id else None,
|
||||
account='**' + account[0][2:-1] + '*@' + account[1] if user and user.account else None,
|
||||
)
|
||||
|
||||
log_dict = dict(
|
||||
url=request.url.hostname + request.url.path,
|
||||
method=str(request.method),
|
||||
statusCode=status_code,
|
||||
errorDetail=error_log,
|
||||
client=user_log,
|
||||
processedTime=str(round(t * 1000, 5)) + 'ms',
|
||||
datetimeUTC=datetime.utcnow().strftime(time_format),
|
||||
datetimeKST=(datetime.utcnow() + timedelta(hours=9)).strftime(time_format),
|
||||
)
|
||||
if error and error.status_code >= 500:
|
||||
logger.error(json.dumps(log_dict))
|
||||
else:
|
||||
logger.info(json.dumps(log_dict))
|
||||
25
vactor_rest/app/utils/parsing_utils.py
Normal file
25
vactor_rest/app/utils/parsing_utils.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from const import ILLEGAL_FILE_NAME
|
||||
|
||||
def prompt_to_filenames(prompt):
|
||||
"""
|
||||
prompt 에 사용할 수 없는 문자가 있으면 '_' 로 치환
|
||||
"""
|
||||
filename = ''
|
||||
for i in prompt:
|
||||
if i in ILLEGAL_FILE_NAME:
|
||||
filename += '_'
|
||||
else:
|
||||
filename += i
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def download_range(download_count:int,max=4):
|
||||
_min = 1
|
||||
_max = max
|
||||
|
||||
if _min <= download_count and download_count <= _max:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
22
vactor_rest/app/utils/query_utils.py
Normal file
22
vactor_rest/app/utils/query_utils.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: query_utils.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: query-utility functions
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def to_dict(model, *args, exclude: List = None):
|
||||
q_dict = {}
|
||||
for c in model.__table__.columns:
|
||||
if not args or c.name in args:
|
||||
if not exclude or c.name not in exclude:
|
||||
q_dict[c.name] = getattr(model, c.name)
|
||||
|
||||
return q_dict
|
||||
222
vactor_rest/gunicorn.conf.py
Normal file
222
vactor_rest/gunicorn.conf.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# Gunicorn configuration file.
|
||||
#
|
||||
# Server socket
|
||||
#
|
||||
# bind - The socket to bind.
|
||||
#
|
||||
# A string of the form: 'HOST', 'HOST:PORT', 'unix:PATH'.
|
||||
# An IP is a valid HOST.
|
||||
#
|
||||
# backlog - The number of pending connections. This refers
|
||||
# to the number of clients that can be waiting to be
|
||||
# served. Exceeding this number results in the client
|
||||
# getting an error when attempting to connect. It should
|
||||
# only affect servers under significant load.
|
||||
#
|
||||
# Must be a positive integer. Generally set in the 64-2048
|
||||
# range.
|
||||
#
|
||||
|
||||
bind = "0.0.0.0:5000"
|
||||
backlog = 2048
|
||||
|
||||
#
|
||||
# Worker processes
|
||||
#
|
||||
# workers - The number of worker processes that this server
|
||||
# should keep alive for handling requests.
|
||||
#
|
||||
# A positive integer generally in the 2-4 x $(NUM_CORES)
|
||||
# range. You'll want to vary this a bit to find the best
|
||||
# for your particular application's work load.
|
||||
#
|
||||
# worker_class - The type of workers to use. The default
|
||||
# sync class should handle most 'normal' types of work
|
||||
# loads. You'll want to read
|
||||
# http://docs.gunicorn.org/en/latest/design.html#choosing-a-worker-type
|
||||
# for information on when you might want to choose one
|
||||
# of the other worker classes.
|
||||
#
|
||||
# A string referring to a Python path to a subclass of
|
||||
# gunicorn.workers.base.Worker. The default provided values
|
||||
# can be seen at
|
||||
# http://docs.gunicorn.org/en/latest/settings.html#worker-class
|
||||
#
|
||||
# worker_connections - For the eventlet and gevent worker classes
|
||||
# this limits the maximum number of simultaneous clients that
|
||||
# a single process can handle.
|
||||
#
|
||||
# A positive integer generally set to around 1000.
|
||||
#
|
||||
# timeout - If a worker does not notify the master process in this
|
||||
# number of seconds it is killed and a new worker is spawned
|
||||
# to replace it.
|
||||
#
|
||||
# Generally set to thirty seconds. Only set this noticeably
|
||||
# higher if you're sure of the repercussions for sync workers.
|
||||
# For the non sync workers it just means that the worker
|
||||
# process is still communicating and is not tied to the length
|
||||
# of time required to handle a single request.
|
||||
#
|
||||
# keepalive - The number of seconds to wait for the next request
|
||||
# on a Keep-Alive HTTP connection.
|
||||
#
|
||||
# A positive integer. Generally set in the 1-5 seconds range.
|
||||
#
|
||||
# reload - Restart workers when code changes.
|
||||
#
|
||||
# This setting is intended for development. It will cause
|
||||
# workers to be restarted whenever application code changes.
|
||||
workers = 3
|
||||
threads = 3
|
||||
worker_class = "uvicorn.workers.UvicornWorker"
|
||||
worker_connections = 1000
|
||||
timeout = 60
|
||||
keepalive = 2
|
||||
reload = True
|
||||
|
||||
#
|
||||
# spew - Install a trace function that spews every line of Python
|
||||
# that is executed when running the server. This is the
|
||||
# nuclear option.
|
||||
#
|
||||
# True or False
|
||||
#
|
||||
|
||||
spew = False
|
||||
|
||||
#
|
||||
# Server mechanics
|
||||
#
|
||||
# daemon - Detach the main Gunicorn process from the controlling
|
||||
# terminal with a standard fork/fork sequence.
|
||||
#
|
||||
# True or False
|
||||
#
|
||||
# raw_env - Pass environment variables to the execution environment.
|
||||
#
|
||||
# pidfile - The path to a pid file to write
|
||||
#
|
||||
# A path string or None to not write a pid file.
|
||||
#
|
||||
# user - Switch worker processes to run as this user.
|
||||
#
|
||||
# A valid user id (as an integer) or the name of a user that
|
||||
# can be retrieved with a call to pwd.getpwnam(value) or None
|
||||
# to not change the worker process user.
|
||||
#
|
||||
# group - Switch worker process to run as this group.
|
||||
#
|
||||
# A valid group id (as an integer) or the name of a user that
|
||||
# can be retrieved with a call to pwd.getgrnam(value) or None
|
||||
# to change the worker processes group.
|
||||
#
|
||||
# umask - A mask for file permissions written by Gunicorn. Note that
|
||||
# this affects unix socket permissions.
|
||||
#
|
||||
# A valid value for the os.umask(mode) call or a string
|
||||
# compatible with int(value, 0) (0 means Python guesses
|
||||
# the base, so values like "0", "0xFF", "0022" are valid
|
||||
# for decimal, hex, and octal representations)
|
||||
#
|
||||
# tmp_upload_dir - A directory to store temporary request data when
|
||||
# requests are read. This will most likely be disappearing soon.
|
||||
#
|
||||
# A path to a directory where the process owner can write. Or
|
||||
# None to signal that Python should choose one on its own.
|
||||
#
|
||||
|
||||
daemon = False
|
||||
pidfile = None
|
||||
umask = 0
|
||||
user = None
|
||||
group = None
|
||||
tmp_upload_dir = None
|
||||
|
||||
#
|
||||
# Logging
|
||||
#
|
||||
# logfile - The path to a log file to write to.
|
||||
#
|
||||
# A path string. "-" means log to stdout.
|
||||
#
|
||||
# loglevel - The granularity of log output
|
||||
#
|
||||
# A string of "debug", "info", "warning", "error", "critical"
|
||||
#
|
||||
|
||||
errorlog = "-"
|
||||
loglevel = "info"
|
||||
accesslog = None
|
||||
access_log_format = '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'
|
||||
|
||||
#
|
||||
# Process naming
|
||||
#
|
||||
# proc_name - A base to use with setproctitle to change the way
|
||||
# that Gunicorn processes are reported in the system process
|
||||
# table. This affects things like 'ps' and 'top'. If you're
|
||||
# going to be running more than one instance of Gunicorn you'll
|
||||
# probably want to set a name to tell them apart. This requires
|
||||
# that you install the setproctitle module.
|
||||
#
|
||||
# A string or None to choose a default of something like 'gunicorn'.
|
||||
#
|
||||
|
||||
proc_name = "NotificationAPI"
|
||||
|
||||
|
||||
#
|
||||
# Server hooks
|
||||
#
|
||||
# post_fork - Called just after a worker has been forked.
|
||||
#
|
||||
# A callable that takes a server and worker instance
|
||||
# as arguments.
|
||||
#
|
||||
# pre_fork - Called just prior to forking the worker subprocess.
|
||||
#
|
||||
# A callable that accepts the same arguments as after_fork
|
||||
#
|
||||
# pre_exec - Called just prior to forking off a secondary
|
||||
# master process during things like config reloading.
|
||||
#
|
||||
# A callable that takes a server instance as the sole argument.
|
||||
#
|
||||
|
||||
|
||||
def post_fork(server, worker):
|
||||
server.log.info("Worker spawned (pid: %s)", worker.pid)
|
||||
|
||||
|
||||
def pre_fork(server, worker):
|
||||
pass
|
||||
|
||||
|
||||
def pre_exec(server):
|
||||
server.log.info("Forked child, re-executing.")
|
||||
|
||||
|
||||
def when_ready(server):
|
||||
server.log.info("Server is ready. Spawning workers")
|
||||
|
||||
|
||||
def worker_int(worker):
|
||||
worker.log.info("worker received INT or QUIT signal")
|
||||
|
||||
# get traceback info
|
||||
import threading, sys, traceback
|
||||
|
||||
id2name = {th.ident: th.name for th in threading.enumerate()}
|
||||
code = []
|
||||
for threadId, stack in sys._current_frames().items():
|
||||
code.append("\n# Thread: %s(%d)" % (id2name.get(threadId, ""), threadId))
|
||||
for filename, lineno, name, line in traceback.extract_stack(stack):
|
||||
code.append('File: "%s", line %d, in %s' % (filename, lineno, name))
|
||||
if line:
|
||||
code.append(" %s" % (line.strip()))
|
||||
worker.log.debug("\n".join(code))
|
||||
|
||||
|
||||
def worker_abort(worker):
|
||||
worker.log.info("worker received SIGABRT signal")
|
||||
0
vactor_rest/tests/__init__.py
Normal file
0
vactor_rest/tests/__init__.py
Normal file
79
vactor_rest/tests/conftest.py
Normal file
79
vactor_rest/tests/conftest.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: conftest.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: test config
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from os import path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database.schema import Users
|
||||
from app.main import create_app
|
||||
from app.database.conn import db, Base
|
||||
from app.models import UserToken
|
||||
from app.routes.auth import create_access_token
|
||||
|
||||
|
||||
"""
|
||||
1. DB 생성
|
||||
2. 테이블 생성
|
||||
3. 테스트 코드 작동
|
||||
4. 테이블 레코드 삭제
|
||||
"""
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def app():
|
||||
os.environ['API_ENV'] = 'test'
|
||||
return create_app()
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def client(app):
|
||||
# Create tables
|
||||
Base.metadata.create_all(db.engine)
|
||||
return TestClient(app=app)
|
||||
|
||||
|
||||
@pytest.fixture(scope='function', autouse=True)
|
||||
def session():
|
||||
sess = next(db.session())
|
||||
yield sess
|
||||
clear_all_table_data(
|
||||
session=sess,
|
||||
metadata=Base.metadata,
|
||||
except_tables=[]
|
||||
)
|
||||
sess.rollback()
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def login(session):
|
||||
"""
|
||||
테스트전 사용자 미리 등록
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
db_user = Users.create(session=session, email='ryan_test@dingrr.com', pw='123')
|
||||
session.commit()
|
||||
access_token = create_access_token(data=UserToken.from_orm(db_user).dict(exclude={'pw', 'marketing_agree'}),)
|
||||
return dict(Authorization=f'Bearer {access_token}')
|
||||
|
||||
|
||||
def clear_all_table_data(session: Session, metadata, except_tables: List[str] = None):
|
||||
session.execute('SET FOREIGN_KEY_CHECKS = 0;')
|
||||
for table in metadata.sorted_tables:
|
||||
if table.name not in except_tables:
|
||||
session.execute(table.delete())
|
||||
session.execute('SET FOREIGN_KEY_CHECKS = 1;')
|
||||
session.commit()
|
||||
44
vactor_rest/tests/test_auth.py
Normal file
44
vactor_rest/tests/test_auth.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: test_auth.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: test auth.
|
||||
"""
|
||||
|
||||
from app.database.conn import db
|
||||
from app.database.schema import Users
|
||||
|
||||
|
||||
def test_registration(client, session):
|
||||
"""
|
||||
레버 로그인
|
||||
:param client:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
user = dict(email='ryan@dingrr.com', pw='123', name='라이언', phone='01099999999')
|
||||
res = client.post('api/auth/register/email', json=user)
|
||||
res_body = res.json()
|
||||
print(res.json())
|
||||
assert res.status_code == 201
|
||||
assert 'Authorization' in res_body.keys()
|
||||
|
||||
|
||||
def test_registration_exist_email(client, session):
|
||||
"""
|
||||
레버 로그인
|
||||
:param client:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
user = dict(email='Hello@dingrr.com', pw='123', name='라이언', phone='01099999999')
|
||||
db_user = Users.create(session=session, **user)
|
||||
session.commit()
|
||||
res = client.post('api/auth/register/email', json=user)
|
||||
res_body = res.json()
|
||||
assert res.status_code == 400
|
||||
assert 'EMAIL_EXISTS' == res_body['msg']
|
||||
33
vactor_rest/tests/test_user.py
Normal file
33
vactor_rest/tests/test_user.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File: test_user.py
|
||||
@Date: 2020-09-14
|
||||
@author: A2TEC
|
||||
@section MODIFYINFO 수정정보
|
||||
- 수정자/수정일 : 수정내역
|
||||
- 2022-01-14/hsj100@a2tec.co.kr : refactoring
|
||||
@brief: test user
|
||||
"""
|
||||
|
||||
from app.database.conn import db
|
||||
from app.database.schema import Users
|
||||
|
||||
|
||||
def test_create_get_apikey(client, session, login):
|
||||
"""
|
||||
레버 로그인
|
||||
:param client:
|
||||
:param session:
|
||||
:return:
|
||||
"""
|
||||
key = dict(user_memo='ryan__key')
|
||||
res = client.post('api/user/apikeys', json=key, headers=login)
|
||||
res_body = res.json()
|
||||
assert res.status_code == 200
|
||||
assert 'secret_key' in res_body
|
||||
|
||||
res = client.get('api/user/apikeys', headers=login)
|
||||
res_body = res.json()
|
||||
assert res.status_code == 200
|
||||
assert 'ryan__key' in res_body[0]['user_memo']
|
||||
|
||||
Reference in New Issue
Block a user