Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 21, 2025
2 parents 759ea27 + 9b67242 commit d6b5d68
Show file tree
Hide file tree
Showing 24 changed files with 350 additions and 10 deletions.
1 change: 1 addition & 0 deletions .github/unittest/linux_sota/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ export SDL_VIDEODRIVER=dummy
export MUJOCO_GL=egl
export PYOPENGL_PLATFORM=egl
export LAZY_LEGACY_OP=False
export COMPOSITE_LP_AGGREGATE=0

conda env config vars set MUJOCO_PY_MUJOCO_PATH=$root_dir/.mujoco/mujoco210 \
DISPLAY=unix:0.0 \
Expand Down
6 changes: 6 additions & 0 deletions .github/unittest/linux_sota/scripts/test_sota.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
from pathlib import Path

import pytest
from tensordict.nn import composite_lp_aggregate

# Check that we're using the new behavior
assert (
not composite_lp_aggregate()
), "Composite LP must be set to False. Run this test with COMPOSITE_LP_AGGREGATE=0"

commands = {
"dt": """python sota-implementations/decision_transformer/dt.py \
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ on:
workflow_dispatch:

permissions:
id-token: write
deployments: write
contents: write

Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
cancel-in-progress: true

permissions:
id-token: write
contents: read

jobs:
build-docs:
strategy:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
cancel-in-progress: true

permissions:
id-token: write
contents: read

jobs:
python-source-and-configs:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/nightly_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
cancel-in-progress: true

permissions:
id-token: write
contents: read

jobs:
build-wheel-linux:
# Don't run on forked repos.
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/test-linux-habitat.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
cancel-in-progress: true

permissions:
id-token: write
contents: read

jobs:
tests:
strategy:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/test-linux-libs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
cancel-in-progress: true

permissions:
id-token: write
contents: read

jobs:

unittests-atari-dqn:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/test-linux-rlhf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
cancel-in-progress: true

permissions:
id-token: write
contents: read

jobs:
unittests:
strategy:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/test-linux-sota.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
cancel-in-progress: true

permissions:
id-token: write
contents: read

jobs:
tests:
strategy:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
cancel-in-progress: true

permissions:
id-token: write
contents: read

jobs:
tests-cpu:
strategy:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/test-windows-optdepts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
cancel-in-progress: true

permissions:
id-token: write
contents: read

jobs:
unittests-cpu:
uses: pytorch/test-infra/.github/workflows/windows_job.yml@main
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/wheels-legacy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
cancel-in-progress: true

permissions:
id-token: write
contents: read

jobs:

build-wheel-windows:
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ to be able to create this other composition:
SelectTransform
SignTransform
SqueezeTransform
Stack
StepCounter
TargetReturn
TensorDictPrimer
Expand Down
215 changes: 215 additions & 0 deletions examples/collectors/collector_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Using the SyncDataCollector with Different Device Combinations
==============================================================
TorchRL's SyncDataCollector allows you to specify the devices on which different components of the data collection
process are executed. This example demonstrates how to use the collector with various device combinations.
Understanding Device Precedence
-------------------------------
When creating a SyncDataCollector, you can specify the devices for the environment (env_device), policy (policy_device),
and data collection (device). The device argument serves as a default value for any unspecified devices. However, if you
provide env_device or policy_device, they take precedence over the device argument for their respective components.
For example:
- If you set device="cuda", all components will be executed on the CUDA device unless you specify otherwise.
- If you set env_device="cpu" and device="cuda", the environment will be executed on the CPU, while the policy and data
collection will be executed on the CUDA device.
Keeping Policy Parameters in Sync
---------------------------------
When using a policy with buffers or other attributes that are not automatically updated when moving the policy's
parameters to a different device, it's essential to keep the policy's parameters in sync between the main workspace and
the collector.
To do this, call update_policy_weights_() anytime the policy's parameters (and buffers!) are updated. This ensures that
the policy used by the collector has the same parameters as the policy in the main workspace.
Example Use Cases
-----------------
This script demonstrates the SyncDataCollector with the following device combinations:
- Collector on CUDA
- Collector on CPU
- Mixed collector: policy on CUDA, env untouched (ie, unmarked CPU, env.device == None)
- Mixed collector: policy on CUDA, env on CPU (env.device == "cpu")
- Mixed collector: all on CUDA, except env on CPU.
For each configuration, we run a DQN algorithm and check that it converges.
By following this example, you can learn how to use the SyncDataCollector with different device combinations and ensure
that your policy's parameters are kept in sync.
"""

import logging
import time

import torch.cuda
import torch.nn as nn
import torch.optim as optim

from tensordict.nn import TensorDictSequential as TDSeq

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs import Compose, GymEnv, RewardSum, StepCounter, TransformedEnv
from torchrl.modules import EGreedyModule, QValueActor
from torchrl.objectives import DQNLoss, SoftUpdate


logging.basicConfig(level=logging.INFO)
my_logger = logging.getLogger(__name__)

ENV_NAME = "CartPole-v1"

INIT_RND_STEPS = 5_120
FRAMES_PER_BATCH = 128
BUFFER_SIZE = 100_000

GAMMA = 0.98
OPTIM_STEPS = 10
BATCH_SIZE = 128

SOFTU_EPS = 0.99
LR = 0.02


class Net(nn.Module):
def __init__(self, obs_size: int, n_actions: int) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_size, 128),
nn.ReLU(),
nn.Linear(128, n_actions),
)

def forward(self, x):
orig_shape_unbatched = len(x.shape) == 1
if orig_shape_unbatched:
x = x.unsqueeze(0)

out = self.net(x)

if orig_shape_unbatched:
out = out.squeeze(0)
return out


def make_env(env_name: str):
return TransformedEnv(GymEnv(env_name), Compose(StepCounter(), RewardSum()))


if __name__ == "__main__":

for env_device, policy_device, device in (
(None, None, "cuda"),
(None, None, "cpu"),
(None, "cuda", None),
("cpu", "cuda", None),
("cpu", None, "cuda"),
# These configs don't run because the collector needs to know that the policy is on CUDA
# This is not true for the env which has specs that are associated with a device, we can
# automatically transfer the data. The policy does not, in general, have a spec indicating
# what the input and output devices are, so this must be told to the collector.
# (None, None, None),
# ("cpu", None, None),
):
torch.manual_seed(0)
torch.cuda.manual_seed(0)

env = make_env(ENV_NAME)
env.set_seed(0)

n_obs = env.observation_spec["observation"].shape[-1]
n_act = env.action_spec.shape[-1]

net = Net(n_obs, n_act).to(device="cuda:0")
agent = QValueActor(net, spec=env.action_spec.to("cuda:0"))

# policy_explore has buffers on CPU - we will need to call collector.update_policy_weights_()
# to sync them during data collection.
policy_explore = EGreedyModule(env.action_spec)
agent_explore = TDSeq(agent, policy_explore)

collector = SyncDataCollector(
env,
agent_explore,
frames_per_batch=FRAMES_PER_BATCH,
init_random_frames=INIT_RND_STEPS,
device=device,
env_device=env_device,
policy_device=policy_device,
)
exp_buffer = ReplayBuffer(
storage=LazyTensorStorage(BUFFER_SIZE, device="cuda:0")
)

loss = DQNLoss(
value_network=agent, action_space=env.action_spec, delay_value=True
)
loss.make_value_estimator(gamma=GAMMA)
target_updater = SoftUpdate(loss, eps=SOFTU_EPS)
optimizer = optim.Adam(loss.parameters(), lr=LR)

total_count = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
# Check the data devices
if device is None:
assert data["action"].device == torch.device("cuda:0")
assert data["observation"].device == torch.device("cpu")
assert data["done"].device == torch.device("cpu")
elif device == "cpu":
assert data["action"].device == torch.device("cpu")
assert data["observation"].device == torch.device("cpu")
assert data["done"].device == torch.device("cpu")
else:
assert data["action"].device == torch.device("cuda:0")
assert data["observation"].device == torch.device("cuda:0")
assert data["done"].device == torch.device("cuda:0")

exp_buffer.extend(data)
max_length = exp_buffer["next", "step_count"].max()
max_reward = exp_buffer["next", "episode_reward"].max()
if len(exp_buffer) > INIT_RND_STEPS:
for _ in range(OPTIM_STEPS):
optimizer.zero_grad()
sample = exp_buffer.sample(batch_size=BATCH_SIZE)

loss_vals = loss(sample)
loss_vals["loss"].backward()
optimizer.step()

agent_explore[1].step(data.numel())
target_updater.step()

total_count += data.numel()
total_episodes += data["next", "done"].sum()

if i % 10 == 0:
my_logger.info(
f"Step: {i}, max. count / epi reward: {max_length} / {max_reward}."
)
collector.update_policy_weights_()
if max_length > 200:
t1 = time.time()
my_logger.info(f"SOLVED in {t1 - t0}s!! MaxLen: {max_length}!")
my_logger.info(f"With {max_reward} Reward!")
my_logger.info(f"In {total_episodes} Episodes!")
my_logger.info(f"Using devices {(env_device, policy_device, device)}")
break
else:
raise RuntimeError(
f"Failed to converge with config {(env_device, policy_device, device)}"
)
Loading

0 comments on commit d6b5d68

Please sign in to comment.