Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 26, 2025
2 parents ab0a2df + f4d2bcf commit 2280e4b
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 55 deletions.
30 changes: 25 additions & 5 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
from torchrl.envs.transforms.transforms import (
AutoResetEnv,
AutoResetTransform,
Tokenizer,
Transform,
)
from torchrl.envs.utils import (
Expand Down Expand Up @@ -3346,10 +3347,6 @@ def test_batched_dynamic(self, break_when_any_done):
)
del env_no_buffers
gc.collect()
# print(dummy_rollouts)
# print(rollout_no_buffers_serial)
# # for a, b in zip(dummy_rollouts.exclude("action").unbind(0), rollout_no_buffers_serial.exclude("action").unbind(0)):
# assert_allclose_td(a, b)
assert_allclose_td(
dummy_rollouts.exclude("action"),
rollout_no_buffers_serial.exclude("action"),
Expand Down Expand Up @@ -3463,6 +3460,8 @@ def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san
include_hash=include_hash,
include_san=include_san,
)
# Because we always use mask_actions=True
assert isinstance(env, TransformedEnv)
check_env_specs(env)
if include_hash:
if include_fen:
Expand Down Expand Up @@ -3560,8 +3559,8 @@ def test_reset_white_to_move(self, stateful, include_pgn, include_fen):
)
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
td = env.reset(TensorDict({"fen": fen}))
assert td["fen"] == fen
if include_fen:
assert td["fen"] == fen
assert env.board.fen() == fen
assert td["turn"] == env.lib.WHITE
assert not td["done"]
Expand Down Expand Up @@ -3666,6 +3665,27 @@ def test_reward(
assert td["reward"] == expected_reward
assert td["turn"] == (not expected_turn)

def test_chess_tokenized(self):
env = ChessEnv(include_fen=True, stateful=True, include_san=True)
assert isinstance(env.observation_spec["fen"], NonTensor)
env = env.append_transform(
Tokenizer(in_keys=["fen"], out_keys=["fen_tokenized"])
)
assert isinstance(env.observation_spec["fen"], NonTensor)
env.transform.transform_output_spec(env.base_env.output_spec)
env.transform.transform_input_spec(env.base_env.input_spec)
r = env.rollout(10, return_contiguous=False)
assert "fen_tokenized" in r
assert "fen" in r
assert "fen_tokenized" in r["next"]
assert "fen" in r["next"]
ftd = env.fake_tensordict()
assert "fen_tokenized" in ftd
assert "fen" in ftd
assert "fen_tokenized" in ftd["next"]
assert "fen" in ftd["next"]
env.check_env_specs()


class TestCustomEnvs:
def test_tictactoe_env(self):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5042,7 +5042,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:

def __eq__(self, other):
return (
type(self) is type(other)
type(self) == type(other)
and self.shape == other.shape
and self._device == other._device
and set(self._specs.keys()) == set(other._specs.keys())
Expand Down
104 changes: 58 additions & 46 deletions torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
import importlib.util
import io
import pathlib
from typing import Dict, Optional
from typing import Dict

import torch
from PIL import Image
from tensordict import TensorDict, TensorDictBase
from torchrl.data import Bounded, Categorical, Composite, NonTensor, Unbounded
from torchrl.data import Binary, Bounded, Categorical, Composite, NonTensor, Unbounded

from torchrl.envs import EnvBase
from torchrl.envs.common import _EnvPostInit

from torchrl.envs.utils import _classproperty


class _HashMeta(_EnvPostInit):
class _ChessMeta(_EnvPostInit):
def __call__(cls, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
if kwargs.get("include_hash"):
Expand All @@ -37,11 +37,15 @@ def __call__(cls, *args, **kwargs):
if instance.include_pgn:
in_keys.append("pgn")
out_keys.append("pgn_hash")
return instance.append_transform(Hash(in_keys, out_keys))
instance = instance.append_transform(Hash(in_keys, out_keys))
if kwargs.get("mask_actions", True):
from torchrl.envs import ActionMask

instance = instance.append_transform(ActionMask())
return instance


class ChessEnv(EnvBase, metaclass=_HashMeta):
class ChessEnv(EnvBase, metaclass=_ChessMeta):
r"""A chess environment that follows the TorchRL API.
This environment simulates a chess game using the `chess` library. It supports various state representations
Expand All @@ -63,6 +67,8 @@ class ChessEnv(EnvBase, metaclass=_HashMeta):
include_pgn (bool): Whether to include PGN (Portable Game Notation) in the observations. Default: ``False``.
include_legal_moves (bool): Whether to include legal moves in the observations. Default: ``False``.
include_hash (bool): Whether to include hash transformations in the environment. Default: ``False``.
mask_actions (bool): if ``True``, a :class:`~torchrl.envs.ActionMask` transform will be appended
to the env to make sure that the actions are properly masked. Default: ``True``.
pixels (bool): Whether to include pixel-based observations of the board. Default: ``False``.
.. note:: The action spec is a :class:`~torchrl.data.Categorical` with a number of actions equal to the number of possible SAN moves.
Expand Down Expand Up @@ -202,16 +208,15 @@ def _legal_moves_to_index(
) -> torch.Tensor:
if not self.stateful:
if tensordict is None:
raise RuntimeError(
"rand_action requires a tensordict when stateful is False."
)
if self.include_fen:
fen = self._get_fen(tensordict)
# trust the board
pass
elif self.include_fen:
fen = tensordict.get("fen", None)
fen = fen.data
self.board.set_fen(fen)
board = self.board
elif self.include_pgn:
pgn = self._get_pgn(tensordict)
pgn = tensordict.get("pgn")
pgn = pgn.data
board = self._pgn_to_board(pgn, self.board)

Expand All @@ -224,15 +229,19 @@ def _legal_moves_to_index(
)

if return_mask:
return torch.zeros(len(self.san_moves), dtype=torch.bool).index_fill_(
0, indices, True
)
return self._move_index_to_mask(indices)
if pad:
indices = torch.nn.functional.pad(
indices, [0, 218 - indices.numel() + 1], value=len(self.san_moves)
)
return indices

@classmethod
def _move_index_to_mask(cls, indices: torch.Tensor) -> torch.Tensor:
return torch.zeros(len(cls.san_moves), dtype=torch.bool).index_fill_(
0, indices, True
)

def __init__(
self,
*,
Expand All @@ -242,6 +251,7 @@ def __init__(
include_pgn: bool = False,
include_legal_moves: bool = False,
include_hash: bool = False,
mask_actions: bool = True,
pixels: bool = False,
):
chess = self.lib
Expand All @@ -252,6 +262,7 @@ def __init__(
self.include_san = include_san
self.include_fen = include_fen
self.include_pgn = include_pgn
self.mask_actions = mask_actions
self.include_legal_moves = include_legal_moves
if include_legal_moves:
# 218 max possible legal moves per chess board position
Expand All @@ -276,8 +287,10 @@ def __init__(

self.stateful = stateful

if not self.stateful:
self.full_state_spec = self.full_observation_spec.clone()
# state_spec is loosely defined as such - it's not really an issue that extra keys
# can go missing but it allows us to reset the env using fen passed to the reset
# method.
self.full_state_spec = self.full_observation_spec.clone()

self.pixels = pixels
if pixels:
Expand All @@ -297,16 +310,16 @@ def __init__(
self.full_reward_spec = Composite(
reward=Unbounded(shape=(1,), dtype=torch.float32)
)
if self.mask_actions:
self.full_observation_spec["action_mask"] = Binary(
n=len(self.san_moves), dtype=torch.bool
)

# done spec generated automatically
self.board = chess.Board()
if self.stateful:
self.action_spec.set_provisional_n(len(list(self.board.legal_moves)))

def rand_action(self, tensordict: Optional[TensorDictBase] = None):
mask = self._legal_moves_to_index(tensordict, return_mask=True)
self.action_spec.update_mask(mask)
return super().rand_action(tensordict)

def _is_done(self, board):
return board.is_game_over() | board.is_fifty_moves()

Expand All @@ -316,11 +329,11 @@ def _reset(self, tensordict=None):
if tensordict is not None:
dest = tensordict.empty()
if self.include_fen:
fen = self._get_fen(tensordict)
fen = tensordict.get("fen", None)
if fen is not None:
fen = fen.data
elif self.include_pgn:
pgn = self._get_pgn(tensordict)
pgn = tensordict.get("pgn", None)
if pgn is not None:
pgn = pgn.data
else:
Expand Down Expand Up @@ -360,13 +373,18 @@ def _reset(self, tensordict=None):
if self.include_legal_moves:
moves_idx = self._legal_moves_to_index(board=self.board, pad=True)
dest.set("legal_moves", moves_idx)
if self.mask_actions:
dest.set("action_mask", self._move_index_to_mask(moves_idx))
elif self.mask_actions:
dest.set(
"action_mask",
self._legal_moves_to_index(
board=self.board, pad=True, return_mask=True
),
)

if self.pixels:
dest.set("pixels", self._get_tensor_image(board=self.board))

if self.stateful:
mask = self._legal_moves_to_index(dest, return_mask=True)
self.action_spec.update_mask(mask)

return dest

_cairosvg_lib = None
Expand Down Expand Up @@ -437,16 +455,6 @@ def _board_to_pgn(cls, board: "chess.Board") -> str: # noqa: F821
pgn_string = str(game)
return pgn_string

@classmethod
def _get_fen(cls, tensordict):
fen = tensordict.get("fen", None)
return fen

@classmethod
def _get_pgn(cls, tensordict):
pgn = tensordict.get("pgn", None)
return pgn

def get_legal_moves(self, tensordict=None, uci=False):
"""List the legal moves in a position.
Expand All @@ -470,7 +478,7 @@ def get_legal_moves(self, tensordict=None, uci=False):
raise ValueError(
"tensordict must be given since this env is not stateful"
)
fen = self._get_fen(tensordict).data
fen = tensordict.get("fen").data
board.set_fen(fen)
moves = board.legal_moves

Expand All @@ -488,10 +496,10 @@ def _step(self, tensordict):
fen = None
if not self.stateful:
if self.include_fen:
fen = self._get_fen(tensordict).data
fen = tensordict.get("fen").data
board.set_fen(fen)
elif self.include_pgn:
pgn = self._get_pgn(tensordict).data
pgn = tensordict.get("pgn").data
board = self._pgn_to_board(pgn, board)
else:
raise RuntimeError(
Expand Down Expand Up @@ -521,6 +529,15 @@ def _step(self, tensordict):
if self.include_legal_moves:
moves_idx = self._legal_moves_to_index(board=board, pad=True)
dest.set("legal_moves", moves_idx)
if self.mask_actions:
dest.set("action_mask", self._move_index_to_mask(moves_idx))
elif self.mask_actions:
dest.set(
"action_mask",
self._legal_moves_to_index(
board=self.board, pad=True, return_mask=True
),
)

turn = torch.tensor(board.turn)
done = self._is_done(board)
Expand All @@ -540,11 +557,6 @@ def _step(self, tensordict):
dest.set("terminated", [done])
if self.pixels:
dest.set("pixels", self._get_tensor_image(board=self.board))

if self.stateful:
mask = self._legal_moves_to_index(dest, return_mask=True)
self.action_spec.update_mask(mask)

return dest

def _set_seed(self, *args, **kwargs):
Expand Down
Loading

0 comments on commit 2280e4b

Please sign in to comment.