Skip to content

Commit

Permalink
[BugFix] patch rand_action in TransformedEnv to read the base_env method
Browse files Browse the repository at this point in the history
ghstack-source-id: 04e2e85e2675cf34c349ebadb8fa85a5aff2e532
Pull Request resolved: #2699
  • Loading branch information
vmoens committed Jan 26, 2025
1 parent ec370c6 commit 2c19fcc
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,28 @@ def input_spec(self) -> TensorSpec:
input_spec = self.__dict__.get("_input_spec", None)
return input_spec

def rand_action(self, tensordict: Optional[TensorDictBase] = None) -> TensorDict:
if type(self.base_env).rand_action is not EnvBase.rand_action:
# TODO: this will fail if the transform modifies the input.
# For instance, if an env overrides rand_action and we build a
# env = PendulumEnv().append_transform(ActionDiscretizer(num_intervals=4))
# env.rand_action will NOT have a discrete action!
# Getting a discrete action would require coding the inverse transform of an action within
# ActionDiscretizer (ie, float->int, not int->float).
# We can loosely check that the action_spec isn't altered - that doesn't mean the action is
# intact but it covers part of these alterations.
#
# The following check may be expensive to run and could be cached.
if self.full_action_spec != self.base_env.full_action_spec:
raise RuntimeError(
f"The rand_action method from the base env {self.base_env.__class__.__name__} "
"has been overwritten, but the transforms appended to the environment modify "
"the action. To call the base env rand_action method, we should then invert the "
"action transform, which is (in general) not doable."
)
return self.base_env.rand_action(tensordict)
return super().rand_action(tensordict)

def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# No need to clone here because inv does it already
# tensordict = tensordict.clone(False)
Expand Down

0 comments on commit 2c19fcc

Please sign in to comment.