version: v0.0.1
This commit is contained in:
108
load_models.py
Normal file
108
load_models.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
|
||||
import demo_const as AI_CONST
|
||||
import project_config
|
||||
|
||||
# from ai_engine.custom_logger.custom_log import log
|
||||
|
||||
class AIObjectModel:
|
||||
"""
|
||||
Yolov8 OD model Load
|
||||
"""
|
||||
|
||||
MODEL_NAME = "OD"
|
||||
WEIGHTS = AI_CONST.WEIGHTS_YOLO
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.model = ""
|
||||
self.set_model(weights=self.WEIGHTS)
|
||||
|
||||
def set_model(self, weights):
|
||||
self.model = YOLO(weights)
|
||||
|
||||
# Check device status
|
||||
import torch
|
||||
device_status = "CUDA (GPU)" if torch.cuda.is_available() else "CPU"
|
||||
print(f"[{self.MODEL_NAME}] Model loaded. Using: {device_status}")
|
||||
|
||||
# warm up models (입력 영상 해상도가 다를경우 수정해야함)
|
||||
img = np.zeros([1080,1920,3],dtype=np.uint8) # black image FHD 1920x1080
|
||||
# img = np.zeros([360,640,3],dtype=np.uint8) # black image HD 640×360
|
||||
self.model.predict(source=img, verbose=False)
|
||||
|
||||
# log.info(f"YOLOV8 {self.MODEL_NAME} MODEL LOADED")
|
||||
|
||||
def get_model(self):
|
||||
return self.model
|
||||
|
||||
def get_class_info(self):
|
||||
return self.model.names
|
||||
|
||||
class AIHelmetObjectModel(AIObjectModel):
|
||||
"""
|
||||
Yolov8 OD model Load
|
||||
"""
|
||||
|
||||
MODEL_NAME = "OD-Helmet"
|
||||
WEIGHTS = AI_CONST.WEIGHTS_YOLO_HELMET
|
||||
|
||||
|
||||
class AIHPEModel(AIObjectModel):
|
||||
"""
|
||||
Yolov8 HPE model Load
|
||||
"""
|
||||
|
||||
MODEL_NAME = "HPE"
|
||||
WEIGHTS = AI_CONST.WEIGHTS_POSE
|
||||
|
||||
|
||||
class AIModelManager:
|
||||
def __init__(self) -> None:
|
||||
self.od_model_info = ""
|
||||
self.hpe_model_info = ""
|
||||
|
||||
self.helmet_model_info = ""
|
||||
|
||||
def set_od(self): # set od
|
||||
if not self.od_model_info:
|
||||
self.od_model_info = AIObjectModel()
|
||||
|
||||
def set_hpe(self): # set hpe
|
||||
if not self.hpe_model_info:
|
||||
self.hpe_model_info = AIHPEModel()
|
||||
|
||||
def set_helmet(self): # set od
|
||||
if not self.helmet_model_info:
|
||||
self.helmet_model_info = AIHelmetObjectModel()
|
||||
|
||||
def get_od(self):
|
||||
if not self.od_model_info:
|
||||
raise Exception("YOLO model not loaded")
|
||||
|
||||
return self.od_model_info
|
||||
|
||||
def get_hpe(self):
|
||||
if not self.hpe_model_info:
|
||||
raise Exception("HPE model not loaded")
|
||||
|
||||
return self.hpe_model_info
|
||||
|
||||
def get_helmet(self):
|
||||
if not self.helmet_model_info:
|
||||
raise Exception("helmet model not loaded")
|
||||
|
||||
return self.helmet_model_info
|
||||
|
||||
|
||||
model_manager = ""
|
||||
|
||||
if not model_manager:
|
||||
model_manager = AIModelManager()
|
||||
model_manager.set_od()
|
||||
model_manager.set_hpe()
|
||||
if project_config.USE_HELMET_MODEL:
|
||||
model_manager.set_helmet()
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
Reference in New Issue
Block a user