Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Aug 3, 2024
2 parents 8470093 + b66da41 commit c82f6e3
Show file tree
Hide file tree
Showing 63 changed files with 2,118 additions and 610 deletions.
3 changes: 3 additions & 0 deletions .github/scripts/version_script.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
@echo off
set TORCHRL_BUILD_VERSION=0.5.0
echo TORCHRL_BUILD_VERSION is set to %TORCHRL_BUILD_VERSION%
3 changes: 2 additions & 1 deletion .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies:
- mlflow
- av
- coverage
- ray<2.8.0
- ray
- transformers
- ninja
- timm
3 changes: 2 additions & 1 deletion .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ conda deactivate
conda activate "${env_dir}"

echo "installing gymnasium"
pip3 install "gymnasium[atari,ale-py,accept-rom-license]"
pip3 install "gymnasium"
pip3 install ale_py
pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
pip3 install mujoco -U

Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux_distributed/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ dependencies:
- mlflow
- av
- coverage
- ray<2.8.0
- ray
- virtualenv
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ dependencies:
- dm_control -e git+https://github.com/deepmind/dm_control.git@c053360edea6170acfd9c8f65446703307d9d352#egg={dm_control}
- patchelf
- pyopengl==3.1.4
- ray<2.8.0
- ray
- av
2 changes: 1 addition & 1 deletion .github/unittest/linux_optdeps/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ dependencies:
- pyyaml
- scipy
- coverage
- ray<2.8.0
- ray
1 change: 1 addition & 0 deletions .github/unittest/windows_optdepts/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ fi

# submodules
git submodule sync && git submodule update --init --recursive
python -m pip install "numpy<2.0"

printf "Installing PyTorch with %s\n" "${cudatoolkit}"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/build-wheels-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ jobs:
package-name: ${{ matrix.package-name }}
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
env-var-script: .github/scripts/td_script.sh
pre-script: .github/scripts/td_script.sh
env-script: .github/scripts/version_script.bat
4 changes: 2 additions & 2 deletions .github/workflows/test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
tests-cpu:
strategy:
matrix:
python_version: ["3.8", "3.9", "3.10", "3.11"]
python_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
Expand Down Expand Up @@ -51,7 +51,7 @@ jobs:
tests-gpu:
strategy:
matrix:
python_version: ["3.10"]
python_version: ["3.11"]
cuda_arch_version: ["12.1"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
Expand Down
103 changes: 103 additions & 0 deletions .github/workflows/wheels-legacy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
name: Wheels
on:
pull_request:
types: [opened, synchronize, reopened]
push:
branches:
- release/*

concurrency:
# Documentation suggests ${{ github.head_ref }}, but that's only available on pull_request/pull_request_target triggers, so using ${{ github.ref }}.
# On master, we want all builds to complete even if merging happens faster to make it easier to discover at which point something broke.
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }}
cancel-in-progress: true

jobs:

build-wheel-windows:
runs-on: windows-latest
strategy:
matrix:
python_version: [["3.8", "3.8"], ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"]]
steps:
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python_version[1] }}
- name: Checkout torchrl
uses: actions/checkout@v2
- name: Install PyTorch RC
shell: bash
run: |
python3 -mpip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Build wheel
shell: bash
run: |
python3 -mpip install wheel
TORCHRL_BUILD_VERSION=0.5.0 python3 setup.py bdist_wheel
- name: Upload wheel for the test-wheel job
uses: actions/upload-artifact@v2
with:
name: torchrl-win-${{ matrix.python_version[0] }}.whl
path: dist/torchrl-*.whl
- name: Upload wheel for download
uses: actions/upload-artifact@v2
with:
name: torchrl-batch.whl
path: dist/*.whl

test-wheel-windows:
needs: build-wheel-windows
strategy:
matrix:
python_version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
runs-on: windows-latest
steps:
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python_version }}
- name: Checkout torchrl
uses: actions/checkout@v2
- name: Install PyTorch RC
shell: bash
run: |
python3 -mpip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
- name: Upgrade pip
shell: bash
run: |
python3 -mpip install --upgrade pip
- name: Install tensordict
shell: bash
run: |
python3 -mpip install git+https://github.com/pytorch/tensordict.git
- name: Install test dependencies
shell: bash
run: |
python3 -mpip install numpy pytest pytest-cov codecov unittest-xml-reporting pillow>=4.1.1 scipy av networkx expecttest pyyaml
- name: Download built wheels
uses: actions/download-artifact@v2
with:
name: torchrl-win-${{ matrix.python_version }}.whl
path: wheels
- name: Install built wheels
shell: bash
run: |
python3 -mpip install wheels/*
- name: Log version string
shell: bash
run: |
# Avoid ambiguity of "import torchrl" by deleting the source files.
rm -rf torchrl/
python -c "import torchrl; print(torchrl.__version__)"
- name: Run tests
shell: bash
run: |
set -e
export IN_CI=1
mkdir test-reports
python -m torch.utils.collect_env
python -c "import torchrl; print(torchrl.__version__)"
EXIT_STATUS=0
pytest test/smoke_test.py -v --durations 200
exit $EXIT_STATUS
42 changes: 34 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,33 @@

**TorchRL** is an open-source Reinforcement Learning (RL) library for PyTorch.

It provides pytorch and **python-first**, low and high level abstractions for RL that are intended to be **efficient**, **modular**, **documented** and properly **tested**.
The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort.
## Key features

This repo attempts to align with the existing pytorch ecosystem libraries in that it has a dataset pillar ([torchrl/envs](https://github.com/pytorch/rl/blob/main/torchrl/envs)), [transforms](https://github.com/pytorch/rl/blob/main/torchrl/envs/transforms), [models](https://github.com/pytorch/rl/blob/main/torchrl/modules), data utilities (e.g. collectors and containers), etc.
TorchRL aims at having as few dependencies as possible (python standard library, numpy and pytorch). Common environment libraries (e.g. OpenAI gym) are only optional.
- 🐍 **Python-first**: Designed with Python as the primary language for ease of use and flexibility
- ⏱️ **Efficient**: Optimized for performance to support demanding RL research applications
- 🧮 **Modular, customizable, extensible**: Highly modular architecture allows for easy swapping, transformation, or creation of new components
- 📚 **Documented**: Thorough documentation ensures that users can quickly understand and utilize the library
-**Tested**: Rigorously tested to ensure reliability and stability
- ⚙️ **Reusable functionals**: Provides a set of highly reusable functions for cost functions, returns, and data processing

On the low-level end, torchrl comes with a set of highly re-usable functionals for cost functions, returns and data processing.
### Design Principles

TorchRL aims at (1) a high modularity and (2) good runtime performance. Read the [full paper](https://arxiv.org/abs/2306.00577) for a more curated description of the library.
- 🔥 **Aligns with PyTorch ecosystem**: Follows the structure and conventions of popular PyTorch libraries
(e.g., dataset pillar, transforms, models, data utilities)
- ➖ Minimal dependencies: Only requires Python standard library, NumPy, and PyTorch; optional dependencies for
common environment libraries (e.g., OpenAI Gym) and datasets (D4RL, OpenX...)

Read the [full paper](https://arxiv.org/abs/2306.00577) for a more curated description of the library.

## Getting started

Check our [Getting Started tutorials](https://pytorch.org/rl/stable/index.html#getting-started) for quickly ramp up with the basic
features of the library!

<p align="center">
<img src="docs/ppo.png" width="800" >
</p>

## Documentation and knowledge base

The TorchRL documentation can be found [here](https://pytorch.org/rl).
Expand All @@ -48,9 +60,23 @@ learn the basics of RL. Check it out [here](https://pytorch.org/rl/stable/refere

We have some introductory videos for you to get to know the library better, check them out:

- [TalkRL podcast](https://www.talkrl.com/episodes/vincent-moens-on-torchrl)
- [TorchRL intro at PyTorch day 2022](https://youtu.be/cIKMhZoykEE)
- [PyTorch 2.0 Q&A: TorchRL](https://www.youtube.com/live/myEfUoYrbts?feature=share)

## Spotlight publications

TorchRL being domain-agnostic, you can use it across many different fields. Here are a few examples:

- [ACEGEN](https://pubs.acs.org/doi/10.1021/acs.jcim.4c00895): Reinforcement Learning of Generative Chemical Agents
for Drug Discovery
- [BenchMARL](https://www.jmlr.org/papers/v25/23-1612.html): Benchmarking Multi-Agent Reinforcement Learning
- [BricksRL](https://arxiv.org/abs/2406.17490): A Platform for Democratizing Robotics and Reinforcement Learning
Research and Education with LEGO
- [OmniDrones](https://ieeexplore.ieee.org/abstract/document/10409589): An Efficient and Flexible Platform for Reinforcement Learning in Drone Control
- [RL4CO](https://arxiv.org/abs/2306.17100): an Extensive Reinforcement Learning for Combinatorial Optimization Benchmark
- [Robohive](https://proceedings.neurips.cc/paper_files/paper/2023/file/8a84a4341c375b8441b36836bb343d4e-Paper-Datasets_and_Benchmarks.pdf): A unified framework for robot learning

## Writing simplified and portable RL codebase with `TensorDict`

RL algorithms are very heterogeneous, and it can be hard to recycle a codebase
Expand Down Expand Up @@ -559,7 +585,7 @@ On certain Windows machines (Windows 11), one should install the library locally

The **nightly build** can be installed via
```bash
pip3install torchrl-nightly
pip3 install torchrl-nightly
```
which we currently only ship for Linux and OsX (Intel) machines.
Importantly, the nightly builds require the nightly builds of PyTorch too.
Expand Down Expand Up @@ -590,7 +616,7 @@ Go to the directory where you have cloned the torchrl repo and install it (after
installing `ninja`)
```bash
cd /path/to/torchrl/
pip3install ninja -U
pip3 install ninja -U
python setup.py develop
```

Expand Down
Binary file added docs/ppo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,7 @@ to be able to create this other composition:
CenterCrop
ClipTransform
Compose
Crop
DTypeCastTransform
DeviceCastTransform
DiscreteActionProjection
Expand Down
9 changes: 7 additions & 2 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,18 @@ algorithms, such as DQN, DDPG or Dreamer.
Multi-agent-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

These networks implement models that can be used in
multi-agent contexts.
These networks implement models that can be used in multi-agent contexts.
They use :func:`~torch.vmap` to execute multiple networks all at once on the
network inputs. Because the parameters are batched, initialization may differ
from what is usually done with other PyTorch modules, see
:meth:`~torchrl.modules.MultiAgentNetBase.get_stateful_net`
for more information.

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

MultiAgentNetBase
MultiAgentMLP
MultiAgentConvNet
QMixer
Expand Down
29 changes: 29 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,35 @@ The main characteristics of TorchRL losses are:

>>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_"))

.. note::
Initializing parameters in losses can be done via a query to :meth:`~torchrl.objectives.LossModule.get_stateful_net`
which will return a stateful version of the network that can be initialized like any other module.
If the modification is done in-place, it will be downstreamed to any other module that uses the same parameter
set (within and outside of the loss): for instance, modifying the ``actor_network`` parameters from the loss
will also modify the actor in the collector.
If the parameters are modified out-of-place, :meth:`~torchrl.objectives.LossModule.from_stateful_net` can be
used to reset the parameters in the loss to the new value.

torch.vmap and randomness
-------------------------

TorchRL loss modules have plenty of calls to :func:`~torch.vmap` to amortize the cost of calling multiple similar models
in a loop, and instead vectorize these operations. `vmap` needs to be told explicitly what to do when random numbers
need to be generated within the call. To do this, a randomness mode need to be set and must be one of `"error"` (default,
errors when dealing with pseudo-random functions), `"same"` (replicates the results across the batch) or `"different"`
(each element of the batch is treated separately).
Relying on the default will typically result in an error such as this one:

>>> RuntimeError: vmap: called random operation while in randomness error mode.

Since the calls to `vmap` are buried down the loss modules, TorchRL
provides an interface to set that vmap mode from the outside through `loss.vmap_randomness = str_value`, see
:meth:`~torchrl.objectives.LossModule.vmap_randomness` for more information.

``LossModule.vmap_randomness`` defaults to `"error"` if no random module is detected, and to `"different"` in
other cases. By default, only a limited number of modules are listed as random, but the list can be extended
using the :func:`~torchrl.objectives.common.add_random_module` function.

Training value functions
------------------------

Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def _get_pytorch_version(is_nightly, is_local):
# if "PYTORCH_VERSION" in os.environ:
# return f"torch=={os.environ['PYTORCH_VERSION']}"
if is_nightly:
return "torch>=2.4.0.dev"
return "torch>=2.5.0.dev"
elif is_local:
return "torch"
return "torch>=2.3.0"
return "torch>=2.4.0"


def _get_packages():
Expand Down Expand Up @@ -274,6 +274,7 @@ def _main(argv):
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Development Status :: 4 - Beta",
Expand Down
10 changes: 6 additions & 4 deletions sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
ActorCriticOperator,
ActorValueOperator,
NoisyLinear,
NormalParamWrapper,
NormalParamExtractor,
SafeModule,
SafeSequential,
)
Expand Down Expand Up @@ -483,10 +483,12 @@ def make_redq_model(
}

if not gSDE:
actor_net = NormalParamWrapper(
actor_net = nn.Sequential(
actor_net,
scale_mapping=f"biased_softplus_{default_policy_scale}",
scale_lb=cfg.network.scale_lb,
NormalParamExtractor(
scale_mapping=f"biased_softplus_{default_policy_scale}",
scale_lb=cfg.network.scale_lb,
),
)
actor_module = SafeModule(
actor_net,
Expand Down
Loading

0 comments on commit c82f6e3

Please sign in to comment.