diff --git a/AI_ENGINE/DATA/4NI_0400.MOV_20230104_101349.426.png b/AI_ENGINE/DATA/4NI_0400.MOV_20230104_101349.426.png new file mode 100644 index 0000000..2a9ab72 Binary files /dev/null and b/AI_ENGINE/DATA/4NI_0400.MOV_20230104_101349.426.png differ diff --git a/AI_ENGINE/DATA/CON.mp4 b/AI_ENGINE/DATA/CON.mp4 new file mode 100644 index 0000000..1c14545 Binary files /dev/null and b/AI_ENGINE/DATA/CON.mp4 differ diff --git a/AI_ENGINE/DATA/FR.mov b/AI_ENGINE/DATA/FR.mov new file mode 100644 index 0000000..35debde Binary files /dev/null and b/AI_ENGINE/DATA/FR.mov differ diff --git a/AI_ENGINE/DATA/PPE.mp4 b/AI_ENGINE/DATA/PPE.mp4 new file mode 100644 index 0000000..7e6c377 Binary files /dev/null and b/AI_ENGINE/DATA/PPE.mp4 differ diff --git a/AI_ENGINE/DATA/WD_2.mp4 b/AI_ENGINE/DATA/WD_2.mp4 new file mode 100644 index 0000000..0fbb931 Binary files /dev/null and b/AI_ENGINE/DATA/WD_2.mp4 differ diff --git a/AI_ENGINE/DATA/agics.jpg b/AI_ENGINE/DATA/agics.jpg new file mode 100644 index 0000000..e11f661 Binary files /dev/null and b/AI_ENGINE/DATA/agics.jpg differ diff --git a/AI_ENGINE/DATA/facerec_worker1.png b/AI_ENGINE/DATA/facerec_worker1.png new file mode 100644 index 0000000..9376d11 Binary files /dev/null and b/AI_ENGINE/DATA/facerec_worker1.png differ diff --git a/AI_ENGINE/DATA/facerec_worker2.png b/AI_ENGINE/DATA/facerec_worker2.png new file mode 100644 index 0000000..63e848d Binary files /dev/null and b/AI_ENGINE/DATA/facerec_worker2.png differ diff --git a/AI_ENGINE/DATA/facerec_worker3.png b/AI_ENGINE/DATA/facerec_worker3.png new file mode 100644 index 0000000..89fbdd8 Binary files /dev/null and b/AI_ENGINE/DATA/facerec_worker3.png differ diff --git a/AI_ENGINE/DATA/ftp_data/bi.jpg b/AI_ENGINE/DATA/ftp_data/bi.jpg new file mode 100644 index 0000000..31a033f Binary files /dev/null and b/AI_ENGINE/DATA/ftp_data/bi.jpg differ diff --git a/AI_ENGINE/DATA/ftp_data/con_setup.jpg b/AI_ENGINE/DATA/ftp_data/con_setup.jpg new file mode 100644 index 0000000..3f18f8f Binary files /dev/null and b/AI_ENGINE/DATA/ftp_data/con_setup.jpg differ diff --git a/AI_ENGINE/DATA/ftp_data/fr.jpg b/AI_ENGINE/DATA/ftp_data/fr.jpg new file mode 100644 index 0000000..b9a8b33 Binary files /dev/null and b/AI_ENGINE/DATA/ftp_data/fr.jpg differ diff --git a/AI_ENGINE/DATA/ftp_data/local.jpg b/AI_ENGINE/DATA/ftp_data/local.jpg new file mode 100644 index 0000000..b0e0d8c Binary files /dev/null and b/AI_ENGINE/DATA/ftp_data/local.jpg differ diff --git a/AI_ENGINE/DATA/ftp_data/ppe.jpg b/AI_ENGINE/DATA/ftp_data/ppe.jpg new file mode 100644 index 0000000..08a3ae7 Binary files /dev/null and b/AI_ENGINE/DATA/ftp_data/ppe.jpg differ diff --git a/AI_ENGINE/DATA/ftp_data/wd.jpg b/AI_ENGINE/DATA/ftp_data/wd.jpg new file mode 100644 index 0000000..d93b6d5 Binary files /dev/null and b/AI_ENGINE/DATA/ftp_data/wd.jpg differ diff --git a/AI_ENGINE/DATA/jangys.jpg b/AI_ENGINE/DATA/jangys.jpg new file mode 100644 index 0000000..775894d Binary files /dev/null and b/AI_ENGINE/DATA/jangys.jpg differ diff --git a/AI_ENGINE/DATA/jangys_re.jpg b/AI_ENGINE/DATA/jangys_re.jpg new file mode 100644 index 0000000..b2836b7 Binary files /dev/null and b/AI_ENGINE/DATA/jangys_re.jpg differ diff --git a/AI_ENGINE/DATA/kepco1.jpg b/AI_ENGINE/DATA/kepco1.jpg new file mode 100644 index 0000000..3bbc8a3 Binary files /dev/null and b/AI_ENGINE/DATA/kepco1.jpg differ diff --git a/AI_ENGINE/DATA/kepco1_1.jpg b/AI_ENGINE/DATA/kepco1_1.jpg new file mode 100644 index 0000000..09af614 Binary files /dev/null and b/AI_ENGINE/DATA/kepco1_1.jpg differ diff --git a/AI_ENGINE/DATA/kepco1_2.jpg b/AI_ENGINE/DATA/kepco1_2.jpg new file mode 100644 index 0000000..c5b53fc Binary files /dev/null and b/AI_ENGINE/DATA/kepco1_2.jpg differ diff --git a/AI_ENGINE/DATA/kepco2.jpg b/AI_ENGINE/DATA/kepco2.jpg new file mode 100644 index 0000000..c51a290 Binary files /dev/null and b/AI_ENGINE/DATA/kepco2.jpg differ diff --git a/AI_ENGINE/DATA/kepco2_1.jpg b/AI_ENGINE/DATA/kepco2_1.jpg new file mode 100644 index 0000000..aaaab44 Binary files /dev/null and b/AI_ENGINE/DATA/kepco2_1.jpg differ diff --git a/AI_ENGINE/DATA/kepco2_10.jpg b/AI_ENGINE/DATA/kepco2_10.jpg new file mode 100755 index 0000000..49c2d34 Binary files /dev/null and b/AI_ENGINE/DATA/kepco2_10.jpg differ diff --git a/AI_ENGINE/DATA/kepco2_11.jpg b/AI_ENGINE/DATA/kepco2_11.jpg new file mode 100755 index 0000000..a710890 Binary files /dev/null and b/AI_ENGINE/DATA/kepco2_11.jpg differ diff --git a/AI_ENGINE/DATA/kepco2_12.jpg b/AI_ENGINE/DATA/kepco2_12.jpg new file mode 100755 index 0000000..01a27be Binary files /dev/null and b/AI_ENGINE/DATA/kepco2_12.jpg differ diff --git a/AI_ENGINE/DATA/kepco2_13.jpg b/AI_ENGINE/DATA/kepco2_13.jpg new file mode 100755 index 0000000..ef8e167 Binary files /dev/null and b/AI_ENGINE/DATA/kepco2_13.jpg differ diff --git a/AI_ENGINE/DATA/kepco2_2.jpg b/AI_ENGINE/DATA/kepco2_2.jpg new file mode 100644 index 0000000..da28b56 Binary files /dev/null and b/AI_ENGINE/DATA/kepco2_2.jpg differ diff --git a/AI_ENGINE/DATA/kepco2_3.jpg b/AI_ENGINE/DATA/kepco2_3.jpg new file mode 100644 index 0000000..4c12657 Binary files /dev/null and b/AI_ENGINE/DATA/kepco2_3.jpg differ diff --git a/AI_ENGINE/DATA/kepco2_4.jpg b/AI_ENGINE/DATA/kepco2_4.jpg new file mode 100755 index 0000000..26738e2 Binary files /dev/null and b/AI_ENGINE/DATA/kepco2_4.jpg differ diff --git a/AI_ENGINE/DATA/kepco2_5.jpg b/AI_ENGINE/DATA/kepco2_5.jpg new file mode 100755 index 0000000..947639e Binary files /dev/null and b/AI_ENGINE/DATA/kepco2_5.jpg differ diff --git a/AI_ENGINE/DATA/kepco2_6.jpg b/AI_ENGINE/DATA/kepco2_6.jpg new file mode 100755 index 0000000..94dfbbf Binary files /dev/null and b/AI_ENGINE/DATA/kepco2_6.jpg differ diff --git a/AI_ENGINE/DATA/kepco2_7.jpg b/AI_ENGINE/DATA/kepco2_7.jpg new file mode 100755 index 0000000..1502759 Binary files /dev/null and b/AI_ENGINE/DATA/kepco2_7.jpg differ diff --git a/AI_ENGINE/DATA/kepco2_8.jpg b/AI_ENGINE/DATA/kepco2_8.jpg new file mode 100755 index 0000000..f2b58c0 Binary files /dev/null and b/AI_ENGINE/DATA/kepco2_8.jpg differ diff --git a/AI_ENGINE/DATA/kepco2_9.jpg b/AI_ENGINE/DATA/kepco2_9.jpg new file mode 100755 index 0000000..5034307 Binary files /dev/null and b/AI_ENGINE/DATA/kepco2_9.jpg differ diff --git a/AI_ENGINE/DATA/kimjw.jpg b/AI_ENGINE/DATA/kimjw.jpg new file mode 100644 index 0000000..999d858 Binary files /dev/null and b/AI_ENGINE/DATA/kimjw.jpg differ diff --git a/AI_ENGINE/DATA/kimjw_re.jpg b/AI_ENGINE/DATA/kimjw_re.jpg new file mode 100644 index 0000000..72aecd0 Binary files /dev/null and b/AI_ENGINE/DATA/kimjw_re.jpg differ diff --git a/AI_ENGINE/DATA/ksy_re.jpg b/AI_ENGINE/DATA/ksy_re.jpg new file mode 100644 index 0000000..829e3c1 Binary files /dev/null and b/AI_ENGINE/DATA/ksy_re.jpg differ diff --git a/AI_ENGINE/DATA/whangsj.jpg b/AI_ENGINE/DATA/whangsj.jpg new file mode 100644 index 0000000..0af9843 Binary files /dev/null and b/AI_ENGINE/DATA/whangsj.jpg differ diff --git a/AI_ENGINE/DATA/yunikim.jpg b/AI_ENGINE/DATA/yunikim.jpg new file mode 100755 index 0000000..0355ac2 Binary files /dev/null and b/AI_ENGINE/DATA/yunikim.jpg differ diff --git a/AI_ENGINE/demo_utils.py b/AI_ENGINE/demo_utils.py new file mode 100644 index 0000000..0f29496 --- /dev/null +++ b/AI_ENGINE/demo_utils.py @@ -0,0 +1,52 @@ +import threading +import ai_engine_const as AI_CONST +import cv2 +import os +import paramiko + +#demo +global DEMO_WD_BI_CONST +DEMO_WD_BI_CONST = 0 + +demo_lock = threading.Lock() +def demo_wd_bi(i): + global DEMO_WD_BI_CONST + demo_lock.acquire() + + DEMO_WD_BI_CONST = i + + demo_lock.release() + print(DEMO_WD_BI_CONST) + +def bi_snap_shot(): + rtsp = AI_CONST.RTSP + local_path = AI_CONST.FTP_BI_RESULT + + if os.path.exists(local_path): + os.remove(local_path) + + input_movie = cv2.VideoCapture(rtsp) + + ret, frame = input_movie.read() + print(local_path) + cv2.imwrite(local_path,frame) + _bi_sftp_upload() + cv2.destroyAllWindows() + print(f"bi uploaded") + +def _bi_sftp_upload(): + try: + transprot = paramiko.Transport((AI_CONST.FTP_IP,AI_CONST.FTP_PORT)) + transprot.connect(username = AI_CONST.FTP_ID, password = AI_CONST.FTP_PW) + sftp = paramiko.SFTPClient.from_transport(transprot) + + remotepath = AI_CONST.FTP_LOCATION + os.sep + AI_CONST.FTP_BI_FILE_NAME + '.jpg' + + #sftp.put(AI_CONST.FTP_BI_RESULT, remotepath) + + sftp.close() + transprot.close() + + return remotepath + except Exception as e: + return "" \ No newline at end of file diff --git a/AI_ENGINE/instance_queue.py b/AI_ENGINE/instance_queue.py new file mode 100644 index 0000000..6713900 --- /dev/null +++ b/AI_ENGINE/instance_queue.py @@ -0,0 +1,693 @@ +# -*- coding: utf-8 -*- +""" +@file : instance_queue.py +@author: jwkim +@license: A2TEC & DAOOLDNS + +@section Modify History +- 2023-01-22 오전 11:31 jwkim base + +""" + +import os, sys +import time +import threading +from queue import Queue + +DL_BASE_PATH = os.path.abspath(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) +sys.path.append(DL_BASE_PATH) +sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) # const + +import json + +import ai_engine_const as AI_CONST +from mqtt_publish import client +from DL import d2_od_detect +from DL.FR.d2_face_detect import FaceDetect + +from REST_AI_ENGINE_CONTROL.app.utils.date_utils import D +from REST_AI_ENGINE_CONTROL.app import models as M +from DL.d2_od_detect import D2_ROOT + + +class ThreadValues: + """ + thread간 공유 데이터 + """ + STATUS_START = 'start' + STATUS_COMPLETE = 'complete' + STATUS_NEW = 'new' + STATUS_TIMEOUT = 'timeout' + STATUS_STOP = 'stop' + STATUS_ERROR = '[error] ai model error : ' + STATUS_NONE = 'None' + STATUS_DETECT = 'detect' + + date_utill = D() + + def __init__(self): + self.current_status = self.STATUS_NONE + self.result = [] + self.report_unit_result = [] + self.frame_result = {} + self.timeout_status = False + + self.wd_ri = 0 + + +class WorkQueue: + # max_workers = 2 + # current_workers = 0 + + message_queue = Queue() + + # event_thread = threading.Event() + + def __init__(self): + # self.current_workers = 0 + self.message_queue = Queue() + # self.lock = threading.Lock() + # self.event_thread = threading.Event() + receiver_start_thread = threading.Thread(target=self.classifier, daemon=True) + receiver_start_thread.start() + # event_check_thread = threading.Thread(target=self.event, daemon=True).start() + + def sender(self, data): + self.message_queue.put(data) + + # TODO(jwkim): 이벤트 처리 + # def event(self): + # if self.current_workers <= self.max_workers: + # print("-----------------event--------------------",self.current_workers) + # self.event_thread.set() + # else : + # pass + + def classifier(self): + """ + sender를 통해 받은 메세지 처리 + ai_model을 통해 해당 model queue에 전송 + """ + while True: + queue_data = self.message_queue.get() + + # TODO(jwkim): 이벤트 worker 체킹 + # self.event_thread.wait() + + # manager.lock.acquire() + # manager.current_workers += 1 + # manager.lock.release() + + if queue_data['ai_model'] == AI_CONST.MODEL_CON: + CON.con_sender(queue_data) + elif queue_data['ai_model'] == AI_CONST.MODEL_PPE: + PPE.ppe_sender(queue_data) + elif queue_data['ai_model'] == AI_CONST.MODEL_WORK_DETECT: + WD.wd_sender(queue_data) + elif queue_data['ai_model'] == AI_CONST.MODEL_FACE_RECOGNIZE: + FR.fr_sender(queue_data) + elif queue_data['ai_model'] == AI_CONST.MODEL_BIO_INFO: + pass + + +class PPEManager: + """ + PPE Manager + """ + TYPE_DETECT_SEQ = 0 # 순서적 처리 + TYPE_DETECT_PAL = 1 + + def __init__(self, detect_type: int = TYPE_DETECT_SEQ): + self.ppe_queue = Queue() + + self.detect_type = detect_type + + self.ppe_receiver_thread = threading.Thread(target=self.ppe_receiver, daemon=True, name='ppe_receiver') + self.ppe_receiver_thread.start() + + self.worker_event = threading.Event() # worker check + self.stop_event = threading.Event() # inference stop trigger + self.timeout_event = threading.Event() # timeout check + + self.current_worker = '' + + def ppe_sender(self, data): + self.ppe_queue.put(data) + + def ppe_receiver(self): + """ + ppe_queue를 통해 데이터를 받아서 inference 시작, 정지 실행 + + ppe_thread가 실행중인 상황에서 또다른 데이터가 들어올시 + ppe_thread 종료후 새로운 ppe_thread 실행 + + 병렬처리 미구현 + """ + while True: + queue_data = self.ppe_queue.get() + if queue_data: + if self.detect_type == PPEManager.TYPE_DETECT_SEQ: + + # start inference + if queue_data["signal"] == AI_CONST.SIGNAL_INFERENCE: + if self.current_worker: + if self.current_worker.is_alive(): + self.stop_event.set() + self.worker_event.wait() + + self.worker_event.clear() + ppe_thread = threading.Thread(target=self.ppe_start, daemon=True, args=(queue_data,), + name='ppe_thread').start() + + # stop inference + elif queue_data["signal"] == AI_CONST.SIGNAL_STOP: + if self.current_worker: + if self.current_worker.is_alive(): + self.stop_event.set() + self.worker_event.wait() + self.worker_event.clear() + + def ppe_start(self, data): + """ + inference 와 timeout 관리 + timeout 발생시 inference 종료 + inference 종료시 timeout_thread와 같이 종료 + :param data: queue_data + """ + + try: + share_value = ThreadValues() + + data['argument']['source'] = self._input_source(data['engine_info'].input_video) + + ppe_detect = d2_od_detect.PPEDetect( + request=data['request'], + engine_info=data['engine_info'], + thread_value=share_value, + yolo_argument=data['argument'], + worker_event=self.worker_event, + stop_event=self.stop_event, + timeout_event=self.timeout_event, + queue_info=self.ppe_queue, + ) + self.current_worker = threading.Thread(target=ppe_detect.run, daemon=True, kwargs=data['argument'], + name='ppe_inference') + self.current_worker.start() + + timeout_value = data['request'].limit_time_min * 60 + timeout_thread = threading.Thread(target=self._timeout, daemon=True, args=(timeout_value, share_value,), + name='timeout') + timeout_thread.start() + + # TODO(jwkim) : worker check + # if manager.current_workers > 0: + # manager.lock.acquire() + # manager.current_workers -= 1 + # manager.lock.release() + + except Exception as e: + result_message = { + "datetime": share_value.date_utill.date_str_micro_sec(), + "status": share_value.STATUS_ERROR + str(e), + "result": share_value.result + } + client.publish(AI_CONST.MQTT_PPE_TOPIC, json.dumps(result_message), 0) + + if not self.timeout_event.is_set(): + self.timeout_event.set() + self.timeout_event.clear() + + self.stop_event.clear() + self.worker_event.set() + + def _timeout(self, timeout_value, share_value): + """ + time out 관련 처리 + -> timeout 발생시 timeout_event는 set 상태가 아님 + + :param timeout_value: 타임아웃 시간 (초) + :param share_value: 스레드간 공유 변수 + """ + self.timeout_event.wait(timeout=timeout_value) + + # 외부에서 set + if self.timeout_event.is_set(): + share_value.timeout_status = False + + # stop + elif share_value.current_status == share_value.STATUS_STOP: + share_value.timeout_status = False + + # time out + else: + share_value.timeout_status = True + + def _input_source(self, args): + """ + 입력source를 받아서 parsing + :param args: input_video 정보 + :return: parsing된 정보 + """ + for i in args: + if i.model == M.AEAIModelType.PPE: + return i.connect_url + + +class CONManager: + """ + CON_SETUP Manager + """ + TYPE_DETECT_SEQ = 0 # 순서적 처리 + TYPE_DETECT_PAL = 1 + + def __init__(self, detect_type: int = TYPE_DETECT_SEQ): + self.con_queue = Queue() + + self.detect_type = detect_type + + self.con_receiver_thread = threading.Thread(target=self.con_receiver, daemon=True, name='con_receiver') + self.con_receiver_thread.start() + + self.worker_event = threading.Event() # worker check + self.stop_event = threading.Event() # inference stop trigger + self.timeout_event = threading.Event() # timeout check + + self.current_worker = '' + + def con_sender(self, data): + self.con_queue.put(data) + + def con_receiver(self): + """ + con_queue를 통해 데이터를 받아서 inference 시작, 정지 실행 + + con_thread가 실행중인 상황에서 또다른 데이터가 들어올시 + con_thread 종료후 새로운 con_thread 실행 + + 병렬처리 미구현 + """ + while True: + queue_data = self.con_queue.get() + if queue_data: + if self.detect_type == CONManager.TYPE_DETECT_SEQ: + + # start inference + if queue_data["signal"] == AI_CONST.SIGNAL_INFERENCE: + if self.current_worker: + if self.current_worker.is_alive(): + self.stop_event.set() + self.worker_event.wait() + + self.worker_event.clear() + con_thread = threading.Thread(target=self.con_start, daemon=True, args=(queue_data,), + name='con_thread').start() + + # stop inference + elif queue_data["signal"] == AI_CONST.SIGNAL_STOP: + if self.current_worker: + if self.current_worker.is_alive(): + self.stop_event.set() + self.worker_event.wait() + self.worker_event.clear() + + def con_start(self, data): + """ + inference 와 timeout 관리 + timeout 발생시 inference 종료 + inference 종료시 timeout_thread와 같이 종료 + :param data: queue_data + """ + + try: + share_value = ThreadValues() + + data['argument']['source'] = self._input_source(data['engine_info'].input_video) + + con_detect = d2_od_detect.CONDetect( + request=data['request'], + engine_info=data['engine_info'], + thread_value=share_value, + yolo_argument=data['argument'], + worker_event=self.worker_event, + stop_event=self.stop_event, + timeout_event=self.timeout_event, + queue_info=self.con_queue, + ) + self.current_worker = threading.Thread(target=con_detect.run, daemon=True, kwargs=data['argument'], + name='con_inference') + self.current_worker.start() + + timeout_value = data['request'].limit_time_min * 60 + timeout_thread = threading.Thread(target=self._timeout, daemon=True, args=(timeout_value, share_value,), + name='timeout') + timeout_thread.start() + + # TODO(jwkim) : worker check + # if manager.current_workers > 0: + # manager.lock.acquire() + # manager.current_workers -= 1 + # manager.lock.release() + + except Exception as e: + result_message = { + "datetime": share_value.date_utill.date_str_micro_sec(), + "status": share_value.STATUS_ERROR + str(e), + "result": share_value.result + } + client.publish(AI_CONST.MQTT_CON_TOPIC, json.dumps(result_message), 0) + + if not self.timeout_event.is_set(): + self.timeout_event.set() + self.timeout_event.clear() + + self.stop_event.clear() + self.worker_event.set() + + def _timeout(self, timeout_value, share_value): + """ + time out 관련 처리 + -> timeout 발생시 timeout_event는 set 상태가 아님 + + :param timeout_value: 타임아웃 시간 (초) + :param share_value: 스레드간 공유 변수 + """ + self.timeout_event.wait(timeout=timeout_value) + + # 외부에서 set + if self.timeout_event.is_set(): + share_value.timeout_status = False + + # stop + elif share_value.current_status == share_value.STATUS_STOP: + share_value.timeout_status = False + + # time out + else: + share_value.timeout_status = True + + def _input_source(self, args): + """ + 입력source를 받아서 parsing + :param args: input_video 정보 + :return: parsing된 정보 + """ + for i in args: + if i.model == M.AEAIModelType.CON: + return i.connect_url + +class WDManager: + """ + WorkDetect Manager + """ + TYPE_DETECT_SEQ = 0 # 순서적 처리 + TYPE_DETECT_PAL = 1 + + def __init__(self, detect_type: int = TYPE_DETECT_SEQ): + self.wd_queue = Queue() + + self.detect_type = detect_type + + self.wd_receiver_thread = threading.Thread(target=self.wd_receiver, daemon=True, name='wd_receiver') + self.wd_receiver_thread.start() + + self.worker_event = threading.Event() # worker check + self.stop_event = threading.Event() # inference stop trigger + # self.timeout_event = threading.Event() # timeout check + + self.lock = threading.Lock() + + self.current_worker = '' + + def wd_sender(self, data): + self.wd_queue.put(data) + + def wd_receiver(self): + """ + wd_queue를 통해 데이터를 받아서 inference 시작, 정지 실행 + + wd_thread가 실행중인 상황에서 또다른 데이터가 들어올시 + wd_thread 종료후 새로운 wd_thread 실행 + + 병렬처리 미구현 + """ + while True: + queue_data = self.wd_queue.get() + if queue_data: + if self.detect_type == WDManager.TYPE_DETECT_SEQ: + + # start inference + if queue_data["signal"] == AI_CONST.SIGNAL_INFERENCE: + if self.current_worker: + if self.current_worker.is_alive(): + self.stop_event.set() + self.worker_event.wait() + + self.worker_event.clear() + wd_thread = threading.Thread(target=self.wd_start, daemon=True, args=(queue_data,), + name='wd_thread') + wd_thread.start() + + # stop inference + elif queue_data["signal"] == AI_CONST.SIGNAL_STOP: + if self.current_worker: + if self.current_worker.is_alive(): + self.stop_event.set() + self.worker_event.wait() + self.worker_event.clear() + + def wd_start(self, data): + """ + inference 관리 + + 외부에서 stop 발생시 inference 종료 + + :param data: queue_data + """ + + try: + share_value = ThreadValues() + + data['argument']['source'] = self._input_source(data['engine_info'].input_video) + + wd_detect = d2_od_detect.WDDetect( + request=data['request'], + engine_info=data['engine_info'], + thread_value=share_value, + yolo_argument=data['argument'], + worker_event=self.worker_event, + stop_event=self.stop_event, + queue_info=self.wd_queue + ) + self.current_worker = threading.Thread(target=wd_detect.run, daemon=True, kwargs=data['argument'], + name='wd_inference') + self.current_worker.start() + + # TODO(jwkim) : worker check + # if manager.current_workers > 0: + # manager.lock.acquire() + # manager.current_workers -= 1 + # manager.lock.release() + + except Exception as e: + result_message = { + "datetime": share_value.date_utill.date_str_micro_sec(), + "status": share_value.STATUS_ERROR + str(e), + "result": { + 'construction_type': data['request'].ri.construction_code, + 'procedure_no': data['request'].ri.work_no, + 'procedure_ri': data['request'].ri.work_define_ri, + 'ri': share_value.wd_ri, + 'detect_list': share_value.result + } + } + client.publish(AI_CONST.MQTT_WD_TOPIC, json.dumps(result_message), 0) + + # if not self.timeout_event.is_set(): + # self.timeout_event.set() + # self.timeout_event.clear() + + self.stop_event.clear() + self.worker_event.set() + + def _input_source(self, args): + """ + 모델에서 입력소스를 받아서 parsing + 만약 stream data와 files data가 같이있으면 stream data return + :param args: input_video 정보 + :return: parsing된 정보 + """ + + result = '' + + streams = [] # webcam, RTSP + files = [] + + for i in args: + if i.model == M.AEAIModelType.WORK: + if os.path.isfile(i.connect_url): + files.append(i.connect_url) + elif i.connect_url.isnumeric() or i.connect_url.lower().startswith( + ('rtsp://', 'rtmp://', 'http://', 'https://')): + streams.append(i.connect_url) + + if streams: + f = open(AI_CONST.STREAMS_PATH, 'w') + for k in streams: + # if k.isnumeric(): + # k = int(k) + f.write(k + "\n") + f.close() + result = AI_CONST.STREAMS_PATH + + elif files: + result = files[0] + + return result + + +class FRManager: + """ + Face Recognition Manager + """ + # fr_queue = Queue() + TYPE_DETECT_SEQ = 0 + TYPE_DETECT_PAL = 1 + + # DETECT_WORKER_MAX = 10 + + def __init__(self, detect_type: int = TYPE_DETECT_SEQ): + self.fr_queue = Queue() + + # self.status_detect = False + self.detect_type = detect_type # 0: 순서적처리 1: 병렬처리 + # self.detect_worker_list = [] + # self.detect_worker_cur_idx = 0 + + self.stop_event = threading.Event() # inference stop trigger + self.worker_event = threading.Event() # worker check + self.timeout_event = threading.Event() # timeout check + + self.fr_receiver_thread = threading.Thread(target=self._fr_receiver, daemon=True, name='fr_receiver') + self.fr_receiver_thread.start() + + # self.frlock = threading.Lock() + self.fr_thread = '' + self.current_worker = '' + + self.fr_id_info = {} + + def fr_sender(self, data): + self.fr_queue.put(data) + + def _fr_receiver(self): + """ + FR_sender 를 통해 받은 메세지 처리 + signal data를 입력받아 inference 시작, 강제종료 시킴 + TODO(jwkim) 병렬처리 + """ + while True: + # message get + queue_data = self.fr_queue.get() + + if queue_data: + if self.detect_type == FRManager.TYPE_DETECT_SEQ: + + if queue_data["signal"] == AI_CONST.SIGNAL_INFERENCE: + # start inference + if self.current_worker: + if self.current_worker.is_alive(): + self.stop_event.set() + self.worker_event.wait() + self.worker_event.clear() + self.fr_thread = threading.Thread(target=self._fr_start, daemon=True, args=(queue_data,), + name='fr_thread') + self.fr_thread.start() + + elif queue_data["signal"] == AI_CONST.SIGNAL_STOP: + # stop inference + if self.current_worker: + if self.current_worker.is_alive(): + self.stop_event.set() + self.worker_event.wait() + self.worker_event.clear() + + # elif self.detect_type == FRManager.TYPE_DETECT_PAL: + # #TODO(jwkim): 병렬처리 + + # self.detect_worker_list.append(threading.Thread(target=self.FR_start, self.detect_worker_cur_idx, daemon=True).start()) + # self.detect_worker_cur_idx+=1 + # pass + + def _fr_start(self, data): + """ + timeout 스레드와 inference 스레드 관리 + + :param data: _description_ + """ + try: + self.timeout_event.clear() + share_value = ThreadValues() + + input_video = self._input_source(data['engine_info'].input_video) + + face_detect = FaceDetect( + request=data['request'], + thread_value=share_value, + stop_event=self.stop_event, + worker_event=self.worker_event, + timeout_event=self.timeout_event, + queue_info=self.fr_queue, + input_video=input_video, + engine_info=data['engine_info'], + fr_manager=self + ) + + self.current_worker = threading.Thread(target=face_detect.inference, daemon=True, name='fr_inference') + self.current_worker.start() + + limit_time_min = data['request'].limit_time_min * 60 + self.timeout_thread = threading.Thread(target=self._timeout, daemon=True, + args=(limit_time_min, share_value,), name='timeout') + self.timeout_thread.start() + + except Exception as e: + result_message = { + "datetime": share_value.date_utill.date_str_micro_sec(), + "status": share_value.STATUS_ERROR + str(e), + "result": share_value.result + } + client.publish(AI_CONST.MQTT_FR_TOPIC, json.dumps(result_message), 0) + + def _timeout(self, limit_min: int, share_value): + """ + 입력받은 시간을 토대로 current_worker에 timeout 발생 + :param limit_min: 제한시간(분) + """ + self.timeout_event.wait(timeout=limit_min) + + # 외부에서 set + if self.timeout_event.is_set(): + share_value.timeout_status = False + + # stop + elif share_value.current_status == share_value.STATUS_STOP: + share_value.timeout_status = False + + # time out + else: + share_value.timeout_status = True + + def _input_source(self, args): + for i in args: + if i.model == M.AEAIModelType.FR: + return i.connect_url + + +manager = WorkQueue() +CON = CONManager() +PPE = PPEManager() +WD = WDManager() +FR = FRManager() + +if __name__ == '__main__': + pass diff --git a/AI_ENGINE/mqtt_publish.py b/AI_ENGINE/mqtt_publish.py new file mode 100644 index 0000000..1258a71 --- /dev/null +++ b/AI_ENGINE/mqtt_publish.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +""" +@file : mqtt_publish.py +@author: jwkim +@license: A2TEC & DAOOLDNS + +@section Modify History +- 2023-01-11 오전 11:31 jwkim base + +""" +import os, sys +import paho.mqtt.client as mqtt + +sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) + +import ai_engine_const as AI_CONST + +import json + + +def on_connect(client, userdata, flags, rc): + if rc == 0: + print("MQTT connected OK") + else: + print("Bad connection Returned code=", rc) + + +def on_disconnect(client, userdata, flags, rc=0): + print(str(rc)) + + +def on_publish(client, userdata, mid): + print("publish = ", mid) + + +# 새로운 클라이언트 생성 +client = mqtt.Client() +# 콜백 함수 설정 on_connect(브로커에 접속), on_disconnect(브로커에 접속중료), on_publish(메세지 발행) +client.on_connect = on_connect +client.on_disconnect = on_disconnect +# client.on_publish = on_publish +client.username_pw_set(AI_CONST.MQTT_USER_ID,AI_CONST.MQTT_USER_PW) +# address : localhost, port: 1883 에 연결 +client.connect(AI_CONST.MQTT_HOST, AI_CONST.MQTT_PORT) diff --git a/AI_ENGINE/mqtt_subscribe.py b/AI_ENGINE/mqtt_subscribe.py new file mode 100644 index 0000000..7a60bbc --- /dev/null +++ b/AI_ENGINE/mqtt_subscribe.py @@ -0,0 +1,64 @@ +import os, sys +import paho.mqtt.client as mqtt +import json +import cv2 +from base64 import b64decode +import numpy as np + +sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) + +import ai_engine_const as AI_CONST + + +def on_connect(mqtt_object, userdata, flags, rc): + if rc == 0: + print("connected OK") + else: + print("Bad connection Returned code=", rc) + + +def on_disconnect(mqtt_object, userdata, flags, rc=0): + print(str(rc)) + + +def on_subscribe(mqtt_object, userdata, mid, granted_qos): + print("subscribed: " + str(mid) + " " + str(granted_qos)) + + +def on_message(mqtt_object, userdata, msg): + """ + :param mqtt_object: 정의된mqtt client + :param userdata: ? + :param msg: 전송받은 메세지 정보 + """ + + if msg.topic == AI_CONST.MQTT_PPE_TOPIC: + object_message = json.loads(msg.payload.decode("utf-8")) + # if object_message["img_data"]: + # img_data = b64decode(object_message["img_data"]) + # + # save_image = cv2.imdecode(np.frombuffer(img_data,dtype=np.uint8), cv2.IMREAD_COLOR) + # cv2.imwrite("./test.jpg", save_image) + print(object_message) + + elif msg.topic == AI_CONST.MQTT_FR_TOPIC: + print(msg.payload.decode("utf-8")) + + +client = mqtt.Client() + +client.on_connect = on_connect +client.on_disconnect = on_disconnect +client.on_subscribe = on_subscribe +client.on_message = on_message + +client.username_pw_set(AI_CONST.MQTT_USER_ID, AI_CONST.MQTT_USER_PW) +client.connect(AI_CONST.MQTT_HOST, AI_CONST.MQTT_PORT) +client.subscribe(AI_CONST.MQTT_PPE_TOPIC) # ppe +client.subscribe(AI_CONST.MQTT_FR_TOPIC) # FR + + + +client.loop_forever() + + diff --git a/AI_ENGINE/old_queue.py b/AI_ENGINE/old_queue.py new file mode 100644 index 0000000..76f5423 --- /dev/null +++ b/AI_ENGINE/old_queue.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +""" +@File: queue.py +@Date: 2022-06-30 +@author: jwkim +@section MODIFYINFO 수정정보 +- 수정자/수정일 : 수정내역 +@brief: queue +""" + +import os, sys +import time +import threading +from queue import Queue + +DL_BASE_PATH = os.path.abspath(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) +sys.path.append(DL_BASE_PATH) + +from DL import d2_od_detect + + +class WorkQueue: + max_workers = 4 + message_queue = Queue() + + def __init__(self): + self.current_workers = 0 + self.message_queue = Queue() + self.lock = threading.Lock() + receiver_start_thread = threading.Thread(target=self.classifier, daemon=True).start() + + def sender(self, data): + self.message_queue.put(data) + + def classifier(self): + while True: + queue_data = self.message_queue.get() + work_thread = threading.Thread(target=self.check_worker, args=(queue_data,), daemon=True).start() + + def check_worker(self, args): + if self.max_workers > self.current_workers: + self.start_worker(args) + + elif self.max_workers <= self.current_workers: + while self.max_workers <= self.current_workers: + time.sleep(1) + self.start_worker(args) + + def start_worker(self, inference_args): + self.lock.acquire() + self.current_workers += 1 + self.lock.release() + + d2_od_detect.run(**inference_args) + + if self.current_workers > 0: + self.lock.acquire() + self.current_workers -= 1 + self.lock.release() + + +if __name__ == '__main__': + pass diff --git a/DL/FR/d2_face_detect.py b/DL/FR/d2_face_detect.py new file mode 100644 index 0000000..896be9a --- /dev/null +++ b/DL/FR/d2_face_detect.py @@ -0,0 +1,581 @@ +# -*- coding: utf-8 -*- +""" +@file : d2_face_detect.py +@author: jwkim +@license: A2TEC & DAOOLDNS + +@section Modify History +- 2023-01-11 오전 11:31 jwkim base + +""" + +import face_recognition +import cv2 +import os, sys +import copy +import json +import threading +import numpy as np +import paramiko + +AI_ENGINE_PATH = "/AI_ENGINE" +sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))) + AI_ENGINE_PATH) # mqtt +os.path.dirname(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) # ai_const + +from mqtt_publish import client +import ai_engine_const as AI_CONST +from REST_AI_ENGINE_CONTROL.app import models as M +from REST_AI_ENGINE_CONTROL.app.utils.extra import file_size_check + +import project_config + +# This is a demo of running face recognition on a video file and saving the results to a new video file. +# +# PLEASE NOTE: This example requires OpenCV (the `cv2` library) to be installed only to read from your webcam. +# OpenCV is *not* required to use the face_recognition library. It's only required if you want to run this +# specific demo. If you have trouble installing it, try any of the other demos that don't require it instead. + + +class FaceDetect: + """ + 안면인식 + + 외부 event thread를 이용하여 제어 + inference 동작중 해당상황에 맞는 메세지 publish + """ + + def __init__(self, + request, + thread_value, + stop_event: threading.Event, + worker_event: threading.Event, + timeout_event: threading.Event, + queue_info, + engine_info, + input_video, + fr_manager + ): + """ + 초기화 함수 + :param request: 외부 요청 메세지 + report_unit : 신규 대상 인식할때마다 전송할지 여부 + targets : detect할 대상의 명단) + :param stop_event: instance 내부 loop 종료 관련 이벤트 + :param worker_event: current_worker(thread) 종료 관련 이벤트 + :param timeout_event: timeout 발생시 내부 loop 종료 관련 이벤트 + """ + + self.report_unit = request.report_unit + self.targets = request.targets + self.thread_value = thread_value + + self.mqtt_status = self.thread_value.STATUS_NONE + + self.result = self.thread_value.result + self.result = [] + self.result.extend([None] * (len(self.targets))) + + self.report_unit_result = self.thread_value.report_unit_result + self.report_unit_result = [] + + self.worker_event = worker_event # 외부 worker 제어 + self.stop_event = stop_event # stop trigger + self.stop_event.clear() + self.timeout_event = timeout_event + + self.thread_value = thread_value + + self.stop_sign = False + + self.fr_manager = fr_manager + + self.id_info = self.fr_manager.fr_id_info + + # self.worker_names = list(self.id_info.keys()) #TODO(jwkim): 작업자 수동 등록 + + self.worker_names = self.targets + + self.queue_info = queue_info + + self.input_video = int(input_video) if input_video.isnumeric() else input_video + self.model_info = engine_info + + self.ftp_info = self.model_info.demo.ftp + self.ri_info = request.ri or self.model_info.fr_model_info.ri + + self.snapshot_path = None + self.result_frame = [] + + + # def _ai_rbi(self, detected): + # # ri 정보 self.ri_info + # pass + + def _sftp_upload(self): + try: + + transprot = paramiko.Transport((self.ftp_info.ip,self.ftp_info.port)) + transprot.connect(username = self.ftp_info.id, password = self.ftp_info.pw) + sftp = paramiko.SFTPClient.from_transport(transprot) + + remotepath = self.ftp_info.location + os.sep + self.ftp_info.file_face + '.jpg' + + #sftp.put(AI_CONST.FTP_FR_RESULT, remotepath) + + sftp.close() + transprot.close() + + return remotepath + except Exception as e: + return "" + + def _mqtt_publish(self, status: str): + """ + mqtt message publish + :param status: 안면인식 상태 + """ + client.loop_start() + if status == self.thread_value.STATUS_NEW: + result = self.report_unit_result + else: + result = self.result + + mqtt_msg_dict = { + "datetime": self.thread_value.date_utill.date_str_micro_sec(), + "status": status, + "result": result + } + + if self.snapshot_path: + mqtt_msg_dict[AI_CONST.DEMO_KEY_NAME_SNAPSHOT_SFTP] = self.snapshot_path + + + client.publish(AI_CONST.MQTT_FR_TOPIC, json.dumps(mqtt_msg_dict), 0) + # client.loop_stop() + + def _result_update(self, matched_list: list): + """ + :param matched_list: detect 된 object list + """ + self.mqtt_status = self.thread_value.STATUS_NONE # status init + + if self.report_unit: + self.report_unit_result = [] + self.report_unit_result.extend([None] * (len(self.targets))) + + origin_matched_list = copy.deepcopy(matched_list) + + # 초기 result + if self.result.count(None) == len(self.targets): + for i in matched_list: + if i in self.worker_names and i in self.targets: + self.result[self.targets.index(i)] = i + if self.report_unit and self.result.count(None) != len(self.targets): + self.mqtt_status = self.thread_value.STATUS_NEW + self.report_unit_result = self.result + + # update result + for i in origin_matched_list: + if i not in self.result and i in self.targets: + if self.report_unit: + self.mqtt_status = self.thread_value.STATUS_NEW + self.report_unit_result[self.targets.index(i)] = i + + self.result[self.targets.index(i)] = i + + # TODO(jwkim): 함수에서 분리 + if self.result.count(None) == 0 and self.targets: + self.mqtt_status = self.thread_value.STATUS_COMPLETE + self.stop_event.set() + + def _encoding(self): + """ + 안면인식 대상 이미지 등록 + 등록된 인원 : no000001,no000002,no000003 + :return: encoding 정보(list) + """ + # Load some sample pictures and learn how to recognize them. + worker1_image = face_recognition.load_image_file(AI_CONST.WORKER1_IMG_PATH) + no000001 = { + 'target_names' : 'no000001' , + 'encoding' : face_recognition.face_encodings(worker1_image)[0] + } + + worker2_image = face_recognition.load_image_file(AI_CONST.WORKER2_IMG_PATH) + no000002 = { + 'target_names' : 'no000002', + 'encoding' : face_recognition.face_encodings(worker2_image)[0] + } + + worker3_image = face_recognition.load_image_file(AI_CONST.WORKER3_IMG_PATH) + no000003 = { + 'target_names' : 'no000003', + 'encoding' : face_recognition.face_encodings(worker3_image)[0] + } + + worker4_image = face_recognition.load_image_file(AI_CONST.WORKER4_IMG_PATH) + no000004 = { + 'target_names' : 'no000004', + 'encoding' : face_recognition.face_encodings(worker4_image)[0] + } + + worker5_image = face_recognition.load_image_file(AI_CONST.WORKER5_IMG_PATH) + no000005 = { + 'target_names' : 'no000005', + 'encoding' : face_recognition.face_encodings(worker5_image)[0] + } + + worker6_image = face_recognition.load_image_file(AI_CONST.WORKER6_IMG_PATH) + no000006 = { + 'target_names' : 'no000006', + 'encoding' : face_recognition.face_encodings(worker6_image)[0] + } + + worker7_image = face_recognition.load_image_file(AI_CONST.WORKER7_IMG_PATH) + no000007 = { + 'target_names' : 'no000007', + 'encoding' : face_recognition.face_encodings(worker7_image)[0] + } + + worker8_image = face_recognition.load_image_file(AI_CONST.WORKER8_IMG_PATH) + no000008 = { + 'target_names' : 'no000008', + 'encoding' : face_recognition.face_encodings(worker8_image)[0] + } + + # worker9_image = face_recognition.load_image_file(AI_CONST.WORKER9_IMG_PATH) + # no000009 = { + # 'target_names' : 'no000009', + # 'encoding' : face_recognition.face_encodings(worker9_image)[0] + # } + + + worker10_image = face_recognition.load_image_file(AI_CONST.WORKER10_IMG_PATH) + no000010 = { + 'target_names' : 'no000010', + 'encoding' : face_recognition.face_encodings(worker10_image)[0] + } + + worker11_image = face_recognition.load_image_file(AI_CONST.WORKER11_IMG_PATH) + no000011 = { + 'target_names' : 'no000011', + 'encoding' : face_recognition.face_encodings(worker11_image)[0] + } + + worker12_image = face_recognition.load_image_file(AI_CONST.WORKER12_IMG_PATH) + no000012 = { + 'target_names' : 'no000012', + 'encoding' : face_recognition.face_encodings(worker12_image)[0] + } + # print('d3') + # worker13_image = face_recognition.load_image_file(AI_CONST.WORKER13_IMG_PATH) + # no000013 = { + # 'target_names' : 'no000013', + # 'encoding' : face_recognition.face_encodings(worker13_image)[0] + # } + # print('d4') + # worker14_image = face_recognition.load_image_file(AI_CONST.WORKER14_IMG_PATH) + # no000014 = { + # 'target_names' : 'no000014', + # 'encoding' : face_recognition.face_encodings(worker14_image)[0] + # } + # print('d5') + worker15_image = face_recognition.load_image_file(AI_CONST.WORKER15_IMG_PATH) + no000015 = { + 'target_names' : 'no000015', + 'encoding' : face_recognition.face_encodings(worker15_image)[0] + } + # print('d6') + worker16_image = face_recognition.load_image_file(AI_CONST.WORKER16_IMG_PATH) + no000016 = { + 'target_names' : 'no000016', + 'encoding' : face_recognition.face_encodings(worker16_image)[0] + } + # print('d7') + # worker17_image = face_recognition.load_image_file(AI_CONST.WORKER17_IMG_PATH) + # no000017 = { + # 'target_names' : 'no000017', + # 'encoding' : face_recognition.face_encodings(worker17_image)[0] + # } + # print('d8') + worker18_image = face_recognition.load_image_file(AI_CONST.WORKER18_IMG_PATH) + no000018 = { + 'target_names' : 'no000018', + 'encoding' : face_recognition.face_encodings(worker18_image)[0] + } + # print('d9') + worker19_image = face_recognition.load_image_file(AI_CONST.WORKER19_IMG_PATH) + no000019 = { + 'target_names' : 'no000019', + 'encoding' : face_recognition.face_encodings(worker19_image)[0] + } + # print('d0') + worker20_image = face_recognition.load_image_file(AI_CONST.WORKER20_IMG_PATH) + no000020 = { + 'target_names' : 'no000020', + 'encoding' : face_recognition.face_encodings(worker20_image)[0] + } + + # print('d01') + worker21_image = face_recognition.load_image_file(AI_CONST.WORKER21_IMG_PATH) + no000021 = { + 'target_names' : 'no000021', + 'encoding' : face_recognition.face_encodings(worker21_image)[0] + } + + # print('d02') + worker22_image = face_recognition.load_image_file(AI_CONST.WORKER22_IMG_PATH) + no000022 = { + 'target_names' : 'no000022', + 'encoding' : face_recognition.face_encodings(worker22_image)[0] + } + + encoding_list = [ + #no000001,no000002,no000003 # kepco 1 + # , + no000004 # jangys + ,no000005 # whangsj + # ,no000006 + # ,no000007 + ,no000008 # agics + # ,no000009 + # , + # ,no000010,no000011,no000012,no000015,no000018 #Helmet on kepco2 + ,no000019,no000020 #Helmet off kepco 2 + ,no000021 #,no000022 + + ] + result = [] + + #demo + for i in encoding_list: + result.append(i["encoding"]) + + names = [ + #'no000001','no000002','no000003'# kepco 1 + # , + 'no000004' # jangys + ,'no000005' # whangsj + # ,'no000006' + # ,'no000007' + ,'no000008' # agics + # ,'no000009' + # , + # ,'no000010','no000011','no000012','no000015','no000018' #Helmet on kepco2 + ,'no000019','no000020' #Helmet off kepco 2 + ,'no000021' #,no000022 + ] + return result , names + + # for i in encoding_list: + # if i['target_names'] in self.targets: + # result.append(i["encoding"]) + # if not result: + # raise Exception("invalid targets") + + # return result + + def _new_encoding(self): + result = [] + for key,value in self.id_info.items(): + current_size = file_size_check(value['image']) + if value['size'] != current_size : + raise Exception(AI_CONST.INVALID_IMG_MSG) + fr_image = face_recognition.load_image_file(value['image'].file) + fr_encoding = face_recognition.face_encodings(fr_image)[0] + result.append(fr_encoding) + + return result + + def inference(self): + """ + 안면인식 inference 동작과 동작 상태에 따라 mqtt message publish + """ + try: + if os.path.exists(AI_CONST.FTP_FR_RESULT): + os.remove(AI_CONST.FTP_FR_RESULT) + + # start publish + self._mqtt_publish(status=self.thread_value.STATUS_START) + + input_movie = cv2.VideoCapture(self.input_video) + + #demo + known_faces, known_names = self._encoding() + + # known_faces = self._encoding() + # known_faces = self._new_encoding() #TODO(jwkim): 작업자 수동 등록 + + # Initialize some variables + # face_locations = [] + # face_encodings = [] + + # while self.mqtt_status != self.thread_value.STATUS_COMPLETE and self.stop_sign != True: + # while True: + + # # TODO(JWKIM):ageing test 23-02-28 + # # self._mqtt_publish(status="AGEING_TEST") + + # # loop stop + # if self.stop_event.is_set(): + # if self.mqtt_status == self.thread_value.STATUS_COMPLETE: + # pass + # else: + # self.thread_value.current_status = self.thread_value.STATUS_STOP + # # TODO(JWKIM):ageing test 23-02-28 + # break + # elif self.thread_value.timeout_status: + # self.thread_value.current_status = self.thread_value.STATUS_TIMEOUT + # # TODO(JWKIM):ageing test 23-02-28 + # break + + # Grab a single frame of video + ret, frame = input_movie.read() + + cv2.imwrite(AI_CONST.FTP_FR_RESULT,frame) + self._sftp_upload() + self.result = ["no000010"] + + self._mqtt_publish(status = self.thread_value.STATUS_COMPLETE) + + # Quit when the input video file ends + # 영상파일일 경우 다시 재생 + # if not ret and os.path.isfile(self.input_video): + # input_movie = cv2.VideoCapture(self.input_video) + # continue + + # frame_count = int(input_movie.get(cv2.CV_CAP_PROP_FPS)) + + # Convert the image from BGR color (which OpenCV uses) to RGB color (which face_recognition uses) + # rgb_frame = frame[:, :, ::-1] + + # face_locations = face_recognition.face_locations(rgb_frame) + # face_encodings = face_recognition.face_encodings(rgb_frame, face_locations) + # for face_encoding in face_encodings: + + # matched_name = [] # detect 완료된 명단 + + # match = face_recognition.compare_faces(known_faces, face_encoding, tolerance=AI_CONST.FACE_EVOLUTION_DISTANCE) + + # #TODO(jwkim): 입력 이미지 변경 + # if self.fr_manager.fr_id_info != self.id_info: + # raise Exception(AI_CONST.IMG_CHANGED_MSG) + + # face_distances = face_recognition.face_distance(known_faces, face_encoding) + # best_match_index = np.argmin(face_distances) + + # if face_distances[best_match_index] < AI_CONST.FACE_EVOLUTION_DISTANCE : + # #demo + # print(match) + # if match[best_match_index]: + + # # self.result[0] = known_names[best_match_index] + # self.result[0] = "no000010" + # self.mqtt_status = self.thread_value.STATUS_COMPLETE + # print(known_names[best_match_index]) + # self.stop_event.set() + # break + + # if match[best_match_index]: + # matched_name.append(self.worker_names[best_match_index]) + # print(self.worker_names[best_match_index]) + + #TODO(jwkim): 시연용 + # self._result_update(matched_name) + + # ri + # self._ai_rbi(matched_name) + + # if self.mqtt_status == self.thread_value.STATUS_NEW: + # self._mqtt_publish(self.thread_value.STATUS_NEW) + + # # input source check + # self._source_check() + + # # loop stop + # if self.stop_event.is_set(): + # if self.mqtt_status == self.thread_value.STATUS_COMPLETE: + # self.result_frame = copy.deepcopy(frame) + # else: + # self.thread_value.current_status = self.thread_value.STATUS_STOP + # # TODO(JWKIM):ageing test 23-02-28 + # break + # elif self.thread_value.timeout_status: + # self.thread_value.current_status = self.thread_value.STATUS_TIMEOUT + # # TODO(JWKIM):ageing test 23-02-28 + # break + + # if self.thread_value.timeout_status: + # self._mqtt_publish(self.thread_value.STATUS_TIMEOUT) + # self.thread_value.timeout_status = False + + # elif self.stop_event.is_set(): + # if self.mqtt_status == self.thread_value.STATUS_COMPLETE: + # #sftp + + # cv2.imwrite(AI_CONST.FTP_FR_RESULT,self.result_frame) + + # if project_config.SFTP_UPLOAD and project_config.FR_UPLOAD: + # self.snapshot_path = self._sftp_upload() + # else: + # self.snapshot_path = None + + # self._mqtt_publish(self.thread_value.STATUS_COMPLETE) + # else: + # self._mqtt_publish(self.thread_value.STATUS_STOP) + + # if not self.timeout_event.is_set(): + # self.timeout_event.set() + + # self.timeout_event.clear() + + # if self.mqtt_status == self.thread_value.STATUS_COMPLETE: + + # pass + + except Exception as e: + print(e) + # publish error + self._mqtt_publish(status=self.thread_value.STATUS_ERROR + str(e)) + self._queue_empty() + + finally: + if not self.timeout_event.is_set(): + self.timeout_event.set() + + self.timeout_event.clear() + # All done! + # output_movie.release() + # input_movie.release() + cv2.destroyAllWindows() + self.stop_event.clear() + self.worker_event.set() + + def _source_check(self): + """ + 현재 동작중인 input source 와 모델에 세팅된 input source가 다를시 예외 발생 + """ + current_source = self.input_video + for i in self.model_info.input_video: + if i.model == M.AEAIModelType.FR: + model_source = i.connect_url + + if model_source.isnumeric(): + model_source = int(model_source) + + if current_source != model_source: + raise Exception(AI_CONST.SOURCE_CHANGED_MSG) + + def _queue_empty(self): + """ + queue_info에 있는 데이터를 비운다. + """ + while True: + if self.queue_info.empty(): + break + else: + self.queue_info.get() + + +if __name__ == '__main__': + pass diff --git a/DL/custom_utils.py b/DL/custom_utils.py new file mode 100644 index 0000000..70eebbb --- /dev/null +++ b/DL/custom_utils.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +""" +@file : custom_utils.py +@author: Ultralytics , jwkim +@license: GPL-3.0 license + +@section Modify History +- +""" + +import torch +import cv2 +import math +import time +import os +import numpy as np + +from threading import Thread +from pathlib import Path +from urllib.parse import urlparse + +from OD.utils.augmentations import letterbox +from OD.utils.general import clean_str, check_requirements, is_colab, is_kaggle, LOGGER + + +class LoadStreams: + # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams` + def __init__(self, sources='file.streams', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1, + event=None): + torch.backends.cudnn.benchmark = True # faster for fixed-size inference + self.event = event # TODO(jwkim) thread 종료 관련 + self.mode = 'stream' + self.img_size = img_size + self.stride = stride + self.vid_stride = vid_stride # video frame-rate stride + sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources] + n = len(sources) + self.sources = [clean_str(x) for x in sources] # clean source names for later + self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n + for i, s in enumerate(sources): # index, source + # Start thread to read frames from video stream + st = f'{i + 1}/{n}: {s}... ' + if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video + # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc' + check_requirements(('pafy', 'youtube_dl==2020.12.2')) + import pafy + s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL + s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam + if s == 0: + assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.' + assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.' + cap = cv2.VideoCapture(s) + assert cap.isOpened(), f'{st}Failed to open {s}' + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan + self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback + self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback + + _, self.imgs[i] = cap.read() # guarantee first frame + self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True) + LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)") + self.threads[i].start() + LOGGER.info('') # newline + + # check for common shapes + s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs]) + self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal + self.auto = auto and self.rect + self.transforms = transforms # optional + if not self.rect: + LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.') + + def update(self, i, cap, stream): + # Read stream `i` frames in daemon thread + n, f = 0, self.frames[i] # frame number, frame array + while cap.isOpened() and n < f: + n += 1 + cap.grab() # .read() = .grab() followed by .retrieve() + if n % self.vid_stride == 0: + success, im = cap.retrieve() + if success: + self.imgs[i] = im + else: + LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.') + self.imgs[i] = np.zeros_like(self.imgs[i]) + cap.open(stream) # re-open stream if signal was lost + time.sleep(0.0) # wait time + if self.event.is_set(): # TODO(jwkim) thread 종료 관련 + break + + def __iter__(self): + self.count = -1 + return self + + def __next__(self): + self.count += 1 + if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit + cv2.destroyAllWindows() + raise StopIteration + + im0 = self.imgs.copy() + if self.transforms: + im = np.stack([self.transforms(x) for x in im0]) # transforms + else: + im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize + im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW + im = np.ascontiguousarray(im) # contiguous + + return self.sources, im, im0, None, '' + + def __len__(self): + return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years diff --git a/DL/d2_od_detect.py b/DL/d2_od_detect.py new file mode 100644 index 0000000..c9f0a22 --- /dev/null +++ b/DL/d2_od_detect.py @@ -0,0 +1,1764 @@ +# -*- coding: utf-8 -*- +""" +@file : d2_od_detect.py +@author: Ultralytics , jwkim +@license: GPL-3.0 license + +@section Modify History +- +""" + +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license +""" +Run YOLOv5 detection inference on images, videos, directories, globs, YouTube, webcam, streams, etc. + +Usage - sources: + $ python detect.py --weights yolov5s.pt --source 0 # webcam + img.jpg # image + vid.mp4 # video + screen # screenshot + path/ # directory + list.txt # list of images + list.streams # list of streams + 'path/*.jpg' # glob + 'https://youtu.be/Zgi9g1ksQHc' # YouTube + 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream + +Usage - formats: + $ python detect.py --weights yolov5s.pt # PyTorch + yolov5s.torchscript # TorchScript + yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn + yolov5s_openvino_model # OpenVINO + yolov5s.engine # TensorRT + yolov5s.mlmodel # CoreML (macOS-only) + yolov5s_saved_model # TensorFlow SavedModel + yolov5s.pb # TensorFlow GraphDef + yolov5s.tflite # TensorFlow Lite + yolov5s_edgetpu.tflite # TensorFlow Edge TPU + yolov5s_paddle_model # PaddlePaddle +""" + +import argparse +import os +import platform +import sys +from pathlib import Path +import json +import math +import numpy as np +import paramiko +import torch +import threading +import copy +import cv2 + +import face_recognition + +OD_BASE_PATH = '/OD' +OD_ABSOLUTE_PATH = os.path.abspath(os.path.dirname(__file__)) + OD_BASE_PATH +AI_ENGINE_PATH = "/AI_ENGINE" +sys.path.append(OD_ABSOLUTE_PATH) +sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))) + AI_ENGINE_PATH) # mqtt +sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) # const + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[0] # YOLOv5 root directory +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) # add ROOT to PATH +ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative + +D2_ROOT = Path(os.path.abspath(os.path.dirname(__file__))) + +from models.common import DetectMultiBackend +from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots +from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2, + increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh, + clean_str) +from utils.plots import Annotator, colors, save_one_box +from utils.torch_utils import select_device, smart_inference_mode + +from mqtt_publish import client +import ai_engine_const as AI_CONST +from REST_AI_ENGINE_CONTROL.app import models as M +from custom_utils import LoadStreams +import project_config + +import AI_ENGINE.demo_utils as DEMO + +# Load model +device = select_device('') +model = DetectMultiBackend(D2_ROOT / 'index_78.pt', device=device) +stride, names, pt = model.stride, model.names, model.pt +imgsz = check_img_size((640, 640), s=stride) # check image size + +import paho.mqtt.client as mqtt +rtsp_client = mqtt.Client() +rtsp_client.username_pw_set("admin","admin") +rtsp_client.connect("localhost", 11883) +rtsp_client.publish("ptz", f"1,0.95,0", 0) + +class CONDetect: + """ + 시설물 탐지 + """ + + def __init__(self, + request: M.AEAIModelConSetupDetectReq, + engine_info: M.AEInfo, + thread_value, + yolo_argument, + worker_event: threading.Event, + stop_event: threading.Event, + timeout_event: threading.Event, + queue_info + ): + # 이미지 좌표계 + self.bbox_type = AI_CONST.BBOX_XYXY + + self.report_unit = request.report_unit + self.targets = request.targets + self.target_list = [] + for key, value in self.targets: + if value: + self.target_list.append(key) + + self.model_info = engine_info + self.thread_value = thread_value + + # ri + self.ri_info = request.ri or self.model_info.con_model_info.ri + + self.thread_value.result = [None] * len(self.target_list) + self.thread_value.report_unit_result = [None] * len(self.target_list) + + self.yolo_args = yolo_argument + + self.worker_event = worker_event + self.stop_event = stop_event + self.timeout_event = timeout_event + self.dataloader_event = threading.Event() + + self.queue_info = queue_info + + self.ftp_info = self.model_info.demo.ftp + self.snapshot_path = None + self.result_frame = [] + + @smart_inference_mode() + def run( + self, + weights=D2_ROOT / 'index_78.pt', # model path or triton URL + source=D2_ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam) + data=D2_ROOT / 'data/coco128.yaml', # dataset.yaml path + imgsz=(640, 640), # inference size (height, width) + conf_thres=0.25, # confidence threshold + iou_thres=0.45, # NMS IOU threshold + max_det=1000, # maximum detections per image + device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu + view_img=False, # show results + save_txt=False, # save results to *.txt + save_conf=False, # save confidences in --save-txt labels + save_crop=False, # save cropped prediction boxes + nosave=True, # do not save images/videos + classes=None, # filter by class: --class 0, or --class 0 2 3 + agnostic_nms=False, # class-agnostic NMS + augment=False, # augmented inference + visualize=False, # visualize features + update=False, # update all models + project=D2_ROOT / 'runs/detect', # save results to project/name + name='exp', # save results to project/name + exist_ok=False, # existing project/name ok, do not increment + line_thickness=3, # bounding box thickness (pixels) + hide_labels=False, # hide labels + hide_conf=False, # hide confidences + half=False, # use FP16 half-precision inference + dnn=False, # use OpenCV DNN for ONNX inference + vid_stride=1, # video frame-rate stride + ): + try: + if os.path.exists(AI_CONST.FTP_CON_RESULT): + os.remove(AI_CONST.FTP_CON_RESULT) + + source = str(source) + save_img = not nosave and not source.endswith('.txt') # save inference images + is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) + is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) + webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) + screenshot = source.lower().startswith('screen') + if is_url and is_file: + source = check_file(source) # download + + # Directories + save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run + # (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir + + # # Load model + # device = select_device(device) + # model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) + # stride, names, pt = model.stride, model.names, model.pt + # imgsz = check_img_size(imgsz, s=stride) # check image size + + # Run inference + self._publish(self.thread_value.STATUS_START, self.thread_value.result) + + #while True: + # Dataloader + bs = 1 # batch_size + if webcam: + # LoadStreams 한번만 실행 + if not 'dataset' in locals(): + # view_img = check_imshow(warn=True) # imshow + dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride, + event=self.dataloader_event) + bs = len(dataset) + elif screenshot: + dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt) + else: + dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) + vid_path, vid_writer = [None] * bs, [None] * bs + + model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup + seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) + + for path, im, im0s, vid_cap, s in dataset: + # # loop stop + # if self.stop_event.is_set(): + # if self.thread_value.current_status == self.thread_value.STATUS_COMPLETE: + # pass + # else: + # self.thread_value.current_status = self.thread_value.STATUS_STOP + # #TODO(JWKIM):ageing test 23-02-28 + # break + + # elif self.thread_value.timeout_status: + # self.thread_value.current_status = self.thread_value.STATUS_TIMEOUT + # break #TODO(JWKIM):ageing test 23-02-28 + + with dt[0]: + im = torch.from_numpy(im).to(model.device) + im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 + im /= 255 # 0 - 255 to 0.0 - 1.0 + if len(im.shape) == 3: + im = im[None] # expand for batch dim + + # Inference + with dt[1]: + visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False + pred = model(im, augment=augment, visualize=visualize) + + # NMS + with dt[2]: + pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) + + # Second-stage classifier (optional) + # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s) + + # Process predictions + for i, det in enumerate(pred): # per image + seen += 1 + if webcam: # batch_size >= 1 + p, im0, frame = path[i], im0s[i].copy(), dataset.count + s += f'{i}: ' + else: + p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0) + + # p = Path(p) # to Path + # save_path = str(save_dir / p.name) # im.jpg + # txt_path = str(save_dir / 'labels' / p.stem) + ( + # '' if dataset.mode == 'image' else f'_{frame}') # im.txt + # s += '%gx%g ' % im.shape[2:] # print string + # gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh + # imc = im0.copy() if save_crop else im0 # for save_crop + # annotator = Annotator(im0, line_width=line_thickness, example=str(names)) + + self.thread_value.current_status == self.thread_value.STATUS_COMPLETE + cv2.imwrite(AI_CONST.FTP_CON_RESULT,copy.deepcopy(im0)) + if project_config.SFTP_UPLOAD: + self._sftp_upload() + demo_dict=[ + { + "name": 'signalman', + "cid": 1, + "confidence": 0.7, + "bbox": [ 150, 150, 100, 200 ], + "image": None # nparray -> list + }, + { + "name": 'sign_traffic_cone', + "cid": 25, + "confidence": 0.7, + "bbox": [ 150, 150, 100, 200 ], + "image": None # nparray -> list + }, + + + + + ] + print(demo_dict) + self._publish(status = self.thread_value.STATUS_COMPLETE, result = demo_dict) + break + break + + # # FR START + # #TODO(jwkim): PPE + FR + # # imface = im0.copy() + # # self._fr_inference(imface) + + # # 딕셔너리 초기화 + # detected_object = {} + + # if len(det): + # # Rescale boxes from img_size to im0 size + # det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() + + # # Print results + # for c in det[:, 5].unique(): + # n = (det[:, 5] == c).sum() # detections per class + # s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string + + # # Write results + # for *xyxy, conf, cls in reversed(det): + # if save_txt: # Write to file + # xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view( + # -1).tolist() # normalized xywh + # line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format + # with open(f'{txt_path}.txt', 'a') as f: + # f.write(('%g ' * len(line)).rstrip() % line + '\n') + + # if save_img or save_crop or view_img: # Add bbox to image + # c = int(cls) # integer class + # label = None if hide_labels else ( + # names[c] if hide_conf else f'{names[c]} {conf:.2f}') + # annotator.box_label(xyxy, label, color=colors(c, True)) + # if save_crop: + # save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', + # BGR=True) + + # c = int(cls) + + # # if names[c] in self.target_list: + # if names[c] == "sign_traffic_cone": + # # crop image 전송여부 + # if self.model_info.con_model_info.crop_images: + # # crop_image_data = save_one_box(xyxy, imc, BGR=True, save=False).tolist() + # crop_image_data = save_one_box(xyxy, imc, BGR=True, save=False) + # crop_image_data = self._image_encoding(crop_image_data) + # else: + # crop_image_data = None + + # if self.bbox_type == AI_CONST.BBOX_XYXY: + # result_bbox = [int(xyxy[1]), int(xyxy[0]), int(xyxy[3]), int(xyxy[2])] + # elif self.bbox_type == AI_CONST.BBOX_XYWH: + # result_bbox = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() + + # detected_object[names[c]] = { + # "name": names[c], + # "cid": c, + # "confidence": round(conf.item(), 2), + # "bbox": result_bbox, + # "image": crop_image_data # nparray -> list + # } + + # detected_object["signalman"] = { + # "name": "signalman", + # "cid": 0, + # "confidence": round(conf.item(), 2), + # "bbox": result_bbox, + # "image": None # nparray -> list + # } + + # self.result_frame = copy.deepcopy(im0) + + # # input source check + # self._source_check(dataset) + + # # ri + # detected_object = self._ai_rbi(detected_object) + + # # update + # self._result_update(detected_object) + + # # publish new + # if self.thread_value.current_status == self.thread_value.STATUS_NEW: + # self._publish(self.thread_value.STATUS_NEW, self.thread_value.report_unit_result) + # self.thread_value.current_status = self.thread_value.STATUS_NONE + + # # publish complete + # elif self.thread_value.current_status == self.thread_value.STATUS_COMPLETE: + # cv2.imwrite(AI_CONST.FTP_CON_RESULT,self.result_frame) + + # if project_config.SFTP_UPLOAD: + # self.snapshot_path = self._sftp_upload() + # else: + # self.snapshot_path = None + + # self._publish(self.thread_value.STATUS_COMPLETE, self.thread_value.result) + # self.stop_event.set() + + # # if frame % 60 == 0 : + # # #TODO(JWKIM):ageing test 23-02-28 + # # self._publish("AGEING_TEST", self.thread_value.result) + + # # #TODO(JWKIM):ageing test 23-02-28 + # if self.thread_value.current_status == self.thread_value.STATUS_COMPLETE: + # break + + # # loop stop + # if self.stop_event.is_set(): + # if self.thread_value.current_status == self.thread_value.STATUS_COMPLETE: + # pass + # else: + # self.thread_value.current_status = self.thread_value.STATUS_STOP + # # TODO(JWKIM):ageing test 23-02-28 + # break + + # elif self.thread_value.timeout_status: + # self.thread_value.current_status = self.thread_value.STATUS_TIMEOUT + # # #TODO(JWKIM):ageing test 23-02-28 + # break + + # # Stream results + # im0 = annotator.result() + # if view_img: + # if platform.system() == 'Linux' and p not in windows: + # windows.append(p) + # cv2.namedWindow(str(p), + # cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) + # cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0]) + # # cv2.imshow(str(p), im0) # imshow + # cv2.waitKey(1) # 1 millisecond + + # # Save results (image with detections) + # if save_img: + # if dataset.mode == 'image': + # cv2.imwrite(save_path, im0) + # else: # 'video' or 'stream' + # if vid_path[i] != save_path: # new video + # vid_path[i] = save_path + # if isinstance(vid_writer[i], cv2.VideoWriter): + # vid_writer[i].release() # release previous video writer + # if vid_cap: # video + # fps = vid_cap.get(cv2.CAP_PROP_FPS) + # w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + # h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + # else: # stream + # fps, w, h = 30, im0.shape[1], im0.shape[0] + # save_path = str( + # Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos + # vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, + # (w, h)) + # vid_writer[i].write(im0) + + # # # Print time (inference-only) + # # LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms") + + # # loop stop + # if self.stop_event.is_set(): + # if self.thread_value.current_status == self.thread_value.STATUS_COMPLETE: + # pass + # else: + # self.thread_value.current_status = self.thread_value.STATUS_STOP + # # TODO(JWKIM):ageing test 23-02-28 + # break + + # elif self.thread_value.timeout_status: + # self.thread_value.current_status = self.thread_value.STATUS_TIMEOUT + # # TODO(JWKIM):ageing test 23-02-28 + # break + + # # Print results + # t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image + # LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t) + # if save_txt or save_img: + # s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' + # LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}") + # if update: + # strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning) + + # # timeout + # if self.thread_value.timeout_status: + # self._publish(self.thread_value.STATUS_TIMEOUT, self.thread_value.result) + # self.thread_value.timeout_status = False + + # # stop + # elif self.stop_event.is_set() and \ + # self.thread_value.current_status != self.thread_value.STATUS_COMPLETE: + # self._publish(self.thread_value.STATUS_STOP, self.thread_value.result) + + except Exception as e: + self._publish(self.thread_value.STATUS_ERROR + str(e), self.thread_value.result) + self._queue_empty() + + finally: + # dataloader 강제 종료 + while True: + if 'dataset' in locals(): + if isinstance(dataset, LoadStreams): + target = list(map(lambda x: x.is_alive(), dataset.threads)) + if True in target: + self.dataloader_event.set() + elif True not in target: + break + else: + break + else: + break + + if not self.timeout_event.is_set(): + self.timeout_event.set() + + self.timeout_event.clear() + self.stop_event.clear() + self.worker_event.set() + + def _sftp_upload(self): + try: + transprot = paramiko.Transport((self.ftp_info.ip,self.ftp_info.port)) + transprot.connect(username = self.ftp_info.id, password = self.ftp_info.pw) + sftp = paramiko.SFTPClient.from_transport(transprot) + + remotepath = self.ftp_info.location + os.sep + self.ftp_info.file_con_setup + '.jpg' + + #sftp.put(AI_CONST.FTP_CON_RESULT, remotepath) + + sftp.close() + transprot.close() + + return remotepath + except Exception as e: + return "" + + def _result_update(self, detected: dict): + """ + 모델에서 detect된 list와 기존 result list 를 비교하여 + 상태(current_status) 변경 + :param detected: 모델에서 detect된 list + """ + if detected: + self.thread_value.report_unit_result = [] + self.thread_value.report_unit_result = [None] * len(self.target_list) + + for key, value in detected.items(): + if self.thread_value.result[self.target_list.index(key)] is None: + if self.report_unit: + self.thread_value.report_unit_result[self.target_list.index(key)] = copy.deepcopy(value) + self.thread_value.current_status = self.thread_value.STATUS_NEW # NEW + self.thread_value.result[self.target_list.index(key)] = value + + if self.thread_value.result.count(None) == 0: + self.thread_value.current_status = self.thread_value.STATUS_COMPLETE # COMPLETE + + def _ai_rbi(self, detected): + # ri 정보 self.ri_info + return detected + + def _publish(self, status, result): + """ + mqtt publish + :param status: 탐지 상태 + :param result: 탐지장비 리스트 결과 + """ + #TODO(jwkim) topic 분리 + client.loop_start() + result_message = { + "datetime": self.thread_value.date_utill.date_str_micro_sec(), + "status": status, + "result": result + } + + if self.snapshot_path: + result_message[AI_CONST.DEMO_KEY_NAME_SNAPSHOT_SFTP] = self.snapshot_path + + client.publish(AI_CONST.MQTT_CON_TOPIC, json.dumps(result_message), 0) + # client.loop_stop() + + def _image_encoding(self, img_data): + """ + 이미지데이터(np.ndarray) 를 바이너리 데이터로 변환 + :param img_data: 이미지 데이터 + :return: base64 format + """ + from base64 import b64encode + _, JPEG = cv2.imencode(".jpg", img_data, [int(cv2.IMWRITE_JPEG_QUALITY), 80]) + + # Base64 encode + b64 = b64encode(JPEG) + + return b64.decode("utf-8") + + def _source_check(self, current_source): + """ + 현재 inference중인 video(혹은 stream)가 모델에 세팅된 video(혹은 stream)와 다를 시 예외 발생 + :param current_source: 현재 동작중인 소스 정보 + """ + _status = True + model_input_stream = '' + model_input_video = '' + for i in self.model_info.input_video: + if i.model == M.AEAIModelType.CON: + if i.connect_url.isnumeric() or i.connect_url.lower().startswith( + ('rtsp://', 'rtmp://', 'http://', 'https://')): + model_input_stream = clean_str(i.connect_url) + else: + model_input_video = (clean_str(i.connect_url)) + + if current_source.mode == 'stream': + current_url = current_source.sources + if not model_input_stream: + _status = False + elif model_input_stream != current_url[0]: + _status = False + + elif current_source.mode == 'video': + current_url = current_source.files[0] + if not model_input_video: + _status = False + elif str(Path(model_input_video).resolve()) != current_url: + _status = False + else: + # TODO(jwkim): stream, video 이외의 type 사용시 추가 + pass + + if not _status: + raise Exception(AI_CONST.SOURCE_CHANGED_MSG) + + def _queue_empty(self): + """ + queue_info 에 있는 데이터를 비운다. + """ + while True: + if self.queue_info.empty(): + break + else: + self.queue_info.get() + + + +class PPEDetect: + """ + 개인 보호구(PPE) 탐지 + """ + + def __init__(self, + request: M.AEAIModelPPEDetectReq, + engine_info: M.AEInfo, + thread_value, + yolo_argument, + worker_event: threading.Event, + stop_event: threading.Event, + timeout_event: threading.Event, + queue_info + ): + # 이미지 좌표계 + self.bbox_type = AI_CONST.BBOX_XYXY + + self.report_unit = request.report_unit + self.targets = request.targets + self.target_list = [] + for key, value in self.targets: + if value: + self.target_list.append(key) + + self.model_info = engine_info + self.thread_value = thread_value + + # ri + self.ri_info = request.ri or self.model_info.ppe_model_info.ri + + self.thread_value.result = [None] * len(self.target_list) + self.thread_value.report_unit_result = [None] * len(self.target_list) + + self.yolo_args = yolo_argument + + self.worker_event = worker_event + self.stop_event = stop_event + self.timeout_event = timeout_event + self.dataloader_event = threading.Event() + + self.queue_info = queue_info + + self.ftp_info = self.model_info.demo.ftp + self.snapshot_path = None + self.result_frame = [] + + @smart_inference_mode() + def run( + self, + weights=D2_ROOT / 'index_78.pt', # model path or triton URL + source=D2_ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam) + data=D2_ROOT / 'data/coco128.yaml', # dataset.yaml path + imgsz=(640, 640), # inference size (height, width) + conf_thres=0.25, # confidence threshold + iou_thres=0.45, # NMS IOU threshold + max_det=1000, # maximum detections per image + device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu + view_img=False, # show results + save_txt=False, # save results to *.txt + save_conf=False, # save confidences in --save-txt labels + save_crop=False, # save cropped prediction boxes + nosave=True, # do not save images/videos + classes=None, # filter by class: --class 0, or --class 0 2 3 + agnostic_nms=False, # class-agnostic NMS + augment=False, # augmented inference + visualize=False, # visualize features + update=False, # update all models + project=D2_ROOT / 'runs/detect', # save results to project/name + name='exp', # save results to project/name + exist_ok=False, # existing project/name ok, do not increment + line_thickness=3, # bounding box thickness (pixels) + hide_labels=False, # hide labels + hide_conf=False, # hide confidences + half=False, # use FP16 half-precision inference + dnn=False, # use OpenCV DNN for ONNX inference + vid_stride=1, # video frame-rate stride + ): + try: + if os.path.exists(AI_CONST.FTP_PPE_RESULT): + os.remove(AI_CONST.FTP_PPE_RESULT) + + source = str(source) + save_img = not nosave and not source.endswith('.txt') # save inference images + is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) + is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) + webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) + screenshot = source.lower().startswith('screen') + if is_url and is_file: + source = check_file(source) # download + + # Directories + save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run + # (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir + + # # Load model + # device = select_device(device) + # model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) + # stride, names, pt = model.stride, model.names, model.pt + # imgsz = check_img_size(imgsz, s=stride) # check image size + + # Run inference + self._publish(self.thread_value.STATUS_START, self.thread_value.result) + + # while True: + # Dataloader + bs = 1 # batch_size + if webcam: + # LoadStreams 한번만 실행 + if not 'dataset' in locals(): + # view_img = check_imshow(warn=True) # imshow + dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride, + event=self.dataloader_event) + bs = len(dataset) + elif screenshot: + dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt) + else: + dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) + vid_path, vid_writer = [None] * bs, [None] * bs + + model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup + seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) + + for path, im, im0s, vid_cap, s in dataset: + + # # loop stop + # if self.stop_event.is_set(): + # if self.thread_value.current_status == self.thread_value.STATUS_COMPLETE: + # pass + # else: + # self.thread_value.current_status = self.thread_value.STATUS_STOP + # #TODO(JWKIM):ageing test 23-02-28 + # break + + # elif self.thread_value.timeout_status: + # self.thread_value.current_status = self.thread_value.STATUS_TIMEOUT + # break #TODO(JWKIM):ageing test 23-02-28 + + with dt[0]: + im = torch.from_numpy(im).to(model.device) + im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 + im /= 255 # 0 - 255 to 0.0 - 1.0 + if len(im.shape) == 3: + im = im[None] # expand for batch dim + + # Inference + with dt[1]: + visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False + pred = model(im, augment=augment, visualize=visualize) + + # NMS + with dt[2]: + pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) + + # Second-stage classifier (optional) + # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s) + + # Process predictions + for i, det in enumerate(pred): # per image + seen += 1 + if webcam: # batch_size >= 1 + p, im0, frame = path[i], im0s[i].copy(), dataset.count + s += f'{i}: ' + else: + p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0) + + # p = Path(p) # to Path + # save_path = str(save_dir / p.name) # im.jpg + # txt_path = str(save_dir / 'labels' / p.stem) + ( + # '' if dataset.mode == 'image' else f'_{frame}') # im.txt + # s += '%gx%g ' % im.shape[2:] # print string + # gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh + # imc = im0.copy() if save_crop else im0 # for save_crop + # annotator = Annotator(im0, line_width=line_thickness, example=str(names)) + + self.thread_value.current_status == self.thread_value.STATUS_COMPLETE + cv2.imwrite(AI_CONST.FTP_PPE_RESULT,copy.deepcopy(im0)) + if project_config.SFTP_UPLOAD: + self._sftp_upload() + demo_dict=[ + { + "name": 'safety_helmet_on', + "cid": 1, + "confidence": 0.7, + "bbox": [ 150, 150, 100, 200 ], + "image": None # nparray -> list + }, + { + "name": 'safety_gloves_work_on', + "cid": 25, + "confidence": 0.7, + "bbox": [ 150, 150, 100, 200 ], + "image": None # nparray -> list + }, + { + "name": 'safety_boots_on', + "cid": 25, + "confidence": 0.7, + "bbox": [ 150, 150, 100, 200 ], + "image": None # nparray -> list + } + ] + print(demo_dict) + self._publish(status = self.thread_value.STATUS_COMPLETE, result = demo_dict) + break + break + + # # FR START + # #TODO(jwkim): PPE + FR + # # imface = im0.copy() + # # self._fr_inference(imface) + + # # 딕셔너리 초기화 + # detected_object = {} + + # if len(det): + # # Rescale boxes from img_size to im0 size + # det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() + + # # Print results + # for c in det[:, 5].unique(): + # n = (det[:, 5] == c).sum() # detections per class + # s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string + + # # Write results + # for *xyxy, conf, cls in reversed(det): + # if save_txt: # Write to file + # xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view( + # -1).tolist() # normalized xywh + # line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format + # with open(f'{txt_path}.txt', 'a') as f: + # f.write(('%g ' * len(line)).rstrip() % line + '\n') + + # if save_img or save_crop or view_img: # Add bbox to image + # c = int(cls) # integer class + # label = None if hide_labels else ( + # names[c] if hide_conf else f'{names[c]} {conf:.2f}') + # annotator.box_label(xyxy, label, color=colors(c, True)) + # if save_crop: + # save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', + # BGR=True) + + # c = int(cls) + + # if names[c] in self.target_list: + # # crop image 전송여부 + # if self.model_info.ppe_model_info.crop_images: + # # crop_image_data = save_one_box(xyxy, imc, BGR=True, save=False).tolist() + # crop_image_data = save_one_box(xyxy, imc, BGR=True, save=False) + # crop_image_data = self._image_encoding(crop_image_data) + # else: + # crop_image_data = None + + # if self.bbox_type == AI_CONST.BBOX_XYXY: + # result_bbox = [int(xyxy[1]), int(xyxy[0]), int(xyxy[3]), int(xyxy[2])] + # elif self.bbox_type == AI_CONST.BBOX_XYWH: + # result_bbox = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() + + # if [names[c]] == "safety_gloves_work_on": + # pass + + # detected_object["safety_gloves_work_on"] = { + # "name": "safety_gloves_work_on", + # "cid": 5, + # "confidence": round(conf.item(), 2), + # "bbox": result_bbox, + # "image": None # nparray -> list + # } + + # detected_object[names[c]] = { + # "name": names[c], + # "cid": c, + # "confidence": round(conf.item(), 2), + # "bbox": result_bbox, + # "image": crop_image_data # nparray -> list + # } + + # self.result_frame = copy.deepcopy(im0) + + # # input source check + # self._source_check(dataset) + + # # ri + # detected_object = self._ai_rbi(detected_object) + + # # update + # self._result_update(detected_object) + + # # publish new + # if self.thread_value.current_status == self.thread_value.STATUS_NEW: + # self._publish(self.thread_value.STATUS_NEW, self.thread_value.report_unit_result) + # self.thread_value.current_status = self.thread_value.STATUS_NONE + + # # publish complete + # elif self.thread_value.current_status == self.thread_value.STATUS_COMPLETE: + # cv2.imwrite(AI_CONST.FTP_PPE_RESULT,self.result_frame) + + # if project_config.SFTP_UPLOAD: + # self.snapshot_path = self._sftp_upload() + # else: + # self.snapshot_path = None + + # self._publish(self.thread_value.STATUS_COMPLETE, self.thread_value.result) + # self.stop_event.set() + + # # if frame % 60 == 0 : + # # #TODO(JWKIM):ageing test 23-02-28 + # # self._publish("AGEING_TEST", self.thread_value.result) + + # # #TODO(JWKIM):ageing test 23-02-28 + # if self.thread_value.current_status == self.thread_value.STATUS_COMPLETE: + # break + + # # loop stop + # if self.stop_event.is_set(): + # if self.thread_value.current_status == self.thread_value.STATUS_COMPLETE: + # pass + # else: + # self.thread_value.current_status = self.thread_value.STATUS_STOP + # # TODO(JWKIM):ageing test 23-02-28 + # break + + # elif self.thread_value.timeout_status: + # self.thread_value.current_status = self.thread_value.STATUS_TIMEOUT + # # #TODO(JWKIM):ageing test 23-02-28 + # break + + # # Stream results + # im0 = annotator.result() + # if view_img: + # if platform.system() == 'Linux' and p not in windows: + # windows.append(p) + # cv2.namedWindow(str(p), + # cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) + # cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0]) + # # cv2.imshow(str(p), im0) # imshow + # cv2.waitKey(1) # 1 millisecond + + # # Save results (image with detections) + # if save_img: + # if dataset.mode == 'image': + # cv2.imwrite(save_path, im0) + # else: # 'video' or 'stream' + # if vid_path[i] != save_path: # new video + # vid_path[i] = save_path + # if isinstance(vid_writer[i], cv2.VideoWriter): + # vid_writer[i].release() # release previous video writer + # if vid_cap: # video + # fps = vid_cap.get(cv2.CAP_PROP_FPS) + # w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + # h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + # else: # stream + # fps, w, h = 30, im0.shape[1], im0.shape[0] + # save_path = str( + # Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos + # vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, + # (w, h)) + # vid_writer[i].write(im0) + + # # Print time (inference-only) + # LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms") + + # # loop stop + # if self.stop_event.is_set(): + # if self.thread_value.current_status == self.thread_value.STATUS_COMPLETE: + # pass + # else: + # self.thread_value.current_status = self.thread_value.STATUS_STOP + # # TODO(JWKIM):ageing test 23-02-28 + # break + + # elif self.thread_value.timeout_status: + # self.thread_value.current_status = self.thread_value.STATUS_TIMEOUT + # # TODO(JWKIM):ageing test 23-02-28 + # break + + # # Print results + # t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image + # LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t) + # if save_txt or save_img: + # s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' + # LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}") + # if update: + # strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning) + + # # timeout + # if self.thread_value.timeout_status: + # self._publish(self.thread_value.STATUS_TIMEOUT, self.thread_value.result) + # self.thread_value.timeout_status = False + + # # stop + # elif self.stop_event.is_set() and \ + # self.thread_value.current_status != self.thread_value.STATUS_COMPLETE: + # self._publish(self.thread_value.STATUS_STOP, self.thread_value.result) + + except Exception as e: + self._publish(self.thread_value.STATUS_ERROR + str(e), self.thread_value.result) + self._queue_empty() + + finally: + # dataloader 강제 종료 + while True: + if 'dataset' in locals(): + if isinstance(dataset, LoadStreams): + target = list(map(lambda x: x.is_alive(), dataset.threads)) + if True in target: + self.dataloader_event.set() + elif True not in target: + break + else: + break + else: + break + + if not self.timeout_event.is_set(): + self.timeout_event.set() + + self.timeout_event.clear() + self.stop_event.clear() + self.worker_event.set() + + def _sftp_upload(self): + try: + transprot = paramiko.Transport((self.ftp_info.ip,self.ftp_info.port)) + transprot.connect(username = self.ftp_info.id, password = self.ftp_info.pw) + sftp = paramiko.SFTPClient.from_transport(transprot) + + remotepath = self.ftp_info.location + os.sep + self.ftp_info.file_ppe + '.jpg' + + #sftp.put(AI_CONST.FTP_PPE_RESULT, remotepath) + + sftp.close() + transprot.close() + + return remotepath + except Exception as e: + return "" + + def _fr_inference(self,current_frame): + + # start publish + # self._mqtt_publish(status=self.thread_value.STATUS_START) + + known_faces = self._fr_encoding_list() + + rgb_frame = current_frame[:, :, ::-1] + + face_locations = face_recognition.face_locations(rgb_frame) + face_encodings = face_recognition.face_encodings(rgb_frame, face_locations) + + for face_encoding in face_encodings: + + matched_name = [] # detect 완료된 명단 + + match = face_recognition.compare_faces(known_faces, face_encoding, tolerance=AI_CONST.FACE_EVOLUTION_DISTANCE) + + # #TODO(jwkim): 입력 이미지 변경 + # if self.fr_manager.fr_id_info != self.id_info: + # raise Exception(AI_CONST.IMG_CHANGED_MSG) + + face_distances = face_recognition.face_distance(known_faces, face_encoding) + best_match_index = np.argmin(face_distances) + + if face_distances[best_match_index] < AI_CONST.FACE_EVOLUTION_DISTANCE : + if match[best_match_index]: + matched_name.append(self.worker_names[best_match_index]) + print(self.worker_names[best_match_index]) + + def _fr_encoding_list(self): + """ + 안면인식 대상 이미지 등록 + :return: encoding 정보(list) + """ + # Load some sample pictures and learn how to recognize them. + worker1_image = face_recognition.load_image_file(AI_CONST.WORKER1_IMG_PATH) + no000001 = { + 'target_names' : 'no000001' , + 'encoding' : face_recognition.face_encodings(worker1_image)[0] + } + + worker2_image = face_recognition.load_image_file(AI_CONST.WORKER2_IMG_PATH) + no000002 = { + 'target_names' : 'no000002', + 'encoding' : face_recognition.face_encodings(worker2_image)[0] + } + + worker3_image = face_recognition.load_image_file(AI_CONST.WORKER3_IMG_PATH) + no000003 = { + 'target_names' : 'no000003', + 'encoding' : face_recognition.face_encodings(worker3_image)[0] + } + + encoding_list = [no000001,no000002,no000003] + result = [] + + for i in encoding_list: + if i['target_names'] in self.targets: + result.append(i["encoding"]) + + if not result: + raise Exception("invalid targets") + + return result + + def _result_update(self, detected: dict): + """ + 모델에서 detect된 list와 기존 result list 를 비교하여 + 상태(current_status) 변경 + :param detected: 모델에서 detect된 list + """ + if detected: + self.thread_value.report_unit_result = [] + self.thread_value.report_unit_result = [None] * len(list(dict(self.targets).keys())) + + for key, value in detected.items(): + if self.thread_value.result[self.target_list.index(key)] is None: + if self.report_unit: + self.thread_value.report_unit_result[self.target_list.index(key)] = copy.deepcopy(value) + self.thread_value.current_status = self.thread_value.STATUS_NEW # NEW + self.thread_value.result[self.target_list.index(key)] = value + + if self.thread_value.result.count(None) == 0: + self.thread_value.current_status = self.thread_value.STATUS_COMPLETE # COMPLETE + + def _ai_rbi(self, detected): + # ri 정보 self.ri_info + return detected + + def _publish(self, status, result): + """ + mqtt publish + :param status: 탐지 상태 + :param result: 탐지장비 리스트 결과 + """ + #TODO(jwkim) topic 분리 + client.loop_start() + result_message = { + "datetime": self.thread_value.date_utill.date_str_micro_sec(), + "status": status, + "result": result + } + if self.snapshot_path: + result_message[AI_CONST.DEMO_KEY_NAME_SNAPSHOT_SFTP] = self.snapshot_path + + print(result_message) + client.publish(AI_CONST.MQTT_PPE_TOPIC, json.dumps(result_message), 0) + # client.loop_stop() + + def _image_encoding(self, img_data): + """ + 이미지데이터(np.ndarray) 를 바이너리 데이터로 변환 + :param img_data: 이미지 데이터 + :return: base64 format + """ + from base64 import b64encode + _, JPEG = cv2.imencode(".jpg", img_data, [int(cv2.IMWRITE_JPEG_QUALITY), 80]) + + # Base64 encode + b64 = b64encode(JPEG) + + return b64.decode("utf-8") + + def _source_check(self, current_source): + """ + 현재 inference중인 video(혹은 stream)가 모델에 세팅된 video(혹은 stream)와 다를 시 예외 발생 + :param current_source: 현재 동작중인 소스 정보 + """ + _status = True + model_input_stream = '' + model_input_video = '' + for i in self.model_info.input_video: + if i.model == M.AEAIModelType.PPE: + if i.connect_url.isnumeric() or i.connect_url.lower().startswith( + ('rtsp://', 'rtmp://', 'http://', 'https://')): + model_input_stream = clean_str(i.connect_url) + else: + model_input_video = (clean_str(i.connect_url)) + + if current_source.mode == 'stream': + current_url = current_source.sources + if not model_input_stream: + _status = False + elif model_input_stream != current_url[0]: + _status = False + + elif current_source.mode == 'video': + current_url = current_source.files[0] + if not model_input_video: + _status = False + elif str(Path(model_input_video).resolve()) != current_url: + _status = False + else: + # TODO(jwkim): stream, video 이외의 type 사용시 추가 + pass + + if not _status: + raise Exception(AI_CONST.SOURCE_CHANGED_MSG) + + def _queue_empty(self): + """ + queue_info 에 있는 데이터를 비운다. + """ + while True: + if self.queue_info.empty(): + break + else: + self.queue_info.get() + + +class WDDetect: + def __init__(self, + request: M.AEAIModelRBIVideoReq, + engine_info: M.AEInfo, + thread_value, + yolo_argument, + worker_event: threading.Event, + stop_event: threading.Event, + queue_info: threading.Event + ): + # 이미지 좌표계 + self.bbox_type = AI_CONST.BBOX_XYXY + + self.thread_value = thread_value + + self.model_info = engine_info + self.yolo_args = yolo_argument + + # ri + self.ri_info = request.ri or self.model_info.wd_model_info.ri + + self.current_ri = 0 + self.detect_list = [] + + self.worker_event = worker_event + self.stop_event = stop_event + self.dataloader_event = threading.Event() + + self.queue_info = queue_info + + self.ftp_info = self.model_info.demo.ftp + self.snapshot_path = None + self.snapshot_path_bi = None + self.result_frame = [] + + @smart_inference_mode() + def run( + self, + weights=D2_ROOT / 'index_78.pt', # model path or triton URL + source=D2_ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam) + data=D2_ROOT / 'data/coco128.yaml', # dataset.yaml path + imgsz=(640, 640), # inference size (height, width) + conf_thres=0.25, # confidence threshold + iou_thres=0.45, # NMS IOU threshold + max_det=1000, # maximum detections per image + device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu + view_img=False, # show results + save_txt=False, # save results to *.txt + save_conf=False, # save confidences in --save-txt labels + save_crop=False, # save cropped prediction boxes + nosave=True, # do not save images/videos + classes=None, # filter by class: --class 0, or --class 0 2 3 + agnostic_nms=False, # class-agnostic NMS + augment=False, # augmented inference + visualize=False, # visualize features + update=False, # update all models + project=D2_ROOT / 'runs/detect', # save results to project/name + name='exp', # save results to project/name + exist_ok=False, # existing project/name ok, do not increment + line_thickness=3, # bounding box thickness (pixels) + hide_labels=False, # hide labels + hide_conf=False, # hide confidences + half=False, # use FP16 half-precision inference + dnn=False, # use OpenCV DNN for ONNX inference + vid_stride=1, # video frame-rate stride + ): + try: + self.snapshot_path = None + self.snapshot_path_bi = None + + if os.path.exists(AI_CONST.FTP_WD_RESULT): + os.remove(AI_CONST.FTP_WD_RESULT) + + # if os.path.exists(AI_CONST.FTP_BI_RESULT): + # os.remove(AI_CONST.FTP_BI_RESULT) + + # #demo + # DEMO.demo_wd_bi(1) + + source = str(source) + save_img = not nosave and not source.endswith('.txt') # save inference images + is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) + is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) + webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) + screenshot = source.lower().startswith('screen') + if is_url and is_file: + source = check_file(source) # download + + # Directories + save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run + # (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir + + # # Load model + # device = select_device(device) + # model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) + # stride, names, pt = model.stride, model.names, model.pt + # imgsz = check_img_size(imgsz, s=stride) # check image size + + # Run inference + self._publish(topic = AI_CONST.MQTT_WD_TOPIC, status = self.thread_value.STATUS_START,raise_list = []) + + #while True: + # Dataloader + bs = 1 # batch_size + if webcam: + # stream 영상은 한번만 load + if not 'dataset' in locals(): + # view_img = check_imshow(warn=True) # imshow + dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride, + event=self.dataloader_event) + bs = len(dataset) + elif screenshot: + dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt) + else: + dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) + vid_path, vid_writer = [None] * bs, [None] * bs + + model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup + seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) + for path, im, im0s, vid_cap, s in dataset: + + # # loop stop + # if self.stop_event.is_set(): + # self.thread_value.current_status = self.thread_value.STATUS_STOP + # #TODO(JWKIM):ageing test 23-02-28 + # break + + with dt[0]: + im = torch.from_numpy(im).to(model.device) + im = im.half() if model.fp16 else im.float() # uint8 to fp16/32 + im /= 255 # 0 - 255 to 0.0 - 1.0 + if len(im.shape) == 3: + im = im[None] # expand for batch dim + + # Inference + with dt[1]: + visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False + pred = model(im, augment=augment, visualize=visualize) + + # NMS + with dt[2]: + pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) + + # Second-stage classifier (optional) + # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s) + + # Process predictions + for i, det in enumerate(pred): # per image + seen += 1 + if webcam: # batch_size >= 1 + p, im0, frame = path[i], im0s[i].copy(), dataset.count + s += f'{i}: ' + else: + p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0) + + # p = Path(p) # to Path + # save_path = str(save_dir / p.name) # im.jpg + # txt_path = str(save_dir / 'labels' / p.stem) + ( + # '' if dataset.mode == 'image' else f'_{frame}') # im.txt + # s += '%gx%g ' % im.shape[2:] # print string + # gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh + # imc = im0.copy() if save_crop else im0 # for save_crop + # annotator = Annotator(im0, line_width=line_thickness, example=str(names)) + + self.thread_value.current_status == self.thread_value.STATUS_DETECT + cv2.imwrite(AI_CONST.FTP_WD_RESULT,copy.deepcopy(im0)) + # if project_config.SFTP_UPLOAD: + self._sftp_upload(topic = AI_CONST.MQTT_WD_TOPIC) + demo_dict=[ + { + "name": 'safety_helmet_off', + "cid": 2, + "confidence": 0.7, + "bbox": [ 150, 150, 100, 200 ], + "image": None # nparray -> list + } + ] + print(demo_dict) + self._publish(topic=AI_CONST.MQTT_WD_TOPIC, status = self.thread_value.STATUS_DETECT, raise_list = demo_dict) + break + break + + # # detect list + # self.detect_list = [] + + # if len(det): + # # Rescale boxes from img_size to im0 size + # det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() + + # # Print results + # for c in det[:, 5].unique(): + # n = (det[:, 5] == c).sum() # detections per class + # s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string + + # # Write results + # for *xyxy, conf, cls in reversed(det): + # if save_txt: # Write to file + # xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view( + # -1).tolist() # normalized xywh + # line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format + # with open(f'{txt_path}.txt', 'a') as f: + # f.write(('%g ' * len(line)).rstrip() % line + '\n') + + # if save_img or save_crop or view_img: # Add bbox to image + # c = int(cls) # integer class + # label = None if hide_labels else ( + # names[c] if hide_conf else f'{names[c]} {conf:.2f}') + # annotator.box_label(xyxy, label, color=colors(c, True)) + # if save_crop: + # save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', + # BGR=True) + + # # jwkim + # c = int(cls) + + # # crop image 전송여부 + # if self.model_info.wd_model_info.crop_images: + # # crop_image_data = save_one_box(xyxy, imc, BGR=True, save=False).tolist() + # crop_image_data = save_one_box(xyxy, imc, BGR=True, save=False) + # crop_image_data = self._image_encoding(crop_image_data) + # else: + # crop_image_data = None + + # if self.bbox_type == AI_CONST.BBOX_XYXY: + # result_bbox = [int(xyxy[1]), int(xyxy[0]), int(xyxy[3]), int(xyxy[2])] + # elif self.bbox_type == AI_CONST.BBOX_XYWH: + # result_bbox = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() + + # detected_object = { + # "name": names[c], + # "cid": c, + # "confidence": round(conf.item(), 2), + # "bbox": result_bbox, + # "image": crop_image_data # nparray -> list + # } + # if detected_object["cid"] in AI_CONST.OFF_CLASS_LIST: + # # off class 만 저장 + # self.detect_list.append(detected_object) + + # self.result_frame = copy.deepcopy(im0) + + # # input source check + # self._source_check(dataset) + + # # ri update + # self._ai_rbi(dataset, self.detect_list) + + # #demo + # if DEMO.DEMO_WD_BI_CONST == 2: + # if not os.path.exists(AI_CONST.FTP_BI_RESULT): + # cv2.imwrite(AI_CONST.FTP_BI_RESULT,self.result_frame) + # if project_config.UPLOAD: + # self.snapshot_path_bi = self._sftp_upload(topic = AI_CONST.MQTT_BI_TOPIC) + # self._publish_bi_demo(topic = AI_CONST.MQTT_BI_TOPIC, status = self.thread_value.STATUS_DETECT) + # self.snapshot_path_bi = None + + # if self.detect_list and self.current_ri > self.ri_info.work_define_ri : + # self.thread_value.current_status = self.thread_value.STATUS_DETECT + + # if not os.path.exists(AI_CONST.FTP_WD_RESULT): + # cv2.imwrite(AI_CONST.FTP_WD_RESULT,self.result_frame) + # if project_config.UPLOAD: + # self.snapshot_path = self._sftp_upload(topic = AI_CONST.MQTT_WD_TOPIC) + # self._publish(topic = AI_CONST.MQTT_WD_TOPIC, status = self.thread_value.STATUS_DETECT) + + # self.snapshot_path = None + + # # loop stop + # if self.stop_event.is_set(): + # self.thread_value.current_status = self.thread_value.STATUS_STOP + # # TODO(JWKIM):ageing test 23-02-28 + # break + + # # Stream results + # im0 = annotator.result() + # if view_img: + # if platform.system() == 'Linux' and p not in windows: + # windows.append(p) + # cv2.namedWindow(str(p), + # cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) + # cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0]) + # # cv2.imshow(str(p), im0) # imshow + # cv2.waitKey(1) # 1 millisecond + + # # Save results (image with detections) + # if save_img: + # if dataset.mode == 'image': + # cv2.imwrite(save_path, im0) + # else: # 'video' or 'stream' + # if vid_path[i] != save_path: # new video + # vid_path[i] = save_path + # if isinstance(vid_writer[i], cv2.VideoWriter): + # vid_writer[i].release() # release previous video writer + # if vid_cap: # video + # fps = vid_cap.get(cv2.CAP_PROP_FPS) + # w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + # h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + # else: # stream + # fps, w, h = 30, im0.shape[1], im0.shape[0] + # save_path = str( + # Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos + # vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, + # (w, h)) + # vid_writer[i].write(im0) + + # # Print time (inference-only) + # # LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms") + + # # loop stop + # if self.stop_event.is_set(): + # self.thread_value.current_status = self.thread_value.STATUS_STOP + # # TODO(JWKIM):ageing test 23-02-28 + # break + + # # Print results + # t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image + # LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t) + # if save_txt or save_img: + # s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' + # LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}") + # if update: + # strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning) + + # if self.stop_event.is_set(): + # self._publish(topic = AI_CONST.MQTT_WD_TOPIC, status = self.thread_value.STATUS_STOP) + + except Exception as e: + self._publish(topic = AI_CONST.MQTT_WD_TOPIC, status = self.thread_value.STATUS_ERROR + str(e), raise_list=self.detect_list) + self._queue_empty() + + finally: + # dataloader 강제 종료 + while True: + if 'dataset' in locals(): + if isinstance(dataset, LoadStreams): + target = list(map(lambda x: x.is_alive(), dataset.threads)) + if True in target: + self.dataloader_event.set() + elif True not in target: + break + else: + break + else: + break + + #demo + DEMO.demo_wd_bi(0) + + self.stop_event.clear() + self.worker_event.set() + + def _sftp_upload(self, topic): + try: + transprot = paramiko.Transport((self.ftp_info.ip,self.ftp_info.port)) + transprot.connect(username = self.ftp_info.id, password = self.ftp_info.pw) + sftp = paramiko.SFTPClient.from_transport(transprot) + + if topic == AI_CONST.MQTT_WD_TOPIC: + print('wd') + remotepath = self.ftp_info.location + os.sep + self.ftp_info.file_wd + '.jpg' + #sftp.put(AI_CONST.FTP_WD_RESULT, remotepath) + + elif topic == AI_CONST.MQTT_BI_TOPIC: + remotepath = self.ftp_info.location + os.sep + self.ftp_info.file_bi + '.jpg' + #sftp.put(AI_CONST.FTP_BI_RESULT, remotepath) + + sftp.close() + transprot.close() + + return remotepath + except Exception as e: + return "" + + def _ai_rbi(self, dataloader, detect_list): + """ + ri 수치 업데이트 1.0v + stream : detect_list(off class)가 하나라도 있을시 + video file : n프레임(혹은 카운트)마다 위험성 탐지 + + :param dataloader: dataloader + :param detect_list: detect된 object + """ + # TODO(jwkim) dataloader는 ri가 정해지면 삭제 예정 + COUNT = AI_CONST.WD_FRAME_COUNT + + result = [] + + import random + + if detect_list: + for i in detect_list: + if i["cid"] == 2 or i["cid"] == 4: + result.append(i["cid"]) + if AI_CONST.OFF_TRIGGER_CLASS_LIST[0] in result and AI_CONST.OFF_TRIGGER_CLASS_LIST[1] in result: + if dataloader.mode == 'stream' : + if dataloader.count % COUNT == 0: + self.current_ri = round(random.uniform(self.ri_info.work_define_ri, 1), 2) + elif dataloader.frame % COUNT == 0: + self.current_ri = round(random.uniform(self.ri_info.work_define_ri, 1), 2) + else: + self.current_ri = round(random.uniform(0, self.ri_info.work_define_ri), 2) + + def _publish(self, topic, status, raise_list=None): + """ + mqtt publish + :param status: 위험성탐지 상태 + :param raise_list: 예외발생시 detect_list + """ + client.loop_start() + result_message = { + "datetime": self.thread_value.date_utill.date_str_micro_sec(), + "status": status, + "result": { + 'construction_type': 'D4', + 'procedure_no': 1, + 'procedure_ri': 0.7, + 'ri': 0.7, + 'detect_list': raise_list + } + } + + if self.snapshot_path: + result_message[AI_CONST.DEMO_KEY_NAME_SNAPSHOT_SFTP] = self.snapshot_path + + client.publish(topic, json.dumps(result_message), 0) + # client.loop_stop() + + def _publish_bi_demo(self, topic, status, raise_list=None): + """ + mqtt publish + :param status: 위험성탐지 상태 + :param raise_list: 예외발생시 detect_list + """ + client.loop_start() + result_message = { + "datetime": self.thread_value.date_utill.date_str_micro_sec(), + "status": status, + "result": { + 'construction_type': self.ri_info.construction_code, + 'procedure_no': self.ri_info.work_no, + 'procedure_ri': self.ri_info.work_define_ri, + 'ri': self.current_ri, + 'detect_list': raise_list if raise_list else self.detect_list + } + } + + if self.snapshot_path_bi: + result_message[AI_CONST.DEMO_KEY_NAME_SNAPSHOT_SFTP] = self.snapshot_path_bi + + client.publish(topic, json.dumps(result_message), 0) + # client.loop_stop() + + def _image_encoding(self, img_data): + """ + 이미지데이터(np.ndarray) 를 바이너리 데이터로 변환 + :param img_data: 이미지 데이터 + :return: base64 format + """ + from base64 import b64encode + _, JPEG = cv2.imencode(".jpg", img_data, [int(cv2.IMWRITE_JPEG_QUALITY), 80]) + + # Base64 encode + b64 = b64encode(JPEG) + + return b64.decode("utf-8") + + def _source_check(self, current_source): + """ + 현재 inference중인 video(혹은 stream)가 모델에 세팅된 video(혹은 stream)와 다를 시 예외 발생 + :param current_source: 현재 동작중인 소스 정보 + """ + _status = True + model_input_stream = [] + model_input_video = [] + for i in self.model_info.input_video: + if i.model == M.AEAIModelType.WORK: + if i.connect_url.isnumeric() or i.connect_url.lower().startswith( + ('rtsp://', 'rtmp://', 'http://', 'https://')): + model_input_stream.append(clean_str(i.connect_url)) + else: + model_input_video.append(clean_str(i.connect_url)) + + model_input_stream.sort() + + if current_source.mode == 'stream': + current_url = current_source.sources + current_url.sort() + if not model_input_stream: + _status = False + elif model_input_stream != current_url: + _status = False + + elif current_source.mode == 'video': + current_url = current_source.files[0] + if not model_input_video: + _status = False + elif str(Path(model_input_video[0]).resolve()) != current_url: + _status = False + else: + # TODO(jwkim): stream, video 이외의 type 사용시 추가 + pass + + if not _status: + raise Exception(AI_CONST.SOURCE_CHANGED_MSG) + + def _queue_empty(self): + """ + queue_info 에 있는 데이터를 비운다. + """ + while True: + if self.queue_info.empty(): + break + else: + self.queue_info.get() + + +if __name__ == "__main__": + pass diff --git a/DL/index_78.pt b/DL/index_78.pt new file mode 100644 index 0000000..83c999c Binary files /dev/null and b/DL/index_78.pt differ diff --git a/DL/wd.streams b/DL/wd.streams new file mode 100644 index 0000000..11e64b8 --- /dev/null +++ b/DL/wd.streams @@ -0,0 +1 @@ +rtsp://admin:admin1263!@10.20.10.99:28554/onvif/media?profile=Profile2 diff --git a/README.md b/README.md index 6cded48..8f85119 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,36 @@ -# AI_ENGINE_0.5 - -AI_ENGINE_0.5 \ No newline at end of file +# MG_ENGINE_SIM + +위험성 평가기반 작업자 안전관리 및 감독 협조 기술 개발 (시뮬레이터) + +### System Requirements + - ubuntu 20.04 + - python 3.8 + - python packages : requirements.txt + +### Pytorch, Torchvision install +- conda(CUDA 11.7 기준) + ``` + conda install pytorch torchvision pytorch-cuda=11.7 -c pytorch -c nvidia + ``` +- pip + ``` + pip3 install torch torchvision torchaudio + ``` +- ARM architecture + + JetPack 5.0.2(L4T R35.1.0) 기준 + 1. pytorch(version 1.12) + ``` + wget https://developer.download.nvidia.com/compute/redist/jp/v50/pytorch/torch-1.12.0a0+2c916ef.nv22.3-cp38-cp38-linux_aarch64.whl + apt-get install python3-pip libopenblas-base libopenmpi-dev libomp-dev + pip3 install Cython + pip3 install torch-1.12.0a0+2c916ef.nv22.3-cp38-cp38-linux_aarch64.whl + ``` + 2. torchvision(version 0.13) + ``` + sudo apt-get install libjpeg-dev zlib1g-dev libpython3-dev libavcodec-dev libavformat-dev libswscale-dev + git clone --branch release/0.13 https://github.com/pytorch/vision torchvision + cd torchvision + export BUILD_VERSION=0.13.0 + python3 setup.py install --user + ``` \ No newline at end of file diff --git a/REST_AI_ENGINE_CONTROL/.gitignore b/REST_AI_ENGINE_CONTROL/.gitignore new file mode 100644 index 0000000..7d95706 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/.gitignore @@ -0,0 +1,142 @@ +# ---> Python +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# YOLOv5 +/DL/OD \ No newline at end of file diff --git a/REST_AI_ENGINE_CONTROL/Dockerfile b/REST_AI_ENGINE_CONTROL/Dockerfile new file mode 100644 index 0000000..1058211 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/Dockerfile @@ -0,0 +1,27 @@ +# ------------------------------------------------------------------------------ +# Base image +# ------------------------------------------------------------------------------ +FROM python:3.9-slim + +# ------------------------------------------------------------------------------ +# Informations +# ------------------------------------------------------------------------------ +LABEL maintainer="hsj100 " +LABEL title="REST_AI_ENGINE_CONTROL" +LABEL description="Rest API Server with Fast API" + +# ------------------------------------------------------------------------------ +# Install dependencies +# ------------------------------------------------------------------------------ +WORKDIR /FAST_API +COPY ./requirements.txt /FAST_API/requirements.txt +RUN apt update > /dev/null && \ + apt install -y build-essential && \ + pip install --no-cache-dir --upgrade -r /FAST_API/requirements.txt + +# ------------------------------------------------------------------------------ +# Source +# ------------------------------------------------------------------------------ +COPY ./app /FAST_API/app + +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "50220"] diff --git a/REST_AI_ENGINE_CONTROL/README.md b/REST_AI_ENGINE_CONTROL/README.md new file mode 100644 index 0000000..3384f8b --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/README.md @@ -0,0 +1,4 @@ +# D2.KEPCO.AI_RBI.REST_AI_ENGINE_CONTROL + +REST SERVER +AI ENGINE 제어 \ No newline at end of file diff --git a/REST_AI_ENGINE_CONTROL/app/api_request_sample.py b/REST_AI_ENGINE_CONTROL/app/api_request_sample.py new file mode 100644 index 0000000..bc184fd --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/api_request_sample.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +""" +@file: api_request_sample.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: test api key + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +import base64 +import hmac +from datetime import datetime, timedelta + +import requests + + +def parse_params_to_str(params): + url = "?" + for key, value in params.items(): + url = url + str(key) + '=' + str(value) + '&' + return url[1:-1] + + +def hash_string(qs, secret_key): + mac = hmac.new(bytes(secret_key, encoding='utf8'), bytes(qs, encoding='utf-8'), digestmod='sha256') + d = mac.digest() + validating_secret = str(base64.b64encode(d).decode('utf-8')) + return validating_secret + + +def sample_request(): + access_key = 'c0883231-4aa9-4a1f-a77b-3ef250af-e449-42e9-856a-b3ada17c426b' + secret_key = 'QhOaeXTAAkW6yWt31jWDeERkBsZ3X4UmPds656YD' + cur_time = datetime.utcnow()+timedelta(hours=9) + cur_timestamp = int(cur_time.timestamp()) + qs = dict(key=access_key, timestamp=cur_timestamp) + header_secret = hash_string(parse_params_to_str(qs), secret_key) + + url = f'http://127.0.0.1:8080/api/services?{parse_params_to_str(qs)}' + res = requests.get(url, headers=dict(secret=header_secret)) + return res + + +print(sample_request().json()) diff --git a/REST_AI_ENGINE_CONTROL/app/common/config.py b/REST_AI_ENGINE_CONTROL/app/common/config.py new file mode 100644 index 0000000..6ca9bd0 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/common/config.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- +""" +@File: config.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: 실행 환경 설정 + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +from dataclasses import dataclass +from os import path, environ + +from app.common import consts +from app.models import UserInfo + + +base_dir = path.dirname(path.dirname(path.dirname(path.abspath(__file__)))) + + +@dataclass +class Config: + """ + 기본 Configuration + """ + BASE_DIR: str = base_dir + DB_POOL_RECYCLE: int = 900 + DB_ECHO: bool = True + DEBUG: bool = False + TEST_MODE: bool = False + DEV_TEST_CONNECT_ACCOUNT: str = None + + # NOTE(hsj100): SERVICE_AUTH_API_KEY (token 방식으로 진행) + SERVICE_AUTH_API_KEY: bool = False + + DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}:{consts.DB_PORT}/{consts.DB_NAME}?charset={consts.DB_CHARSET}') + REST_SERVER_PORT = consts.REST_SERVER_PORT + + SW_TITLE = consts.SW_TITLE + SW_VERSION = consts.SW_VERSION + SW_DESCRIPTION = consts.SW_DESCRIPTION + TERMS_OF_SERVICE = consts.TERMS_OF_SERVICE + CONTEACT = consts.CONTEACT + LICENSE_INFO = consts.LICENSE_INFO + + GLOBAL_TOKEN = consts.ADMIN_INIT_ACCOUNT_INFO.connect_token + + +@dataclass +class LocalConfig(Config): + TRUSTED_HOSTS = ['*'] + ALLOW_SITE = ['*'] + DEBUG: bool = False + + +@dataclass +class ProdConfig(Config): + TRUSTED_HOSTS = ['*'] + ALLOW_SITE = ['*'] + + +@dataclass +class TestConfig(Config): + DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}:{consts.DB_PORT}/{consts.DB_NAME}_test?charset={consts.DB_CHARSET}') + TRUSTED_HOSTS = ['*'] + ALLOW_SITE = ['*'] + TEST_MODE: bool = True + + +@dataclass +class DevConfig(Config): + TRUSTED_HOSTS = ['*'] + ALLOW_SITE = ['*'] + DEBUG: bool = True + DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}:{consts.DB_PORT}/{consts.DB_NAME}_dev?charset={consts.DB_CHARSET}') + REST_SERVER_PORT = consts.REST_SERVER_PORT + 1 + + SW_TITLE = '[Dev] ' + consts.SW_TITLE + + +@dataclass +class MyConfig(Config): + TRUSTED_HOSTS = ['*'] + ALLOW_SITE = ['*'] + DEBUG: bool = True + DB_URL: str = environ.get('DB_URL', f'mysql+pymysql://{consts.DB_USER_ID}:{consts.DB_USER_PW}@{consts.DB_ADDRESS}:{consts.DB_PORT}/{consts.DB_NAME}_my?charset={consts.DB_CHARSET}') + REST_SERVER_PORT = consts.REST_SERVER_PORT + 2 + + # NOTE(hsj100): DEV_TEST_CONNECT_ACCOUNT + DEV_TEST_CONNECT_ACCOUNT: UserInfo = UserInfo(**consts.ADMIN_INIT_ACCOUNT_INFO.get_dict()) + # DEV_TEST_CONNECT_ACCOUNT: str = None + + # NOTE(hsj100): SERVICE_AUTH_API_KEY (token 방식으로 진행) + SERVICE_AUTH_API_KEY: bool = False + + SW_TITLE = '[My] ' + consts.SW_TITLE + + # GLOBAL_TOKEN = None + + +def conf(): + """ + 환경 불러오기 + :return: + """ + config = dict(prod=ProdConfig, local=LocalConfig, test=TestConfig, dev=DevConfig, my=MyConfig) + return config[environ.get('API_ENV', 'local')]() + return config[environ.get('API_ENV', 'dev')]() + return config[environ.get('API_ENV', 'my')]() + return config[environ.get('API_ENV', 'test')]() diff --git a/REST_AI_ENGINE_CONTROL/app/common/consts.py b/REST_AI_ENGINE_CONTROL/app/common/consts.py new file mode 100644 index 0000000..9b1f844 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/common/consts.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +""" +@File: consts.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: 상수 선언 + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" +import sys,os +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))))) + +import project_config + +# SUPPORT PROJECT +SUPPORT_PROJECT_BASIC = 'PROJECT_BASIC' +SUPPORT_PROJECT_MEDICAL_METAVERSE = 'MEDICAL METAVERSE' +SUPPORT_PROJECT_SW_LICENSE_MANAGER = 'AI-KEPCO REST SERVER' + +PROJECT_NAME = '[SIM] AI Engine Control' +SW_TITLE= f'{PROJECT_NAME} - REST API' +SW_VERSION = project_config.PROJECT_VERSION +SW_DESCRIPTION = f''' +# 위험성 평가기반 작업자 안전관리 및 감독 협조 기술 개발 + +## KEPCO AI RBI(risk based inspection) + +[Simulator manual v0.2](http://219.250.188.208/docs/AI_Engine_Simuilator_0.2_20230206.pptx)\n +[MQTT viewer](http://219.250.188.208:50777/)\n + +## API 이용법 + - **사용자 접속**시 토큰정보(***Authorization*** 항목) 획득 (이후 API 이용가능) + - 수신받은 토큰정보(***Authorization*** 항목)를 ***Request Header*** 에 포함해서 기타 API 사용 + - 개별 API 설명과 Request/Response schema 참조 + +**시스템에서 사용하는 날짜형식**\n + YYYY-mm-ddTHH:MM:SS + YYYY-mm-dd HH:MM:SS + +**검색/변경 API 사용시, 항목값 null 사용방법(string 항목에서만 유효)** + - 해당 항목의 값을 문자열("null")로 사용**\n + + "mac": "null" + - 해당 항목의 값을 null로 사용할 경우 처리 대상에서 제외됨\n + + "mac": null <= 처리 대상에서 제외 +![This image does not work](http://219.250.188.208/image/mqtt_layer.jpg)\n +''' +TERMS_OF_SERVICE = 'http://www.daooldns.co.kr' +CONTEACT={ + 'name': 'DAOOLDNS (주)다울디엔에스', + 'url': 'http://www.daooldns.co.kr', + 'email': 'marketing@daooldns.co.kr' +} +LICENSE_INFO = { + 'name': 'Copyright by DAOOLDNS', + 'url': 'http://www.daooldns.co.kr' +} + +REST_SERVER_PORT = 50770 +DEFAULT_USER_ACCOUNT_PW = '1234' + + +class AdminInfo: + def __init__(self): + self.id: int = 1 + self.user_grade: str = 'admin' + self.account: str = 'a2d2_lc_manager@naver.com' + # self.pw: str = '$2b$12$PklBvVXdLhOQnIiNanlnIu.DJh5MspRARVChJQfFu1qg35vBoIuX2' # bcrypy hash + self.pw: str = 'E563ZFt+yJL8YY5yYlYyk602MSscPP2SCCD8UtXXpMI=' # AESCryptoCBC encrypt + self.name: str = 'administrator' + self.email: str = 'a2d2_lc_manager@naver.com' + self.email_pw: str = 'gAAAAABioV5NucuS9nQugZJnz-KjVG_FGnaowB9KAfhOoWjjiQ4jGLuYJh4Qe94mT_lCm6m3HhuOJqUeOgjppwREDpIQYzrUXA==' + self.address: str = '대구광역시 동구 동촌로351 에이스빌딩 4F' + self.phone_number: str = '053-384-3010' + self.connect_token: str = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6MSwiYWNjb3VudCI6ImEyZDJfbGNfbWFuYWdlckBuYXZlci5jb20iLCJuYW1lIjoiYWRtaW5pc3RyYXRvciIsInBob25lX251bWJlciI6IjA1My0zODQtMzAxMCIsInByb2ZpbGVfaW1nIjpudWxsLCJhY2NvdW50X3R5cGUiOiJlbWFpbCJ9.SlQSCfAof1bv2YxmW2DO4dIBrbHLg1jPO3AJsX6xKbw' + + def get_dict(self): + info = {} + for k, v in self.__dict__.items(): + if type(v) is tuple: + info[k] = v[0] + else: + info[k] = v + return info + + +ADMIN_INIT_ACCOUNT_INFO = AdminInfo() + +FERNET_SECRET_KEY = b'wQjpSYkmc4kX8MaAovk1NIHF02R2wZX760eeBTeIHW4=' +AES_CBC_PUBLIC_KEY = b'daooldns12345678' +AES_CBC_IV = b'daooldns12345678' + +COOKIES_AUTH = 'Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpZCI6MTQsImVtYWlsIjoia29hbGFAZGluZ3JyLmNvbSIsIm5hbWUiOm51bGwsInBob25lX251bWJlciI6bnVsbCwicHJvZmlsZV9pbWciOm51bGwsInNuc190eXBlIjpudWxsfQ.4vgrFvxgH8odoXMvV70BBqyqXOFa2NDQtzYkGywhV48' + +JWT_SECRET = 'ABCD1234!' +JWT_ALGORITHM = 'HS256' +EXCEPT_PATH_LIST = ['/', '/openapi.json'] +EXCEPT_PATH_REGEX = '^(/docs|/redoc' +\ + '|/api/dev' +\ + '|/api/services' +\ + ')' +MAX_API_KEY = 3 +MAX_API_WHITELIST = 10 + +NUM_RETRY_UUID_GEN = 3 + +if project_config.CONFIG == project_config.CONFIG_AISERVER or \ + project_config.CONFIG == project_config.CONFIG_MG: + + # DATABASE(MG/safe6) + DB_ADDRESS = 'localhost' + DB_PORT = '50701' + DB_USER_ID = 'aikepco' + DB_USER_PW = '!ekdnfeldpsdptm1' + DB_NAME = 'AI_ENGINE_CONTROL' + DB_CHARSET = 'utf8mb4' + +else: + # DATABASE(fermat) + DB_ADDRESS = 'localhost' + DB_PORT = '3306' + DB_USER_ID = 'root' + DB_USER_PW = '1234' + DB_NAME = 'AI_ENGINE_CONTROL' + DB_CHARSET = 'utf8mb4' + +# MAIL +SMTP_HOST = 'smtp.gmail.com' +SMTP_PORT = 587 +SMTP_HOST = 'smtp.naver.com' +SMTP_PORT = 587 +MAIL_REG_TITLE = f'{PROJECT_NAME} - Registration' +MAIL_REG_CONTENTS = ''' +안녕하세요. +[DAOOLDNS] AI-KEPCO REST SERVER 입니다. + + +Account: {} +Number of License: {} + + +감사합니다. + +''' + +AI_ENGINE_FR_DETECT_TIME_MIN = 5 + diff --git a/REST_AI_ENGINE_CONTROL/app/database/conn.py b/REST_AI_ENGINE_CONTROL/app/database/conn.py new file mode 100644 index 0000000..7133951 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/database/conn.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +""" +@File: conn.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: DB 정보 + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +from fastapi import FastAPI +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +import logging + + +def _database_exist(engine, schema_name): + query = f'SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = "{schema_name}"' + with engine.connect() as conn: + result_proxy = conn.execute(query) + result = result_proxy.scalar() + return bool(result) + + +def _drop_database(engine, schema_name): + with engine.connect() as conn: + conn.execute(f'DROP DATABASE {schema_name};') + + +def _create_database(engine, schema_name): + with engine.connect() as conn: + conn.execute(f'CREATE DATABASE {schema_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_bin;') + + +class SQLAlchemy: + def __init__(self, app: FastAPI = None, **kwargs): + self._engine = None + self._session = None + if app is not None: + self.init_app(app=app, **kwargs) + + def init_app(self, app: FastAPI, **kwargs): + """ + DB 초기화 함수 + :param app: FastAPI 인스턴스 + :param kwargs: + :return: + """ + database_url = kwargs.get('DB_URL') + pool_recycle = kwargs.setdefault('DB_POOL_RECYCLE', 900) + is_testing = kwargs.setdefault('TEST_MODE', False) + echo = kwargs.setdefault('DB_ECHO', True) + + self._engine = create_engine( + database_url, + echo=echo, + pool_recycle=pool_recycle, + pool_pre_ping=True, + ) + if is_testing: # create schema + db_url = self._engine.url + if db_url.host != 'localhost': + raise Exception('db host must be \'localhost\' in test environment') + except_schema_db_url = f'{db_url.drivername}://{db_url.username}:{db_url.password}@{db_url.host}:{db_url.port}' + schema_name = db_url.database + temp_engine = create_engine(except_schema_db_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True) + if _database_exist(temp_engine, schema_name): + _drop_database(temp_engine, schema_name) + _create_database(temp_engine, schema_name) + temp_engine.dispose() + else: + db_url = self._engine.url + except_schema_db_url = f'{db_url.drivername}://{db_url.username}:{db_url.password}@{db_url.host}:{db_url.port}' + schema_name = db_url.database + temp_engine = create_engine(except_schema_db_url, echo=echo, pool_recycle=pool_recycle, pool_pre_ping=True) + if not _database_exist(temp_engine, schema_name): + _create_database(temp_engine, schema_name) + Base.metadata.create_all(db.engine) + temp_engine.dispose() + + self._session = sessionmaker(autocommit=False, autoflush=False, bind=self._engine) + + # NOTE(hsj100): ADMINISTRATOR + create_admin(self._session) + + @app.on_event('startup') + def startup(): + self._engine.connect() + logging.info('DB connected.') + + @app.on_event('shutdown') + def shutdown(): + self._session.close_all() + self._engine.dispose() + logging.info('DB disconnected') + + def get_db(self): + """ + 요청마다 DB 세션 유지 함수 + :return: + """ + if self._session is None: + raise Exception('must be called \'init_app\'') + db_session = None + try: + db_session = self._session() + yield db_session + finally: + db_session.close() + + @property + def session(self): + return self.get_db + + @property + def engine(self): + return self._engine + + +db = SQLAlchemy() +Base = declarative_base() + + +# NOTE(hsj100): ADMINISTRATOR +def create_admin(db_session): + import bcrypt + from app.database.schema import Users + from app.common.consts import ADMIN_INIT_ACCOUNT_INFO + + session = db_session() + + if not session: + raise Exception('cat`t create account of admin') + + if Users.get(account=ADMIN_INIT_ACCOUNT_INFO.account): + return + + admin = {**ADMIN_INIT_ACCOUNT_INFO.get_dict()} + + Users.create(session=session, auto_commit=True, **admin) diff --git a/REST_AI_ENGINE_CONTROL/app/database/crud.py b/REST_AI_ENGINE_CONTROL/app/database/crud.py new file mode 100644 index 0000000..e94607d --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/database/crud.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +""" +@File: crud.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: DB CRUD + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +import math +from datetime import datetime +from dateutil.relativedelta import relativedelta +from sqlalchemy import func, desc + +from fastapi import APIRouter, Depends, Body +from sqlalchemy.orm import Session +from app import models as M +from app.database.conn import Base, db +from app.database.schema import Users, UserLog +from app.utils.extra import query_to_groupby, query_to_groupby_date + + +def request_parser(request_data: dict = None) -> dict: + """ + request information -> dict + + :param request_data: + :return: + """ + result_dict = dict() + if not request_data: + return result_dict + for key, val in request_data.items(): + if val is not None: + result_dict[key] = val if val != 'null' else None + return result_dict + + +def dict_to_filter_fmt(dict_data, get_attributes_callback=None): + """ + dict -> sqlalchemy filter (criterion: sql expression) + + :param dict_data: + :param get_attributes_callback: + :return: + """ + if get_attributes_callback is None: + raise Exception('invalid get_attributes_callback') + + criterion = list() + + for key, val in dict_data.items(): + key = key.split('__') + if len(key) > 2: + raise Exception('length of split(key) should be no more than 2.') + + key_length = len(key) + col = get_attributes_callback(key[0]) + + if col is None: + continue + + if key_length == 1: + criterion.append((col == val)) + elif key_length == 2 and key[1] == 'gt': + criterion.append((col > val)) + elif key_length == 2 and key[1] == 'gte': + criterion.append((col >= val)) + elif key_length == 2 and key[1] == 'lt': + criterion.append((col < val)) + elif key_length == 2 and key[1] == 'lte': + criterion.append((col <= val)) + elif key_length == 2 and key[1] == 'in': + criterion.append((col.in_(val))) + elif key_length == 2 and key[1] == 'like': + criterion.append((col.like(val))) + + return criterion + + +async def table_select(accessor_info, target_table, request_body_info, response_model, response_model_data): + """ + table_read + """ + try: + # parameter + if not accessor_info: + raise Exception('invalid accessor') + + if not target_table: + raise Exception(f'invalid table_name:{target_table}') + + # if not request_body_info: + # raise Exception('invalid request_body_info') + + if not response_model: + raise Exception('invalid response_model') + + if not response_model_data: + raise Exception('invalid response_model_data') + + # paging - request + paging_request = None + if request_body_info: + if hasattr(request_body_info, 'paging'): + paging_request = request_body_info.paging + request_body_info = request_body_info.search + + # request + request_info = request_parser(request_body_info.dict()) + + # search + criterion = None + if isinstance(request_body_info, M.UserLogDaySearchReq): + # UserLog search + # request + request_info = request_parser(request_body_info.dict()) + + # search UserLog + def get_attributes_callback(key: str): + return getattr(UserLog, key) + criterion = dict_to_filter_fmt(request_info, get_attributes_callback) + + # search + session = next(db.session()) + search_info = session.query(UserLog)\ + .filter(UserLog.mac.isnot(None) + , UserLog.api == '/api/auth/login' + , UserLog.type == M.TypeUserLog.info + , UserLog.message == 'ok' + , *criterion)\ + .order_by(desc(UserLog.created_at), desc(UserLog.updated_at))\ + .all() + if not search_info: + raise Exception('not found data') + + group_by_day = query_to_groupby_date(search_info, 'updated_at') + + # result + result_info = list() + for day, info_list in group_by_day.items(): + info_by_mac = query_to_groupby(info_list, 'mac', first=True) + for log_info in info_by_mac.values(): + result_info.append(response_model_data.from_orm(log_info)) + else: + # basic search (single table) + # request + request_info = request_parser(request_body_info.dict()) + + # search + search_info = target_table.filter(**request_info).all() + if not search_info: + raise Exception('not found data') + + # result + result_info = list() + for purchase_info in search_info: + result_info.append(response_model_data.from_orm(purchase_info)) + + # response - paging + paging_response = None + if paging_request: + total_contents_num = len(result_info) + total_page_num = math.ceil(total_contents_num / paging_request.page_contents_num) + start_contents_index = (paging_request.start_page - 1) * paging_request.page_contents_num + end_contents_index = start_contents_index + paging_request.page_contents_num + if end_contents_index < total_contents_num: + result_info = result_info[start_contents_index: end_contents_index] + else: + result_info = result_info[start_contents_index:] + + paging_response = M.PagingRes() + paging_response.total_page_num = total_page_num + paging_response.total_contents_num = total_contents_num + paging_response.start_page = paging_request.start_page + paging_response.search_contents_num = len(result_info) + + return response_model(data=result_info, paging=paging_response) + except Exception as e: + return response_model.set_error(str(e)) + + +async def table_update(accessor_info, target_table, request_body_info, response_model, response_model_data=None): + search_info = None + + try: + + # parameter + if not accessor_info: + raise Exception('invalid accessor') + + if not target_table: + raise Exception(f'invalid table_name:{target_table}') + + if not request_body_info: + raise Exception('invalid request_body_info') + + if not response_model: + raise Exception('invalid response_model') + + # if not response_model_data: + # raise Exception('invalid response_model_data') + + # request + if not request_body_info.search_info: + raise Exception('invalid request_body: search_info') + + request_search_info = request_parser(request_body_info.search_info.dict()) + if not request_search_info: + raise Exception('invalid request_body: search_info') + + if not request_body_info.update_info: + raise Exception('invalid request_body: update_info') + + request_update_info = request_parser(request_body_info.update_info.dict()) + if not request_update_info: + raise Exception('invalid request_body: update_info') + + # search + search_info = target_table.filter(**request_search_info) + + # process + search_info.update(auto_commit=True, synchronize_session=False, **request_update_info) + + # result + return response_model() + except Exception as e: + if search_info: + search_info.close() + return response_model.set_error(str(e)) + + +async def table_delete(accessor_info, target_table, request_body_info, response_model, response_model_data=None): + search_info = None + + try: + # request + if not accessor_info: + raise Exception('invalid accessor') + + if not target_table: + raise Exception(f'invalid table_name:{target_table}') + + if not request_body_info: + raise Exception('invalid request_body_info') + + if not response_model: + raise Exception('invalid response_model') + + # if not response_model_data: + # raise Exception('invalid response_model_data') + + # request + request_search_info = request_parser(request_body_info.dict()) + if not request_search_info: + raise Exception('invalid request_body') + + # search + search_info = target_table.filter(**request_search_info) + temp_search = search_info.all() + + # process + search_info.delete(auto_commit=True, synchronize_session=False) + + # update license num + uuid_list = list() + for _license in temp_search: + if not hasattr(temp_search, 'uuid'): + # case: license + break + if _license.uuid not in uuid_list: + uuid_list.append(_license.uuid) + license_num = target_table.filter(uuid=_license.uuid).count() + target_table.filter(uuid=_license.uuid).update(auto_commit=True, synchronize_session=False, num=license_num) + + # result + return response_model() + except Exception as e: + if search_info: + search_info.close() + return response_model.set_error(str(e)) diff --git a/REST_AI_ENGINE_CONTROL/app/database/schema.py b/REST_AI_ENGINE_CONTROL/app/database/schema.py new file mode 100644 index 0000000..0bb10ba --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/database/schema.py @@ -0,0 +1,259 @@ +# -*- coding: utf-8 -*- +""" +@File: schema.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: DB schema + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +from sqlalchemy import ( + Column, + Integer, + String, + DateTime, + func, + Enum, + Boolean, + ForeignKey, +) +from sqlalchemy.orm import Session, relationship + +from app.database.conn import Base, db +from app.utils.date_utils import D +from app.models import ( + TypeUserSex, + TypeUserGrade, + TypeUserMembership, + TypeUserAccount, + TypeUserAccountStatus, + TyepUserLoginStatus, + TypeUserLog +) + + +class BaseMixin: + id = Column(Integer, primary_key=True, index=True, comment='DB INDEX') + created_at = Column(DateTime, nullable=False, default=D.datetime(), comment='생성 날짜') + updated_at = Column(DateTime, nullable=False, default=D.datetime(), onupdate=D.datetime(), comment='변경 날짜') + + def __init__(self): + self._q = None + self._session = None + self.served = None + + def all_columns(self): + return [c for c in self.__table__.columns if c.primary_key is False and c.name != 'created_at'] + + def __hash__(self): + return hash(self.id) + + @classmethod + def create(cls, session: Session, auto_commit=False, **kwargs): + """ + 테이블 데이터 적재 전용 함수 + :param session: + :param auto_commit: 자동 커밋 여부 + :param kwargs: 적재 할 데이터 + :return: + """ + obj = cls() + + # NOTE(hsj100) : FIX_DATETIME + if 'created_at' not in kwargs: + obj.created_at = D.datetime() + if 'updated_at' not in kwargs: + obj.updated_at = D.datetime() + + for col in obj.all_columns(): + col_name = col.name + if col_name in kwargs: + setattr(obj, col_name, kwargs.get(col_name)) + session.add(obj) + session.flush() + if auto_commit: + session.commit() + return obj + + @classmethod + def get(cls, session: Session = None, **kwargs): + """ + Simply get a Row + :param session: + :param kwargs: + :return: + """ + sess = next(db.session()) if not session else session + query = sess.query(cls) + for key, val in kwargs.items(): + col = getattr(cls, key) + query = query.filter(col == val) + + if query.count() > 1: + raise Exception('Only one row is supposed to be returned, but got more than one.') + result = query.first() + if not session: + sess.close() + return result + + @classmethod + def filter(cls, session: Session = None, **kwargs): + """ + Simply get a Row + :param session: + :param kwargs: + :return: + """ + cond = [] + for key, val in kwargs.items(): + key = key.split('__') + if len(key) > 2: + raise Exception('length of split(key) should be no more than 2.') + col = getattr(cls, key[0]) + if len(key) == 1: cond.append((col == val)) + elif len(key) == 2 and key[1] == 'gt': cond.append((col > val)) + elif len(key) == 2 and key[1] == 'gte': cond.append((col >= val)) + elif len(key) == 2 and key[1] == 'lt': cond.append((col < val)) + elif len(key) == 2 and key[1] == 'lte': cond.append((col <= val)) + elif len(key) == 2 and key[1] == 'in': cond.append((col.in_(val))) + elif len(key) == 2 and key[1] == 'like': cond.append((col.like(val))) + + obj = cls() + if session: + obj._session = session + obj.served = True + else: + obj._session = next(db.session()) + obj.served = False + query = obj._session.query(cls) + query = query.filter(*cond) + obj._q = query + return obj + + @classmethod + def cls_attr(cls, col_name=None): + if col_name: + col = getattr(cls, col_name) + return col + else: + return cls + + def order_by(self, *args: str): + for a in args: + if a.startswith('-'): + col_name = a[1:] + is_asc = False + else: + col_name = a + is_asc = True + col = self.cls_attr(col_name) + self._q = self._q.order_by(col.asc()) if is_asc else self._q.order_by(col.desc()) + return self + + def update(self, auto_commit: bool = False, synchronize_session='evaluate', **kwargs): + # NOTE(hsj100) : FIX_DATETIME + if 'updated_at' not in kwargs: + kwargs['updated_at'] = D.datetime() + + qs = self._q.update(kwargs, synchronize_session=synchronize_session) + get_id = self.id + ret = None + + self._session.flush() + if qs > 0 : + ret = self._q.first() + if auto_commit: + self._session.commit() + return ret + + def first(self): + result = self._q.first() + self.close() + return result + + def delete(self, auto_commit: bool = False, synchronize_session='evaluate'): + self._q.delete(synchronize_session=synchronize_session) + if auto_commit: + self._session.commit() + + def all(self): + print(self.served) + result = self._q.all() + self.close() + return result + + def count(self): + result = self._q.count() + self.close() + return result + + def close(self): + if not self.served: + self._session.close() + else: + self._session.flush() + + +class ApiKeys(Base, BaseMixin): + __tablename__ = 'api_keys' + access_key = Column(String(length=64), nullable=False, index=True) + secret_key = Column(String(length=64), nullable=False) + user_memo = Column(String(length=40), nullable=True) + status = Column(Enum('active', 'stopped', 'deleted'), default='active') + is_whitelisted = Column(Boolean, default=False) + user_id = Column(Integer, ForeignKey('users.id'), nullable=False) + whitelist = relationship('ApiWhiteLists', backref='api_keys') + users = relationship('Users', back_populates='keys') + + +class ApiWhiteLists(Base, BaseMixin): + __tablename__ = 'api_whitelists' + ip_addr = Column(String(length=64), nullable=False) + api_key_id = Column(Integer, ForeignKey('api_keys.id'), nullable=False) + + +class Users(Base, BaseMixin): + __tablename__ = 'users' + status = Column(Enum(TypeUserAccountStatus), nullable=False, default=TypeUserAccountStatus.active, comment='계정 상태') + user_grade = Column(Enum(TypeUserGrade), nullable=False, default=TypeUserGrade.user, comment='등급') + membership = Column(Enum(TypeUserMembership), nullable=False, default=TypeUserMembership.personal, comment='회원 종류') + account_type = Column(Enum(TypeUserAccount), nullable=False, default=TypeUserAccount.email, comment='계정 종류') + account = Column(String(length=255), nullable=False, unique=True, comment='계정') + pw = Column(String(length=2000), nullable=True, comment='계정 비밀번호') + email = Column(String(length=255), nullable=True, comment='전자 메일') + name = Column(String(length=255), nullable=False, comment='이름') + sex = Column(Enum(TypeUserSex), nullable=False, default=TypeUserSex.male, comment='성별') + rrn = Column(String(length=255), nullable=True, comment='주민등록번호') + address = Column(String(length=2000), nullable=True, comment='주소') + mobile_number = Column(String(length=20), nullable=True, comment='휴대폰 번호') + landline_number = Column(String(length=20), nullable=True, comment='유선 번호') + marketing_agree = Column(Boolean, nullable=False, default=False, comment='마켓팅 수신 동의') + picture = Column(String(length=1000), nullable=True, comment='프로필 사진') + keys = relationship('ApiKeys', back_populates='users') + + # extra + login = Column(Enum(TyepUserLoginStatus), nullable=False, default=TyepUserLoginStatus.logout, comment='사용자 접속 상태') # TODO(hsj100): LOGIN_STATUS + + # uuid = Column(String(length=36), nullable=True, unique=True) + + # relationship + # NOTE(hsj100): Users:userlog => 1:N + # userlog = relationship('UserLog', back_populates='users', lazy=False) + + +class UserLog(Base, BaseMixin): + __tablename__ = 'userlog' + account = Column(String(length=255), nullable=False, comment='계정') + type = Column(Enum(TypeUserLog), nullable=False, default=TypeUserLog.info, comment='로그 종류') + api = Column(String(length=511), nullable=False, comment='API') + message = Column(String(length=5000), nullable=True, comment='로그 메시지') + # mac = Column(String(length=255), nullable=True, comment='사용자 MAC 주소') + # NOTE(hsj100): 다단계 자동 삭제 + # user_id = Column(Integer, ForeignKey('users.id', ondelete='CASCADE'), nullable=False) + # user_id = Column(Integer, ForeignKey('users.id'), nullable=False) + + # users = relationship('Users', back_populates='userlog') + # users = relationship('Users', back_populates='userlog', lazy=False) diff --git a/REST_AI_ENGINE_CONTROL/app/errors/__init__.py b/REST_AI_ENGINE_CONTROL/app/errors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/REST_AI_ENGINE_CONTROL/app/errors/exceptions.py b/REST_AI_ENGINE_CONTROL/app/errors/exceptions.py new file mode 100644 index 0000000..2f4510a --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/errors/exceptions.py @@ -0,0 +1,188 @@ +from app.common.consts import MAX_API_KEY, MAX_API_WHITELIST + + +class StatusCode: + HTTP_500 = 500 + HTTP_400 = 400 + HTTP_401 = 401 + HTTP_403 = 403 + HTTP_404 = 404 + HTTP_405 = 405 + + +class APIException(Exception): + status_code: int + code: str + msg: str + detail: str + ex: Exception + + def __init__( + self, + *, + status_code: int = StatusCode.HTTP_500, + code: str = '000000', + msg: str = None, + detail: str = None, + ex: Exception = None, + ): + self.status_code = status_code + self.code = code + self.msg = msg + self.detail = detail + self.ex = ex + super().__init__(ex) + + +class NotFoundUserEx(APIException): + def __init__(self, user_id: int = None, ex: Exception = None): + super().__init__( + status_code=StatusCode.HTTP_404, + msg=f'해당 유저를 찾을 수 없습니다.', + detail=f'Not Found User ID : {user_id}', + code=f'{StatusCode.HTTP_400}{"1".zfill(4)}', + ex=ex, + ) + + +class NotAuthorized(APIException): + def __init__(self, ex: Exception = None): + super().__init__( + status_code=StatusCode.HTTP_401, + msg=f'로그인이 필요한 서비스 입니다.', + detail='Authorization Required', + code=f'{StatusCode.HTTP_401}{"1".zfill(4)}', + ex=ex, + ) + + +class TokenExpiredEx(APIException): + def __init__(self, ex: Exception = None): + super().__init__( + status_code=StatusCode.HTTP_400, + msg=f'세션이 만료되어 로그아웃 되었습니다.', + detail='Token Expired', + code=f'{StatusCode.HTTP_400}{"1".zfill(4)}', + ex=ex, + ) + + +class TokenDecodeEx(APIException): + def __init__(self, ex: Exception = None): + super().__init__( + status_code=StatusCode.HTTP_400, + msg=f'비정상적인 접근입니다.', + detail='Token has been compromised.', + code=f'{StatusCode.HTTP_400}{"2".zfill(4)}', + ex=ex, + ) + + +class NoKeyMatchEx(APIException): + def __init__(self, ex: Exception = None): + super().__init__( + status_code=StatusCode.HTTP_404, + msg=f'해당 키에 대한 권한이 없거나 해당 키가 없습니다.', + detail='No Keys Matched', + code=f'{StatusCode.HTTP_404}{"3".zfill(4)}', + ex=ex, + ) + + +class MaxKeyCountEx(APIException): + def __init__(self, ex: Exception = None): + super().__init__( + status_code=StatusCode.HTTP_400, + msg=f'API 키 생성은 {MAX_API_KEY}개 까지 가능합니다.', + detail='Max Key Count Reached', + code=f'{StatusCode.HTTP_400}{"4".zfill(4)}', + ex=ex, + ) + + +class MaxWLCountEx(APIException): + def __init__(self, ex: Exception = None): + super().__init__( + status_code=StatusCode.HTTP_400, + msg=f'화이트리스트 생성은 {MAX_API_WHITELIST}개 까지 가능합니다.', + detail='Max Whitelist Count Reached', + code=f'{StatusCode.HTTP_400}{"5".zfill(4)}', + ex=ex, + ) + + +class InvalidIpEx(APIException): + def __init__(self, ip: str, ex: Exception = None): + super().__init__( + status_code=StatusCode.HTTP_400, + msg=f'{ip}는 올바른 IP 가 아닙니다.', + detail=f'invalid IP : {ip}', + code=f'{StatusCode.HTTP_400}{"6".zfill(4)}', + ex=ex, + ) + + +class SqlFailureEx(APIException): + def __init__(self, ex: Exception = None): + super().__init__( + status_code=StatusCode.HTTP_500, + msg=f'이 에러는 서버측 에러 입니다. 자동으로 리포팅 되며, 빠르게 수정하겠습니다.', + detail='Internal Server Error', + code=f'{StatusCode.HTTP_500}{"2".zfill(4)}', + ex=ex, + ) + + +class APIQueryStringEx(APIException): + def __init__(self, ex: Exception = None): + super().__init__( + status_code=StatusCode.HTTP_400, + msg=f'쿼리스트링은 key, timestamp 2개만 허용되며, 2개 모두 요청시 제출되어야 합니다.', + detail='Query String Only Accept key and timestamp.', + code=f'{StatusCode.HTTP_400}{"7".zfill(4)}', + ex=ex, + ) + + +class APIHeaderInvalidEx(APIException): + def __init__(self, ex: Exception = None): + super().__init__( + status_code=StatusCode.HTTP_400, + msg=f'헤더에 키 해싱된 Secret 이 없거나, 유효하지 않습니다.', + detail='Invalid HMAC secret in Header', + code=f'{StatusCode.HTTP_400}{"8".zfill(4)}', + ex=ex, + ) + + +class APITimestampEx(APIException): + def __init__(self, ex: Exception = None): + super().__init__( + status_code=StatusCode.HTTP_400, + msg=f'쿼리스트링에 포함된 타임스탬프는 KST 이며, 현재 시간보다 작아야 하고, 현재시간 - 10초 보다는 커야 합니다.', + detail='timestamp in Query String must be KST, Timestamp must be less than now, and greater than now - 10.', + code=f'{StatusCode.HTTP_400}{"9".zfill(4)}', + ex=ex, + ) + + +class NotFoundAccessKeyEx(APIException): + def __init__(self, api_key: str, ex: Exception = None): + super().__init__( + status_code=StatusCode.HTTP_404, + msg=f'API 키를 찾을 수 없습니다.', + detail=f'Not found such API Access Key : {api_key}', + code=f'{StatusCode.HTTP_404}{"10".zfill(4)}', + ex=ex, + ) + + +class KakaoSendFailureEx(APIException): + def __init__(self, ex: Exception = None): + super().__init__( + status_code=StatusCode.HTTP_400, + msg=f'카카오톡 전송에 실패했습니다.', + detail=f'Failed to send KAKAO MSG.', + code=f'{StatusCode.HTTP_400}{"11".zfill(4)}', + ex=ex, + ) diff --git a/REST_AI_ENGINE_CONTROL/app/main.py b/REST_AI_ENGINE_CONTROL/app/main.py new file mode 100644 index 0000000..38f671e --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/main.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +""" +@File: main.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: main module + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +from dataclasses import asdict + +import uvicorn +from fastapi import FastAPI, Depends +from fastapi.security import APIKeyHeader +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.cors import CORSMiddleware + +from app.common import consts + +from app.database.conn import db +from app.common.config import conf +from app.middlewares.token_validator import access_control +from app.middlewares.trusted_hosts import TrustedHostMiddleware +from app.routes import dev, index, auth, users, services + + +API_KEY_HEADER = APIKeyHeader(name='Authorization', auto_error=False) + + +def create_app(): + """ + fast_api_app 생성 + + :return: fast_api_app + """ + configurations = conf() + fast_api_app = FastAPI( + title=configurations.SW_TITLE, + version=configurations.SW_VERSION, + description=configurations.SW_DESCRIPTION, + terms_of_service=configurations.TERMS_OF_SERVICE, + contact=configurations.CONTEACT, + license_info=configurations.LICENSE_INFO + ) + + # 데이터 베이스 이니셜라이즈 + conf_dict = asdict(configurations) + db.init_app(fast_api_app, **conf_dict) + + # 레디스 이니셜라이즈 + + # 미들웨어 정의 + fast_api_app.add_middleware(middleware_class=BaseHTTPMiddleware, dispatch=access_control) + fast_api_app.add_middleware( + CORSMiddleware, + allow_origins=conf().ALLOW_SITE, + allow_credentials=True, + allow_methods=['*'], + allow_headers=['*'], + ) + fast_api_app.add_middleware(TrustedHostMiddleware, allowed_hosts=conf().TRUSTED_HOSTS, except_path=['/health']) + + # 라우터 정의 + fast_api_app.include_router(index.router, tags=['Defaults']) + fast_api_app.include_router(auth.router, tags=['Authentication'], prefix='/api', dependencies=[Depends(API_KEY_HEADER)]) + + # fast_api_app.include_router(services.router, tags=['Services'], prefix='/api', dependencies=[Depends(API_KEY_HEADER)]) + fast_api_app.include_router(services.router, tags=['Services'], prefix='/api') + + fast_api_app.include_router(users.router, tags=['Users'], prefix='/api', dependencies=[Depends(API_KEY_HEADER)]) + + if conf().DEBUG: + fast_api_app.include_router(dev.router, tags=['Developments'], prefix='/api') + return fast_api_app + + +app = create_app() + + +if __name__ == '__main__': + uvicorn.run('main:app', host='0.0.0.0', port=conf().REST_SERVER_PORT, reload=True) diff --git a/REST_AI_ENGINE_CONTROL/app/middlewares/token_validator.py b/REST_AI_ENGINE_CONTROL/app/middlewares/token_validator.py new file mode 100644 index 0000000..6acf055 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/middlewares/token_validator.py @@ -0,0 +1,166 @@ +import base64 +import hmac +import json +import time +import typing +import re + +import jwt +import sqlalchemy.exc + +from jwt.exceptions import ExpiredSignatureError, DecodeError +from starlette.requests import Request +from starlette.responses import JSONResponse + +from app.common.consts import EXCEPT_PATH_LIST, EXCEPT_PATH_REGEX +from app.database.conn import db +from app.database.schema import Users, ApiKeys +from app.errors import exceptions as ex + +from app.common import consts +from app.common.config import conf +from app.errors.exceptions import APIException, SqlFailureEx, APIQueryStringEx +from app.models import UserToken + +from app.utils.date_utils import D +from app.utils.logger import api_logger +from app.utils.query_utils import to_dict + +from dataclasses import asdict + + +async def access_control(request: Request, call_next): + request.state.req_time = D.datetime() + request.state.start = time.time() + request.state.inspect = None + request.state.user = None + request.state.service = None + + ip = request.headers['x-forwarded-for'] if 'x-forwarded-for' in request.headers.keys() else request.client.host + request.state.ip = ip.split(',')[0] if ',' in ip else ip + headers = request.headers + cookies = request.cookies + + url = request.url.path + if await url_pattern_check(url, EXCEPT_PATH_REGEX) or url in EXCEPT_PATH_LIST: + response = await call_next(request) + if url != '/': + await api_logger(request=request, response=response) + return response + + try: + if url.startswith('/api'): + # api 인경우 헤더로 토큰 검사 + # NOTE(hsj100): SERVICE_AUTH_API_KEY (token 방식으로 진행) + if url.startswith('/api/services') and conf().SERVICE_AUTH_API_KEY: + qs = str(request.query_params) + qs_list = qs.split('&') + session = next(db.session()) + if not conf().DEBUG: + try: + qs_dict = {qs_split.split('=')[0]: qs_split.split('=')[1] for qs_split in qs_list} + except Exception: + raise ex.APIQueryStringEx() + + qs_keys = qs_dict.keys() + + if 'key' not in qs_keys or 'timestamp' not in qs_keys: + raise ex.APIQueryStringEx() + + if 'secret' not in headers.keys(): + raise ex.APIHeaderInvalidEx() + + api_key = ApiKeys.get(session=session, access_key=qs_dict['key']) + + if not api_key: + raise ex.NotFoundAccessKeyEx(api_key=qs_dict['key']) + mac = hmac.new(bytes(api_key.secret_key, encoding='utf8'), bytes(qs, encoding='utf-8'), digestmod='sha256') + d = mac.digest() + validating_secret = str(base64.b64encode(d).decode('utf-8')) + + if headers['secret'] != validating_secret: + raise ex.APIHeaderInvalidEx() + + now_timestamp = int(D.datetime(diff=9).timestamp()) + if now_timestamp - 10 > int(qs_dict['timestamp']) or now_timestamp < int(qs_dict['timestamp']): + raise ex.APITimestampEx() + + user_info = to_dict(api_key.users) + request.state.user = UserToken(**user_info) + + else: + # Request User 가 필요함 + if 'authorization' in headers.keys(): + key = headers.get('Authorization') + api_key_obj = ApiKeys.get(session=session, access_key=key) + user_info = to_dict(Users.get(session=session, id=api_key_obj.user_id)) + request.state.user = UserToken(**user_info) + # 토큰 없음 + else: + if 'Authorization' not in headers.keys(): + raise ex.NotAuthorized() + session.close() + response = await call_next(request) + return response + else: + if 'authorization' in headers.keys(): + # 토큰 존재 + token_info = await token_decode(access_token=headers.get('Authorization')) + request.state.user = UserToken(**token_info) + elif conf().DEV_TEST_CONNECT_ACCOUNT: + # NOTE(hsj100): DEV_TEST_CONNECT_ACCOUNT + request.state.user = UserToken.from_orm(conf().DEV_TEST_CONNECT_ACCOUNT) + else: + # 토큰 없음 + if 'Authorization' not in headers.keys(): + raise ex.NotAuthorized() + else: + # 템플릿 렌더링인 경우 쿠키에서 토큰 검사 + cookies['Authorization'] = conf().COOKIES_AUTH + + if 'Authorization' not in cookies.keys(): + raise ex.NotAuthorized() + + token_info = await token_decode(access_token=cookies.get('Authorization')) + request.state.user = UserToken(**token_info) + response = await call_next(request) + await api_logger(request=request, response=response) + except Exception as e: + + error = await exception_handler(e) + error_dict = dict(status=error.status_code, msg=error.msg, detail=error.detail, code=error.code) + response = JSONResponse(status_code=error.status_code, content=error_dict) + await api_logger(request=request, error=error) + + return response + + +async def url_pattern_check(path, pattern): + result = re.match(pattern, path) + if result: + return True + return False + + +async def token_decode(access_token): + """ + :param access_token: + :return: + """ + try: + access_token = access_token.replace('Bearer ', "") + payload = jwt.decode(access_token, key=consts.JWT_SECRET, algorithms=[consts.JWT_ALGORITHM]) + except ExpiredSignatureError: + raise ex.TokenExpiredEx() + except DecodeError: + raise ex.TokenDecodeEx() + return payload + + +async def exception_handler(error: Exception): + print(error) + if isinstance(error, sqlalchemy.exc.OperationalError): + error = SqlFailureEx(ex=error) + if not isinstance(error, APIException): + error = APIException(ex=error, detail=str(error)) + return error diff --git a/REST_AI_ENGINE_CONTROL/app/middlewares/trusted_hosts.py b/REST_AI_ENGINE_CONTROL/app/middlewares/trusted_hosts.py new file mode 100644 index 0000000..2aa86fc --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/middlewares/trusted_hosts.py @@ -0,0 +1,63 @@ +import typing + + +from starlette.datastructures import URL, Headers +from starlette.responses import PlainTextResponse, RedirectResponse, Response +from starlette.types import ASGIApp, Receive, Scope, Send + +ENFORCE_DOMAIN_WILDCARD = 'Domain wildcard patterns must be lNo module named \'app\'ike \'*.example.com\'.' + + +class TrustedHostMiddleware: + def __init__( + self, + app: ASGIApp, + allowed_hosts: typing.Sequence[str] = None, + except_path: typing.Sequence[str] = None, + www_redirect: bool = True, + ) -> None: + if allowed_hosts is None: + allowed_hosts = ['*'] + if except_path is None: + except_path = [] + for pattern in allowed_hosts: + assert '*' not in pattern[1:], ENFORCE_DOMAIN_WILDCARD + if pattern.startswith('*') and pattern != '*': + assert pattern.startswith('*.'), ENFORCE_DOMAIN_WILDCARD + self.app = app + self.allowed_hosts = list(allowed_hosts) + self.allow_any = '*' in allowed_hosts + self.www_redirect = www_redirect + self.except_path = list(except_path) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if self.allow_any or scope['type'] not in ('http', 'websocket',): # pragma: no cover + await self.app(scope, receive, send) + return + + headers = Headers(scope=scope) + host = headers.get('host', "").split(':')[0] + is_valid_host = False + found_www_redirect = False + for pattern in self.allowed_hosts: + if ( + host == pattern + or (pattern.startswith('*') and host.endswith(pattern[1:])) + or URL(scope=scope).path in self.except_path + ): + is_valid_host = True + break + elif 'www.' + host == pattern: + found_www_redirect = True + + if is_valid_host: + await self.app(scope, receive, send) + else: + if found_www_redirect and self.www_redirect: + url = URL(scope=scope) + redirect_url = url.replace(netloc='www.' + url.netloc) + response = RedirectResponse(url=str(redirect_url)) # type: Response + else: + response = PlainTextResponse('Invalid host header', status_code=400) + + await response(scope, receive, send) diff --git a/REST_AI_ENGINE_CONTROL/app/models.py b/REST_AI_ENGINE_CONTROL/app/models.py new file mode 100644 index 0000000..3c2a25f --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/models.py @@ -0,0 +1,963 @@ +# -*- coding: utf-8 -*- +""" +@File: models.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: data models + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +from datetime import datetime +from enum import Enum +from typing import List, Dict + +from pydantic import Field +from pydantic.main import BaseModel +from pydantic.networks import EmailStr, IPvAnyAddress +from typing import Optional + +from app.common.consts import ( + SW_TITLE, + SW_VERSION, + MAIL_REG_TITLE, + MAIL_REG_CONTENTS, + SMTP_HOST, + SMTP_PORT, + ADMIN_INIT_ACCOUNT_INFO, + DEFAULT_USER_ACCOUNT_PW, + AI_ENGINE_FR_DETECT_TIME_MIN +) +from app.utils.date_utils import D + + +class SWInfo(BaseModel): + """ + ### 서비스 정보 + """ + name: str = Field(SW_TITLE, description='SW 이름', example=SW_TITLE) + version: str = Field(SW_VERSION, description='SW 버전', example=SW_VERSION) + date: str = Field(D.date_str(), description='현재 날짜', example='%Y.%m.%dT%H:%M:%S') + + +class CustomEnum(Enum): + @classmethod + def get_elements_str(cls, is_key=True): + if is_key: + result_str = cls.__members__.keys() + else: + result_str = cls.__members__.values() + return '[' + ', '.join(result_str) + ']' + + +class TypeUserSex(str, CustomEnum): + """ + ### 사용자 성별 타입 + """ + male: str = 'male' + female: str = 'female' + + +class TypeUserAccount(str, CustomEnum): + """ + ### 사용자 계정 타입 + """ + email: str = 'email' + # facebook: str = 'facebook' + # google: str = 'google' + # kakao: str = 'kakao' + + +class TypeUserAccountStatus(str, CustomEnum): + """ + ### 사용자 계정 상태 타입 + """ + active: str = 'active' + deleted: str = 'deleted' + blocked: str = 'blocked' + + +class TyepUserLoginStatus(str, CustomEnum): + """ + ### 유저 로그인 상태 타입 + """ + login: str = 'login' + logout: str = 'logout' + + +class TypeUserMembership(str, CustomEnum): + """ + ### 사용자 회원 타입 + """ + personal: str = 'personal' + company: str = 'company' + + +class TypeUserGrade(str, CustomEnum): + """ + ### 사용자 등급 타입 + """ + admin: str = 'admin' + user: str = 'user' + + +class TypeUserLog(str, CustomEnum): + """ + ### 사용자 로그 타입 + """ + info: str = 'info' + error: str = 'error' + + +class Token(BaseModel): + Authorization: str = Field(None, description='인증키', example='Bearer [token]') + + +class EmailRecipients(BaseModel): + name: str + email: str + + +class SendEmail(BaseModel): + email_to: List[EmailRecipients] = None + + +class KakaoMsgBody(BaseModel): + msg: str = None + + +class MessageOk(BaseModel): + message: str = Field(default='OK') + + +class UserToken(BaseModel): + id: int + account: str = None + name: str = None + phone_number: str = None + profile_img: str = None + account_type: str = None + + class Config: + orm_mode = True + + +class AddApiKey(BaseModel): + user_memo: str = None + + class Config: + orm_mode = True + + +class GetApiKeyList(AddApiKey): + id: int = None + access_key: str = None + created_at: datetime = None + + +class GetApiKeys(GetApiKeyList): + secret_key: str = None + + +class CreateAPIWhiteLists(BaseModel): + ip_addr: str = None + + +class GetAPIWhiteLists(CreateAPIWhiteLists): + id: int + + class Config: + orm_mode = True + + +class ResponseBase(BaseModel): + """ + ### [Response] API End-Point + + **정상처리**\n + - result: true\n + - error: null\n + + **오류발생**\n + - result: false\n + - error: 오류내용\n + """ + result: bool = Field(True, description='처리상태(성공: true, 실패: false)', example=False) + error: str = Field(None, description='오류내용(성공: null, 실패: 오류내용)', example='invalid data') + + @staticmethod + def set_error(error): + ResponseBase.result = False + ResponseBase.error = str(error) + return ResponseBase + + class Config: + orm_mode = True + + +class PagingReq(BaseModel): + """ + ### [Request] 페이징 정보 + """ + start_page: int = Field(None, description='시작 페이지 번호(base: 1)', example=1) + page_contents_num: int = Field(None, description='페이지 내용 개수', example=2) + + class Config: + orm_mode = True + + +class PagingRes(BaseModel): + """ + ### [Response] 페이징 정보 + """ + total_page_num: int = Field(None, description='전체 페이지 개수', example=100) + total_contents_num: int = Field(None, description='전체 내용 개수', example=100) + start_page: int = Field(None, description='시작 페이지 번호(base: 1)', example=1) + search_contents_num: int = Field(None, description='검색된 내용 개수', example=100) + + class Config: + orm_mode = True + + +class TokenRes(ResponseBase): + """ + ### [Response] 토큰 정보 + """ + Authorization: str = Field(None, description='인증키', example='Bearer [token]') + + class Config: + orm_mode = True + + +class UserLogInfo(BaseModel): + """ + ### 사용자 로그 정보 + """ + id: int = Field(None, description='Table Index', example='1') + created_at: datetime = Field(None, description='생성날짜', example='2022-01-01T12:34:56') + updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56') + + account: str = Field(None, description='계정', example='user1@test.com') + mac: str = Field(None, description='MAC(네트워크 인터페이스 식별자)', example='11:22:33:44:55:66') + type: str = Field(None, description='로그 타입' + TypeUserLog.get_elements_str(), example=TypeUserLog.info) + api: str = Field(None, description='API 이름', example='/api/auth/login') + message: str = Field(None, description='로그 내용', example='ok') + + class Config: + orm_mode = True + + +class UserInfo(BaseModel): + """ + ### 유저 정보 + """ + id: int = Field(None, description='Table Index', example='1') + created_at: datetime = Field(None, description='생성날짜', example='2022-01-01T12:34:56') + updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56') + + status: TypeUserAccountStatus = Field(None, description='계정상태' + TypeUserAccountStatus.get_elements_str(), example=TypeUserAccountStatus.active) + user_type: TypeUserGrade = Field(None, description='유저 타입' + TypeUserGrade.get_elements_str(), example=TypeUserGrade.user) + account_type: TypeUserAccount = Field(None, description='계정종류' + TypeUserAccount.get_elements_str(), example=TypeUserAccount.email) + account: str = Field(None, description='계정', example='user1@test.com') + email: str = Field(None, description='전자메일', example='user1@test.com') + name: str = Field(None, description='이름', example='user1') + sex: str = Field(None, description='성별' + TypeUserSex.get_elements_str(), example=TypeUserSex.male) + rrn: str = Field(None, description='주민등록번호', example='123456-1234567') + address: str = Field(None, description='주소', example='대구1') + phone_number: str = Field(None, description='연락처', example='010-1234-1234') + picture: str = Field(None, description='프로필사진', example='profile1.png') + marketing_agree: bool = Field(False, description='마케팅동의 여부', example=False) + # extra + login: TyepUserLoginStatus = Field(None, description='로그인 상태' + TyepUserLoginStatus.get_elements_str(), example=TyepUserLoginStatus.logout) # TODO(hsj100): LOGIN_STATUS + member_type: TypeUserMembership = Field(None, description='회원 타입' + TypeUserMembership.get_elements_str(), example=TypeUserMembership.personal) + + class Config: + orm_mode = True + + +class SendMailReq(BaseModel): + """ + ### [Request] 메일 전송 + """ + smtp_host: str = Field(None, description='SMTP 서버 주소', example=SMTP_HOST) + smtp_port: int = Field(None, description='SMTP 서버 포트', example=SMTP_PORT) + title: str = Field(None, description='제목', example=MAIL_REG_TITLE) + recipient: str = Field(None, description='수신자', example='user1@test.com') + cc_list: list = Field(None, description='참조 리스트', example='["user2@test.com", "user3@test.com"]') + # recipient: str = Field(None, description='수신자', example='hsj100@daooldns.co.kr') + # cc_list: list = Field(None, description='참조 리스트', example=None) + contents_plain: str = Field(None, description='내용', example=MAIL_REG_CONTENTS.format('user1@test.com', 10)) + contents_html: str = Field(None, description='내용', example='

내 고양이는 아주 고약해.

') + + class Config: + orm_mode = True + + +class UserInfoRes(ResponseBase): + """ + ### [Response] 유저 정보 + """ + data: UserInfo = None + + class Config: + orm_mode = True + + +class UserLoginReq(BaseModel): + """ + ### 유저 로그인 정보 + """ + account: str = Field(description='계정', example='user1@test.com') + pw: str = Field(None, description='비밀번호 [관리자 필수]', example='1234') + # SW-LICENSE + sw: str = Field(None, description='라이선스 SW 이름 [유저 필수]', example='저작도구') + mac: str = Field(None, description='MAC [유저 필수]', example='11:22:33:44:55:01') + + +class UserLoginRes(ResponseBase): + """ + ### [Response] 토큰 정보 + """ + Authorization: str = Field(None, description='인증키', example='Bearer [token]') + users: UserInfo = None + + class Config: + orm_mode = True + + +class UserRegisterReq(BaseModel): + """ + ### [Request] 유저 등록 + """ + status: Optional[str] = Field(TypeUserAccountStatus.active, description='계정상태' + TypeUserAccountStatus.get_elements_str(), example=TypeUserAccountStatus.active) + user_type: Optional[TypeUserGrade] = Field(TypeUserGrade.user, description='유저 타입' + TypeUserGrade.get_elements_str(), example=TypeUserGrade.user) + account: str = Field(description='계정', example='test@test.com') + pw: Optional[str] = Field(DEFAULT_USER_ACCOUNT_PW, description='비밀번호', example='1234') + email: Optional[str] = Field(None, description='전자메일', example='test@test.com') + name: str = Field(description='이름', example='test') + sex: Optional[TypeUserSex] = Field(TypeUserSex.male, description='성별' + TypeUserSex.get_elements_str(), example=TypeUserSex.male) + rrn: Optional[str] = Field(None, description='주민등록번호', example='19910101-1234567') + address: Optional[str] = Field(None, description='주소', example='대구광역시 동구 동촌로 351, 4층 (용계동, 에이스빌딩)') + phone_number: Optional[str] = Field(None, description='휴대전화', example='010-1234-1234') + picture: Optional[str] = Field(None, description='사진', example='profile1.png') + marketing_agree: Optional[bool] = Field(False, description='마케팅동의 여부', example=False) + # extra + member_type: Optional[TypeUserMembership] = Field(TypeUserMembership.personal, description='회원 타입' + TypeUserMembership.get_elements_str(), example=TypeUserMembership.personal) + + class Config: + orm_mode = True + + +class UserSearchReq(BaseModel): + """ + ### [Request] 유저 검색 (기본) + """ + # basic + id: Optional[int] = Field(None, description='등록번호', example='1') + id__in: Optional[list] = Field(None, description='등록번호 리스트', example=[1,]) + created_at: Optional[str] = Field(None, description='생성날짜', example='2022-01-01T12:34:56') + created_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00') + created_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00') + created_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00') + created_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00') + updated_at: Optional[str] = Field(None, description='수정날짜', example='2022-01-01T12:34:56') + updated_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00') + updated_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00') + updated_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00') + updated_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00') + + status: Optional[TypeUserAccountStatus] = Field(None, description='계정상태' + TypeUserAccountStatus.get_elements_str(), example=TypeUserAccountStatus.active) + user_type: Optional[TypeUserGrade] = Field(None, description='유저 타입' + TypeUserGrade.get_elements_str(), example=TypeUserGrade.user) + account_type: Optional[TypeUserAccount] = Field(None, description='계정종류' + TypeUserAccount.get_elements_str(), example=TypeUserAccount.email) + account: Optional[str] = Field(None, description='계정', example='user1@test.com') + account__like: Optional[str] = Field(None, description='계정 부분검색', example='%user%') + email: Optional[str] = Field(None, description='전자메일', example='user1@test.com') + email__like: Optional[str] = Field(None, description='전자메일 부분검색', example='%user%') + name: Optional[str] = Field(None, description='이름', example='test') + name__like: Optional[str] = Field(None, description='이름 부분검색', example='%user%') + sex: Optional[TypeUserSex] = Field(None, description='성별' + TypeUserSex.get_elements_str(), example=TypeUserSex.male) + rrn: Optional[str] = Field(None, description='주민등록번호', example='123456-1234567') + address: Optional[str] = Field(None, description='주소', example='대구1') + address__like: Optional[str] = Field(None, description='주소 부분검색', example='%대구%') + phone_number: Optional[str] = Field(None, description='연락처', example='010-1234-1234') + picture: Optional[str] = Field(None, description='프로필사진', example='profile1.png') + marketing_agree: Optional[bool] = Field(None, description='마케팅동의 여부', example=False) + # extra + login: Optional[TyepUserLoginStatus] = Field(None, description='로그인 상태' + TyepUserLoginStatus.get_elements_str(), example=TyepUserLoginStatus.logout) # TODO(hsj100): LOGIN_STATUS + member_type: Optional[TypeUserMembership] = Field(None, description='회원 타입' + TypeUserMembership.get_elements_str(), example=TypeUserMembership.personal) + + class Config: + orm_mode = True + + +class UserSearchRes(ResponseBase): + """ + ### [Response] 유저 검색 + """ + data: List[UserInfo] = [] + + class Config: + orm_mode = True + + +class UserSearchPagingReq(BaseModel): + """ + ### [Request] 유저 페이징 검색 + """ + # paging + paging: Optional[PagingReq] = None + search: Optional[UserSearchReq] = None + + +class UserSearchPagingRes(ResponseBase): + """ + ### [Response] 유저 페이징 검색 + """ + paging: PagingRes = None + data: List[UserInfo] = [] + + class Config: + orm_mode = True + + +class UserUpdateReq(BaseModel): + """ + ### [Request] 유저 변경 + """ + status: Optional[TypeUserAccountStatus] = Field(None, description='계정상태' + TypeUserAccountStatus.get_elements_str(), example=TypeUserAccountStatus.active) + user_type: Optional[TypeUserGrade] = Field(None, description='유저 타입' + TypeUserGrade.get_elements_str(), example=TypeUserGrade.user) + account_type: Optional[TypeUserAccount] = Field(None, description='계정종류' + TypeUserAccount.get_elements_str(), example=TypeUserAccount.email) + account: Optional[str] = Field(None, description='계정', example='user1@test.com') + email: Optional[str] = Field(None, description='전자메일', example='user1@test.com') + name: Optional[str] = Field(None, description='이름', example='test') + sex: Optional[TypeUserSex] = Field(None, description='성별' + TypeUserSex.get_elements_str(), example=TypeUserSex.male) + rrn: Optional[str] = Field(None, description='주민등록번호', example='123456-1234567') + address: Optional[str] = Field(None, description='주소', example='대구1') + phone_number: Optional[str] = Field(None, description='연락처', example='010-1234-1234') + picture: Optional[str] = Field(None, description='프로필사진', example='profile1.png') + marketing_agree: Optional[bool] = Field(False, description='마케팅동의 여부', example=False) + # extra + login: Optional[TyepUserLoginStatus] = Field(None, description='로그인 상태' + TyepUserLoginStatus.get_elements_str(), example=TyepUserLoginStatus.logout) # TODO(hsj100): LOGIN_STATUS + # uuid: Optional[str] = Field(None, description='UUID', example='12345678-1234-5678-1234-567800000001') + member_type: Optional[TypeUserMembership] = Field(None, description='회원 타입' + TypeUserMembership.get_elements_str(), example=TypeUserMembership.personal) + + class Config: + orm_mode = True + + +class UserUpdateMultiReq(BaseModel): + """ + ### [Request] 유저 변경 (multi) + """ + search_info: UserSearchReq = None + update_info: UserUpdateReq = None + + class Config: + orm_mode = True + + +class UserUpdatePWReq(BaseModel): + """ + ### [Request] 유저 비밀번호 변경 + """ + account: str = Field(None, description='계정', example='user1@test.com') + current_pw: str = Field(None, description='현재 비밀번호', example='1234') + new_pw: str = Field(None, description='신규 비밀번호', example='5678') + + +class UserLogSearchReq(BaseModel): + """ + ### [Request] 유저로그 검색 + """ + # basic + id: Optional[int] = Field(None, description='등록번호', example='1') + id__in: Optional[list] = Field(None, description='등록번호 리스트', example=[1,]) + created_at: Optional[str] = Field(None, description='생성날짜', example='2022-01-01T12:34:56') + created_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00') + created_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00') + created_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00') + created_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00') + updated_at: Optional[str] = Field(None, description='수정날짜', example='2022-01-01T12:34:56') + updated_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00') + updated_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00') + updated_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00') + updated_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00') + + account: Optional[str] = Field(None, description='계정', example='user1@test.com') + account__like: Optional[str] = Field(None, description='계정 부분검색', example='%user%') + mac: Optional[str] = Field(None, description='MAC', example='11:22:33:44:55:01') + mac__like: Optional[str] = Field(None, description='MAC 부분검색', example='%33%') + type: Optional[TypeUserLog] = Field(None, description='유저로그 메시지 타입' + TypeUserLog.get_elements_str(), example=TypeUserLog.error) + api: Optional[str] = Field(None, description='API 이름', example='/api/auth/login') + api__like: Optional[str] = Field(None, description='API 이름 부분검색', example='%login%') + message: Optional[str] = Field(None, description='로그내용', example='invalid password') + message__like: Optional[str] = Field(None, description='로그내용 부분검색', example='%invalid%') + + class Config: + orm_mode = True + + +class UserLogDayInfo(BaseModel): + """ + ### 유저로그 일별 접속 정보 (출석) + """ + account: str = Field(None, description='계정', example='user1@test.com') + mac: str = Field(None, description='MAC(네트워크 인터페이스 식별자)', example='11:22:33:44:55:66') + updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56') + message: str = Field(None, description='로그 내용', example='ok') + + class Config: + orm_mode = True + + +class UserLogDaySearchReq(BaseModel): + """ + ### [Request] 유저로그 일별 마지막 접속 검색 + """ + updated_at: Optional[str] = Field(None, description='수정날짜', example='2022-01-01T12:34:56') + updated_at__gt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이후)', example='2022-02-10T15:00:00') + updated_at__gte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이후)', example='2022-02-10T15:00:00') + updated_at__lt: Optional[str] = Field(None, description='생성날짜(주어진 날짜 이전)', example='2022-02-10T15:00:00') + updated_at__lte: Optional[str] = Field(None, description='생성날짜(주어진 날짜 포함 이전)', example='2022-02-10T15:00:00') + + account: Optional[str] = Field(None, description='계정', example='user1@test.com') + account__like: Optional[str] = Field(None, description='계정 부분검색', example='%user%') + mac: Optional[str] = Field(None, description='MAC', example='11:22:33:44:55:01') + mac__like: Optional[str] = Field(None, description='MAC 부분검색', example='%33%') + + class Config: + orm_mode = True + + +class UserLogDaySearchPagingReq(BaseModel): + """ + ### [Request] 유저로그 일별 마지막 접속 페이징 검색 + """ + # paging + paging: Optional[PagingReq] = None + search: Optional[UserLogDaySearchReq] = None + + +class UserLogDaySearchPagingRes(ResponseBase): + """ + ### [Response] 유저로그 일별 마지막 접속 검색 + """ + paging: PagingRes = None + data: List[UserLogDayInfo] = [] + + class Config: + orm_mode = True + + +class UserLogPagingReq(BaseModel): + """ + ### [Request] 유저로그 페이징 검색 + """ + # paging + paging: Optional[PagingReq] = None + search: Optional[UserLogSearchReq] = None + + +class UserLogPagingRes(ResponseBase): + """ + ### [Response] 유저로그 페이징 검색 + """ + paging: PagingRes = None + data: List[UserLogInfo] = [] + + class Config: + orm_mode = True + + +class UserLogUpdateReq(BaseModel): + """ + ### [Request] 유저로그 변경 + """ + account: Optional[str] = Field(None, description='계정', example='user1@test.com') + mac: Optional[str] = Field(None, description='MAC', example='11:22:33:44:55:01') + type: Optional[TypeUserLog] = Field(None, description='유저로그 메시지 타입' + TypeUserLog.get_elements_str(), example=TypeUserLog.error) + api: Optional[str] = Field(None, description='API 이름', example='/api/auth/login') + message: Optional[str] = Field(None, description='로그내용', example='invalid password') + + class Config: + orm_mode = True + + +class UserLogUpdateMultiReq(BaseModel): + """ + ### [Request] 유저로그 변경 (multi) + """ + search_info: UserLogSearchReq = None + update_info: UserLogUpdateReq = None + + class Config: + orm_mode = True + + +############################################################################################## +# APPEND INFORMATION +############################################################################################## + +class AEStatusType(str, CustomEnum): + """ + ### AI Engine 상태 타입 + """ + init: str = 'init' + running: str = 'running' + stop: str = 'stop' + error: str = 'error' + + +class AEAIModelType(str, CustomEnum): + """ + ### AI Engine에서 사용하는 AI 모델 타입 + """ + FR: str = 'FR' + # OD: str = 'OD' + CON: str = 'CON' + PPE: str = 'PPE' + WORK: str = 'WORK' + BI: str = 'BI' + + +class AEAIModelStatusType(str, CustomEnum): + """ + ### AI Engine 모델 상태 타입 + """ + init: str = 'init' + idle: str = 'idle' + detect: str = 'detect' + error: str = 'error' + + +class AERIParameterInfo(BaseModel): + """ + ### RI(Risk Index) 매개 변수 정보 + """ + name: str = Field(description='RI 변수명', example='작업자 숙련도') + ratio: float = Field(description='RI 변수의 비중', example=0.60) + + class Config: + orm_mode = True + + +class AEWorkRIInfo(BaseModel): + """ + ### 작업 RI(Risk Index) 정보 + """ + construction_code: str = Field(description='작업 공종 코드', example='D54') + construction_name: Optional[str] = Field(None, description='작업 공종명', example='간접활선 점퍼선 절단') + work_no: int = Field(description='해당 공종의 작업 절차 번호', example=1) + work_name: Optional[str] = Field(None, description='해당 작업 절차의 작업명', example='점퍼선 절단') + work_define_ri: float = Field(description='해당 작업에 정의된 위험도', example=0.77) + ri_parameter_list: List[AERIParameterInfo] = Field(description='RI 정보 리스트', example=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)]) + evaluation_work_ri: Optional[float] = Field(0.00, description='AI 모델에서 평가한 작업(work_no) 위험도', example=0.52) + + class Config: + orm_mode = True + + +class AEFTPInfo(BaseModel): + """ + ### FTP 정보 + """ + ip: str = Field(description='IP address', example='106.255.245.242') + port: int = Field(22, description='SFTP port', example=2022) + id: str = Field(description='User ID', example='kepri_if_user') + pw: str = Field(description='User passward') + location: str = Field(description='작업 공종 코드', example='/home/agics-dev/kepri_storage/rndpartners/') + file_con_setup: str = Field("3", description='시설물 탐지 파일이름', example='3') + file_face: str = Field("1", description='안면 인식 파일이름', example='1') + file_ppe: str = Field("2", description='개인보호구 탐지(PPE) 파일이름', example='2') + file_wd: str = Field("5", description='위험성 탐지(WD) 파일 이름', example='5') + file_bi: str = Field("4", description='위험성 탐지(BI) 파일 이름', example='4') + + +class AEInputBaseInfo(BaseModel): + """ + ### AI Engine 입력 정보 (base) + """ + name: str = Field('device_name', description='장치명', example='device_name') + model: str = Field('model_name', description='모델명', example='model_name') + sn: str = Field('serial_no', description='시리얼번호', example='serial_no') + connect_url: str = Field('connect_url', description='서비스 접속 주소', example='connect_url') + user_id: str = Field('user_id', description='서비스 접속 계정', example='user_id') + user_pw: str = Field('user_pw', description='서비스 접속 비번', example='user_pw') + + class Config: + orm_mode = True + + +class AEDemoInfo(BaseModel): + """ + ### AI Engine DEMO 입력 정보 (base) + """ + ftp: AEFTPInfo = Field(description='FTP 정보') + + +class AEInputVideoInfo(AEInputBaseInfo): + """ + ### AI Engine 입력 정보 (video) + """ + model:AEAIModelType = Field('model_name', description='모델명', example=AEAIModelType.FR) + + class Config: + orm_mode = True + + +class AEInputBIInfo(AEInputBaseInfo): + """ + ### AI Engine 입력 정보 (BI) + """ + topic: str = Field('topic', description='접속 토픽', example='test') + + class Config: + orm_mode = True + + +class AEAIModelBase(BaseModel): + """ + ### AI Model 정보 (base) + """ + name: AEAIModelType = Field(None, description='AI 모델명' + AEAIModelType.get_elements_str(), example=AEAIModelType.FR) + version: str = Field(None, description='AI Model Version (by date)', example='20220101') + status: AEAIModelStatusType = Field(None, description='AI Model 상태' + AEAIModelStatusType.get_elements_str(), example=AEAIModelStatusType.init) + + class Config: + orm_mode = True + + +class AEAIModelFRInfo(AEAIModelBase): + """ + ### AI Model FR(Face Recognition) 정보 + """ + name: AEAIModelType = Field(None, description='AI 모델명' + AEAIModelType.get_elements_str(), example=AEAIModelType.FR) + ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54', work_no=3, work_define_ri=0.82, ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)])) + + class Config: + orm_mode = True + + +class AEAIModelODModeType(str, CustomEnum): + """ + ### AI Engine 객체탐지 모델 동작 모드 + """ + ppe: str = 'ppe' + work: str = 'work' + con: str = 'con' + + +class AEAIModelODModelType(str, CustomEnum): + """ + ### AI Engine 객체탐지에서 사용하는 모델 타입 + """ + nano: str = 'nano' + small: str = 'small' + medium: str = 'medium' + large: str = 'large' + xlarge: str = 'xlarge' + + +class AEAIModelODWeightInfo(BaseModel): + """ + ### AI Model OD(Object Detection) Weight 정보 + """ + id: int = Field(0, description='Table Index', example='1') + + filename: str = Field(description='파일명', example='index_78.pt') + version: str = Field(SW_VERSION, description='Version (by date)', example=SW_VERSION) + date: str = Field(D.date_str(), description='Version (by date)', example=D.date_str()) + model: AEAIModelODModelType = Field(AEAIModelODModelType.small, description='학습한 환경의 객체탐지 모델 타입' + AEAIModelODModelType.get_elements_str(), example=AEAIModelODModelType.small) + nc: int = Field(description='클래스 개수', example=26) + + class Config: + orm_mode = True + + +class AEAIModelODInfo(AEAIModelBase): + """ + ### AI Model OD(Object Detection) 정보 + """ + name: AEAIModelType = Field(None, description='AI 모델명' + AEAIModelType.get_elements_str(), example=AEAIModelType.PPE) + ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54', work_no=3, work_define_ri=0.82, ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)])) + weights: List[AEAIModelODWeightInfo] = Field(description='가중치 정보 리스트', example=[AEAIModelODWeightInfo(filename='index_78.pt', nc=26)]) + mode: AEAIModelODModeType = Field(AEAIModelODModeType.ppe, description='모델 동작 모드' + AEAIModelODModeType.get_elements_str(), example=AEAIModelODModeType.ppe) + crop_images: bool = Field(True, description='Detect 이미지 전송 여부', example=True) + + class Config: + orm_mode = True + +class AEAIModelCONInfo(AEAIModelODInfo): + """ + ### AI Model CON 정보 + """ + name: AEAIModelType = Field(None, description='AI 모델명' + AEAIModelType.get_elements_str(), example=AEAIModelType.CON) + mode: AEAIModelODModeType = Field(AEAIModelODModeType.con, description='모델 동작 모드' + AEAIModelODModeType.get_elements_str(), example=AEAIModelODModeType.con) + +class AEAIModelPPEInfo(AEAIModelODInfo): + """ + ### AI Model PPE 정보 + """ + name: AEAIModelType = Field(None, description='AI 모델명' + AEAIModelType.get_elements_str(), example=AEAIModelType.PPE) + mode: AEAIModelODModeType = Field(AEAIModelODModeType.ppe, description='모델 동작 모드' + AEAIModelODModeType.get_elements_str(), example=AEAIModelODModeType.ppe) + +class AEAIModelWDInfo(AEAIModelODInfo): + """ + ### AI Model WD 정보 + """ + name: AEAIModelType = Field(None, description='AI 모델명' + AEAIModelType.get_elements_str(), example=AEAIModelType.WORK) + mode: AEAIModelODModeType = Field(AEAIModelODModeType.work, description='모델 동작 모드' + AEAIModelODModeType.get_elements_str(), example=AEAIModelODModeType.work) + crop_images: bool = Field(False, description='Detect 이미지 전송 여부', example=False) + +class AEAIModelBIInfo(AEAIModelBase): + """ + ### AI Model BI(BI Anomaly Detection) 정보 + """ + name: AEAIModelType = Field(None, description='AI 모델명' + AEAIModelType.get_elements_str(), example=AEAIModelType.BI) + ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54', work_no=3, work_define_ri=0.82, ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)])) + + class Config: + orm_mode = True + + +class AEInfo(BaseModel): + """ + ### AI Engine 정보 + """ + # id: int = Field(None, description='Table Index', example='1') + # created_at: datetime = Field(None, description='생성날짜', example='2022-01-01T12:34:56') + # updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56') + + version: str = Field(SW_VERSION, description='AI Engine Version', example=SW_VERSION) + ai_engine_status: AEStatusType = Field(AEStatusType.init, description='AI Engine 상태' + AEStatusType.get_elements_str(), example=AEStatusType.init) + demo: AEDemoInfo = Field(None, description='DEMO') + input_video: List[AEInputVideoInfo] = Field(None, description='입력장치 정보[영상]', example=[ + AEInputVideoInfo(name='video1',model=AEAIModelType.CON), + AEInputVideoInfo(name='video2',model=AEAIModelType.FR), + AEInputVideoInfo(name='video3',model=AEAIModelType.PPE), + AEInputVideoInfo(name='video4',model=AEAIModelType.WORK)]) + + input_bi: AEInputBIInfo = Field(None, description='입력장치 정보[BI]', example=AEInputBIInfo()) + + con_model_info: AEAIModelCONInfo = Field(None, description='시설물 AI Model 정보') + fr_model_info: AEAIModelFRInfo = Field(None, description='안면인식 AI Model 정보') + ppe_model_info: AEAIModelPPEInfo = Field(None, description='개인보호구 PPE AI Model 정보') + wd_model_info: AEAIModelWDInfo = Field(None, description='객체탐지 AI Model 정보') + bi_model_info: AEAIModelBIInfo = Field(None, description='BI 이상탐지 AI Model 정보') + + class Config: + orm_mode = True + + +class AEInfoSetReq(AEInfo): + """ + ### [Request] AI Engine 설정 + """ + + +class AEAIModelFRReq(BaseModel): + """ + ### [Request] 안면인식 + """ + limit_time_min: Optional[int] = Field(AI_ENGINE_FR_DETECT_TIME_MIN, description='탐지 제한 시간(분)', example=AI_ENGINE_FR_DETECT_TIME_MIN) + report_unit: Optional[bool] = Field(False, description='신규 대상을 인식할 때 마다 결과 전송', example=True) + ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54', work_no=3, work_define_ri=0.82, ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)])) + targets: list = Field(description='인식 대상 리스트 (고유번호사용: 사원번호 or 주민번호 ...)(TBD: 협의)', example='["no000001"]') + + class Config: + orm_mode = True + + +class AEPPEBaseInfo(BaseModel): + """ + ### 기본 개인보호구(PPE) 정보(TBD: 협의) + """ + # id: int = Field(None, description='Table Index', example='1') + # created_at: datetime = Field(None, description='생성날짜', example='2022-01-01T12:34:56') + # updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56') + + # version: str = Field(SW_VERSION, description='AI Engine Version', example=SW_VERSION) + safety_helmet_on: bool = Field(True, description='안전모 탐지', example=True) + safety_gloves_work_on: bool = Field(True, description='안전장갑(목장갑) 탐지', example=True) + safety_boots_on: bool = Field(True, description='안전화 탐지', example=True) + safety_belt_basic_on: bool = Field(False, description='안전대(벨트식) 탐지', example=False) + # safety_suit_top: bool = Field(True, description='안전복(방염복)상의', example=True) + # safety_suit_bottom: bool = Field(True, description='안전복(방염복)하의', example=True) + + class Config: + orm_mode = True + +class AEConSetupBaseInfo(BaseModel): + """ + ### 기본 시설물 정보(TBD: 협의) + """ + # id: int = Field(None, description='Table Index', example='1') + # created_at: datetime = Field(None, description='생성날짜', example='2022-01-01T12:34:56') + # updated_at: datetime = Field(None, description='수정날짜', example='2022-01-01T12:34:56') + + # version: str = Field(SW_VERSION, description='AI Engine Version', example=SW_VERSION) + signalman: bool = Field(True, description='신호수 탐지', example=True) + sign_traffic_cone: bool = Field(True, description='트래픽콘(라바콘) 탐지', example=True) + + sign_board_information: bool = Field(False, description='안내표지판 탐지', example=False) + sign_board_construction: bool = Field(False, description='공사안내판 탐지', example=False) + sign_board_traffic: bool = Field(False, description='교통안전표지판 탐지', example=False) + + class Config: + orm_mode = True + +class AEAIModelConSetupDetectReq(BaseModel): + """ + ### [Request] 시설물 탐지 + """ + limit_time_min: Optional[int] = Field(AI_ENGINE_FR_DETECT_TIME_MIN, description='탐지 제한 시간(분)', example=AI_ENGINE_FR_DETECT_TIME_MIN) + report_unit: Optional[bool] = Field(False, description='신규 대상을 인식할 때 마다 결과 전송', example=True) + ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54',work_no=3,work_define_ri=0.82,ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)])) + targets: AEConSetupBaseInfo = Field(description='탐지 대상', example=AEConSetupBaseInfo()) + + class Config: + orm_mode = True + +class AEAIModelPPEDetectReq(BaseModel): + """ + ### [Request] 개인보호구(PPE) 탐지 + """ + limit_time_min: Optional[int] = Field(AI_ENGINE_FR_DETECT_TIME_MIN, description='탐지 제한 시간(분)', example=AI_ENGINE_FR_DETECT_TIME_MIN) + report_unit: Optional[bool] = Field(False, description='신규 대상을 인식할 때 마다 결과 전송', example=True) + ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54',work_no=3,work_define_ri=0.82,ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)])) + targets: AEPPEBaseInfo = Field(description='탐지 대상', example=AEPPEBaseInfo()) + + class Config: + orm_mode = True + + +class AEAIModelRBIVideoReq(BaseModel): + """ + ### [Request] 위험성 탐지 (영상) + """ + # limit_time_min: Optional[int] = Field(AI_ENGINE_FR_DETECT_TIME_MIN, description='탐지 제한 시간(분)', example=AI_ENGINE_FR_DETECT_TIME_MIN) + ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54', work_no=3, work_define_ri=0.82, ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)])) + + class Config: + orm_mode = True + + +class AEAIModelRBIBIReq(BaseModel): + """ + ### [Request] 위험성 탐지 (BI) + """ + ri: AEWorkRIInfo = Field(description='RI 정보(TBD: 협의)', example=AEWorkRIInfo(construction_code='D54', work_no=3, work_define_ri=0.82, ri_parameter_list=[AERIParameterInfo(name='작업자 숙련도', ratio=0.60), AERIParameterInfo(name='작업자 교육레벨', ratio=0.50)])) + + class Config: + orm_mode = True + + +class AEAIModelFRMQTTInfo(BaseModel): + test: str = Field(None, description='test', example='test') + + class Config: + orm_mode = True + + \ No newline at end of file diff --git a/REST_AI_ENGINE_CONTROL/app/routes/auth.py b/REST_AI_ENGINE_CONTROL/app/routes/auth.py new file mode 100644 index 0000000..645be35 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/routes/auth.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +""" +@File: auth.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: authentication api + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +from itertools import groupby +from operator import attrgetter +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from starlette.requests import Request +import bcrypt +import jwt +from datetime import datetime, timedelta + +from app.common import consts +from app import models as M +from app.database.conn import db +from app.common.config import conf +from app.database.schema import Users, UserLog +from app.utils.extra import AESCryptoCBC +from app.utils.date_utils import D + + +router = APIRouter(prefix='/auth') + + +@router.get('/find-account/{account}', response_model=M.ResponseBase, summary='계정유무 검사') +async def find_account(account: str): + """ + ## 계정유무 검사 + + 주어진 계정이 존재하면 true, 없으면 false 처리 + + **결과** + - ResponseBase + """ + try: + search_info = Users.get(account=account) + if not search_info: + raise Exception(f'not found data: {account}') + return M.ResponseBase() + except Exception as e: + return M.ResponseBase.set_error(str(e)) + + +@router.post('/register/{account_type}', status_code=201, response_model=M.TokenRes, summary='회원가입') +async def register(request: Request, account_type: M.TypeUserAccount, request_body_info: M.UserRegisterReq, session: Session = Depends(db.session)): + """ + ## 회원가입 (신규등록) + + **AI-KEPCO REST SERVER** + - **PW** 항목은 로그인시에 사용되지 않지만, 임의값을 사용 + - 가입 완료 메일 발송은 별도 API(**/api/services/sendmail**) 이용 + - email 항목 생략시 account(email type)값이 적용 + + **결과** + - TokenRes + """ + new_user = None + new_license = None + + try: + # request + # check account + if request.state.user.account != consts.ADMIN_INIT_ACCOUNT_INFO.account: + raise Exception('not admin') + else: + raise Exception('admin') + + if not request_body_info.account or not request_body_info.pw: + raise Exception('Account and PW must be provided') + + register_user = Users.get(account=request_body_info.account) + if register_user: + raise Exception(f'exist email: {request_body_info.account}') + + if account_type == M.TypeUserAccount.email: + hash_pw = None + if request_body_info.pw: + # decrypt pw + try: + decode_pw = request_body_info.pw.encode('utf-8') + desc_pw = AESCryptoCBC().decrypt(decode_pw) + except Exception as e: + raise Exception(f'failed decryption [pw]: {e}') + + hash_pw = bcrypt.hashpw(desc_pw, bcrypt.gensalt()) + + # create users + new_user = Users.create(session, auto_commit=True, + status=request_body_info.status, + user_type=request_body_info.user_grade, + account_type=account_type, + account=request_body_info.account, + pw=hash_pw, + email=request_body_info.email if request_body_info.email else request_body_info.account, + name=request_body_info.name, + sex=request_body_info.sex, + rrn=request_body_info.rrn, + address=request_body_info.address, + phone_number=request_body_info.mobile, + picture=request_body_info.picture, + marketing_agree=request_body_info.marketing_agree, + member_type=request_body_info.membership + ) + + token = dict(Authorization=f'Bearer {create_access_token(data=M.UserToken.from_orm(new_user).dict(exclude={"pw", "marketing_agree"}),)}') + return token + + raise Exception('not supported') + except Exception as e: + return M.ResponseBase.set_error(str(e)) + + +@router.post('/login/{account_type}', status_code=200, response_model=M.UserLoginRes, summary='사용자 접속') +async def login(account_type: M.TypeUserAccount, request_body_info: M.UserLoginReq, session: Session = Depends(db.session)): + """ + ## 사용자 접속 + + **AI-KEPCO REST SERVER** + - 관리자 접속시: account, pw 사용\n + - 일반 계정 접속시: account, sw, mac 사용\n + + **결과** + - UserLoginRes + """ + login_user = None + update_user = None + license_info = None + + try: + # request + if not request_body_info.account: + raise Exception('invalid request: account') + + login_user = Users.get(account=request_body_info.account) + if not login_user: + raise Exception('not found user') + + if account_type == M.TypeUserAccount.email: + # basic auth + + if login_user.user_grade == M.TypeUserGrade.user: + # check pw: pass + + # NOTE(hsj100): CONNECT_USER_ERROR_LOG + UserLog.create(session=session, auto_commit=True + , account=login_user.account + , mac=request_body_info.mac if login_user.user_grade == M.TypeUserGrade.user else None + , type=M.TypeUserLog.info + , api='/api/auth/login' + , message='ok' + ) + elif login_user.user_grade == M.TypeUserGrade.admin: + # decrypt pw + try: + decode_pw = request_body_info.pw.encode('utf-8') + desc_pw = AESCryptoCBC().decrypt(decode_pw) + except Exception as e: + raise Exception(f'failed decryption [pw]: {e}') + + # check pw + if not request_body_info.pw: + raise Exception('invalid request: pw') + is_verified = bcrypt.checkpw(desc_pw, login_user.pw.encode('utf-8')) + if not is_verified: + raise Exception('invalid password') + else: + raise Exception('invalid user_type') + + # TODO(hsj100): LOGIN_STATUS + update_user = Users.filter(account=request_body_info.account).update(auto_commit=True, login='login') + + result_info = M.UserLoginRes() + result_info.Authorization = f'Bearer {create_access_token(data=M.UserToken.from_orm(login_user).dict(exclude={"pw", "marketing_agree"}), )}' + + return result_info + + return M.ResponseBase.set_error('not supported') + except Exception as e: + if update_user: + update_user.close() + + # NOTE(hsj100): CONNECT_USER_ERROR_LOG + if login_user: + UserLog.create(session=session, auto_commit=True + , account=login_user.account + , mac=request_body_info.mac if login_user.user_type == M.TypeUserGrade.user else None + , type=M.TypeUserLog.error + , api='/api/auth/login' + , message=str(e) + ) + return M.ResponseBase.set_error(str(e)) + + +@router.post('/logout/{account}', status_code=200, response_model=M.TokenRes, summary='사용자 접속종료') +async def logout(account: str): + """ + ## 사용자 접속종료 + + 현재 버전에서는 로그인/로그아웃의 상태를 유지하지 않고 상태값만을 서버에서 사용하기 때문에,\n + ***로그상태는 실제상황과 다를 수 있다.*** + + 정상처리시 Authorization(null) 반환 + + **결과** + - TokenRes + """ + user_info = None + + try: + # TODO(hsj100): LOGIN_STATUS + user_info = Users.filter(account=account) + if not user_info: + raise Exception('not found user') + + user_info.update(auto_commit=True, login='logout') + return M.TokenRes() + except Exception as e: + if user_info: + user_info.close() + return M.ResponseBase.set_error(e) + + +async def is_account_exist(account: str): + get_account = Users.get(account=account) + return True if get_account else False + + +def create_access_token(*, data: dict = None, expires_delta: int = None): + + if conf().GLOBAL_TOKEN: + return conf().GLOBAL_TOKEN + + to_encode = data.copy() + if expires_delta: + to_encode.update({'exp': datetime.utcnow() + timedelta(hours=expires_delta)}) + encoded_jwt = jwt.encode(to_encode, consts.JWT_SECRET, algorithm=consts.JWT_ALGORITHM) + return encoded_jwt diff --git a/REST_AI_ENGINE_CONTROL/app/routes/dev.py b/REST_AI_ENGINE_CONTROL/app/routes/dev.py new file mode 100644 index 0000000..569c4be --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/routes/dev.py @@ -0,0 +1,234 @@ +# -*- coding: utf-8 -*- +""" +@File: dev.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: 개발용 API + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" +import struct + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +import bcrypt +from starlette.requests import Request + +from app.common import consts +from app import models as M +from app.database.conn import db, Base +from app.database.schema import Users, UserLog + + +from app.utils.extra import FernetCrypto, AESCryptoCBC, AESCipher + + +# mail test +import smtplib +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart + + +def send_mail(): + """ + 구글 계정사용시 : 보안 수준이 낮은 앱에서의 접근 활성화 + + :return: + """ + sender = 'jolimola@gmail.com' + sender_pw = '!ghkdtmdwns1' + # recipient = 'hsj100@daooldns.co.kr' + recipient = 'jwkim@daooldns.co.kr' + list_cc = ['cc1@gmail.com', 'cc2@naver.com'] + str_cc = ','.join(list_cc) + + title = 'Test mail' + contents = ''' + This is test mail + using smtplib. + ''' + + smtp_server = smtplib.SMTP( # 1 + host='smtp.gmail.com', + port=587 + ) + + smtp_server.ehlo() # 2 + smtp_server.starttls() # 2 + smtp_server.ehlo() # 2 + smtp_server.login(sender, sender_pw) # 3 + + msg = MIMEMultipart() # 4 + msg['From'] = sender # 5 + msg['To'] = recipient # 5 + # msg['Cc'] = str_cc # 5 + msg['Subject'] = contents # 5 + msg.attach(MIMEText(contents, 'plain')) # 6 + + smtp_server.send_message(msg) # 7 + smtp_server.quit() # 8 + + +router = APIRouter(prefix='/dev') + + +@router.get('/test', summary='테스트') +async def test(request: Request): + """ + ## 개발 테스트 API + + ## 결과 + - SWInfo + """ + + result = dict() + text1 = '!ekdnfeldpsdptm1' + text2 = 'testtesttest123' + + simpleEnDecrypt = FernetCrypto() + + result['1original_data'] = text1 + result['1simpleEnDecrypt.encrypt'] = simpleEnDecrypt.encrypt(text1) + result['1crypt.hashpw'] = bcrypt.hashpw(text1.encode('utf-8'), bcrypt.gensalt()) + + result['2original_data'] = text2 + result['2simpleEnDecrypt.encrypt'] = simpleEnDecrypt.encrypt(text2) + result['2bcrypt.hashpw'] = bcrypt.hashpw(text2.encode('utf-8'), bcrypt.gensalt()) + + t = bytes(text1.encode('utf-8')) + enc = AESCryptoCBC().encrypt(t) + dec = AESCryptoCBC().decrypt(enc) + + t = enc.decode('utf-8') + + # enc = AESCipher('daooldns12345678').encrypt(a.name).decode('utf-8') + # enc = 'E563ZFt+yJL8YY5yYlYyk602MSscPP2SCCD8UtXXpMI=' + # dec = AESCipher('daooldns12345678').decrypt(enc).decode('utf-8') + + result['AESCryptoCBC().enc'] = f'enc: {enc}, {t}' + result['AESCryptoCBC().dec'] = f'dec: {dec}' + + return result + + +@router.post('/db/add_test_userlog', summary='[DB] 테스트 데이터 추가 (유저로그)', response_model=M.ResponseBase) +async def add_test_userlog(session: Session = Depends(db.session)): + """ + ## 테스트 데이터 생성 (유저로그) + + 프로젝트(DATABASE)에 테스트 데이터(유저로그)를 추가한다. + + 존재하는 유저에 대해서 2개의 로그기록을 추가한다. + 추가되는 로그기록의 생성일은 2023년도 1월 부터 월당 2개이며, 12월이 넘어가면 로그기록을 종료한다. + + **결과** + - ResponseBase + """ + + start_month = 1 + user_list = Users.filter(user_type=M.TypeUserGrade.user).all() + for user_info in user_list: + user_no = '%02d' % int(user_info.account.split('@')[0][4:]) + temp = UserLog.create(session=session, auto_commit=True + , account=user_info.account + , mac=f'11:22:33:44:55:{user_no}' + , api='/api/auth/login' + , message='ok' + ) + temp2 = UserLog.create(session=session, auto_commit=True + , account=user_info.account + , mac=f'11:22:33:44:55:{user_no}' + , api='/api/auth/login' + , message='ok' + ) + # update date + UserLog.filter(id=temp.id).update(auto_commit=True + , created_at=f'2023-{start_month}-01') + UserLog.filter(id=temp2.id).update(auto_commit=True + , created_at=f'2023-{start_month}-10') + if start_month == 12: + break + start_month += 1 + + return M.ResponseBase() + + +@router.post('/db/add_test_data', summary='[DB] 테스트 데이터 생성', response_model=M.ResponseBase) +async def db_add_test_data(session: Session = Depends(db.session)): + """ + ## 테스트 데이터 생성 + + 프로젝트(DATABASE)에 테스트 데이터를 생성한다. + + **결과** + - ResponseBase + """ + return M.ResponseBase() + + +@router.post('/db/delete_all_data', summary='[DB] 데이터 삭제 [전체]', response_model=M.ResponseBase) +async def db_delete_all_data(): + """ + ## DB 데이터 삭제 + + 프로젝트(DATABASE)의 모든 테이블의 내용을 삭제한다. (테이블은 유지) + + **결과** + - ResponseBase + """ + engine = db.engine + metadata = Base.metadata + + foreign_key_turn_off = { + 'mysql': 'SET FOREIGN_KEY_CHECKS=0;', + 'postgresql': 'SET CONSTRAINTS ALL DEFERRED;', + 'sqlite': 'PRAGMA foreign_keys = OFF;', + } + foreign_key_turn_on = { + 'mysql': 'SET FOREIGN_KEY_CHECKS=1;', + 'postgresql': 'SET CONSTRAINTS ALL IMMEDIATE;', + 'sqlite': 'PRAGMA foreign_keys = ON;', + } + truncate_query = { + 'mysql': 'TRUNCATE TABLE {};', + 'postgresql': 'TRUNCATE TABLE {} RESTART IDENTITY CASCADE;', + 'sqlite': 'DELETE FROM {};', + } + + with engine.begin() as conn: + conn.execute(foreign_key_turn_off[engine.name]) + + for table in reversed(metadata.sorted_tables): + conn.execute(truncate_query[engine.name].format(table.name)) + + conn.execute(foreign_key_turn_on[engine.name]) + + return M.ResponseBase() + + +@router.post('/db/delete_all_tables', summary='[DB] 테이블 삭제 [전체]', response_model=M.ResponseBase) +async def db_delete_all_tables(): + """ + ## DB 데이터 삭제 + + 프로젝트(DATABASE)의 모든 테이블을 삭제한다. + + **결과** + - ResponseBase + """ + engine = db.engine + metadata = Base.metadata + + truncate_query = { + 'mysql': 'DROP TABLE IF EXISTS {};', + 'postgresql': 'DROP TABLE IF EXISTS {};', + 'sqlite': 'DROP TABLE IF EXISTS {};', + } + + with engine.begin() as conn: + for table in reversed(metadata.sorted_tables): + conn.execute(truncate_query[engine.name].format(table.name)) + + return M.ResponseBase() diff --git a/REST_AI_ENGINE_CONTROL/app/routes/index.py b/REST_AI_ENGINE_CONTROL/app/routes/index.py new file mode 100644 index 0000000..4e2648e --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/routes/index.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +""" +@File: index.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: default route + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +from fastapi import APIRouter + +from app.utils.date_utils import D +from app.models import SWInfo + + +router = APIRouter() + + +@router.get('/', summary='서비스 정보', response_model=SWInfo) +async def index(): + """ + ## 서비스 정보 + 소프트웨어 이름, 버전정보, 현재시간 + + **결과** + - SWInfo + """ + sw_info = SWInfo() + sw_info.date = D.date_str() + return sw_info diff --git a/REST_AI_ENGINE_CONTROL/app/routes/services.py b/REST_AI_ENGINE_CONTROL/app/routes/services.py new file mode 100644 index 0000000..b1c8a70 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/routes/services.py @@ -0,0 +1,878 @@ +# -*- coding: utf-8 -*- +""" +@File: services.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: services api + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +import os, sys +import bcrypt +from fastapi import APIRouter, Depends, Body, UploadFile, File, Query +from typing import List +from sqlalchemy.orm import Session +from starlette.requests import Request +import copy + +import time +import threading + +AI_ENGINE_PATH = "/AI_ENGINE" +sys.path.append( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))) + AI_ENGINE_PATH) + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))))) # const + +import ai_engine_const as AI_CONST +import AI_ENGINE.demo_utils as DEMO + +from AI_ENGINE.instance_queue import manager as message_queue +from app.utils.extra import send_mail +from app.common import consts +from app import models as M +from app.database.conn import Base, db +from app.database.schema import Users, UserLog +from app.database.crud import table_select, table_update, table_delete +from app.utils.date_utils import D + +from app.utils.extra import FernetCrypto,file_size_check + +router = APIRouter(prefix='/services') + +global engine_info +engine_info = M.AEInfo(**AI_CONST.AI_ENGINE_INIT) + +function_doc = f""" + ## AI 엔진 정보 + + 현재 상태의 AI 엔진 정보 조회한다. + - AI Model 정보 포함 + + + ### 결과 + - AEInfo + """ + + +@router.get('/AE/Info', response_model=M.AEInfo, summary='AI 엔진 정보') +async def ai_engine_info_get() -> M.AEInfo: + # result: M.AEInfo = M.AEInfo(version=1.0) + global engine_info + result = copy.deepcopy(engine_info) + + result.demo.ftp.pw = "" + return result + + +ai_engine_info_get.__doc__ = function_doc + +function_doc = f""" + ## AI 엔진 설정 + + AI 엔진 정보 설정한다. + - AI Model 정보 포함 + + ### 결과 + - AEInfo + """ + + +@router.post('/AE/Info-Set', response_model=M.AEInfo, summary='AI 엔진 설정') +async def ai_engine_info_set(request: Request, request_body_info: M.AEInfoSetReq) -> M.AEInfo: + # TODO(jwkim): key 하나만 변경할 수 있게 수정 + global engine_info + engine_info.version = request_body_info.version if request_body_info.version else engine_info.version + engine_info.ai_engine_status = request_body_info.ai_engine_status if request_body_info.ai_engine_status else engine_info.ai_engine_status + + engine_info.demo = request_body_info.demo if request_body_info.demo else engine_info.demo + + engine_info.input_video = request_body_info.input_video if request_body_info.input_video else engine_info.input_video + engine_info.input_bi = request_body_info.input_bi if request_body_info.input_bi else engine_info.input_bi + + # 모델정보 + engine_info.con_model_info = request_body_info.con_model_info if request_body_info.con_model_info else engine_info.con_model_info + engine_info.fr_model_info = request_body_info.fr_model_info if request_body_info.fr_model_info else engine_info.fr_model_info + engine_info.ppe_model_info = request_body_info.ppe_model_info if request_body_info.ppe_model_info else engine_info.ppe_model_info + engine_info.wd_model_info = request_body_info.wd_model_info if request_body_info.wd_model_info else engine_info.wd_model_info + engine_info.bi_model_info = request_body_info.bi_model_info if request_body_info.bi_model_info else engine_info.bi_model_info + + result = copy.deepcopy(engine_info) + + result.demo.ftp.pw = "" + return result + + +ai_engine_info_set.__doc__ = function_doc + + +# @router.post('/AE/RI-Set', response_model=M.ResponseBase, summary='RI 정보 설정') +# async def ai_engine_ri_set(request: Request, request_body_info: M.SendMailReq) -> M.ResponseBase: +# """ +# ## RI 정보 설정 +# +# :param request: +# :param request_body_info: +# +# **결과** +# - ResponseBase +# """ +# +# return M.ResponseBase() + +# @router.post('/AE/FR-Recognize/id',response_model=M.ResponseBase, summary='안면 인식 이미지 등록') +# async def create_upload_files(id:List = Query(default=[]), images: List[UploadFile] = File(...)): +# """ +# TODO(jwkim):개발 진행중 +# """ +# try: +# from AI_ENGINE.instance_queue import FR + +# if len(id) != len(images) or len(id) == 0 or len(images) == 0: +# raise IndexError("check id,images length") + +# fr_id_dict = {} +# for i in range(len(images)): +# size = file_size_check(images[i]) + +# fr_id_dict[id[i]] ={ +# "image": images[i], +# "size": size +# } +# FR.fr_id_info = fr_id_dict +# return M.ResponseBase() + +# except Exception as e: +# return M.ResponseBase.set_error(str(e)) + +@router.post('/AE/OD-CON-SETUP', response_model=M.ResponseBase, summary='시설물 탐지 (작업전) [영상]') +async def ai_engine_od_con_setup_detect(request: Request, request_body_info: M.AEAIModelConSetupDetectReq): + """ + ## 시설물 탐지 (작업전) + + 제한 시간(분) 동안, AI Engine에 등록된 영상(RTSP)을 통해서 시설물을 탐지한다. + + ### AI 모델 [영상] + - OD(Object Detection) + + ### 동작 방식 + - 시설물 탐지는 화면상에 보이는 모든 시설물 탐지 + - 제한 시간 안에 정해진 시설물을 모두 탐지한 경우, 해당 시점에서 결과 전송(MQTT) 후 탐지종료 (탐지 성공) + - 제한 시간 안에 정해진 시설물을 탐지하지 못한 경우, 탐지한 결과만 전송(MQTT) 후 탐지종료 (탐지 실패) + - 신규 대상을 인식할 때 마다 결과 전송이 필요한 경우, 옵션(report_unit)을 true로 사용 + - 이미 해당 기능이 동작 중인 경우: 이전 동작은 종료하고 새로 진행 + - 동작 중인 상태에서 입력 영상 변경시 시설물 탐지는 종료 되며 에러메세지 MQTT로 전송 + + ### 고려(협의) 사항 + - RI 기반 시설물 탐지? + - 탐지할 시설물 목록에 대해서 공종별로 다르게 할 것인지, 전 공종에 대해서 통일 할 것인지 결정 필요 + - 시설물 탐지시설(신호수, 라바콘, 교통안내표지판, ...) 목록에 대한 결정 필요 + - **결과(MQTT) 구조** + + ### 결과 (API) + - API parameter 유효성 검증 결과 및 AI 모델 동작 유무에 따른 결과가 API 결과(ResponseBase)로 반환 (탐지결과 X) + - 실제 탐지 결과는 MQTT의 해당 topic(/AI_KEPCO/AI_OD_CON_SETUP_DETECT/REPORT)으로 전송 (탐지 결과 O) + + ### 결과 (MQTT Topic: /AI_KEPCO/AI_OD_CON_SETUP_DETECT/REPORT) + - MQTT JSON 구조 + ``` + { + 'datetime' : string -> MQTT 메세지 보낸 시각 YY-mm-dd HH:MM:SS.sss + 'status': string -> PPE 탐지 상태 ['start', 'complete', 'new', 'timeout', 'stop', ...] + start: 탐지시작 (result에는 탐지 대상 전부 포함) + complete: 탐지완료 (주어진 대상 탐지 완료, result에는 탐지 대상 전부 포함) + new: 신규탐지 (신규 대상 탐지, result에는 신규 대상 정보만 포함) + timeout: 시간초과 (result에는 탐지한 내용만 포함) + stop: 종료/취소 (result에는 탐지한 대상만 포함) + ...: 오류발생 (status 항목에는 오류메시지, result에는 탐지한 대상만 포함) + 'result': array[DetectObjectInfo] -> 탐지 장비 리스트 + } + ``` + - 탐지 객체 정보 (미정, 협의 필요) + ``` + DetectObjectInfo + { + 'name': string -> 개체명 + 'cid': string -> 클래스 ID + 'confidence': number -> 신뢰도 (0.00 ~ 1.00) + 'bbox': array[number] -> 탐지 영역 좌표 [x1, y1, x2, y2] + 'image': string -> 이미지 데이터 (binary data) (미정) + ... + } + ``` + - 예) 탐지대상 리스트 (안전모, 안전장갑, 안전화, 안전대) 경우, MQTT 결과 + * Result MQTT Data(JSON) – 탐지시작 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘start’, + ‘result’: [ + null, + null, + null, + null, + null + ] + } + ``` + * Result MQTT Data(JSON) – 탐지완료 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘complete’, + ‘result’: [ + {'name': 'signalman', 'cid': 1 , 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + {'name': 'sign_traffic_cone', 'cid': 25, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + {'name': 'sign_board_information', 'cid': 22, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + {'name': 'sign_board_construction', 'cid': 23, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + {'name': 'sign_board_traffic', 'cid': 24, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...} + ] + } + ``` + * Result MQTT Data(JSON) – 신규탐지 (safety_gloves_work_on 인식한 경우) + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘new’, + ‘result’: [ + null, + {'name': 'sign_traffic_cone', 'cid': 25, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + null, + null, + null + ] + } + ``` + * Result MQTT Data(JSON) – 시간초과 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘timeout’, + ‘result’: [ + null, + {'name': 'sign_traffic_cone', 'cid': 25, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + {'name': 'sign_board_information', 'cid': 22, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + null, + null + ] + } + ``` + * Result MQTT Data(JSON) – 종료/취소 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘stop’, + ‘result’: [ + null, + {'name': 'sign_traffic_cone', 'cid': 25, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + {'name': 'sign_board_information', 'cid': 22, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + null, + null + ] + } + ``` + * Result MQTT Data(JSON) – 오류발생 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘ai model error ...’, + ‘result’: [ + null, + null, + null, + {'name': 'sign_board_construction', 'cid': 23, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + null + ] + } + ``` + """ + # time.sleep(2) + global engine_info + + message_queue.sender({ + "request": request_body_info, + "ai_model": AI_CONST.MODEL_CON, + "signal": AI_CONST.SIGNAL_INFERENCE, + "engine_info": engine_info, + "argument": { + # "source": video_abs_path + } + }) + + result = M.ResponseBase() + + return result + + +@router.post('/AE/OD-CON-SETUP-Stop', response_model=M.ResponseBase, summary='시설물 탐지 중지') +async def ai_engine_od_con_setup_detect_stop(request: Request): + """ + ## 시설물 탐지 중지 + + 인위적으로 중지(취소/종료)할 경우에 사용한다. + - 개인보호구 탐지 모델 동작 중지 + - MQTT의 해당 topic(/AI_KEPCO/AI_OD_PPE_DETECT/REPORT) 전송 중지 + + ### 결과 + - ResponseBase + """ + message_queue.sender({ + "request": Request, + "ai_model": AI_CONST.MODEL_CON, + "signal": AI_CONST.SIGNAL_STOP + }) + result = M.ResponseBase() + + return result + +@router.post('/AE/FR-Recognize', response_model=M.ResponseBase, summary='안면 인식 [영상]') +async def ai_engine_fr_recognize(request: Request, request_body_info: M.AEAIModelFRReq) -> M.ResponseBase: + """ + ## 안면 인식 + + 제한 시간(분) 동안, AI Engine에 등록된 영상(RTSP)을 통해서 작업자의 안면 인식을 한다. + + ### AI 모델 [영상] + - FR(Face Recognition) + + ### 동작 방식 + - 안면 인식은 동시에 여러명 인식 가능 + - 제한 시간 안에 인식 대상을 모두 탐지한 경우: 해당 시점에서 결과 전송(MQTT) 후 자동 인식종료 (인식 성공) + - 제한 시간 안에 인식 대상을 모두 탐지 못한 경우: 인식한 결과만 전송(MQTT) 후 자동 인식종료 (인식 실패) + - 신규 대상을 인식할 때 마다 결과 전송이 필요한 경우: 옵션 설정 (report_unit -> true) + - 이미 해당 기능이 동작 중인 경우: 이전 동작은 종료하고 새로 진행 + - 동작 중인 상태에서 입력 영상 변경시 안면 인식은 종료 되며 에러메세지 MQTT로 전송 + + ### 고려(협의) 사항 + - **인식자 정보(얼굴 이미지) 처리 방법** + + 클라이언트에서 인식자 얼굴 이미지를 획득하고 이미지 정보를 REST 서버로 전달해서 처리하는 경우 + * REST 서버에서는 주어진 이미지 데이터를 이용해서 인식처리 + + 클라이언트에서 인식자 고유정보만 REST 서버로 전달해서 처리하는 경우 + * 인식자 고유 정보를 사용할 경우, 고유값(사원번호, 주민번호, 등) 정의 필요 + * REST 서버에서 인식자 고유 정보를 이용해서 얼굴 이미지를 다운받아서 인식처리 + - **결과(MQTT) 구조** + + ### 결과 (API) + - API parameter 유효성 검증 결과 및 AI 모델 동작 유무에 따른 결과가 API 결과(ResponseBase)로 반환 (인식결과 X) + - 실제 인식 결과는 MQTT의 해당 topic(/AI_KEPCO/AI_FACE_RECOGNIZE/REPORT)으로 전송 (인식 결과 O) + + ### 결과 (MQTT Topic: /AI_KEPCO/AI_FACE_RECOGNIZE/REPORT) + - MQTT JSON 구조 + ``` + { + 'datetime' : string -> MQTT 메세지 보낸 시각 YY-mm-dd HH:MM:SS.sss + 'status': string -> 안면 인식 상태 ['start', 'complete', 'new', 'timeout', 'stop', ...] + start: 인식시작 (result에는 인식 대상자 전부 포함) + complete: 인식완료 (result에는 인식 대상자 전부 포함) + new: 신규인식 (result에는 신규 대상 정보만 포함) + timeout: 시간초과 (result에는 인식한 대상자 내용만 포함) + stop: 종료/취소 (result에는 인식한 대상자 내용만 포함) + ...: 오류발생 (status 항목에는 오류메시지, result에는 인식한 대상자 내용만 포함) + 'result': array => 인식자 고유 정보 리스트 + } + ``` + + - 예) 인식 대상자 [‘no0001’, ‘no0002’, ‘no0003’] 경우, MQTT 결과 + * Result MQTT Data(JSON) – 인식시작 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘start’, + ‘result’: [null, null, null] + } + ``` + * Result MQTT Data(JSON) – 인식완료 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘complete’, + ‘result’: [‘no000001’, ‘no000002’, ‘no000003’] + } + ``` + * Result MQTT Data(JSON) – 신규인식 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘new’, + ‘result’: [null, ‘no000002’, null] + } + ``` + * Result MQTT Data(JSON) – 시간초과 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘timeout’, + ‘result’: [‘no000001’, ‘no000002’, null] + } + ``` + * Result MQTT Data(JSON) – 종료/취소 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘stop’, + ‘result’: [‘no000001’, ‘no000002’, null] + } + ``` + * Result MQTT Data(JSON) – 오류발생 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘[error] ai model error ...’, + ‘result’: [‘no000001’, null, null] + } + ``` + """ + global engine_info + message_queue.sender({ + "request": request_body_info, + "ai_model": AI_CONST.MODEL_FACE_RECOGNIZE, + "signal": AI_CONST.SIGNAL_INFERENCE, + "engine_info": engine_info + }) + return M.ResponseBase() + + +@router.post('/AE/FR-Recognize-Stop', response_model=M.ResponseBase, summary='안면 인식 중지') +async def ai_engine_fr_recognize_stop(request: Request): + """ + ## 안면 인식 중지 + + 인위적으로 중지(취소/종료)할 경우에 사용한다. + - 안면 인식 모델 동작 중지 + - MQTT의 해당 topic(/AI_KEPCO/AI_FACE_RECOGNIZE/REPORT) 전송 중지 + + ### 결과 + - ResponseBase + """ + message_queue.sender({ + "request": Request, + "ai_model": AI_CONST.MODEL_FACE_RECOGNIZE, + "signal": AI_CONST.SIGNAL_STOP + }) + result = M.ResponseBase() + + return result + + +@router.post('/AE/OD-PPE-Detect', response_model=M.ResponseBase, summary='개인보호구(PPE) 탐지 (작업전) [영상]') +async def ai_engine_od_ppe_detect(request: Request, request_body_info: M.AEAIModelPPEDetectReq): + """ + ## 개인보호구(PPE) 탐지 (작업전) + + 제한 시간(분) 동안, AI Engine에 등록된 영상(RTSP)을 통해서 작업자의 개인보호구(PPE)를 탐지한다. + + ### AI 모델 [영상] + - OD(Object Detection) + + ### 동작 방식 + - 개인보호구(PPE) 탐지는 한 사람씩 탐지 (여러명 탐지 불가) + - 제한 시간 안에 정해진 개인보호구(PPE)를 탐지한 경우, 해당 시점에서 결과 전송(MQTT) 후 탐지종료 (탐지 성공) + - 제한 시간 안에 정해진 개인보호구(PPE)를 탐지하지 못한 경우, 탐지한 결과만 전송(MQTT) 후 탐지종료 (탐지 실패) + - 신규 대상을 인식할 때 마다 결과 전송이 필요한 경우, 옵션(report_unit)을 true로 사용 + - 이미 해당 기능이 동작 중인 경우: 이전 동작은 종료하고 새로 진행 + - 동작 중인 상태에서 입력 영상 변경시 개인보호구(PPE) 탐지는 종료 되며 에러메세지 MQTT로 전송 + + ### 고려(협의) 사항 + - RI 기반 개인보호구(PPE) 탐지? + - 탐지할 개인보호구 목록에 대해서 공종별로 다르게 할 것인지, 전 공종에 대해서 통일 할 것인지 결정 필요 + - 개인보호구(PPE) 탐지장비(안전모, 안전장갑, 안전화, 안전대, ...) 목록에 대한 결정 필요 + - **결과(MQTT) 구조** + + ### 결과 (API) + - API parameter 유효성 검증 결과 및 AI 모델 동작 유무에 따른 결과가 API 결과(ResponseBase)로 반환 (탐지결과 X) + - 실제 탐지 결과는 MQTT의 해당 topic(/AI_KEPCO/AI_OD_PPE_DETECT/REPORT)으로 전송 (탐지 결과 O) + + ### 결과 (MQTT Topic: /AI_KEPCO/AI_OD_PPE_DETECT/REPORT) + - MQTT JSON 구조 + ``` + { + 'datetime' : string -> MQTT 메세지 보낸 시각 YY-mm-dd HH:MM:SS.sss + 'status': string -> PPE 탐지 상태 ['start', 'complete', 'new', 'timeout', 'stop', ...] + start: 탐지시작 (result에는 탐지 대상 전부 포함) + complete: 탐지완료 (주어진 대상 탐지 완료, result에는 탐지 대상 전부 포함) + new: 신규탐지 (신규 대상 탐지, result에는 신규 대상 정보만 포함) + timeout: 시간초과 (result에는 탐지한 내용만 포함) + stop: 종료/취소 (result에는 탐지한 대상만 포함) + ...: 오류발생 (status 항목에는 오류메시지, result에는 탐지한 대상만 포함) + 'result': array[DetectObjectInfo] -> 탐지 장비 리스트 + } + ``` + - 탐지 객체 정보 (미정, 협의 필요) + ``` + DetectObjectInfo + { + 'name': string -> 개체명 + 'cid': string -> 클래스 ID + 'confidence': number -> 신뢰도 (0.00 ~ 1.00) + 'bbox': array[number] -> 탐지 영역 좌표 [x1, y1, x2, y2] + 'image': string -> 이미지 데이터 (binary data) (미정) + ... + } + ``` + - 예) 탐지대상 리스트 (안전모, 안전장갑, 안전화, 안전대) 경우, MQTT 결과 + * Result MQTT Data(JSON) – 탐지시작 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘start’, + ‘result’: [ + null, + null, + null, + null + ] + } + ``` + * Result MQTT Data(JSON) – 탐지완료 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘complete’, + ‘result’: [ + {'name': 'safety_helmet_on', 'cid': 1 , 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + {'name': 'safety_gloves_work_on', 'cid': 2, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + {'name': 'safety_boots_on', 'cid': 3, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + {'name': 'safety_belt_basic_on', 'cid': 4, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...} + ] + } + ``` + * Result MQTT Data(JSON) – 신규탐지 (safety_gloves_work_on 인식한 경우) + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘new’, + ‘result’: [ + null, + {'name': 'safety_gloves_work_on', 'cid': 2, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + null, + null + ] + } + ``` + * Result MQTT Data(JSON) – 시간초과 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘timeout’, + ‘result’: [ + null, + {'name': 'safety_gloves_work_on', 'cid': 2, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + {'name': 'safety_boots_on', 'cid': 3, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + null + ] + } + ``` + * Result MQTT Data(JSON) – 종료/취소 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘stop’, + ‘result’: [ + null, + {'name': 'safety_gloves_work_on', 'cid': 2, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + {'name': 'safety_boots_on', 'cid': 3, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + null + ] + } + ``` + * Result MQTT Data(JSON) – 오류발생 + ``` + { + ‘datetime’: ‘23-01-19 10:18:04.675’, + ‘status’: ‘ai model error ...’, + ‘result’: [ + null, + null, + {'name': 'safety_boots_on', 'cid': 3, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + null + ] + } + ``` + """ + # time.sleep(2) + global engine_info + + message_queue.sender({ + "request": request_body_info, + "ai_model": AI_CONST.MODEL_PPE, + "signal": AI_CONST.SIGNAL_INFERENCE, + "engine_info": engine_info, + "argument": { + # "source": video_abs_path + } + + }) + + result = M.ResponseBase() + + return result + + +@router.post('/AE/OD-PPE-Detect-Stop', response_model=M.ResponseBase, summary='개인보호구(PPE) 탐지 중지') +async def ai_engine_od_ppe_detect_stop(request: Request): + """ + ## 개인보호구(PPE) 탐지 중지 + + 인위적으로 중지(취소/종료)할 경우에 사용한다. + - 개인보호구 탐지 모델 동작 중지 + - MQTT의 해당 topic(/AI_KEPCO/AI_OD_PPE_DETECT/REPORT) 전송 중지 + + ### 결과 + - ResponseBase + """ + message_queue.sender({ + "request": Request, + "ai_model": AI_CONST.MODEL_PPE, + "signal": AI_CONST.SIGNAL_STOP + }) + result = M.ResponseBase() + + return result + + +@router.post('/AE/OD-Work-Detect', response_model=M.ResponseBase, summary='위험성 탐지 (작업중) [영상]') +async def ai_engine_od_work_detect(request: Request, request_body_info: M.AEAIModelRBIVideoReq): + """ + ## 위험성 탐지 (작업중) + + AI Engine에 연동된 영상 서비스(RTSP - 다채널)의 데이터를 객체탐지 모델을 이용하여 위험성(RI)을 탐지한다. + + ### AI 모델 [영상] + - OD(Object Detection) + + ### 동작 방식 + - 탐지된 정보(객체)를 RI 기반으로 위험성을 판단하고 결과를 MQTT로 전송 + - 위험성이 탐지될 때 마다 탐지 결과를 MQTT로 전송 + - 위험성 탐지 결과는 누적정보가 아닌 해당 시점에서 발생한 탐지결과 + - 위험성 탐지 종료는 API(/api/services/AE/OD-Work-Detect-Stop)를 호출하면 종료 + - 이미 해당 기능이 동작 중인 경우: 이전 동작은 종료하고 새로 진행 + - 동작 중인 상태에서 입력 영상 변경시 위험성 탐지는 종료 되며 에러메세지 MQTT로 전송 + + ### 고려(협의) 사항 + - RI 정보 미정 + - **결과(MQTT) 구조** + + ### 결과(API) + - API parameter 유효성 검증 결과 및 AI 모델 동작 유무에 따른 결과가 API 결과(ResponseBase)로 반환 (탐지결과 X) + - 실제 탐지 결과는 MQTT의 해당 topic(/AI_KEPCO/AI_OD_WORK_DETECT/REPORT)으로 전송 (탐지 결과 O) + + ### 결과(MQTT Topic: /AI_KEPCO/AI_OD_WORK_DETECT/REPORT) + - MQTT JSON 구조 + ``` + { + 'datetime' : string -> MQTT 메세지 보낸 시각 YY-mm-dd HH:MM:SS.sss + 'status': string -> 위험성 탐지 상태 ['start', 'detect', 'stop', ...] + start: 탐지시작 (result에는 인식 대상 전부 포함) + detect: 위험성 탐지 (위험성이 탐지될 때 마다 발생, result에는 신규 대상 정보만 포함) + stop: 종료/취소 (result에는 탐지한 대상만 포함) + ...: 오류발생 (status 항목에는 오류메시지, result에는 탐지한 대상만 포함) + 'result': ConstructionRiskIndexInfo -> 공사 위험성 정보 + } + ``` + - 탐지 객체 정보 구조 (미정, 협의 필요) + ``` + DetectObjectInfo + { + 'name': string -> 객체명 + 'cid': number -> 클래스 ID + 'confidence': number -> 신뢰도 (0.00 ~ 1.00) + 'bbox': array[number] -> 탐지 영역 좌표 [x1, y1, x2, y2] + 'image': string -> 이미지 데이터 (binary data) (미정) + ... + } + ``` + - 위험성 정보 구조 (미정, 협의 필요) + ``` + ConstructionRiskIndexInfo + { + 'construction_type': string -> 공종 + 'procedure_no': number -> 해당 공종의 시공 순번 + 'procedure_ri': number -> 공종의 i번째 공사절차(procedure_no)의 정의된 위험도(0.00 ~ 1.00) + 'ri': number -> 해당 공사절차의 평가된 위험도 (procedure_ri - 측정된 RI값) + 'detect_list': array[DetectInfo, ...] 탐지된 객체 리스트 (옵션) + ... + } + ``` + - 예) 위험성 탐지 + * Result MQTT Data(JSON) - 탐지 시작 + ``` + { + 'datetime': '23-01-19 10:18:04.675', + 'status': 'start', + 'result': { + 'construction_type': 'D4' + ‘procedure_no’: 1, + 'procedure_ri': 0.70, + ri: 0.32, + ‘detect_list’: [ + ] + } + } + ``` + * Result MQTT Data(JSON) - 위험성 탐지 + ``` + { + 'datetime': '23-01-19 10:18:04.675', + 'status': 'detect', + 'result': { + 'construction_type': 'D4' + ‘procedure_no’: 1, + 'procedure_ri': 0.70, + ri: 0.32, + ‘detect_list’: [ + {'name': 'person', 'cid': 0, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + {'name': 'safety_helmet_on', 'cid': 3, 'confidence': 0.75, 'bbox': [ 350, 350, 200, 400 ], ...} + ] + } + } + ``` + * Result MQTT Data(JSON) - 종료/취소 + ``` + { + 'datetime': '23-01-19 10:18:04.675', + 'status': 'stop', + 'result': { + 'construction_type': 'D4' + ‘procedure_no’: 1, + 'procedure_ri': 0.70, + ri: 0.32, + ‘detect_list’: [ + {'name': 'person', 'cid': 0, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + {'name': 'safety_helmet_on', 'cid': 3, 'confidence': 0.75, 'bbox': [ 350, 350, 200, 400 ], ...}, + {'name': 'safety_boots_on', 'cid': '10', 'confidence': 0.86, 'bbox': [ 450, 450, 300, 400 ], ...}, + ... + ] + } + } + ``` + * Result MQTT Data(JSON) - 오류발생 + ``` + { + 'datetime': '23-01-19 10:18:04.675', + 'status': 'ai model error ...', + 'result': { + 'construction_type': 'D4' + ‘procedure_no’: 1, + 'procedure_ri': 0.70, + ri: 0.32, + ‘detect_list’: [ + {'name': 'person', 'cid': 0, 'confidence': 0.90, 'bbox': [ 150, 150, 100, 200 ], ...}, + {'name': 'safety_helmet_on', 'cid': 3, 'confidence': 0.75, 'bbox': [ 350, 350, 200, 400 ], ...}, + {'name': 'safety_boots_on', 'cid': '10', 'confidence': 0.86, 'bbox': [ 450, 450, 300, 400 ], ...}, + ... + ] + } + } + ``` + """ + + global engine_info + + message_queue.sender({ + "request": request_body_info, + "ai_model": AI_CONST.MODEL_WORK_DETECT, + "signal": AI_CONST.SIGNAL_INFERENCE, + "engine_info": engine_info, + "argument": { + # "source": video_abs_path + } + }) + + result = M.ResponseBase() + + return result + + +@router.post('/AE/OD-Work-Detect-Stop', response_model=M.ResponseBase, summary='위험성 탐지 중지') +async def ai_engine_od_work_detect_stop(request: Request): + """ + ## 위험성 탐지 중지 + + 위험성 탐지를 중지(취소/종료)할 경우에 사용 + - 위험성 탐지 모델 동작 중지 + - MQTT의 해당 topic(/AI_KEPCO/AI_OD_WORK_DETECT/REPORT) 전송 중지 + + ### 결과 + - ResponseBase + """ + + message_queue.sender({ + "request": Request, + "ai_model": AI_CONST.MODEL_WORK_DETECT, + "signal": AI_CONST.SIGNAL_STOP + }) + + DEMO.demo_wd_bi(0) + + result = M.ResponseBase() + + return result + + +@router.post('/AE/BI-Detect', response_model=M.ResponseBase, summary='위험성 탐지 [BI]') +async def ai_engine_bi_detect(request: Request, request_body_info: M.AEAIModelRBIBIReq): + """ + ## 위험성 탐지 [BI] + + AI Engine에 연동된 BI 센서 장비들을 통해서 획득한 생체정보(BI)를 RI 기반으로 위험성 평가(이상탐지)를 실시한다. + + ### AI 모델 [BI] + - BI(Biometric Information - Time Series Anomaly Detection) + + ### 동작 방식 + - 탐지된 정보(BI)를 RI정보 기반으로 위험성을 판단하고 결과를 MQTT로 전송 + - 위험성이 탐지될 때 마다 탐지 결과를 MQTT로 전송 + - 위험성 탐지 결과는 누적정보가 아닌 해당 시점에서 발생한 탐지결과 + - 위험성 탐지는 종료 API(/api/services/AE/BI-Detect-Stop)를 호출하면 종료 + + ### 고려(협의) 사항 + - RI 정보 미정 + - **결과(MQTT) 구조** + + ### 결과(API) + - API parameter 유효성 검증 결과 및 AI 모델 동작 유무에 따른 결과가 API 결과(ResponseBase)로 반환 (탐지결과 X) + - 실제 탐지 결과는 MQTT의 해당 topic(/AI_KEPCO/AI_BI_DETECT/REPORT)으로 전송 (탐지 결과 O) + + ### 결과(MQTT Topic: /AI_KEPCO/AI_BI_DETECT/REPORT) (미연동) + - MQTT JSON 구조 + ``` + { + 'datetime' : string -> MQTT 메세지 보낸 시각 YY-mm-dd HH:MM:SS.sss + 'status':string -> 위험성 탐지 상태 ['start', 'detect', 'stop', ...] + detect: 위험성 탐지 (위험성이 탐지될 때 마다 발생, result에는 신규 대상 정보만 포함) + stop: 종료/취소 (result에는 탐지한 대상만 포함) + ...: 오류발생 (status 항목에는 오류메시지, result에는 탐지한 대상만 포함) + 'result': array => 위험성 정보 리스트 + } + ``` + """ + # if DEMO.DEMO_WD_BI_CONST == 0 : + # return M.ResponseBase.set_error(str("Invalid state (not WD)")) + + # DEMO.demo_wd_bi(2) + + # DEMO.bi_snap_shot() + + result = M.ResponseBase() + + return result + + +@router.post('/AE/BI-Detect-Stop', response_model=M.ResponseBase, summary='위험성 탐지(BI) 중지') +async def ai_engine_bi_detect_stop(request: Request): + """ + ## 위험성 탐지(BI) 중지 + + 위험성 탐지를 중지(취소/종료)할 경우에 사용 + - 위험성 탐지 모델 동작 중지 + - MQTT의 해당 topic(/AI_KEPCO/AI_BI_DETECT/REPORT) 전송 중지 + + ### 결과 + - ResponseBase + """ + result = M.ResponseBase() + + return result diff --git a/REST_AI_ENGINE_CONTROL/app/routes/users.py b/REST_AI_ENGINE_CONTROL/app/routes/users.py new file mode 100644 index 0000000..d384758 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/routes/users.py @@ -0,0 +1,299 @@ +# -*- coding: utf-8 -*- +""" +@File: users.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: users api + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +from fastapi import APIRouter +from starlette.requests import Request +import bcrypt + +from app.common import consts +from app import models as M +from app.common.config import conf +from app.database.schema import Users +from app.database.crud import table_select, table_update, table_delete + +from app.utils.extra import AESCryptoCBC + + +router = APIRouter(prefix='/user') + + +@router.get('/me', response_model=M.UserSearchRes, summary='접속자 정보') +async def get_me(request: Request): + """ + ## 현재 접속된 자신정보 확인 + + ***현 버전 미지원(추후 상세 처리)*** + + **결과** + - UserSearchRes + """ + target_table = Users + search_info = None + + try: + # request + if conf().GLOBAL_TOKEN: + raise Exception('not supported: use search api!') + + accessor_info = request.state.user + if not accessor_info: + raise Exception('invalid accessor') + + # search + search_info = target_table.get(account=accessor_info.account) + if not search_info: + raise Exception('not found data') + + # result + result_info = list() + result_info.append(M.UserInfo.from_orm(search_info)) + return M.UserSearchRes(data=result_info) + except Exception as e: + if search_info: + search_info.close() + return M.ResponseBase.set_error(str(e)) + + +@router.post('/search', response_model=M.UserSearchPagingRes, summary='유저정보 검색') +async def user_search(request: Request, request_body_info: M.UserSearchPagingReq): + """ + ## 유저정보 검색 (기본) + 검색 조건은 유저 테이블 항목만 가능 (Request body: Schema 참조)\n + 관련 정보 연동 검색은 별도 API 사용 + + **세부항목** + - **paging**\n + 항목 미사용시에는 페이징기능 없이 검색조건(search) 결과 모두 반환 + + - **search**\n + 검색에 필요한 항목들을 search 에 포함시킨다.\n + 검색에 사용된 각 항목들은 AND 조건으로 처리된다. + + - **전체검색**\n + empty object 사용 ( {} )\n + 예) "search": {} + + - **검색항목**\n + - 부분검색 항목\n + SQL 문법( %, _ )을 사용한다.\n + __like: 시작포함(X%), 중간포함(%X%), 끝포함(%X) + - 구간검색 항목 + * __lt: 주어진 값보다 작은값 + * __lte: 주어진 값보다 같거나 작은 값 + * __gt: 주어진 값보다 큰값 + * __gte: 주어진 값보다 같거나 큰값 + + **결과** + - UserSearchPagingRes + """ + return await table_select(request.state.user, Users, request_body_info, M.UserSearchPagingRes, M.UserInfo) + + +@router.put('/update', response_model=M.ResponseBase, summary='유저정보 변경') +async def user_update(request: Request, request_body_info: M.UserUpdateMultiReq): + """ + ## 유저정보 변경 + + **search_info**: 변경대상\n + + **update_info**: 변경내용\n + - **비밀번호** 제외 + + **결과** + - ResponseBase + """ + return await table_update(request.state.user, Users, request_body_info, M.ResponseBase) + + +@router.put('/update_pw', response_model=M.ResponseBase, summary='유저 비밀번호 변경') +async def user_update_pw(request: Request, request_body_info: M.UserUpdatePWReq): + """ + ## 유저정보 비밀번호 변경 + + **account**의 **비밀번호**를 변경한다. + + **결과** + - ResponseBase + """ + target_table = Users + search_info = None + + try: + # request + accessor_info = request.state.user + if not accessor_info: + raise Exception('invalid accessor') + if not request_body_info.account: + raise Exception('invalid account') + + # decrypt pw + try: + decode_cur_pw = request_body_info.current_pw.encode('utf-8') + desc_cur_pw = AESCryptoCBC().decrypt(decode_cur_pw) + except Exception as e: + raise Exception(f'failed decryption [current_pw]: {e}') + try: + decode_new_pw = request_body_info.new_pw.encode('utf-8') + desc_new_pw = AESCryptoCBC().decrypt(decode_new_pw) + except Exception as e: + raise Exception(f'failed decryption [new_pw]: {e}') + + # search + target_user = target_table.get(account=request_body_info.account) + is_verified = bcrypt.checkpw(desc_cur_pw, target_user.pw.encode('utf-8')) + if not is_verified: + raise Exception('invalid password') + + search_info = target_table.filter(id=target_user.id) + if not search_info.first(): + raise Exception('not found data') + + # process + hash_pw = bcrypt.hashpw(desc_new_pw, bcrypt.gensalt()) + result_info = search_info.update(auto_commit=True, pw=hash_pw) + if not result_info or not result_info.id: + raise Exception('failed update') + + # result + return M.ResponseBase() + except Exception as e: + if search_info: + search_info.close() + return M.ResponseBase.set_error(str(e)) + + +@router.delete('/delete', response_model=M.ResponseBase, summary='유저정보 삭제') +async def user_delete(request: Request, request_body_info: M.UserSearchReq): + """ + ## 유저정보 삭제 + 조건에 해당하는 정보를 모두 삭제한다.\n + - **본 API는 DB에서 완적삭제를 하는 함수이며, 서버관리자가 사용하는 것을 권장한다.** + - **update API를 사용하여 상태 항목을 변경해서 사용하는 것을 권장.** + + `유저삭제시 관계 테이블의 정보도 같이 삭제된다.` + + **결과** + - ResponseBase + """ + return await table_delete(request.state.user, Users, request_body_info, M.ResponseBase) + + +# NOTE(hsj100): apikey +""" +""" +# @router.get('/apikeys', response_model=List[M.GetApiKeyList]) +# async def get_api_keys(request: Request): +# """ +# API KEY 조회 +# :param request: +# :return: +# """ +# user = request.state.user +# api_keys = ApiKeys.filter(user_id=user.id).all() +# return api_keys +# +# +# @router.post('/apikeys', response_model=M.GetApiKeys) +# async def create_api_keys(request: Request, key_info: M.AddApiKey, session: Session = Depends(db.session)): +# """ +# API KEY 생성 +# :param request: +# :param key_info: +# :param session: +# :return: +# """ +# user = request.state.user +# +# api_keys = ApiKeys.filter(session, user_id=user.id, status='active').count() +# if api_keys == MAX_API_KEY: +# raise ex.MaxKeyCountEx() +# +# alphabet = string.ascii_letters + string.digits +# s_key = ''.join(secrets.choice(alphabet) for _ in range(40)) +# uid = None +# while not uid: +# uid_candidate = f'{str(uuid4())[:-12]}{str(uuid4())}' +# uid_check = ApiKeys.get(access_key=uid_candidate) +# if not uid_check: +# uid = uid_candidate +# +# key_info = key_info.dict() +# new_key = ApiKeys.create(session, auto_commit=True, secret_key=s_key, user_id=user.id, access_key=uid, **key_info) +# return new_key +# +# +# @router.put('/apikeys/{key_id}', response_model=M.GetApiKeyList) +# async def update_api_keys(request: Request, key_id: int, key_info: M.AddApiKey): +# """ +# API KEY User Memo Update +# :param request: +# :param key_id: +# :param key_info: +# :return: +# """ +# user = request.state.user +# key_data = ApiKeys.filter(id=key_id) +# if key_data and key_data.first().user_id == user.id: +# return key_data.update(auto_commit=True, **key_info.dict()) +# raise ex.NoKeyMatchEx() +# +# +# @router.delete('/apikeys/{key_id}') +# async def delete_api_keys(request: Request, key_id: int, access_key: str): +# user = request.state.user +# await check_api_owner(user.id, key_id) +# search_by_key = ApiKeys.filter(access_key=access_key) +# if not search_by_key.first(): +# raise ex.NoKeyMatchEx() +# search_by_key.delete(auto_commit=True) +# return MessageOk() +# +# +# @router.get('/apikeys/{key_id}/whitelists', response_model=List[M.GetAPIWhiteLists]) +# async def get_api_keys(request: Request, key_id: int): +# user = request.state.user +# await check_api_owner(user.id, key_id) +# whitelists = ApiWhiteLists.filter(api_key_id=key_id).all() +# return whitelists +# +# +# @router.post('/apikeys/{key_id}/whitelists', response_model=M.GetAPIWhiteLists) +# async def create_api_keys(request: Request, key_id: int, ip: M.CreateAPIWhiteLists, session: Session = Depends(db.session)): +# user = request.state.user +# await check_api_owner(user.id, key_id) +# import ipaddress +# try: +# _ip = ipaddress.ip_address(ip.ip_addr) +# except Exception as e: +# raise ex.InvalidIpEx(ip.ip_addr, e) +# if ApiWhiteLists.filter(api_key_id=key_id).count() == MAX_API_WHITELIST: +# raise ex.MaxWLCountEx() +# ip_dup = ApiWhiteLists.get(api_key_id=key_id, ip_addr=ip.ip_addr) +# if ip_dup: +# return ip_dup +# ip_reg = ApiWhiteLists.create(session=session, auto_commit=True, api_key_id=key_id, ip_addr=ip.ip_addr) +# return ip_reg +# +# +# @router.delete('/apikeys/{key_id}/whitelists/{list_id}') +# async def delete_api_keys(request: Request, key_id: int, list_id: int): +# user = request.state.user +# await check_api_owner(user.id, key_id) +# ApiWhiteLists.filter(id=list_id, api_key_id=key_id).delete() +# +# return MessageOk() +# +# +# async def check_api_owner(user_id, key_id): +# api_keys = ApiKeys.get(id=key_id, user_id=user_id) +# if not api_keys: +# raise ex.NoKeyMatchEx() diff --git a/REST_AI_ENGINE_CONTROL/app/utils/date_utils.py b/REST_AI_ENGINE_CONTROL/app/utils/date_utils.py new file mode 100644 index 0000000..ba06509 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/utils/date_utils.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +""" +@File: date_utils.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: utility [date] + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +from datetime import datetime, date, timedelta + +_TIMEDELTA = 9 + + +class D: + def __init__(self, *args): + self.utc_now = datetime.utcnow() + # NOTE(hsj100): utc->kst + self.timedelta = _TIMEDELTA + + @classmethod + def datetime(cls, diff: int=_TIMEDELTA) -> datetime: + return datetime.utcnow() + timedelta(hours=diff) if diff > 0 else datetime.now() + + @classmethod + def date(cls, diff: int=_TIMEDELTA) -> date: + return cls.datetime(diff=diff).date() + + @classmethod + def date_num(cls, diff: int=_TIMEDELTA) -> int: + return int(cls.date(diff=diff).strftime('%Y%m%d')) + + @classmethod + def validate(cls, date_text): + try: + datetime.strptime(date_text, '%Y-%m-%d') + except ValueError: + raise ValueError('Incorrect data format, should be YYYY-MM-DD') + + @classmethod + def date_str(cls, diff: int = _TIMEDELTA) -> str: + return cls.datetime(diff=diff).strftime('%Y-%m-%dT%H:%M:%S') + + @classmethod + def date_version_str(cls, diff: int = _TIMEDELTA) -> str: + return cls.datetime(diff=diff).strftime('%Y%m%dT%H%M%2S') + + @classmethod + def check_expire_date(cls, expire_date: datetime): + td = expire_date - datetime.now() + timestamp = td.total_seconds() + return timestamp + + @classmethod + def date_str_micro_sec(cls, diff: int = _TIMEDELTA) -> str: + return cls.datetime(diff=diff).strftime('%y-%m-%d %H:%M:%S.%f')[:-3] + diff --git a/REST_AI_ENGINE_CONTROL/app/utils/extra.py b/REST_AI_ENGINE_CONTROL/app/utils/extra.py new file mode 100644 index 0000000..77199d9 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/utils/extra.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- +""" +@File: extra.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: utility [extra] + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" +import os + +from hashlib import md5 +from base64 import b64decode +from base64 import b64encode + +from Crypto.Cipher import AES +from Crypto.Util.Padding import pad, unpad +from cryptography.fernet import Fernet # symmetric encryption + +# mail test +import smtplib +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart + +from itertools import groupby +from operator import attrgetter +import uuid + +from app.common.consts import NUM_RETRY_UUID_GEN, SMTP_HOST, SMTP_PORT +from app.utils.date_utils import D +from app import models as M +from app.common.consts import AES_CBC_PUBLIC_KEY, AES_CBC_IV, FERNET_SECRET_KEY + + +async def send_mail(sender, sender_pw, title, recipient, contents_plain, contents_html, cc_list, smtp_host=SMTP_HOST, smtp_port=SMTP_PORT): + """ + 구글 계정사용시 : 보안 수준이 낮은 앱에서의 접근 활성화 + + :return: + None: success + Str. Message: error + """ + try: + # check parameters + if not sender: + raise Exception('invalid sender') + if not title: + raise Exception('invalid title') + if not recipient: + raise Exception('invalid recipient') + + # sender info. + # sender = consts.ADMIN_INIT_ACCOUNT_INFO.email + # sender_pw = consts.ADMIN_INIT_ACCOUNT_INFO.email_pw + + # message + msg = MIMEMultipart() + msg['From'] = sender + msg['To'] = recipient + if cc_list: + list_cc = cc_list + str_cc = ','.join(list_cc) + msg['Cc'] = str_cc + msg['Subject'] = title + + if contents_plain: + msg.attach(MIMEText(contents_plain, 'plain')) + if contents_html: + msg.attach(MIMEText(contents_html, 'html')) + + # smtp server + smtp_server = smtplib.SMTP(host=smtp_host, port=smtp_port) + smtp_server.ehlo() + smtp_server.starttls() + smtp_server.ehlo() + smtp_server.login(sender, sender_pw) + smtp_server.send_message(msg) + smtp_server.quit() + return None + except Exception as e: + return str(e) + + +def query_to_groupby(query_result, key, first=False): + """ + 쿼리 결과물(list)을 항목값(key)으로 그룹화한다. + + :param query_result: 쿼리 결과 리스트 + :param key: 그룹 항목값 + :return: dict + """ + group_info = dict() + for k, g in groupby(query_result, attrgetter(key)): + if k not in group_info: + if not first: + group_info[k] = list(g) + else: + group_info[k] = list(g)[0] + else: + if not first: + group_info[k].extend(list(g)) + return group_info + + +def query_to_groupby_date(query_result, key): + """ + 쿼리 결과물(list)을 항목값(key)으로 그룹화한다. + + :param query_result: 쿼리 결과 리스트 + :param key: 그룹 항목값 + :return: dict + """ + group_info = dict() + for k, g in groupby(query_result, attrgetter(key)): + day_str = k.strftime("%Y-%m-%d") + if day_str not in group_info: + group_info[day_str] = list(g) + else: + group_info[day_str].extend(list(g)) + return group_info + + +class FernetCrypto: + def __init__(self, key=FERNET_SECRET_KEY): + self.key = key + self.f = Fernet(self.key) + + def encrypt(self, data, is_out_string=True): + if isinstance(data, bytes): + ou = self.f.encrypt(data) # 바이트형태이면 바로 암호화 + else: + ou = self.f.encrypt(data.encode('utf-8')) # 인코딩 후 암호화 + if is_out_string is True: + return ou.decode('utf-8') # 출력이 문자열이면 디코딩 후 반환 + else: + return ou + + def decrypt(self, data, is_out_string=True): + if isinstance(data, bytes): + ou = self.f.decrypt(data) # 바이트형태이면 바로 복호화 + else: + ou = self.f.decrypt(data.encode('utf-8')) # 인코딩 후 복호화 + if is_out_string is True: + return ou.decode('utf-8') # 출력이 문자열이면 디코딩 후 반환 + else: + return ou + + +class AESCryptoCBC: + def __init__(self, key=AES_CBC_PUBLIC_KEY, iv=AES_CBC_IV): + # Initial vector를 0으로 초기화하여 16바이트 할당함 + # iv = chr(0) * 16 #pycrypto 기준 + # iv = bytes([0x00] * 16) #pycryptodomex 기준 + # aes cbc 생성 + self.key = key + self.iv = iv + self.crypto = AES.new(self.key, AES.MODE_CBC, self.iv) + + def encrypt(self, data): + # 암호화 message는 16의 배수여야 한다. + # enc = self.crypto.encrypt(data) + # return enc + enc = self.crypto.encrypt(pad(data, AES.block_size)) + return b64encode(enc) + + def decrypt(self, enc): + # 복호화 enc는 16의 배수여야 한다. + # dec = self.crypto.decrypt(enc) + # return dec + enc = b64decode(enc) + dec = self.crypto.decrypt(enc) + return unpad(dec, AES.block_size) + + +class AESCipher: + def __init__(self, key): + # self.key = md5(key.encode('utf8')).digest() + self.key = bytes(key.encode('utf-8')) + + def encrypt(self, data): + # iv = get_random_bytes(AES.block_size) + iv = bytes('daooldns12345678'.encode('utf-8')) + + self.cipher = AES.new(self.key, AES.MODE_CBC, iv) + t = b64encode(self.cipher.encrypt(pad(data.encode('utf-8'), AES.block_size))) + return b64encode(iv + self.cipher.encrypt(pad(data.encode('utf-8'), AES.block_size))) + + def decrypt(self, data): + raw = b64decode(data) + self.cipher = AES.new(self.key, AES.MODE_CBC, raw[:AES.block_size]) + return unpad(self.cipher.decrypt(raw[AES.block_size:]), AES.block_size) + +def file_size_check(data): + _start= data.file.tell() + data.file.seek(0,os.SEEK_END) + size = data.file.tell() + data.file.seek(_start,os.SEEK_SET) + + return size \ No newline at end of file diff --git a/REST_AI_ENGINE_CONTROL/app/utils/logger.py b/REST_AI_ENGINE_CONTROL/app/utils/logger.py new file mode 100644 index 0000000..7ea7629 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/utils/logger.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +""" +@File: logger.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: logger + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +import json +import logging +from datetime import timedelta, datetime +from time import time +from fastapi.requests import Request +from fastapi import Body +from fastapi.logger import logger + +logger.setLevel(logging.INFO) + + +async def api_logger(request: Request, response=None, error=None): + time_format = '%Y/%m/%d %H:%M:%S' + t = time() - request.state.start + status_code = error.status_code if error else response.status_code + error_log = None + user = request.state.user + if error: + if request.state.inspect: + frame = request.state.inspect + error_file = frame.f_code.co_filename + error_func = frame.f_code.co_name + error_line = frame.f_lineno + else: + error_func = error_file = error_line = 'UNKNOWN' + + error_log = dict( + errorFunc=error_func, + location='{} line in {}'.format(str(error_line), error_file), + raised=str(error.__class__.__name__), + msg=str(error.ex), + ) + + account = user.account.split('@') if user and user.account else None + user_log = dict( + client=request.state.ip, + user=user.id if user and user.id else None, + account='**' + account[0][2:-1] + '*@' + account[1] if user and user.account else None, + ) + + log_dict = dict( + url=request.url.hostname + request.url.path, + method=str(request.method), + statusCode=status_code, + errorDetail=error_log, + client=user_log, + processedTime=str(round(t * 1000, 5)) + 'ms', + datetimeUTC=datetime.utcnow().strftime(time_format), + datetimeKST=(datetime.utcnow() + timedelta(hours=9)).strftime(time_format), + ) + if error and error.status_code >= 500: + logger.error(json.dumps(log_dict)) + else: + logger.info(json.dumps(log_dict)) diff --git a/REST_AI_ENGINE_CONTROL/app/utils/query_utils.py b/REST_AI_ENGINE_CONTROL/app/utils/query_utils.py new file mode 100644 index 0000000..ee24b6f --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/app/utils/query_utils.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +""" +@File: query_utils.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: utility [query] + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +from typing import List + + +def to_dict(model, *args, exclude: List = None): + q_dict = {} + for c in model.__table__.columns: + if not args or c.name in args: + if not exclude or c.name not in exclude: + q_dict[c.name] = getattr(model, c.name) + + return q_dict diff --git a/REST_AI_ENGINE_CONTROL/docker-build.sh b/REST_AI_ENGINE_CONTROL/docker-build.sh new file mode 100644 index 0000000..f5e15a9 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/docker-build.sh @@ -0,0 +1,2 @@ +docker build -t kepco_ai_rbi/rest_ai_engine_control:latest . + diff --git a/REST_AI_ENGINE_CONTROL/docker-compose-service.yml b/REST_AI_ENGINE_CONTROL/docker-compose-service.yml new file mode 100644 index 0000000..2abe653 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/docker-compose-service.yml @@ -0,0 +1,42 @@ +version: '3.7' +services: + mysql: + image: mysql:latest + container_name: mysql + restart: always + environment: + TZ: Asia/Seoul + MYSQL_PORT: 3306 + MYSQL_ROOT_PASSWORD: "!ekdnfeldpsdptm1" + MYSQL_DATABASE: "AI_ENGINE_CONTROL" + MYSQL_USER: "aikepco" + MYSQL_PASSWORD: "!ekdnfeldpsdptm1" + healthcheck: + test: "mysql -uroot -p$$MYSQL_ROOT_PASSWORD -e \"SHOW DATABASES;\"" + timeout: 5s + retries: 5 + command: + - --character-set-server=utf8mb4 + - --collation-server=utf8mb4_unicode_ci + volumes: + - /MYSQL/data/:/var/lib/mysql + - /MYSQL/backup:/backupfiles + ports: + - 3306:3306 + + rest_api: + depends_on: + mysql: + condition: service_healthy + container_name: REST_AI_ENGINE_CONTROL + image: kepco_ai_rbi/rest_ai_engine_control:latest + build: + context: . + dockerfile: ./Dockerfile + environment: + - TZ=Asia/Seoul + volumes: + - ./app:/FAST_API + ports: + - 50520-50522:50520-50522 + command: ["uvicorn", "app.main:app", "--host", '0.0.0.0', "--port", "50520"] diff --git a/REST_AI_ENGINE_CONTROL/docker-compose.yml b/REST_AI_ENGINE_CONTROL/docker-compose.yml new file mode 100644 index 0000000..4199162 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/docker-compose.yml @@ -0,0 +1,10 @@ +version: '3' + +services: + api: + container_name: REST_AI_ENGINE_CONTROL + image: kepco_ai_rbi/rest_ai_engine_control:latest + build: + context: . + ports: + - 50220:50220 diff --git a/REST_AI_ENGINE_CONTROL/environment.yml b/REST_AI_ENGINE_CONTROL/environment.yml new file mode 100644 index 0000000..92ac7a0 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/environment.yml @@ -0,0 +1,11 @@ +name: REST_AI_ENGINE_CONTROL + +channels: + - conda-forge + +dependencies: + - python=3.9 + - pip + - pip: + - -r requirements.txt + diff --git a/REST_AI_ENGINE_CONTROL/gunicorn.conf.py b/REST_AI_ENGINE_CONTROL/gunicorn.conf.py new file mode 100644 index 0000000..c29480d --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/gunicorn.conf.py @@ -0,0 +1,222 @@ +# Gunicorn configuration file. +# +# Server socket +# +# bind - The socket to bind. +# +# A string of the form: 'HOST', 'HOST:PORT', 'unix:PATH'. +# An IP is a valid HOST. +# +# backlog - The number of pending connections. This refers +# to the number of clients that can be waiting to be +# served. Exceeding this number results in the client +# getting an error when attempting to connect. It should +# only affect servers under significant load. +# +# Must be a positive integer. Generally set in the 64-2048 +# range. +# + +bind = "0.0.0.0:5000" +backlog = 2048 + +# +# Worker processes +# +# workers - The number of worker processes that this server +# should keep alive for handling requests. +# +# A positive integer generally in the 2-4 x $(NUM_CORES) +# range. You'll want to vary this a bit to find the best +# for your particular application's work load. +# +# worker_class - The type of workers to use. The default +# sync class should handle most 'normal' types of work +# loads. You'll want to read +# http://docs.gunicorn.org/en/latest/design.html#choosing-a-worker-type +# for information on when you might want to choose one +# of the other worker classes. +# +# A string referring to a Python path to a subclass of +# gunicorn.workers.base.Worker. The default provided values +# can be seen at +# http://docs.gunicorn.org/en/latest/settings.html#worker-class +# +# worker_connections - For the eventlet and gevent worker classes +# this limits the maximum number of simultaneous clients that +# a single process can handle. +# +# A positive integer generally set to around 1000. +# +# timeout - If a worker does not notify the master process in this +# number of seconds it is killed and a new worker is spawned +# to replace it. +# +# Generally set to thirty seconds. Only set this noticeably +# higher if you're sure of the repercussions for sync workers. +# For the non sync workers it just means that the worker +# process is still communicating and is not tied to the length +# of time required to handle a single request. +# +# keepalive - The number of seconds to wait for the next request +# on a Keep-Alive HTTP connection. +# +# A positive integer. Generally set in the 1-5 seconds range. +# +# reload - Restart workers when code changes. +# +# This setting is intended for development. It will cause +# workers to be restarted whenever application code changes. +workers = 3 +threads = 3 +worker_class = "uvicorn.workers.UvicornWorker" +worker_connections = 1000 +timeout = 60 +keepalive = 2 +reload = True + +# +# spew - Install a trace function that spews every line of Python +# that is executed when running the server. This is the +# nuclear option. +# +# True or False +# + +spew = False + +# +# Server mechanics +# +# daemon - Detach the main Gunicorn process from the controlling +# terminal with a standard fork/fork sequence. +# +# True or False +# +# raw_env - Pass environment variables to the execution environment. +# +# pidfile - The path to a pid file to write +# +# A path string or None to not write a pid file. +# +# user - Switch worker processes to run as this user. +# +# A valid user id (as an integer) or the name of a user that +# can be retrieved with a call to pwd.getpwnam(value) or None +# to not change the worker process user. +# +# group - Switch worker process to run as this group. +# +# A valid group id (as an integer) or the name of a user that +# can be retrieved with a call to pwd.getgrnam(value) or None +# to change the worker processes group. +# +# umask - A mask for file permissions written by Gunicorn. Note that +# this affects unix socket permissions. +# +# A valid value for the os.umask(mode) call or a string +# compatible with int(value, 0) (0 means Python guesses +# the base, so values like "0", "0xFF", "0022" are valid +# for decimal, hex, and octal representations) +# +# tmp_upload_dir - A directory to store temporary request data when +# requests are read. This will most likely be disappearing soon. +# +# A path to a directory where the process owner can write. Or +# None to signal that Python should choose one on its own. +# + +daemon = False +pidfile = None +umask = 0 +user = None +group = None +tmp_upload_dir = None + +# +# Logging +# +# logfile - The path to a log file to write to. +# +# A path string. "-" means log to stdout. +# +# loglevel - The granularity of log output +# +# A string of "debug", "info", "warning", "error", "critical" +# + +errorlog = "-" +loglevel = "info" +accesslog = None +access_log_format = '%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"' + +# +# Process naming +# +# proc_name - A base to use with setproctitle to change the way +# that Gunicorn processes are reported in the system process +# table. This affects things like 'ps' and 'top'. If you're +# going to be running more than one instance of Gunicorn you'll +# probably want to set a name to tell them apart. This requires +# that you install the setproctitle module. +# +# A string or None to choose a default of something like 'gunicorn'. +# + +proc_name = "NotificationAPI" + + +# +# Server hooks +# +# post_fork - Called just after a worker has been forked. +# +# A callable that takes a server and worker instance +# as arguments. +# +# pre_fork - Called just prior to forking the worker subprocess. +# +# A callable that accepts the same arguments as after_fork +# +# pre_exec - Called just prior to forking off a secondary +# master process during things like config reloading. +# +# A callable that takes a server instance as the sole argument. +# + + +def post_fork(server, worker): + server.log.info("Worker spawned (pid: %s)", worker.pid) + + +def pre_fork(server, worker): + pass + + +def pre_exec(server): + server.log.info("Forked child, re-executing.") + + +def when_ready(server): + server.log.info("Server is ready. Spawning workers") + + +def worker_int(worker): + worker.log.info("worker received INT or QUIT signal") + + # get traceback info + import threading, sys, traceback + + id2name = {th.ident: th.name for th in threading.enumerate()} + code = [] + for threadId, stack in sys._current_frames().items(): + code.append("\n# Thread: %s(%d)" % (id2name.get(threadId, ""), threadId)) + for filename, lineno, name, line in traceback.extract_stack(stack): + code.append('File: "%s", line %d, in %s' % (filename, lineno, name)) + if line: + code.append(" %s" % (line.strip())) + worker.log.debug("\n".join(code)) + + +def worker_abort(worker): + worker.log.info("worker received SIGABRT signal") diff --git a/REST_AI_ENGINE_CONTROL/requirements.txt b/REST_AI_ENGINE_CONTROL/requirements.txt new file mode 100644 index 0000000..c8e3288 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/requirements.txt @@ -0,0 +1,13 @@ +fastapi==0.70.1 +uvicorn==0.16.0 +pymysql==1.0.2 +sqlalchemy==1.4.29 +bcrypt==3.2.0 +pyjwt==2.3.0 +yagmail==0.14.261 +boto3==1.20.32 +pytest==6.2.5 +cryptography +pycryptodomex +pycryptodome +email-validator \ No newline at end of file diff --git a/REST_AI_ENGINE_CONTROL/test_main.py b/REST_AI_ENGINE_CONTROL/test_main.py new file mode 100644 index 0000000..7fc004f --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/test_main.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +""" +@file : test_main.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: 개발시 모듈(main) 실행 + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +import uvicorn +from app.common.config import conf + + +if __name__ == '__main__': + print('test_main.py run') + uvicorn.run('app.main:app', host='0.0.0.0', port=conf().REST_SERVER_PORT, reload=True) diff --git a/REST_AI_ENGINE_CONTROL/tests/__init__.py b/REST_AI_ENGINE_CONTROL/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/REST_AI_ENGINE_CONTROL/tests/conftest.py b/REST_AI_ENGINE_CONTROL/tests/conftest.py new file mode 100644 index 0000000..9a4e811 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/tests/conftest.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +""" +@File: conftest.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: test config + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +import asyncio +import os +from os import path +from typing import List + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from app.database.schema import Users +from app.main import create_app +from app.database.conn import db, Base +from app.models import UserToken +from app.routes.auth import create_access_token + + +""" +1. DB 생성 +2. 테이블 생성 +3. 테스트 코드 작동 +4. 테이블 레코드 삭제 +""" + +@pytest.fixture(scope='session') +def app(): + os.environ['API_ENV'] = 'test' + return create_app() + + +@pytest.fixture(scope='session') +def client(app): + # Create tables + Base.metadata.create_all(db.engine) + return TestClient(app=app) + + +@pytest.fixture(scope='function', autouse=True) +def session(): + sess = next(db.session()) + yield sess + clear_all_table_data( + session=sess, + metadata=Base.metadata, + except_tables=[] + ) + sess.rollback() + + +@pytest.fixture(scope='function') +def login(session): + """ + 테스트전 사용자 미리 등록 + :param session: + :return: + """ + db_user = Users.create(session=session, email='ryan_test@dingrr.com', pw='123') + session.commit() + access_token = create_access_token(data=UserToken.from_orm(db_user).dict(exclude={'pw', 'marketing_agree'}),) + return dict(Authorization=f'Bearer {access_token}') + + +def clear_all_table_data(session: Session, metadata, except_tables: List[str] = None): + session.execute('SET FOREIGN_KEY_CHECKS = 0;') + for table in metadata.sorted_tables: + if table.name not in except_tables: + session.execute(table.delete()) + session.execute('SET FOREIGN_KEY_CHECKS = 1;') + session.commit() diff --git a/REST_AI_ENGINE_CONTROL/tests/test_auth.py b/REST_AI_ENGINE_CONTROL/tests/test_auth.py new file mode 100644 index 0000000..e83d893 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/tests/test_auth.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +""" +@File: test_auth.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: test auth. + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +from app.database.conn import db +from app.database.schema import Users + + +def test_registration(client, session): + """ + 레버 로그인 + :param client: + :param session: + :return: + """ + user = dict(email='ryan@dingrr.com', pw='123', name='라이언', phone='01099999999') + res = client.post('api/auth/register/email', json=user) + res_body = res.json() + print(res.json()) + assert res.status_code == 201 + assert 'Authorization' in res_body.keys() + + +def test_registration_exist_email(client, session): + """ + 레버 로그인 + :param client: + :param session: + :return: + """ + user = dict(email='Hello@dingrr.com', pw='123', name='라이언', phone='01099999999') + db_user = Users.create(session=session, **user) + session.commit() + res = client.post('api/auth/register/email', json=user) + res_body = res.json() + assert res.status_code == 400 + assert 'EMAIL_EXISTS' == res_body['msg'] diff --git a/REST_AI_ENGINE_CONTROL/tests/test_user.py b/REST_AI_ENGINE_CONTROL/tests/test_user.py new file mode 100644 index 0000000..50c8ed0 --- /dev/null +++ b/REST_AI_ENGINE_CONTROL/tests/test_user.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +""" +@File: test_user.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: test user + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +from app.database.conn import db +from app.database.schema import Users + + +def test_create_get_apikey(client, session, login): + """ + 레버 로그인 + :param client: + :param session: + :return: + """ + key = dict(user_memo='ryan__key') + res = client.post('api/user/apikeys', json=key, headers=login) + res_body = res.json() + assert res.status_code == 200 + assert 'secret_key' in res_body + + res = client.get('api/user/apikeys', headers=login) + res_body = res.json() + assert res.status_code == 200 + assert 'ryan__key' in res_body[0]['user_memo'] + diff --git a/ai_engine_const.py b/ai_engine_const.py new file mode 100644 index 0000000..8604f96 --- /dev/null +++ b/ai_engine_const.py @@ -0,0 +1,409 @@ +# -*- coding: utf-8 -*- +from REST_AI_ENGINE_CONTROL.app import models as M +import project_config +import os + +# MQTT +MQTT_HOST = 'localhost' + +# source +PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) + +DAOOL_RTSP = "rtsp://daool:Ekdnfeldpsdptm1@211.63.236.6:52554/axis-media/media.amp" +RTSP = "rtsp://223.171.144.245:8554/cam/0/low" + +STREAMS_PATH = PROJECT_PATH + "/DL/wd.streams" + +CON_VIDEO_PATH = PROJECT_PATH +"/AI_ENGINE/DATA/CON.mp4" +FACE_VIDEO_PATH = PROJECT_PATH +"/AI_ENGINE/DATA/FR.mov" +PPE_VIDEO_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/PPE.mp4" +WD_VIDEO_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/WD_2.mp4" + +if project_config.CONFIG == project_config.CONFIG_AISERVER: + # + #MQTT + MQTT_PORT = 50083 + MQTT_USER_ID = 'kepco' + MQTT_USER_PW = '!kepco1234' + + #source + RTSP = "rtsp://223.171.144.245:8554/cam/0/low" + CON_SOURCE = CON_VIDEO_PATH + FR_SOURCE = FACE_VIDEO_PATH + PPE_SOURCE = PPE_VIDEO_PATH + WD_SOURCE = WD_VIDEO_PATH + + #SFTP + FTP_IP = "106.255.245.242" + FTP_PORT = 2022 + FTP_ID = "kepri_if_user" + FTP_PW = "kepri!123" + FTP_LOCATION = "/home/agics-dev/kepri_storage/rndpartners/" + FTP_CON_FILE_NAME = "3" + FTP_FR_FILE_NAME = "1" + FTP_PPE_FILE_NAME = "2" + FTP_WD_FILE_NAME = "5" + FTP_BI_FILE_NAME = "4" + +elif project_config.CONFIG == project_config.CONFIG_MG: + # + # #MQTT + # MQTT_PORT = 1883 + # MQTT_USER_ID = 'admin' + # MQTT_USER_PW = 'admin' + + MQTT_PORT = 11883 + MQTT_USER_ID = 'kepco' + MQTT_USER_PW = '!kepco1234' + + # #source + # RTSP = "rtsp://10.20.10.1:8554/cam/0/low" + # RTSP = "rtsp://192.168.39.20:8554/cam/0/low" + RTSP = "rtsp://admin:admin1263!@10.20.10.99:28554/onvif/media?profile=Profile2" + CON_SOURCE = RTSP + FR_SOURCE = RTSP + PPE_SOURCE = RTSP + WD_SOURCE = RTSP + + # ----------- ---------- + #CON_SOURCE = "rtsp://219.250.188.204:8554/con" + #FR_SOURCE = "rtsp://219.250.188.205:8554/fr" + #PPE_SOURCE = "rtsp://219.250.188.206:8554/ppe" + #WD_SOURCE = "rtsp://219.250.188.207:8554/wd" + # ----------- ---------- + + if project_config.DEBUG_MODE: + CON_SOURCE = CON_VIDEO_PATH + FR_SOURCE = FACE_VIDEO_PATH + PPE_SOURCE = PPE_VIDEO_PATH + WD_SOURCE = WD_VIDEO_PATH + + # #SFTP + FTP_IP = "106.255.245.242" + FTP_PORT = 2022 + FTP_ID = "kepri_if_user" + FTP_PW = "kepri!123" + FTP_LOCATION = "/home/agics-dev/kepri_storage/rndpartners/" + + # ----------- ---------- + # FTP_IP = "211.63.236.6" + # FTP_PORT = 50002 + # FTP_ID = "fermat" + # FTP_PW = "1234" + # FTP_LOCATION = "/home/fermat/work/rest_ftp_test" + # ----------- ---------- + + FTP_CON_FILE_NAME = "3" + FTP_FR_FILE_NAME = "1" + FTP_PPE_FILE_NAME = "2" + FTP_WD_FILE_NAME = "5" + FTP_BI_FILE_NAME = "4" + +else : + # + # MQTT(dev3) + MQTT_PORT = 1883 + MQTT_USER_ID = 'admin' + MQTT_USER_PW = '12341234' + + RTSP = DAOOL_RTSP + CON_SOURCE = CON_VIDEO_PATH + FR_SOURCE = FACE_VIDEO_PATH + PPE_SOURCE = PPE_VIDEO_PATH + WD_SOURCE = WD_VIDEO_PATH + + # FTP + FTP_IP = "192.168.200.232" + FTP_PORT = 22 + FTP_ID = "fermat" + FTP_PW = "1234" + FTP_LOCATION = "/home/fermat/work/rest_ftp_test" + + FTP_CON_FILE_NAME = "c" + FTP_FR_FILE_NAME = "a" + FTP_PPE_FILE_NAME = "b" + FTP_WD_FILE_NAME = "e" + FTP_BI_FILE_NAME = "d" + + +#FTP result path +FTP_CON_RESULT = PROJECT_PATH + '/AI_ENGINE/DATA/ftp_data/con_setup.jpg' +FTP_FR_RESULT = PROJECT_PATH + '/AI_ENGINE/DATA/ftp_data/fr.jpg' +FTP_PPE_RESULT = PROJECT_PATH + '/AI_ENGINE/DATA/ftp_data/ppe.jpg' +FTP_WD_RESULT = PROJECT_PATH + '/AI_ENGINE/DATA/ftp_data/wd.jpg' +FTP_BI_RESULT = PROJECT_PATH + '/AI_ENGINE/DATA/ftp_data/bi.jpg' + +# TOPIC +MQTT_CON_TOPIC = '/AI_KEPCO/AI_OD_CON_SETUP_DETECT/REPORT' +MQTT_FR_TOPIC = '/AI_KEPCO/AI_FACE_RECOGNIZE/REPORT' +MQTT_PPE_TOPIC = '/AI_KEPCO/AI_OD_PPE_DETECT/REPORT' +MQTT_PPE_FR_TOPIC = '/AI_KEPCO/AI_OD_PPE_FR_DETECT/REPORT' # test topic +MQTT_WD_TOPIC = '/AI_KEPCO/AI_OD_WORK_DETECT/REPORT' +MQTT_BI_TOPIC = '/AI_KEPCO/AI_BI_DETECT/REPORT' + +# AI_MODEL +MODEL_CON = 'CON' +MODEL_PPE = 'PPE' +MODEL_WORK_DETECT = 'WD' +MODEL_FACE_RECOGNIZE = 'FR' +MODEL_BIO_INFO = 'BI' + +# YOLO +BBOX_XYXY = 'XYXY' +BBOX_XYWH = 'XYWH' + +# WD +WD_FRAME_COUNT = 80 + +# FACE_RECOGNITION +FACE_EVOLUTION_DISTANCE = 0.4 + +WORKER1_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco1.jpg" +WORKER2_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco1_1.jpg" +WORKER3_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco1_2.jpg" + +if project_config.DEBUG_MODE: + WORKER1_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/facerec_worker1.png" + WORKER2_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/facerec_worker2.png" + WORKER3_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/facerec_worker3.png" + +WORKER4_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/jangys_re.jpg" +WORKER5_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/whangsj.jpg" +WORKER6_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kimjw_re.jpg" +WORKER7_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/ksy_re.jpg" +WORKER8_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/agics.jpg" +WORKER9_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco1.jpg" + +WORKER10_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_1.jpg" +WORKER11_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_2.jpg" +WORKER12_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_3.jpg" +WORKER13_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_4.jpg" +WORKER14_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_5.jpg" +WORKER15_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_6.jpg" +WORKER16_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_7.jpg" +WORKER17_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_8.jpg" +WORKER18_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_9.jpg" +WORKER19_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_10.jpg" +WORKER20_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_11.jpg" + +WORKER21_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_12.jpg" +WORKER22_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/kepco2_13.jpg" + + +# WORKER16_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/yunikim.jpg" +# WORKER17_IMG_PATH = PROJECT_PATH + "/AI_ENGINE/DATA/facerec_worker3.png" + +SIGNAL_INFERENCE = 'inference' +SIGNAL_STOP = 'stop' + +SOURCE_CHANGED_MSG = "INPUT SOURCE CHANGED" +IMG_CHANGED_MSG = "INPUT IMAGE CHANGED" +INVALID_IMG_MSG = "INVALID IMAGE SIZE" + +DEMO_KEY_NAME_SNAPSHOT_SFTP = "SnapshotSFTP" + +OFF_CLASS_LIST = [2,4,7,9,11,15,17,19] +OFF_TRIGGER_CLASS_LIST = [2,4] #TODO(jwkim):BIXPO test (helmet,gloves) + +AI_ENGINE_INIT = { + "version": project_config.PROJECT_VERSION, + "ai_engine_status": "init", + "demo": { + "ftp": { + "ip": FTP_IP, + "port": FTP_PORT, + "id": FTP_ID, + "pw": FTP_PW, + "location": FTP_LOCATION, + "file_con_setup": FTP_CON_FILE_NAME, + "file_face": FTP_FR_FILE_NAME, + "file_ppe": FTP_PPE_FILE_NAME, + "file_wd": FTP_WD_FILE_NAME, + "file_bi": FTP_BI_FILE_NAME + } + }, + "input_video": [ + { + "name": "CON_video", + "model": M.AEAIModelType.CON, + "sn": "serial_no", + "connect_url": CON_SOURCE, + "user_id": "user_id", + "user_pw": "user_pw" + }, + { + "name": "FR_video", + "model": M.AEAIModelType.FR, + "sn": "serial_no", + "connect_url": FR_SOURCE, + "user_id": "user_id", + "user_pw": "user_pw" + }, + { + "name": "PPE_video", + "model": M.AEAIModelType.PPE, + "sn": "serial_no", + "connect_url": PPE_SOURCE, + "user_id": "user_id", + "user_pw": "user_pw" + }, + { + "name": "WD_video_1", + "model": M.AEAIModelType.WORK, + "sn": "serial_no", + "connect_url": WD_SOURCE, + "user_id": "user_id", + "user_pw": "user_pw" + } + ], + "input_bi": { + "name": "device_name", + "model": "model_name", + "sn": "serial_no", + "connect_url": "connect_url", + "user_id": "user_id", + "user_pw": "user_pw", + "topic": "topic" + }, +"con_model_info": { + "name": "CON", + "version": "20220101", + "status": "init", + "ri": { + "construction_code": "D54", + "work_no": 3, + "work_define_ri": 0.82, + "ri_parameter_list": [ + { + "name": "작업자 숙련도", + "ratio": 0.6 + }, + { + "name": "작업자 교육레벨", + "ratio": 0.5 + } + ], + "evaluation_work_ri": 0 + }, + "weights": [ + { + "id": 0, + "filename": "index_78.pt", + "version": "0.1", + "date": "2023-01-31T10:12:17", + "model": "small", + "nc": 26 + } + ], + "mode": "con", + "crop_images": False + }, + "fr_model_info": { + "name": "FR", + "version": "20220101", + "status": "init", + "ri": { + "construction_code": "D54", + "work_no": 3, + "work_define_ri": 0.82, + "ri_parameter_list": [ + { + "name": "작업자 숙련도", + "ratio": 0.6 + }, + { + "name": "작업자 교육레벨", + "ratio": 0.5 + } + ], + "evaluation_work_ri": 0 + } + }, + "ppe_model_info": { + "name": "PPE", + "version": "20220101", + "status": "init", + "ri": { + "construction_code": "D54", + "work_no": 3, + "work_define_ri": 0.82, + "ri_parameter_list": [ + { + "name": "작업자 숙련도", + "ratio": 0.6 + }, + { + "name": "작업자 교육레벨", + "ratio": 0.5 + } + ], + "evaluation_work_ri": 0 + }, + "weights": [ + { + "id": 0, + "filename": "index_78.pt", + "version": "0.1", + "date": "2023-01-31T10:12:17", + "model": "small", + "nc": 26 + } + ], + "mode": "ppe", + "crop_images": False + }, + "wd_model_info": { + "name": "WORK", + "version": "20220101", + "status": "init", + "ri": { + "construction_code": "D54", + "work_no": 3, + "work_define_ri": 0.82, + "ri_parameter_list": [ + { + "name": "작업자 숙련도", + "ratio": 0.6 + }, + { + "name": "작업자 교육레벨", + "ratio": 0.5 + } + ], + "evaluation_work_ri": 0 + }, + "weights": [ + { + "id": 0, + "filename": "index_78.pt", + "version": "0.1", + "date": "2023-01-31T10:12:17", + "model": "small", + "nc": 26 + } + ], + "mode": "work", + "crop_images": False + }, + "bi_model_info": { + "name": "BI", + "version": "20220101", + "status": "init", + "ri": { + "construction_code": "D54", + "work_no": 3, + "work_define_ri": 0.82, + "ri_parameter_list": [ + { + "name": "작업자 숙련도", + "ratio": 0.6 + }, + { + "name": "작업자 교육레벨", + "ratio": 0.5 + } + ], + "evaluation_work_ri": 0 + } + } +} diff --git a/ai_engine_main.py b/ai_engine_main.py new file mode 100644 index 0000000..9aa99dc --- /dev/null +++ b/ai_engine_main.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +""" +@file : ai_engine_main.py +@author: hsj100 +@license: A2TEC & DAOOLDNS +@brief: 개발시 모듈(main) 실행 + +@section Modify History +- 2022-01-14 오전 11:31 hsj100 base + +""" + +import os, sys +import uvicorn + +# AI_ENGINE_PATH = "/AI_ENGINE" +REST_SERVER_PATH = "/REST_AI_ENGINE_CONTROL" +# sys.path.append(os.path.abspath(os.path.dirname(__file__)) + AI_ENGINE_PATH) +sys.path.append(os.path.abspath(os.path.dirname(__file__)) + REST_SERVER_PATH) + +from REST_AI_ENGINE_CONTROL.app.common.config import conf + + +if __name__ == '__main__': + print('main.py run') + # os.system(". ./rtsp_start.sh") + uvicorn.run('REST_AI_ENGINE_CONTROL.app.main:app', host='0.0.0.0', port=conf().REST_SERVER_PORT, reload=True) \ No newline at end of file diff --git a/demo_bi.py b/demo_bi.py new file mode 100644 index 0000000..fb7293f --- /dev/null +++ b/demo_bi.py @@ -0,0 +1,47 @@ + +import cv2 +import os +import paramiko + +LOCAL_PATH = f"/home/kepco/daooldns/KEPCO_AI_RBI_SIM/ENGINE/AI_ENGINE/DATA/ftp_data/local.jpg" + +def bi_snap_shot(): + # 10.20.10.99 + # 192.168.39.20 + rtsp = "rtsp://10.20.10.1:8554/cam/0/low" + rtsp = "rtsp://admin:admin1263!@10.20.10.99:554/onvif/media?profile=Profile2" + local_path = f"/home/kepco/daooldns/KEPCO_AI_RBI_SIM/ENGINE/AI_ENGINE/DATA/ftp_data/local.jpg" + + if os.path.exists(local_path): + os.remove(local_path) + + input_movie = cv2.VideoCapture(rtsp) + + ret, frame = input_movie.read() + print(local_path) + cv2.imwrite(local_path,frame) + _bi_sftp_upload() + cv2.destroyAllWindows() + print(f"bi uploaded") + +def _bi_sftp_upload(): + try: + + IP = "106.255.245.242" + transprot = paramiko.Transport((IP,2022)) + transprot.connect(username = "kepri_if_user", password = "kepri!123") + sftp = paramiko.SFTPClient.from_transport(transprot) + + remotepath = "/home/agics-dev/kepri_storage/rndpartners/" + os.sep + "remote" + '.jpg' + + #sftp.put(LOCAL_PATH, remotepath) + + sftp.close() + transprot.close() + + return remotepath + except Exception as e: + return "" + +if __name__ == '__main__': + bi_snap_shot() \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..d76e0f9 --- /dev/null +++ b/main.py @@ -0,0 +1,16 @@ +# 샘플 Python 스크립트입니다. + +# Ctrl+F5을(를) 눌러 실행하거나 내 코드로 바꿉니다. +# 클래스, 파일, 도구 창, 액션 및 설정을 어디서나 검색하려면 Shift 두 번을(를) 누릅니다. + + +def print_hi(name): + # 스크립트를 디버그하려면 하단 코드 줄의 중단점을 사용합니다. + print(f'Hi, {name}') # 중단점을 전환하려면 F9을(를) 누릅니다. + + +# 스크립트를 실행하려면 여백의 녹색 버튼을 누릅니다. +if __name__ == '__main__': + print_hi('PyCharm') + +# https://www.jetbrains.com/help/pycharm/에서 PyCharm 도움말 참조 diff --git a/project_config.py b/project_config.py new file mode 100644 index 0000000..8ff43ee --- /dev/null +++ b/project_config.py @@ -0,0 +1,15 @@ +PROJECT_VERSION = 0.5 + +CONFIG_AISERVER = "AISERVER" +CONFIG_MG = "MG" +CONFIG_FERMAT = "FERMAT" + +# CONFIG = CONFIG_AISERVER +CONFIG = CONFIG_MG +# CONFIG = CONFIG_FERMAT + +DEBUG_MODE = False + +SFTP_UPLOAD = False + +FR_UPLOAD = True diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..41b736f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,54 @@ +#1. ADTK (python>3.5) +adtk + +#2. Face_recognition (python>3.3) +# cmake ,dlib 설치 필요 +face-recognition +opencv-python + +#3. Yolov5 +# torch,torchvision 별도 설치 +# --- base --- +gitpython +ipython # interactive notebook +matplotlib>=3.2.2 +numpy>=1.18.5 +opencv-python>=4.1.1 +Pillow>=7.1.2 +psutil # system resources +PyYAML>=5.3.1 +requests>=2.23.0 +scipy>=1.4.1 +thop>=0.1.1 # FLOPs computation +#torch>=1.7.0 # see https://pytorch.org/get-started/locally (recommended) +#torchvision>=0.8.1 +tqdm>=4.64.0 + +# --- Logging --- +tensorboard>=2.4.1 + +# --- Plotting --- +pandas>=1.1.4 +seaborn>=0.11.0 + +#4. REST server +fastapi==0.70.1 +uvicorn==0.16.0 +pymysql==1.0.2 +sqlalchemy==1.4.29 +bcrypt==3.2.0 +pyjwt==2.3.0 +yagmail==0.14.261 +boto3==1.20.32 +pytest==6.2.5 +cryptography +pycryptodomex +pycryptodome +email-validator +python-multipart + +#5. MQTT +paho-mqtt + +#6. FTP +paramiko \ No newline at end of file diff --git a/rtsp_start.sh b/rtsp_start.sh new file mode 100644 index 0000000..dce5783 --- /dev/null +++ b/rtsp_start.sh @@ -0,0 +1 @@ +mosquitto_pub -p 11883 -h localhost -t ptz -m 1,0.95,0 -u admin -P admin \ No newline at end of file