From 7a7c1aa2773dc60371f9c1ea6e88435ee64d4482 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 19 Dec 2024 12:56:01 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- .github/workflows/nightly_build.yml | 39 +++++---- test/test_tensordictmodules.py | 105 +++++++++++++++++------ torchrl/envs/batched_envs.py | 36 +++++++- torchrl/envs/transforms/transforms.py | 14 +-- torchrl/modules/tensordict_module/rnn.py | 8 ++ 5 files changed, 149 insertions(+), 53 deletions(-) diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml index 08eb61bfa6c..732077f4b58 100644 --- a/.github/workflows/nightly_build.yml +++ b/.github/workflows/nightly_build.yml @@ -21,11 +21,6 @@ on: branches: - "nightly" -env: - ACTIONS_RUNNER_FORCED_INTERNAL_NODE_VERSION: node16 - ACTIONS_RUNNER_FORCE_ACTIONS_NODE_VERSION: node16 - ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true # https://github.com/actions/checkout/issues/1809 - concurrency: # Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}. # On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke. @@ -41,12 +36,15 @@ jobs: matrix: python_version: [["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]] cuda_support: [["", "cpu", "cpu"]] - container: pytorch/manylinux-${{ matrix.cuda_support[2] }} steps: - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 env: AGENT_TOOLSDIRECTORY: "/opt/hostedtoolcache" + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version[0] }} - name: Install PyTorch nightly run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" @@ -67,7 +65,7 @@ jobs: python3 -mpip install auditwheel auditwheel show dist/* - name: Upload wheel for the test-wheel job - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: torchrl-linux-${{ matrix.python_version[0] }}_${{ matrix.cuda_support[2] }}.whl path: dist/*.whl @@ -81,12 +79,15 @@ jobs: matrix: python_version: [["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]] cuda_support: [["", "cpu", "cpu"]] - container: pytorch/manylinux-${{ matrix.cuda_support[2] }} steps: - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version[0] }} - name: Download built wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: torchrl-linux-${{ matrix.python_version[0] }}_${{ matrix.cuda_support[2] }}.whl path: /tmp/wheels @@ -121,7 +122,7 @@ jobs: env: AGENT_TOOLSDIRECTORY: "/opt/hostedtoolcache" - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install PyTorch Nightly run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" @@ -138,7 +139,7 @@ jobs: export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" python3 -mpip install numpy pytest pillow>=4.1.1 scipy networkx expecttest pyyaml - name: Download built wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: torchrl-linux-${{ matrix.python_version[0] }}_${{ matrix.cuda_support[2] }}.whl path: /tmp/wheels @@ -179,7 +180,7 @@ jobs: with: python-version: ${{ matrix.python_version[1] }} - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install PyTorch nightly shell: bash run: | @@ -193,7 +194,7 @@ jobs: --package_name torchrl-nightly \ --python-tag=${{ matrix.python-tag }} - name: Upload wheel for the test-wheel job - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: torchrl-win-${{ matrix.python_version[0] }}.whl path: dist/*.whl @@ -212,7 +213,7 @@ jobs: with: python-version: ${{ matrix.python_version[1] }} - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install PyTorch Nightly shell: bash run: | @@ -229,7 +230,7 @@ jobs: run: | python3 -mpip install git+https://github.com/pytorch/tensordict.git - name: Download built wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: torchrl-win-${{ matrix.python_version[0] }}.whl path: wheels @@ -265,9 +266,9 @@ jobs: python_version: [["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]] steps: - name: Checkout torchrl - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Download built wheels - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: torchrl-win-${{ matrix.python_version[0] }}.whl path: wheels diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index d3b7b7850f4..c2a34f3797d 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import functools import os import pytest @@ -12,6 +13,7 @@ import torchrl.modules from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential +from tensordict.utils import assert_close from torch import nn from torchrl.data.tensor_specs import Bounded, Composite, Unbounded from torchrl.envs import ( @@ -938,10 +940,12 @@ def test_multi_consecutive(self, shape, python_based): @pytest.mark.parametrize("python_based", [True, False]) @pytest.mark.parametrize("parallel", [True, False]) @pytest.mark.parametrize("heterogeneous", [True, False]) - def test_lstm_parallel_env(self, python_based, parallel, heterogeneous): + @pytest.mark.parametrize("within", [False, True]) + def test_lstm_parallel_env(self, python_based, parallel, heterogeneous, within): from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv torch.manual_seed(0) + num_envs = 3 device = "cuda" if torch.cuda.device_count() else "cpu" # tests that hidden states are carried over with parallel envs lstm_module = LSTMModule( @@ -958,25 +962,36 @@ def test_lstm_parallel_env(self, python_based, parallel, heterogeneous): else: cls = SerialEnv - def create_transformed_env(): - primer = lstm_module.make_tensordict_primer() - env = DiscreteActionVecMockEnv( - categorical_action_encoding=True, device=device + if within: + + def create_transformed_env(): + primer = lstm_module.make_tensordict_primer() + env = DiscreteActionVecMockEnv( + categorical_action_encoding=True, device=device + ) + env = TransformedEnv(env) + env.append_transform(InitTracker()) + env.append_transform(primer) + return env + + else: + create_transformed_env = functools.partial( + DiscreteActionVecMockEnv, + categorical_action_encoding=True, + device=device, ) - env = TransformedEnv(env) - env.append_transform(InitTracker()) - env.append_transform(primer) - return env if heterogeneous: create_transformed_env = [ - EnvCreator(create_transformed_env), - EnvCreator(create_transformed_env), + EnvCreator(create_transformed_env) for _ in range(num_envs) ] env = cls( create_env_fn=create_transformed_env, - num_workers=2, + num_workers=num_envs, ) + if not within: + env = env.append_transform(InitTracker()) + env.append_transform(lstm_module.make_tensordict_primer()) mlp = TensorDictModule( MLP( @@ -1002,6 +1017,19 @@ def create_transformed_env(): data = env.rollout(10, actor, break_when_any_done=break_when_any_done) assert (data.get(("next", "recurrent_state_c")) != 0.0).all() assert (data.get("recurrent_state_c") != 0.0).any() + return data + + @pytest.mark.parametrize("python_based", [True, False]) + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("heterogeneous", [True, False]) + def test_lstm_parallel_within(self, python_based, parallel, heterogeneous): + out_within = self.test_lstm_parallel_env( + python_based, parallel, heterogeneous, within=True + ) + out_not_within = self.test_lstm_parallel_env( + python_based, parallel, heterogeneous, within=False + ) + assert_close(out_within, out_not_within) @pytest.mark.skipif( not _has_functorch, reason="vmap can only be used with functorch" @@ -1330,10 +1358,12 @@ def test_multi_consecutive(self, shape, python_based): @pytest.mark.parametrize("python_based", [True, False]) @pytest.mark.parametrize("parallel", [True, False]) @pytest.mark.parametrize("heterogeneous", [True, False]) - def test_gru_parallel_env(self, python_based, parallel, heterogeneous): + @pytest.mark.parametrize("within", [False, True]) + def test_gru_parallel_env(self, python_based, parallel, heterogeneous, within): from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv torch.manual_seed(0) + num_workers = 3 device = "cuda" if torch.cuda.device_count() else "cpu" # tests that hidden states are carried over with parallel envs @@ -1347,15 +1377,24 @@ def test_gru_parallel_env(self, python_based, parallel, heterogeneous): python_based=python_based, ) - def create_transformed_env(): - primer = gru_module.make_tensordict_primer() - env = DiscreteActionVecMockEnv( - categorical_action_encoding=True, device=device + if within: + + def create_transformed_env(): + primer = gru_module.make_tensordict_primer() + env = DiscreteActionVecMockEnv( + categorical_action_encoding=True, device=device + ) + env = TransformedEnv(env) + env.append_transform(InitTracker()) + env.append_transform(primer) + return env + + else: + create_transformed_env = functools.partial( + DiscreteActionVecMockEnv, + categorical_action_encoding=True, + device=device, ) - env = TransformedEnv(env) - env.append_transform(InitTracker()) - env.append_transform(primer) - return env if parallel: cls = ParallelEnv @@ -1363,14 +1402,17 @@ def create_transformed_env(): cls = SerialEnv if heterogeneous: create_transformed_env = [ - EnvCreator(create_transformed_env), - EnvCreator(create_transformed_env), + EnvCreator(create_transformed_env) for _ in range(num_workers) ] - env = cls( + env: ParallelEnv | SerialEnv = cls( create_env_fn=create_transformed_env, - num_workers=2, + num_workers=num_workers, ) + if not within: + primer = gru_module.make_tensordict_primer() + env = env.append_transform(InitTracker()) + env.append_transform(primer) mlp = TensorDictModule( MLP( @@ -1396,6 +1438,19 @@ def create_transformed_env(): data = env.rollout(10, actor, break_when_any_done=break_when_any_done) assert (data.get("recurrent_state") != 0.0).any() assert (data.get(("next", "recurrent_state")) != 0.0).all() + return data + + @pytest.mark.parametrize("python_based", [True, False]) + @pytest.mark.parametrize("parallel", [True, False]) + @pytest.mark.parametrize("heterogeneous", [True, False]) + def test_gru_parallel_within(self, python_based, parallel, heterogeneous): + out_within = self.test_gru_parallel_env( + python_based, parallel, heterogeneous, within=True + ) + out_not_within = self.test_gru_parallel_env( + python_based, parallel, heterogeneous, within=False + ) + assert_close(out_within, out_not_within) @pytest.mark.skipif( not _has_functorch, reason="vmap can only be used with functorch" diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 17bd28c8390..f7a25c1bd5c 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1744,14 +1744,39 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # We keep track of which keys are present to let the worker know what # should be passed to the env (we don't want to pass done states for instance) next_td_keys = list(next_td_passthrough.keys(True, True)) + next_shared_tensordict_parent = shared_tensordict_parent.get("next") + + # We separate keys that are and are not present in the buffer here and not in step_and_maybe_reset. + # The reason we do that is that the policy may write stuff in 'next' that is not part of the specs of + # the batched env but part of the specs of a transformed batched env. + # If that is the case, `update_` will fail to find the entries to update. + # What we do instead is keeping the tensors on the side and putting them back after completing _step. + keys_to_update, keys_to_copy = zip( + *[ + (key, None) + if key in next_shared_tensordict_parent.keys(True, True) + else (None, key) + for key in next_td_keys + ] + ) + keys_to_update = [key for key in keys_to_update if key is not None] + keys_to_copy = [key for key in keys_to_copy if key is not None] data = [ - {"next_td_passthrough_keys": next_td_keys} + {"next_td_passthrough_keys": keys_to_update} for _ in range(self.num_workers) ] - shared_tensordict_parent.get("next").update_( - next_td_passthrough, non_blocking=self.non_blocking - ) + if keys_to_update: + next_shared_tensordict_parent.update_( + next_td_passthrough, + non_blocking=self.non_blocking, + keys_to_update=keys_to_update, + ) + if keys_to_copy: + next_td_passthrough = next_td_passthrough.select(*keys_to_copy) + else: + next_td_passthrough = None else: + next_td_passthrough = None data = [{} for _ in range(self.num_workers)] if self._non_tensor_keys: @@ -1807,6 +1832,9 @@ def select_and_clone(name, tensor): LazyStackedTensorDict(*non_tensor_tds), keys_to_update=self._non_tensor_keys, ) + if next_td_passthrough is not None: + out.update(next_td_passthrough) + self._sync_w2m() if partial_steps is not None: result = out.new_zeros(tensordict_save.shape) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f3329d085df..14d4133412c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -5089,12 +5089,16 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: ) if self.primers.shape != observation_spec.shape: - try: - # We try to set the primer shape to the observation spec shape - self.primers.shape = observation_spec.shape - except ValueError: - # If we fail, we expand them to that shape + if self.primers.shape == () and self.parent.batch_size != (): self.primers = self._expand_shape(self.primers) + else: + try: + # We try to set the primer shape to the observation spec shape + self.primers.shape = observation_spec.shape + except ValueError: + # If we fail, we expand them to that shape + self.primers = self._expand_shape(self.primers) + device = observation_spec.device observation_spec.update(self.primers.clone().to(device)) return observation_spec diff --git a/torchrl/modules/tensordict_module/rnn.py b/torchrl/modules/tensordict_module/rnn.py index f4ceb648665..57bcac94cf4 100644 --- a/torchrl/modules/tensordict_module/rnn.py +++ b/torchrl/modules/tensordict_module/rnn.py @@ -592,6 +592,10 @@ def make_tensordict_primer(self): inputs and outputs (recurrent states) during rollout execution. That way, the data can be shared across processes and dealt with properly. + When using batched environments such as :class:`~torchrl.envs.ParallelEnv`, the transform can be used at the + single env instance level (i.e., a batch of transformed envs with tensordict primers set within) or at the + batched env instance level (i.e., a transformed batch of regular envs). + Not including a ``TensorDictPrimer`` in the environment may result in poorly defined behaviors, for instance in parallel settings where a step involves copying the new recurrent state from ``"next"`` to the root tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states @@ -1410,6 +1414,10 @@ def make_tensordict_primer(self): tensordict, which the meth:`~torchrl.EnvBase.step_mdp` method will not be able to do as the recurrent states are not registered within the environment specs. + When using batched environments such as :class:`~torchrl.envs.ParallelEnv`, the transform can be used at the + single env instance level (i.e., a batch of transformed envs with tensordict primers set within) or at the + batched env instance level (i.e., a transformed batch of regular envs). + See :func:`torchrl.modules.utils.get_primers_from_module` for a method to generate all primers for a given module.