diff --git a/test/test_env.py b/test/test_env.py index 12bdc0bc9ad..d8bf36cdf98 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -131,6 +131,7 @@ from torchrl.envs.transforms.transforms import ( AutoResetEnv, AutoResetTransform, + Tokenizer, Transform, ) from torchrl.envs.utils import ( @@ -3346,10 +3347,6 @@ def test_batched_dynamic(self, break_when_any_done): ) del env_no_buffers gc.collect() - # print(dummy_rollouts) - # print(rollout_no_buffers_serial) - # # for a, b in zip(dummy_rollouts.exclude("action").unbind(0), rollout_no_buffers_serial.exclude("action").unbind(0)): - # assert_allclose_td(a, b) assert_allclose_td( dummy_rollouts.exclude("action"), rollout_no_buffers_serial.exclude("action"), @@ -3463,6 +3460,8 @@ def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san include_hash=include_hash, include_san=include_san, ) + # Because we always use mask_actions=True + assert isinstance(env, TransformedEnv) check_env_specs(env) if include_hash: if include_fen: @@ -3560,8 +3559,8 @@ def test_reset_white_to_move(self, stateful, include_pgn, include_fen): ) fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1" td = env.reset(TensorDict({"fen": fen})) - assert td["fen"] == fen if include_fen: + assert td["fen"] == fen assert env.board.fen() == fen assert td["turn"] == env.lib.WHITE assert not td["done"] @@ -3666,6 +3665,27 @@ def test_reward( assert td["reward"] == expected_reward assert td["turn"] == (not expected_turn) + def test_chess_tokenized(self): + env = ChessEnv(include_fen=True, stateful=True, include_san=True) + assert isinstance(env.observation_spec["fen"], NonTensor) + env = env.append_transform( + Tokenizer(in_keys=["fen"], out_keys=["fen_tokenized"]) + ) + assert isinstance(env.observation_spec["fen"], NonTensor) + env.transform.transform_output_spec(env.base_env.output_spec) + env.transform.transform_input_spec(env.base_env.input_spec) + r = env.rollout(10, return_contiguous=False) + assert "fen_tokenized" in r + assert "fen" in r + assert "fen_tokenized" in r["next"] + assert "fen" in r["next"] + ftd = env.fake_tensordict() + assert "fen_tokenized" in ftd + assert "fen" in ftd + assert "fen_tokenized" in ftd["next"] + assert "fen" in ftd["next"] + env.check_env_specs() + class TestCustomEnvs: def test_tictactoe_env(self): diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 3d4198ae234..95aaaebd936 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -5042,7 +5042,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase: def __eq__(self, other): return ( - type(self) is type(other) + type(self) == type(other) and self.shape == other.shape and self._device == other._device and set(self._specs.keys()) == set(other._specs.keys()) diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py index 1446d105ae9..cf9a79e2de3 100644 --- a/torchrl/envs/custom/chess.py +++ b/torchrl/envs/custom/chess.py @@ -7,12 +7,12 @@ import importlib.util import io import pathlib -from typing import Dict, Optional +from typing import Dict import torch from PIL import Image from tensordict import TensorDict, TensorDictBase -from torchrl.data import Bounded, Categorical, Composite, NonTensor, Unbounded +from torchrl.data import Binary, Bounded, Categorical, Composite, NonTensor, Unbounded from torchrl.envs import EnvBase from torchrl.envs.common import _EnvPostInit @@ -20,7 +20,7 @@ from torchrl.envs.utils import _classproperty -class _HashMeta(_EnvPostInit): +class _ChessMeta(_EnvPostInit): def __call__(cls, *args, **kwargs): instance = super().__call__(*args, **kwargs) if kwargs.get("include_hash"): @@ -37,11 +37,15 @@ def __call__(cls, *args, **kwargs): if instance.include_pgn: in_keys.append("pgn") out_keys.append("pgn_hash") - return instance.append_transform(Hash(in_keys, out_keys)) + instance = instance.append_transform(Hash(in_keys, out_keys)) + if kwargs.get("mask_actions", True): + from torchrl.envs import ActionMask + + instance = instance.append_transform(ActionMask()) return instance -class ChessEnv(EnvBase, metaclass=_HashMeta): +class ChessEnv(EnvBase, metaclass=_ChessMeta): r"""A chess environment that follows the TorchRL API. This environment simulates a chess game using the `chess` library. It supports various state representations @@ -63,6 +67,8 @@ class ChessEnv(EnvBase, metaclass=_HashMeta): include_pgn (bool): Whether to include PGN (Portable Game Notation) in the observations. Default: ``False``. include_legal_moves (bool): Whether to include legal moves in the observations. Default: ``False``. include_hash (bool): Whether to include hash transformations in the environment. Default: ``False``. + mask_actions (bool): if ``True``, a :class:`~torchrl.envs.ActionMask` transform will be appended + to the env to make sure that the actions are properly masked. Default: ``True``. pixels (bool): Whether to include pixel-based observations of the board. Default: ``False``. .. note:: The action spec is a :class:`~torchrl.data.Categorical` with a number of actions equal to the number of possible SAN moves. @@ -200,16 +206,15 @@ def _legal_moves_to_index( ) -> torch.Tensor: if not self.stateful: if tensordict is None: - raise RuntimeError( - "rand_action requires a tensordict when stateful is False." - ) - if self.include_fen: - fen = self._get_fen(tensordict) + # trust the board + pass + elif self.include_fen: + fen = tensordict.get("fen", None) fen = fen.data self.board.set_fen(fen) board = self.board elif self.include_pgn: - pgn = self._get_pgn(tensordict) + pgn = tensordict.get("pgn") pgn = pgn.data board = self._pgn_to_board(pgn, self.board) @@ -222,15 +227,19 @@ def _legal_moves_to_index( ) if return_mask: - return torch.zeros(len(self.san_moves), dtype=torch.bool).index_fill_( - 0, indices, True - ) + return self._move_index_to_mask(indices) if pad: indices = torch.nn.functional.pad( indices, [0, 218 - indices.numel() + 1], value=len(self.san_moves) ) return indices + @classmethod + def _move_index_to_mask(cls, indices: torch.Tensor) -> torch.Tensor: + return torch.zeros(len(cls.san_moves), dtype=torch.bool).index_fill_( + 0, indices, True + ) + def __init__( self, *, @@ -240,6 +249,7 @@ def __init__( include_pgn: bool = False, include_legal_moves: bool = False, include_hash: bool = False, + mask_actions: bool = True, pixels: bool = False, ): chess = self.lib @@ -250,6 +260,7 @@ def __init__( self.include_san = include_san self.include_fen = include_fen self.include_pgn = include_pgn + self.mask_actions = mask_actions self.include_legal_moves = include_legal_moves if include_legal_moves: # 218 max possible legal moves per chess board position @@ -274,8 +285,10 @@ def __init__( self.stateful = stateful - if not self.stateful: - self.full_state_spec = self.full_observation_spec.clone() + # state_spec is loosely defined as such - it's not really an issue that extra keys + # can go missing but it allows us to reset the env using fen passed to the reset + # method. + self.full_state_spec = self.full_observation_spec.clone() self.pixels = pixels if pixels: @@ -295,16 +308,16 @@ def __init__( self.full_reward_spec = Composite( reward=Unbounded(shape=(1,), dtype=torch.float32) ) + if self.mask_actions: + self.full_observation_spec["action_mask"] = Binary( + n=len(self.san_moves), dtype=torch.bool + ) + # done spec generated automatically self.board = chess.Board() if self.stateful: self.action_spec.set_provisional_n(len(list(self.board.legal_moves))) - def rand_action(self, tensordict: Optional[TensorDictBase] = None): - mask = self._legal_moves_to_index(tensordict, return_mask=True) - self.action_spec.update_mask(mask) - return super().rand_action(tensordict) - def _is_done(self, board): return board.is_game_over() | board.is_fifty_moves() @@ -314,11 +327,11 @@ def _reset(self, tensordict=None): if tensordict is not None: dest = tensordict.empty() if self.include_fen: - fen = self._get_fen(tensordict) + fen = tensordict.get("fen", None) if fen is not None: fen = fen.data elif self.include_pgn: - pgn = self._get_pgn(tensordict) + pgn = tensordict.get("pgn", None) if pgn is not None: pgn = pgn.data else: @@ -358,13 +371,18 @@ def _reset(self, tensordict=None): if self.include_legal_moves: moves_idx = self._legal_moves_to_index(board=self.board, pad=True) dest.set("legal_moves", moves_idx) + if self.mask_actions: + dest.set("action_mask", self._move_index_to_mask(moves_idx)) + elif self.mask_actions: + dest.set( + "action_mask", + self._legal_moves_to_index( + board=self.board, pad=True, return_mask=True + ), + ) + if self.pixels: dest.set("pixels", self._get_tensor_image(board=self.board)) - - if self.stateful: - mask = self._legal_moves_to_index(dest, return_mask=True) - self.action_spec.update_mask(mask) - return dest _cairosvg_lib = None @@ -435,16 +453,6 @@ def _board_to_pgn(cls, board: "chess.Board") -> str: # noqa: F821 pgn_string = str(game) return pgn_string - @classmethod - def _get_fen(cls, tensordict): - fen = tensordict.get("fen", None) - return fen - - @classmethod - def _get_pgn(cls, tensordict): - pgn = tensordict.get("pgn", None) - return pgn - def get_legal_moves(self, tensordict=None, uci=False): """List the legal moves in a position. @@ -468,7 +476,7 @@ def get_legal_moves(self, tensordict=None, uci=False): raise ValueError( "tensordict must be given since this env is not stateful" ) - fen = self._get_fen(tensordict).data + fen = tensordict.get("fen").data board.set_fen(fen) moves = board.legal_moves @@ -486,10 +494,10 @@ def _step(self, tensordict): fen = None if not self.stateful: if self.include_fen: - fen = self._get_fen(tensordict).data + fen = tensordict.get("fen").data board.set_fen(fen) elif self.include_pgn: - pgn = self._get_pgn(tensordict).data + pgn = tensordict.get("pgn").data board = self._pgn_to_board(pgn, board) else: raise RuntimeError( @@ -519,6 +527,15 @@ def _step(self, tensordict): if self.include_legal_moves: moves_idx = self._legal_moves_to_index(board=board, pad=True) dest.set("legal_moves", moves_idx) + if self.mask_actions: + dest.set("action_mask", self._move_index_to_mask(moves_idx)) + elif self.mask_actions: + dest.set( + "action_mask", + self._legal_moves_to_index( + board=self.board, pad=True, return_mask=True + ), + ) turn = torch.tensor(board.turn) done = self._is_done(board) @@ -538,11 +555,6 @@ def _step(self, tensordict): dest.set("terminated", [done]) if self.pixels: dest.set("pixels", self._get_tensor_image(board=self.board)) - - if self.stateful: - mask = self._legal_moves_to_index(dest, return_mask=True) - self.action_spec.update_mask(mask) - return dest def _set_seed(self, *args, **kwargs): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index d32b845a0f8..491cf295a03 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -862,7 +862,9 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None) -> TensorDict f"The rand_action method from the base env {self.base_env.__class__.__name__} " "has been overwritten, but the transforms appended to the environment modify " "the action. To call the base env rand_action method, we should then invert the " - "action transform, which is (in general) not doable." + "action transform, which is (in general) not doable. " + f"The full action spec of the base env is: {self.base_env.full_action_spec}, \n" + f"the full action spec of the transformed env is {self.full_action_spec}." ) return self.base_env.rand_action(tensordict) return super().rand_action(tensordict) @@ -5071,23 +5073,73 @@ def transform_input_spec(self, input_spec: Composite) -> Composite: # We need to cap the spec to generate valid random strings for out_key in self.out_keys_inv: if out_key in input_spec["full_state_spec"].keys(True, True): + new_shape = input_spec["full_state_spec"][out_key].shape + if self.max_length is None: + # Then we can't tell what the shape will be + new_shape = new_shape[:-1] + torch.Size((-1,)) input_spec["full_state_spec"][out_key] = Bounded( 0, self.tokenizer.vocab_size, - shape=input_spec["full_state_spec"][out_key].shape, + shape=new_shape, device=input_spec["full_state_spec"][out_key].device, dtype=input_spec["full_state_spec"][out_key].dtype, ) elif out_key in input_spec["full_action_spec"].keys(True, True): + new_shape = input_spec["full_action_spec"][out_key].shape + if self.max_length is None: + # Then we can't tell what the shape will be + new_shape = new_shape[:-1] + torch.Size((-1,)) input_spec["full_action_spec"][out_key] = Bounded( 0, self.tokenizer.vocab_size, - shape=input_spec["full_action_spec"][out_key].shape, + shape=new_shape, device=input_spec["full_action_spec"][out_key].device, dtype=input_spec["full_action_spec"][out_key].dtype, ) return input_spec + def transform_output_spec(self, output_spec: Composite) -> Composite: + output_spec = super().transform_output_spec(output_spec) + # We need to cap the spec to generate valid random strings + for out_key in self.out_keys: + if out_key in output_spec["full_observation_spec"].keys(True, True): + new_shape = output_spec["full_observation_spec"][out_key].shape + if self.max_length is None: + # Then we can't tell what the shape will be + new_shape = new_shape[:-1] + torch.Size((-1,)) + output_spec["full_observation_spec"][out_key] = Bounded( + 0, + self.tokenizer.vocab_size, + shape=new_shape, + device=output_spec["full_observation_spec"][out_key].device, + dtype=output_spec["full_observation_spec"][out_key].dtype, + ) + elif out_key in output_spec["full_reward_spec"].keys(True, True): + new_shape = output_spec["full_reward_spec"][out_key].shape + if self.max_length is None: + # Then we can't tell what the shape will be + new_shape = new_shape[:-1] + torch.Size((-1,)) + output_spec["full_reward_spec"][out_key] = Bounded( + 0, + self.tokenizer.vocab_size, + shape=new_shape, + device=output_spec["full_reward_spec"][out_key].device, + dtype=output_spec["full_reward_spec"][out_key].dtype, + ) + elif out_key in output_spec["full_done_spec"].keys(True, True): + new_shape = output_spec["full_done_spec"][out_key].shape + if self.max_length is None: + # Then we can't tell what the shape will be + new_shape = new_shape[:-1] + torch.Size((-1,)) + output_spec["full_done_spec"][out_key] = Bounded( + 0, + self.tokenizer.vocab_size, + shape=new_shape, + device=output_spec["full_done_spec"][out_key].device, + dtype=output_spec["full_done_spec"][out_key].dtype, + ) + return output_spec + class Stack(Transform): """Stacks tensors and tensordicts.