Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 25, 2025
2 parents d6b5d68 + 8e89024 commit a737d09
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 103 deletions.
5 changes: 4 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8409,7 +8409,6 @@ def test_ppo_composite_no_aggregate(
if isinstance(loss_fn, KLPENPPOLoss):
kl = loss.pop("kl_approx")
assert (kl != 0).any()

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -8637,12 +8636,16 @@ def test_ppo_shared_seq(
)

loss = loss_fn(td).exclude("entropy")
if composite_action_dist:
loss = loss.exclude("composite_entropy")

sum(val for key, val in loss.items() if key.startswith("loss_")).backward()
grad = TensorDict(dict(model.named_parameters()), []).apply(
lambda x: x.grad.clone()
)
loss2 = loss_fn2(td).exclude("entropy")
if composite_action_dist:
loss2 = loss2.exclude("composite_entropy")

model.zero_grad()
sum(val for key, val in loss2.items() if key.startswith("loss_")).backward()
Expand Down
124 changes: 103 additions & 21 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3291,6 +3291,10 @@ def test_batched_dynamic(self, break_when_any_done):
)
del env_no_buffers
gc.collect()
# print(dummy_rollouts)
# print(rollout_no_buffers_serial)
# # for a, b in zip(dummy_rollouts.exclude("action").unbind(0), rollout_no_buffers_serial.exclude("action").unbind(0)):
# assert_allclose_td(a, b)
assert_allclose_td(
dummy_rollouts.exclude("action"),
rollout_no_buffers_serial.exclude("action"),
Expand Down Expand Up @@ -3386,35 +3390,107 @@ def test_partial_rest(self, batched):

# fen strings for board positions generated with:
# https://lichess.org/editor
@pytest.mark.parametrize("stateful", [False, True])
@pytest.mark.skipif(not _has_chess, reason="chess not found")
class TestChessEnv:
def test_env(self, stateful):
env = ChessEnv(stateful=stateful)
check_env_specs(env)
@pytest.mark.parametrize("include_pgn", [False, True])
@pytest.mark.parametrize("include_fen", [False, True])
@pytest.mark.parametrize("stateful", [False, True])
def test_env(self, stateful, include_pgn, include_fen):
with pytest.raises(
RuntimeError, match="At least one state representation"
) if not stateful and not include_pgn and not include_fen else contextlib.nullcontext():
env = ChessEnv(
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
)
check_env_specs(env)

def test_rollout(self, stateful):
env = ChessEnv(stateful=stateful)
env.rollout(5000)
def test_pgn_bijectivity(self):
np.random.seed(0)
pgn = ChessEnv._PGN_RESTART
board = ChessEnv._pgn_to_board(pgn)
pgn_prev = pgn
for _ in range(10):
moves = list(board.legal_moves)
move = np.random.choice(moves)
board.push(move)
pgn_move = ChessEnv._board_to_pgn(board)
assert pgn_move != pgn_prev
assert pgn_move == ChessEnv._board_to_pgn(ChessEnv._pgn_to_board(pgn_move))
assert pgn_move == ChessEnv._add_move_to_pgn(pgn_prev, move)
pgn_prev = pgn_move

def test_consistency(self):
env0_stateful = ChessEnv(stateful=True, include_pgn=True, include_fen=True)
env1_stateful = ChessEnv(stateful=True, include_pgn=False, include_fen=True)
env2_stateful = ChessEnv(stateful=True, include_pgn=True, include_fen=False)
env0_stateless = ChessEnv(stateful=False, include_pgn=True, include_fen=True)
env1_stateless = ChessEnv(stateful=False, include_pgn=False, include_fen=True)
env2_stateless = ChessEnv(stateful=False, include_pgn=True, include_fen=False)
torch.manual_seed(0)
r1_stateless = env1_stateless.rollout(50, break_when_any_done=False)
torch.manual_seed(0)
r1_stateful = env1_stateful.rollout(50, break_when_any_done=False)
torch.manual_seed(0)
r2_stateless = env2_stateless.rollout(50, break_when_any_done=False)
torch.manual_seed(0)
r2_stateful = env2_stateful.rollout(50, break_when_any_done=False)
torch.manual_seed(0)
r0_stateless = env0_stateless.rollout(50, break_when_any_done=False)
torch.manual_seed(0)
r0_stateful = env0_stateful.rollout(50, break_when_any_done=False)
assert (r0_stateless["action"] == r1_stateless["action"]).all()
assert (r0_stateless["action"] == r2_stateless["action"]).all()
assert (r0_stateless["action"] == r0_stateful["action"]).all()
assert (r1_stateless["action"] == r1_stateful["action"]).all()
assert (r2_stateless["action"] == r2_stateful["action"]).all()

@pytest.mark.parametrize(
"include_fen,include_pgn", [[True, False], [False, True], [True, True]]
)
@pytest.mark.parametrize("stateful", [False, True])
def test_rollout(self, stateful, include_pgn, include_fen):
torch.manual_seed(0)
env = ChessEnv(
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
)
r = env.rollout(500, break_when_any_done=False)
assert r.shape == (500,)

def test_reset_white_to_move(self, stateful):
env = ChessEnv(stateful=stateful)
@pytest.mark.parametrize(
"include_fen,include_pgn", [[True, False], [False, True], [True, True]]
)
@pytest.mark.parametrize("stateful", [False, True])
def test_reset_white_to_move(self, stateful, include_pgn, include_fen):
env = ChessEnv(
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
)
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
td = env.reset(TensorDict({"fen": fen}))
assert td["fen"] == fen
if include_fen:
assert env.board.fen() == fen
assert td["turn"] == env.lib.WHITE
assert not td["done"]

def test_reset_black_to_move(self, stateful):
env = ChessEnv(stateful=stateful)
@pytest.mark.parametrize("include_fen,include_pgn", [[True, False], [True, True]])
@pytest.mark.parametrize("stateful", [False, True])
def test_reset_black_to_move(self, stateful, include_pgn, include_fen):
env = ChessEnv(
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
)
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1"
td = env.reset(TensorDict({"fen": fen}))
assert td["fen"] == fen
assert env.board.fen() == fen
assert td["turn"] == env.lib.BLACK
assert not td["done"]

def test_reset_done_error(self, stateful):
env = ChessEnv(stateful=stateful)
@pytest.mark.parametrize("include_fen,include_pgn", [[True, False], [True, True]])
@pytest.mark.parametrize("stateful", [False, True])
def test_reset_done_error(self, stateful, include_pgn, include_fen):
env = ChessEnv(
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
)
fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1"
with pytest.raises(ValueError) as e_info:
env.reset(TensorDict({"fen": fen}))
Expand All @@ -3425,12 +3501,19 @@ def test_reset_done_error(self, stateful):
@pytest.mark.parametrize(
"endstate", ["white win", "black win", "stalemate", "50 move", "insufficient"]
)
def test_reward(self, stateful, reset_without_fen, endstate):
@pytest.mark.parametrize("include_pgn", [False, True])
@pytest.mark.parametrize("include_fen", [True])
@pytest.mark.parametrize("stateful", [False, True])
def test_reward(
self, stateful, reset_without_fen, endstate, include_pgn, include_fen
):
if stateful and reset_without_fen:
# reset_without_fen is only used for stateless env
return

env = ChessEnv(stateful=stateful)
env = ChessEnv(
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
)

if endstate == "white win":
fen = "5k2/2R5/8/8/8/1R6/2K5/8 w - - 0 1"
Expand All @@ -3443,28 +3526,28 @@ def test_reward(self, stateful, reset_without_fen, endstate):
fen = "5k2/6r1/8/8/8/8/7r/1K6 b - - 0 1"
expected_turn = env.lib.BLACK
move = "Rg1#"
expected_reward = -1
expected_reward = 1
expected_done = True

elif endstate == "stalemate":
fen = "5k2/6r1/8/8/8/8/7r/K7 b - - 0 1"
expected_turn = env.lib.BLACK
move = "Rb7"
expected_reward = 0
expected_reward = 0.5
expected_done = True

elif endstate == "insufficient":
fen = "5k2/8/8/8/3r4/2K5/8/8 w - - 0 1"
expected_turn = env.lib.WHITE
move = "Kxd4"
expected_reward = 0
expected_reward = 0.5
expected_done = True

elif endstate == "50 move":
fen = "5k2/8/1R6/8/6r1/2K5/8/8 b - - 99 123"
expected_turn = env.lib.BLACK
move = "Kf7"
expected_reward = 0
expected_reward = 0.5
expected_done = True

elif endstate == "not_done":
Expand All @@ -3483,8 +3566,7 @@ def test_reward(self, stateful, reset_without_fen, endstate):
td = env.reset(TensorDict({"fen": fen}))
assert td["turn"] == expected_turn

moves = env.get_legal_moves(None if stateful else td)
td["action"] = moves.index(move)
td["action"] = env._san_moves.index(move)
td = env.step(td)["next"]
assert td["done"] == expected_done
assert td["reward"] == expected_reward
Expand Down
79 changes: 79 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,85 @@ def test_single_trans_env_check(self, out_keys):
)
check_env_specs(env)

@pytest.mark.parametrize("cat_dim", [-1, -2, -3])
@pytest.mark.parametrize("cat_N", [3, 10])
@pytest.mark.parametrize("device", get_default_devices())
def test_with_permute_no_env(self, cat_dim, cat_N, device):
torch.manual_seed(cat_dim * cat_N)
pixels = torch.randn(8, 5, 3, 10, 4, device=device)

a = TensorDict(
{
"pixels": pixels,
},
[
pixels.shape[0],
],
device=device,
)

t0 = Compose(
CatFrames(N=cat_N, dim=cat_dim),
)

def get_rand_perm(ndim):
cat_dim_perm = cat_dim
# Ensure that the permutation moves the cat_dim
while cat_dim_perm == cat_dim:
perm_pos = torch.randperm(ndim)
perm = perm_pos - ndim
cat_dim_perm = (perm == cat_dim).nonzero().item() - ndim
perm_inv = perm_pos.argsort() - ndim
return perm.tolist(), perm_inv.tolist(), cat_dim_perm

perm, perm_inv, cat_dim_perm = get_rand_perm(pixels.dim() - 1)

t1 = Compose(
PermuteTransform(perm, in_keys=["pixels"]),
CatFrames(N=cat_N, dim=cat_dim_perm),
PermuteTransform(perm_inv, in_keys=["pixels"]),
)

b = t0._call(a.clone())
c = t1._call(a.clone())
assert (b == c).all()

@pytest.mark.skipif(not _has_gym, reason="Test executed on gym")
@pytest.mark.parametrize("cat_dim", [-1, -2])
def test_with_permute_env(self, cat_dim):
env0 = TransformedEnv(
GymEnv("Pendulum-v1"),
Compose(
UnsqueezeTransform(-1, in_keys=["observation"]),
CatFrames(N=4, dim=cat_dim, in_keys=["observation"]),
),
)

env1 = TransformedEnv(
GymEnv("Pendulum-v1"),
Compose(
UnsqueezeTransform(-1, in_keys=["observation"]),
PermuteTransform((-1, -2), in_keys=["observation"]),
CatFrames(N=4, dim=-3 - cat_dim, in_keys=["observation"]),
PermuteTransform((-1, -2), in_keys=["observation"]),
),
)

torch.manual_seed(0)
env0.set_seed(0)
td0 = env0.reset()

torch.manual_seed(0)
env1.set_seed(0)
td1 = env1.reset()

assert (td0 == td1).all()

td0 = env0.step(td0.update(env0.full_action_spec.rand()))
td1 = env0.step(td0.update(env1.full_action_spec.rand()))

assert (td0 == td1).all()

def test_serial_trans_env_check(self):
env = SerialEnv(
2,
Expand Down
23 changes: 12 additions & 11 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,17 +718,13 @@ def _create_td(self) -> None:
env_output_keys = set()
env_obs_keys = set()
for meta_data in self.meta_data:
env_obs_keys = env_obs_keys.union(
key
for key in meta_data.specs["output_spec"][
"full_observation_spec"
].keys(True, True)
)
env_output_keys = env_output_keys.union(
meta_data.specs["output_spec"]["full_observation_spec"].keys(
True, True
)
keys = meta_data.specs["output_spec"]["full_observation_spec"].keys(
True, True
)
keys = list(keys)
env_obs_keys = env_obs_keys.union(keys)

env_output_keys = env_output_keys.union(keys)
env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys)
self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys)
self._env_input_keys = sorted(env_input_keys, key=_sort_keys)
Expand Down Expand Up @@ -1003,7 +999,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
for i, _env in enumerate(self._envs):
if not needs_resetting[i]:
if out_tds is not None and tensordict is not None:
out_tds[i] = tensordict[i].exclude(*self._envs[i].reset_keys)
ftd = _env.observation_spec.zero()
if self.device is None:
ftd.clear_device_()
else:
ftd = ftd.to(self.device)
out_tds[i] = ftd
continue
if tensordict is not None:
tensordict_ = tensordict[i]
Expand Down
21 changes: 18 additions & 3 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2505,11 +2505,26 @@ def reset(
Returns:
a tensordict (or the input tensordict, if any), modified in place with the resulting observations.
.. note:: `reset` should not be overwritten by :class:`~torchrl.envs.EnvBase` subclasses. The method to
modify is :meth:`~torchrl.envs.EnvBase._reset`.
"""
if tensordict is not None:
self._assert_tensordict_shape(tensordict)

tensordict_reset = self._reset(tensordict, **kwargs)
select_reset_only = kwargs.pop("select_reset_only", False)
if select_reset_only and tensordict is not None:
# When making rollouts with step_and_maybe_reset, it can happen that a tensordict has
# keys that are used by reset to optionally set the reset state (eg, the fen in chess). If that's the
# case and we don't throw them away here, reset will just be a no-op (put the env in the state reached
# during the previous step).
# Therefore, maybe_reset tells reset to temporarily hide the non-reset keys.
# To make step_and_maybe_reset handle custom reset states, some version of TensorDictPrimer should be used.
tensordict_reset = self._reset(
tensordict.select(*self.reset_keys, strict=False), **kwargs
)
else:
tensordict_reset = self._reset(tensordict, **kwargs)
# We assume that this is done properly
# if reset.device != self.device:
# reset = reset.to(self.device, non_blocking=True)
Expand Down Expand Up @@ -3292,7 +3307,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
else:
any_done = False
if any_done:
tensordict._set_str(
tensordict = tensordict._set_str(
"_reset",
done.clone(),
validated=True,
Expand All @@ -3306,7 +3321,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
key="_reset",
)
if any_done:
tensordict = self.reset(tensordict)
return self.reset(tensordict, select_reset_only=True)
return tensordict

def empty_cache(self):
Expand Down
Loading

0 comments on commit a737d09

Please sign in to comment.