diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index aa8c64bacc3..ff6047310ff 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -795,6 +795,28 @@ def input_spec(self) -> TensorSpec: input_spec = self.__dict__.get("_input_spec", None) return input_spec + def rand_action(self, tensordict: Optional[TensorDictBase] = None) -> TensorDict: + if type(self.base_env).rand_action is not EnvBase.rand_action: + # TODO: this will fail if the transform modifies the input. + # For instance, if an env overrides rand_action and we build a + # env = PendulumEnv().append_transform(ActionDiscretizer(num_intervals=4)) + # env.rand_action will NOT have a discrete action! + # Getting a discrete action would require coding the inverse transform of an action within + # ActionDiscretizer (ie, float->int, not int->float). + # We can loosely check that the action_spec isn't altered - that doesn't mean the action is + # intact but it covers part of these alterations. + # + # The following check may be expensive to run and could be cached. + if self.full_action_spec != self.base_env.full_action_spec: + raise RuntimeError( + f"The rand_action method from the base env {self.base_env.__class__.__name__} " + "has been overwritten, but the transforms appended to the environment modify " + "the action. To call the base env rand_action method, we should then invert the " + "action transform, which is (in general) not doable." + ) + return self.base_env.rand_action(tensordict) + return super().rand_action(tensordict) + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # No need to clone here because inv does it already # tensordict = tensordict.clone(False)