import cv2 import demo_const as AI_CONST from pydantic.main import BaseModel # from DL.custom_utils import crop_image,image_encoding from hpe_classification.hpe_classification import HPEClassification from hpe_classification import config as hpe_config class HistoryInfo(BaseModel): """ history 에 저장될 정보 """ class_name: str class_id: int object_id: int = None bbox: list bbox_conf: float keypoint: list = None kpt_conf: list = None class ObjectDetect: """ ObjectDetect """ MODEL_CONFIDENCE = AI_CONST.MODEL_CONFIDENCE def __init__(self) -> None: self.image = '' self.model = '' self.class_info = '' def predict(self,crop_image = True, class_name=True): od_predict = self.model.model.predict(source=self.image, show=False, stream=True, save=False, conf=self.MODEL_CONFIDENCE, verbose=False, imgsz=AI_CONST.MODEL_IMAGE_SIZE) message_list = [] for object_list in od_predict: for box in object_list.boxes: _parsing_data = self.parsing(bbox=box) _name = _parsing_data.class_name if class_name else False message_list.append(self._od_mqtt_message(_parsing_data, image= self.image, crop=crop_image, name=_name)) return message_list def set_image(self,image): self.image = image def set_model(self,model): self.model = model self.class_info = self.model.get_class_info() def _od_mqtt_message(self,data,image,crop,name): message ={} # if crop: # crop_img = image_encoding(crop_image(xyxy=data.bbox,img=image)) # else: crop_img = None if name: message = { "class_id" : data.class_id, "class_name" : name, "confidence" : data.bbox_conf, "bbox" : data.bbox, "object_id": data.object_id, "parent_object_id": None, "image": crop_img } else: message = { "class_id" : data.class_id, "confidence" : data.bbox_conf, "bbox" : data.bbox, "object_id": data.object_id, "parent_object_id": None, "image": crop_img } return message def parsing(self, bbox, kpt=None): """ 데이터 파싱 kpt(pose) 정보가 없을시 keypoint 관련정보 None으로 처리 object id가 없을시 None으로 처리 :param bbox: object detect 정보 :param kpt: pose detect 정보 :return: HistoryInfo """ _cls_id = int(bbox.cls[0].item()) _history_info = HistoryInfo( class_id = _cls_id, class_name = self.class_info[_cls_id], bbox = list(map(round, bbox.xyxy[0].detach().cpu().tolist())), bbox_conf = round(bbox.conf[0].item(), 2), ) if bbox.id is not None: _history_info.object_id = int(bbox.id[0].item()) if kpt: _history_info.keypoint = kpt.xy[0].detach().cpu().tolist() _history_info.kpt_conf = kpt.conf[0].detach().cpu().tolist() return _history_info class PoseDetect(ObjectDetect): """ PoseDetect """ MODEL_CONFIDENCE = 0.5 FALLDOWN_TILT_RATIO = hpe_config.FALLDOWN_TILT_RATIO FALLDOWN_TILT_ANGLE = hpe_config.FALLDOWN_TILT_ANGLE CROSS_RATIO_THRESHOLD = hpe_config.CROSS_ARM_RATIO_THRESHOLD CROSS_ANGLE_THRESHOLD = hpe_config.CROSS_ARM_ANGLE_THRESHOLD def predict(self,working,crop_image=True): # pose_predict = self.model.model.predict(source=self.image, show=False, stream=True, save=False, conf=self.MODEL_CONFIDENCE, verbose=False, imgsz=AI_CONST.MODEL_IMAGE_SIZE) #NOTE(jwkim): pose predict pose_predict = self.model.model.track(source=self.image, show=False, stream=True, save=False, conf=self.MODEL_CONFIDENCE, verbose=False, imgsz=AI_CONST.MODEL_IMAGE_SIZE, persist=True) message_list = [] for object_list in pose_predict: for box, pose in zip(object_list.boxes, object_list.keypoints): _parsing_data = self.parsing(bbox=box, kpt=pose) current_pose={} current_pose["person"]= _parsing_data.bbox current_pose["keypoints"]= _parsing_data.keypoint current_pose["kpt_conf"] = _parsing_data.kpt_conf # HPEClassification hpe_classification = HPEClassification(pose_info=current_pose) current_pose["result"]=self._od_mqtt_message(_parsing_data, self.image, crop=crop_image, name=_parsing_data.class_name) current_pose["result"]['pose_type'] = hpe_classification.get_hpe_type(is_working_on=working) current_pose["result"]['pose_level'] = hpe_classification.get_hpe_level(is_working_on=working) message_list.append(current_pose) return message_list if __name__ == "__main__": pass