Skip to content

Commit

Permalink
Fix: fetch move forward/backward only (#386)
Browse files Browse the repository at this point in the history
  • Loading branch information
arth-shukla authored Jun 22, 2024
1 parent 374d596 commit 0875df7
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
1 change: 1 addition & 0 deletions mani_skill/agents/controllers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .pd_joint_pos_vel import PDJointPosVelController, PDJointPosVelControllerConfig
from .passive_controller import PassiveController, PassiveControllerConfig
from .pd_base_vel import PDBaseVelController, PDBaseVelControllerConfig
from .pd_base_vel import PDBaseForwardVelController, PDBaseForwardVelControllerConfig


def deepcopy_dict(configs: dict):
Expand Down
40 changes: 39 additions & 1 deletion mani_skill/agents/controllers/pd_base_vel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import torch

from mani_skill.utils.geometry import rotate_2d_vec_by_angle
from gymnasium import spaces
from mani_skill.utils.structs.types import Array

from .pd_joint_vel import PDJointVelController, PDJointVelControllerConfig
Expand Down Expand Up @@ -34,3 +34,41 @@ def set_action(self, action: Array):

class PDBaseVelControllerConfig(PDJointVelControllerConfig):
controller_cls = PDBaseVelController


class PDBaseForwardVelController(PDJointVelController):
"""PDJointVelController for forward-only ego-centric base movement."""

def _initialize_action_space(self):
assert len(self.joints) >= 3, len(self.joints)
low = np.float32(np.broadcast_to(self.config.lower, 2))
high = np.float32(np.broadcast_to(self.config.upper, 2))
self.single_action_space = spaces.Box(low, high, dtype=np.float32)

def set_action(self, action: Array):
action = self._preprocess_action(action)
# action[:, 0] should correspond to forward vel
# action[:, 1] should correspond to rotation vel

# Convert to ego-centric action
# Assume the 3rd DoF stands for orientation
ori = self.qpos[:, 2]
rot_mat = torch.zeros(ori.shape[0], 2, 2, device=action.device)
rot_mat[:, 0, 0] = torch.cos(ori)
rot_mat[:, 0, 1] = -torch.sin(ori)
rot_mat[:, 1, 0] = torch.sin(ori)
rot_mat[:, 1, 1] = torch.cos(ori)

# Assume the 1st DoF stands for forward movement
# make action with 0 y vel
move_action = action.clone()
move_action[:, 1] = 0
vel = (rot_mat @ move_action.float().unsqueeze(-1)).squeeze(-1)
new_action = torch.hstack([vel, action[:, 1:]])
self.articulation.set_joint_drive_velocity_targets(
new_action, self.joints, self.active_joint_indices
)


class PDBaseForwardVelControllerConfig(PDJointVelControllerConfig):
controller_cls = PDBaseForwardVelController
6 changes: 3 additions & 3 deletions mani_skill/agents/robots/fetch/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,10 @@ def _controller_configs(self):
# -------------------------------------------------------------------------- #
# Base
# -------------------------------------------------------------------------- #
base_pd_joint_vel = PDBaseVelControllerConfig(
base_pd_joint_vel = PDBaseForwardVelControllerConfig(
self.base_joint_names,
lower=[-0.5, -0.5, -3.14],
upper=[0.5, 0.5, 3.14],
lower=[-0.5, -3.14],
upper=[0.5, 3.14],
damping=1000,
force_limit=500,
)
Expand Down

0 comments on commit 0875df7

Please sign in to comment.