108 lines
2.8 KiB
Python
108 lines
2.8 KiB
Python
import numpy as np
|
||
from ultralytics import YOLO
|
||
|
||
import demo_const as AI_CONST
|
||
import project_config
|
||
|
||
# from ai_engine.custom_logger.custom_log import log
|
||
|
||
class AIObjectModel:
|
||
"""
|
||
Yolov8 OD model Load
|
||
"""
|
||
|
||
MODEL_NAME = "OD"
|
||
WEIGHTS = AI_CONST.WEIGHTS_YOLO
|
||
|
||
def __init__(self) -> None:
|
||
self.model = ""
|
||
self.set_model(weights=self.WEIGHTS)
|
||
|
||
def set_model(self, weights):
|
||
self.model = YOLO(weights)
|
||
|
||
# Check device status
|
||
import torch
|
||
device_status = "CUDA (GPU)" if torch.cuda.is_available() else "CPU"
|
||
print(f"[{self.MODEL_NAME}] Model loaded. Using: {device_status}")
|
||
|
||
# warm up models (입력 영상 해상도가 다를경우 수정해야함)
|
||
img = np.zeros([1080,1920,3],dtype=np.uint8) # black image FHD 1920x1080
|
||
# img = np.zeros([360,640,3],dtype=np.uint8) # black image HD 640×360
|
||
self.model.predict(source=img, verbose=False)
|
||
|
||
# log.info(f"YOLOV8 {self.MODEL_NAME} MODEL LOADED")
|
||
|
||
def get_model(self):
|
||
return self.model
|
||
|
||
def get_class_info(self):
|
||
return self.model.names
|
||
|
||
class AIHelmetObjectModel(AIObjectModel):
|
||
"""
|
||
Yolov8 OD model Load
|
||
"""
|
||
|
||
MODEL_NAME = "OD-Helmet"
|
||
WEIGHTS = AI_CONST.WEIGHTS_YOLO_HELMET
|
||
|
||
|
||
class AIHPEModel(AIObjectModel):
|
||
"""
|
||
Yolov8 HPE model Load
|
||
"""
|
||
|
||
MODEL_NAME = "HPE"
|
||
WEIGHTS = AI_CONST.WEIGHTS_POSE
|
||
|
||
|
||
class AIModelManager:
|
||
def __init__(self) -> None:
|
||
self.od_model_info = ""
|
||
self.hpe_model_info = ""
|
||
|
||
self.helmet_model_info = ""
|
||
|
||
def set_od(self): # set od
|
||
if not self.od_model_info:
|
||
self.od_model_info = AIObjectModel()
|
||
|
||
def set_hpe(self): # set hpe
|
||
if not self.hpe_model_info:
|
||
self.hpe_model_info = AIHPEModel()
|
||
|
||
def set_helmet(self): # set od
|
||
if not self.helmet_model_info:
|
||
self.helmet_model_info = AIHelmetObjectModel()
|
||
|
||
def get_od(self):
|
||
if not self.od_model_info:
|
||
raise Exception("YOLO model not loaded")
|
||
|
||
return self.od_model_info
|
||
|
||
def get_hpe(self):
|
||
if not self.hpe_model_info:
|
||
raise Exception("HPE model not loaded")
|
||
|
||
return self.hpe_model_info
|
||
|
||
def get_helmet(self):
|
||
if not self.helmet_model_info:
|
||
raise Exception("helmet model not loaded")
|
||
|
||
return self.helmet_model_info
|
||
|
||
|
||
model_manager = ""
|
||
|
||
if not model_manager:
|
||
model_manager = AIModelManager()
|
||
model_manager.set_od()
|
||
model_manager.set_hpe()
|
||
if project_config.USE_HELMET_MODEL:
|
||
model_manager.set_helmet()
|
||
|
||
if __name__ == "__main__":
|
||
pass |