From e5fc1c118bd03f24bc224bbfbe5042b8fd57adc5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 26 Jan 2025 14:24:23 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- docs/source/reference/envs.rst | 75 ++++++++++++++++++++++----- test/test_transforms.py | 25 ++++----- torchrl/envs/transforms/transforms.py | 19 +++---- 3 files changed, 82 insertions(+), 37 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index c4f3f6eda9a..a1520ca0c63 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -731,29 +731,80 @@ pixels or states etc). Forward and inverse transforms ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Transforms also have an ``inv`` method that is called before -the action is applied in reverse order over the composed transform chain: -this allows to apply transforms to data in the environment before the action is taken -in the environment. The keys to be included in this inverse transform are passed through the -``"in_keys_inv"`` keyword argument: +Transforms also have an :meth:`~torchrl.envs.Transform.inv` method that is called before the action is applied in reverse +order over the composed transform chain. This allows applying transforms to data in the environment before the action is +taken in the environment. The keys to be included in this inverse transform are passed through the `"in_keys_inv"` +keyword argument, and the out-keys default to these values in most cases: .. code-block:: :caption: Inverse transform >>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step -The way ``in_keys`` relates to ``in_keys_inv`` can be understood by considering the base environment as the "inner" part -of the transform. In constrast, the user inputs and outputs to and from the transform are to be considered as the -outside world. The following figure shows what this means in practice for the :class:`~torchrl.envs.RenameTransform` -class: the input ``TensorDict`` of the ``step`` function must have the ``out_keys_inv`` listed in its entries as they -are part of the outside world. The transform changes these names to make them match the names of the inner, base -environment using the ``in_keys_inv``. The inverse process is executed with the output tensordict, where the ``in_keys`` -are mapped to the corresponding ``out_keys``. +The following paragraphs detail how one can think about what is to be considered `in_` or `out_` features. + +Understanding Transform Keys +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In transforms, `in_keys` and `out_keys` define the interaction between the base environment and the outside world +(e.g., your policy): + +- `in_keys` refers to the base environment's perspective (inner = `base_env` of the + :class:`~torchrl.envs.TransformedEnv`). +- `out_keys` refers to the outside world (outer = `policy`, `agent`, etc.). + +For example, with `in_keys=["obs"]` and `out_keys=["obs_standardized"]`, the policy will "see" a standardized +observation, while the base environment outputs a regular observation. + +Similarly, for inverse keys: + +- `in_keys_inv` refers to entries as seen by the base environment. +- `out_keys_inv` refers to entries as seen or produced by the policy. + +The following figure illustrates this concept for the :class:`~torchrl.envs.RenameTransform` class: the input +`TensorDict` of the `step` function must include the `out_keys_inv` as they are part of the outside world. The +transform changes these names to match the names of the inner, base environment using the `in_keys_inv`. +The inverse process is executed with the output tensordict, where the `in_keys` are mapped to the corresponding +`out_keys`. .. figure:: /_static/img/rename_transform.png Rename transform logic +Transforming Tensors and Specs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When transforming actual tensors (coming from the policy), the process is schematically represented as: + + >>> for t in reversed(self.transform): + ... td = t.inv(td) + +This starts with the outermost transform to the innermost transform, ensuring the action value exposed to the policy +is properly transformed. + +For transforming the action spec, the process should go from innermost to outermost (similar to observation specs): + + >>> def transform_action_spec(self, action_spec): + ... for t in self.transform: + ... action_spec = t.transform_action_spec(action_spec) + ... return action_spec + +A pseudocode for a single transform_action_spec could be: + + >>> def transform_action_spec(self, action_spec): + ... return spec_from_random_values(self._apply_transform(action_spec.rand())) + +This approach ensures that the "outside" spec is inferred from the "inside" spec. Note that we did not call +`_inv_apply_transform` but `_apply_transform` on purpose! + +Exposing Specs to the Outside World +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +`TransformedEnv` will expose the specs corresponding to the `out_keys_inv` for actions and states. +For example, with :class:`~torchrl.envs.ActionDiscretizer`, the environment's action (e.g., `"action"`) is a float-valued +tensor that should not be generated when using :meth:`~torchrl.envs.EnvBase.rand_action` with the transformed +environment. Instead, `"action_discrete"` should be generated, and its continuous counterpart obtained from the +transform. Therefore, the user should see the `"action_discrete"` entry being exposed, but not `"action"`. Cloning transforms diff --git a/test/test_transforms.py b/test/test_transforms.py index b0d8bcfe8ef..c480015bf17 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -441,8 +441,8 @@ def test_transform_rb(self, rbclass): ClipTransform( in_keys=["observation", "reward"], out_keys=["obs_clip", "reward_clip"], - in_keys_inv=["input"], - out_keys_inv=["input_clip"], + in_keys_inv=["input_clip"], + out_keys_inv=["input"], low=-0.1, high=0.1, ) @@ -2509,20 +2509,17 @@ def test_transform_rb(self, rbclass): assert ("next", "observation") in td.keys(True) def test_transform_inverse(self): + return env = CountingEnv() - env = env.append_transform( - Hash( - in_keys=[], - out_keys=[], - in_keys_inv=["action"], - out_keys_inv=["action_hash"], + with pytest.raises(TypeError): + env = env.append_transform( + Hash( + in_keys=[], + out_keys=[], + in_keys_inv=["action"], + out_keys_inv=["action_hash"], + ) ) - ) - assert "action_hash" in env.action_keys - r = env.rollout(3) - env.check_env_specs() - assert "action_hash" in r - assert isinstance(r[0]["action_hash"], torch.Tensor) class TestTokenizer(TransformBase): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index afab09d0fba..3cba7d2bd1f 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -146,17 +146,23 @@ def new_fun(self, input_spec): in_keys_inv = self.in_keys_inv out_keys_inv = self.out_keys_inv for in_key, out_key in _zip_strict(in_keys_inv, out_keys_inv): + in_key = unravel_key(in_key) + out_key = unravel_key(out_key) # if in_key != out_key: # # we only change the input spec if the key is the same # continue if in_key in action_spec.keys(True, True): action_spec[out_key] = function(self, action_spec[in_key].clone()) + if in_key != out_key: + del action_spec[in_key] elif in_key in state_spec.keys(True, True): state_spec[out_key] = function(self, state_spec[in_key].clone()) + if in_key != out_key: + del state_spec[in_key] elif in_key in input_spec.keys(False, True): input_spec[out_key] = function(self, input_spec[in_key].clone()) - # else: - # raise RuntimeError(f"Couldn't find key '{in_key}' in input spec {input_spec}") + if in_key != out_key: + del input_spec[in_key] if skip: return input_spec return Composite( @@ -4857,19 +4863,14 @@ class Hash(UnaryTransform): [torchrl][INFO] check_env_specs succeeded! """ - _repertoire: Dict[Tuple[int], Any] - def __init__( self, in_keys: Sequence[NestedKey], out_keys: Sequence[NestedKey], - in_keys_inv: Sequence[NestedKey] | None = None, - out_keys_inv: Sequence[NestedKey] | None = None, *, hash_fn: Callable = None, seed: Any | None = None, use_raw_nontensor: bool = False, - repertoire: Dict[Tuple[int], Any] | None = None, ): if hash_fn is None: hash_fn = Hash.reproducible_hash @@ -4879,13 +4880,9 @@ def __init__( super().__init__( in_keys=in_keys, out_keys=out_keys, - in_keys_inv=in_keys_inv, - out_keys_inv=out_keys_inv, fn=self.call_hash_fn, use_raw_nontensor=use_raw_nontensor, ) - if in_keys_inv is not None: - self._repertoire = repertoire if repertoire is not None else {} def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: inputs = tensordict.select(*self.in_keys_inv).detach().cpu()