version: v0.0.1
This commit is contained in:
147
predict.py
Normal file
147
predict.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import cv2
|
||||
import demo_const as AI_CONST
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
# from DL.custom_utils import crop_image,image_encoding
|
||||
from hpe_classification.hpe_classification import HPEClassification
|
||||
from hpe_classification import config as hpe_config
|
||||
|
||||
|
||||
class HistoryInfo(BaseModel):
|
||||
"""
|
||||
history 에 저장될 정보
|
||||
"""
|
||||
class_name: str
|
||||
class_id: int
|
||||
object_id: int = None
|
||||
bbox: list
|
||||
bbox_conf: float
|
||||
keypoint: list = None
|
||||
kpt_conf: list = None
|
||||
|
||||
class ObjectDetect:
|
||||
"""
|
||||
ObjectDetect
|
||||
"""
|
||||
MODEL_CONFIDENCE = AI_CONST.MODEL_CONFIDENCE
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.image = ''
|
||||
self.model = ''
|
||||
self.class_info = ''
|
||||
|
||||
def predict(self,crop_image = True, class_name=True):
|
||||
od_predict = self.model.model.predict(source=self.image, show=False, stream=True, save=False, conf=self.MODEL_CONFIDENCE, verbose=False, imgsz=AI_CONST.MODEL_IMAGE_SIZE)
|
||||
message_list = []
|
||||
for object_list in od_predict:
|
||||
for box in object_list.boxes:
|
||||
_parsing_data = self.parsing(bbox=box)
|
||||
_name = _parsing_data.class_name if class_name else False
|
||||
message_list.append(self._od_mqtt_message(_parsing_data, image= self.image, crop=crop_image, name=_name))
|
||||
return message_list
|
||||
|
||||
def set_image(self,image):
|
||||
self.image = image
|
||||
|
||||
def set_model(self,model):
|
||||
self.model = model
|
||||
self.class_info = self.model.get_class_info()
|
||||
|
||||
def _od_mqtt_message(self,data,image,crop,name):
|
||||
message ={}
|
||||
|
||||
# if crop:
|
||||
# crop_img = image_encoding(crop_image(xyxy=data.bbox,img=image))
|
||||
# else:
|
||||
crop_img = None
|
||||
|
||||
if name:
|
||||
message = {
|
||||
"class_id" : data.class_id,
|
||||
"class_name" : name,
|
||||
"confidence" : data.bbox_conf,
|
||||
"bbox" : data.bbox,
|
||||
"object_id": data.object_id,
|
||||
"parent_object_id": None,
|
||||
"image": crop_img
|
||||
}
|
||||
|
||||
else:
|
||||
message = {
|
||||
"class_id" : data.class_id,
|
||||
"confidence" : data.bbox_conf,
|
||||
"bbox" : data.bbox,
|
||||
"object_id": data.object_id,
|
||||
"parent_object_id": None,
|
||||
"image": crop_img
|
||||
}
|
||||
|
||||
return message
|
||||
|
||||
def parsing(self, bbox, kpt=None):
|
||||
"""
|
||||
데이터 파싱
|
||||
kpt(pose) 정보가 없을시 keypoint 관련정보 None으로 처리
|
||||
object id가 없을시 None으로 처리
|
||||
|
||||
:param bbox: object detect 정보
|
||||
:param kpt: pose detect 정보
|
||||
:return: HistoryInfo
|
||||
"""
|
||||
|
||||
_cls_id = int(bbox.cls[0].item())
|
||||
_history_info = HistoryInfo(
|
||||
class_id = _cls_id,
|
||||
class_name = self.class_info[_cls_id],
|
||||
bbox = list(map(round, bbox.xyxy[0].detach().cpu().tolist())),
|
||||
bbox_conf = round(bbox.conf[0].item(), 2),
|
||||
)
|
||||
if bbox.id is not None:
|
||||
_history_info.object_id = int(bbox.id[0].item())
|
||||
|
||||
if kpt:
|
||||
_history_info.keypoint = kpt.xy[0].detach().cpu().tolist()
|
||||
_history_info.kpt_conf = kpt.conf[0].detach().cpu().tolist()
|
||||
|
||||
return _history_info
|
||||
|
||||
class PoseDetect(ObjectDetect):
|
||||
"""
|
||||
PoseDetect
|
||||
"""
|
||||
MODEL_CONFIDENCE = 0.5
|
||||
|
||||
FALLDOWN_TILT_RATIO = hpe_config.FALLDOWN_TILT_RATIO
|
||||
FALLDOWN_TILT_ANGLE = hpe_config.FALLDOWN_TILT_ANGLE
|
||||
CROSS_RATIO_THRESHOLD = hpe_config.CROSS_ARM_RATIO_THRESHOLD
|
||||
CROSS_ANGLE_THRESHOLD = hpe_config.CROSS_ARM_ANGLE_THRESHOLD
|
||||
|
||||
def predict(self,working,crop_image=True):
|
||||
|
||||
# pose_predict = self.model.model.predict(source=self.image, show=False, stream=True, save=False, conf=self.MODEL_CONFIDENCE, verbose=False, imgsz=AI_CONST.MODEL_IMAGE_SIZE)
|
||||
#NOTE(jwkim): pose predict
|
||||
pose_predict = self.model.model.track(source=self.image, show=False, stream=True, save=False, conf=self.MODEL_CONFIDENCE, verbose=False, imgsz=AI_CONST.MODEL_IMAGE_SIZE, persist=True)
|
||||
|
||||
message_list = []
|
||||
for object_list in pose_predict:
|
||||
for box, pose in zip(object_list.boxes, object_list.keypoints):
|
||||
_parsing_data = self.parsing(bbox=box, kpt=pose)
|
||||
current_pose={}
|
||||
current_pose["person"]= _parsing_data.bbox
|
||||
current_pose["keypoints"]= _parsing_data.keypoint
|
||||
current_pose["kpt_conf"] = _parsing_data.kpt_conf
|
||||
|
||||
# HPEClassification
|
||||
hpe_classification = HPEClassification(pose_info=current_pose)
|
||||
|
||||
current_pose["result"]=self._od_mqtt_message(_parsing_data, self.image, crop=crop_image, name=_parsing_data.class_name)
|
||||
current_pose["result"]['pose_type'] = hpe_classification.get_hpe_type(is_working_on=working)
|
||||
current_pose["result"]['pose_level'] = hpe_classification.get_hpe_level(is_working_on=working)
|
||||
|
||||
message_list.append(current_pose)
|
||||
return message_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user