From 1f6c7f267718728a13105cb6003826cfe83f1457 Mon Sep 17 00:00:00 2001 From: "Thomas B. Brunner" Date: Mon, 11 Nov 2024 13:59:51 +0100 Subject: [PATCH] [Doc] Minor fixes to the docs and type hints (#2548) (cherry picked from commit 50a35f69bf3b2d0930fb4dada76acf8e3b84e899) --- sota-implementations/ppo/ppo_atari.py | 2 +- torchrl/envs/common.py | 25 ++++++++++++++----------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 6d8883393d5..276c706baef 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -5,7 +5,7 @@ """ This script reproduces the Proximal Policy Optimization (PPO) Algorithm -results from Schulman et al. 2017 for the on Atari Environments. +results from Schulman et al. 2017 for the Atari Environments. """ import hydra from torchrl._utils import logger as torchrl_logger diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 98a35628c40..76a4bbf86dc 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -516,7 +516,7 @@ def append_transform( self, transform: "Transform" # noqa: F821 | Callable[[TensorDictBase], TensorDictBase], - ) -> None: + ) -> EnvBase: """Returns a transformed environment where the callable/transform passed is applied. Args: @@ -1482,7 +1482,8 @@ def full_state_spec(self, spec: Composite) -> None: # Single-env specs can be used to remove the batch size from the spec @property - def batch_dims(self): + def batch_dims(self) -> int: + """Number of batch dimensions of the env.""" return len(self.batch_size) def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec: @@ -2425,11 +2426,11 @@ def rollout( set_truncated: bool = False, out=None, trust_policy: bool = False, - ): + ) -> TensorDictBase: """Executes a rollout in the environment. - The function will stop as soon as one of the contained environments - returns done=True. + The function will return as soon as any of the contained environments + reaches any of the done states. Args: max_steps (int): maximum number of steps to be executed. The actual number of steps can be smaller if @@ -2445,14 +2446,16 @@ def rollout( the call to ``rollout``. Keyword Args: - auto_reset (bool, optional): if ``True``, resets automatically the environment - if it is in a done state when the rollout is initiated. - Default is ``True``. + auto_reset (bool, optional): if ``True``, the contained environments will be reset before starting the + rollout. If ``False``, then the rollout will continue from a previous state, which requires the + ``tensordict`` argument to be passed with the previous rollout. Default is ``True``. auto_cast_to_device (bool, optional): if ``True``, the device of the tensordict is automatically cast to the policy device before the policy is used. Default is ``False``. - break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is - called on the sub-envs that are done. Default is True. - break_when_all_done (bool): TODO + break_when_any_done (bool): if ``True``, break when any of the contained environments reaches any of the + done states. If ``False``, then the done environments are reset automatically. Default is ``True``. + break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any + of the done states. If ``False``, break if at least one environment reaches any of the done states. + Default is ``False``. return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True. tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial tensordict must be provided. Rollout will check if this tensordict has done flags and reset the