Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 19, 2024
1 parent e021391 commit 7a7c1aa
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 53 deletions.
39 changes: 20 additions & 19 deletions .github/workflows/nightly_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ on:
branches:
- "nightly"

env:
ACTIONS_RUNNER_FORCED_INTERNAL_NODE_VERSION: node16
ACTIONS_RUNNER_FORCE_ACTIONS_NODE_VERSION: node16
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true # https://github.com/actions/checkout/issues/1809

concurrency:
# Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}.
# On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke.
Expand All @@ -41,12 +36,15 @@ jobs:
matrix:
python_version: [["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]]
cuda_support: [["", "cpu", "cpu"]]
container: pytorch/manylinux-${{ matrix.cuda_support[2] }}
steps:
- name: Checkout torchrl
uses: actions/checkout@v3
uses: actions/checkout@v4
env:
AGENT_TOOLSDIRECTORY: "/opt/hostedtoolcache"
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version[0] }}
- name: Install PyTorch nightly
run: |
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
Expand All @@ -67,7 +65,7 @@ jobs:
python3 -mpip install auditwheel
auditwheel show dist/*
- name: Upload wheel for the test-wheel job
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: torchrl-linux-${{ matrix.python_version[0] }}_${{ matrix.cuda_support[2] }}.whl
path: dist/*.whl
Expand All @@ -81,12 +79,15 @@ jobs:
matrix:
python_version: [["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"]]
cuda_support: [["", "cpu", "cpu"]]
container: pytorch/manylinux-${{ matrix.cuda_support[2] }}
steps:
- name: Checkout torchrl
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version[0] }}
- name: Download built wheels
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: torchrl-linux-${{ matrix.python_version[0] }}_${{ matrix.cuda_support[2] }}.whl
path: /tmp/wheels
Expand Down Expand Up @@ -121,7 +122,7 @@ jobs:
env:
AGENT_TOOLSDIRECTORY: "/opt/hostedtoolcache"
- name: Checkout torchrl
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install PyTorch Nightly
run: |
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
Expand All @@ -138,7 +139,7 @@ jobs:
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
python3 -mpip install numpy pytest pillow>=4.1.1 scipy networkx expecttest pyyaml
- name: Download built wheels
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: torchrl-linux-${{ matrix.python_version[0] }}_${{ matrix.cuda_support[2] }}.whl
path: /tmp/wheels
Expand Down Expand Up @@ -179,7 +180,7 @@ jobs:
with:
python-version: ${{ matrix.python_version[1] }}
- name: Checkout torchrl
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install PyTorch nightly
shell: bash
run: |
Expand All @@ -193,7 +194,7 @@ jobs:
--package_name torchrl-nightly \
--python-tag=${{ matrix.python-tag }}
- name: Upload wheel for the test-wheel job
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: torchrl-win-${{ matrix.python_version[0] }}.whl
path: dist/*.whl
Expand All @@ -212,7 +213,7 @@ jobs:
with:
python-version: ${{ matrix.python_version[1] }}
- name: Checkout torchrl
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Install PyTorch Nightly
shell: bash
run: |
Expand All @@ -229,7 +230,7 @@ jobs:
run: |
python3 -mpip install git+https://github.com/pytorch/tensordict.git
- name: Download built wheels
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: torchrl-win-${{ matrix.python_version[0] }}.whl
path: wheels
Expand Down Expand Up @@ -265,9 +266,9 @@ jobs:
python_version: [["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]]
steps:
- name: Checkout torchrl
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Download built wheels
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: torchrl-win-${{ matrix.python_version[0] }}.whl
path: wheels
Expand Down
105 changes: 80 additions & 25 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import functools
import os

import pytest
Expand All @@ -12,6 +13,7 @@
import torchrl.modules
from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list
from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential
from tensordict.utils import assert_close
from torch import nn
from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
from torchrl.envs import (
Expand Down Expand Up @@ -938,10 +940,12 @@ def test_multi_consecutive(self, shape, python_based):
@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
@pytest.mark.parametrize("within", [False, True])
def test_lstm_parallel_env(self, python_based, parallel, heterogeneous, within):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

torch.manual_seed(0)
num_envs = 3
device = "cuda" if torch.cuda.device_count() else "cpu"
# tests that hidden states are carried over with parallel envs
lstm_module = LSTMModule(
Expand All @@ -958,25 +962,36 @@ def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
else:
cls = SerialEnv

def create_transformed_env():
primer = lstm_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
if within:

def create_transformed_env():
primer = lstm_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

else:
create_transformed_env = functools.partial(
DiscreteActionVecMockEnv,
categorical_action_encoding=True,
device=device,
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

if heterogeneous:
create_transformed_env = [
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env) for _ in range(num_envs)
]
env = cls(
create_env_fn=create_transformed_env,
num_workers=2,
num_workers=num_envs,
)
if not within:
env = env.append_transform(InitTracker())
env.append_transform(lstm_module.make_tensordict_primer())

mlp = TensorDictModule(
MLP(
Expand All @@ -1002,6 +1017,19 @@ def create_transformed_env():
data = env.rollout(10, actor, break_when_any_done=break_when_any_done)
assert (data.get(("next", "recurrent_state_c")) != 0.0).all()
assert (data.get("recurrent_state_c") != 0.0).any()
return data

@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_lstm_parallel_within(self, python_based, parallel, heterogeneous):
out_within = self.test_lstm_parallel_env(
python_based, parallel, heterogeneous, within=True
)
out_not_within = self.test_lstm_parallel_env(
python_based, parallel, heterogeneous, within=False
)
assert_close(out_within, out_not_within)

@pytest.mark.skipif(
not _has_functorch, reason="vmap can only be used with functorch"
Expand Down Expand Up @@ -1330,10 +1358,12 @@ def test_multi_consecutive(self, shape, python_based):
@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
@pytest.mark.parametrize("within", [False, True])
def test_gru_parallel_env(self, python_based, parallel, heterogeneous, within):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

torch.manual_seed(0)
num_workers = 3

device = "cuda" if torch.cuda.device_count() else "cpu"
# tests that hidden states are carried over with parallel envs
Expand All @@ -1347,30 +1377,42 @@ def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
python_based=python_based,
)

def create_transformed_env():
primer = gru_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
if within:

def create_transformed_env():
primer = gru_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

else:
create_transformed_env = functools.partial(
DiscreteActionVecMockEnv,
categorical_action_encoding=True,
device=device,
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

if parallel:
cls = ParallelEnv
else:
cls = SerialEnv
if heterogeneous:
create_transformed_env = [
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env) for _ in range(num_workers)
]

env = cls(
env: ParallelEnv | SerialEnv = cls(
create_env_fn=create_transformed_env,
num_workers=2,
num_workers=num_workers,
)
if not within:
primer = gru_module.make_tensordict_primer()
env = env.append_transform(InitTracker())
env.append_transform(primer)

mlp = TensorDictModule(
MLP(
Expand All @@ -1396,6 +1438,19 @@ def create_transformed_env():
data = env.rollout(10, actor, break_when_any_done=break_when_any_done)
assert (data.get("recurrent_state") != 0.0).any()
assert (data.get(("next", "recurrent_state")) != 0.0).all()
return data

@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_gru_parallel_within(self, python_based, parallel, heterogeneous):
out_within = self.test_gru_parallel_env(
python_based, parallel, heterogeneous, within=True
)
out_not_within = self.test_gru_parallel_env(
python_based, parallel, heterogeneous, within=False
)
assert_close(out_within, out_not_within)

@pytest.mark.skipif(
not _has_functorch, reason="vmap can only be used with functorch"
Expand Down
36 changes: 32 additions & 4 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1744,14 +1744,39 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# We keep track of which keys are present to let the worker know what
# should be passed to the env (we don't want to pass done states for instance)
next_td_keys = list(next_td_passthrough.keys(True, True))
next_shared_tensordict_parent = shared_tensordict_parent.get("next")

# We separate keys that are and are not present in the buffer here and not in step_and_maybe_reset.
# The reason we do that is that the policy may write stuff in 'next' that is not part of the specs of
# the batched env but part of the specs of a transformed batched env.
# If that is the case, `update_` will fail to find the entries to update.
# What we do instead is keeping the tensors on the side and putting them back after completing _step.
keys_to_update, keys_to_copy = zip(
*[
(key, None)
if key in next_shared_tensordict_parent.keys(True, True)
else (None, key)
for key in next_td_keys
]
)
keys_to_update = [key for key in keys_to_update if key is not None]
keys_to_copy = [key for key in keys_to_copy if key is not None]
data = [
{"next_td_passthrough_keys": next_td_keys}
{"next_td_passthrough_keys": keys_to_update}
for _ in range(self.num_workers)
]
shared_tensordict_parent.get("next").update_(
next_td_passthrough, non_blocking=self.non_blocking
)
if keys_to_update:
next_shared_tensordict_parent.update_(
next_td_passthrough,
non_blocking=self.non_blocking,
keys_to_update=keys_to_update,
)
if keys_to_copy:
next_td_passthrough = next_td_passthrough.select(*keys_to_copy)
else:
next_td_passthrough = None
else:
next_td_passthrough = None
data = [{} for _ in range(self.num_workers)]

if self._non_tensor_keys:
Expand Down Expand Up @@ -1807,6 +1832,9 @@ def select_and_clone(name, tensor):
LazyStackedTensorDict(*non_tensor_tds),
keys_to_update=self._non_tensor_keys,
)
if next_td_passthrough is not None:
out.update(next_td_passthrough)

self._sync_w2m()
if partial_steps is not None:
result = out.new_zeros(tensordict_save.shape)
Expand Down
14 changes: 9 additions & 5 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5089,12 +5089,16 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
)

if self.primers.shape != observation_spec.shape:
try:
# We try to set the primer shape to the observation spec shape
self.primers.shape = observation_spec.shape
except ValueError:
# If we fail, we expand them to that shape
if self.primers.shape == () and self.parent.batch_size != ():
self.primers = self._expand_shape(self.primers)
else:
try:
# We try to set the primer shape to the observation spec shape
self.primers.shape = observation_spec.shape
except ValueError:
# If we fail, we expand them to that shape
self.primers = self._expand_shape(self.primers)

device = observation_spec.device
observation_spec.update(self.primers.clone().to(device))
return observation_spec
Expand Down
Loading

0 comments on commit 7a7c1aa

Please sign in to comment.