Files
HOSPITAL_CCTV/load_models.py
2026-02-25 15:05:58 +09:00

108 lines
2.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
from ultralytics import YOLO
import demo_const as AI_CONST
import project_config
# from ai_engine.custom_logger.custom_log import log
class AIObjectModel:
"""
Yolov8 OD model Load
"""
MODEL_NAME = "OD"
WEIGHTS = AI_CONST.WEIGHTS_YOLO
def __init__(self) -> None:
self.model = ""
self.set_model(weights=self.WEIGHTS)
def set_model(self, weights):
self.model = YOLO(weights)
# Check device status
import torch
device_status = "CUDA (GPU)" if torch.cuda.is_available() else "CPU"
print(f"[{self.MODEL_NAME}] Model loaded. Using: {device_status}")
# warm up models (입력 영상 해상도가 다를경우 수정해야함)
img = np.zeros([1080,1920,3],dtype=np.uint8) # black image FHD 1920x1080
# img = np.zeros([360,640,3],dtype=np.uint8) # black image HD 640×360
self.model.predict(source=img, verbose=False)
# log.info(f"YOLOV8 {self.MODEL_NAME} MODEL LOADED")
def get_model(self):
return self.model
def get_class_info(self):
return self.model.names
class AIHelmetObjectModel(AIObjectModel):
"""
Yolov8 OD model Load
"""
MODEL_NAME = "OD-Helmet"
WEIGHTS = AI_CONST.WEIGHTS_YOLO_HELMET
class AIHPEModel(AIObjectModel):
"""
Yolov8 HPE model Load
"""
MODEL_NAME = "HPE"
WEIGHTS = AI_CONST.WEIGHTS_POSE
class AIModelManager:
def __init__(self) -> None:
self.od_model_info = ""
self.hpe_model_info = ""
self.helmet_model_info = ""
def set_od(self): # set od
if not self.od_model_info:
self.od_model_info = AIObjectModel()
def set_hpe(self): # set hpe
if not self.hpe_model_info:
self.hpe_model_info = AIHPEModel()
def set_helmet(self): # set od
if not self.helmet_model_info:
self.helmet_model_info = AIHelmetObjectModel()
def get_od(self):
if not self.od_model_info:
raise Exception("YOLO model not loaded")
return self.od_model_info
def get_hpe(self):
if not self.hpe_model_info:
raise Exception("HPE model not loaded")
return self.hpe_model_info
def get_helmet(self):
if not self.helmet_model_info:
raise Exception("helmet model not loaded")
return self.helmet_model_info
model_manager = ""
if not model_manager:
model_manager = AIModelManager()
model_manager.set_od()
model_manager.set_hpe()
if project_config.USE_HELMET_MODEL:
model_manager.set_helmet()
if __name__ == "__main__":
pass