diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index c4f3f6eda9a..a1520ca0c63 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -731,29 +731,80 @@ pixels or states etc). Forward and inverse transforms ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Transforms also have an ``inv`` method that is called before -the action is applied in reverse order over the composed transform chain: -this allows to apply transforms to data in the environment before the action is taken -in the environment. The keys to be included in this inverse transform are passed through the -``"in_keys_inv"`` keyword argument: +Transforms also have an :meth:`~torchrl.envs.Transform.inv` method that is called before the action is applied in reverse +order over the composed transform chain. This allows applying transforms to data in the environment before the action is +taken in the environment. The keys to be included in this inverse transform are passed through the `"in_keys_inv"` +keyword argument, and the out-keys default to these values in most cases: .. code-block:: :caption: Inverse transform >>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step -The way ``in_keys`` relates to ``in_keys_inv`` can be understood by considering the base environment as the "inner" part -of the transform. In constrast, the user inputs and outputs to and from the transform are to be considered as the -outside world. The following figure shows what this means in practice for the :class:`~torchrl.envs.RenameTransform` -class: the input ``TensorDict`` of the ``step`` function must have the ``out_keys_inv`` listed in its entries as they -are part of the outside world. The transform changes these names to make them match the names of the inner, base -environment using the ``in_keys_inv``. The inverse process is executed with the output tensordict, where the ``in_keys`` -are mapped to the corresponding ``out_keys``. +The following paragraphs detail how one can think about what is to be considered `in_` or `out_` features. + +Understanding Transform Keys +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In transforms, `in_keys` and `out_keys` define the interaction between the base environment and the outside world +(e.g., your policy): + +- `in_keys` refers to the base environment's perspective (inner = `base_env` of the + :class:`~torchrl.envs.TransformedEnv`). +- `out_keys` refers to the outside world (outer = `policy`, `agent`, etc.). + +For example, with `in_keys=["obs"]` and `out_keys=["obs_standardized"]`, the policy will "see" a standardized +observation, while the base environment outputs a regular observation. + +Similarly, for inverse keys: + +- `in_keys_inv` refers to entries as seen by the base environment. +- `out_keys_inv` refers to entries as seen or produced by the policy. + +The following figure illustrates this concept for the :class:`~torchrl.envs.RenameTransform` class: the input +`TensorDict` of the `step` function must include the `out_keys_inv` as they are part of the outside world. The +transform changes these names to match the names of the inner, base environment using the `in_keys_inv`. +The inverse process is executed with the output tensordict, where the `in_keys` are mapped to the corresponding +`out_keys`. .. figure:: /_static/img/rename_transform.png Rename transform logic +Transforming Tensors and Specs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When transforming actual tensors (coming from the policy), the process is schematically represented as: + + >>> for t in reversed(self.transform): + ... td = t.inv(td) + +This starts with the outermost transform to the innermost transform, ensuring the action value exposed to the policy +is properly transformed. + +For transforming the action spec, the process should go from innermost to outermost (similar to observation specs): + + >>> def transform_action_spec(self, action_spec): + ... for t in self.transform: + ... action_spec = t.transform_action_spec(action_spec) + ... return action_spec + +A pseudocode for a single transform_action_spec could be: + + >>> def transform_action_spec(self, action_spec): + ... return spec_from_random_values(self._apply_transform(action_spec.rand())) + +This approach ensures that the "outside" spec is inferred from the "inside" spec. Note that we did not call +`_inv_apply_transform` but `_apply_transform` on purpose! + +Exposing Specs to the Outside World +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +`TransformedEnv` will expose the specs corresponding to the `out_keys_inv` for actions and states. +For example, with :class:`~torchrl.envs.ActionDiscretizer`, the environment's action (e.g., `"action"`) is a float-valued +tensor that should not be generated when using :meth:`~torchrl.envs.EnvBase.rand_action` with the transformed +environment. Instead, `"action_discrete"` should be generated, and its continuous counterpart obtained from the +transform. Therefore, the user should see the `"action_discrete"` entry being exposed, but not `"action"`. Cloning transforms diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 0531bff10df..71375fd13a2 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1070,17 +1070,20 @@ def _step( class CountingEnvWithString(CountingEnv): def __init__(self, *args, **kwargs): + self.max_size = kwargs.pop("max_size", 30) + self.min_size = kwargs.pop("min_size", 4) super().__init__(*args, **kwargs) self.observation_spec.set( "string", NonTensor( shape=self.batch_size, device=self.device, + example_data=self.get_random_string(), ), ) def get_random_string(self): - size = random.randint(4, 30) + size = random.randint(self.min_size, self.max_size) return "".join(random.choice(string.ascii_lowercase) for _ in range(size)) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: diff --git a/test/test_transforms.py b/test/test_transforms.py index 6a57d4faa1e..c480015bf17 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -20,76 +20,26 @@ import tensordict.tensordict import torch - -from torchrl.collectors import MultiSyncDataCollector - -if os.getenv("PYTORCH_TEST_FBCODE"): - from pytorch.rl.test._utils_internal import ( # noqa - BREAKOUT_VERSIONED, - dtype_fixture, - get_default_devices, - HALFCHEETAH_VERSIONED, - PENDULUM_VERSIONED, - PONG_VERSIONED, - rand_reset, - retry, - ) - from pytorch.rl.test.mocking_classes import ( - ContinuousActionVecMockEnv, - CountingBatchedEnv, - CountingEnv, - CountingEnvCountPolicy, - CountingEnvWithString, - DiscreteActionConvMockEnv, - DiscreteActionConvMockEnvNumpy, - EnvWithScalarAction, - IncrementingEnv, - MockBatchedLockedEnv, - MockBatchedUnLockedEnv, - MultiAgentCountingEnv, - MultiKeyCountingEnv, - MultiKeyCountingEnvPolicy, - NestedCountingEnv, - ) -else: - from _utils_internal import ( # noqa - BREAKOUT_VERSIONED, - dtype_fixture, - get_default_devices, - HALFCHEETAH_VERSIONED, - PENDULUM_VERSIONED, - PONG_VERSIONED, - rand_reset, - retry, - ) - from mocking_classes import ( - ContinuousActionVecMockEnv, - CountingBatchedEnv, - CountingEnv, - CountingEnvCountPolicy, - CountingEnvWithString, - DiscreteActionConvMockEnv, - DiscreteActionConvMockEnvNumpy, - EnvWithScalarAction, - IncrementingEnv, - MockBatchedLockedEnv, - MockBatchedUnLockedEnv, - MultiAgentCountingEnv, - MultiKeyCountingEnv, - MultiKeyCountingEnvPolicy, - NestedCountingEnv, - ) -from tensordict import NonTensorData, TensorDict, TensorDictBase, unravel_key +from tensordict import ( + NonTensorData, + NonTensorStack, + TensorDict, + TensorDictBase, + unravel_key, +) from tensordict.nn import TensorDictSequential from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td from torch import multiprocessing as mp, nn, Tensor from torchrl._utils import _replace_last, prod + +from torchrl.collectors import MultiSyncDataCollector from torchrl.data import ( Bounded, BoundedContinuous, Categorical, Composite, LazyTensorStorage, + NonTensor, ReplayBuffer, TensorDictReplayBuffer, TensorSpec, @@ -147,6 +97,7 @@ TargetReturn, TensorDictPrimer, TimeMaxPool, + Tokenizer, ToTensorImage, TrajCounter, TransformedEnv, @@ -174,6 +125,63 @@ from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal from torchrl.modules.utils import get_primers_from_module +if os.getenv("PYTORCH_TEST_FBCODE"): + from pytorch.rl.test._utils_internal import ( # noqa + BREAKOUT_VERSIONED, + dtype_fixture, + get_default_devices, + HALFCHEETAH_VERSIONED, + PENDULUM_VERSIONED, + PONG_VERSIONED, + rand_reset, + retry, + ) + from pytorch.rl.test.mocking_classes import ( + ContinuousActionVecMockEnv, + CountingBatchedEnv, + CountingEnv, + CountingEnvCountPolicy, + CountingEnvWithString, + DiscreteActionConvMockEnv, + DiscreteActionConvMockEnvNumpy, + EnvWithScalarAction, + IncrementingEnv, + MockBatchedLockedEnv, + MockBatchedUnLockedEnv, + MultiAgentCountingEnv, + MultiKeyCountingEnv, + MultiKeyCountingEnvPolicy, + NestedCountingEnv, + ) +else: + from _utils_internal import ( # noqa + BREAKOUT_VERSIONED, + dtype_fixture, + get_default_devices, + HALFCHEETAH_VERSIONED, + PENDULUM_VERSIONED, + PONG_VERSIONED, + rand_reset, + retry, + ) + from mocking_classes import ( + ContinuousActionVecMockEnv, + CountingBatchedEnv, + CountingEnv, + CountingEnvCountPolicy, + CountingEnvWithString, + DiscreteActionConvMockEnv, + DiscreteActionConvMockEnvNumpy, + EnvWithScalarAction, + IncrementingEnv, + MockBatchedLockedEnv, + MockBatchedUnLockedEnv, + MultiAgentCountingEnv, + MultiKeyCountingEnv, + MultiKeyCountingEnvPolicy, + NestedCountingEnv, + ) + IS_WIN = platform == "win32" if IS_WIN: mp_ctx = "spawn" @@ -433,8 +441,8 @@ def test_transform_rb(self, rbclass): ClipTransform( in_keys=["observation", "reward"], out_keys=["obs_clip", "reward_clip"], - in_keys_inv=["input"], - out_keys_inv=["input_clip"], + in_keys_inv=["input_clip"], + out_keys_inv=["input"], low=-0.1, high=0.1, ) @@ -567,8 +575,10 @@ def test_transform_env(self, device): def test_transform_inverse(self): t = ClipTransform( - in_keys_inv=["observation", "reward"], - out_keys_inv=["obs_clip", "reward_clip"], + # What the outside world sees + out_keys_inv=["observation", "reward"], + # What the env expects + in_keys_inv=["obs_clip", "reward_clip"], low=-0.1, high=0.1, ) @@ -2499,7 +2509,270 @@ def test_transform_rb(self, rbclass): assert ("next", "observation") in td.keys(True) def test_transform_inverse(self): - raise pytest.skip("No inverse for Hash") + return + env = CountingEnv() + with pytest.raises(TypeError): + env = env.append_transform( + Hash( + in_keys=[], + out_keys=[], + in_keys_inv=["action"], + out_keys_inv=["action_hash"], + ) + ) + + +class TestTokenizer(TransformBase): + @pytest.mark.parametrize("datatype", ["str", "NonTensorStack"]) + def test_transform_no_env(self, datatype): + if datatype == "str": + obs = "abcdefg" + elif datatype == "NonTensorStack": + obs = torch.stack( + [ + NonTensorData(data="abcde"), + NonTensorData(data="fghij"), + NonTensorData(data="klmno"), + ] + ) + else: + raise RuntimeError(f"please add a test case for datatype {datatype}") + + td = TensorDict( + { + "observation": obs, + } + ) + + t = Tokenizer(in_keys=["observation"], out_keys=["tokens"]) + td_tokenized = t(td) + t_inv = Tokenizer([], [], in_keys_inv=["observation"], out_keys_inv=["tokens"]) + td_recon = t_inv.inv(td_tokenized.clone().exclude("observation")) + assert td_tokenized.get("observation") is td.get("observation") + assert td_recon["observation"] == td["observation"] + + @pytest.mark.parametrize("datatype", ["str"]) + def test_single_trans_env_check(self, datatype): + if datatype == "str": + t = Tokenizer( + in_keys=["string"], + out_keys=["tokens"], + max_length=5, + ) + base_env = CountingEnvWithString(max_size=4, min_size=4) + env = TransformedEnv(base_env, t) + check_env_specs(env, return_contiguous=False) + + @pytest.mark.parametrize("datatype", ["str"]) + def test_serial_trans_env_check(self, datatype): + def make_env(): + if datatype == "str": + t = Tokenizer( + in_keys=["string"], + out_keys=["tokens"], + max_length=5, + ) + base_env = CountingEnvWithString(max_size=4, min_size=4) + + return TransformedEnv(base_env, t) + + env = SerialEnv(2, make_env) + check_env_specs(env, return_contiguous=False) + + @pytest.mark.parametrize("datatype", ["str"]) + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv, datatype): + def make_env(): + if datatype == "str": + t = Tokenizer( + in_keys=["string"], + out_keys=["tokens"], + max_length=5, + ) + base_env = CountingEnvWithString(max_size=4, min_size=4) + return TransformedEnv(base_env, t) + + env = maybe_fork_ParallelEnv(2, make_env) + try: + check_env_specs(env, return_contiguous=False) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.parametrize("datatype", ["str"]) + def test_trans_serial_env_check(self, datatype): + if datatype == "str": + t = Tokenizer( + in_keys=["string"], + out_keys=["tokens"], + max_length=5, + ) + base_env = partial(CountingEnvWithString, max_size=4, min_size=4) + + env = TransformedEnv(SerialEnv(2, base_env), t) + check_env_specs(env, return_contiguous=False) + + @pytest.mark.parametrize("datatype", ["str"]) + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype): + if datatype == "str": + t = Tokenizer( + in_keys=["string"], + out_keys=["tokens"], + max_length=5, + ) + base_env = partial(CountingEnvWithString, max_size=4, min_size=4) + + env = TransformedEnv(maybe_fork_ParallelEnv(2, base_env), t) + try: + check_env_specs(env, return_contiguous=False) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.parametrize("datatype", ["str"]) + def test_transform_compose(self, datatype): + if datatype == "str": + obs = "abcdefg" + + td = TensorDict( + { + "observation": obs, + } + ) + t = Tokenizer( + in_keys=["observation"], + out_keys=["tokens"], + max_length=5, + ) + t = Compose(t) + td_tokenized = t(td) + + assert td_tokenized["observation"] is td["observation"] + assert ( + td_tokenized["tokens"] + == t[0].tokenizer.encode( + obs, + return_tensors="pt", + add_special_tokens=False, + padding="max_length", + max_length=5, + ) + ).all() + + @pytest.mark.parametrize("n", [3, 5, 7]) + def test_transform_model(self, n): + t = Tokenizer( + in_keys=["observation"], + out_keys=["tokens"], + max_length=n, + ) + model = nn.Sequential(t, nn.Identity()) + td = TensorDict({"observation": "a string!"}) + td_out = model(td) + assert ( + td_out["tokens"] == torch.tensor([1037, 5164, 999] + [0] * (n - 3)) + ).all() + + def test_transform_env(self): + import random + + random.seed(0) + t = Tokenizer( + in_keys=["string"], + out_keys=["tokens"], + max_length=10, + ) + base_env = CountingEnvWithString(max_steps=10, max_size=4, min_size=4) + env = TransformedEnv(base_env, t) + policy = lambda td: env.full_action_spec.one() + r = env.rollout(100, policy) + assert r["string"] == [ + "mzjp", + "sgqe", + "eydt", + "rwzt", + "jdxc", + "prdl", + "ktug", + "oqib", + "cxmw", + "tpkh", + "wcgs", + ] + assert ( + env.transform.tokenizer.batch_decode(r["tokens"], skip_special_tokens=True) + == r["string"] + ) + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb(self, rbclass): + t = Tokenizer( + in_keys=["observation"], + out_keys=["tokens"], + max_length=5, + ) + rb = rbclass(storage=LazyTensorStorage(10)) + rb.append_transform(t) + td = TensorDict( + { + "observation": NonTensorStack( + "mzjp", + "sgqe", + "eydt", + "rwzt", + "jdxc", + "prdl", + "ktug", + "oqib", + "cxmw", + "tpkh", + ), + }, + [10], + ) + rb.extend(td) + td = rb.sample(2) + assert ( + t.tokenizer.batch_decode(td["tokens"], skip_special_tokens=True) + == td["observation"] + ) + + def test_transform_inverse(self): + torch.manual_seed(0) + t = Tokenizer( + in_keys=[], + out_keys=[], + # The policy produces tokens + out_keys_inv=["tokens"], + # The env must see strings + in_keys_inv=["strings"], + max_length=5, + ) + base_env = CountingEnv() + + class CheckString(Transform): + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + assert "strings" in tensordict + tensordict.pop("strings") + return tensordict + + def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec: + action_spec["strings"] = NonTensor( + shape=action_spec.shape, example_data="a string!" + ) + return action_spec + + env = TransformedEnv(base_env, Compose(CheckString(), t)) + + def policy(td): + td.set("tokens", torch.randint(0, 10000, (10,))) + td.update(env.full_action_spec.one()) + return td + + env.check_env_specs() class TestStack(TransformBase): @@ -5358,28 +5631,33 @@ def test_transform_rb(self, rbclass): @pytest.mark.parametrize( "out_key", ["observation_out", ("nested", "observation_out")] ) - def test_transform_inverse(self, out_key, out_key_inv): + @pytest.mark.parametrize("compose", [False, True]) + def test_transform_inverse(self, out_key, out_key_inv, compose): standard_normal = True out_keys = [out_key] in_keys_inv = ["action"] out_keys_inv = [out_key_inv] - t = Compose( - ObservationNorm( - loc=torch.ones(()), - scale=0.5, - in_keys=["observation"], - out_keys=out_keys, - in_keys_inv=in_keys_inv, - out_keys_inv=out_keys_inv, - standard_normal=standard_normal, - ) + t = ObservationNorm( + loc=torch.ones(()), + scale=0.5, + in_keys=["observation"], + out_keys=out_keys, + # What the env asks for + in_keys_inv=in_keys_inv, + # What the outside world sees + out_keys_inv=out_keys_inv, + standard_normal=standard_normal, ) + if compose: + t = Compose(t) base_env = GymEnv(PENDULUM_VERSIONED()) env = TransformedEnv(base_env, t) + assert out_keys_inv[0] in env.full_action_spec.keys(True, True) td = env.rollout(3) check_env_specs(env) env.set_seed(0) - assert torch.allclose(td["action"] * 0.5 + 1, t.inv(td)[out_key_inv]) + a, a_ = td[out_key_inv] * 0.5 + 1, t.inv(td)["action"] + assert torch.allclose(a, a_), (a, a_) assert torch.allclose((td["observation"] - 1) / 0.5, td[out_key]) @pytest.mark.parametrize("batch", [[], [1], [3, 2]]) @@ -6915,13 +7193,13 @@ def test_transform_inverse(self): # the order is inverted Compose( UnsqueezeTransform( - -1, in_keys_inv=["action_t"], out_keys_inv=["action"] + -1, in_keys_inv=["action"], out_keys_inv=["action_t"] ), - SqueezeTransform(-1, in_keys_inv=["action"], out_keys_inv=["action_t"]), + SqueezeTransform(-1, in_keys_inv=["action_t"], out_keys_inv=["action"]), ), ) td = env.rollout(3) - assert env.action_spec.shape[-1] == 6 + assert env.full_action_spec["action"].shape[-1] == 6 assert td["action"].shape[-1] == 6 @@ -7036,8 +7314,10 @@ def _circular_transform(self): @property def _inv_circular_transform(self): return Compose( - UnsqueezeTransform(-1, in_keys_inv=["action_un"], out_keys_inv=["action"]), - SqueezeTransform(-1, in_keys_inv=["action"], out_keys_inv=["action_un"]), + # The env wants a squeezed action - the inv of unsqueeze + UnsqueezeTransform(-1, in_keys_inv=["action"], out_keys_inv=["action_un"]), + # The outsize world has an squeezed action that we unsqueeze - the inv of squeeze + SqueezeTransform(-1, in_keys_inv=["action_un"], out_keys_inv=["action"]), ) def test_single_trans_env_check(self): @@ -7195,7 +7475,7 @@ def test_transform_inverse(self): check_env_specs(env) r = env.rollout(3) r2 = GymEnv(HALFCHEETAH_VERSIONED()).rollout(3) - assert (r.zero_() == r2.zero_()).all() + assert_allclose_td(r.zero_(), r2.zero_(), intersection=True) class TestTargetReturn(TransformBase): @@ -11890,12 +12170,14 @@ def test_transform_rb(self, rbclass): SignTransform( in_keys=["observation", "reward"], out_keys=["obs_sign", "reward_sign"], - in_keys_inv=["input"], - out_keys_inv=["input_sign"], + # What is stored within + in_keys_inv=["input_signed"], + # What the outside world sees + out_keys_inv=["input_unsigned"], ) ) rb.append_transform(t) - data = TensorDict({"observation": 1, "reward": 2, "input": 3}, []) + data = TensorDict({"observation": 1, "reward": 2, "input_unsigned": 3}, []) rb.add(data) sample = rb.sample(20) @@ -11905,8 +12187,8 @@ def test_transform_rb(self, rbclass): assert (sample["reward"] == 2).all() assert self.check_sign_applied(sample["reward_sign"]) - assert (sample["input"] == 3).all() - assert self.check_sign_applied(sample["input_sign"]) + assert (sample["input_unsigned"] == 3).all() + assert self.check_sign_applied(sample["input_signed"]) def test_single_trans_env_check(self): env = ContinuousActionVecMockEnv() @@ -11951,15 +12233,17 @@ def test_transform_env(self, device): def test_transform_inverse(self): t = SignTransform( - in_keys_inv=["observation", "reward"], - out_keys_inv=["obs_sign", "reward_sign"], + # What is seen inside + in_keys_inv=["obs_signed", "reward_signed"], + # What the outside world sees + out_keys_inv=["obs", "reward"], ) - data = TensorDict({"observation": 1, "reward": 2}, []) + data = TensorDict({"obs": 1, "reward": 2}, []) data = t.inv(data) - assert data["observation"] == 1 - assert self.check_sign_applied(data["obs_sign"]) + assert data["obs"] == 1 + assert self.check_sign_applied(data["obs_signed"]) assert data["reward"] == 2 - assert self.check_sign_applied(data["reward_sign"]) + assert self.check_sign_applied(data["reward_signed"]) def test_transform_model(self): t = nn.Sequential( diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 3a4cde38aa2..fed73755502 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -94,6 +94,7 @@ TargetReturn, TensorDictPrimer, TimeMaxPool, + Tokenizer, ToTensorImage, TrajCounter, Transform, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index a25c676e378..7ee142fe811 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -55,6 +55,7 @@ TargetReturn, TensorDictPrimer, TimeMaxPool, + Tokenizer, ToTensorImage, TrajCounter, Transform, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ff6047310ff..3cba7d2bd1f 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -129,23 +129,42 @@ def _apply_to_composite_inv(function): # Now since EnvBase.step ignores new inputs (ie the root level of the # tensor is not updated) an out_key that does not match the in_key has # no effect on the spec. + @wraps(function) def new_fun(self, input_spec): - action_spec = input_spec["full_action_spec"].clone() - state_spec = input_spec["full_state_spec"] - if state_spec is None: - state_spec = Composite(shape=input_spec.shape, device=input_spec.device) + if "full_action_spec" in input_spec.keys(): + skip = False + action_spec = input_spec["full_action_spec"].clone() + state_spec = input_spec["full_state_spec"] + if state_spec is None: + state_spec = Composite(shape=input_spec.shape, device=input_spec.device) + else: + state_spec = state_spec.clone() else: - state_spec = state_spec.clone() + skip = True + # In case we pass full_action_spec or full_state_spec directly + action_spec = state_spec = Composite() in_keys_inv = self.in_keys_inv out_keys_inv = self.out_keys_inv for in_key, out_key in _zip_strict(in_keys_inv, out_keys_inv): - if in_key != out_key: - # we only change the input spec if the key is the same - continue + in_key = unravel_key(in_key) + out_key = unravel_key(out_key) + # if in_key != out_key: + # # we only change the input spec if the key is the same + # continue if in_key in action_spec.keys(True, True): action_spec[out_key] = function(self, action_spec[in_key].clone()) + if in_key != out_key: + del action_spec[in_key] elif in_key in state_spec.keys(True, True): state_spec[out_key] = function(self, state_spec[in_key].clone()) + if in_key != out_key: + del state_spec[in_key] + elif in_key in input_spec.keys(False, True): + input_spec[out_key] = function(self, input_spec[in_key].clone()) + if in_key != out_key: + del input_spec[in_key] + if skip: + return input_spec return Composite( full_state_spec=state_spec, full_action_spec=action_spec, @@ -353,13 +372,12 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: if not self.in_keys_inv: return tensordict for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): - data = tensordict.get(in_key, None) + data = tensordict.get(out_key, None) if data is not None: item = self._inv_apply_transform(data) - tensordict.set(out_key, item) + tensordict.set(in_key, item) elif not self.missing_tolerance: - raise KeyError(f"'{in_key}' not found in tensordict {tensordict}") - + raise KeyError(f"'{out_key}' not found in tensordict {tensordict}") return tensordict @dispatch(source="in_keys_inv", dest="out_keys_inv") @@ -420,6 +438,13 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: expected spec after the transform """ + input_spec = input_spec.clone() + input_spec["full_state_spec"] = self.transform_state_spec( + input_spec["full_state_spec"] + ) + input_spec["full_action_spec"] = self.transform_action_spec( + input_spec["full_action_spec"] + ) return input_spec def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: @@ -458,6 +483,30 @@ def transform_done_spec(self, done_spec: TensorSpec) -> TensorSpec: """ return done_spec + def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec: + """Transforms the action spec such that the resulting spec matches transform mapping. + + Args: + action_spec (TensorSpec): spec before the transform + + Returns: + expected spec after the transform + + """ + return action_spec + + def transform_state_spec(self, state_spec: TensorSpec) -> TensorSpec: + """Transforms the state spec such that the resulting spec matches transform mapping. + + Args: + state_spec (TensorSpec): spec before the transform + + Returns: + expected spec after the transform + + """ + return state_spec + def dump(self, **kwargs) -> None: pass @@ -1136,10 +1185,34 @@ def transform_env_batch_size(self, batch_size: torch.batch_size): return batch_size def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + # Input, action and state specs do NOT need to be reversed + # although applying these specs requires them to be called backward. + # To prove this, imagine we have 2 action transforms: t0 is an ActionDiscretizer, it maps float actions + # from the env to int actions for the policy. We add one more transform t1 that, if a == a_action_max, + # reduces its value by 1 (ie, the policy can sample actions from 0 to N + 1, and ActionDiscretizer + # has top N values). + # To apply this transform given an int action from the policy, we first call t1 to clamp the action to + # N (from N+1), then call t0 to map it to a float. + # We build this from TEnv(env, Compose(ActionDiscretizer, ActionClamp)) and call them starting with the + # last then the first. + # To know what the action spec is to the 'outside world' (ie, to the policy) we must take + # the action spec from the env, map it using t0 then t1 (going from in to out). for t in self.transforms: input_spec = t.transform_input_spec(input_spec) return input_spec + def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec: + # To understand why we don't invert, look up at transform_input_spec + for t in self.transforms: + action_spec = t.transform_action_spec(action_spec) + return action_spec + + def transform_state_spec(self, state_spec: TensorSpec) -> TensorSpec: + # To understand why we don't invert, look up at transform_input_spec + for t in self.transforms: + state_spec = t.transform_state_spec(state_spec) + return state_spec + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: for t in self.transforms: observation_spec = t.transform_observation_spec(observation_spec) @@ -2265,19 +2338,16 @@ def _transform_spec(self, spec: TensorSpec): spec.shape = self._apply_transform(torch.zeros(spec.shape)).shape return spec - def _inv_transform_spec(self, spec: TensorSpec) -> None: - space = spec.space - if isinstance(space, ContinuousBox): - space.low = self._inv_apply_transform(space.low) - space.high = self._inv_apply_transform(space.high) - spec.shape = space.low.shape - else: - spec.shape = self._inv_apply_transform(torch.zeros(spec.shape)).shape - return spec + # To map the specs, we actually use the forward call, not the inv + _inv_transform_spec = _transform_spec @_apply_to_composite_inv - def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: - return self._inv_transform_spec(input_spec) + def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec: + return self._inv_transform_spec(action_spec) + + @_apply_to_composite_inv + def transform_state_spec(self, state_spec: TensorSpec) -> TensorSpec: + return self._inv_transform_spec(state_spec) @_apply_to_composite def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: @@ -2827,13 +2897,29 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec space.high = self._apply_transform(space.high) return observation_spec + # @_apply_to_composite_inv + # def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + # space = input_spec.space + # if isinstance(space, ContinuousBox): + # space.low = self._apply_transform(space.low) + # space.high = self._apply_transform(space.high) + # return input_spec + @_apply_to_composite_inv - def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: - space = input_spec.space + def transform_action_spec(self, action_spec: TensorSpec) -> TensorSpec: + space = action_spec.space if isinstance(space, ContinuousBox): space.low = self._apply_transform(space.low) space.high = self._apply_transform(space.high) - return input_spec + return action_spec + + @_apply_to_composite_inv + def transform_state_spec(self, state_spec: TensorSpec) -> TensorSpec: + space = state_spec.space + if isinstance(space, ContinuousBox): + space.low = self._apply_transform(space.low) + space.high = self._apply_transform(space.high) + return state_spec def __repr__(self) -> str: if self.initialized and (self.loc.numel() == 1 and self.scale.numel() == 1): @@ -4437,10 +4523,15 @@ class UnaryTransform(Transform): Args: in_keys (sequence of NestedKey): the keys of inputs to the unary operation. out_keys (sequence of NestedKey): the keys of the outputs of the unary operation. - fn (Callable): the function to use as the unary operation. If it accepts - a non-tensor input, it must also accept ``None``. + in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the unary operation during inverse call. + out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the unary operation durin inverse call. Keyword Args: + fn (Callable[[Any], Tensor | TensorDictBase]): the function to use as the unary operation. If it accepts + a non-tensor input, it must also accept ``None``. + inv_fn (Callable[[Any], Any], optional): the function to use as the unary operation during inverse calls. + If it accepts a non-tensor input, it must also accept ``None``. + Can be ommitted, in which case :attr:`fn` will be used for inverse maps. use_raw_nontensor (bool, optional): if ``False``, data is extracted from :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` @@ -4511,12 +4602,21 @@ def __init__( self, in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], - fn: Callable, + in_keys_inv: Sequence[NestedKey] | None = None, + out_keys_inv: Sequence[NestedKey] | None = None, *, + fn: Callable[[Any], Tensor | TensorDictBase], + inv_fn: Callable[[Any], Any] | None = None, use_raw_nontensor: bool = False, ): - super().__init__(in_keys=in_keys, out_keys=out_keys) + super().__init__( + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + ) self._fn = fn + self._inv_fn = inv_fn self._use_raw_nontensor = use_raw_nontensor def _apply_transform(self, value): @@ -4530,6 +4630,19 @@ def _apply_transform(self, value): value = value.tolist() return self._fn(value) + def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: + if not self._use_raw_nontensor: + if isinstance(state, NonTensorData): + if state.dim() == 0: + state = state.get("data") + else: + state = state.tolist() + elif isinstance(state, NonTensorStack): + state = state.tolist() + if self._inv_fn is not None: + return self._inv_fn(state) + return self._fn(state) + def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: @@ -4537,6 +4650,41 @@ def _reset( tensordict_reset = self._call(tensordict_reset) return tensordict_reset + def transform_input_spec(self, input_spec: Composite) -> Composite: + input_spec = input_spec.clone() + + # Make a generic input from the spec, call the transform with that + # input, and then generate the output spec from the output. + zero_input_ = input_spec.zero() + test_input = zero_input_["full_action_spec"].update( + zero_input_["full_state_spec"] + ) + # We use forward and not inv because the spec comes from the base env and + # we are trying to infer what the spec looks like from the outside. + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): + data = test_input.get(in_key, None) + if data is not None: + data = self._apply_transform(data) + test_input.set(out_key, data) + elif not self.missing_tolerance: + raise KeyError(f"'{in_key}' not found in tensordict {test_input}") + test_output = test_input + # test_output = self.inv(test_input) + test_input_spec = make_composite_from_td( + test_output, unsqueeze_null_shapes=False + ) + + input_spec["full_action_spec"] = self.transform_action_spec( + input_spec["full_action_spec"], + test_input_spec, + ) + if "full_state_spec" in input_spec.keys(): + input_spec["full_state_spec"] = self.transform_state_spec( + input_spec["full_state_spec"], + test_input_spec, + ) + return input_spec + def transform_output_spec(self, output_spec: Composite) -> Composite: output_spec = output_spec.clone() @@ -4570,14 +4718,19 @@ def transform_output_spec(self, output_spec: Composite) -> Composite: return output_spec def _transform_spec( - self, spec: TensorSpec, test_output_spec: TensorSpec + self, spec: TensorSpec, test_output_spec: TensorSpec, inverse: bool = False ) -> TensorSpec: if not isinstance(spec, Composite): raise TypeError(f"{self}: Only specs of type Composite can be transformed") spec_keys = set(spec.keys(include_nested=True)) - for in_key, out_key in zip(self.in_keys, self.out_keys): + iterator = ( + zip(self.in_keys, self.out_keys) + if not inverse + else zip(self.in_keys_inv, self.out_keys_inv) + ) + for in_key, out_key in iterator: if in_key in spec_keys: spec.set(out_key, test_output_spec[out_key]) return spec @@ -4597,6 +4750,16 @@ def transform_done_spec( ) -> TensorSpec: return self._transform_spec(done_spec, test_output_spec) + def transform_action_spec( + self, action_spec: TensorSpec, test_input_spec: TensorSpec + ) -> TensorSpec: + return self._transform_spec(action_spec, test_input_spec, inverse=True) + + def transform_state_spec( + self, state_spec: TensorSpec, test_input_spec: TensorSpec + ) -> TensorSpec: + return self._transform_spec(state_spec, test_input_spec, inverse=True) + class Hash(UnaryTransform): r"""Adds a hash value to a tensordict. @@ -4604,12 +4767,21 @@ class Hash(UnaryTransform): Args: in_keys (sequence of NestedKey): the keys of the values to hash. out_keys (sequence of NestedKey): the keys of the resulting hashes. + in_keys_inv (sequence of NestedKey, optional): the keys of the values to hash during inv call. + + .. note:: If an inverse map is required, a repertoire ``Dict[Tuple[int], Any]`` of hash to value should be + passed alongside the list of keys to let the ``Hash`` transform know how to recover a value from a + given hash. This repertoire isn't copied, so it can be modified in the same workspace after the + transform instantiation and these modifications will be reflected in the map. Missing hashes will be + mapped to ``None``. + + out_keys_inv (sequence of NestedKey, optional): the keys of the resulting hashes during inv call. + + Keyword Args: hash_fn (Callable, optional): the hash function to use. If ``seed`` is given, the hash function must accept it as its second argument. Default is ``Hash.reproducible_hash``. seed (optional): seed to use for the hash function, if it requires one. - - Keyword Args: use_raw_nontensor (bool, optional): if ``False``, data is extracted from :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` @@ -4695,9 +4867,9 @@ def __init__( self, in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], + *, hash_fn: Callable = None, seed: Any | None = None, - *, use_raw_nontensor: bool = False, ): if hash_fn is None: @@ -4712,6 +4884,35 @@ def __init__( use_raw_nontensor=use_raw_nontensor, ) + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + inputs = tensordict.select(*self.in_keys_inv).detach().cpu() + tensordict = super()._inv_call(tensordict) + + def register_outcome(td): + # We need to treat each hash independently + if td.ndim: + if td.ndim > 1: + td_r = td.reshape(-1) + elif td.ndim == 1: + td_r = td + result = torch.stack([register_outcome(_td) for _td in td_r.unbind(0)]) + if td_r is not td: + return result.reshape(td.shape) + return result + for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + inp = inputs.get(in_key) + inp = tuple(inp.tolist()) + outp = self._repertoire.get(inp) + td[out_key] = outp + return td + + return register_outcome(tensordict) + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + if self.in_keys_inv is not None: + return {"_repertoire": self._repertoire} + return {} + def call_hash_fn(self, value): if self._seed is None: return self._hash_fn(value) @@ -4736,7 +4937,7 @@ def reproducible_hash(cls, string, seed=None): if seed is not None: seeded_string = seed + string else: - seeded_string = string + seeded_string = str(string) # Create a new SHA-256 hash object hash_object = hashlib.sha256() @@ -4750,6 +4951,143 @@ def reproducible_hash(cls, string, seed=None): return torch.frombuffer(hash_bytes, dtype=torch.uint8) +class Tokenizer(UnaryTransform): + r"""Applies a tokenization operation on the specified inputs. + + Args: + in_keys (sequence of NestedKey): the keys of inputs to the tokenization operation. + out_keys (sequence of NestedKey): the keys of the outputs of the tokenization operation. + in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the tokenization operation during inverse call. + out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the tokenization operation during inverse call. + + Keyword Args: + tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``, + "bert-base-uncased" will be used by default. If a string is provided, it should be the name of a + pre-trained tokenizer. + use_raw_nontensor (bool, optional): if ``False``, data is extracted from + :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before the tokenization + function is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` + inputs are given directly to the tokenization function, which must support those inputs. Default is ``False``. + additional_tokens (List[str], optional): list of additional tokens to add to the tokenizer's vocabulary. + + .. note:: This transform can be used both to transform output strings into tokens and to transform back tokenized + actions or states into strings. If the environment has a string state-spec, the transformed version will have + a tokenized state-spec. If it is a string action spec, it will result in a tokenized action spec. + + """ + + def __init__( + self, + in_keys: Sequence[NestedKey], + out_keys: Sequence[NestedKey], + in_keys_inv: Sequence[NestedKey] | None = None, + out_keys_inv: Sequence[NestedKey] | None = None, + *, + tokenizer: "transformers.PretrainedTokenizerBase" = None, # noqa: F821 + use_raw_nontensor: bool = False, + additional_tokens: List[str] | None = None, + skip_special_tokens: bool = True, + add_special_tokens: bool = False, + padding: bool = True, + max_length: int | None = None, + ): + if tokenizer is None: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + elif isinstance(tokenizer, str): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + + self.tokenizer = tokenizer + self.add_special_tokens = add_special_tokens + self.skip_special_tokens = skip_special_tokens + self.padding = padding + self.max_length = max_length + if additional_tokens: + self.tokenizer.add_tokens(additional_tokens) + super().__init__( + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + fn=self.call_tokenizer_fn, + inv_fn=self.call_tokenizer_inv_fn, + use_raw_nontensor=use_raw_nontensor, + ) + + @property + def device(self): + if "_device" in self.__dict__: + return self._device + parent = self.parent + if parent is None: + return None + device = parent.device + self._device = device + return device + + def call_tokenizer_fn(self, value: str | List[str]): + device = self.device + kwargs = {"add_special_tokens": self.add_special_tokens} + if self.max_length is not None: + kwargs["padding"] = "max_length" + kwargs["max_length"] = self.max_length + if isinstance(value, str): + out = self.tokenizer.encode(value, return_tensors="pt", **kwargs)[0] + # TODO: incorporate attention mask + # attention_mask = torch.ones_like(out, dtype=torch.bool) + else: + kwargs["padding"] = ( + self.padding if self.max_length is None else "max_length" + ) + # kwargs["return_attention_mask"] = False + # kwargs["return_token_type_ids"] = False + out = self.tokenizer.batch_encode_plus(value, return_tensors="pt", **kwargs) + # attention_mask = out["attention_mask"] + out = out["input_ids"] + + if device is not None and out.device != device: + out = out.to(device) + return out + + def call_tokenizer_inv_fn(self, value: Tensor): + if value.ndim == 1: + out = self.tokenizer.decode( + value, skip_special_tokens=self.skip_special_tokens + ) + else: + out = self.tokenizer.batch_decode( + value, skip_special_tokens=self.skip_special_tokens + ) + if isinstance(out, list): + return NonTensorStack(*out) + return NonTensorData(out) + + def transform_input_spec(self, input_spec: Composite) -> Composite: + input_spec = super().transform_input_spec(input_spec) + # We need to cap the spec to generate valid random strings + for out_key in self.out_keys_inv: + if out_key in input_spec["full_state_spec"].keys(True, True): + input_spec["full_state_spec"][out_key] = Bounded( + 0, + self.tokenizer.vocab_size, + shape=input_spec["full_state_spec"][out_key].shape, + device=input_spec["full_state_spec"][out_key].device, + dtype=input_spec["full_state_spec"][out_key].dtype, + ) + elif out_key in input_spec["full_action_spec"].keys(True, True): + input_spec["full_action_spec"][out_key] = Bounded( + 0, + self.tokenizer.vocab_size, + shape=input_spec["full_action_spec"][out_key].shape, + device=input_spec["full_action_spec"][out_key].device, + dtype=input_spec["full_action_spec"][out_key].dtype, + ) + return input_spec + + class Stack(Transform): """Stacks tensors and tensordicts.