Skip to content

Commit

Permalink
[BugFix] Fix trajectory replay saving the wrong first observation if …
Browse files Browse the repository at this point in the history
…using use first env state (#369)

* work

* fixes

* motion plan in gpu env with pd joint pos, fix bug with recording actions

* Update run.py

* work

* work

* Update sac_rfcl.py

* Update README.md

* fixes

* attempts

* undo changes

* Update pick_cube.py

* Delete examples/baselines/sac-rfcl/ppo_rfcl.py
  • Loading branch information
StoneT2000 authored Jun 6, 2024
1 parent a11bd08 commit 5f4d7a1
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 41 deletions.
55 changes: 26 additions & 29 deletions mani_skill/examples/benchmarking/gpu_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,31 +96,32 @@ def main(args):
env.reset()
profiler.log_stats("env.step+env.reset")
env.close()
# append results to csv
try:
assert (
args.save_video == False
), "Saving video slows down speed a lot and it will distort results"
Path("benchmark_results").mkdir(parents=True, exist_ok=True)
data = dict(
env_id=args.env_id,
obs_mode=args.obs_mode,
num_envs=args.num_envs,
control_mode=args.control_mode,
gpu_type=torch.cuda.get_device_name()
)
if args.env_id in BENCHMARK_ENVS:
data.update(
num_cameras=args.num_cams,
camera_width=args.cam_width,
camera_height=args.cam_height,
if args.save_results:
# append results to csv
try:
assert (
args.save_video == False
), "Saving video slows down speed a lot and it will distort results"
Path("benchmark_results").mkdir(parents=True, exist_ok=True)
data = dict(
env_id=args.env_id,
obs_mode=args.obs_mode,
num_envs=args.num_envs,
control_mode=args.control_mode,
gpu_type=torch.cuda.get_device_name()
)
profiler.update_csv(
"benchmark_results/maniskill.csv",
data,
)
except:
pass
if args.env_id in BENCHMARK_ENVS:
data.update(
num_cameras=args.num_cams,
camera_width=args.cam_width,
camera_height=args.cam_height,
)
profiler.update_csv(
"benchmark_results/maniskill.csv",
data,
)
except:
pass


def parse_args():
Expand All @@ -145,11 +146,7 @@ def parse_args():
"--save-video", action="store_true", help="whether to save videos"
)
parser.add_argument(
"-f",
"--format",
type=str,
default="stdout",
help="format of results. Can be stdout or json.",
"--save-results", action="store_true", help="whether to save results to a csv file"
)
args = parser.parse_args()
return args
Expand Down
4 changes: 3 additions & 1 deletion mani_skill/examples/motionplanning/panda/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def parse_args(args=None):
parser.add_argument("-o", "--obs-mode", type=str, default="none", help="Observation mode to use. Usually this is kept as 'none' as observations are not necesary to be stored, they can be replayed later via the mani_skill.trajectory.replay_trajectory script.")
parser.add_argument("-n", "--num-traj", type=int, default=10, help="Number of trajectories to generate.")
parser.add_argument("--reward-mode", type=str)
parser.add_argument("-b", "--sim-backend", type=str, default="auto", help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'")
parser.add_argument("--render-mode", type=str, default="rgb_array", help="can be 'sensors' or 'rgb_array' which only affect what is saved to videos")
parser.add_argument("--vis", action="store_true", help="whether or not to open a GUI to visualize the solution live")
parser.add_argument("--save-video", action="store_true", help="whether or not to save videos locally")
Expand All @@ -33,7 +34,8 @@ def main(args):
control_mode="pd_joint_pos",
render_mode=args.render_mode,
reward_mode="dense" if args.reward_mode is None else args.reward_mode,
shader_dir=args.shader
shader_dir=args.shader,
sim_backend=args.sim_backend
)
if env_id not in MP_SOLUTIONS:
raise RuntimeError(f"No already written motion planning solutions for {env_id}. Available options are {list(MP_SOLUTIONS.keys())}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def solve(env: PickCubeEnv, seed=None, debug=False, vis=False):

approaching = np.array([0, 0, -1])
# get transformation matrix of the tcp pose, is default batched and on torch
target_closing = env.agent.tcp.pose.to_transformation_matrix()[0, :3, 1].numpy()
target_closing = env.agent.tcp.pose.to_transformation_matrix()[0, :3, 1].cpu().numpy()
# we can build a simple grasp pose using this information for Panda
grasp_info = compute_grasp_info_by_obb(
obb,
Expand All @@ -35,7 +35,7 @@ def solve(env: PickCubeEnv, seed=None, debug=False, vis=False):
depth=FINGER_LENGTH,
)
closing, center = grasp_info["closing"], grasp_info["center"]
grasp_pose = env.agent.build_grasp_pose(approaching, closing, center)
grasp_pose = env.agent.build_grasp_pose(approaching, closing, env.cube.pose.sp.p)

# -------------------------------------------------------------------------- #
# Reach
Expand All @@ -52,7 +52,7 @@ def solve(env: PickCubeEnv, seed=None, debug=False, vis=False):
# -------------------------------------------------------------------------- #
# Move to goal pose
# -------------------------------------------------------------------------- #
goal_pose = sapien.Pose(env.goal_site.pose.p[0], grasp_pose.q)
goal_pose = sapien.Pose(env.goal_site.pose.sp.p, grasp_pose.q)
res = planner.move_to_pose_with_screw(goal_pose)

planner.close()
Expand Down
30 changes: 27 additions & 3 deletions mani_skill/trajectory/replay_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from mani_skill.trajectory import utils as trajectory_utils
from mani_skill.trajectory.merge_trajectory import merge_h5
from mani_skill.utils import common, gym_utils, io_utils, wrappers
from mani_skill.utils.structs.link import Link


def qpos_to_pd_joint_delta_pos(controller: PDJointPosController, qpos):
Expand Down Expand Up @@ -103,10 +104,15 @@ def from_pd_joint_pos_to_ee(
# given target joint positions instead of current joint positions.
# Thus, we need to compute forward kinematics
pin_model = ori_controller.articulation.create_pinocchio_model()
assert (
"arm" in ori_controller.controllers
), "Could not find the controller for the robot arm. This controller conversion tool requires there to be a key called 'arm' in the controller"
ori_arm_controller: PDJointPosController = ori_controller.controllers["arm"]
arm_controller: PDEEPoseController = controller.controllers["arm"]
assert arm_controller.config.frame == "ee"
ee_link: sapien.Link = arm_controller.ee_link
assert (
arm_controller.config.frame == "root_translation:root_aligned_body_rotation"
), "Currently only support the 'root_translation:root_aligned_body_rotation' ee control frame"
ee_link: Link = arm_controller.ee_link

info = {}

Expand All @@ -132,7 +138,7 @@ def from_pd_joint_pos_to_ee(

flag = True

for _ in range(2):
for _ in range(4):
if target_mode:
prev_ee_pose_at_base = arm_controller._target_pose
else:
Expand Down Expand Up @@ -485,9 +491,13 @@ def _main(args, proc_id: int = 0, num_procs=1, pbar=None):
ori_control_mode = ep["control_mode"]

for _ in range(args.max_retry + 1):
# Each trial for each trajectory to replay, we reset the environment
# and optionally set the first environment state
env.reset(seed=seed, **reset_kwargs)
if ori_env is not None:
ori_env.reset(seed=seed, **reset_kwargs)

# set first environment state and update recorded env state
if args.use_first_env_state or args.use_env_states:
ori_env_states = trajectory_utils.dict_to_list_of_dicts(
ori_h5_file[traj_id]["env_states"]
Expand All @@ -509,11 +519,21 @@ def recursive_replace(x, y):
recursive_replace(
env._trajectory_buffer.state, common.batch(ori_env_states[0])
)
fixed_obs = env.base_env.get_obs()
recursive_replace(
env._trajectory_buffer.observation,
common.to_numpy(common.batch(fixed_obs)),
)
# Original actions to replay
ori_actions = ori_h5_file[traj_id]["actions"][:]
info = {}

# Without conversion between control modes
assert not (
target_control_mode is not None and args.use_env_states
), "Cannot use env states when trying to \
convert from one control mode to another. This is because control mode conversion causes there to be changes \
in how many actions are taken to achieve the same states"
if target_control_mode is None:
n = len(ori_actions)
if pbar is not None:
Expand Down Expand Up @@ -550,6 +570,10 @@ def recursive_replace(x, y):
pbar=pbar,
verbose=args.verbose,
)
else:
raise NotImplementedError(
f"Script currently does not support converting {ori_control_mode} to {target_control_mode}"
)

success = info.get("success", False)
if args.discard_timeout:
Expand Down
7 changes: 2 additions & 5 deletions mani_skill/utils/wrappers/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,15 +331,14 @@ def reset(
# if we reconfigure, there is the possibility that state dictionary looks different now
# so trajectory buffer must be wiped
self._trajectory_buffer = None

if self.save_trajectory:
state_dict = self.base_env.get_state_dict()
action = common.batch(self.action_space.sample())
first_step = Step(
state=common.to_numpy(common.batch(state_dict)),
observation=common.to_numpy(common.batch(obs)),
# note first reward/action etc. are ignored when saving trajectories to disk
action=action,
action=common.to_numpy(common.batch(action)),
reward=np.zeros(
(
1,
Expand All @@ -356,9 +355,6 @@ def reset(
fail=np.zeros((1, self.num_envs), dtype=bool),
env_episode_ptr=np.zeros((self.num_envs,), dtype=int),
)
if self.num_envs == 1:
first_step.observation = common.batch(first_step.observation)
first_step.action = common.batch(first_step.action)
env_idx = np.arange(self.num_envs)
if "env_idx" in options:
env_idx = common.to_numpy(options["env_idx"])
Expand Down Expand Up @@ -579,6 +575,7 @@ def recursive_add_to_h5py(group: h5py.Group, data: dict, key):
episode_info.update(reset_kwargs=dict())

# slice some data to remove the first dummy frame.

actions = common.index_dict_array(
self._trajectory_buffer.action, (slice(start_ptr + 1, end_ptr), env_idx)
)
Expand Down

0 comments on commit 5f4d7a1

Please sign in to comment.