Skip to content

Commit

Permalink
Minor fixes to the docs and type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasbbrunner authored and vmoens committed Nov 11, 2024
1 parent 58c3847 commit 91d2c07
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

"""
This script reproduces the Proximal Policy Optimization (PPO) Algorithm
results from Schulman et al. 2017 for the on Atari Environments.
results from Schulman et al. 2017 for the Atari Environments.
"""
import hydra
from torchrl._utils import logger as torchrl_logger
Expand Down
25 changes: 14 additions & 11 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def append_transform(
self,
transform: "Transform" # noqa: F821
| Callable[[TensorDictBase], TensorDictBase],
) -> None:
) -> EnvBase:
"""Returns a transformed environment where the callable/transform passed is applied.
Args:
Expand Down Expand Up @@ -1482,7 +1482,8 @@ def full_state_spec(self, spec: Composite) -> None:

# Single-env specs can be used to remove the batch size from the spec
@property
def batch_dims(self):
def batch_dims(self) -> int:
"""Number of batch dimensions of the env."""
return len(self.batch_size)

def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec:
Expand Down Expand Up @@ -2444,11 +2445,11 @@ def rollout(
set_truncated: bool = False,
out=None,
trust_policy: bool = False,
):
) -> TensorDictBase:
"""Executes a rollout in the environment.
The function will stop as soon as one of the contained environments
returns done=True.
The function will return as soon as any of the contained environments
reaches any of the done states.
Args:
max_steps (int): maximum number of steps to be executed. The actual number of steps can be smaller if
Expand All @@ -2464,14 +2465,16 @@ def rollout(
the call to ``rollout``.
Keyword Args:
auto_reset (bool, optional): if ``True``, resets automatically the environment
if it is in a done state when the rollout is initiated.
Default is ``True``.
auto_reset (bool, optional): if ``True``, the contained environments will be reset before starting the
rollout. If ``False``, then the rollout will continue from a previous state, which requires the
``tensordict`` argument to be passed with the previous rollout. Default is ``True``.
auto_cast_to_device (bool, optional): if ``True``, the device of the tensordict is automatically cast to the
policy device before the policy is used. Default is ``False``.
break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is
called on the sub-envs that are done. Default is True.
break_when_all_done (bool): TODO
break_when_any_done (bool): if ``True``, break when any of the contained environments reaches any of the
done states. If ``False``, then the done environments are reset automatically. Default is ``True``.
break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any
of the done states. If ``False``, break if at least one environment reaches any of the done states.
Default is ``False``.
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
Expand Down

0 comments on commit 91d2c07

Please sign in to comment.