Skip to content

Commit

Permalink
Merge pull request #2 from AmitMY/patch-1
Browse files Browse the repository at this point in the history
simplify pose2text processor
  • Loading branch information
GerrySant authored Nov 19, 2024
2 parents 1a0f894 + be06981 commit b006c1b
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions multimodalhugs/processors/pose2text_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from multimodalhugs.processors import MultimodalSecuence2TextTranslationProcessor
from pose_format import Pose
from pose_format.utils.generic import reduce_holistic, pose_hide_legs
from pose_format.utils.generic import reduce_holistic, pose_hide_legs, pose_normalization_info

logger = logging.getLogger(__name__)

Expand All @@ -37,21 +37,18 @@ def __init__(
super().__init__(tokenizer=tokenizer, **kwargs)

def _pose_file_to_tensor(self, pose_file: Union[str, Path]):
pose_file = open(pose_file, "rb").read()
pose = Pose.read(pose_file) # [t, people, d, xyz]

P1 = ("POSE_LANDMARKS", "RIGHT_SHOULDER") if pose.header.components[0].name == "POSE_LANDMARKS" else ("BODY_135", "RShoulder")
P2 = ("POSE_LANDMARKS", "LEFT_SHOULDER") if pose.header.components[0].name == "POSE_LANDMARKS" else ("BODY_135", "LShoulder")

with open(pose_file, "rb") as pose_file:
pose = Pose.read(pose_file.read()) # [t, people, d, xyz]

pose_hide_legs(pose)

if self.reduce_holistic_poses:
# This will be skipped if the pose is not holistic
pose = reduce_holistic(pose) # [t, people, d', xyz]

pose = pose.normalize(pose.header.normalization_info(p1=P1, p2=P2))
pose = pose.torch().body.data.squeeze(1) # [t, d', xyz]
pose = pose.reshape(shape=(pose.shape[0], -1))
return pose.zero_filled()
pose = pose.normalize(pose_normalization_info(pose.header))
tensor = pose.torch().body.data.zero_filled()
return tensor.contiguous().view(tensor.size(0), -1)

def _obtain_multimodal_input_and_masks(self, batch, **kwargs):
tensor_secuences = [self._pose_file_to_tensor(sample["source"]) for sample in batch]
Expand Down

0 comments on commit b006c1b

Please sign in to comment.