non use ai

This commit is contained in:
2024-03-04 14:22:56 +09:00
parent 3d12dfe64d
commit 209ba8345f
92 changed files with 9130 additions and 3 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 MiB

BIN
AI_ENGINE/DATA/CON.mp4 Normal file

Binary file not shown.

BIN
AI_ENGINE/DATA/FR.mov Normal file

Binary file not shown.

BIN
AI_ENGINE/DATA/PPE.mp4 Normal file

Binary file not shown.

BIN
AI_ENGINE/DATA/WD_2.mp4 Normal file

Binary file not shown.

BIN
AI_ENGINE/DATA/agics.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 65 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 76 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 77 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 77 KiB

BIN
AI_ENGINE/DATA/jangys.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

BIN
AI_ENGINE/DATA/kepco1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

BIN
AI_ENGINE/DATA/kepco1_1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.7 KiB

BIN
AI_ENGINE/DATA/kepco1_2.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

BIN
AI_ENGINE/DATA/kepco2.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

BIN
AI_ENGINE/DATA/kepco2_1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

BIN
AI_ENGINE/DATA/kepco2_10.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

BIN
AI_ENGINE/DATA/kepco2_11.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

BIN
AI_ENGINE/DATA/kepco2_12.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 74 KiB

BIN
AI_ENGINE/DATA/kepco2_13.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 67 KiB

BIN
AI_ENGINE/DATA/kepco2_2.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.9 KiB

BIN
AI_ENGINE/DATA/kepco2_3.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

BIN
AI_ENGINE/DATA/kepco2_4.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

BIN
AI_ENGINE/DATA/kepco2_5.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

BIN
AI_ENGINE/DATA/kepco2_6.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

BIN
AI_ENGINE/DATA/kepco2_7.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

BIN
AI_ENGINE/DATA/kepco2_8.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

BIN
AI_ENGINE/DATA/kepco2_9.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.0 KiB

BIN
AI_ENGINE/DATA/kimjw.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 282 KiB

BIN
AI_ENGINE/DATA/kimjw_re.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

BIN
AI_ENGINE/DATA/ksy_re.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 159 KiB

BIN
AI_ENGINE/DATA/whangsj.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

BIN
AI_ENGINE/DATA/yunikim.jpg Executable file

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

52
AI_ENGINE/demo_utils.py Normal file
View File

@@ -0,0 +1,52 @@
import threading
import ai_engine_const as AI_CONST
import cv2
import os
import paramiko
#demo
global DEMO_WD_BI_CONST
DEMO_WD_BI_CONST = 0
demo_lock = threading.Lock()
def demo_wd_bi(i):
global DEMO_WD_BI_CONST
demo_lock.acquire()
DEMO_WD_BI_CONST = i
demo_lock.release()
print(DEMO_WD_BI_CONST)
def bi_snap_shot():
rtsp = AI_CONST.RTSP
local_path = AI_CONST.FTP_BI_RESULT
if os.path.exists(local_path):
os.remove(local_path)
input_movie = cv2.VideoCapture(rtsp)
ret, frame = input_movie.read()
print(local_path)
cv2.imwrite(local_path,frame)
_bi_sftp_upload()
cv2.destroyAllWindows()
print(f"bi uploaded")
def _bi_sftp_upload():
try:
transprot = paramiko.Transport((AI_CONST.FTP_IP,AI_CONST.FTP_PORT))
transprot.connect(username = AI_CONST.FTP_ID, password = AI_CONST.FTP_PW)
sftp = paramiko.SFTPClient.from_transport(transprot)
remotepath = AI_CONST.FTP_LOCATION + os.sep + AI_CONST.FTP_BI_FILE_NAME + '.jpg'
#sftp.put(AI_CONST.FTP_BI_RESULT, remotepath)
sftp.close()
transprot.close()
return remotepath
except Exception as e:
return ""

693
AI_ENGINE/instance_queue.py Normal file
View File

@@ -0,0 +1,693 @@
# -*- coding: utf-8 -*-
"""
@file : instance_queue.py
@author: jwkim
@license: A2TEC & DAOOLDNS
@section Modify History
- 2023-01-22 오전 11:31 jwkim base
"""
import os, sys
import time
import threading
from queue import Queue
DL_BASE_PATH = os.path.abspath(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
sys.path.append(DL_BASE_PATH)
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) # const
import json
import ai_engine_const as AI_CONST
from mqtt_publish import client
from DL import d2_od_detect
from DL.FR.d2_face_detect import FaceDetect
from REST_AI_ENGINE_CONTROL.app.utils.date_utils import D
from REST_AI_ENGINE_CONTROL.app import models as M
from DL.d2_od_detect import D2_ROOT
class ThreadValues:
"""
thread간 공유 데이터
"""
STATUS_START = 'start'
STATUS_COMPLETE = 'complete'
STATUS_NEW = 'new'
STATUS_TIMEOUT = 'timeout'
STATUS_STOP = 'stop'
STATUS_ERROR = '[error] ai model error : '
STATUS_NONE = 'None'
STATUS_DETECT = 'detect'
date_utill = D()
def __init__(self):
self.current_status = self.STATUS_NONE
self.result = []
self.report_unit_result = []
self.frame_result = {}
self.timeout_status = False
self.wd_ri = 0
class WorkQueue:
# max_workers = 2
# current_workers = 0
message_queue = Queue()
# event_thread = threading.Event()
def __init__(self):
# self.current_workers = 0
self.message_queue = Queue()
# self.lock = threading.Lock()
# self.event_thread = threading.Event()
receiver_start_thread = threading.Thread(target=self.classifier, daemon=True)
receiver_start_thread.start()
# event_check_thread = threading.Thread(target=self.event, daemon=True).start()
def sender(self, data):
self.message_queue.put(data)
# TODO(jwkim): 이벤트 처리
# def event(self):
# if self.current_workers <= self.max_workers:
# print("-----------------event--------------------",self.current_workers)
# self.event_thread.set()
# else :
# pass
def classifier(self):
"""
sender를 통해 받은 메세지 처리
ai_model을 통해 해당 model queue에 전송
"""
while True:
queue_data = self.message_queue.get()
# TODO(jwkim): 이벤트 worker 체킹
# self.event_thread.wait()
# manager.lock.acquire()
# manager.current_workers += 1
# manager.lock.release()
if queue_data['ai_model'] == AI_CONST.MODEL_CON:
CON.con_sender(queue_data)
elif queue_data['ai_model'] == AI_CONST.MODEL_PPE:
PPE.ppe_sender(queue_data)
elif queue_data['ai_model'] == AI_CONST.MODEL_WORK_DETECT:
WD.wd_sender(queue_data)
elif queue_data['ai_model'] == AI_CONST.MODEL_FACE_RECOGNIZE:
FR.fr_sender(queue_data)
elif queue_data['ai_model'] == AI_CONST.MODEL_BIO_INFO:
pass
class PPEManager:
"""
PPE Manager
"""
TYPE_DETECT_SEQ = 0 # 순서적 처리
TYPE_DETECT_PAL = 1
def __init__(self, detect_type: int = TYPE_DETECT_SEQ):
self.ppe_queue = Queue()
self.detect_type = detect_type
self.ppe_receiver_thread = threading.Thread(target=self.ppe_receiver, daemon=True, name='ppe_receiver')
self.ppe_receiver_thread.start()
self.worker_event = threading.Event() # worker check
self.stop_event = threading.Event() # inference stop trigger
self.timeout_event = threading.Event() # timeout check
self.current_worker = ''
def ppe_sender(self, data):
self.ppe_queue.put(data)
def ppe_receiver(self):
"""
ppe_queue를 통해 데이터를 받아서 inference 시작, 정지 실행
ppe_thread가 실행중인 상황에서 또다른 데이터가 들어올시
ppe_thread 종료후 새로운 ppe_thread 실행
병렬처리 미구현
"""
while True:
queue_data = self.ppe_queue.get()
if queue_data:
if self.detect_type == PPEManager.TYPE_DETECT_SEQ:
# start inference
if queue_data["signal"] == AI_CONST.SIGNAL_INFERENCE:
if self.current_worker:
if self.current_worker.is_alive():
self.stop_event.set()
self.worker_event.wait()
self.worker_event.clear()
ppe_thread = threading.Thread(target=self.ppe_start, daemon=True, args=(queue_data,),
name='ppe_thread').start()
# stop inference
elif queue_data["signal"] == AI_CONST.SIGNAL_STOP:
if self.current_worker:
if self.current_worker.is_alive():
self.stop_event.set()
self.worker_event.wait()
self.worker_event.clear()
def ppe_start(self, data):
"""
inference 와 timeout 관리
timeout 발생시 inference 종료
inference 종료시 timeout_thread와 같이 종료
:param data: queue_data
"""
try:
share_value = ThreadValues()
data['argument']['source'] = self._input_source(data['engine_info'].input_video)
ppe_detect = d2_od_detect.PPEDetect(
request=data['request'],
engine_info=data['engine_info'],
thread_value=share_value,
yolo_argument=data['argument'],
worker_event=self.worker_event,
stop_event=self.stop_event,
timeout_event=self.timeout_event,
queue_info=self.ppe_queue,
)
self.current_worker = threading.Thread(target=ppe_detect.run, daemon=True, kwargs=data['argument'],
name='ppe_inference')
self.current_worker.start()
timeout_value = data['request'].limit_time_min * 60
timeout_thread = threading.Thread(target=self._timeout, daemon=True, args=(timeout_value, share_value,),
name='timeout')
timeout_thread.start()
# TODO(jwkim) : worker check
# if manager.current_workers > 0:
# manager.lock.acquire()
# manager.current_workers -= 1
# manager.lock.release()
except Exception as e:
result_message = {
"datetime": share_value.date_utill.date_str_micro_sec(),
"status": share_value.STATUS_ERROR + str(e),
"result": share_value.result
}
client.publish(AI_CONST.MQTT_PPE_TOPIC, json.dumps(result_message), 0)
if not self.timeout_event.is_set():
self.timeout_event.set()
self.timeout_event.clear()
self.stop_event.clear()
self.worker_event.set()
def _timeout(self, timeout_value, share_value):
"""
time out 관련 처리
-> timeout 발생시 timeout_event는 set 상태가 아님
:param timeout_value: 타임아웃 시간 (초)
:param share_value: 스레드간 공유 변수
"""
self.timeout_event.wait(timeout=timeout_value)
# 외부에서 set
if self.timeout_event.is_set():
share_value.timeout_status = False
# stop
elif share_value.current_status == share_value.STATUS_STOP:
share_value.timeout_status = False
# time out
else:
share_value.timeout_status = True
def _input_source(self, args):
"""
입력source를 받아서 parsing
:param args: input_video 정보
:return: parsing된 정보
"""
for i in args:
if i.model == M.AEAIModelType.PPE:
return i.connect_url
class CONManager:
"""
CON_SETUP Manager
"""
TYPE_DETECT_SEQ = 0 # 순서적 처리
TYPE_DETECT_PAL = 1
def __init__(self, detect_type: int = TYPE_DETECT_SEQ):
self.con_queue = Queue()
self.detect_type = detect_type
self.con_receiver_thread = threading.Thread(target=self.con_receiver, daemon=True, name='con_receiver')
self.con_receiver_thread.start()
self.worker_event = threading.Event() # worker check
self.stop_event = threading.Event() # inference stop trigger
self.timeout_event = threading.Event() # timeout check
self.current_worker = ''
def con_sender(self, data):
self.con_queue.put(data)
def con_receiver(self):
"""
con_queue를 통해 데이터를 받아서 inference 시작, 정지 실행
con_thread가 실행중인 상황에서 또다른 데이터가 들어올시
con_thread 종료후 새로운 con_thread 실행
병렬처리 미구현
"""
while True:
queue_data = self.con_queue.get()
if queue_data:
if self.detect_type == CONManager.TYPE_DETECT_SEQ:
# start inference
if queue_data["signal"] == AI_CONST.SIGNAL_INFERENCE:
if self.current_worker:
if self.current_worker.is_alive():
self.stop_event.set()
self.worker_event.wait()
self.worker_event.clear()
con_thread = threading.Thread(target=self.con_start, daemon=True, args=(queue_data,),
name='con_thread').start()
# stop inference
elif queue_data["signal"] == AI_CONST.SIGNAL_STOP:
if self.current_worker:
if self.current_worker.is_alive():
self.stop_event.set()
self.worker_event.wait()
self.worker_event.clear()
def con_start(self, data):
"""
inference 와 timeout 관리
timeout 발생시 inference 종료
inference 종료시 timeout_thread와 같이 종료
:param data: queue_data
"""
try:
share_value = ThreadValues()
data['argument']['source'] = self._input_source(data['engine_info'].input_video)
con_detect = d2_od_detect.CONDetect(
request=data['request'],
engine_info=data['engine_info'],
thread_value=share_value,
yolo_argument=data['argument'],
worker_event=self.worker_event,
stop_event=self.stop_event,
timeout_event=self.timeout_event,
queue_info=self.con_queue,
)
self.current_worker = threading.Thread(target=con_detect.run, daemon=True, kwargs=data['argument'],
name='con_inference')
self.current_worker.start()
timeout_value = data['request'].limit_time_min * 60
timeout_thread = threading.Thread(target=self._timeout, daemon=True, args=(timeout_value, share_value,),
name='timeout')
timeout_thread.start()
# TODO(jwkim) : worker check
# if manager.current_workers > 0:
# manager.lock.acquire()
# manager.current_workers -= 1
# manager.lock.release()
except Exception as e:
result_message = {
"datetime": share_value.date_utill.date_str_micro_sec(),
"status": share_value.STATUS_ERROR + str(e),
"result": share_value.result
}
client.publish(AI_CONST.MQTT_CON_TOPIC, json.dumps(result_message), 0)
if not self.timeout_event.is_set():
self.timeout_event.set()
self.timeout_event.clear()
self.stop_event.clear()
self.worker_event.set()
def _timeout(self, timeout_value, share_value):
"""
time out 관련 처리
-> timeout 발생시 timeout_event는 set 상태가 아님
:param timeout_value: 타임아웃 시간 (초)
:param share_value: 스레드간 공유 변수
"""
self.timeout_event.wait(timeout=timeout_value)
# 외부에서 set
if self.timeout_event.is_set():
share_value.timeout_status = False
# stop
elif share_value.current_status == share_value.STATUS_STOP:
share_value.timeout_status = False
# time out
else:
share_value.timeout_status = True
def _input_source(self, args):
"""
입력source를 받아서 parsing
:param args: input_video 정보
:return: parsing된 정보
"""
for i in args:
if i.model == M.AEAIModelType.CON:
return i.connect_url
class WDManager:
"""
WorkDetect Manager
"""
TYPE_DETECT_SEQ = 0 # 순서적 처리
TYPE_DETECT_PAL = 1
def __init__(self, detect_type: int = TYPE_DETECT_SEQ):
self.wd_queue = Queue()
self.detect_type = detect_type
self.wd_receiver_thread = threading.Thread(target=self.wd_receiver, daemon=True, name='wd_receiver')
self.wd_receiver_thread.start()
self.worker_event = threading.Event() # worker check
self.stop_event = threading.Event() # inference stop trigger
# self.timeout_event = threading.Event() # timeout check
self.lock = threading.Lock()
self.current_worker = ''
def wd_sender(self, data):
self.wd_queue.put(data)
def wd_receiver(self):
"""
wd_queue를 통해 데이터를 받아서 inference 시작, 정지 실행
wd_thread가 실행중인 상황에서 또다른 데이터가 들어올시
wd_thread 종료후 새로운 wd_thread 실행
병렬처리 미구현
"""
while True:
queue_data = self.wd_queue.get()
if queue_data:
if self.detect_type == WDManager.TYPE_DETECT_SEQ:
# start inference
if queue_data["signal"] == AI_CONST.SIGNAL_INFERENCE:
if self.current_worker:
if self.current_worker.is_alive():
self.stop_event.set()
self.worker_event.wait()
self.worker_event.clear()
wd_thread = threading.Thread(target=self.wd_start, daemon=True, args=(queue_data,),
name='wd_thread')
wd_thread.start()
# stop inference
elif queue_data["signal"] == AI_CONST.SIGNAL_STOP:
if self.current_worker:
if self.current_worker.is_alive():
self.stop_event.set()
self.worker_event.wait()
self.worker_event.clear()
def wd_start(self, data):
"""
inference 관리
외부에서 stop 발생시 inference 종료
:param data: queue_data
"""
try:
share_value = ThreadValues()
data['argument']['source'] = self._input_source(data['engine_info'].input_video)
wd_detect = d2_od_detect.WDDetect(
request=data['request'],
engine_info=data['engine_info'],
thread_value=share_value,
yolo_argument=data['argument'],
worker_event=self.worker_event,
stop_event=self.stop_event,
queue_info=self.wd_queue
)
self.current_worker = threading.Thread(target=wd_detect.run, daemon=True, kwargs=data['argument'],
name='wd_inference')
self.current_worker.start()
# TODO(jwkim) : worker check
# if manager.current_workers > 0:
# manager.lock.acquire()
# manager.current_workers -= 1
# manager.lock.release()
except Exception as e:
result_message = {
"datetime": share_value.date_utill.date_str_micro_sec(),
"status": share_value.STATUS_ERROR + str(e),
"result": {
'construction_type': data['request'].ri.construction_code,
'procedure_no': data['request'].ri.work_no,
'procedure_ri': data['request'].ri.work_define_ri,
'ri': share_value.wd_ri,
'detect_list': share_value.result
}
}
client.publish(AI_CONST.MQTT_WD_TOPIC, json.dumps(result_message), 0)
# if not self.timeout_event.is_set():
# self.timeout_event.set()
# self.timeout_event.clear()
self.stop_event.clear()
self.worker_event.set()
def _input_source(self, args):
"""
모델에서 입력소스를 받아서 parsing
만약 stream data와 files data가 같이있으면 stream data return
:param args: input_video 정보
:return: parsing된 정보
"""
result = ''
streams = [] # webcam, RTSP
files = []
for i in args:
if i.model == M.AEAIModelType.WORK:
if os.path.isfile(i.connect_url):
files.append(i.connect_url)
elif i.connect_url.isnumeric() or i.connect_url.lower().startswith(
('rtsp://', 'rtmp://', 'http://', 'https://')):
streams.append(i.connect_url)
if streams:
f = open(AI_CONST.STREAMS_PATH, 'w')
for k in streams:
# if k.isnumeric():
# k = int(k)
f.write(k + "\n")
f.close()
result = AI_CONST.STREAMS_PATH
elif files:
result = files[0]
return result
class FRManager:
"""
Face Recognition Manager
"""
# fr_queue = Queue()
TYPE_DETECT_SEQ = 0
TYPE_DETECT_PAL = 1
# DETECT_WORKER_MAX = 10
def __init__(self, detect_type: int = TYPE_DETECT_SEQ):
self.fr_queue = Queue()
# self.status_detect = False
self.detect_type = detect_type # 0: 순서적처리 1: 병렬처리
# self.detect_worker_list = []
# self.detect_worker_cur_idx = 0
self.stop_event = threading.Event() # inference stop trigger
self.worker_event = threading.Event() # worker check
self.timeout_event = threading.Event() # timeout check
self.fr_receiver_thread = threading.Thread(target=self._fr_receiver, daemon=True, name='fr_receiver')
self.fr_receiver_thread.start()
# self.frlock = threading.Lock()
self.fr_thread = ''
self.current_worker = ''
self.fr_id_info = {}
def fr_sender(self, data):
self.fr_queue.put(data)
def _fr_receiver(self):
"""
FR_sender 를 통해 받은 메세지 처리
signal data를 입력받아 inference 시작, 강제종료 시킴
TODO(jwkim) 병렬처리
"""
while True:
# message get
queue_data = self.fr_queue.get()
if queue_data:
if self.detect_type == FRManager.TYPE_DETECT_SEQ:
if queue_data["signal"] == AI_CONST.SIGNAL_INFERENCE:
# start inference
if self.current_worker:
if self.current_worker.is_alive():
self.stop_event.set()
self.worker_event.wait()
self.worker_event.clear()
self.fr_thread = threading.Thread(target=self._fr_start, daemon=True, args=(queue_data,),
name='fr_thread')
self.fr_thread.start()
elif queue_data["signal"] == AI_CONST.SIGNAL_STOP:
# stop inference
if self.current_worker:
if self.current_worker.is_alive():
self.stop_event.set()
self.worker_event.wait()
self.worker_event.clear()
# elif self.detect_type == FRManager.TYPE_DETECT_PAL:
# #TODO(jwkim): 병렬처리
# self.detect_worker_list.append(threading.Thread(target=self.FR_start, self.detect_worker_cur_idx, daemon=True).start())
# self.detect_worker_cur_idx+=1
# pass
def _fr_start(self, data):
"""
timeout 스레드와 inference 스레드 관리
:param data: _description_
"""
try:
self.timeout_event.clear()
share_value = ThreadValues()
input_video = self._input_source(data['engine_info'].input_video)
face_detect = FaceDetect(
request=data['request'],
thread_value=share_value,
stop_event=self.stop_event,
worker_event=self.worker_event,
timeout_event=self.timeout_event,
queue_info=self.fr_queue,
input_video=input_video,
engine_info=data['engine_info'],
fr_manager=self
)
self.current_worker = threading.Thread(target=face_detect.inference, daemon=True, name='fr_inference')
self.current_worker.start()
limit_time_min = data['request'].limit_time_min * 60
self.timeout_thread = threading.Thread(target=self._timeout, daemon=True,
args=(limit_time_min, share_value,), name='timeout')
self.timeout_thread.start()
except Exception as e:
result_message = {
"datetime": share_value.date_utill.date_str_micro_sec(),
"status": share_value.STATUS_ERROR + str(e),
"result": share_value.result
}
client.publish(AI_CONST.MQTT_FR_TOPIC, json.dumps(result_message), 0)
def _timeout(self, limit_min: int, share_value):
"""
입력받은 시간을 토대로 current_worker에 timeout 발생
:param limit_min: 제한시간(분)
"""
self.timeout_event.wait(timeout=limit_min)
# 외부에서 set
if self.timeout_event.is_set():
share_value.timeout_status = False
# stop
elif share_value.current_status == share_value.STATUS_STOP:
share_value.timeout_status = False
# time out
else:
share_value.timeout_status = True
def _input_source(self, args):
for i in args:
if i.model == M.AEAIModelType.FR:
return i.connect_url
manager = WorkQueue()
CON = CONManager()
PPE = PPEManager()
WD = WDManager()
FR = FRManager()
if __name__ == '__main__':
pass

44
AI_ENGINE/mqtt_publish.py Normal file
View File

@@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
"""
@file : mqtt_publish.py
@author: jwkim
@license: A2TEC & DAOOLDNS
@section Modify History
- 2023-01-11 오전 11:31 jwkim base
"""
import os, sys
import paho.mqtt.client as mqtt
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
import ai_engine_const as AI_CONST
import json
def on_connect(client, userdata, flags, rc):
if rc == 0:
print("MQTT connected OK")
else:
print("Bad connection Returned code=", rc)
def on_disconnect(client, userdata, flags, rc=0):
print(str(rc))
def on_publish(client, userdata, mid):
print("publish = ", mid)
# 새로운 클라이언트 생성
client = mqtt.Client()
# 콜백 함수 설정 on_connect(브로커에 접속), on_disconnect(브로커에 접속중료), on_publish(메세지 발행)
client.on_connect = on_connect
client.on_disconnect = on_disconnect
# client.on_publish = on_publish
client.username_pw_set(AI_CONST.MQTT_USER_ID,AI_CONST.MQTT_USER_PW)
# address : localhost, port: 1883 에 연결
client.connect(AI_CONST.MQTT_HOST, AI_CONST.MQTT_PORT)

View File

@@ -0,0 +1,64 @@
import os, sys
import paho.mqtt.client as mqtt
import json
import cv2
from base64 import b64decode
import numpy as np
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
import ai_engine_const as AI_CONST
def on_connect(mqtt_object, userdata, flags, rc):
if rc == 0:
print("connected OK")
else:
print("Bad connection Returned code=", rc)
def on_disconnect(mqtt_object, userdata, flags, rc=0):
print(str(rc))
def on_subscribe(mqtt_object, userdata, mid, granted_qos):
print("subscribed: " + str(mid) + " " + str(granted_qos))
def on_message(mqtt_object, userdata, msg):
"""
:param mqtt_object: 정의된mqtt client
:param userdata: ?
:param msg: 전송받은 메세지 정보
"""
if msg.topic == AI_CONST.MQTT_PPE_TOPIC:
object_message = json.loads(msg.payload.decode("utf-8"))
# if object_message["img_data"]:
# img_data = b64decode(object_message["img_data"])
#
# save_image = cv2.imdecode(np.frombuffer(img_data,dtype=np.uint8), cv2.IMREAD_COLOR)
# cv2.imwrite("./test.jpg", save_image)
print(object_message)
elif msg.topic == AI_CONST.MQTT_FR_TOPIC:
print(msg.payload.decode("utf-8"))
client = mqtt.Client()
client.on_connect = on_connect
client.on_disconnect = on_disconnect
client.on_subscribe = on_subscribe
client.on_message = on_message
client.username_pw_set(AI_CONST.MQTT_USER_ID, AI_CONST.MQTT_USER_PW)
client.connect(AI_CONST.MQTT_HOST, AI_CONST.MQTT_PORT)
client.subscribe(AI_CONST.MQTT_PPE_TOPIC) # ppe
client.subscribe(AI_CONST.MQTT_FR_TOPIC) # FR
client.loop_forever()

63
AI_ENGINE/old_queue.py Normal file
View File

@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
"""
@File: queue.py
@Date: 2022-06-30
@author: jwkim
@section MODIFYINFO 수정정보
- 수정자/수정일 : 수정내역
@brief: queue
"""
import os, sys
import time
import threading
from queue import Queue
DL_BASE_PATH = os.path.abspath(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
sys.path.append(DL_BASE_PATH)
from DL import d2_od_detect
class WorkQueue:
max_workers = 4
message_queue = Queue()
def __init__(self):
self.current_workers = 0
self.message_queue = Queue()
self.lock = threading.Lock()
receiver_start_thread = threading.Thread(target=self.classifier, daemon=True).start()
def sender(self, data):
self.message_queue.put(data)
def classifier(self):
while True:
queue_data = self.message_queue.get()
work_thread = threading.Thread(target=self.check_worker, args=(queue_data,), daemon=True).start()
def check_worker(self, args):
if self.max_workers > self.current_workers:
self.start_worker(args)
elif self.max_workers <= self.current_workers:
while self.max_workers <= self.current_workers:
time.sleep(1)
self.start_worker(args)
def start_worker(self, inference_args):
self.lock.acquire()
self.current_workers += 1
self.lock.release()
d2_od_detect.run(**inference_args)
if self.current_workers > 0:
self.lock.acquire()
self.current_workers -= 1
self.lock.release()
if __name__ == '__main__':
pass

581
DL/FR/d2_face_detect.py Normal file
View File

@@ -0,0 +1,581 @@
# -*- coding: utf-8 -*-
"""
@file : d2_face_detect.py
@author: jwkim
@license: A2TEC & DAOOLDNS
@section Modify History
- 2023-01-11 오전 11:31 jwkim base
"""
import face_recognition
import cv2
import os, sys
import copy
import json
import threading
import numpy as np
import paramiko
AI_ENGINE_PATH = "/AI_ENGINE"
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))) + AI_ENGINE_PATH) # mqtt
os.path.dirname(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) # ai_const
from mqtt_publish import client
import ai_engine_const as AI_CONST
from REST_AI_ENGINE_CONTROL.app import models as M
from REST_AI_ENGINE_CONTROL.app.utils.extra import file_size_check
import project_config
# This is a demo of running face recognition on a video file and saving the results to a new video file.
#
# PLEASE NOTE: This example requires OpenCV (the `cv2` library) to be installed only to read from your webcam.
# OpenCV is *not* required to use the face_recognition library. It's only required if you want to run this
# specific demo. If you have trouble installing it, try any of the other demos that don't require it instead.
class FaceDetect:
"""
안면인식
외부 event thread를 이용하여 제어
inference 동작중 해당상황에 맞는 메세지 publish
"""
def __init__(self,
request,
thread_value,
stop_event: threading.Event,
worker_event: threading.Event,
timeout_event: threading.Event,
queue_info,
engine_info,
input_video,
fr_manager
):
"""
초기화 함수
:param request: 외부 요청 메세지
report_unit : 신규 대상 인식할때마다 전송할지 여부
targets : detect할 대상의 명단)
:param stop_event: instance 내부 loop 종료 관련 이벤트
:param worker_event: current_worker(thread) 종료 관련 이벤트
:param timeout_event: timeout 발생시 내부 loop 종료 관련 이벤트
"""
self.report_unit = request.report_unit
self.targets = request.targets
self.thread_value = thread_value
self.mqtt_status = self.thread_value.STATUS_NONE
self.result = self.thread_value.result
self.result = []
self.result.extend([None] * (len(self.targets)))
self.report_unit_result = self.thread_value.report_unit_result
self.report_unit_result = []
self.worker_event = worker_event # 외부 worker 제어
self.stop_event = stop_event # stop trigger
self.stop_event.clear()
self.timeout_event = timeout_event
self.thread_value = thread_value
self.stop_sign = False
self.fr_manager = fr_manager
self.id_info = self.fr_manager.fr_id_info
# self.worker_names = list(self.id_info.keys()) #TODO(jwkim): 작업자 수동 등록
self.worker_names = self.targets
self.queue_info = queue_info
self.input_video = int(input_video) if input_video.isnumeric() else input_video
self.model_info = engine_info
self.ftp_info = self.model_info.demo.ftp
self.ri_info = request.ri or self.model_info.fr_model_info.ri
self.snapshot_path = None
self.result_frame = []
# def _ai_rbi(self, detected):
# # ri 정보 self.ri_info
# pass
def _sftp_upload(self):
try:
transprot = paramiko.Transport((self.ftp_info.ip,self.ftp_info.port))
transprot.connect(username = self.ftp_info.id, password = self.ftp_info.pw)
sftp = paramiko.SFTPClient.from_transport(transprot)
remotepath = self.ftp_info.location + os.sep + self.ftp_info.file_face + '.jpg'
#sftp.put(AI_CONST.FTP_FR_RESULT, remotepath)
sftp.close()
transprot.close()
return remotepath
except Exception as e:
return ""
def _mqtt_publish(self, status: str):
"""
mqtt message publish
:param status: 안면인식 상태
"""
client.loop_start()
if status == self.thread_value.STATUS_NEW:
result = self.report_unit_result
else:
result = self.result
mqtt_msg_dict = {
"datetime": self.thread_value.date_utill.date_str_micro_sec(),
"status": status,
"result": result
}
if self.snapshot_path:
mqtt_msg_dict[AI_CONST.DEMO_KEY_NAME_SNAPSHOT_SFTP] = self.snapshot_path
client.publish(AI_CONST.MQTT_FR_TOPIC, json.dumps(mqtt_msg_dict), 0)
# client.loop_stop()
def _result_update(self, matched_list: list):
"""
:param matched_list: detect 된 object list
"""
self.mqtt_status = self.thread_value.STATUS_NONE # status init
if self.report_unit:
self.report_unit_result = []
self.report_unit_result.extend([None] * (len(self.targets)))
origin_matched_list = copy.deepcopy(matched_list)
# 초기 result
if self.result.count(None) == len(self.targets):
for i in matched_list:
if i in self.worker_names and i in self.targets:
self.result[self.targets.index(i)] = i
if self.report_unit and self.result.count(None) != len(self.targets):
self.mqtt_status = self.thread_value.STATUS_NEW
self.report_unit_result = self.result
# update result
for i in origin_matched_list:
if i not in self.result and i in self.targets:
if self.report_unit:
self.mqtt_status = self.thread_value.STATUS_NEW
self.report_unit_result[self.targets.index(i)] = i
self.result[self.targets.index(i)] = i
# TODO(jwkim): 함수에서 분리
if self.result.count(None) == 0 and self.targets:
self.mqtt_status = self.thread_value.STATUS_COMPLETE
self.stop_event.set()
def _encoding(self):
"""
안면인식 대상 이미지 등록
등록된 인원 : no000001,no000002,no000003
:return: encoding 정보(list)
"""
# Load some sample pictures and learn how to recognize them.
worker1_image = face_recognition.load_image_file(AI_CONST.WORKER1_IMG_PATH)
no000001 = {
'target_names' : 'no000001' ,
'encoding' : face_recognition.face_encodings(worker1_image)[0]
}
worker2_image = face_recognition.load_image_file(AI_CONST.WORKER2_IMG_PATH)
no000002 = {
'target_names' : 'no000002',
'encoding' : face_recognition.face_encodings(worker2_image)[0]
}
worker3_image = face_recognition.load_image_file(AI_CONST.WORKER3_IMG_PATH)
no000003 = {
'target_names' : 'no000003',
'encoding' : face_recognition.face_encodings(worker3_image)[0]
}
worker4_image = face_recognition.load_image_file(AI_CONST.WORKER4_IMG_PATH)
no000004 = {
'target_names' : 'no000004',
'encoding' : face_recognition.face_encodings(worker4_image)[0]
}
worker5_image = face_recognition.load_image_file(AI_CONST.WORKER5_IMG_PATH)
no000005 = {
'target_names' : 'no000005',
'encoding' : face_recognition.face_encodings(worker5_image)[0]
}
worker6_image = face_recognition.load_image_file(AI_CONST.WORKER6_IMG_PATH)
no000006 = {
'target_names' : 'no000006',
'encoding' : face_recognition.face_encodings(worker6_image)[0]
}
worker7_image = face_recognition.load_image_file(AI_CONST.WORKER7_IMG_PATH)
no000007 = {
'target_names' : 'no000007',
'encoding' : face_recognition.face_encodings(worker7_image)[0]
}
worker8_image = face_recognition.load_image_file(AI_CONST.WORKER8_IMG_PATH)
no000008 = {
'target_names' : 'no000008',
'encoding' : face_recognition.face_encodings(worker8_image)[0]
}
# worker9_image = face_recognition.load_image_file(AI_CONST.WORKER9_IMG_PATH)
# no000009 = {
# 'target_names' : 'no000009',
# 'encoding' : face_recognition.face_encodings(worker9_image)[0]
# }
worker10_image = face_recognition.load_image_file(AI_CONST.WORKER10_IMG_PATH)
no000010 = {
'target_names' : 'no000010',
'encoding' : face_recognition.face_encodings(worker10_image)[0]
}
worker11_image = face_recognition.load_image_file(AI_CONST.WORKER11_IMG_PATH)
no000011 = {
'target_names' : 'no000011',
'encoding' : face_recognition.face_encodings(worker11_image)[0]
}
worker12_image = face_recognition.load_image_file(AI_CONST.WORKER12_IMG_PATH)
no000012 = {
'target_names' : 'no000012',
'encoding' : face_recognition.face_encodings(worker12_image)[0]
}
# print('d3')
# worker13_image = face_recognition.load_image_file(AI_CONST.WORKER13_IMG_PATH)
# no000013 = {
# 'target_names' : 'no000013',
# 'encoding' : face_recognition.face_encodings(worker13_image)[0]
# }
# print('d4')
# worker14_image = face_recognition.load_image_file(AI_CONST.WORKER14_IMG_PATH)
# no000014 = {
# 'target_names' : 'no000014',
# 'encoding' : face_recognition.face_encodings(worker14_image)[0]
# }
# print('d5')
worker15_image = face_recognition.load_image_file(AI_CONST.WORKER15_IMG_PATH)
no000015 = {
'target_names' : 'no000015',
'encoding' : face_recognition.face_encodings(worker15_image)[0]
}
# print('d6')
worker16_image = face_recognition.load_image_file(AI_CONST.WORKER16_IMG_PATH)
no000016 = {
'target_names' : 'no000016',
'encoding' : face_recognition.face_encodings(worker16_image)[0]
}
# print('d7')
# worker17_image = face_recognition.load_image_file(AI_CONST.WORKER17_IMG_PATH)
# no000017 = {
# 'target_names' : 'no000017',
# 'encoding' : face_recognition.face_encodings(worker17_image)[0]
# }
# print('d8')
worker18_image = face_recognition.load_image_file(AI_CONST.WORKER18_IMG_PATH)
no000018 = {
'target_names' : 'no000018',
'encoding' : face_recognition.face_encodings(worker18_image)[0]
}
# print('d9')
worker19_image = face_recognition.load_image_file(AI_CONST.WORKER19_IMG_PATH)
no000019 = {
'target_names' : 'no000019',
'encoding' : face_recognition.face_encodings(worker19_image)[0]
}
# print('d0')
worker20_image = face_recognition.load_image_file(AI_CONST.WORKER20_IMG_PATH)
no000020 = {
'target_names' : 'no000020',
'encoding' : face_recognition.face_encodings(worker20_image)[0]
}
# print('d01')
worker21_image = face_recognition.load_image_file(AI_CONST.WORKER21_IMG_PATH)
no000021 = {
'target_names' : 'no000021',
'encoding' : face_recognition.face_encodings(worker21_image)[0]
}
# print('d02')
worker22_image = face_recognition.load_image_file(AI_CONST.WORKER22_IMG_PATH)
no000022 = {
'target_names' : 'no000022',
'encoding' : face_recognition.face_encodings(worker22_image)[0]
}
encoding_list = [
#no000001,no000002,no000003 # kepco 1
# ,
no000004 # jangys
,no000005 # whangsj
# ,no000006
# ,no000007
,no000008 # agics
# ,no000009
# ,
# ,no000010,no000011,no000012,no000015,no000018 #Helmet on kepco2
,no000019,no000020 #Helmet off kepco 2
,no000021 #,no000022
]
result = []
#demo
for i in encoding_list:
result.append(i["encoding"])
names = [
#'no000001','no000002','no000003'# kepco 1
# ,
'no000004' # jangys
,'no000005' # whangsj
# ,'no000006'
# ,'no000007'
,'no000008' # agics
# ,'no000009'
# ,
# ,'no000010','no000011','no000012','no000015','no000018' #Helmet on kepco2
,'no000019','no000020' #Helmet off kepco 2
,'no000021' #,no000022
]
return result , names
# for i in encoding_list:
# if i['target_names'] in self.targets:
# result.append(i["encoding"])
# if not result:
# raise Exception("invalid targets")
# return result
def _new_encoding(self):
result = []
for key,value in self.id_info.items():
current_size = file_size_check(value['image'])
if value['size'] != current_size :
raise Exception(AI_CONST.INVALID_IMG_MSG)
fr_image = face_recognition.load_image_file(value['image'].file)
fr_encoding = face_recognition.face_encodings(fr_image)[0]
result.append(fr_encoding)
return result
def inference(self):
"""
안면인식 inference 동작과 동작 상태에 따라 mqtt message publish
"""
try:
if os.path.exists(AI_CONST.FTP_FR_RESULT):
os.remove(AI_CONST.FTP_FR_RESULT)
# start publish
self._mqtt_publish(status=self.thread_value.STATUS_START)
input_movie = cv2.VideoCapture(self.input_video)
#demo
known_faces, known_names = self._encoding()
# known_faces = self._encoding()
# known_faces = self._new_encoding() #TODO(jwkim): 작업자 수동 등록
# Initialize some variables
# face_locations = []
# face_encodings = []
# while self.mqtt_status != self.thread_value.STATUS_COMPLETE and self.stop_sign != True:
# while True:
# # TODO(JWKIM):ageing test 23-02-28
# # self._mqtt_publish(status="AGEING_TEST")
# # loop stop
# if self.stop_event.is_set():
# if self.mqtt_status == self.thread_value.STATUS_COMPLETE:
# pass
# else:
# self.thread_value.current_status = self.thread_value.STATUS_STOP
# # TODO(JWKIM):ageing test 23-02-28
# break
# elif self.thread_value.timeout_status:
# self.thread_value.current_status = self.thread_value.STATUS_TIMEOUT
# # TODO(JWKIM):ageing test 23-02-28
# break
# Grab a single frame of video
ret, frame = input_movie.read()
cv2.imwrite(AI_CONST.FTP_FR_RESULT,frame)
self._sftp_upload()
self.result = ["no000010"]
self._mqtt_publish(status = self.thread_value.STATUS_COMPLETE)
# Quit when the input video file ends
# 영상파일일 경우 다시 재생
# if not ret and os.path.isfile(self.input_video):
# input_movie = cv2.VideoCapture(self.input_video)
# continue
# frame_count = int(input_movie.get(cv2.CV_CAP_PROP_FPS))
# Convert the image from BGR color (which OpenCV uses) to RGB color (which face_recognition uses)
# rgb_frame = frame[:, :, ::-1]
# face_locations = face_recognition.face_locations(rgb_frame)
# face_encodings = face_recognition.face_encodings(rgb_frame, face_locations)
# for face_encoding in face_encodings:
# matched_name = [] # detect 완료된 명단
# match = face_recognition.compare_faces(known_faces, face_encoding, tolerance=AI_CONST.FACE_EVOLUTION_DISTANCE)
# #TODO(jwkim): 입력 이미지 변경
# if self.fr_manager.fr_id_info != self.id_info:
# raise Exception(AI_CONST.IMG_CHANGED_MSG)
# face_distances = face_recognition.face_distance(known_faces, face_encoding)
# best_match_index = np.argmin(face_distances)
# if face_distances[best_match_index] < AI_CONST.FACE_EVOLUTION_DISTANCE :
# #demo
# print(match)
# if match[best_match_index]:
# # self.result[0] = known_names[best_match_index]
# self.result[0] = "no000010"
# self.mqtt_status = self.thread_value.STATUS_COMPLETE
# print(known_names[best_match_index])
# self.stop_event.set()
# break
# if match[best_match_index]:
# matched_name.append(self.worker_names[best_match_index])
# print(self.worker_names[best_match_index])
#TODO(jwkim): 시연용
# self._result_update(matched_name)
# ri
# self._ai_rbi(matched_name)
# if self.mqtt_status == self.thread_value.STATUS_NEW:
# self._mqtt_publish(self.thread_value.STATUS_NEW)
# # input source check
# self._source_check()
# # loop stop
# if self.stop_event.is_set():
# if self.mqtt_status == self.thread_value.STATUS_COMPLETE:
# self.result_frame = copy.deepcopy(frame)
# else:
# self.thread_value.current_status = self.thread_value.STATUS_STOP
# # TODO(JWKIM):ageing test 23-02-28
# break
# elif self.thread_value.timeout_status:
# self.thread_value.current_status = self.thread_value.STATUS_TIMEOUT
# # TODO(JWKIM):ageing test 23-02-28
# break
# if self.thread_value.timeout_status:
# self._mqtt_publish(self.thread_value.STATUS_TIMEOUT)
# self.thread_value.timeout_status = False
# elif self.stop_event.is_set():
# if self.mqtt_status == self.thread_value.STATUS_COMPLETE:
# #sftp
# cv2.imwrite(AI_CONST.FTP_FR_RESULT,self.result_frame)
# if project_config.SFTP_UPLOAD and project_config.FR_UPLOAD:
# self.snapshot_path = self._sftp_upload()
# else:
# self.snapshot_path = None
# self._mqtt_publish(self.thread_value.STATUS_COMPLETE)
# else:
# self._mqtt_publish(self.thread_value.STATUS_STOP)
# if not self.timeout_event.is_set():
# self.timeout_event.set()
# self.timeout_event.clear()
# if self.mqtt_status == self.thread_value.STATUS_COMPLETE:
# pass
except Exception as e:
print(e)
# publish error
self._mqtt_publish(status=self.thread_value.STATUS_ERROR + str(e))
self._queue_empty()
finally:
if not self.timeout_event.is_set():
self.timeout_event.set()
self.timeout_event.clear()
# All done!
# output_movie.release()
# input_movie.release()
cv2.destroyAllWindows()
self.stop_event.clear()
self.worker_event.set()
def _source_check(self):
"""
현재 동작중인 input source 와 모델에 세팅된 input source가 다를시 예외 발생
"""
current_source = self.input_video
for i in self.model_info.input_video:
if i.model == M.AEAIModelType.FR:
model_source = i.connect_url
if model_source.isnumeric():
model_source = int(model_source)
if current_source != model_source:
raise Exception(AI_CONST.SOURCE_CHANGED_MSG)
def _queue_empty(self):
"""
queue_info에 있는 데이터를 비운다.
"""
while True:
if self.queue_info.empty():
break
else:
self.queue_info.get()
if __name__ == '__main__':
pass

113
DL/custom_utils.py Normal file
View File

@@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
"""
@file : custom_utils.py
@author: Ultralytics , jwkim
@license: GPL-3.0 license
@section Modify History
-
"""
import torch
import cv2
import math
import time
import os
import numpy as np
from threading import Thread
from pathlib import Path
from urllib.parse import urlparse
from OD.utils.augmentations import letterbox
from OD.utils.general import clean_str, check_requirements, is_colab, is_kaggle, LOGGER
class LoadStreams:
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='file.streams', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1,
event=None):
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.event = event # TODO(jwkim) thread 종료 관련
self.mode = 'stream'
self.img_size = img_size
self.stride = stride
self.vid_stride = vid_stride # video frame-rate stride
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
n = len(sources)
self.sources = [clean_str(x) for x in sources] # clean source names for later
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
st = f'{i + 1}/{n}: {s}... '
if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
import pafy
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
if s == 0:
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
cap = cv2.VideoCapture(s)
assert cap.isOpened(), f'{st}Failed to open {s}'
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
_, self.imgs[i] = cap.read() # guarantee first frame
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
self.threads[i].start()
LOGGER.info('') # newline
# check for common shapes
s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
self.auto = auto and self.rect
self.transforms = transforms # optional
if not self.rect:
LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
def update(self, i, cap, stream):
# Read stream `i` frames in daemon thread
n, f = 0, self.frames[i] # frame number, frame array
while cap.isOpened() and n < f:
n += 1
cap.grab() # .read() = .grab() followed by .retrieve()
if n % self.vid_stride == 0:
success, im = cap.retrieve()
if success:
self.imgs[i] = im
else:
LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
self.imgs[i] = np.zeros_like(self.imgs[i])
cap.open(stream) # re-open stream if signal was lost
time.sleep(0.0) # wait time
if self.event.is_set(): # TODO(jwkim) thread 종료 관련
break
def __iter__(self):
self.count = -1
return self
def __next__(self):
self.count += 1
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
cv2.destroyAllWindows()
raise StopIteration
im0 = self.imgs.copy()
if self.transforms:
im = np.stack([self.transforms(x) for x in im0]) # transforms
else:
im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
im = np.ascontiguousarray(im) # contiguous
return self.sources, im, im0, None, ''
def __len__(self):
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years

1764
DL/d2_od_detect.py Normal file

File diff suppressed because it is too large Load Diff

BIN
DL/index_78.pt Normal file

Binary file not shown.

1
DL/wd.streams Normal file
View File

@@ -0,0 +1 @@
rtsp://admin:admin1263!@10.20.10.99:28554/onvif/media?profile=Profile2

View File

@@ -1,3 +1,36 @@
# AI_ENGINE_0.5
AI_ENGINE_0.5
# MG_ENGINE_SIM
위험성 평가기반 작업자 안전관리 및 감독 협조 기술 개발 (시뮬레이터)
### System Requirements
- ubuntu 20.04
- python 3.8
- python packages : requirements.txt
### Pytorch, Torchvision install
- conda(CUDA 11.7 기준)
```
conda install pytorch torchvision pytorch-cuda=11.7 -c pytorch -c nvidia
```
- pip
```
pip3 install torch torchvision torchaudio
```
- ARM architecture
JetPack 5.0.2(L4T R35.1.0) 기준
1. pytorch(version 1.12)
```
wget https://developer.download.nvidia.com/compute/redist/jp/v50/pytorch/torch-1.12.0a0+2c916ef.nv22.3-cp38-cp38-linux_aarch64.whl
apt-get install python3-pip libopenblas-base libopenmpi-dev libomp-dev
pip3 install Cython
pip3 install torch-1.12.0a0+2c916ef.nv22.3-cp38-cp38-linux_aarch64.whl
```
2. torchvision(version 0.13)
```
sudo apt-get install libjpeg-dev zlib1g-dev libpython3-dev libavcodec-dev libavformat-dev libswscale-dev
git clone --branch release/0.13 https://github.com/pytorch/vision torchvision
cd torchvision
export BUILD_VERSION=0.13.0
python3 setup.py install --user
```

142
REST_AI_ENGINE_CONTROL/.gitignore vendored Normal file
View File

@@ -0,0 +1,142 @@
# ---> Python
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# YOLOv5
/DL/OD

View File

@@ -0,0 +1,27 @@
# ------------------------------------------------------------------------------
# Base image
# ------------------------------------------------------------------------------
FROM python:3.9-slim
# ------------------------------------------------------------------------------
# Informations
# ------------------------------------------------------------------------------
LABEL maintainer="hsj100 <hsj100@daooldns.co.kr>"
LABEL title="REST_AI_ENGINE_CONTROL"
LABEL description="Rest API Server with Fast API"
# ------------------------------------------------------------------------------
# Install dependencies
# ------------------------------------------------------------------------------
WORKDIR /FAST_API
COPY ./requirements.txt /FAST_API/requirements.txt
RUN apt update > /dev/null && \
apt install -y build-essential && \
pip install --no-cache-dir --upgrade -r /FAST_API/requirements.txt
# ------------------------------------------------------------------------------
# Source
# ------------------------------------------------------------------------------
COPY ./app /FAST_API/app
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "50220"]

View File

@@ -0,0 +1,4 @@
# D2.KEPCO.AI_RBI.REST_AI_ENGINE_CONTROL
REST SERVER
AI ENGINE 제어

View File

@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
"""
@file: api_request_sample.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: test api key
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
import base64
import hmac
from datetime import datetime, timedelta
import requests
def parse_params_to_str(params):
url = "?"
for key, value in params.items():
url = url + str(key) + '=' + str(value) + '&'
return url[1:-1]
def hash_string(qs, secret_key):
mac = hmac.new(bytes(secret_key, encoding='utf8'), bytes(qs, encoding='utf-8'), digestmod='sha256')
d = mac.digest()
validating_secret = str(base64.b64encode(d).decode('utf-8'))
return validating_secret
def sample_request():
access_key = 'c0883231-4aa9-4a1f-a77b-3ef250af-e449-42e9-856a-b3ada17c426b'
secret_key = 'QhOaeXTAAkW6yWt31jWDeERkBsZ3X4UmPds656YD'
cur_time = datetime.utcnow()+timedelta(hours=9)
cur_timestamp = int(cur_time.timestamp())
qs = dict(key=access_key, timestamp=cur_timestamp)
header_secret = hash_string(parse_params_to_str(qs), secret_key)
url = f'http://127.0.0.1:8080/api/services?{parse_params_to_str(qs)}'
res = requests.get(url, headers=dict(secret=header_secret))
return res
print(sample_request().json())

View File

@@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
"""
@File: config.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: 실행 환경 설정
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
from dataclasses import dataclass
from os import path, environ
from app.common import consts
from app.models import UserInfo
base_dir = path.dirname(path.dirname(path.dirname(path.abspath(__file__))))
@dataclass
class Config:
"""
기본 Configuration
"""
BASE_DIR: str = base_dir
DB_POOL_RECYCLE: int = 900
DB_ECHO: bool = True
DEBUG: bool = False
TEST_MODE: bool = False
DEV_TEST_CONNECT_ACCOUNT: str = None
# NOTE(hsj100): SERVICE_AUTH_API_KEY (token 방식으로 진행)
SERVICE_AUTH_API_KEY: bool = False
DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}:{consts.DB_PORT}/{consts.DB_NAME}?charset={consts.DB_CHARSET}')
REST_SERVER_PORT = consts.REST_SERVER_PORT
SW_TITLE = consts.SW_TITLE
SW_VERSION = consts.SW_VERSION
SW_DESCRIPTION = consts.SW_DESCRIPTION
TERMS_OF_SERVICE = consts.TERMS_OF_SERVICE
CONTEACT = consts.CONTEACT
LICENSE_INFO = consts.LICENSE_INFO
GLOBAL_TOKEN = consts.ADMIN_INIT_ACCOUNT_INFO.connect_token
@dataclass
class LocalConfig(Config):
TRUSTED_HOSTS = ['*']
ALLOW_SITE = ['*']
DEBUG: bool = False
@dataclass
class ProdConfig(Config):
TRUSTED_HOSTS = ['*']
ALLOW_SITE = ['*']
@dataclass
class TestConfig(Config):
DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}:{consts.DB_PORT}/{consts.DB_NAME}_test?charset={consts.DB_CHARSET}')
TRUSTED_HOSTS = ['*']
ALLOW_SITE = ['*']
TEST_MODE: bool = True
@dataclass
class DevConfig(Config):
TRUSTED_HOSTS = ['*']
ALLOW_SITE = ['*']
DEBUG: bool = True
DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}:{consts.DB_PORT}/{consts.DB_NAME}_dev?charset={consts.DB_CHARSET}')
REST_SERVER_PORT = consts.REST_SERVER_PORT + 1
SW_TITLE = '[Dev] ' + consts.SW_TITLE
@dataclass
class MyConfig(Config):
TRUSTED_HOSTS = ['*']
ALLOW_SITE = ['*']
DEBUG: bool = True
DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}:{consts.DB_PORT}/{consts.DB_NAME}_my?charset={consts.DB_CHARSET}')
REST_SERVER_PORT = consts.REST_SERVER_PORT + 2
# NOTE(hsj100): DEV_TEST_CONNECT_ACCOUNT
DEV_TEST_CONNECT_ACCOUNT: UserInfo = UserInfo(**consts.ADMIN_INIT_ACCOUNT_INFO.get_dict())
# DEV_TEST_CONNECT_ACCOUNT: str = None
# NOTE(hsj100): SERVICE_AUTH_API_KEY (token 방식으로 진행)
SERVICE_AUTH_API_KEY: bool = False
SW_TITLE = '[My] ' + consts.SW_TITLE
# GLOBAL_TOKEN = None
def conf():
"""
환경 불러오기
:return:
"""
config = dict(prod=ProdConfig, local=LocalConfig, test=TestConfig, dev=DevConfig, my=MyConfig)
return config[environ.get('API_ENV', 'local')]()
return config[environ.get('API_ENV', 'dev')]()
return config[environ.get('API_ENV', 'my')]()
return config[environ.get('API_ENV', 'test')]()

View File

@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
"""
@File: consts.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: 상수 선언
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
import sys,os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))))
import project_config
# SUPPORT PROJECT
SUPPORT_PROJECT_BASIC = 'PROJECT_BASIC'
SUPPORT_PROJECT_MEDICAL_METAVERSE = 'MEDICAL METAVERSE'
SUPPORT_PROJECT_SW_LICENSE_MANAGER = 'AI-KEPCO REST SERVER'
PROJECT_NAME = '[SIM] AI Engine Control'
SW_TITLE= f'{PROJECT_NAME} - REST API'
SW_VERSION = project_config.PROJECT_VERSION
SW_DESCRIPTION = f'''
# 위험성 평가기반 작업자 안전관리 및 감독 협조 기술 개발
## KEPCO AI RBI(risk based inspection)
[Simulator manual v0.2](http://219.250.188.208/docs/AI_Engine_Simuilator_0.2_20230206.pptx)\n
[MQTT viewer](http://219.250.188.208:50777/)\n
## API 이용법
- **사용자 접속**시 토큰정보(***Authorization*** 항목) 획득 (이후 API 이용가능)
- 수신받은 토큰정보(***Authorization*** 항목)를 ***Request Header*** 에 포함해서 기타 API 사용
- 개별 API 설명과 Request/Response schema 참조
**시스템에서 사용하는 날짜형식**\n
YYYY-mm-ddTHH:MM:SS
YYYY-mm-dd HH:MM:SS
**검색/변경 API 사용시, 항목값 null 사용방법(string 항목에서만 유효)**
- 해당 항목의 값을 문자열("null")로 사용**\n
"mac": "null"
- 해당 항목의 값을 null로 사용할 경우 처리 대상에서 제외됨\n
"mac": null <= 처리 대상에서 제외
![This image does not work](http://219.250.188.208/image/mqtt_layer.jpg)\n
'''
TERMS_OF_SERVICE = 'http://www.daooldns.co.kr'
CONTEACT={
'name': 'DAOOLDNS (주)다울디엔에스',
'url': 'http://www.daooldns.co.kr',
'email': 'marketing@daooldns.co.kr'
}
LICENSE_INFO = {
'name': 'Copyright by DAOOLDNS',
'url': 'http://www.daooldns.co.kr'
}
REST_SERVER_PORT = 50770
DEFAULT_USER_ACCOUNT_PW = '1234'
class AdminInfo:
def __init__(self):
self.id: int = 1
self.user_grade: str = 'admin'
self.account: str = 'a2d2_lc_manager@naver.com'
# self.pw: str = '$2b$12$PklBvVXdLhOQnIiNanlnIu.DJh5MspRARVChJQfFu1qg35vBoIuX2' # bcrypy hash
self.pw: str = 'E563ZFt+yJL8YY5yYlYyk602MSscPP2SCCD8UtXXpMI=' # AESCryptoCBC encrypt
self.name: str = 'administrator'
self.email: str = 'a2d2_lc_manager@naver.com'
self.email_pw: str = 'gAAAAABioV5NucuS9nQugZJnz-KjVG_FGnaowB9KAfhOoWjjiQ4jGLuYJh4Qe94mT_lCm6m3HhuOJqUeOgjppwREDpIQYzrUXA=='
self.address: str = '대구광역시 동구 동촌로351 에이스빌딩 4F'
self.phone_number: str = '053-384-3010'
self.connect_token: str = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6MSwiYWNjb3VudCI6ImEyZDJfbGNfbWFuYWdlckBuYXZlci5jb20iLCJuYW1lIjoiYWRtaW5pc3RyYXRvciIsInBob25lX251bWJlciI6IjA1My0zODQtMzAxMCIsInByb2ZpbGVfaW1nIjpudWxsLCJhY2NvdW50X3R5cGUiOiJlbWFpbCJ9.SlQSCfAof1bv2YxmW2DO4dIBrbHLg1jPO3AJsX6xKbw'
def get_dict(self):
info = {}
for k, v in self.__dict__.items():
if type(v) is tuple:
info[k] = v[0]
else:
info[k] = v
return info
ADMIN_INIT_ACCOUNT_INFO = AdminInfo()
FERNET_SECRET_KEY = b'wQjpSYkmc4kX8MaAovk1NIHF02R2wZX760eeBTeIHW4='
AES_CBC_PUBLIC_KEY = b'daooldns12345678'
AES_CBC_IV = b'daooldns12345678'
COOKIES_AUTH = 'Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6MTQsImVtYWlsIjoia29hbGFAZGluZ3JyLmNvbSIsIm5hbWUiOm51bGwsInBob25lX251bWJlciI6bnVsbCwicHJvZmlsZV9pbWciOm51bGwsInNuc190eXBlIjpudWxsfQ.4vgrFvxgH8odoXMvV70BBqyqXOFa2NDQtzYkGywhV48'
JWT_SECRET = 'ABCD1234!'
JWT_ALGORITHM = 'HS256'
EXCEPT_PATH_LIST = ['/', '/openapi.json']
EXCEPT_PATH_REGEX = '^(/docs|/redoc' +\
'|/api/dev' +\
'|/api/services' +\
')'
MAX_API_KEY = 3
MAX_API_WHITELIST = 10
NUM_RETRY_UUID_GEN = 3
if project_config.CONFIG == project_config.CONFIG_AISERVER or \
project_config.CONFIG == project_config.CONFIG_MG:
# DATABASE(MG/safe6)
DB_ADDRESS = 'localhost'
DB_PORT = '50701'
DB_USER_ID = 'aikepco'
DB_USER_PW = '!ekdnfeldpsdptm1'
DB_NAME = 'AI_ENGINE_CONTROL'
DB_CHARSET = 'utf8mb4'
else:
# DATABASE(fermat)
DB_ADDRESS = 'localhost'
DB_PORT = '3306'
DB_USER_ID = 'root'
DB_USER_PW = '1234'
DB_NAME = 'AI_ENGINE_CONTROL'
DB_CHARSET = 'utf8mb4'
# MAIL
SMTP_HOST = 'smtp.gmail.com'
SMTP_PORT = 587
SMTP_HOST = 'smtp.naver.com'
SMTP_PORT = 587
MAIL_REG_TITLE = f'{PROJECT_NAME} - Registration'
MAIL_REG_CONTENTS = '''
안녕하세요.
[DAOOLDNS] AI-KEPCO REST SERVER 입니다.
Account: {}
Number of License: {}
감사합니다.
'''
AI_ENGINE_FR_DETECT_TIME_MIN = 5

View File

@@ -0,0 +1,143 @@
# -*- coding: utf-8 -*-
"""
@File: conn.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: DB 정보
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
from fastapi import FastAPI
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
import logging
def _database_exist(engine, schema_name):
query = f'SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = "{schema_name}"'
with engine.connect() as conn:
result_proxy = conn.execute(query)
result = result_proxy.scalar()
return bool(result)
def _drop_database(engine, schema_name):
with engine.connect() as conn:
conn.execute(f'DROP DATABASE {schema_name};')
def _create_database(engine, schema_name):
with engine.connect() as conn:
conn.execute(f'CREATE DATABASE {schema_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_bin;')
class SQLAlchemy:
def __init__(self, app: FastAPI = None, **kwargs):
self._engine = None
self._session = None
if app is not None:
self.init_app(app=app, **kwargs)
def init_app(self, app: FastAPI, **kwargs):
"""
DB 초기화 함수
:param app: FastAPI 인스턴스
:param kwargs:
:return:
"""
database_url = kwargs.get('DB_URL')
pool_recycle = kwargs.setdefault('DB_POOL_RECYCLE', 900)
is_testing = kwargs.setdefault('TEST_MODE', False)
echo = kwargs.setdefault('DB_ECHO', True)
self._engine = create_engine(
database_url,
echo=echo,
pool_recycle=pool_recycle,
pool_pre_ping=True,
)
if is_testing: # create schema
db_url = self._engine.url
if db_url.host != 'localhost':
raise Exception('db host must be \'localhost\' in test environment')
except_schema_db_url = f'{db_url.drivername}://{db_url.username}:{db_url.password}@{db_url.host}:{db_url.port}'
schema_name = db_url.database
temp_engine = create_engine(except_schema_db_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True)
if _database_exist(temp_engine, schema_name):
_drop_database(temp_engine, schema_name)
_create_database(temp_engine, schema_name)
temp_engine.dispose()
else:
db_url = self._engine.url
except_schema_db_url = f'{db_url.drivername}://{db_url.username}:{db_url.password}@{db_url.host}:{db_url.port}'
schema_name = db_url.database
temp_engine = create_engine(except_schema_db_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True)
if not _database_exist(temp_engine, schema_name):
_create_database(temp_engine, schema_name)
Base.metadata.create_all(db.engine)
temp_engine.dispose()
self._session = sessionmaker(autocommit=False, autoflush=False, bind=self._engine)
# NOTE(hsj100): ADMINISTRATOR
create_admin(self._session)
@app.on_event('startup')
def startup():
self._engine.connect()
logging.info('DB connected.')
@app.on_event('shutdown')
def shutdown():
self._session.close_all()
self._engine.dispose()
logging.info('DB disconnected')
def get_db(self):
"""
요청마다 DB 세션 유지 함수
:return:
"""
if self._session is None:
raise Exception('must be called \'init_app\'')
db_session = None
try:
db_session = self._session()
yield db_session
finally:
db_session.close()
@property
def session(self):
return self.get_db
@property
def engine(self):
return self._engine
db = SQLAlchemy()
Base = declarative_base()
# NOTE(hsj100): ADMINISTRATOR
def create_admin(db_session):
import bcrypt
from app.database.schema import Users
from app.common.consts import ADMIN_INIT_ACCOUNT_INFO
session = db_session()
if not session:
raise Exception('cat`t create account of admin')
if Users.get(account=ADMIN_INIT_ACCOUNT_INFO.account):
return
admin = {**ADMIN_INIT_ACCOUNT_INFO.get_dict()}
Users.create(session=session, auto_commit=True, **admin)

View File

@@ -0,0 +1,284 @@
# -*- coding: utf-8 -*-
"""
@File: crud.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: DB CRUD
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
import math
from datetime import datetime
from dateutil.relativedelta import relativedelta
from sqlalchemy import func, desc
from fastapi import APIRouter, Depends, Body
from sqlalchemy.orm import Session
from app import models as M
from app.database.conn import Base, db
from app.database.schema import Users, UserLog
from app.utils.extra import query_to_groupby, query_to_groupby_date
def request_parser(request_data: dict = None) -> dict:
"""
request information -> dict
:param request_data:
:return:
"""
result_dict = dict()
if not request_data:
return result_dict
for key, val in request_data.items():
if val is not None:
result_dict[key] = val if val != 'null' else None
return result_dict
def dict_to_filter_fmt(dict_data, get_attributes_callback=None):
"""
dict -> sqlalchemy filter (criterion: sql expression)
:param dict_data:
:param get_attributes_callback:
:return:
"""
if get_attributes_callback is None:
raise Exception('invalid get_attributes_callback')
criterion = list()
for key, val in dict_data.items():
key = key.split('__')
if len(key) > 2:
raise Exception('length of split(key) should be no more than 2.')
key_length = len(key)
col = get_attributes_callback(key[0])
if col is None:
continue
if key_length == 1:
criterion.append((col == val))
elif key_length == 2 and key[1] == 'gt':
criterion.append((col > val))
elif key_length == 2 and key[1] == 'gte':
criterion.append((col >= val))
elif key_length == 2 and key[1] == 'lt':
criterion.append((col < val))
elif key_length == 2 and key[1] == 'lte':
criterion.append((col <= val))
elif key_length == 2 and key[1] == 'in':
criterion.append((col.in_(val)))
elif key_length == 2 and key[1] == 'like':
criterion.append((col.like(val)))
return criterion
async def table_select(accessor_info, target_table, request_body_info, response_model, response_model_data):
"""
table_read
"""
try:
# parameter
if not accessor_info:
raise Exception('invalid accessor')
if not target_table:
raise Exception(f'invalid table_name:{target_table}')
# if not request_body_info:
# raise Exception('invalid request_body_info')
if not response_model:
raise Exception('invalid response_model')
if not response_model_data:
raise Exception('invalid response_model_data')
# paging - request
paging_request = None
if request_body_info:
if hasattr(request_body_info, 'paging'):
paging_request = request_body_info.paging
request_body_info = request_body_info.search
# request
request_info = request_parser(request_body_info.dict())
# search
criterion = None
if isinstance(request_body_info, M.UserLogDaySearchReq):
# UserLog search
# request
request_info = request_parser(request_body_info.dict())
# search UserLog
def get_attributes_callback(key: str):
return getattr(UserLog, key)
criterion = dict_to_filter_fmt(request_info, get_attributes_callback)
# search
session = next(db.session())
search_info = session.query(UserLog)\
.filter(UserLog.mac.isnot(None)
, UserLog.api == '/api/auth/login'
, UserLog.type == M.TypeUserLog.info
, UserLog.message == 'ok'
, *criterion)\
.order_by(desc(UserLog.created_at), desc(UserLog.updated_at))\
.all()
if not search_info:
raise Exception('not found data')
group_by_day = query_to_groupby_date(search_info, 'updated_at')
# result
result_info = list()
for day, info_list in group_by_day.items():
info_by_mac = query_to_groupby(info_list, 'mac', first=True)
for log_info in info_by_mac.values():
result_info.append(response_model_data.from_orm(log_info))
else:
# basic search (single table)
# request
request_info = request_parser(request_body_info.dict())
# search
search_info = target_table.filter(**request_info).all()
if not search_info:
raise Exception('not found data')
# result
result_info = list()
for purchase_info in search_info:
result_info.append(response_model_data.from_orm(purchase_info))
# response - paging
paging_response = None
if paging_request:
total_contents_num = len(result_info)
total_page_num = math.ceil(total_contents_num / paging_request.page_contents_num)
start_contents_index = (paging_request.start_page - 1) * paging_request.page_contents_num
end_contents_index = start_contents_index + paging_request.page_contents_num
if end_contents_index < total_contents_num:
result_info = result_info[start_contents_index: end_contents_index]
else:
result_info = result_info[start_contents_index:]
paging_response = M.PagingRes()
paging_response.total_page_num = total_page_num
paging_response.total_contents_num = total_contents_num
paging_response.start_page = paging_request.start_page
paging_response.search_contents_num = len(result_info)
return response_model(data=result_info, paging=paging_response)
except Exception as e:
return response_model.set_error(str(e))
async def table_update(accessor_info, target_table, request_body_info, response_model, response_model_data=None):
search_info = None
try:
# parameter
if not accessor_info:
raise Exception('invalid accessor')
if not target_table:
raise Exception(f'invalid table_name:{target_table}')
if not request_body_info:
raise Exception('invalid request_body_info')
if not response_model:
raise Exception('invalid response_model')
# if not response_model_data:
# raise Exception('invalid response_model_data')
# request
if not request_body_info.search_info:
raise Exception('invalid request_body: search_info')
request_search_info = request_parser(request_body_info.search_info.dict())
if not request_search_info:
raise Exception('invalid request_body: search_info')
if not request_body_info.update_info:
raise Exception('invalid request_body: update_info')
request_update_info = request_parser(request_body_info.update_info.dict())
if not request_update_info:
raise Exception('invalid request_body: update_info')
# search
search_info = target_table.filter(**request_search_info)
# process
search_info.update(auto_commit=True, synchronize_session=False, **request_update_info)
# result
return response_model()
except Exception as e:
if search_info:
search_info.close()
return response_model.set_error(str(e))
async def table_delete(accessor_info, target_table, request_body_info, response_model, response_model_data=None):
search_info = None
try:
# request
if not accessor_info:
raise Exception('invalid accessor')
if not target_table:
raise Exception(f'invalid table_name:{target_table}')
if not request_body_info:
raise Exception('invalid request_body_info')
if not response_model:
raise Exception('invalid response_model')
# if not response_model_data:
# raise Exception('invalid response_model_data')
# request
request_search_info = request_parser(request_body_info.dict())
if not request_search_info:
raise Exception('invalid request_body')
# search
search_info = target_table.filter(**request_search_info)
temp_search = search_info.all()
# process
search_info.delete(auto_commit=True, synchronize_session=False)
# update license num
uuid_list = list()
for _license in temp_search:
if not hasattr(temp_search, 'uuid'):
# case: license
break
if _license.uuid not in uuid_list:
uuid_list.append(_license.uuid)
license_num = target_table.filter(uuid=_license.uuid).count()
target_table.filter(uuid=_license.uuid).update(auto_commit=True, synchronize_session=False, num=license_num)
# result
return response_model()
except Exception as e:
if search_info:
search_info.close()
return response_model.set_error(str(e))

View File

@@ -0,0 +1,259 @@
# -*- coding: utf-8 -*-
"""
@File: schema.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: DB schema
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
func,
Enum,
Boolean,
ForeignKey,
)
from sqlalchemy.orm import Session, relationship
from app.database.conn import Base, db
from app.utils.date_utils import D
from app.models import (
TypeUserSex,
TypeUserGrade,
TypeUserMembership,
TypeUserAccount,
TypeUserAccountStatus,
TyepUserLoginStatus,
TypeUserLog
)
class BaseMixin:
id = Column(Integer, primary_key=True, index=True, comment='DB INDEX')
created_at = Column(DateTime, nullable=False, default=D.datetime(), comment='생성 날짜')
updated_at = Column(DateTime, nullable=False, default=D.datetime(), onupdate=D.datetime(), comment='변경 날짜')
def __init__(self):
self._q = None
self._session = None
self.served = None
def all_columns(self):
return [c for c in self.__table__.columns if c.primary_key is False and c.name != 'created_at']
def __hash__(self):
return hash(self.id)
@classmethod
def create(cls, session: Session, auto_commit=False, **kwargs):
"""
테이블 데이터 적재 전용 함수
:param session:
:param auto_commit: 자동 커밋 여부
:param kwargs: 적재 할 데이터
:return:
"""
obj = cls()
# NOTE(hsj100) : FIX_DATETIME
if 'created_at' not in kwargs:
obj.created_at = D.datetime()
if 'updated_at' not in kwargs:
obj.updated_at = D.datetime()
for col in obj.all_columns():
col_name = col.name
if col_name in kwargs:
setattr(obj, col_name, kwargs.get(col_name))
session.add(obj)
session.flush()
if auto_commit:
session.commit()
return obj
@classmethod
def get(cls, session: Session = None, **kwargs):
"""
Simply get a Row
:param session:
:param kwargs:
:return:
"""
sess = next(db.session()) if not session else session
query = sess.query(cls)
for key, val in kwargs.items():
col = getattr(cls, key)
query = query.filter(col == val)
if query.count() > 1:
raise Exception('Only one row is supposed to be returned, but got more than one.')
result = query.first()
if not session:
sess.close()
return result
@classmethod
def filter(cls, session: Session = None, **kwargs):
"""
Simply get a Row
:param session:
:param kwargs:
:return:
"""
cond = []
for key, val in kwargs.items():
key = key.split('__')
if len(key) > 2:
raise Exception('length of split(key) should be no more than 2.')
col = getattr(cls, key[0])
if len(key) == 1: cond.append((col == val))
elif len(key) == 2 and key[1] == 'gt': cond.append((col > val))
elif len(key) == 2 and key[1] == 'gte': cond.append((col >= val))
elif len(key) == 2 and key[1] == 'lt': cond.append((col < val))
elif len(key) == 2 and key[1] == 'lte': cond.append((col <= val))
elif len(key) == 2 and key[1] == 'in': cond.append((col.in_(val)))
elif len(key) == 2 and key[1] == 'like': cond.append((col.like(val)))
obj = cls()
if session:
obj._session = session
obj.served = True
else:
obj._session = next(db.session())
obj.served = False
query = obj._session.query(cls)
query = query.filter(*cond)
obj._q = query
return obj
@classmethod
def cls_attr(cls, col_name=None):
if col_name:
col = getattr(cls, col_name)
return col
else:
return cls
def order_by(self, *args: str):
for a in args:
if a.startswith('-'):
col_name = a[1:]
is_asc = False
else:
col_name = a
is_asc = True
col = self.cls_attr(col_name)
self._q = self._q.order_by(col.asc()) if is_asc else self._q.order_by(col.desc())
return self
def update(self, auto_commit: bool = False, synchronize_session='evaluate', **kwargs):
# NOTE(hsj100) : FIX_DATETIME
if 'updated_at' not in kwargs:
kwargs['updated_at'] = D.datetime()
qs = self._q.update(kwargs, synchronize_session=synchronize_session)
get_id = self.id
ret = None
self._session.flush()
if qs > 0 :
ret = self._q.first()
if auto_commit:
self._session.commit()
return ret
def first(self):
result = self._q.first()
self.close()
return result
def delete(self, auto_commit: bool = False, synchronize_session='evaluate'):
self._q.delete(synchronize_session=synchronize_session)
if auto_commit:
self._session.commit()
def all(self):
print(self.served)
result = self._q.all()
self.close()
return result
def count(self):
result = self._q.count()
self.close()
return result
def close(self):
if not self.served:
self._session.close()
else:
self._session.flush()
class ApiKeys(Base, BaseMixin):
__tablename__ = 'api_keys'
access_key = Column(String(length=64), nullable=False, index=True)
secret_key = Column(String(length=64), nullable=False)
user_memo = Column(String(length=40), nullable=True)
status = Column(Enum('active', 'stopped', 'deleted'), default='active')
is_whitelisted = Column(Boolean, default=False)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
whitelist = relationship('ApiWhiteLists', backref='api_keys')
users = relationship('Users', back_populates='keys')
class ApiWhiteLists(Base, BaseMixin):
__tablename__ = 'api_whitelists'
ip_addr = Column(String(length=64), nullable=False)
api_key_id = Column(Integer, ForeignKey('api_keys.id'), nullable=False)
class Users(Base, BaseMixin):
__tablename__ = 'users'
status = Column(Enum(TypeUserAccountStatus), nullable=False, default=TypeUserAccountStatus.active, comment='계정 상태')
user_grade = Column(Enum(TypeUserGrade), nullable=False, default=TypeUserGrade.user, comment='등급')
membership = Column(Enum(TypeUserMembership), nullable=False, default=TypeUserMembership.personal, comment='회원 종류')
account_type = Column(Enum(TypeUserAccount), nullable=False, default=TypeUserAccount.email, comment='계정 종류')
account = Column(String(length=255), nullable=False, unique=True, comment='계정')
pw = Column(String(length=2000), nullable=True, comment='계정 비밀번호')
email = Column(String(length=255), nullable=True, comment='전자 메일')
name = Column(String(length=255), nullable=False, comment='이름')
sex = Column(Enum(TypeUserSex), nullable=False, default=TypeUserSex.male, comment='성별')
rrn = Column(String(length=255), nullable=True, comment='주민등록번호')
address = Column(String(length=2000), nullable=True, comment='주소')
mobile_number = Column(String(length=20), nullable=True, comment='휴대폰 번호')
landline_number = Column(String(length=20), nullable=True, comment='유선 번호')
marketing_agree = Column(Boolean, nullable=False, default=False, comment='마켓팅 수신 동의')
picture = Column(String(length=1000), nullable=True, comment='프로필 사진')
keys = relationship('ApiKeys', back_populates='users')
# extra
login = Column(Enum(TyepUserLoginStatus), nullable=False, default=TyepUserLoginStatus.logout, comment='사용자 접속 상태') # TODO(hsj100): LOGIN_STATUS
# uuid = Column(String(length=36), nullable=True, unique=True)
# relationship
# NOTE(hsj100): Users:userlog => 1:N
# userlog = relationship('UserLog', back_populates='users', lazy=False)
class UserLog(Base, BaseMixin):
__tablename__ = 'userlog'
account = Column(String(length=255), nullable=False, comment='계정')
type = Column(Enum(TypeUserLog), nullable=False, default=TypeUserLog.info, comment='로그 종류')
api = Column(String(length=511), nullable=False, comment='API')
message = Column(String(length=5000), nullable=True, comment='로그 메시지')
# mac = Column(String(length=255), nullable=True, comment='사용자 MAC 주소')
# NOTE(hsj100): 다단계 자동 삭제
# user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'), nullable=False)
# user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
# users = relationship('Users', back_populates='userlog')
# users = relationship('Users', back_populates='userlog', lazy=False)

View File

@@ -0,0 +1,188 @@
from app.common.consts import MAX_API_KEY, MAX_API_WHITELIST
class StatusCode:
HTTP_500 = 500
HTTP_400 = 400
HTTP_401 = 401
HTTP_403 = 403
HTTP_404 = 404
HTTP_405 = 405
class APIException(Exception):
status_code: int
code: str
msg: str
detail: str
ex: Exception
def __init__(
self,
*,
status_code: int = StatusCode.HTTP_500,
code: str = '000000',
msg: str = None,
detail: str = None,
ex: Exception = None,
):
self.status_code = status_code
self.code = code
self.msg = msg
self.detail = detail
self.ex = ex
super().__init__(ex)
class NotFoundUserEx(APIException):
def __init__(self, user_id: int = None, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_404,
msg=f'해당 유저를 찾을 수 없습니다.',
detail=f'Not Found User ID : {user_id}',
code=f'{StatusCode.HTTP_400}{"1".zfill(4)}',
ex=ex,
)
class NotAuthorized(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_401,
msg=f'로그인이 필요한 서비스 입니다.',
detail='Authorization Required',
code=f'{StatusCode.HTTP_401}{"1".zfill(4)}',
ex=ex,
)
class TokenExpiredEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'세션이 만료되어 로그아웃 되었습니다.',
detail='Token Expired',
code=f'{StatusCode.HTTP_400}{"1".zfill(4)}',
ex=ex,
)
class TokenDecodeEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'비정상적인 접근입니다.',
detail='Token has been compromised.',
code=f'{StatusCode.HTTP_400}{"2".zfill(4)}',
ex=ex,
)
class NoKeyMatchEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_404,
msg=f'해당 키에 대한 권한이 없거나 해당 키가 없습니다.',
detail='No Keys Matched',
code=f'{StatusCode.HTTP_404}{"3".zfill(4)}',
ex=ex,
)
class MaxKeyCountEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'API 키 생성은 {MAX_API_KEY}개 까지 가능합니다.',
detail='Max Key Count Reached',
code=f'{StatusCode.HTTP_400}{"4".zfill(4)}',
ex=ex,
)
class MaxWLCountEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'화이트리스트 생성은 {MAX_API_WHITELIST}개 까지 가능합니다.',
detail='Max Whitelist Count Reached',
code=f'{StatusCode.HTTP_400}{"5".zfill(4)}',
ex=ex,
)
class InvalidIpEx(APIException):
def __init__(self, ip: str, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'{ip}는 올바른 IP 가 아닙니다.',
detail=f'invalid IP : {ip}',
code=f'{StatusCode.HTTP_400}{"6".zfill(4)}',
ex=ex,
)
class SqlFailureEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_500,
msg=f'이 에러는 서버측 에러 입니다. 자동으로 리포팅 되며, 빠르게 수정하겠습니다.',
detail='Internal Server Error',
code=f'{StatusCode.HTTP_500}{"2".zfill(4)}',
ex=ex,
)
class APIQueryStringEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'쿼리스트링은 key, timestamp 2개만 허용되며, 2개 모두 요청시 제출되어야 합니다.',
detail='Query String Only Accept key and timestamp.',
code=f'{StatusCode.HTTP_400}{"7".zfill(4)}',
ex=ex,
)
class APIHeaderInvalidEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'헤더에 키 해싱된 Secret 이 없거나, 유효하지 않습니다.',
detail='Invalid HMAC secret in Header',
code=f'{StatusCode.HTTP_400}{"8".zfill(4)}',
ex=ex,
)
class APITimestampEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'쿼리스트링에 포함된 타임스탬프는 KST 이며, 현재 시간보다 작아야 하고, 현재시간 - 10초 보다는 커야 합니다.',
detail='timestamp in Query String must be KST, Timestamp must be less than now, and greater than now - 10.',
code=f'{StatusCode.HTTP_400}{"9".zfill(4)}',
ex=ex,
)
class NotFoundAccessKeyEx(APIException):
def __init__(self, api_key: str, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_404,
msg=f'API 키를 찾을 수 없습니다.',
detail=f'Not found such API Access Key : {api_key}',
code=f'{StatusCode.HTTP_404}{"10".zfill(4)}',
ex=ex,
)
class KakaoSendFailureEx(APIException):
def __init__(self, ex: Exception = None):
super().__init__(
status_code=StatusCode.HTTP_400,
msg=f'카카오톡 전송에 실패했습니다.',
detail=f'Failed to send KAKAO MSG.',
code=f'{StatusCode.HTTP_400}{"11".zfill(4)}',
ex=ex,
)

View File

@@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
"""
@File: main.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: main module
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
from dataclasses import asdict
import uvicorn
from fastapi import FastAPI, Depends
from fastapi.security import APIKeyHeader
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from app.common import consts
from app.database.conn import db
from app.common.config import conf
from app.middlewares.token_validator import access_control
from app.middlewares.trusted_hosts import TrustedHostMiddleware
from app.routes import dev, index, auth, users, services
API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False)
def create_app():
"""
fast_api_app 생성
:return: fast_api_app
"""
configurations = conf()
fast_api_app = FastAPI(
title=configurations.SW_TITLE,
version=configurations.SW_VERSION,
description=configurations.SW_DESCRIPTION,
terms_of_service=configurations.TERMS_OF_SERVICE,
contact=configurations.CONTEACT,
license_info=configurations.LICENSE_INFO
)
# 데이터 베이스 이니셜라이즈
conf_dict = asdict(configurations)
db.init_app(fast_api_app, **conf_dict)
# 레디스 이니셜라이즈
# 미들웨어 정의
fast_api_app.add_middleware(middleware_class=BaseHTTPMiddleware, dispatch=access_control)
fast_api_app.add_middleware(
CORSMiddleware,
allow_origins=conf().ALLOW_SITE,
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
fast_api_app.add_middleware(TrustedHostMiddleware, allowed_hosts=conf().TRUSTED_HOSTS, except_path=['/health'])
# 라우터 정의
fast_api_app.include_router(index.router, tags=['Defaults'])
fast_api_app.include_router(auth.router, tags=['Authentication'], prefix='/api', dependencies=[Depends(API_KEY_HEADER)])
# fast_api_app.include_router(services.router, tags=['Services'], prefix='/api', dependencies=[Depends(API_KEY_HEADER)])
fast_api_app.include_router(services.router, tags=['Services'], prefix='/api')
fast_api_app.include_router(users.router, tags=['Users'], prefix='/api', dependencies=[Depends(API_KEY_HEADER)])
if conf().DEBUG:
fast_api_app.include_router(dev.router, tags=['Developments'], prefix='/api')
return fast_api_app
app = create_app()
if __name__ == '__main__':
uvicorn.run('main:app', host='0.0.0.0', port=conf().REST_SERVER_PORT, reload=True)

View File

@@ -0,0 +1,166 @@
import base64
import hmac
import json
import time
import typing
import re
import jwt
import sqlalchemy.exc
from jwt.exceptions import ExpiredSignatureError, DecodeError
from starlette.requests import Request
from starlette.responses import JSONResponse
from app.common.consts import EXCEPT_PATH_LIST, EXCEPT_PATH_REGEX
from app.database.conn import db
from app.database.schema import Users, ApiKeys
from app.errors import exceptions as ex
from app.common import consts
from app.common.config import conf
from app.errors.exceptions import APIException, SqlFailureEx, APIQueryStringEx
from app.models import UserToken
from app.utils.date_utils import D
from app.utils.logger import api_logger
from app.utils.query_utils import to_dict
from dataclasses import asdict
async def access_control(request: Request, call_next):
request.state.req_time = D.datetime()
request.state.start = time.time()
request.state.inspect = None
request.state.user = None
request.state.service = None
ip = request.headers['x-forwarded-for'] if 'x-forwarded-for' in request.headers.keys() else request.client.host
request.state.ip = ip.split(',')[0] if ',' in ip else ip
headers = request.headers
cookies = request.cookies
url = request.url.path
if await url_pattern_check(url, EXCEPT_PATH_REGEX) or url in EXCEPT_PATH_LIST:
response = await call_next(request)
if url != '/':
await api_logger(request=request, response=response)
return response
try:
if url.startswith('/api'):
# api 인경우 헤더로 토큰 검사
# NOTE(hsj100): SERVICE_AUTH_API_KEY (token 방식으로 진행)
if url.startswith('/api/services') and conf().SERVICE_AUTH_API_KEY:
qs = str(request.query_params)
qs_list = qs.split('&')
session = next(db.session())
if not conf().DEBUG:
try:
qs_dict = {qs_split.split('=')[0]: qs_split.split('=')[1] for qs_split in qs_list}
except Exception:
raise ex.APIQueryStringEx()
qs_keys = qs_dict.keys()
if 'key' not in qs_keys or 'timestamp' not in qs_keys:
raise ex.APIQueryStringEx()
if 'secret' not in headers.keys():
raise ex.APIHeaderInvalidEx()
api_key = ApiKeys.get(session=session, access_key=qs_dict['key'])
if not api_key:
raise ex.NotFoundAccessKeyEx(api_key=qs_dict['key'])
mac = hmac.new(bytes(api_key.secret_key, encoding='utf8'), bytes(qs, encoding='utf-8'), digestmod='sha256')
d = mac.digest()
validating_secret = str(base64.b64encode(d).decode('utf-8'))
if headers['secret'] != validating_secret:
raise ex.APIHeaderInvalidEx()
now_timestamp = int(D.datetime(diff=9).timestamp())
if now_timestamp - 10 > int(qs_dict['timestamp']) or now_timestamp < int(qs_dict['timestamp']):
raise ex.APITimestampEx()
user_info = to_dict(api_key.users)
request.state.user = UserToken(**user_info)
else:
# Request User 가 필요함
if 'authorization' in headers.keys():
key = headers.get('Authorization')
api_key_obj = ApiKeys.get(session=session, access_key=key)
user_info = to_dict(Users.get(session=session, id=api_key_obj.user_id))
request.state.user = UserToken(**user_info)
# 토큰 없음
else:
if 'Authorization' not in headers.keys():
raise ex.NotAuthorized()
session.close()
response = await call_next(request)
return response
else:
if 'authorization' in headers.keys():
# 토큰 존재
token_info = await token_decode(access_token=headers.get('Authorization'))
request.state.user = UserToken(**token_info)
elif conf().DEV_TEST_CONNECT_ACCOUNT:
# NOTE(hsj100): DEV_TEST_CONNECT_ACCOUNT
request.state.user = UserToken.from_orm(conf().DEV_TEST_CONNECT_ACCOUNT)
else:
# 토큰 없음
if 'Authorization' not in headers.keys():
raise ex.NotAuthorized()
else:
# 템플릿 렌더링인 경우 쿠키에서 토큰 검사
cookies['Authorization'] = conf().COOKIES_AUTH
if 'Authorization' not in cookies.keys():
raise ex.NotAuthorized()
token_info = await token_decode(access_token=cookies.get('Authorization'))
request.state.user = UserToken(**token_info)
response = await call_next(request)
await api_logger(request=request, response=response)
except Exception as e:
error = await exception_handler(e)
error_dict = dict(status=error.status_code, msg=error.msg, detail=error.detail, code=error.code)
response = JSONResponse(status_code=error.status_code, content=error_dict)
await api_logger(request=request, error=error)
return response
async def url_pattern_check(path, pattern):
result = re.match(pattern, path)
if result:
return True
return False
async def token_decode(access_token):
"""
:param access_token:
:return:
"""
try:
access_token = access_token.replace('Bearer ', "")
payload = jwt.decode(access_token, key=consts.JWT_SECRET, algorithms=[consts.JWT_ALGORITHM])
except ExpiredSignatureError:
raise ex.TokenExpiredEx()
except DecodeError:
raise ex.TokenDecodeEx()
return payload
async def exception_handler(error: Exception):
print(error)
if isinstance(error, sqlalchemy.exc.OperationalError):
error = SqlFailureEx(ex=error)
if not isinstance(error, APIException):
error = APIException(ex=error, detail=str(error))
return error

View File

@@ -0,0 +1,63 @@
import typing
from starlette.datastructures import URL, Headers
from starlette.responses import PlainTextResponse, RedirectResponse, Response
from starlette.types import ASGIApp, Receive, Scope, Send
ENFORCE_DOMAIN_WILDCARD = 'Domain wildcard patterns must be lNo module named \'app\'ike \'*.example.com\'.'
class TrustedHostMiddleware:
def __init__(
self,
app: ASGIApp,
allowed_hosts: typing.Sequence[str] = None,
except_path: typing.Sequence[str] = None,
www_redirect: bool = True,
) -> None:
if allowed_hosts is None:
allowed_hosts = ['*']
if except_path is None:
except_path = []
for pattern in allowed_hosts:
assert '*' not in pattern[1:], ENFORCE_DOMAIN_WILDCARD
if pattern.startswith('*') and pattern != '*':
assert pattern.startswith('*.'), ENFORCE_DOMAIN_WILDCARD
self.app = app
self.allowed_hosts = list(allowed_hosts)
self.allow_any = '*' in allowed_hosts
self.www_redirect = www_redirect
self.except_path = list(except_path)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.allow_any or scope['type'] not in ('http', 'websocket',): # pragma: no cover
await self.app(scope, receive, send)
return
headers = Headers(scope=scope)
host = headers.get('host', "").split(':')[0]
is_valid_host = False
found_www_redirect = False
for pattern in self.allowed_hosts:
if (
host == pattern
or (pattern.startswith('*') and host.endswith(pattern[1:]))
or URL(scope=scope).path in self.except_path
):
is_valid_host = True
break
elif 'www.' + host == pattern:
found_www_redirect = True
if is_valid_host:
await self.app(scope, receive, send)
else:
if found_www_redirect and self.www_redirect:
url = URL(scope=scope)
redirect_url = url.replace(netloc='www.' + url.netloc)
response = RedirectResponse(url=str(redirect_url)) # type: Response
else:
response = PlainTextResponse('Invalid host header', status_code=400)
await response(scope, receive, send)

View File

@@ -0,0 +1,963 @@
# -*- coding: utf-8 -*-
"""
@File: models.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: data models
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
from datetime import datetime
from enum import Enum
from typing import List, Dict
from pydantic import Field
from pydantic.main import BaseModel
from pydantic.networks import EmailStr, IPvAnyAddress
from typing import Optional
from app.common.consts import (
SW_TITLE,
SW_VERSION,
MAIL_REG_TITLE,
MAIL_REG_CONTENTS,
SMTP_HOST,
SMTP_PORT,
ADMIN_INIT_ACCOUNT_INFO,
DEFAULT_USER_ACCOUNT_PW,
AI_ENGINE_FR_DETECT_TIME_MIN
)
from app.utils.date_utils import D
class SWInfo(BaseModel):
"""
### 서비스 정보
"""
name: str = Field(SW_TITLE, description='SW 이름', example=SW_TITLE)
version: str = Field(SW_VERSION, description='SW 버전', example=SW_VERSION)
date: str = Field(D.date_str(), description='현재 날짜', example='%Y.%m.%dT%H:%M:%S')
class CustomEnum(Enum):
@classmethod
def get_elements_str(cls, is_key=True):
if is_key:
result_str = cls.__members__.keys()
else:
result_str = cls.__members__.values()
return '[' + ', '.join(result_str) + ']'
class TypeUserSex(str, CustomEnum):
"""
### 사용자 성별 타입
"""
male: str = 'male'
female: str = 'female'
class TypeUserAccount(str, CustomEnum):
"""
### 사용자 계정 타입
"""
email: str = 'email'
# facebook: str = 'facebook'
# google: str = 'google'
# kakao: str = 'kakao'
class TypeUserAccountStatus(str, CustomEnum):
"""
### 사용자 계정 상태 타입
"""
active: str = 'active'
deleted: str = 'deleted'
blocked: str = 'blocked'
class TyepUserLoginStatus(str, CustomEnum):
"""
### 유저 로그인 상태 타입
"""
login: str = 'login'
logout: str = 'logout'
class TypeUserMembership(str, CustomEnum):
"""
### 사용자 회원 타입
"""
personal: str = 'personal'
company: str = 'company'
class TypeUserGrade(str, CustomEnum):
"""
### 사용자 등급 타입
"""
admin: str = 'admin'
user: str = 'user'
class TypeUserLog(str, CustomEnum):
"""
### 사용자 로그 타입
"""
info: str = 'info'
error: str = 'error'
class Token(BaseModel):
Authorization: str = Field(None, description='인증키', example='Bearer [token]')
class EmailRecipients(BaseModel):
name: str
email: str
class SendEmail(BaseModel):
email_to: List[EmailRecipients] = None
class KakaoMsgBody(BaseModel):
msg: str = None
class MessageOk(BaseModel):
message: str = Field(default='OK')
class UserToken(BaseModel):
id: int
account: str = None
name: str = None
phone_number: str = None
profile_img: str = None
account_type: str = None
class Config:
orm_mode = True
class AddApiKey(BaseModel):
user_memo: str = None
class Config:
orm_mode = True
class GetApiKeyList(AddApiKey):
id: int = None
access_key: str = None
created_at: datetime = None
class GetApiKeys(GetApiKeyList):
secret_key: str = None
class CreateAPIWhiteLists(BaseModel):
ip_addr: str = None
class GetAPIWhiteLists(CreateAPIWhiteLists):
id: int
class Config:
orm_mode = True
class ResponseBase(BaseModel):
"""
### [Response] API End-Point
**정상처리**\n
- result: true\n
- error: null\n
**오류발생**\n
- result: false\n
- error: 오류내용\n
"""
result: bool = Field(True, description='처리상태(성공: true, 실패: false)', example=False)
error: str = Field(None, description='오류내용(성공: null, 실패: 오류내용)', example='invalid data')
@staticmethod
def set_error(error):
ResponseBase.result = False
ResponseBase.error = str(error)
return ResponseBase
class Config:
orm_mode = True
class PagingReq(BaseModel):
"""
### [Request] 페이징 정보
"""
start_page: int = Field(None, description='시작 페이지 번호(base: 1)', example=1)
page_contents_num: int = Field(None, description='페이지 내용 개수', example=2)
class Config:
orm_mode = True
class PagingRes(BaseModel):
"""
### [Response] 페이징 정보
"""
total_page_num: int = Field(None, description='전체 페이지 개수', example=100)
total_contents_num: int = Field(None, description='전체 내용 개수', example=100)
start_page: int = Field(None, description='시작 페이지 번호(base: 1)', example=1)
search_contents_num: int = Field(None, description='검색된 내용 개수', example=100)
class Config:
orm_mode = True
class TokenRes(ResponseBase):
"""
### [Response] 토큰 정보
"""
Authorization: str = Field(None, description='인증키', example='Bearer [token]')
class Config:
orm_mode = True
class UserLogInfo(BaseModel):
"""
### 사용자 로그 정보
"""
id: int = Field(None, description='Table Index', example='1')
created_at: datetime = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
account: str = Field(None, description='계정', example='user1@test.com')
mac: str = Field(None, description='MAC(네트워크 인터페이스 식별자)', example='11:22:33:44:55:66')
type: str = Field(None, description='로그 타입' + TypeUserLog.get_elements_str(), example=TypeUserLog.info)
api: str = Field(None, description='API 이름', example='/api/auth/login')
message: str = Field(None, description='로그 내용', example='ok')
class Config:
orm_mode = True
class UserInfo(BaseModel):
"""
### 유저 정보
"""
id: int = Field(None, description='Table Index', example='1')
created_at: datetime = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
status: TypeUserAccountStatus = Field(None, description='계정상태' + TypeUserAccountStatus.get_elements_str(), example=TypeUserAccountStatus.active)
user_type: TypeUserGrade = Field(None, description='유저 타입' + TypeUserGrade.get_elements_str(), example=TypeUserGrade.user)
account_type: TypeUserAccount = Field(None, description='계정종류' + TypeUserAccount.get_elements_str(), example=TypeUserAccount.email)
account: str = Field(None, description='계정', example='user1@test.com')
email: str = Field(None, description='전자메일', example='user1@test.com')
name: str = Field(None, description='이름', example='user1')
sex: str = Field(None, description='성별' + TypeUserSex.get_elements_str(), example=TypeUserSex.male)
rrn: str = Field(None, description='주민등록번호', example='123456-1234567')
address: str = Field(None, description='주소', example='대구1')
phone_number: str = Field(None, description='연락처', example='010-1234-1234')
picture: str = Field(None, description='프로필사진', example='profile1.png')
marketing_agree: bool = Field(False, description='마케팅동의 여부', example=False)
# extra
login: TyepUserLoginStatus = Field(None, description='로그인 상태' + TyepUserLoginStatus.get_elements_str(), example=TyepUserLoginStatus.logout) # TODO(hsj100): LOGIN_STATUS
member_type: TypeUserMembership = Field(None, description='회원 타입' + TypeUserMembership.get_elements_str(), example=TypeUserMembership.personal)
class Config:
orm_mode = True
class SendMailReq(BaseModel):
"""
### [Request] 메일 전송
"""
smtp_host: str = Field(None, description='SMTP 서버 주소', example=SMTP_HOST)
smtp_port: int = Field(None, description='SMTP 서버 포트', example=SMTP_PORT)
title: str = Field(None, description='제목', example=MAIL_REG_TITLE)
recipient: str = Field(None, description='수신자', example='user1@test.com')
cc_list: list = Field(None, description='참조 리스트', example='["user2@test.com", "user3@test.com"]')
# recipient: str = Field(None, description='수신자', example='hsj100@daooldns.co.kr')
# cc_list: list = Field(None, description='참조 리스트', example=None)
contents_plain: str = Field(None, description='내용', example=MAIL_REG_CONTENTS.format('user1@test.com', 10))
contents_html: str = Field(None, description='내용', example='<p>내 고양이는 <strong>아주 고약해.</p></strong>')
class Config:
orm_mode = True
class UserInfoRes(ResponseBase):
"""
### [Response] 유저 정보
"""
data: UserInfo = None
class Config:
orm_mode = True
class UserLoginReq(BaseModel):
"""
### 유저 로그인 정보
"""
account: str = Field(description='계정', example='user1@test.com')
pw: str = Field(None, description='비밀번호 [관리자 필수]', example='1234')
# SW-LICENSE
sw: str = Field(None, description='라이선스 SW 이름 [유저 필수]', example='저작도구')
mac: str = Field(None, description='MAC [유저 필수]', example='11:22:33:44:55:01')
class UserLoginRes(ResponseBase):
"""
### [Response] 토큰 정보
"""
Authorization: str = Field(None, description='인증키', example='Bearer [token]')
users: UserInfo = None
class Config:
orm_mode = True
class UserRegisterReq(BaseModel):
"""
### [Request] 유저 등록
"""
status: Optional[str] = Field(TypeUserAccountStatus.active, description='계정상태' + TypeUserAccountStatus.get_elements_str(), example=TypeUserAccountStatus.active)
user_type: Optional[TypeUserGrade] = Field(TypeUserGrade.user, description='유저 타입' + TypeUserGrade.get_elements_str(), example=TypeUserGrade.user)
account: str = Field(description='계정', example='test@test.com')
pw: Optional[str] = Field(DEFAULT_USER_ACCOUNT_PW, description='비밀번호', example='1234')
email: Optional[str] = Field(None, description='전자메일', example='test@test.com')
name: str = Field(description='이름', example='test')
sex: Optional[TypeUserSex] = Field(TypeUserSex.male, description='성별' + TypeUserSex.get_elements_str(), example=TypeUserSex.male)
rrn: Optional[str] = Field(None, description='주민등록번호', example='19910101-1234567')
address: Optional[str] = Field(None, description='주소', example='대구광역시 동구 동촌로 351, 4층 (용계동, 에이스빌딩)')
phone_number: Optional[str] = Field(None, description='휴대전화', example='010-1234-1234')
picture: Optional[str] = Field(None, description='사진', example='profile1.png')
marketing_agree: Optional[bool] = Field(False, description='마케팅동의 여부', example=False)
# extra
member_type: Optional[TypeUserMembership] = Field(TypeUserMembership.personal, description='회원 타입' + TypeUserMembership.get_elements_str(), example=TypeUserMembership.personal)
class Config:
orm_mode = True
class UserSearchReq(BaseModel):
"""
### [Request] 유저 검색 (기본)
"""
# basic
id: Optional[int] = Field(None, description='등록번호', example='1')
id__in: Optional[list] = Field(None, description='등록번호 리스트', example=[1,])
created_at: Optional[str] = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
created_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
created_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
created_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
created_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
updated_at: Optional[str] = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
updated_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
updated_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
updated_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
updated_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
status: Optional[TypeUserAccountStatus] = Field(None, description='계정상태' + TypeUserAccountStatus.get_elements_str(), example=TypeUserAccountStatus.active)
user_type: Optional[TypeUserGrade] = Field(None, description='유저 타입' + TypeUserGrade.get_elements_str(), example=TypeUserGrade.user)
account_type: Optional[TypeUserAccount] = Field(None, description='계정종류' + TypeUserAccount.get_elements_str(), example=TypeUserAccount.email)
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
account__like: Optional[str] = Field(None, description='계정 부분검색', example='%user%')
email: Optional[str] = Field(None, description='전자메일', example='user1@test.com')
email__like: Optional[str] = Field(None, description='전자메일 부분검색', example='%user%')
name: Optional[str] = Field(None, description='이름', example='test')
name__like: Optional[str] = Field(None, description='이름 부분검색', example='%user%')
sex: Optional[TypeUserSex] = Field(None, description='성별' + TypeUserSex.get_elements_str(), example=TypeUserSex.male)
rrn: Optional[str] = Field(None, description='주민등록번호', example='123456-1234567')
address: Optional[str] = Field(None, description='주소', example='대구1')
address__like: Optional[str] = Field(None, description='주소 부분검색', example='%대구%')
phone_number: Optional[str] = Field(None, description='연락처', example='010-1234-1234')
picture: Optional[str] = Field(None, description='프로필사진', example='profile1.png')
marketing_agree: Optional[bool] = Field(None, description='마케팅동의 여부', example=False)
# extra
login: Optional[TyepUserLoginStatus] = Field(None, description='로그인 상태' + TyepUserLoginStatus.get_elements_str(), example=TyepUserLoginStatus.logout) # TODO(hsj100): LOGIN_STATUS
member_type: Optional[TypeUserMembership] = Field(None, description='회원 타입' + TypeUserMembership.get_elements_str(), example=TypeUserMembership.personal)
class Config:
orm_mode = True
class UserSearchRes(ResponseBase):
"""
### [Response] 유저 검색
"""
data: List[UserInfo] = []
class Config:
orm_mode = True
class UserSearchPagingReq(BaseModel):
"""
### [Request] 유저 페이징 검색
"""
# paging
paging: Optional[PagingReq] = None
search: Optional[UserSearchReq] = None
class UserSearchPagingRes(ResponseBase):
"""
### [Response] 유저 페이징 검색
"""
paging: PagingRes = None
data: List[UserInfo] = []
class Config:
orm_mode = True
class UserUpdateReq(BaseModel):
"""
### [Request] 유저 변경
"""
status: Optional[TypeUserAccountStatus] = Field(None, description='계정상태' + TypeUserAccountStatus.get_elements_str(), example=TypeUserAccountStatus.active)
user_type: Optional[TypeUserGrade] = Field(None, description='유저 타입' + TypeUserGrade.get_elements_str(), example=TypeUserGrade.user)
account_type: Optional[TypeUserAccount] = Field(None, description='계정종류' + TypeUserAccount.get_elements_str(), example=TypeUserAccount.email)
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
email: Optional[str] = Field(None, description='전자메일', example='user1@test.com')
name: Optional[str] = Field(None, description='이름', example='test')
sex: Optional[TypeUserSex] = Field(None, description='성별' + TypeUserSex.get_elements_str(), example=TypeUserSex.male)
rrn: Optional[str] = Field(None, description='주민등록번호', example='123456-1234567')
address: Optional[str] = Field(None, description='주소', example='대구1')
phone_number: Optional[str] = Field(None, description='연락처', example='010-1234-1234')
picture: Optional[str] = Field(None, description='프로필사진', example='profile1.png')
marketing_agree: Optional[bool] = Field(False, description='마케팅동의 여부', example=False)
# extra
login: Optional[TyepUserLoginStatus] = Field(None, description='로그인 상태' + TyepUserLoginStatus.get_elements_str(), example=TyepUserLoginStatus.logout) # TODO(hsj100): LOGIN_STATUS
# uuid: Optional[str] = Field(None, description='UUID', example='12345678-1234-5678-1234-567800000001')
member_type: Optional[TypeUserMembership] = Field(None, description='회원 타입' + TypeUserMembership.get_elements_str(), example=TypeUserMembership.personal)
class Config:
orm_mode = True
class UserUpdateMultiReq(BaseModel):
"""
### [Request] 유저 변경 (multi)
"""
search_info: UserSearchReq = None
update_info: UserUpdateReq = None
class Config:
orm_mode = True
class UserUpdatePWReq(BaseModel):
"""
### [Request] 유저 비밀번호 변경
"""
account: str = Field(None, description='계정', example='user1@test.com')
current_pw: str = Field(None, description='현재 비밀번호', example='1234')
new_pw: str = Field(None, description='신규 비밀번호', example='5678')
class UserLogSearchReq(BaseModel):
"""
### [Request] 유저로그 검색
"""
# basic
id: Optional[int] = Field(None, description='등록번호', example='1')
id__in: Optional[list] = Field(None, description='등록번호 리스트', example=[1,])
created_at: Optional[str] = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
created_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
created_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
created_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
created_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
updated_at: Optional[str] = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
updated_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
updated_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
updated_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
updated_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
account__like: Optional[str] = Field(None, description='계정 부분검색', example='%user%')
mac: Optional[str] = Field(None, description='MAC', example='11:22:33:44:55:01')
mac__like: Optional[str] = Field(None, description='MAC 부분검색', example='%33%')
type: Optional[TypeUserLog] = Field(None, description='유저로그 메시지 타입' + TypeUserLog.get_elements_str(), example=TypeUserLog.error)
api: Optional[str] = Field(None, description='API 이름', example='/api/auth/login')
api__like: Optional[str] = Field(None, description='API 이름 부분검색', example='%login%')
message: Optional[str] = Field(None, description='로그내용', example='invalid password')
message__like: Optional[str] = Field(None, description='로그내용 부분검색', example='%invalid%')
class Config:
orm_mode = True
class UserLogDayInfo(BaseModel):
"""
### 유저로그 일별 접속 정보 (출석)
"""
account: str = Field(None, description='계정', example='user1@test.com')
mac: str = Field(None, description='MAC(네트워크 인터페이스 식별자)', example='11:22:33:44:55:66')
updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
message: str = Field(None, description='로그 내용', example='ok')
class Config:
orm_mode = True
class UserLogDaySearchReq(BaseModel):
"""
### [Request] 유저로그 일별 마지막 접속 검색
"""
updated_at: Optional[str] = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
updated_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00')
updated_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00')
updated_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00')
updated_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00')
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
account__like: Optional[str] = Field(None, description='계정 부분검색', example='%user%')
mac: Optional[str] = Field(None, description='MAC', example='11:22:33:44:55:01')
mac__like: Optional[str] = Field(None, description='MAC 부분검색', example='%33%')
class Config:
orm_mode = True
class UserLogDaySearchPagingReq(BaseModel):
"""
### [Request] 유저로그 일별 마지막 접속 페이징 검색
"""
# paging
paging: Optional[PagingReq] = None
search: Optional[UserLogDaySearchReq] = None
class UserLogDaySearchPagingRes(ResponseBase):
"""
### [Response] 유저로그 일별 마지막 접속 검색
"""
paging: PagingRes = None
data: List[UserLogDayInfo] = []
class Config:
orm_mode = True
class UserLogPagingReq(BaseModel):
"""
### [Request] 유저로그 페이징 검색
"""
# paging
paging: Optional[PagingReq] = None
search: Optional[UserLogSearchReq] = None
class UserLogPagingRes(ResponseBase):
"""
### [Response] 유저로그 페이징 검색
"""
paging: PagingRes = None
data: List[UserLogInfo] = []
class Config:
orm_mode = True
class UserLogUpdateReq(BaseModel):
"""
### [Request] 유저로그 변경
"""
account: Optional[str] = Field(None, description='계정', example='user1@test.com')
mac: Optional[str] = Field(None, description='MAC', example='11:22:33:44:55:01')
type: Optional[TypeUserLog] = Field(None, description='유저로그 메시지 타입' + TypeUserLog.get_elements_str(), example=TypeUserLog.error)
api: Optional[str] = Field(None, description='API 이름', example='/api/auth/login')
message: Optional[str] = Field(None, description='로그내용', example='invalid password')
class Config:
orm_mode = True
class UserLogUpdateMultiReq(BaseModel):
"""
### [Request] 유저로그 변경 (multi)
"""
search_info: UserLogSearchReq = None
update_info: UserLogUpdateReq = None
class Config:
orm_mode = True
##############################################################################################
# APPEND INFORMATION
##############################################################################################
class AEStatusType(str, CustomEnum):
"""
### AI Engine 상태 타입
"""
init: str = 'init'
running: str = 'running'
stop: str = 'stop'
error: str = 'error'
class AEAIModelType(str, CustomEnum):
"""
### AI Engine에서 사용하는 AI 모델 타입
"""
FR: str = 'FR'
# OD: str = 'OD'
CON: str = 'CON'
PPE: str = 'PPE'
WORK: str = 'WORK'
BI: str = 'BI'
class AEAIModelStatusType(str, CustomEnum):
"""
### AI Engine 모델 상태 타입
"""
init: str = 'init'
idle: str = 'idle'
detect: str = 'detect'
error: str = 'error'
class AERIParameterInfo(BaseModel):
"""
### RI(Risk Index) 매개 변수 정보
"""
name: str = Field(description='RI 변수명', example='작업자 숙련도')
ratio: float = Field(description='RI 변수의 비중', example=0.60)
class Config:
orm_mode = True
class AEWorkRIInfo(BaseModel):
"""
### 작업 RI(Risk Index) 정보
"""
construction_code: str = Field(description='작업 공종 코드', example='D54')
construction_name: Optional[str] = Field(None, description='작업 공종명', example='간접활선 점퍼선 절단')
work_no: int = Field(description='해당 공종의 작업 절차 번호', example=1)
work_name: Optional[str] = Field(None, description='해당 작업 절차의 작업명', example='점퍼선 절단')
work_define_ri: float = Field(description='해당 작업에 정의된 위험도', example=0.77)
ri_parameter_list: List[AERIParameterInfo] = Field(description='RI 정보 리스트', example=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)])
evaluation_work_ri: Optional[float] = Field(0.00, description='AI 모델에서 평가한 작업(work_no) 위험도', example=0.52)
class Config:
orm_mode = True
class AEFTPInfo(BaseModel):
"""
### FTP 정보
"""
ip: str = Field(description='IP address', example='106.255.245.242')
port: int = Field(22, description='SFTP port', example=2022)
id: str = Field(description='User ID', example='kepri_if_user')
pw: str = Field(description='User passward')
location: str = Field(description='작업 공종 코드', example='/home/agics-dev/kepri_storage/rndpartners/')
file_con_setup: str = Field("3", description='시설물 탐지 파일이름', example='3')
file_face: str = Field("1", description='안면 인식 파일이름', example='1')
file_ppe: str = Field("2", description='개인보호구 탐지(PPE) 파일이름', example='2')
file_wd: str = Field("5", description='위험성 탐지(WD) 파일 이름', example='5')
file_bi: str = Field("4", description='위험성 탐지(BI) 파일 이름', example='4')
class AEInputBaseInfo(BaseModel):
"""
### AI Engine 입력 정보 (base)
"""
name: str = Field('device_name', description='장치명', example='device_name')
model: str = Field('model_name', description='모델명', example='model_name')
sn: str = Field('serial_no', description='시리얼번호', example='serial_no')
connect_url: str = Field('connect_url', description='서비스 접속 주소', example='connect_url')
user_id: str = Field('user_id', description='서비스 접속 계정', example='user_id')
user_pw: str = Field('user_pw', description='서비스 접속 비번', example='user_pw')
class Config:
orm_mode = True
class AEDemoInfo(BaseModel):
"""
### AI Engine DEMO 입력 정보 (base)
"""
ftp: AEFTPInfo = Field(description='FTP 정보')
class AEInputVideoInfo(AEInputBaseInfo):
"""
### AI Engine 입력 정보 (video)
"""
model:AEAIModelType = Field('model_name', description='모델명', example=AEAIModelType.FR)
class Config:
orm_mode = True
class AEInputBIInfo(AEInputBaseInfo):
"""
### AI Engine 입력 정보 (BI)
"""
topic: str = Field('topic', description='접속 토픽', example='test')
class Config:
orm_mode = True
class AEAIModelBase(BaseModel):
"""
### AI Model 정보 (base)
"""
name: AEAIModelType = Field(None, description='AI 모델명' + AEAIModelType.get_elements_str(), example=AEAIModelType.FR)
version: str = Field(None, description='AI Model Version (by date)', example='20220101')
status: AEAIModelStatusType = Field(None, description='AI Model 상태' + AEAIModelStatusType.get_elements_str(), example=AEAIModelStatusType.init)
class Config:
orm_mode = True
class AEAIModelFRInfo(AEAIModelBase):
"""
### AI Model FR(Face Recognition) 정보
"""
name: AEAIModelType = Field(None, description='AI 모델명' + AEAIModelType.get_elements_str(), example=AEAIModelType.FR)
ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54', work_no=3, work_define_ri=0.82, ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)]))
class Config:
orm_mode = True
class AEAIModelODModeType(str, CustomEnum):
"""
### AI Engine 객체탐지 모델 동작 모드
"""
ppe: str = 'ppe'
work: str = 'work'
con: str = 'con'
class AEAIModelODModelType(str, CustomEnum):
"""
### AI Engine 객체탐지에서 사용하는 모델 타입
"""
nano: str = 'nano'
small: str = 'small'
medium: str = 'medium'
large: str = 'large'
xlarge: str = 'xlarge'
class AEAIModelODWeightInfo(BaseModel):
"""
### AI Model OD(Object Detection) Weight 정보
"""
id: int = Field(0, description='Table Index', example='1')
filename: str = Field(description='파일명', example='index_78.pt')
version: str = Field(SW_VERSION, description='Version (by date)', example=SW_VERSION)
date: str = Field(D.date_str(), description='Version (by date)', example=D.date_str())
model: AEAIModelODModelType = Field(AEAIModelODModelType.small, description='학습한 환경의 객체탐지 모델 타입' + AEAIModelODModelType.get_elements_str(), example=AEAIModelODModelType.small)
nc: int = Field(description='클래스 개수', example=26)
class Config:
orm_mode = True
class AEAIModelODInfo(AEAIModelBase):
"""
### AI Model OD(Object Detection) 정보
"""
name: AEAIModelType = Field(None, description='AI 모델명' + AEAIModelType.get_elements_str(), example=AEAIModelType.PPE)
ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54', work_no=3, work_define_ri=0.82, ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)]))
weights: List[AEAIModelODWeightInfo] = Field(description='가중치 정보 리스트', example=[AEAIModelODWeightInfo(filename='index_78.pt', nc=26)])
mode: AEAIModelODModeType = Field(AEAIModelODModeType.ppe, description='모델 동작 모드' + AEAIModelODModeType.get_elements_str(), example=AEAIModelODModeType.ppe)
crop_images: bool = Field(True, description='Detect 이미지 전송 여부', example=True)
class Config:
orm_mode = True
class AEAIModelCONInfo(AEAIModelODInfo):
"""
### AI Model CON 정보
"""
name: AEAIModelType = Field(None, description='AI 모델명' + AEAIModelType.get_elements_str(), example=AEAIModelType.CON)
mode: AEAIModelODModeType = Field(AEAIModelODModeType.con, description='모델 동작 모드' + AEAIModelODModeType.get_elements_str(), example=AEAIModelODModeType.con)
class AEAIModelPPEInfo(AEAIModelODInfo):
"""
### AI Model PPE 정보
"""
name: AEAIModelType = Field(None, description='AI 모델명' + AEAIModelType.get_elements_str(), example=AEAIModelType.PPE)
mode: AEAIModelODModeType = Field(AEAIModelODModeType.ppe, description='모델 동작 모드' + AEAIModelODModeType.get_elements_str(), example=AEAIModelODModeType.ppe)
class AEAIModelWDInfo(AEAIModelODInfo):
"""
### AI Model WD 정보
"""
name: AEAIModelType = Field(None, description='AI 모델명' + AEAIModelType.get_elements_str(), example=AEAIModelType.WORK)
mode: AEAIModelODModeType = Field(AEAIModelODModeType.work, description='모델 동작 모드' + AEAIModelODModeType.get_elements_str(), example=AEAIModelODModeType.work)
crop_images: bool = Field(False, description='Detect 이미지 전송 여부', example=False)
class AEAIModelBIInfo(AEAIModelBase):
"""
### AI Model BI(BI Anomaly Detection) 정보
"""
name: AEAIModelType = Field(None, description='AI 모델명' + AEAIModelType.get_elements_str(), example=AEAIModelType.BI)
ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54', work_no=3, work_define_ri=0.82, ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)]))
class Config:
orm_mode = True
class AEInfo(BaseModel):
"""
### AI Engine 정보
"""
# id: int = Field(None, description='Table Index', example='1')
# created_at: datetime = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
# updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
version: str = Field(SW_VERSION, description='AI Engine Version', example=SW_VERSION)
ai_engine_status: AEStatusType = Field(AEStatusType.init, description='AI Engine 상태' + AEStatusType.get_elements_str(), example=AEStatusType.init)
demo: AEDemoInfo = Field(None, description='DEMO')
input_video: List[AEInputVideoInfo] = Field(None, description='입력장치 정보[영상]', example=[
AEInputVideoInfo(name='video1',model=AEAIModelType.CON),
AEInputVideoInfo(name='video2',model=AEAIModelType.FR),
AEInputVideoInfo(name='video3',model=AEAIModelType.PPE),
AEInputVideoInfo(name='video4',model=AEAIModelType.WORK)])
input_bi: AEInputBIInfo = Field(None, description='입력장치 정보[BI]', example=AEInputBIInfo())
con_model_info: AEAIModelCONInfo = Field(None, description='시설물 AI Model 정보')
fr_model_info: AEAIModelFRInfo = Field(None, description='안면인식 AI Model 정보')
ppe_model_info: AEAIModelPPEInfo = Field(None, description='개인보호구 PPE AI Model 정보')
wd_model_info: AEAIModelWDInfo = Field(None, description='객체탐지 AI Model 정보')
bi_model_info: AEAIModelBIInfo = Field(None, description='BI 이상탐지 AI Model 정보')
class Config:
orm_mode = True
class AEInfoSetReq(AEInfo):
"""
### [Request] AI Engine 설정
"""
class AEAIModelFRReq(BaseModel):
"""
### [Request] 안면인식
"""
limit_time_min: Optional[int] = Field(AI_ENGINE_FR_DETECT_TIME_MIN, description='탐지 제한 시간(분)', example=AI_ENGINE_FR_DETECT_TIME_MIN)
report_unit: Optional[bool] = Field(False, description='신규 대상을 인식할 때 마다 결과 전송', example=True)
ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54', work_no=3, work_define_ri=0.82, ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)]))
targets: list = Field(description='인식 대상 리스트 (고유번호사용: 사원번호 or 주민번호 ...)(TBD: 협의)', example='["no000001"]')
class Config:
orm_mode = True
class AEPPEBaseInfo(BaseModel):
"""
### 기본 개인보호구(PPE) 정보(TBD: 협의)
"""
# id: int = Field(None, description='Table Index', example='1')
# created_at: datetime = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
# updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
# version: str = Field(SW_VERSION, description='AI Engine Version', example=SW_VERSION)
safety_helmet_on: bool = Field(True, description='안전모 탐지', example=True)
safety_gloves_work_on: bool = Field(True, description='안전장갑(목장갑) 탐지', example=True)
safety_boots_on: bool = Field(True, description='안전화 탐지', example=True)
safety_belt_basic_on: bool = Field(False, description='안전대(벨트식) 탐지', example=False)
# safety_suit_top: bool = Field(True, description='안전복(방염복)상의', example=True)
# safety_suit_bottom: bool = Field(True, description='안전복(방염복)하의', example=True)
class Config:
orm_mode = True
class AEConSetupBaseInfo(BaseModel):
"""
### 기본 시설물 정보(TBD: 협의)
"""
# id: int = Field(None, description='Table Index', example='1')
# created_at: datetime = Field(None, description='생성날짜', example='2022-01-01T12:34:56')
# updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56')
# version: str = Field(SW_VERSION, description='AI Engine Version', example=SW_VERSION)
signalman: bool = Field(True, description='신호수 탐지', example=True)
sign_traffic_cone: bool = Field(True, description='트래픽콘(라바콘) 탐지', example=True)
sign_board_information: bool = Field(False, description='안내표지판 탐지', example=False)
sign_board_construction: bool = Field(False, description='공사안내판 탐지', example=False)
sign_board_traffic: bool = Field(False, description='교통안전표지판 탐지', example=False)
class Config:
orm_mode = True
class AEAIModelConSetupDetectReq(BaseModel):
"""
### [Request] 시설물 탐지
"""
limit_time_min: Optional[int] = Field(AI_ENGINE_FR_DETECT_TIME_MIN, description='탐지 제한 시간(분)', example=AI_ENGINE_FR_DETECT_TIME_MIN)
report_unit: Optional[bool] = Field(False, description='신규 대상을 인식할 때 마다 결과 전송', example=True)
ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54',work_no=3,work_define_ri=0.82,ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)]))
targets: AEConSetupBaseInfo = Field(description='탐지 대상', example=AEConSetupBaseInfo())
class Config:
orm_mode = True
class AEAIModelPPEDetectReq(BaseModel):
"""
### [Request] 개인보호구(PPE) 탐지
"""
limit_time_min: Optional[int] = Field(AI_ENGINE_FR_DETECT_TIME_MIN, description='탐지 제한 시간(분)', example=AI_ENGINE_FR_DETECT_TIME_MIN)
report_unit: Optional[bool] = Field(False, description='신규 대상을 인식할 때 마다 결과 전송', example=True)
ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54',work_no=3,work_define_ri=0.82,ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)]))
targets: AEPPEBaseInfo = Field(description='탐지 대상', example=AEPPEBaseInfo())
class Config:
orm_mode = True
class AEAIModelRBIVideoReq(BaseModel):
"""
### [Request] 위험성 탐지 (영상)
"""
# limit_time_min: Optional[int] = Field(AI_ENGINE_FR_DETECT_TIME_MIN, description='탐지 제한 시간(분)', example=AI_ENGINE_FR_DETECT_TIME_MIN)
ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54', work_no=3, work_define_ri=0.82, ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)]))
class Config:
orm_mode = True
class AEAIModelRBIBIReq(BaseModel):
"""
### [Request] 위험성 탐지 (BI)
"""
ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54', work_no=3, work_define_ri=0.82, ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)]))
class Config:
orm_mode = True
class AEAIModelFRMQTTInfo(BaseModel):
test: str = Field(None, description='test', example='test')
class Config:
orm_mode = True

View File

@@ -0,0 +1,246 @@
# -*- coding: utf-8 -*-
"""
@File: auth.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: authentication api
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
from itertools import groupby
from operator import attrgetter
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from starlette.requests import Request
import bcrypt
import jwt
from datetime import datetime, timedelta
from app.common import consts
from app import models as M
from app.database.conn import db
from app.common.config import conf
from app.database.schema import Users, UserLog
from app.utils.extra import AESCryptoCBC
from app.utils.date_utils import D
router = APIRouter(prefix='/auth')
@router.get('/find-account/{account}', response_model=M.ResponseBase, summary='계정유무 검사')
async def find_account(account: str):
"""
## 계정유무 검사
주어진 계정이 존재하면 true, 없으면 false 처리
**결과**
- ResponseBase
"""
try:
search_info = Users.get(account=account)
if not search_info:
raise Exception(f'not found data: {account}')
return M.ResponseBase()
except Exception as e:
return M.ResponseBase.set_error(str(e))
@router.post('/register/{account_type}', status_code=201, response_model=M.TokenRes, summary='회원가입')
async def register(request: Request, account_type: M.TypeUserAccount, request_body_info: M.UserRegisterReq, session: Session = Depends(db.session)):
"""
## 회원가입 (신규등록)
**AI-KEPCO REST SERVER**
- **PW** 항목은 로그인시에 사용되지 않지만, 임의값을 사용
- 가입 완료 메일 발송은 별도 API(**/api/services/sendmail**) 이용
- email 항목 생략시 account(email type)값이 적용
**결과**
- TokenRes
"""
new_user = None
new_license = None
try:
# request
# check account
if request.state.user.account != consts.ADMIN_INIT_ACCOUNT_INFO.account:
raise Exception('not admin')
else:
raise Exception('admin')
if not request_body_info.account or not request_body_info.pw:
raise Exception('Account and PW must be provided')
register_user = Users.get(account=request_body_info.account)
if register_user:
raise Exception(f'exist email: {request_body_info.account}')
if account_type == M.TypeUserAccount.email:
hash_pw = None
if request_body_info.pw:
# decrypt pw
try:
decode_pw = request_body_info.pw.encode('utf-8')
desc_pw = AESCryptoCBC().decrypt(decode_pw)
except Exception as e:
raise Exception(f'failed decryption [pw]: {e}')
hash_pw = bcrypt.hashpw(desc_pw, bcrypt.gensalt())
# create users
new_user = Users.create(session, auto_commit=True,
status=request_body_info.status,
user_type=request_body_info.user_grade,
account_type=account_type,
account=request_body_info.account,
pw=hash_pw,
email=request_body_info.email if request_body_info.email else request_body_info.account,
name=request_body_info.name,
sex=request_body_info.sex,
rrn=request_body_info.rrn,
address=request_body_info.address,
phone_number=request_body_info.mobile,
picture=request_body_info.picture,
marketing_agree=request_body_info.marketing_agree,
member_type=request_body_info.membership
)
token = dict(Authorization=f'Bearer {create_access_token(data=M.UserToken.from_orm(new_user).dict(exclude={"pw", "marketing_agree"}),)}')
return token
raise Exception('not supported')
except Exception as e:
return M.ResponseBase.set_error(str(e))
@router.post('/login/{account_type}', status_code=200, response_model=M.UserLoginRes, summary='사용자 접속')
async def login(account_type: M.TypeUserAccount, request_body_info: M.UserLoginReq, session: Session = Depends(db.session)):
"""
## 사용자 접속
**AI-KEPCO REST SERVER**
- 관리자 접속시: account, pw 사용\n
- 일반 계정 접속시: account, sw, mac 사용\n
**결과**
- UserLoginRes
"""
login_user = None
update_user = None
license_info = None
try:
# request
if not request_body_info.account:
raise Exception('invalid request: account')
login_user = Users.get(account=request_body_info.account)
if not login_user:
raise Exception('not found user')
if account_type == M.TypeUserAccount.email:
# basic auth
if login_user.user_grade == M.TypeUserGrade.user:
# check pw: pass
# NOTE(hsj100): CONNECT_USER_ERROR_LOG
UserLog.create(session=session, auto_commit=True
, account=login_user.account
, mac=request_body_info.mac if login_user.user_grade == M.TypeUserGrade.user else None
, type=M.TypeUserLog.info
, api='/api/auth/login'
, message='ok'
)
elif login_user.user_grade == M.TypeUserGrade.admin:
# decrypt pw
try:
decode_pw = request_body_info.pw.encode('utf-8')
desc_pw = AESCryptoCBC().decrypt(decode_pw)
except Exception as e:
raise Exception(f'failed decryption [pw]: {e}')
# check pw
if not request_body_info.pw:
raise Exception('invalid request: pw')
is_verified = bcrypt.checkpw(desc_pw, login_user.pw.encode('utf-8'))
if not is_verified:
raise Exception('invalid password')
else:
raise Exception('invalid user_type')
# TODO(hsj100): LOGIN_STATUS
update_user = Users.filter(account=request_body_info.account).update(auto_commit=True, login='login')
result_info = M.UserLoginRes()
result_info.Authorization = f'Bearer {create_access_token(data=M.UserToken.from_orm(login_user).dict(exclude={"pw", "marketing_agree"}), )}'
return result_info
return M.ResponseBase.set_error('not supported')
except Exception as e:
if update_user:
update_user.close()
# NOTE(hsj100): CONNECT_USER_ERROR_LOG
if login_user:
UserLog.create(session=session, auto_commit=True
, account=login_user.account
, mac=request_body_info.mac if login_user.user_type == M.TypeUserGrade.user else None
, type=M.TypeUserLog.error
, api='/api/auth/login'
, message=str(e)
)
return M.ResponseBase.set_error(str(e))
@router.post('/logout/{account}', status_code=200, response_model=M.TokenRes, summary='사용자 접속종료')
async def logout(account: str):
"""
## 사용자 접속종료
현재 버전에서는 로그인/로그아웃의 상태를 유지하지 않고 상태값만을 서버에서 사용하기 때문에,\n
***로그상태는 실제상황과 다를 수 있다.***
정상처리시 Authorization(null) 반환
**결과**
- TokenRes
"""
user_info = None
try:
# TODO(hsj100): LOGIN_STATUS
user_info = Users.filter(account=account)
if not user_info:
raise Exception('not found user')
user_info.update(auto_commit=True, login='logout')
return M.TokenRes()
except Exception as e:
if user_info:
user_info.close()
return M.ResponseBase.set_error(e)
async def is_account_exist(account: str):
get_account = Users.get(account=account)
return True if get_account else False
def create_access_token(*, data: dict = None, expires_delta: int = None):
if conf().GLOBAL_TOKEN:
return conf().GLOBAL_TOKEN
to_encode = data.copy()
if expires_delta:
to_encode.update({'exp': datetime.utcnow() + timedelta(hours=expires_delta)})
encoded_jwt = jwt.encode(to_encode, consts.JWT_SECRET, algorithm=consts.JWT_ALGORITHM)
return encoded_jwt

View File

@@ -0,0 +1,234 @@
# -*- coding: utf-8 -*-
"""
@File: dev.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: 개발용 API
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
import struct
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
import bcrypt
from starlette.requests import Request
from app.common import consts
from app import models as M
from app.database.conn import db, Base
from app.database.schema import Users, UserLog
from app.utils.extra import FernetCrypto, AESCryptoCBC, AESCipher
# mail test
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
def send_mail():
"""
구글 계정사용시 : 보안 수준이 낮은 앱에서의 접근 활성화
:return:
"""
sender = 'jolimola@gmail.com'
sender_pw = '!ghkdtmdwns1'
# recipient = 'hsj100@daooldns.co.kr'
recipient = 'jwkim@daooldns.co.kr'
list_cc = ['cc1@gmail.com', 'cc2@naver.com']
str_cc = ','.join(list_cc)
title = 'Test mail'
contents = '''
This is test mail
using smtplib.
'''
smtp_server = smtplib.SMTP( # 1
host='smtp.gmail.com',
port=587
)
smtp_server.ehlo() # 2
smtp_server.starttls() # 2
smtp_server.ehlo() # 2
smtp_server.login(sender, sender_pw) # 3
msg = MIMEMultipart() # 4
msg['From'] = sender # 5
msg['To'] = recipient # 5
# msg['Cc'] = str_cc # 5
msg['Subject'] = contents # 5
msg.attach(MIMEText(contents, 'plain')) # 6
smtp_server.send_message(msg) # 7
smtp_server.quit() # 8
router = APIRouter(prefix='/dev')
@router.get('/test', summary='테스트')
async def test(request: Request):
"""
## 개발 테스트 API
## 결과
- SWInfo
"""
result = dict()
text1 = '!ekdnfeldpsdptm1'
text2 = 'testtesttest123'
simpleEnDecrypt = FernetCrypto()
result['1original_data'] = text1
result['1simpleEnDecrypt.encrypt'] = simpleEnDecrypt.encrypt(text1)
result['1crypt.hashpw'] = bcrypt.hashpw(text1.encode('utf-8'), bcrypt.gensalt())
result['2original_data'] = text2
result['2simpleEnDecrypt.encrypt'] = simpleEnDecrypt.encrypt(text2)
result['2bcrypt.hashpw'] = bcrypt.hashpw(text2.encode('utf-8'), bcrypt.gensalt())
t = bytes(text1.encode('utf-8'))
enc = AESCryptoCBC().encrypt(t)
dec = AESCryptoCBC().decrypt(enc)
t = enc.decode('utf-8')
# enc = AESCipher('daooldns12345678').encrypt(a.name).decode('utf-8')
# enc = 'E563ZFt+yJL8YY5yYlYyk602MSscPP2SCCD8UtXXpMI='
# dec = AESCipher('daooldns12345678').decrypt(enc).decode('utf-8')
result['AESCryptoCBC().enc'] = f'enc: {enc}, {t}'
result['AESCryptoCBC().dec'] = f'dec: {dec}'
return result
@router.post('/db/add_test_userlog', summary='[DB] 테스트 데이터 추가 (유저로그)', response_model=M.ResponseBase)
async def add_test_userlog(session: Session = Depends(db.session)):
"""
## 테스트 데이터 생성 (유저로그)
프로젝트(DATABASE)에 테스트 데이터(유저로그)를 추가한다.
존재하는 유저에 대해서 2개의 로그기록을 추가한다.
추가되는 로그기록의 생성일은 2023년도 1월 부터 월당 2개이며, 12월이 넘어가면 로그기록을 종료한다.
**결과**
- ResponseBase
"""
start_month = 1
user_list = Users.filter(user_type=M.TypeUserGrade.user).all()
for user_info in user_list:
user_no = '%02d' % int(user_info.account.split('@')[0][4:])
temp = UserLog.create(session=session, auto_commit=True
, account=user_info.account
, mac=f'11:22:33:44:55:{user_no}'
, api='/api/auth/login'
, message='ok'
)
temp2 = UserLog.create(session=session, auto_commit=True
, account=user_info.account
, mac=f'11:22:33:44:55:{user_no}'
, api='/api/auth/login'
, message='ok'
)
# update date
UserLog.filter(id=temp.id).update(auto_commit=True
, created_at=f'2023-{start_month}-01')
UserLog.filter(id=temp2.id).update(auto_commit=True
, created_at=f'2023-{start_month}-10')
if start_month == 12:
break
start_month += 1
return M.ResponseBase()
@router.post('/db/add_test_data', summary='[DB] 테스트 데이터 생성', response_model=M.ResponseBase)
async def db_add_test_data(session: Session = Depends(db.session)):
"""
## 테스트 데이터 생성
프로젝트(DATABASE)에 테스트 데이터를 생성한다.
**결과**
- ResponseBase
"""
return M.ResponseBase()
@router.post('/db/delete_all_data', summary='[DB] 데이터 삭제 [전체]', response_model=M.ResponseBase)
async def db_delete_all_data():
"""
## DB 데이터 삭제
프로젝트(DATABASE)의 모든 테이블의 내용을 삭제한다. (테이블은 유지)
**결과**
- ResponseBase
"""
engine = db.engine
metadata = Base.metadata
foreign_key_turn_off = {
'mysql': 'SET FOREIGN_KEY_CHECKS=0;',
'postgresql': 'SET CONSTRAINTS ALL DEFERRED;',
'sqlite': 'PRAGMA foreign_keys = OFF;',
}
foreign_key_turn_on = {
'mysql': 'SET FOREIGN_KEY_CHECKS=1;',
'postgresql': 'SET CONSTRAINTS ALL IMMEDIATE;',
'sqlite': 'PRAGMA foreign_keys = ON;',
}
truncate_query = {
'mysql': 'TRUNCATE TABLE {};',
'postgresql': 'TRUNCATE TABLE {} RESTART IDENTITY CASCADE;',
'sqlite': 'DELETE FROM {};',
}
with engine.begin() as conn:
conn.execute(foreign_key_turn_off[engine.name])
for table in reversed(metadata.sorted_tables):
conn.execute(truncate_query[engine.name].format(table.name))
conn.execute(foreign_key_turn_on[engine.name])
return M.ResponseBase()
@router.post('/db/delete_all_tables', summary='[DB] 테이블 삭제 [전체]', response_model=M.ResponseBase)
async def db_delete_all_tables():
"""
## DB 데이터 삭제
프로젝트(DATABASE)의 모든 테이블을 삭제한다.
**결과**
- ResponseBase
"""
engine = db.engine
metadata = Base.metadata
truncate_query = {
'mysql': 'DROP TABLE IF EXISTS {};',
'postgresql': 'DROP TABLE IF EXISTS {};',
'sqlite': 'DROP TABLE IF EXISTS {};',
}
with engine.begin() as conn:
for table in reversed(metadata.sorted_tables):
conn.execute(truncate_query[engine.name].format(table.name))
return M.ResponseBase()

View File

@@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
"""
@File: index.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: default route
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
from fastapi import APIRouter
from app.utils.date_utils import D
from app.models import SWInfo
router = APIRouter()
@router.get('/', summary='서비스 정보', response_model=SWInfo)
async def index():
"""
## 서비스 정보
소프트웨어 이름, 버전정보, 현재시간
**결과**
- SWInfo
"""
sw_info = SWInfo()
sw_info.date = D.date_str()
return sw_info

View File

@@ -0,0 +1,878 @@
# -*- coding: utf-8 -*-
"""
@File: services.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: services api
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
import os, sys
import bcrypt
from fastapi import APIRouter, Depends, Body, UploadFile, File, Query
from typing import List
from sqlalchemy.orm import Session
from starlette.requests import Request
import copy
import time
import threading
AI_ENGINE_PATH = "/AI_ENGINE"
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))) + AI_ENGINE_PATH)
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))))) # const
import ai_engine_const as AI_CONST
import AI_ENGINE.demo_utils as DEMO
from AI_ENGINE.instance_queue import manager as message_queue
from app.utils.extra import send_mail
from app.common import consts
from app import models as M
from app.database.conn import Base, db
from app.database.schema import Users, UserLog
from app.database.crud import table_select, table_update, table_delete
from app.utils.date_utils import D
from app.utils.extra import FernetCrypto,file_size_check
router = APIRouter(prefix='/services')
global engine_info
engine_info = M.AEInfo(**AI_CONST.AI_ENGINE_INIT)
function_doc = f"""
## AI 엔진 정보
현재 상태의 AI 엔진 정보 조회한다.
- AI Model 정보 포함
### 결과
- AEInfo
"""
@router.get('/AE/Info', response_model=M.AEInfo, summary='AI 엔진 정보')
async def ai_engine_info_get() -> M.AEInfo:
# result: M.AEInfo = M.AEInfo(version=1.0)
global engine_info
result = copy.deepcopy(engine_info)
result.demo.ftp.pw = ""
return result
ai_engine_info_get.__doc__ = function_doc
function_doc = f"""
## AI 엔진 설정
AI 엔진 정보 설정한다.
- AI Model 정보 포함
### 결과
- AEInfo
"""
@router.post('/AE/Info-Set', response_model=M.AEInfo, summary='AI 엔진 설정')
async def ai_engine_info_set(request: Request, request_body_info: M.AEInfoSetReq) -> M.AEInfo:
# TODO(jwkim): key 하나만 변경할 수 있게 수정
global engine_info
engine_info.version = request_body_info.version if request_body_info.version else engine_info.version
engine_info.ai_engine_status = request_body_info.ai_engine_status if request_body_info.ai_engine_status else engine_info.ai_engine_status
engine_info.demo = request_body_info.demo if request_body_info.demo else engine_info.demo
engine_info.input_video = request_body_info.input_video if request_body_info.input_video else engine_info.input_video
engine_info.input_bi = request_body_info.input_bi if request_body_info.input_bi else engine_info.input_bi
# 모델정보
engine_info.con_model_info = request_body_info.con_model_info if request_body_info.con_model_info else engine_info.con_model_info
engine_info.fr_model_info = request_body_info.fr_model_info if request_body_info.fr_model_info else engine_info.fr_model_info
engine_info.ppe_model_info = request_body_info.ppe_model_info if request_body_info.ppe_model_info else engine_info.ppe_model_info
engine_info.wd_model_info = request_body_info.wd_model_info if request_body_info.wd_model_info else engine_info.wd_model_info
engine_info.bi_model_info = request_body_info.bi_model_info if request_body_info.bi_model_info else engine_info.bi_model_info
result = copy.deepcopy(engine_info)
result.demo.ftp.pw = ""
return result
ai_engine_info_set.__doc__ = function_doc
# @router.post('/AE/RI-Set', response_model=M.ResponseBase, summary='RI 정보 설정')
# async def ai_engine_ri_set(request: Request, request_body_info: M.SendMailReq) -> M.ResponseBase:
# """
# ## RI 정보 설정
#
# :param request:
# :param request_body_info:
#
# **결과**
# - ResponseBase
# """
#
# return M.ResponseBase()
# @router.post('/AE/FR-Recognize/id',response_model=M.ResponseBase, summary='안면 인식 이미지 등록')
# async def create_upload_files(id:List = Query(default=[]), images: List[UploadFile] = File(...)):
# """
# TODO(jwkim):개발 진행중
# """
# try:
# from AI_ENGINE.instance_queue import FR
# if len(id) != len(images) or len(id) == 0 or len(images) == 0:
# raise IndexError("check id,images length")
# fr_id_dict = {}
# for i in range(len(images)):
# size = file_size_check(images[i])
# fr_id_dict[id[i]] ={
# "image": images[i],
# "size": size
# }
# FR.fr_id_info = fr_id_dict
# return M.ResponseBase()
# except Exception as e:
# return M.ResponseBase.set_error(str(e))
@router.post('/AE/OD-CON-SETUP', response_model=M.ResponseBase, summary='시설물 탐지 (작업전) [영상]')
async def ai_engine_od_con_setup_detect(request: Request, request_body_info: M.AEAIModelConSetupDetectReq):
"""
## 시설물 탐지 (작업전)
제한 시간(분) 동안, AI Engine에 등록된 영상(RTSP)을 통해서 시설물을 탐지한다.
### AI 모델 [영상]
- OD(Object Detection)
### 동작 방식
- 시설물 탐지는 화면상에 보이는 모든 시설물 탐지
- 제한 시간 안에 정해진 시설물을 모두 탐지한 경우, 해당 시점에서 결과 전송(MQTT) 후 탐지종료 (탐지 성공)
- 제한 시간 안에 정해진 시설물을 탐지하지 못한 경우, 탐지한 결과만 전송(MQTT) 후 탐지종료 (탐지 실패)
- 신규 대상을 인식할 때 마다 결과 전송이 필요한 경우, 옵션(report_unit)을 true로 사용
- 이미 해당 기능이 동작 중인 경우: 이전 동작은 종료하고 새로 진행
- 동작 중인 상태에서 입력 영상 변경시 시설물 탐지는 종료 되며 에러메세지 MQTT로 전송
### 고려(협의) 사항
- RI 기반 시설물 탐지?
- 탐지할 시설물 목록에 대해서 공종별로 다르게 할 것인지, 전 공종에 대해서 통일 할 것인지 결정 필요
- 시설물 탐지시설(신호수, 라바콘, 교통안내표지판, ...) 목록에 대한 결정 필요
- **결과(MQTT) 구조**
### 결과 (API)
- API parameter 유효성 검증 결과 및 AI 모델 동작 유무에 따른 결과가 API 결과(ResponseBase)로 반환 (탐지결과 X)
- 실제 탐지 결과는 MQTT의 해당 topic(/AI_KEPCO/AI_OD_CON_SETUP_DETECT/REPORT)으로 전송 (탐지 결과 O)
### 결과 (MQTT Topic: /AI_KEPCO/AI_OD_CON_SETUP_DETECT/REPORT)
- MQTT JSON 구조
```
{
'datetime' : string -> MQTT 메세지 보낸 시각 YY-mm-dd HH:MM:SS.sss
'status': string -> PPE 탐지 상태 ['start', 'complete', 'new', 'timeout', 'stop', ...]
start: 탐지시작 (result에는 탐지 대상 전부 포함)
complete: 탐지완료 (주어진 대상 탐지 완료, result에는 탐지 대상 전부 포함)
new: 신규탐지 (신규 대상 탐지, result에는 신규 대상 정보만 포함)
timeout: 시간초과 (result에는 탐지한 내용만 포함)
stop: 종료/취소 (result에는 탐지한 대상만 포함)
...: 오류발생 (status 항목에는 오류메시지, result에는 탐지한 대상만 포함)
'result': array[DetectObjectInfo] -> 탐지 장비 리스트
}
```
- 탐지 객체 정보 (미정, 협의 필요)
```
DetectObjectInfo
{
'name': string -> 개체명
'cid': string -> 클래스 ID
'confidence': number -> 신뢰도 (0.00 ~ 1.00)
'bbox': array[number] -> 탐지 영역 좌표 [x1, y1, x2, y2]
'image': string -> 이미지 데이터 (binary data) (미정)
...
}
```
- 예) 탐지대상 리스트 (안전모, 안전장갑, 안전화, 안전대) 경우, MQTT 결과
* Result MQTT Data(JSON) 탐지시작
```
{
datetime: 23-01-19 10:18:04.675,
status: start,
result: [
null,
null,
null,
null,
null
]
}
```
* Result MQTT Data(JSON) 탐지완료
```
{
datetime: 23-01-19 10:18:04.675,
status: complete,
result: [
{'name': 'signalman', 'cid': 1 , 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
{'name': 'sign_traffic_cone', 'cid': 25, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
{'name': 'sign_board_information', 'cid': 22, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
{'name': 'sign_board_construction', 'cid': 23, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
{'name': 'sign_board_traffic', 'cid': 24, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}
]
}
```
* Result MQTT Data(JSON) 신규탐지 (safety_gloves_work_on 인식한 경우)
```
{
datetime: 23-01-19 10:18:04.675,
status: new,
result: [
null,
{'name': 'sign_traffic_cone', 'cid': 25, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
null,
null,
null
]
}
```
* Result MQTT Data(JSON) 시간초과
```
{
datetime: 23-01-19 10:18:04.675,
status: timeout,
result: [
null,
{'name': 'sign_traffic_cone', 'cid': 25, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
{'name': 'sign_board_information', 'cid': 22, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
null,
null
]
}
```
* Result MQTT Data(JSON) 종료/취소
```
{
datetime: 23-01-19 10:18:04.675,
status: stop,
result: [
null,
{'name': 'sign_traffic_cone', 'cid': 25, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
{'name': 'sign_board_information', 'cid': 22, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
null,
null
]
}
```
* Result MQTT Data(JSON) 오류발생
```
{
datetime: 23-01-19 10:18:04.675,
status: ai model error ...,
result: [
null,
null,
null,
{'name': 'sign_board_construction', 'cid': 23, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
null
]
}
```
"""
# time.sleep(2)
global engine_info
message_queue.sender({
"request": request_body_info,
"ai_model": AI_CONST.MODEL_CON,
"signal": AI_CONST.SIGNAL_INFERENCE,
"engine_info": engine_info,
"argument": {
# "source": video_abs_path
}
})
result = M.ResponseBase()
return result
@router.post('/AE/OD-CON-SETUP-Stop', response_model=M.ResponseBase, summary='시설물 탐지 중지')
async def ai_engine_od_con_setup_detect_stop(request: Request):
"""
## 시설물 탐지 중지
인위적으로 중지(취소/종료)할 경우에 사용한다.
- 개인보호구 탐지 모델 동작 중지
- MQTT의 해당 topic(/AI_KEPCO/AI_OD_PPE_DETECT/REPORT) 전송 중지
### 결과
- ResponseBase
"""
message_queue.sender({
"request": Request,
"ai_model": AI_CONST.MODEL_CON,
"signal": AI_CONST.SIGNAL_STOP
})
result = M.ResponseBase()
return result
@router.post('/AE/FR-Recognize', response_model=M.ResponseBase, summary='안면 인식 [영상]')
async def ai_engine_fr_recognize(request: Request, request_body_info: M.AEAIModelFRReq) -> M.ResponseBase:
"""
## 안면 인식
제한 시간(분) 동안, AI Engine에 등록된 영상(RTSP)을 통해서 작업자의 안면 인식을 한다.
### AI 모델 [영상]
- FR(Face Recognition)
### 동작 방식
- 안면 인식은 동시에 여러명 인식 가능
- 제한 시간 안에 인식 대상을 모두 탐지한 경우: 해당 시점에서 결과 전송(MQTT) 후 자동 인식종료 (인식 성공)
- 제한 시간 안에 인식 대상을 모두 탐지 못한 경우: 인식한 결과만 전송(MQTT) 후 자동 인식종료 (인식 실패)
- 신규 대상을 인식할 때 마다 결과 전송이 필요한 경우: 옵션 설정 (report_unit -> true)
- 이미 해당 기능이 동작 중인 경우: 이전 동작은 종료하고 새로 진행
- 동작 중인 상태에서 입력 영상 변경시 안면 인식은 종료 되며 에러메세지 MQTT로 전송
### 고려(협의) 사항
- **인식자 정보(얼굴 이미지) 처리 방법**
+ 클라이언트에서 인식자 얼굴 이미지를 획득하고 이미지 정보를 REST 서버로 전달해서 처리하는 경우
* REST 서버에서는 주어진 이미지 데이터를 이용해서 인식처리
+ 클라이언트에서 인식자 고유정보만 REST 서버로 전달해서 처리하는 경우
* 인식자 고유 정보를 사용할 경우, 고유값(사원번호, 주민번호, 등) 정의 필요
* REST 서버에서 인식자 고유 정보를 이용해서 얼굴 이미지를 다운받아서 인식처리
- **결과(MQTT) 구조**
### 결과 (API)
- API parameter 유효성 검증 결과 및 AI 모델 동작 유무에 따른 결과가 API 결과(ResponseBase)로 반환 (인식결과 X)
- 실제 인식 결과는 MQTT의 해당 topic(/AI_KEPCO/AI_FACE_RECOGNIZE/REPORT)으로 전송 (인식 결과 O)
### 결과 (MQTT Topic: /AI_KEPCO/AI_FACE_RECOGNIZE/REPORT)
- MQTT JSON 구조
```
{
'datetime' : string -> MQTT 메세지 보낸 시각 YY-mm-dd HH:MM:SS.sss
'status': string -> 안면 인식 상태 ['start', 'complete', 'new', 'timeout', 'stop', ...]
start: 인식시작 (result에는 인식 대상자 전부 포함)
complete: 인식완료 (result에는 인식 대상자 전부 포함)
new: 신규인식 (result에는 신규 대상 정보만 포함)
timeout: 시간초과 (result에는 인식한 대상자 내용만 포함)
stop: 종료/취소 (result에는 인식한 대상자 내용만 포함)
...: 오류발생 (status 항목에는 오류메시지, result에는 인식한 대상자 내용만 포함)
'result': array => 인식자 고유 정보 리스트
}
```
- 예) 인식 대상자 [no0001, no0002, no0003] 경우, MQTT 결과
* Result MQTT Data(JSON) 인식시작
```
{
datetime: 23-01-19 10:18:04.675,
status: start,
result: [null, null, null]
}
```
* Result MQTT Data(JSON) 인식완료
```
{
datetime: 23-01-19 10:18:04.675,
status: complete,
result: [no000001, no000002, no000003]
}
```
* Result MQTT Data(JSON) 신규인식
```
{
datetime: 23-01-19 10:18:04.675,
status: new,
result: [null, no000002, null]
}
```
* Result MQTT Data(JSON) 시간초과
```
{
datetime: 23-01-19 10:18:04.675,
status: timeout,
result: [no000001, no000002, null]
}
```
* Result MQTT Data(JSON) 종료/취소
```
{
datetime: 23-01-19 10:18:04.675,
status: stop,
result: [no000001, no000002, null]
}
```
* Result MQTT Data(JSON) 오류발생
```
{
datetime: 23-01-19 10:18:04.675,
status: [error] ai model error ...,
result: [no000001, null, null]
}
```
"""
global engine_info
message_queue.sender({
"request": request_body_info,
"ai_model": AI_CONST.MODEL_FACE_RECOGNIZE,
"signal": AI_CONST.SIGNAL_INFERENCE,
"engine_info": engine_info
})
return M.ResponseBase()
@router.post('/AE/FR-Recognize-Stop', response_model=M.ResponseBase, summary='안면 인식 중지')
async def ai_engine_fr_recognize_stop(request: Request):
"""
## 안면 인식 중지
인위적으로 중지(취소/종료)할 경우에 사용한다.
- 안면 인식 모델 동작 중지
- MQTT의 해당 topic(/AI_KEPCO/AI_FACE_RECOGNIZE/REPORT) 전송 중지
### 결과
- ResponseBase
"""
message_queue.sender({
"request": Request,
"ai_model": AI_CONST.MODEL_FACE_RECOGNIZE,
"signal": AI_CONST.SIGNAL_STOP
})
result = M.ResponseBase()
return result
@router.post('/AE/OD-PPE-Detect', response_model=M.ResponseBase, summary='개인보호구(PPE) 탐지 (작업전) [영상]')
async def ai_engine_od_ppe_detect(request: Request, request_body_info: M.AEAIModelPPEDetectReq):
"""
## 개인보호구(PPE) 탐지 (작업전)
제한 시간(분) 동안, AI Engine에 등록된 영상(RTSP)을 통해서 작업자의 개인보호구(PPE)를 탐지한다.
### AI 모델 [영상]
- OD(Object Detection)
### 동작 방식
- 개인보호구(PPE) 탐지는 한 사람씩 탐지 (여러명 탐지 불가)
- 제한 시간 안에 정해진 개인보호구(PPE)를 탐지한 경우, 해당 시점에서 결과 전송(MQTT) 후 탐지종료 (탐지 성공)
- 제한 시간 안에 정해진 개인보호구(PPE)를 탐지하지 못한 경우, 탐지한 결과만 전송(MQTT) 후 탐지종료 (탐지 실패)
- 신규 대상을 인식할 때 마다 결과 전송이 필요한 경우, 옵션(report_unit)을 true로 사용
- 이미 해당 기능이 동작 중인 경우: 이전 동작은 종료하고 새로 진행
- 동작 중인 상태에서 입력 영상 변경시 개인보호구(PPE) 탐지는 종료 되며 에러메세지 MQTT로 전송
### 고려(협의) 사항
- RI 기반 개인보호구(PPE) 탐지?
- 탐지할 개인보호구 목록에 대해서 공종별로 다르게 할 것인지, 전 공종에 대해서 통일 할 것인지 결정 필요
- 개인보호구(PPE) 탐지장비(안전모, 안전장갑, 안전화, 안전대, ...) 목록에 대한 결정 필요
- **결과(MQTT) 구조**
### 결과 (API)
- API parameter 유효성 검증 결과 및 AI 모델 동작 유무에 따른 결과가 API 결과(ResponseBase)로 반환 (탐지결과 X)
- 실제 탐지 결과는 MQTT의 해당 topic(/AI_KEPCO/AI_OD_PPE_DETECT/REPORT)으로 전송 (탐지 결과 O)
### 결과 (MQTT Topic: /AI_KEPCO/AI_OD_PPE_DETECT/REPORT)
- MQTT JSON 구조
```
{
'datetime' : string -> MQTT 메세지 보낸 시각 YY-mm-dd HH:MM:SS.sss
'status': string -> PPE 탐지 상태 ['start', 'complete', 'new', 'timeout', 'stop', ...]
start: 탐지시작 (result에는 탐지 대상 전부 포함)
complete: 탐지완료 (주어진 대상 탐지 완료, result에는 탐지 대상 전부 포함)
new: 신규탐지 (신규 대상 탐지, result에는 신규 대상 정보만 포함)
timeout: 시간초과 (result에는 탐지한 내용만 포함)
stop: 종료/취소 (result에는 탐지한 대상만 포함)
...: 오류발생 (status 항목에는 오류메시지, result에는 탐지한 대상만 포함)
'result': array[DetectObjectInfo] -> 탐지 장비 리스트
}
```
- 탐지 객체 정보 (미정, 협의 필요)
```
DetectObjectInfo
{
'name': string -> 개체명
'cid': string -> 클래스 ID
'confidence': number -> 신뢰도 (0.00 ~ 1.00)
'bbox': array[number] -> 탐지 영역 좌표 [x1, y1, x2, y2]
'image': string -> 이미지 데이터 (binary data) (미정)
...
}
```
- 예) 탐지대상 리스트 (안전모, 안전장갑, 안전화, 안전대) 경우, MQTT 결과
* Result MQTT Data(JSON) 탐지시작
```
{
datetime: 23-01-19 10:18:04.675,
status: start,
result: [
null,
null,
null,
null
]
}
```
* Result MQTT Data(JSON) 탐지완료
```
{
datetime: 23-01-19 10:18:04.675,
status: complete,
result: [
{'name': 'safety_helmet_on', 'cid': 1 , 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
{'name': 'safety_gloves_work_on', 'cid': 2, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
{'name': 'safety_boots_on', 'cid': 3, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
{'name': 'safety_belt_basic_on', 'cid': 4, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}
]
}
```
* Result MQTT Data(JSON) 신규탐지 (safety_gloves_work_on 인식한 경우)
```
{
datetime: 23-01-19 10:18:04.675,
status: new,
result: [
null,
{'name': 'safety_gloves_work_on', 'cid': 2, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
null,
null
]
}
```
* Result MQTT Data(JSON) 시간초과
```
{
datetime: 23-01-19 10:18:04.675,
status: timeout,
result: [
null,
{'name': 'safety_gloves_work_on', 'cid': 2, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
{'name': 'safety_boots_on', 'cid': 3, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
null
]
}
```
* Result MQTT Data(JSON) 종료/취소
```
{
datetime: 23-01-19 10:18:04.675,
status: stop,
result: [
null,
{'name': 'safety_gloves_work_on', 'cid': 2, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
{'name': 'safety_boots_on', 'cid': 3, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
null
]
}
```
* Result MQTT Data(JSON) 오류발생
```
{
datetime: 23-01-19 10:18:04.675,
status: ai model error ...,
result: [
null,
null,
{'name': 'safety_boots_on', 'cid': 3, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
null
]
}
```
"""
# time.sleep(2)
global engine_info
message_queue.sender({
"request": request_body_info,
"ai_model": AI_CONST.MODEL_PPE,
"signal": AI_CONST.SIGNAL_INFERENCE,
"engine_info": engine_info,
"argument": {
# "source": video_abs_path
}
})
result = M.ResponseBase()
return result
@router.post('/AE/OD-PPE-Detect-Stop', response_model=M.ResponseBase, summary='개인보호구(PPE) 탐지 중지')
async def ai_engine_od_ppe_detect_stop(request: Request):
"""
## 개인보호구(PPE) 탐지 중지
인위적으로 중지(취소/종료)할 경우에 사용한다.
- 개인보호구 탐지 모델 동작 중지
- MQTT의 해당 topic(/AI_KEPCO/AI_OD_PPE_DETECT/REPORT) 전송 중지
### 결과
- ResponseBase
"""
message_queue.sender({
"request": Request,
"ai_model": AI_CONST.MODEL_PPE,
"signal": AI_CONST.SIGNAL_STOP
})
result = M.ResponseBase()
return result
@router.post('/AE/OD-Work-Detect', response_model=M.ResponseBase, summary='위험성 탐지 (작업중) [영상]')
async def ai_engine_od_work_detect(request: Request, request_body_info: M.AEAIModelRBIVideoReq):
"""
## 위험성 탐지 (작업중)
AI Engine에 연동된 영상 서비스(RTSP - 다채널)의 데이터를 객체탐지 모델을 이용하여 위험성(RI)을 탐지한다.
### AI 모델 [영상]
- OD(Object Detection)
### 동작 방식
- 탐지된 정보(객체)를 RI 기반으로 위험성을 판단하고 결과를 MQTT로 전송
- 위험성이 탐지될 때 마다 탐지 결과를 MQTT로 전송
- 위험성 탐지 결과는 누적정보가 아닌 해당 시점에서 발생한 탐지결과
- 위험성 탐지 종료는 API(/api/services/AE/OD-Work-Detect-Stop)를 호출하면 종료
- 이미 해당 기능이 동작 중인 경우: 이전 동작은 종료하고 새로 진행
- 동작 중인 상태에서 입력 영상 변경시 위험성 탐지는 종료 되며 에러메세지 MQTT로 전송
### 고려(협의) 사항
- RI 정보 미정
- **결과(MQTT) 구조**
### 결과(API)
- API parameter 유효성 검증 결과 및 AI 모델 동작 유무에 따른 결과가 API 결과(ResponseBase)로 반환 (탐지결과 X)
- 실제 탐지 결과는 MQTT의 해당 topic(/AI_KEPCO/AI_OD_WORK_DETECT/REPORT)으로 전송 (탐지 결과 O)
### 결과(MQTT Topic: /AI_KEPCO/AI_OD_WORK_DETECT/REPORT)
- MQTT JSON 구조
```
{
'datetime' : string -> MQTT 메세지 보낸 시각 YY-mm-dd HH:MM:SS.sss
'status': string -> 위험성 탐지 상태 ['start', 'detect', 'stop', ...]
start: 탐지시작 (result에는 인식 대상 전부 포함)
detect: 위험성 탐지 (위험성이 탐지될 때 마다 발생, result에는 신규 대상 정보만 포함)
stop: 종료/취소 (result에는 탐지한 대상만 포함)
...: 오류발생 (status 항목에는 오류메시지, result에는 탐지한 대상만 포함)
'result': ConstructionRiskIndexInfo -> 공사 위험성 정보
}
```
- 탐지 객체 정보 구조 (미정, 협의 필요)
```
DetectObjectInfo
{
'name': string -> 객체명
'cid': number -> 클래스 ID
'confidence': number -> 신뢰도 (0.00 ~ 1.00)
'bbox': array[number] -> 탐지 영역 좌표 [x1, y1, x2, y2]
'image': string -> 이미지 데이터 (binary data) (미정)
...
}
```
- 위험성 정보 구조 (미정, 협의 필요)
```
ConstructionRiskIndexInfo
{
'construction_type': string -> 공종
'procedure_no': number -> 해당 공종의 시공 순번
'procedure_ri': number -> 공종의 i번째 공사절차(procedure_no)의 정의된 위험도(0.00 ~ 1.00)
'ri': number -> 해당 공사절차의 평가된 위험도 (procedure_ri - 측정된 RI값)
'detect_list': array[DetectInfo, ...] 탐지된 객체 리스트 (옵션)
...
}
```
- 예) 위험성 탐지
* Result MQTT Data(JSON) - 탐지 시작
```
{
'datetime': '23-01-19 10:18:04.675',
'status': 'start',
'result': {
'construction_type': 'D4'
procedure_no: 1,
'procedure_ri': 0.70,
ri: 0.32,
detect_list: [
]
}
}
```
* Result MQTT Data(JSON) - 위험성 탐지
```
{
'datetime': '23-01-19 10:18:04.675',
'status': 'detect',
'result': {
'construction_type': 'D4'
procedure_no: 1,
'procedure_ri': 0.70,
ri: 0.32,
detect_list: [
{'name': 'person', 'cid': 0, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
{'name': 'safety_helmet_on', 'cid': 3, 'confidence': 0.75, 'bbox': [ 350, 350, 200, 400 ], ...}
]
}
}
```
* Result MQTT Data(JSON) - 종료/취소
```
{
'datetime': '23-01-19 10:18:04.675',
'status': 'stop',
'result': {
'construction_type': 'D4'
procedure_no: 1,
'procedure_ri': 0.70,
ri: 0.32,
detect_list: [
{'name': 'person', 'cid': 0, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
{'name': 'safety_helmet_on', 'cid': 3, 'confidence': 0.75, 'bbox': [ 350, 350, 200, 400 ], ...},
{'name': 'safety_boots_on', 'cid': '10', 'confidence': 0.86, 'bbox': [ 450, 450, 300, 400 ], ...},
...
]
}
}
```
* Result MQTT Data(JSON) - 오류발생
```
{
'datetime': '23-01-19 10:18:04.675',
'status': 'ai model error ...',
'result': {
'construction_type': 'D4'
procedure_no: 1,
'procedure_ri': 0.70,
ri: 0.32,
detect_list: [
{'name': 'person', 'cid': 0, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...},
{'name': 'safety_helmet_on', 'cid': 3, 'confidence': 0.75, 'bbox': [ 350, 350, 200, 400 ], ...},
{'name': 'safety_boots_on', 'cid': '10', 'confidence': 0.86, 'bbox': [ 450, 450, 300, 400 ], ...},
...
]
}
}
```
"""
global engine_info
message_queue.sender({
"request": request_body_info,
"ai_model": AI_CONST.MODEL_WORK_DETECT,
"signal": AI_CONST.SIGNAL_INFERENCE,
"engine_info": engine_info,
"argument": {
# "source": video_abs_path
}
})
result = M.ResponseBase()
return result
@router.post('/AE/OD-Work-Detect-Stop', response_model=M.ResponseBase, summary='위험성 탐지 중지')
async def ai_engine_od_work_detect_stop(request: Request):
"""
## 위험성 탐지 중지
위험성 탐지를 중지(취소/종료)할 경우에 사용
- 위험성 탐지 모델 동작 중지
- MQTT의 해당 topic(/AI_KEPCO/AI_OD_WORK_DETECT/REPORT) 전송 중지
### 결과
- ResponseBase
"""
message_queue.sender({
"request": Request,
"ai_model": AI_CONST.MODEL_WORK_DETECT,
"signal": AI_CONST.SIGNAL_STOP
})
DEMO.demo_wd_bi(0)
result = M.ResponseBase()
return result
@router.post('/AE/BI-Detect', response_model=M.ResponseBase, summary='위험성 탐지 [BI]')
async def ai_engine_bi_detect(request: Request, request_body_info: M.AEAIModelRBIBIReq):
"""
## 위험성 탐지 [BI]
AI Engine에 연동된 BI 센서 장비들을 통해서 획득한 생체정보(BI)를 RI 기반으로 위험성 평가(이상탐지)를 실시한다.
### AI 모델 [BI]
- BI(Biometric Information - Time Series Anomaly Detection)
### 동작 방식
- 탐지된 정보(BI)를 RI정보 기반으로 위험성을 판단하고 결과를 MQTT로 전송
- 위험성이 탐지될 때 마다 탐지 결과를 MQTT로 전송
- 위험성 탐지 결과는 누적정보가 아닌 해당 시점에서 발생한 탐지결과
- 위험성 탐지는 종료 API(/api/services/AE/BI-Detect-Stop)를 호출하면 종료
### 고려(협의) 사항
- RI 정보 미정
- **결과(MQTT) 구조**
### 결과(API)
- API parameter 유효성 검증 결과 및 AI 모델 동작 유무에 따른 결과가 API 결과(ResponseBase)로 반환 (탐지결과 X)
- 실제 탐지 결과는 MQTT의 해당 topic(/AI_KEPCO/AI_BI_DETECT/REPORT)으로 전송 (탐지 결과 O)
### 결과(MQTT Topic: /AI_KEPCO/AI_BI_DETECT/REPORT) (미연동)
- MQTT JSON 구조
```
{
'datetime' : string -> MQTT 메세지 보낸 시각 YY-mm-dd HH:MM:SS.sss
'status':string -> 위험성 탐지 상태 ['start', 'detect', 'stop', ...]
detect: 위험성 탐지 (위험성이 탐지될 때 마다 발생, result에는 신규 대상 정보만 포함)
stop: 종료/취소 (result에는 탐지한 대상만 포함)
...: 오류발생 (status 항목에는 오류메시지, result에는 탐지한 대상만 포함)
'result': array => 위험성 정보 리스트
}
```
"""
# if DEMO.DEMO_WD_BI_CONST == 0 :
# return M.ResponseBase.set_error(str("Invalid state (not WD)"))
# DEMO.demo_wd_bi(2)
# DEMO.bi_snap_shot()
result = M.ResponseBase()
return result
@router.post('/AE/BI-Detect-Stop', response_model=M.ResponseBase, summary='위험성 탐지(BI) 중지')
async def ai_engine_bi_detect_stop(request: Request):
"""
## 위험성 탐지(BI) 중지
위험성 탐지를 중지(취소/종료)할 경우에 사용
- 위험성 탐지 모델 동작 중지
- MQTT의 해당 topic(/AI_KEPCO/AI_BI_DETECT/REPORT) 전송 중지
### 결과
- ResponseBase
"""
result = M.ResponseBase()
return result

View File

@@ -0,0 +1,299 @@
# -*- coding: utf-8 -*-
"""
@File: users.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: users api
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
from fastapi import APIRouter
from starlette.requests import Request
import bcrypt
from app.common import consts
from app import models as M
from app.common.config import conf
from app.database.schema import Users
from app.database.crud import table_select, table_update, table_delete
from app.utils.extra import AESCryptoCBC
router = APIRouter(prefix='/user')
@router.get('/me', response_model=M.UserSearchRes, summary='접속자 정보')
async def get_me(request: Request):
"""
## 현재 접속된 자신정보 확인
***현 버전 미지원(추후 상세 처리)***
**결과**
- UserSearchRes
"""
target_table = Users
search_info = None
try:
# request
if conf().GLOBAL_TOKEN:
raise Exception('not supported: use search api!')
accessor_info = request.state.user
if not accessor_info:
raise Exception('invalid accessor')
# search
search_info = target_table.get(account=accessor_info.account)
if not search_info:
raise Exception('not found data')
# result
result_info = list()
result_info.append(M.UserInfo.from_orm(search_info))
return M.UserSearchRes(data=result_info)
except Exception as e:
if search_info:
search_info.close()
return M.ResponseBase.set_error(str(e))
@router.post('/search', response_model=M.UserSearchPagingRes, summary='유저정보 검색')
async def user_search(request: Request, request_body_info: M.UserSearchPagingReq):
"""
## 유저정보 검색 (기본)
검색 조건은 유저 테이블 항목만 가능 (Request body: Schema 참조)\n
관련 정보 연동 검색은 별도 API 사용
**세부항목**
- **paging**\n
항목 미사용시에는 페이징기능 없이 검색조건(search) 결과 모두 반환
- **search**\n
검색에 필요한 항목들을 search 에 포함시킨다.\n
검색에 사용된 각 항목들은 AND 조건으로 처리된다.
- **전체검색**\n
empty object 사용 ( {} )\n
예) "search": {}
- **검색항목**\n
- 부분검색 항목\n
SQL 문법( %, _ )을 사용한다.\n
__like: 시작포함(X%), 중간포함(%X%), 끝포함(%X)
- 구간검색 항목
* __lt: 주어진 값보다 작은값
* __lte: 주어진 값보다 같거나 작은 값
* __gt: 주어진 값보다 큰값
* __gte: 주어진 값보다 같거나 큰값
**결과**
- UserSearchPagingRes
"""
return await table_select(request.state.user, Users, request_body_info, M.UserSearchPagingRes, M.UserInfo)
@router.put('/update', response_model=M.ResponseBase, summary='유저정보 변경')
async def user_update(request: Request, request_body_info: M.UserUpdateMultiReq):
"""
## 유저정보 변경
**search_info**: 변경대상\n
**update_info**: 변경내용\n
- **비밀번호** 제외
**결과**
- ResponseBase
"""
return await table_update(request.state.user, Users, request_body_info, M.ResponseBase)
@router.put('/update_pw', response_model=M.ResponseBase, summary='유저 비밀번호 변경')
async def user_update_pw(request: Request, request_body_info: M.UserUpdatePWReq):
"""
## 유저정보 비밀번호 변경
**account**의 **비밀번호**를 변경한다.
**결과**
- ResponseBase
"""
target_table = Users
search_info = None
try:
# request
accessor_info = request.state.user
if not accessor_info:
raise Exception('invalid accessor')
if not request_body_info.account:
raise Exception('invalid account')
# decrypt pw
try:
decode_cur_pw = request_body_info.current_pw.encode('utf-8')
desc_cur_pw = AESCryptoCBC().decrypt(decode_cur_pw)
except Exception as e:
raise Exception(f'failed decryption [current_pw]: {e}')
try:
decode_new_pw = request_body_info.new_pw.encode('utf-8')
desc_new_pw = AESCryptoCBC().decrypt(decode_new_pw)
except Exception as e:
raise Exception(f'failed decryption [new_pw]: {e}')
# search
target_user = target_table.get(account=request_body_info.account)
is_verified = bcrypt.checkpw(desc_cur_pw, target_user.pw.encode('utf-8'))
if not is_verified:
raise Exception('invalid password')
search_info = target_table.filter(id=target_user.id)
if not search_info.first():
raise Exception('not found data')
# process
hash_pw = bcrypt.hashpw(desc_new_pw, bcrypt.gensalt())
result_info = search_info.update(auto_commit=True, pw=hash_pw)
if not result_info or not result_info.id:
raise Exception('failed update')
# result
return M.ResponseBase()
except Exception as e:
if search_info:
search_info.close()
return M.ResponseBase.set_error(str(e))
@router.delete('/delete', response_model=M.ResponseBase, summary='유저정보 삭제')
async def user_delete(request: Request, request_body_info: M.UserSearchReq):
"""
## 유저정보 삭제
조건에 해당하는 정보를 모두 삭제한다.\n
- **본 API는 DB에서 완적삭제를 하는 함수이며, 서버관리자가 사용하는 것을 권장한다.**
- **update API를 사용하여 상태 항목을 변경해서 사용하는 것을 권장.**
`유저삭제시 관계 테이블의 정보도 같이 삭제된다.`
**결과**
- ResponseBase
"""
return await table_delete(request.state.user, Users, request_body_info, M.ResponseBase)
# NOTE(hsj100): apikey
"""
"""
# @router.get('/apikeys', response_model=List[M.GetApiKeyList])
# async def get_api_keys(request: Request):
# """
# API KEY 조회
# :param request:
# :return:
# """
# user = request.state.user
# api_keys = ApiKeys.filter(user_id=user.id).all()
# return api_keys
#
#
# @router.post('/apikeys', response_model=M.GetApiKeys)
# async def create_api_keys(request: Request, key_info: M.AddApiKey, session: Session = Depends(db.session)):
# """
# API KEY 생성
# :param request:
# :param key_info:
# :param session:
# :return:
# """
# user = request.state.user
#
# api_keys = ApiKeys.filter(session, user_id=user.id, status='active').count()
# if api_keys == MAX_API_KEY:
# raise ex.MaxKeyCountEx()
#
# alphabet = string.ascii_letters + string.digits
# s_key = ''.join(secrets.choice(alphabet) for _ in range(40))
# uid = None
# while not uid:
# uid_candidate = f'{str(uuid4())[:-12]}{str(uuid4())}'
# uid_check = ApiKeys.get(access_key=uid_candidate)
# if not uid_check:
# uid = uid_candidate
#
# key_info = key_info.dict()
# new_key = ApiKeys.create(session, auto_commit=True, secret_key=s_key, user_id=user.id, access_key=uid, **key_info)
# return new_key
#
#
# @router.put('/apikeys/{key_id}', response_model=M.GetApiKeyList)
# async def update_api_keys(request: Request, key_id: int, key_info: M.AddApiKey):
# """
# API KEY User Memo Update
# :param request:
# :param key_id:
# :param key_info:
# :return:
# """
# user = request.state.user
# key_data = ApiKeys.filter(id=key_id)
# if key_data and key_data.first().user_id == user.id:
# return key_data.update(auto_commit=True, **key_info.dict())
# raise ex.NoKeyMatchEx()
#
#
# @router.delete('/apikeys/{key_id}')
# async def delete_api_keys(request: Request, key_id: int, access_key: str):
# user = request.state.user
# await check_api_owner(user.id, key_id)
# search_by_key = ApiKeys.filter(access_key=access_key)
# if not search_by_key.first():
# raise ex.NoKeyMatchEx()
# search_by_key.delete(auto_commit=True)
# return MessageOk()
#
#
# @router.get('/apikeys/{key_id}/whitelists', response_model=List[M.GetAPIWhiteLists])
# async def get_api_keys(request: Request, key_id: int):
# user = request.state.user
# await check_api_owner(user.id, key_id)
# whitelists = ApiWhiteLists.filter(api_key_id=key_id).all()
# return whitelists
#
#
# @router.post('/apikeys/{key_id}/whitelists', response_model=M.GetAPIWhiteLists)
# async def create_api_keys(request: Request, key_id: int, ip: M.CreateAPIWhiteLists, session: Session = Depends(db.session)):
# user = request.state.user
# await check_api_owner(user.id, key_id)
# import ipaddress
# try:
# _ip = ipaddress.ip_address(ip.ip_addr)
# except Exception as e:
# raise ex.InvalidIpEx(ip.ip_addr, e)
# if ApiWhiteLists.filter(api_key_id=key_id).count() == MAX_API_WHITELIST:
# raise ex.MaxWLCountEx()
# ip_dup = ApiWhiteLists.get(api_key_id=key_id, ip_addr=ip.ip_addr)
# if ip_dup:
# return ip_dup
# ip_reg = ApiWhiteLists.create(session=session, auto_commit=True, api_key_id=key_id, ip_addr=ip.ip_addr)
# return ip_reg
#
#
# @router.delete('/apikeys/{key_id}/whitelists/{list_id}')
# async def delete_api_keys(request: Request, key_id: int, list_id: int):
# user = request.state.user
# await check_api_owner(user.id, key_id)
# ApiWhiteLists.filter(id=list_id, api_key_id=key_id).delete()
#
# return MessageOk()
#
#
# async def check_api_owner(user_id, key_id):
# api_keys = ApiKeys.get(id=key_id, user_id=user_id)
# if not api_keys:
# raise ex.NoKeyMatchEx()

View File

@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
"""
@File: date_utils.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: utility [date]
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
from datetime import datetime, date, timedelta
_TIMEDELTA = 9
class D:
def __init__(self, *args):
self.utc_now = datetime.utcnow()
# NOTE(hsj100): utc->kst
self.timedelta = _TIMEDELTA
@classmethod
def datetime(cls, diff: int=_TIMEDELTA) -> datetime:
return datetime.utcnow() + timedelta(hours=diff) if diff > 0 else datetime.now()
@classmethod
def date(cls, diff: int=_TIMEDELTA) -> date:
return cls.datetime(diff=diff).date()
@classmethod
def date_num(cls, diff: int=_TIMEDELTA) -> int:
return int(cls.date(diff=diff).strftime('%Y%m%d'))
@classmethod
def validate(cls, date_text):
try:
datetime.strptime(date_text, '%Y-%m-%d')
except ValueError:
raise ValueError('Incorrect data format, should be YYYY-MM-DD')
@classmethod
def date_str(cls, diff: int = _TIMEDELTA) -> str:
return cls.datetime(diff=diff).strftime('%Y-%m-%dT%H:%M:%S')
@classmethod
def date_version_str(cls, diff: int = _TIMEDELTA) -> str:
return cls.datetime(diff=diff).strftime('%Y%m%dT%H%M%2S')
@classmethod
def check_expire_date(cls, expire_date: datetime):
td = expire_date - datetime.now()
timestamp = td.total_seconds()
return timestamp
@classmethod
def date_str_micro_sec(cls, diff: int = _TIMEDELTA) -> str:
return cls.datetime(diff=diff).strftime('%y-%m-%d %H:%M:%S.%f')[:-3]

View File

@@ -0,0 +1,201 @@
# -*- coding: utf-8 -*-
"""
@File: extra.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: utility [extra]
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
import os
from hashlib import md5
from base64 import b64decode
from base64 import b64encode
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from cryptography.fernet import Fernet # symmetric encryption
# mail test
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from itertools import groupby
from operator import attrgetter
import uuid
from app.common.consts import NUM_RETRY_UUID_GEN, SMTP_HOST, SMTP_PORT
from app.utils.date_utils import D
from app import models as M
from app.common.consts import AES_CBC_PUBLIC_KEY, AES_CBC_IV, FERNET_SECRET_KEY
async def send_mail(sender, sender_pw, title, recipient, contents_plain, contents_html, cc_list, smtp_host=SMTP_HOST, smtp_port=SMTP_PORT):
"""
구글 계정사용시 : 보안 수준이 낮은 앱에서의 접근 활성화
:return:
None: success
Str. Message: error
"""
try:
# check parameters
if not sender:
raise Exception('invalid sender')
if not title:
raise Exception('invalid title')
if not recipient:
raise Exception('invalid recipient')
# sender info.
# sender = consts.ADMIN_INIT_ACCOUNT_INFO.email
# sender_pw = consts.ADMIN_INIT_ACCOUNT_INFO.email_pw
# message
msg = MIMEMultipart()
msg['From'] = sender
msg['To'] = recipient
if cc_list:
list_cc = cc_list
str_cc = ','.join(list_cc)
msg['Cc'] = str_cc
msg['Subject'] = title
if contents_plain:
msg.attach(MIMEText(contents_plain, 'plain'))
if contents_html:
msg.attach(MIMEText(contents_html, 'html'))
# smtp server
smtp_server = smtplib.SMTP(host=smtp_host, port=smtp_port)
smtp_server.ehlo()
smtp_server.starttls()
smtp_server.ehlo()
smtp_server.login(sender, sender_pw)
smtp_server.send_message(msg)
smtp_server.quit()
return None
except Exception as e:
return str(e)
def query_to_groupby(query_result, key, first=False):
"""
쿼리 결과물(list)을 항목값(key)으로 그룹화한다.
:param query_result: 쿼리 결과 리스트
:param key: 그룹 항목값
:return: dict
"""
group_info = dict()
for k, g in groupby(query_result, attrgetter(key)):
if k not in group_info:
if not first:
group_info[k] = list(g)
else:
group_info[k] = list(g)[0]
else:
if not first:
group_info[k].extend(list(g))
return group_info
def query_to_groupby_date(query_result, key):
"""
쿼리 결과물(list)을 항목값(key)으로 그룹화한다.
:param query_result: 쿼리 결과 리스트
:param key: 그룹 항목값
:return: dict
"""
group_info = dict()
for k, g in groupby(query_result, attrgetter(key)):
day_str = k.strftime("%Y-%m-%d")
if day_str not in group_info:
group_info[day_str] = list(g)
else:
group_info[day_str].extend(list(g))
return group_info
class FernetCrypto:
def __init__(self, key=FERNET_SECRET_KEY):
self.key = key
self.f = Fernet(self.key)
def encrypt(self, data, is_out_string=True):
if isinstance(data, bytes):
ou = self.f.encrypt(data) # 바이트형태이면 바로 암호화
else:
ou = self.f.encrypt(data.encode('utf-8')) # 인코딩 후 암호화
if is_out_string is True:
return ou.decode('utf-8') # 출력이 문자열이면 디코딩 후 반환
else:
return ou
def decrypt(self, data, is_out_string=True):
if isinstance(data, bytes):
ou = self.f.decrypt(data) # 바이트형태이면 바로 복호화
else:
ou = self.f.decrypt(data.encode('utf-8')) # 인코딩 후 복호화
if is_out_string is True:
return ou.decode('utf-8') # 출력이 문자열이면 디코딩 후 반환
else:
return ou
class AESCryptoCBC:
def __init__(self, key=AES_CBC_PUBLIC_KEY, iv=AES_CBC_IV):
# Initial vector를 0으로 초기화하여 16바이트 할당함
# iv = chr(0) * 16 #pycrypto 기준
# iv = bytes([0x00] * 16) #pycryptodomex 기준
# aes cbc 생성
self.key = key
self.iv = iv
self.crypto = AES.new(self.key, AES.MODE_CBC, self.iv)
def encrypt(self, data):
# 암호화 message는 16의 배수여야 한다.
# enc = self.crypto.encrypt(data)
# return enc
enc = self.crypto.encrypt(pad(data, AES.block_size))
return b64encode(enc)
def decrypt(self, enc):
# 복호화 enc는 16의 배수여야 한다.
# dec = self.crypto.decrypt(enc)
# return dec
enc = b64decode(enc)
dec = self.crypto.decrypt(enc)
return unpad(dec, AES.block_size)
class AESCipher:
def __init__(self, key):
# self.key = md5(key.encode('utf8')).digest()
self.key = bytes(key.encode('utf-8'))
def encrypt(self, data):
# iv = get_random_bytes(AES.block_size)
iv = bytes('daooldns12345678'.encode('utf-8'))
self.cipher = AES.new(self.key, AES.MODE_CBC, iv)
t = b64encode(self.cipher.encrypt(pad(data.encode('utf-8'), AES.block_size)))
return b64encode(iv + self.cipher.encrypt(pad(data.encode('utf-8'), AES.block_size)))
def decrypt(self, data):
raw = b64decode(data)
self.cipher = AES.new(self.key, AES.MODE_CBC, raw[:AES.block_size])
return unpad(self.cipher.decrypt(raw[AES.block_size:]), AES.block_size)
def file_size_check(data):
_start= data.file.tell()
data.file.seek(0,os.SEEK_END)
size = data.file.tell()
data.file.seek(_start,os.SEEK_SET)
return size

View File

@@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
"""
@File: logger.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: logger
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
import json
import logging
from datetime import timedelta, datetime
from time import time
from fastapi.requests import Request
from fastapi import Body
from fastapi.logger import logger
logger.setLevel(logging.INFO)
async def api_logger(request: Request, response=None, error=None):
time_format = '%Y/%m/%d %H:%M:%S'
t = time() - request.state.start
status_code = error.status_code if error else response.status_code
error_log = None
user = request.state.user
if error:
if request.state.inspect:
frame = request.state.inspect
error_file = frame.f_code.co_filename
error_func = frame.f_code.co_name
error_line = frame.f_lineno
else:
error_func = error_file = error_line = 'UNKNOWN'
error_log = dict(
errorFunc=error_func,
location='{} line in {}'.format(str(error_line), error_file),
raised=str(error.__class__.__name__),
msg=str(error.ex),
)
account = user.account.split('@') if user and user.account else None
user_log = dict(
client=request.state.ip,
user=user.id if user and user.id else None,
account='**' + account[0][2:-1] + '*@' + account[1] if user and user.account else None,
)
log_dict = dict(
url=request.url.hostname + request.url.path,
method=str(request.method),
statusCode=status_code,
errorDetail=error_log,
client=user_log,
processedTime=str(round(t * 1000, 5)) + 'ms',
datetimeUTC=datetime.utcnow().strftime(time_format),
datetimeKST=(datetime.utcnow() + timedelta(hours=9)).strftime(time_format),
)
if error and error.status_code >= 500:
logger.error(json.dumps(log_dict))
else:
logger.info(json.dumps(log_dict))

View File

@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
"""
@File: query_utils.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: utility [query]
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
from typing import List
def to_dict(model, *args, exclude: List = None):
q_dict = {}
for c in model.__table__.columns:
if not args or c.name in args:
if not exclude or c.name not in exclude:
q_dict[c.name] = getattr(model, c.name)
return q_dict

View File

@@ -0,0 +1,2 @@
docker build -t kepco_ai_rbi/rest_ai_engine_control:latest .

View File

@@ -0,0 +1,42 @@
version: '3.7'
services:
mysql:
image: mysql:latest
container_name: mysql
restart: always
environment:
TZ: Asia/Seoul
MYSQL_PORT: 3306
MYSQL_ROOT_PASSWORD: "!ekdnfeldpsdptm1"
MYSQL_DATABASE: "AI_ENGINE_CONTROL"
MYSQL_USER: "aikepco"
MYSQL_PASSWORD: "!ekdnfeldpsdptm1"
healthcheck:
test: "mysql -uroot -p$$MYSQL_ROOT_PASSWORD -e \"SHOW DATABASES;\""
timeout: 5s
retries: 5
command:
- --character-set-server=utf8mb4
- --collation-server=utf8mb4_unicode_ci
volumes:
- /MYSQL/data/:/var/lib/mysql
- /MYSQL/backup:/backupfiles
ports:
- 3306:3306
rest_api:
depends_on:
mysql:
condition: service_healthy
container_name: REST_AI_ENGINE_CONTROL
image: kepco_ai_rbi/rest_ai_engine_control:latest
build:
context: .
dockerfile: ./Dockerfile
environment:
- TZ=Asia/Seoul
volumes:
- ./app:/FAST_API
ports:
- 50520-50522:50520-50522
command: ["uvicorn", "app.main:app", "--host", '0.0.0.0', "--port", "50520"]

View File

@@ -0,0 +1,10 @@
version: '3'
services:
api:
container_name: REST_AI_ENGINE_CONTROL
image: kepco_ai_rbi/rest_ai_engine_control:latest
build:
context: .
ports:
- 50220:50220

View File

@@ -0,0 +1,11 @@
name: REST_AI_ENGINE_CONTROL
channels:
- conda-forge
dependencies:
- python=3.9
- pip
- pip:
- -r requirements.txt

View File

@@ -0,0 +1,222 @@
# Gunicorn configuration file.
#
# Server socket
#
# bind - The socket to bind.
#
# A string of the form: 'HOST', 'HOST:PORT', 'unix:PATH'.
# An IP is a valid HOST.
#
# backlog - The number of pending connections. This refers
# to the number of clients that can be waiting to be
# served. Exceeding this number results in the client
# getting an error when attempting to connect. It should
# only affect servers under significant load.
#
# Must be a positive integer. Generally set in the 64-2048
# range.
#
bind = "0.0.0.0:5000"
backlog = 2048
#
# Worker processes
#
# workers - The number of worker processes that this server
# should keep alive for handling requests.
#
# A positive integer generally in the 2-4 x $(NUM_CORES)
# range. You'll want to vary this a bit to find the best
# for your particular application's work load.
#
# worker_class - The type of workers to use. The default
# sync class should handle most 'normal' types of work
# loads. You'll want to read
# http://docs.gunicorn.org/en/latest/design.html#choosing-a-worker-type
# for information on when you might want to choose one
# of the other worker classes.
#
# A string referring to a Python path to a subclass of
# gunicorn.workers.base.Worker. The default provided values
# can be seen at
# http://docs.gunicorn.org/en/latest/settings.html#worker-class
#
# worker_connections - For the eventlet and gevent worker classes
# this limits the maximum number of simultaneous clients that
# a single process can handle.
#
# A positive integer generally set to around 1000.
#
# timeout - If a worker does not notify the master process in this
# number of seconds it is killed and a new worker is spawned
# to replace it.
#
# Generally set to thirty seconds. Only set this noticeably
# higher if you're sure of the repercussions for sync workers.
# For the non sync workers it just means that the worker
# process is still communicating and is not tied to the length
# of time required to handle a single request.
#
# keepalive - The number of seconds to wait for the next request
# on a Keep-Alive HTTP connection.
#
# A positive integer. Generally set in the 1-5 seconds range.
#
# reload - Restart workers when code changes.
#
# This setting is intended for development. It will cause
# workers to be restarted whenever application code changes.
workers = 3
threads = 3
worker_class = "uvicorn.workers.UvicornWorker"
worker_connections = 1000
timeout = 60
keepalive = 2
reload = True
#
# spew - Install a trace function that spews every line of Python
# that is executed when running the server. This is the
# nuclear option.
#
# True or False
#
spew = False
#
# Server mechanics
#
# daemon - Detach the main Gunicorn process from the controlling
# terminal with a standard fork/fork sequence.
#
# True or False
#
# raw_env - Pass environment variables to the execution environment.
#
# pidfile - The path to a pid file to write
#
# A path string or None to not write a pid file.
#
# user - Switch worker processes to run as this user.
#
# A valid user id (as an integer) or the name of a user that
# can be retrieved with a call to pwd.getpwnam(value) or None
# to not change the worker process user.
#
# group - Switch worker process to run as this group.
#
# A valid group id (as an integer) or the name of a user that
# can be retrieved with a call to pwd.getgrnam(value) or None
# to change the worker processes group.
#
# umask - A mask for file permissions written by Gunicorn. Note that
# this affects unix socket permissions.
#
# A valid value for the os.umask(mode) call or a string
# compatible with int(value, 0) (0 means Python guesses
# the base, so values like "0", "0xFF", "0022" are valid
# for decimal, hex, and octal representations)
#
# tmp_upload_dir - A directory to store temporary request data when
# requests are read. This will most likely be disappearing soon.
#
# A path to a directory where the process owner can write. Or
# None to signal that Python should choose one on its own.
#
daemon = False
pidfile = None
umask = 0
user = None
group = None
tmp_upload_dir = None
#
# Logging
#
# logfile - The path to a log file to write to.
#
# A path string. "-" means log to stdout.
#
# loglevel - The granularity of log output
#
# A string of "debug", "info", "warning", "error", "critical"
#
errorlog = "-"
loglevel = "info"
accesslog = None
access_log_format = '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'
#
# Process naming
#
# proc_name - A base to use with setproctitle to change the way
# that Gunicorn processes are reported in the system process
# table. This affects things like 'ps' and 'top'. If you're
# going to be running more than one instance of Gunicorn you'll
# probably want to set a name to tell them apart. This requires
# that you install the setproctitle module.
#
# A string or None to choose a default of something like 'gunicorn'.
#
proc_name = "NotificationAPI"
#
# Server hooks
#
# post_fork - Called just after a worker has been forked.
#
# A callable that takes a server and worker instance
# as arguments.
#
# pre_fork - Called just prior to forking the worker subprocess.
#
# A callable that accepts the same arguments as after_fork
#
# pre_exec - Called just prior to forking off a secondary
# master process during things like config reloading.
#
# A callable that takes a server instance as the sole argument.
#
def post_fork(server, worker):
server.log.info("Worker spawned (pid: %s)", worker.pid)
def pre_fork(server, worker):
pass
def pre_exec(server):
server.log.info("Forked child, re-executing.")
def when_ready(server):
server.log.info("Server is ready. Spawning workers")
def worker_int(worker):
worker.log.info("worker received INT or QUIT signal")
# get traceback info
import threading, sys, traceback
id2name = {th.ident: th.name for th in threading.enumerate()}
code = []
for threadId, stack in sys._current_frames().items():
code.append("\n# Thread: %s(%d)" % (id2name.get(threadId, ""), threadId))
for filename, lineno, name, line in traceback.extract_stack(stack):
code.append('File: "%s", line %d, in %s' % (filename, lineno, name))
if line:
code.append(" %s" % (line.strip()))
worker.log.debug("\n".join(code))
def worker_abort(worker):
worker.log.info("worker received SIGABRT signal")

View File

@@ -0,0 +1,13 @@
fastapi==0.70.1
uvicorn==0.16.0
pymysql==1.0.2
sqlalchemy==1.4.29
bcrypt==3.2.0
pyjwt==2.3.0
yagmail==0.14.261
boto3==1.20.32
pytest==6.2.5
cryptography
pycryptodomex
pycryptodome
email-validator

View File

@@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-
"""
@file : test_main.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: 개발시 모듈(main) 실행
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
import uvicorn
from app.common.config import conf
if __name__ == '__main__':
print('test_main.py run')
uvicorn.run('app.main:app', host='0.0.0.0', port=conf().REST_SERVER_PORT, reload=True)

View File

View File

@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
"""
@File: conftest.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: test config
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
import asyncio
import os
from os import path
from typing import List
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from app.database.schema import Users
from app.main import create_app
from app.database.conn import db, Base
from app.models import UserToken
from app.routes.auth import create_access_token
"""
1. DB 생성
2. 테이블 생성
3. 테스트 코드 작동
4. 테이블 레코드 삭제
"""
@pytest.fixture(scope='session')
def app():
os.environ['API_ENV'] = 'test'
return create_app()
@pytest.fixture(scope='session')
def client(app):
# Create tables
Base.metadata.create_all(db.engine)
return TestClient(app=app)
@pytest.fixture(scope='function', autouse=True)
def session():
sess = next(db.session())
yield sess
clear_all_table_data(
session=sess,
metadata=Base.metadata,
except_tables=[]
)
sess.rollback()
@pytest.fixture(scope='function')
def login(session):
"""
테스트전 사용자 미리 등록
:param session:
:return:
"""
db_user = Users.create(session=session, email='ryan_test@dingrr.com', pw='123')
session.commit()
access_token = create_access_token(data=UserToken.from_orm(db_user).dict(exclude={'pw', 'marketing_agree'}),)
return dict(Authorization=f'Bearer {access_token}')
def clear_all_table_data(session: Session, metadata, except_tables: List[str] = None):
session.execute('SET FOREIGN_KEY_CHECKS = 0;')
for table in metadata.sorted_tables:
if table.name not in except_tables:
session.execute(table.delete())
session.execute('SET FOREIGN_KEY_CHECKS = 1;')
session.commit()

View File

@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
"""
@File: test_auth.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: test auth.
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
from app.database.conn import db
from app.database.schema import Users
def test_registration(client, session):
"""
레버 로그인
:param client:
:param session:
:return:
"""
user = dict(email='ryan@dingrr.com', pw='123', name='라이언', phone='01099999999')
res = client.post('api/auth/register/email', json=user)
res_body = res.json()
print(res.json())
assert res.status_code == 201
assert 'Authorization' in res_body.keys()
def test_registration_exist_email(client, session):
"""
레버 로그인
:param client:
:param session:
:return:
"""
user = dict(email='Hello@dingrr.com', pw='123', name='라이언', phone='01099999999')
db_user = Users.create(session=session, **user)
session.commit()
res = client.post('api/auth/register/email', json=user)
res_body = res.json()
assert res.status_code == 400
assert 'EMAIL_EXISTS' == res_body['msg']

View File

@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
"""
@File: test_user.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: test user
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
from app.database.conn import db
from app.database.schema import Users
def test_create_get_apikey(client, session, login):
"""
레버 로그인
:param client:
:param session:
:return:
"""
key = dict(user_memo='ryan__key')
res = client.post('api/user/apikeys', json=key, headers=login)
res_body = res.json()
assert res.status_code == 200
assert 'secret_key' in res_body
res = client.get('api/user/apikeys', headers=login)
res_body = res.json()
assert res.status_code == 200
assert 'ryan__key' in res_body[0]['user_memo']

409
ai_engine_const.py Normal file
View File

@@ -0,0 +1,409 @@
# -*- coding: utf-8 -*-
from REST_AI_ENGINE_CONTROL.app import models as M
import project_config
import os
# MQTT
MQTT_HOST = 'localhost'
# source
PROJECT_PATH = os.path.dirname(os.path.abspath(__file__))
DAOOL_RTSP = "rtsp://daool:Ekdnfeldpsdptm1@211.63.236.6:52554/axis-media/media.amp"
RTSP = "rtsp://223.171.144.245:8554/cam/0/low"
STREAMS_PATH = PROJECT_PATH + "/DL/wd.streams"
CON_VIDEO_PATH = PROJECT_PATH +"/AI_ENGINE/DATA/CON.mp4"
FACE_VIDEO_PATH = PROJECT_PATH +"/AI_ENGINE/DATA/FR.mov"
PPE_VIDEO_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/PPE.mp4"
WD_VIDEO_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/WD_2.mp4"
if project_config.CONFIG == project_config.CONFIG_AISERVER:
# <aiserver>
#MQTT
MQTT_PORT = 50083
MQTT_USER_ID = 'kepco'
MQTT_USER_PW = '!kepco1234'
#source
RTSP = "rtsp://223.171.144.245:8554/cam/0/low"
CON_SOURCE = CON_VIDEO_PATH
FR_SOURCE = FACE_VIDEO_PATH
PPE_SOURCE = PPE_VIDEO_PATH
WD_SOURCE = WD_VIDEO_PATH
#SFTP
FTP_IP = "106.255.245.242"
FTP_PORT = 2022
FTP_ID = "kepri_if_user"
FTP_PW = "kepri!123"
FTP_LOCATION = "/home/agics-dev/kepri_storage/rndpartners/"
FTP_CON_FILE_NAME = "3"
FTP_FR_FILE_NAME = "1"
FTP_PPE_FILE_NAME = "2"
FTP_WD_FILE_NAME = "5"
FTP_BI_FILE_NAME = "4"
elif project_config.CONFIG == project_config.CONFIG_MG:
#<KEPCO-MG>
# #MQTT
# MQTT_PORT = 1883
# MQTT_USER_ID = 'admin'
# MQTT_USER_PW = 'admin'
MQTT_PORT = 11883
MQTT_USER_ID = 'kepco'
MQTT_USER_PW = '!kepco1234'
# #source
# RTSP = "rtsp://10.20.10.1:8554/cam/0/low"
# RTSP = "rtsp://192.168.39.20:8554/cam/0/low"
RTSP = "rtsp://admin:admin1263!@10.20.10.99:28554/onvif/media?profile=Profile2"
CON_SOURCE = RTSP
FR_SOURCE = RTSP
PPE_SOURCE = RTSP
WD_SOURCE = RTSP
# ----------- <test> ----------
#CON_SOURCE = "rtsp://219.250.188.204:8554/con"
#FR_SOURCE = "rtsp://219.250.188.205:8554/fr"
#PPE_SOURCE = "rtsp://219.250.188.206:8554/ppe"
#WD_SOURCE = "rtsp://219.250.188.207:8554/wd"
# ----------- <test> ----------
if project_config.DEBUG_MODE:
CON_SOURCE = CON_VIDEO_PATH
FR_SOURCE = FACE_VIDEO_PATH
PPE_SOURCE = PPE_VIDEO_PATH
WD_SOURCE = WD_VIDEO_PATH
# #SFTP
FTP_IP = "106.255.245.242"
FTP_PORT = 2022
FTP_ID = "kepri_if_user"
FTP_PW = "kepri!123"
FTP_LOCATION = "/home/agics-dev/kepri_storage/rndpartners/"
# ----------- <test> ----------
# FTP_IP = "211.63.236.6"
# FTP_PORT = 50002
# FTP_ID = "fermat"
# FTP_PW = "1234"
# FTP_LOCATION = "/home/fermat/work/rest_ftp_test"
# ----------- <test> ----------
FTP_CON_FILE_NAME = "3"
FTP_FR_FILE_NAME = "1"
FTP_PPE_FILE_NAME = "2"
FTP_WD_FILE_NAME = "5"
FTP_BI_FILE_NAME = "4"
else :
# <fermat>
# MQTT(dev3)
MQTT_PORT = 1883
MQTT_USER_ID = 'admin'
MQTT_USER_PW = '12341234'
RTSP = DAOOL_RTSP
CON_SOURCE = CON_VIDEO_PATH
FR_SOURCE = FACE_VIDEO_PATH
PPE_SOURCE = PPE_VIDEO_PATH
WD_SOURCE = WD_VIDEO_PATH
# FTP
FTP_IP = "192.168.200.232"
FTP_PORT = 22
FTP_ID = "fermat"
FTP_PW = "1234"
FTP_LOCATION = "/home/fermat/work/rest_ftp_test"
FTP_CON_FILE_NAME = "c"
FTP_FR_FILE_NAME = "a"
FTP_PPE_FILE_NAME = "b"
FTP_WD_FILE_NAME = "e"
FTP_BI_FILE_NAME = "d"
#FTP result path
FTP_CON_RESULT = PROJECT_PATH + '/AI_ENGINE/DATA/ftp_data/con_setup.jpg'
FTP_FR_RESULT = PROJECT_PATH + '/AI_ENGINE/DATA/ftp_data/fr.jpg'
FTP_PPE_RESULT = PROJECT_PATH + '/AI_ENGINE/DATA/ftp_data/ppe.jpg'
FTP_WD_RESULT = PROJECT_PATH + '/AI_ENGINE/DATA/ftp_data/wd.jpg'
FTP_BI_RESULT = PROJECT_PATH + '/AI_ENGINE/DATA/ftp_data/bi.jpg'
# TOPIC
MQTT_CON_TOPIC = '/AI_KEPCO/AI_OD_CON_SETUP_DETECT/REPORT'
MQTT_FR_TOPIC = '/AI_KEPCO/AI_FACE_RECOGNIZE/REPORT'
MQTT_PPE_TOPIC = '/AI_KEPCO/AI_OD_PPE_DETECT/REPORT'
MQTT_PPE_FR_TOPIC = '/AI_KEPCO/AI_OD_PPE_FR_DETECT/REPORT' # test topic
MQTT_WD_TOPIC = '/AI_KEPCO/AI_OD_WORK_DETECT/REPORT'
MQTT_BI_TOPIC = '/AI_KEPCO/AI_BI_DETECT/REPORT'
# AI_MODEL
MODEL_CON = 'CON'
MODEL_PPE = 'PPE'
MODEL_WORK_DETECT = 'WD'
MODEL_FACE_RECOGNIZE = 'FR'
MODEL_BIO_INFO = 'BI'
# YOLO
BBOX_XYXY = 'XYXY'
BBOX_XYWH = 'XYWH'
# WD
WD_FRAME_COUNT = 80
# FACE_RECOGNITION
FACE_EVOLUTION_DISTANCE = 0.4
WORKER1_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco1.jpg"
WORKER2_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco1_1.jpg"
WORKER3_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco1_2.jpg"
if project_config.DEBUG_MODE:
WORKER1_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/facerec_worker1.png"
WORKER2_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/facerec_worker2.png"
WORKER3_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/facerec_worker3.png"
WORKER4_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/jangys_re.jpg"
WORKER5_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/whangsj.jpg"
WORKER6_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kimjw_re.jpg"
WORKER7_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/ksy_re.jpg"
WORKER8_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/agics.jpg"
WORKER9_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco1.jpg"
WORKER10_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_1.jpg"
WORKER11_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_2.jpg"
WORKER12_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_3.jpg"
WORKER13_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_4.jpg"
WORKER14_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_5.jpg"
WORKER15_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_6.jpg"
WORKER16_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_7.jpg"
WORKER17_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_8.jpg"
WORKER18_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_9.jpg"
WORKER19_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_10.jpg"
WORKER20_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_11.jpg"
WORKER21_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_12.jpg"
WORKER22_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_13.jpg"
# WORKER16_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/yunikim.jpg"
# WORKER17_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/facerec_worker3.png"
SIGNAL_INFERENCE = 'inference'
SIGNAL_STOP = 'stop'
SOURCE_CHANGED_MSG = "INPUT SOURCE CHANGED"
IMG_CHANGED_MSG = "INPUT IMAGE CHANGED"
INVALID_IMG_MSG = "INVALID IMAGE SIZE"
DEMO_KEY_NAME_SNAPSHOT_SFTP = "SnapshotSFTP"
OFF_CLASS_LIST = [2,4,7,9,11,15,17,19]
OFF_TRIGGER_CLASS_LIST = [2,4] #TODO(jwkim):BIXPO test (helmet,gloves)
AI_ENGINE_INIT = {
"version": project_config.PROJECT_VERSION,
"ai_engine_status": "init",
"demo": {
"ftp": {
"ip": FTP_IP,
"port": FTP_PORT,
"id": FTP_ID,
"pw": FTP_PW,
"location": FTP_LOCATION,
"file_con_setup": FTP_CON_FILE_NAME,
"file_face": FTP_FR_FILE_NAME,
"file_ppe": FTP_PPE_FILE_NAME,
"file_wd": FTP_WD_FILE_NAME,
"file_bi": FTP_BI_FILE_NAME
}
},
"input_video": [
{
"name": "CON_video",
"model": M.AEAIModelType.CON,
"sn": "serial_no",
"connect_url": CON_SOURCE,
"user_id": "user_id",
"user_pw": "user_pw"
},
{
"name": "FR_video",
"model": M.AEAIModelType.FR,
"sn": "serial_no",
"connect_url": FR_SOURCE,
"user_id": "user_id",
"user_pw": "user_pw"
},
{
"name": "PPE_video",
"model": M.AEAIModelType.PPE,
"sn": "serial_no",
"connect_url": PPE_SOURCE,
"user_id": "user_id",
"user_pw": "user_pw"
},
{
"name": "WD_video_1",
"model": M.AEAIModelType.WORK,
"sn": "serial_no",
"connect_url": WD_SOURCE,
"user_id": "user_id",
"user_pw": "user_pw"
}
],
"input_bi": {
"name": "device_name",
"model": "model_name",
"sn": "serial_no",
"connect_url": "connect_url",
"user_id": "user_id",
"user_pw": "user_pw",
"topic": "topic"
},
"con_model_info": {
"name": "CON",
"version": "20220101",
"status": "init",
"ri": {
"construction_code": "D54",
"work_no": 3,
"work_define_ri": 0.82,
"ri_parameter_list": [
{
"name": "작업자 숙련도",
"ratio": 0.6
},
{
"name": "작업자 교육레벨",
"ratio": 0.5
}
],
"evaluation_work_ri": 0
},
"weights": [
{
"id": 0,
"filename": "index_78.pt",
"version": "0.1",
"date": "2023-01-31T10:12:17",
"model": "small",
"nc": 26
}
],
"mode": "con",
"crop_images": False
},
"fr_model_info": {
"name": "FR",
"version": "20220101",
"status": "init",
"ri": {
"construction_code": "D54",
"work_no": 3,
"work_define_ri": 0.82,
"ri_parameter_list": [
{
"name": "작업자 숙련도",
"ratio": 0.6
},
{
"name": "작업자 교육레벨",
"ratio": 0.5
}
],
"evaluation_work_ri": 0
}
},
"ppe_model_info": {
"name": "PPE",
"version": "20220101",
"status": "init",
"ri": {
"construction_code": "D54",
"work_no": 3,
"work_define_ri": 0.82,
"ri_parameter_list": [
{
"name": "작업자 숙련도",
"ratio": 0.6
},
{
"name": "작업자 교육레벨",
"ratio": 0.5
}
],
"evaluation_work_ri": 0
},
"weights": [
{
"id": 0,
"filename": "index_78.pt",
"version": "0.1",
"date": "2023-01-31T10:12:17",
"model": "small",
"nc": 26
}
],
"mode": "ppe",
"crop_images": False
},
"wd_model_info": {
"name": "WORK",
"version": "20220101",
"status": "init",
"ri": {
"construction_code": "D54",
"work_no": 3,
"work_define_ri": 0.82,
"ri_parameter_list": [
{
"name": "작업자 숙련도",
"ratio": 0.6
},
{
"name": "작업자 교육레벨",
"ratio": 0.5
}
],
"evaluation_work_ri": 0
},
"weights": [
{
"id": 0,
"filename": "index_78.pt",
"version": "0.1",
"date": "2023-01-31T10:12:17",
"model": "small",
"nc": 26
}
],
"mode": "work",
"crop_images": False
},
"bi_model_info": {
"name": "BI",
"version": "20220101",
"status": "init",
"ri": {
"construction_code": "D54",
"work_no": 3,
"work_define_ri": 0.82,
"ri_parameter_list": [
{
"name": "작업자 숙련도",
"ratio": 0.6
},
{
"name": "작업자 교육레벨",
"ratio": 0.5
}
],
"evaluation_work_ri": 0
}
}
}

27
ai_engine_main.py Normal file
View File

@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
"""
@file : ai_engine_main.py
@author: hsj100
@license: A2TEC & DAOOLDNS
@brief: 개발시 모듈(main) 실행
@section Modify History
- 2022-01-14 오전 11:31 hsj100 base
"""
import os, sys
import uvicorn
# AI_ENGINE_PATH = "/AI_ENGINE"
REST_SERVER_PATH = "/REST_AI_ENGINE_CONTROL"
# sys.path.append(os.path.abspath(os.path.dirname(__file__)) + AI_ENGINE_PATH)
sys.path.append(os.path.abspath(os.path.dirname(__file__)) + REST_SERVER_PATH)
from REST_AI_ENGINE_CONTROL.app.common.config import conf
if __name__ == '__main__':
print('main.py run')
# os.system(". ./rtsp_start.sh")
uvicorn.run('REST_AI_ENGINE_CONTROL.app.main:app', host='0.0.0.0', port=conf().REST_SERVER_PORT, reload=True)

47
demo_bi.py Normal file
View File

@@ -0,0 +1,47 @@
import cv2
import os
import paramiko
LOCAL_PATH = f"/home/kepco/daooldns/KEPCO_AI_RBI_SIM/ENGINE/AI_ENGINE/DATA/ftp_data/local.jpg"
def bi_snap_shot():
# 10.20.10.99
# 192.168.39.20
rtsp = "rtsp://10.20.10.1:8554/cam/0/low"
rtsp = "rtsp://admin:admin1263!@10.20.10.99:554/onvif/media?profile=Profile2"
local_path = f"/home/kepco/daooldns/KEPCO_AI_RBI_SIM/ENGINE/AI_ENGINE/DATA/ftp_data/local.jpg"
if os.path.exists(local_path):
os.remove(local_path)
input_movie = cv2.VideoCapture(rtsp)
ret, frame = input_movie.read()
print(local_path)
cv2.imwrite(local_path,frame)
_bi_sftp_upload()
cv2.destroyAllWindows()
print(f"bi uploaded")
def _bi_sftp_upload():
try:
IP = "106.255.245.242"
transprot = paramiko.Transport((IP,2022))
transprot.connect(username = "kepri_if_user", password = "kepri!123")
sftp = paramiko.SFTPClient.from_transport(transprot)
remotepath = "/home/agics-dev/kepri_storage/rndpartners/" + os.sep + "remote" + '.jpg'
#sftp.put(LOCAL_PATH, remotepath)
sftp.close()
transprot.close()
return remotepath
except Exception as e:
return ""
if __name__ == '__main__':
bi_snap_shot()

16
main.py Normal file
View File

@@ -0,0 +1,16 @@
# 샘플 Python 스크립트입니다.
# Ctrl+F5을(를) 눌러 실행하거나 내 코드로 바꿉니다.
# 클래스, 파일, 도구 창, 액션 및 설정을 어디서나 검색하려면 Shift 두 번을(를) 누릅니다.
def print_hi(name):
# 스크립트를 디버그하려면 하단 코드 줄의 중단점을 사용합니다.
print(f'Hi, {name}') # 중단점을 전환하려면 F9을(를) 누릅니다.
# 스크립트를 실행하려면 여백의 녹색 버튼을 누릅니다.
if __name__ == '__main__':
print_hi('PyCharm')
# https://www.jetbrains.com/help/pycharm/에서 PyCharm 도움말 참조

15
project_config.py Normal file
View File

@@ -0,0 +1,15 @@
PROJECT_VERSION = 0.5
CONFIG_AISERVER = "AISERVER"
CONFIG_MG = "MG"
CONFIG_FERMAT = "FERMAT"
# CONFIG = CONFIG_AISERVER
CONFIG = CONFIG_MG
# CONFIG = CONFIG_FERMAT
DEBUG_MODE = False
SFTP_UPLOAD = False
FR_UPLOAD = True

54
requirements.txt Normal file
View File

@@ -0,0 +1,54 @@
#1. ADTK (python>3.5)
adtk
#2. Face_recognition (python>3.3)
# cmake ,dlib 설치 필요
face-recognition
opencv-python
#3. Yolov5
# torch,torchvision 별도 설치
# --- base ---
gitpython
ipython # interactive notebook
matplotlib>=3.2.2
numpy>=1.18.5
opencv-python>=4.1.1
Pillow>=7.1.2
psutil # system resources
PyYAML>=5.3.1
requests>=2.23.0
scipy>=1.4.1
thop>=0.1.1 # FLOPs computation
#torch>=1.7.0 # see https://pytorch.org/get-started/locally (recommended)
#torchvision>=0.8.1
tqdm>=4.64.0
# --- Logging ---
tensorboard>=2.4.1
# --- Plotting ---
pandas>=1.1.4
seaborn>=0.11.0
#4. REST server
fastapi==0.70.1
uvicorn==0.16.0
pymysql==1.0.2
sqlalchemy==1.4.29
bcrypt==3.2.0
pyjwt==2.3.0
yagmail==0.14.261
boto3==1.20.32
pytest==6.2.5
cryptography
pycryptodomex
pycryptodome
email-validator
python-multipart
#5. MQTT
paho-mqtt
#6. FTP
paramiko

1
rtsp_start.sh Normal file
View File

@@ -0,0 +1 @@
mosquitto_pub -p 11883 -h localhost -t ptz -m 1,0.95,0 -u admin -P admin