Files

148 lines
5.3 KiB
Python
Raw Permalink Normal View History

2026-02-25 15:05:58 +09:00
import cv2
import demo_const as AI_CONST
from pydantic.main import BaseModel
# from DL.custom_utils import crop_image,image_encoding
from hpe_classification.hpe_classification import HPEClassification
from hpe_classification import config as hpe_config
class HistoryInfo(BaseModel):
"""
history 저장될 정보
"""
class_name: str
class_id: int
object_id: int = None
bbox: list
bbox_conf: float
keypoint: list = None
kpt_conf: list = None
class ObjectDetect:
"""
ObjectDetect
"""
MODEL_CONFIDENCE = AI_CONST.MODEL_CONFIDENCE
def __init__(self) -> None:
self.image = ''
self.model = ''
self.class_info = ''
def predict(self,crop_image = True, class_name=True):
od_predict = self.model.model.predict(source=self.image, show=False, stream=True, save=False, conf=self.MODEL_CONFIDENCE, verbose=False, imgsz=AI_CONST.MODEL_IMAGE_SIZE)
message_list = []
for object_list in od_predict:
for box in object_list.boxes:
_parsing_data = self.parsing(bbox=box)
_name = _parsing_data.class_name if class_name else False
message_list.append(self._od_mqtt_message(_parsing_data, image= self.image, crop=crop_image, name=_name))
return message_list
def set_image(self,image):
self.image = image
def set_model(self,model):
self.model = model
self.class_info = self.model.get_class_info()
def _od_mqtt_message(self,data,image,crop,name):
message ={}
# if crop:
# crop_img = image_encoding(crop_image(xyxy=data.bbox,img=image))
# else:
crop_img = None
if name:
message = {
"class_id" : data.class_id,
"class_name" : name,
"confidence" : data.bbox_conf,
"bbox" : data.bbox,
"object_id": data.object_id,
"parent_object_id": None,
"image": crop_img
}
else:
message = {
"class_id" : data.class_id,
"confidence" : data.bbox_conf,
"bbox" : data.bbox,
"object_id": data.object_id,
"parent_object_id": None,
"image": crop_img
}
return message
def parsing(self, bbox, kpt=None):
"""
데이터 파싱
kpt(pose) 정보가 없을시 keypoint 관련정보 None으로 처리
object id가 없을시 None으로 처리
:param bbox: object detect 정보
:param kpt: pose detect 정보
:return: HistoryInfo
"""
_cls_id = int(bbox.cls[0].item())
_history_info = HistoryInfo(
class_id = _cls_id,
class_name = self.class_info[_cls_id],
bbox = list(map(round, bbox.xyxy[0].detach().cpu().tolist())),
bbox_conf = round(bbox.conf[0].item(), 2),
)
if bbox.id is not None:
_history_info.object_id = int(bbox.id[0].item())
if kpt:
_history_info.keypoint = kpt.xy[0].detach().cpu().tolist()
_history_info.kpt_conf = kpt.conf[0].detach().cpu().tolist()
return _history_info
class PoseDetect(ObjectDetect):
"""
PoseDetect
"""
MODEL_CONFIDENCE = 0.5
FALLDOWN_TILT_RATIO = hpe_config.FALLDOWN_TILT_RATIO
FALLDOWN_TILT_ANGLE = hpe_config.FALLDOWN_TILT_ANGLE
CROSS_RATIO_THRESHOLD = hpe_config.CROSS_ARM_RATIO_THRESHOLD
CROSS_ANGLE_THRESHOLD = hpe_config.CROSS_ARM_ANGLE_THRESHOLD
def predict(self,working,crop_image=True):
# pose_predict = self.model.model.predict(source=self.image, show=False, stream=True, save=False, conf=self.MODEL_CONFIDENCE, verbose=False, imgsz=AI_CONST.MODEL_IMAGE_SIZE)
#NOTE(jwkim): pose predict
pose_predict = self.model.model.track(source=self.image, show=False, stream=True, save=False, conf=self.MODEL_CONFIDENCE, verbose=False, imgsz=AI_CONST.MODEL_IMAGE_SIZE, persist=True)
message_list = []
for object_list in pose_predict:
for box, pose in zip(object_list.boxes, object_list.keypoints):
_parsing_data = self.parsing(bbox=box, kpt=pose)
current_pose={}
current_pose["person"]= _parsing_data.bbox
current_pose["keypoints"]= _parsing_data.keypoint
current_pose["kpt_conf"] = _parsing_data.kpt_conf
# HPEClassification
hpe_classification = HPEClassification(pose_info=current_pose)
current_pose["result"]=self._od_mqtt_message(_parsing_data, self.image, crop=crop_image, name=_parsing_data.class_name)
current_pose["result"]['pose_type'] = hpe_classification.get_hpe_type(is_working_on=working)
current_pose["result"]['pose_level'] = hpe_classification.get_hpe_level(is_working_on=working)
message_list.append(current_pose)
return message_list
if __name__ == "__main__":
pass