148 lines
5.3 KiB
Python
148 lines
5.3 KiB
Python
|
|
import cv2
|
||
|
|
import demo_const as AI_CONST
|
||
|
|
from pydantic.main import BaseModel
|
||
|
|
|
||
|
|
# from DL.custom_utils import crop_image,image_encoding
|
||
|
|
from hpe_classification.hpe_classification import HPEClassification
|
||
|
|
from hpe_classification import config as hpe_config
|
||
|
|
|
||
|
|
|
||
|
|
class HistoryInfo(BaseModel):
|
||
|
|
"""
|
||
|
|
history 에 저장될 정보
|
||
|
|
"""
|
||
|
|
class_name: str
|
||
|
|
class_id: int
|
||
|
|
object_id: int = None
|
||
|
|
bbox: list
|
||
|
|
bbox_conf: float
|
||
|
|
keypoint: list = None
|
||
|
|
kpt_conf: list = None
|
||
|
|
|
||
|
|
class ObjectDetect:
|
||
|
|
"""
|
||
|
|
ObjectDetect
|
||
|
|
"""
|
||
|
|
MODEL_CONFIDENCE = AI_CONST.MODEL_CONFIDENCE
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self.image = ''
|
||
|
|
self.model = ''
|
||
|
|
self.class_info = ''
|
||
|
|
|
||
|
|
def predict(self,crop_image = True, class_name=True):
|
||
|
|
od_predict = self.model.model.predict(source=self.image, show=False, stream=True, save=False, conf=self.MODEL_CONFIDENCE, verbose=False, imgsz=AI_CONST.MODEL_IMAGE_SIZE)
|
||
|
|
message_list = []
|
||
|
|
for object_list in od_predict:
|
||
|
|
for box in object_list.boxes:
|
||
|
|
_parsing_data = self.parsing(bbox=box)
|
||
|
|
_name = _parsing_data.class_name if class_name else False
|
||
|
|
message_list.append(self._od_mqtt_message(_parsing_data, image= self.image, crop=crop_image, name=_name))
|
||
|
|
return message_list
|
||
|
|
|
||
|
|
def set_image(self,image):
|
||
|
|
self.image = image
|
||
|
|
|
||
|
|
def set_model(self,model):
|
||
|
|
self.model = model
|
||
|
|
self.class_info = self.model.get_class_info()
|
||
|
|
|
||
|
|
def _od_mqtt_message(self,data,image,crop,name):
|
||
|
|
message ={}
|
||
|
|
|
||
|
|
# if crop:
|
||
|
|
# crop_img = image_encoding(crop_image(xyxy=data.bbox,img=image))
|
||
|
|
# else:
|
||
|
|
crop_img = None
|
||
|
|
|
||
|
|
if name:
|
||
|
|
message = {
|
||
|
|
"class_id" : data.class_id,
|
||
|
|
"class_name" : name,
|
||
|
|
"confidence" : data.bbox_conf,
|
||
|
|
"bbox" : data.bbox,
|
||
|
|
"object_id": data.object_id,
|
||
|
|
"parent_object_id": None,
|
||
|
|
"image": crop_img
|
||
|
|
}
|
||
|
|
|
||
|
|
else:
|
||
|
|
message = {
|
||
|
|
"class_id" : data.class_id,
|
||
|
|
"confidence" : data.bbox_conf,
|
||
|
|
"bbox" : data.bbox,
|
||
|
|
"object_id": data.object_id,
|
||
|
|
"parent_object_id": None,
|
||
|
|
"image": crop_img
|
||
|
|
}
|
||
|
|
|
||
|
|
return message
|
||
|
|
|
||
|
|
def parsing(self, bbox, kpt=None):
|
||
|
|
"""
|
||
|
|
데이터 파싱
|
||
|
|
kpt(pose) 정보가 없을시 keypoint 관련정보 None으로 처리
|
||
|
|
object id가 없을시 None으로 처리
|
||
|
|
|
||
|
|
:param bbox: object detect 정보
|
||
|
|
:param kpt: pose detect 정보
|
||
|
|
:return: HistoryInfo
|
||
|
|
"""
|
||
|
|
|
||
|
|
_cls_id = int(bbox.cls[0].item())
|
||
|
|
_history_info = HistoryInfo(
|
||
|
|
class_id = _cls_id,
|
||
|
|
class_name = self.class_info[_cls_id],
|
||
|
|
bbox = list(map(round, bbox.xyxy[0].detach().cpu().tolist())),
|
||
|
|
bbox_conf = round(bbox.conf[0].item(), 2),
|
||
|
|
)
|
||
|
|
if bbox.id is not None:
|
||
|
|
_history_info.object_id = int(bbox.id[0].item())
|
||
|
|
|
||
|
|
if kpt:
|
||
|
|
_history_info.keypoint = kpt.xy[0].detach().cpu().tolist()
|
||
|
|
_history_info.kpt_conf = kpt.conf[0].detach().cpu().tolist()
|
||
|
|
|
||
|
|
return _history_info
|
||
|
|
|
||
|
|
class PoseDetect(ObjectDetect):
|
||
|
|
"""
|
||
|
|
PoseDetect
|
||
|
|
"""
|
||
|
|
MODEL_CONFIDENCE = 0.5
|
||
|
|
|
||
|
|
FALLDOWN_TILT_RATIO = hpe_config.FALLDOWN_TILT_RATIO
|
||
|
|
FALLDOWN_TILT_ANGLE = hpe_config.FALLDOWN_TILT_ANGLE
|
||
|
|
CROSS_RATIO_THRESHOLD = hpe_config.CROSS_ARM_RATIO_THRESHOLD
|
||
|
|
CROSS_ANGLE_THRESHOLD = hpe_config.CROSS_ARM_ANGLE_THRESHOLD
|
||
|
|
|
||
|
|
def predict(self,working,crop_image=True):
|
||
|
|
|
||
|
|
# pose_predict = self.model.model.predict(source=self.image, show=False, stream=True, save=False, conf=self.MODEL_CONFIDENCE, verbose=False, imgsz=AI_CONST.MODEL_IMAGE_SIZE)
|
||
|
|
#NOTE(jwkim): pose predict
|
||
|
|
pose_predict = self.model.model.track(source=self.image, show=False, stream=True, save=False, conf=self.MODEL_CONFIDENCE, verbose=False, imgsz=AI_CONST.MODEL_IMAGE_SIZE, persist=True)
|
||
|
|
|
||
|
|
message_list = []
|
||
|
|
for object_list in pose_predict:
|
||
|
|
for box, pose in zip(object_list.boxes, object_list.keypoints):
|
||
|
|
_parsing_data = self.parsing(bbox=box, kpt=pose)
|
||
|
|
current_pose={}
|
||
|
|
current_pose["person"]= _parsing_data.bbox
|
||
|
|
current_pose["keypoints"]= _parsing_data.keypoint
|
||
|
|
current_pose["kpt_conf"] = _parsing_data.kpt_conf
|
||
|
|
|
||
|
|
# HPEClassification
|
||
|
|
hpe_classification = HPEClassification(pose_info=current_pose)
|
||
|
|
|
||
|
|
current_pose["result"]=self._od_mqtt_message(_parsing_data, self.image, crop=crop_image, name=_parsing_data.class_name)
|
||
|
|
current_pose["result"]['pose_type'] = hpe_classification.get_hpe_type(is_working_on=working)
|
||
|
|
current_pose["result"]['pose_level'] = hpe_classification.get_hpe_level(is_working_on=working)
|
||
|
|
|
||
|
|
message_list.append(current_pose)
|
||
|
|
return message_list
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
pass
|
||
|
|
|