[Bug Report] MARL Quadcopter Reset function problem #1625
-
Hi everyone, I'm not sure why, but at some point during my training, the pre_physic_step function starts generating NaN values. I've checked and rechecked the code, and the values seem to appear out of nowhere. I checked for division by zero, but none of my values are close to zero. Has anyone experienced this behavior before? I might have misunderstood how the MARL workflow is intended to be designed, as I'm new to experimenting with it. My goal is to create a multi-agent formation task for drones. Currently, my drones spawn, but they disappear as soon as the NaN values appear. Here is my first attempt at designing the task. Many functions are still incomplete, as I’m just trying to figure out why the actions are producing NaN values. Could the issue come from the reset_index function? Thanks for your help! class formationEnv(DirectMARLEnv):
cfg: formationEnvCfg
def __init__(self, cfg: formationEnvCfg, render_mode: str | None = None, **kwargs):
super().__init__(cfg, render_mode, **kwargs)
print("Initialization start")
# Initialize robots, cameras, and terrain for multi-agent setup
self._robots = [self.scene[f"robot_{i+1}"] for i in range(len(self.cfg.possible_agents))]
print(f"Robots initialized: {self._robots}")
self.robot1 = self.scene["robot_1"]
self.robot2 = self.scene["robot_2"]
self.robot3 = self.scene["robot_3"]
self.robot4 = self.scene["robot_4"]
self.robot5 = self.scene["robot_5"]
self.robot6 = self.scene["robot_6"]
self._terrain = self.scene["terrain"]
print("Terrain and robots set")
num_agents = len(self.cfg.possible_agents)
self._actions = torch.zeros(self.num_envs, num_agents, self.cfg.individual_action_space, device=self.device)
self._thrust = torch.zeros(self.num_envs, num_agents, 3, device=self.device)
self._moment = torch.zeros(self.num_envs, num_agents, 3, device=self.device)
print("Actions, thrust, and moments initialized")
# Logging for each robot
self._episode_sums = {
key: torch.zeros(self.num_envs, num_agents, dtype=torch.float, device=self.device)
for key in ["lin_vel", "ang_vel", "distance_to_goal"]
}
print("Episode sums initialized")
self._body_id1 = self.robot1.find_bodies("body")[0]
self._body_id2 = self.robot2.find_bodies("body")[0]
self._body_id3 = self.robot3.find_bodies("body")[0]
self._body_id4 = self.robot4.find_bodies("body")[0]
self._body_id5 = self.robot5.find_bodies("body")[0]
self._body_id6 = self.robot6.find_bodies("body")[0]
print("Body IDs set for all robots")
self._robot_mass = self.robot1.root_physx_view.get_masses()[0].sum()
self._gravity_magnitude = torch.tensor(self.sim.cfg.gravity, device=self.device).norm()
self._robot_weight = (self._robot_mass * self._gravity_magnitude).item()
print(f"Gravity and robot weight computed: {self._robot_weight}")
# Debug visualization
self.set_debug_vis(self.cfg.debug_vis)
print("Debug visualization set")
def _get_observations(self) -> dict[str, torch.Tensor]:
observations = {}
for i, robot in enumerate(self._robots):
state = torch.zeros(60, device=self.device) # Dummy state
if torch.any(torch.isnan(state)):
print(f"Warning: NaN detected in observation for robot_{i+1}. Replacing with zeros.")
state = torch.zeros_like(state)
observations[f"robot_{i+1}"] = state
return observations
def _get_states(self) -> torch.Tensor:
states = torch.zeros(self.num_envs, self.cfg.state_space, device=self.device)
return states
def _get_rewards(self) -> dict[str, torch.Tensor]:
rewards = torch.zeros(self.num_envs, len(self._robots), device=self.device)
if torch.any(torch.isnan(rewards)):
print("Warning: NaN detected in rewards. Replacing with zeros.")
rewards = torch.zeros_like(rewards)
rewards_dict = {
f"robot_{i+1}": rewards[:, i]
for i in range(len(self._robots))
}
return rewards_dict
def _get_dones(self) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
time_out = self.episode_length_buf >= self.max_episode_length - 1
terminated = {}
time_outs = {}
for i, agent_name in enumerate(self.cfg.possible_agents):
low_altitude = self._robots[i].data.root_pos_w[:, 2] < 0.1
high_altitude = self._robots[i].data.root_pos_w[:, 2] > 100
altitude_violation = torch.logical_or(low_altitude, high_altitude)
x_constraint = self._robots[i].data.root_pos_w[:, 0] < -22
done_condition = torch.logical_or(altitude_violation, x_constraint)
terminated[agent_name] = done_condition
time_outs[agent_name] = time_out
return terminated, time_outs
'''
def _pre_physics_step(self, actions: dict[str, torch.Tensor]) -> None:
print("Pre-physics step")
for i, action_key in enumerate(actions.keys()):
action = actions[action_key]
print(f"Action for {action_key}: {action}")
temp_act = action.clone().clamp(-1.0, 1.0)
self._actions[:, i, :] = temp_act
self._thrust[:, i, 2] = (self.cfg.thrust_to_weight * self._robot_weight * (temp_act[:, 0] + 1.0) / 2.0)
self._moment[:, i, :] = self.cfg.moment_scale * temp_act[:, 1:]
print(f"Thrust for {action_key}: {self._thrust[:, i, 2]}")
print(f"Moment for {action_key}: {self._moment[:, i, :]}")
'''
def _pre_physics_step(self, actions: dict[str, torch.tensor]) -> None:
action = actions["robot_1"]
temp_act = action.clone().clamp(-1.0, 1.0)
self._actions[:, 0, :] = temp_act
self._thrust[:, 0, 2] = (self.cfg.thrust_to_weight * self._robot_weight * (temp_act[:, 0] + 1.0) / 2.0)
self._moment[:, 0, :] = self.cfg.moment_scale * temp_act[:, 1:]
print("Pre-physics step")
print("Trust:",self._thrust[:, 0, 2])
print("Moment:",self._moment[:, 0, :])
print("Actions",action)
print("Pre-physics step done")
def _apply_action(self) -> None:
self.robot1.set_external_force_and_torque(self._thrust[:, self._body_id1, :], self._moment[:, self._body_id1, :], body_ids=self._body_id1)
self.robot2.set_external_force_and_torque(self._thrust[:, self._body_id2, :], self._moment[:, self._body_id2, :], body_ids=self._body_id2)
self.robot3.set_external_force_and_torque(self._thrust[:, self._body_id3, :], self._moment[:, self._body_id3, :], body_ids=self._body_id3)
self.robot4.set_external_force_and_torque(self._thrust[:, self._body_id4, :], self._moment[:, self._body_id4, :], body_ids=self._body_id4)
self.robot5.set_external_force_and_torque(self._thrust[:, self._body_id5, :], self._moment[:, self._body_id5, :], body_ids=self._body_id5)
self.robot6.set_external_force_and_torque(self._thrust[:, self._body_id6, :], self._moment[:, self._body_id6, :], body_ids=self._body_id6)
def _reset_idx(self, env_ids: Sequence[int] | torch.Tensor | None):
print("Reset start")
if env_ids is None or len(env_ids) == self.num_envs:
env_ids = self.robot1._ALL_INDICES
super()._reset_idx(env_ids)
for i, agent_name in enumerate(self.cfg.possible_agents):
joint_pos = self._robots[i].data.default_joint_pos[env_ids]
joint_vel = self._robots[i].data.default_joint_vel[env_ids]
default_root_state = self._robots[i].data.default_root_state[env_ids]
default_root_state[:, :3] += self._terrain.env_origins[env_ids]
self._robots[i].write_root_pose_to_sim(default_root_state[:, :7], env_ids)
self._robots[i].write_root_velocity_to_sim(default_root_state[:, 7:], env_ids)
self._robots[i].write_joint_state_to_sim(joint_pos, joint_vel, None, env_ids)
print("Reset done") and below the values of my actions, and i have check again but the others functions don't modify the actions values Pre-physics step
Trust: tensor([0.0665, 0.0665], device='cuda:0')
Moment: tensor([[-0.0083, -0.0009, 0.0012],
[-0.0083, -0.0009, 0.0012]], device='cuda:0')
Actions tensor([[-0.7469, -0.8330, -0.0936, 0.1228]], device='cuda:0')
Pre-physics step done
Pre-physics step
Trust: tensor([nan, nan], device='cuda:0')
Moment: tensor([[nan, nan, nan],
[nan, nan, nan]], device='cuda:0')
Actions tensor([[nan, nan, nan, nan]], device='cuda:0')
Pre-physics step done |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 11 replies
-
Thank you for posting this. This may be related to the batch sizes available per GPU and learners assigned to them. I'm moving this into a discussion for the team to follow up. In the meantime, could you try less robots and ensure your batch sizes are of reasonable size? |
Beta Was this translation helpful? Give feedback.
-
I've been playing around with a custom MARL set-up with two quadrupeds and I have not encountered any NaN issues. Could you try pdb-ing or print out every line to determine where the first NaN is occurring? |
Beta Was this translation helpful? Give feedback.
-
Thanks for following up, great suggestions @Rishi-V. @JulienHansen, let us know if still have this issue. |
Beta Was this translation helpful? Give feedback.
I just found the error , the reset function was not correct and was causing Nan value to be generated, here is a corrected snippet of the corrected version :