Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 27, 2025
1 parent 16d7e60 commit 8d5bd03
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2699,6 +2699,7 @@ def specs(self) -> Composite:

@property
def _has_dynamic_specs(self) -> bool:
# TODO: cache this value
return _has_dynamic_specs(self.specs)

def rollout(
Expand All @@ -2711,7 +2712,7 @@ def rollout(
auto_cast_to_device: bool = False,
break_when_any_done: bool | None = None,
break_when_all_done: bool | None = None,
return_contiguous: bool = True,
return_contiguous: bool | None = False,
tensordict: Optional[TensorDictBase] = None,
set_truncated: bool = False,
out=None,
Expand Down Expand Up @@ -2746,7 +2747,8 @@ def rollout(
break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any
of the done states. If ``False``, break if at least one environment reaches any of the done states.
Default is ``False``.
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is `True` if
the env does not have dynamic specs, otherwise `False`.
tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
environment in those dimensions (if needed).
Expand Down Expand Up @@ -2957,7 +2959,8 @@ def rollout(
raise TypeError(
"Cannot have both break_when_all_done and break_when_any_done True at the same time."
)

if return_contiguous is None:
return_contiguous = not self._has_dynamic_specs
if policy is not None:
policy = _make_compatible_policy(
policy,
Expand Down

0 comments on commit 8d5bd03

Please sign in to comment.