Skip to content

Commit

Permalink
[Feature] UnaryTransform for input entries
Browse files Browse the repository at this point in the history
ghstack-source-id: bb0ea97f47bdad6ba5e73692969fece4e2efbfb4
Pull Request resolved: #2700
  • Loading branch information
vmoens committed Jan 26, 2025
1 parent 2c19fcc commit 093a159
Show file tree
Hide file tree
Showing 6 changed files with 821 additions and 143 deletions.
75 changes: 63 additions & 12 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -731,29 +731,80 @@ pixels or states etc).
Forward and inverse transforms
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Transforms also have an ``inv`` method that is called before
the action is applied in reverse order over the composed transform chain:
this allows to apply transforms to data in the environment before the action is taken
in the environment. The keys to be included in this inverse transform are passed through the
``"in_keys_inv"`` keyword argument:
Transforms also have an :meth:`~torchrl.envs.Transform.inv` method that is called before the action is applied in reverse
order over the composed transform chain. This allows applying transforms to data in the environment before the action is
taken in the environment. The keys to be included in this inverse transform are passed through the `"in_keys_inv"`
keyword argument, and the out-keys default to these values in most cases:

.. code-block::
:caption: Inverse transform
>>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step
The way ``in_keys`` relates to ``in_keys_inv`` can be understood by considering the base environment as the "inner" part
of the transform. In constrast, the user inputs and outputs to and from the transform are to be considered as the
outside world. The following figure shows what this means in practice for the :class:`~torchrl.envs.RenameTransform`
class: the input ``TensorDict`` of the ``step`` function must have the ``out_keys_inv`` listed in its entries as they
are part of the outside world. The transform changes these names to make them match the names of the inner, base
environment using the ``in_keys_inv``. The inverse process is executed with the output tensordict, where the ``in_keys``
are mapped to the corresponding ``out_keys``.
The following paragraphs detail how one can think about what is to be considered `in_` or `out_` features.

Understanding Transform Keys
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In transforms, `in_keys` and `out_keys` define the interaction between the base environment and the outside world
(e.g., your policy):

- `in_keys` refers to the base environment's perspective (inner = `base_env` of the
:class:`~torchrl.envs.TransformedEnv`).
- `out_keys` refers to the outside world (outer = `policy`, `agent`, etc.).

For example, with `in_keys=["obs"]` and `out_keys=["obs_standardized"]`, the policy will "see" a standardized
observation, while the base environment outputs a regular observation.

Similarly, for inverse keys:

- `in_keys_inv` refers to entries as seen by the base environment.
- `out_keys_inv` refers to entries as seen or produced by the policy.

The following figure illustrates this concept for the :class:`~torchrl.envs.RenameTransform` class: the input
`TensorDict` of the `step` function must include the `out_keys_inv` as they are part of the outside world. The
transform changes these names to match the names of the inner, base environment using the `in_keys_inv`.
The inverse process is executed with the output tensordict, where the `in_keys` are mapped to the corresponding
`out_keys`.

.. figure:: /_static/img/rename_transform.png

Rename transform logic

Transforming Tensors and Specs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

When transforming actual tensors (coming from the policy), the process is schematically represented as:

>>> for t in reversed(self.transform):
... td = t.inv(td)

This starts with the outermost transform to the innermost transform, ensuring the action value exposed to the policy
is properly transformed.

For transforming the action spec, the process should go from innermost to outermost (similar to observation specs):

>>> def transform_action_spec(self, action_spec):
... for t in self.transform:
... action_spec = t.transform_action_spec(action_spec)
... return action_spec

A pseudocode for a single transform_action_spec could be:

>>> def transform_action_spec(self, action_spec):
... return spec_from_random_values(self._apply_transform(action_spec.rand()))

This approach ensures that the "outside" spec is inferred from the "inside" spec. Note that we did not call
`_inv_apply_transform` but `_apply_transform` on purpose!

Exposing Specs to the Outside World
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

`TransformedEnv` will expose the specs corresponding to the `out_keys_inv` for actions and states.
For example, with :class:`~torchrl.envs.ActionDiscretizer`, the environment's action (e.g., `"action"`) is a float-valued
tensor that should not be generated when using :meth:`~torchrl.envs.EnvBase.rand_action` with the transformed
environment. Instead, `"action_discrete"` should be generated, and its continuous counterpart obtained from the
transform. Therefore, the user should see the `"action_discrete"` entry being exposed, but not `"action"`.


Cloning transforms
Expand Down
5 changes: 4 additions & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,17 +1070,20 @@ def _step(

class CountingEnvWithString(CountingEnv):
def __init__(self, *args, **kwargs):
self.max_size = kwargs.pop("max_size", 30)
self.min_size = kwargs.pop("min_size", 4)
super().__init__(*args, **kwargs)
self.observation_spec.set(
"string",
NonTensor(
shape=self.batch_size,
device=self.device,
example_data=self.get_random_string(),
),
)

def get_random_string(self):
size = random.randint(4, 30)
size = random.randint(self.min_size, self.max_size)
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
Expand Down
Loading

0 comments on commit 093a159

Please sign in to comment.