Files
HOSPITAL_CCTV/load_models.py

108 lines
2.8 KiB
Python
Raw Normal View History

2026-02-25 15:05:58 +09:00
import numpy as np
from ultralytics import YOLO
import demo_const as AI_CONST
import project_config
# from ai_engine.custom_logger.custom_log import log
class AIObjectModel:
"""
Yolov8 OD model Load
"""
MODEL_NAME = "OD"
WEIGHTS = AI_CONST.WEIGHTS_YOLO
def __init__(self) -> None:
self.model = ""
self.set_model(weights=self.WEIGHTS)
def set_model(self, weights):
self.model = YOLO(weights)
# Check device status
import torch
device_status = "CUDA (GPU)" if torch.cuda.is_available() else "CPU"
print(f"[{self.MODEL_NAME}] Model loaded. Using: {device_status}")
# warm up models (입력 영상 해상도가 다를경우 수정해야함)
img = np.zeros([1080,1920,3],dtype=np.uint8) # black image FHD 1920x1080
# img = np.zeros([360,640,3],dtype=np.uint8) # black image HD 640×360
self.model.predict(source=img, verbose=False)
# log.info(f"YOLOV8 {self.MODEL_NAME} MODEL LOADED")
def get_model(self):
return self.model
def get_class_info(self):
return self.model.names
class AIHelmetObjectModel(AIObjectModel):
"""
Yolov8 OD model Load
"""
MODEL_NAME = "OD-Helmet"
WEIGHTS = AI_CONST.WEIGHTS_YOLO_HELMET
class AIHPEModel(AIObjectModel):
"""
Yolov8 HPE model Load
"""
MODEL_NAME = "HPE"
WEIGHTS = AI_CONST.WEIGHTS_POSE
class AIModelManager:
def __init__(self) -> None:
self.od_model_info = ""
self.hpe_model_info = ""
self.helmet_model_info = ""
def set_od(self): # set od
if not self.od_model_info:
self.od_model_info = AIObjectModel()
def set_hpe(self): # set hpe
if not self.hpe_model_info:
self.hpe_model_info = AIHPEModel()
def set_helmet(self): # set od
if not self.helmet_model_info:
self.helmet_model_info = AIHelmetObjectModel()
def get_od(self):
if not self.od_model_info:
raise Exception("YOLO model not loaded")
return self.od_model_info
def get_hpe(self):
if not self.hpe_model_info:
raise Exception("HPE model not loaded")
return self.hpe_model_info
def get_helmet(self):
if not self.helmet_model_info:
raise Exception("helmet model not loaded")
return self.helmet_model_info
model_manager = ""
if not model_manager:
model_manager = AIModelManager()
model_manager.set_od()
model_manager.set_hpe()
if project_config.USE_HELMET_MODEL:
model_manager.set_helmet()
if __name__ == "__main__":
pass