From 8d5bd03ef68f693e50f9fcbf58ac33559111628b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 26 Jan 2025 18:07:52 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/envs/common.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index c3a714fcf91..35f86d78eaf 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2699,6 +2699,7 @@ def specs(self) -> Composite: @property def _has_dynamic_specs(self) -> bool: + # TODO: cache this value return _has_dynamic_specs(self.specs) def rollout( @@ -2711,7 +2712,7 @@ def rollout( auto_cast_to_device: bool = False, break_when_any_done: bool | None = None, break_when_all_done: bool | None = None, - return_contiguous: bool = True, + return_contiguous: bool | None = False, tensordict: Optional[TensorDictBase] = None, set_truncated: bool = False, out=None, @@ -2746,7 +2747,8 @@ def rollout( break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any of the done states. If ``False``, break if at least one environment reaches any of the done states. Default is ``False``. - return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True. + return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is `True` if + the env does not have dynamic specs, otherwise `False`. tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial tensordict must be provided. Rollout will check if this tensordict has done flags and reset the environment in those dimensions (if needed). @@ -2957,7 +2959,8 @@ def rollout( raise TypeError( "Cannot have both break_when_all_done and break_when_any_done True at the same time." ) - + if return_contiguous is None: + return_contiguous = not self._has_dynamic_specs if policy is not None: policy = _make_compatible_policy( policy,