diff --git a/mani_skill/agents/controllers/__init__.py b/mani_skill/agents/controllers/__init__.py index a0fdc446f..d2867b85e 100644 --- a/mani_skill/agents/controllers/__init__.py +++ b/mani_skill/agents/controllers/__init__.py @@ -15,6 +15,7 @@ from .pd_joint_pos_vel import PDJointPosVelController, PDJointPosVelControllerConfig from .passive_controller import PassiveController, PassiveControllerConfig from .pd_base_vel import PDBaseVelController, PDBaseVelControllerConfig +from .pd_base_vel import PDBaseForwardVelController, PDBaseForwardVelControllerConfig def deepcopy_dict(configs: dict): diff --git a/mani_skill/agents/controllers/pd_base_vel.py b/mani_skill/agents/controllers/pd_base_vel.py index 264d89cc9..7374c9004 100644 --- a/mani_skill/agents/controllers/pd_base_vel.py +++ b/mani_skill/agents/controllers/pd_base_vel.py @@ -1,7 +1,7 @@ import numpy as np import torch -from mani_skill.utils.geometry import rotate_2d_vec_by_angle +from gymnasium import spaces from mani_skill.utils.structs.types import Array from .pd_joint_vel import PDJointVelController, PDJointVelControllerConfig @@ -34,3 +34,41 @@ def set_action(self, action: Array): class PDBaseVelControllerConfig(PDJointVelControllerConfig): controller_cls = PDBaseVelController + + +class PDBaseForwardVelController(PDJointVelController): + """PDJointVelController for forward-only ego-centric base movement.""" + + def _initialize_action_space(self): + assert len(self.joints) >= 3, len(self.joints) + low = np.float32(np.broadcast_to(self.config.lower, 2)) + high = np.float32(np.broadcast_to(self.config.upper, 2)) + self.single_action_space = spaces.Box(low, high, dtype=np.float32) + + def set_action(self, action: Array): + action = self._preprocess_action(action) + # action[:, 0] should correspond to forward vel + # action[:, 1] should correspond to rotation vel + + # Convert to ego-centric action + # Assume the 3rd DoF stands for orientation + ori = self.qpos[:, 2] + rot_mat = torch.zeros(ori.shape[0], 2, 2, device=action.device) + rot_mat[:, 0, 0] = torch.cos(ori) + rot_mat[:, 0, 1] = -torch.sin(ori) + rot_mat[:, 1, 0] = torch.sin(ori) + rot_mat[:, 1, 1] = torch.cos(ori) + + # Assume the 1st DoF stands for forward movement + # make action with 0 y vel + move_action = action.clone() + move_action[:, 1] = 0 + vel = (rot_mat @ move_action.float().unsqueeze(-1)).squeeze(-1) + new_action = torch.hstack([vel, action[:, 1:]]) + self.articulation.set_joint_drive_velocity_targets( + new_action, self.joints, self.active_joint_indices + ) + + +class PDBaseForwardVelControllerConfig(PDJointVelControllerConfig): + controller_cls = PDBaseForwardVelController diff --git a/mani_skill/agents/robots/fetch/fetch.py b/mani_skill/agents/robots/fetch/fetch.py index dfd781d2a..efd6746d2 100644 --- a/mani_skill/agents/robots/fetch/fetch.py +++ b/mani_skill/agents/robots/fetch/fetch.py @@ -250,10 +250,10 @@ def _controller_configs(self): # -------------------------------------------------------------------------- # # Base # -------------------------------------------------------------------------- # - base_pd_joint_vel = PDBaseVelControllerConfig( + base_pd_joint_vel = PDBaseForwardVelControllerConfig( self.base_joint_names, - lower=[-0.5, -0.5, -3.14], - upper=[0.5, 0.5, 3.14], + lower=[-0.5, -3.14], + upper=[0.5, 3.14], damping=1000, force_limit=500, )