Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Versioning] v0.5 bump #2267

Merged
merged 29 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash

export TORCHRL_BUILD_VERSION=0.4.0
export TORCHRL_BUILD_VERSION=0.5.0

${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U
35 changes: 21 additions & 14 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,19 @@
#
#

set -e
#set -e
set -v

# Initialize an error flag
error_occurred=0
# Function to handle errors
error_handler() {
echo "Error on line $1"
error_occurred=1
}
# Trap ERR to call the error_handler function with the failing line number
trap 'error_handler $LINENO' ERR

export PYTORCH_TEST_WITH_SLOW='1'
python -m torch.utils.collect_env
# Avoid error: "fatal: unsafe repository"
Expand All @@ -24,6 +34,7 @@ lib_dir="${env_dir}/lib"
# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
export MKL_THREADING_LAYER=GNU
export CUDA_LAUNCH_BLOCKING=1

python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200
#python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200
Expand Down Expand Up @@ -163,18 +174,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cr
env.name=Pendulum-v1 \
network.device= \
logger.backend=
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \
collector.total_frames=200 \
collector.init_random_frames=10 \
collector.frames_per_batch=200 \
env.n_parallel_envs=4 \
optimization.optim_steps_per_batch=1 \
logger.video=True \
logger.backend=csv \
replay_buffer.buffer_size=120 \
replay_buffer.batch_size=24 \
replay_buffer.batch_length=12 \
networks.rssm_hidden_dim=17
python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \
collector.total_frames=48 \
collector.init_random_frames=10 \
Expand Down Expand Up @@ -214,8 +213,8 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dr
collector.frames_per_batch=200 \
env.n_parallel_envs=1 \
optimization.optim_steps_per_batch=1 \
logger.backend=csv \
logger.video=True \
logger.backend=csv \
replay_buffer.buffer_size=120 \
replay_buffer.batch_size=24 \
replay_buffer.batch_length=12 \
Expand Down Expand Up @@ -312,3 +311,11 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ba

coverage combine
coverage xml -i

# Check if any errors occurred during the script execution
if [ "$error_occurred" -ne 0 ]; then
echo "Errors occurred during script execution"
exit 1
else
echo "Script executed successfully"
fi
1 change: 1 addition & 0 deletions .github/workflows/build-wheels-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ jobs:
package-name: ${{ matrix.package-name }}
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
env-var-script: .github/scripts/td_script.sh
2 changes: 1 addition & 1 deletion .github/workflows/build-wheels-m1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ jobs:
runner-type: macos-m1-stable
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
env-var-script: .github/scripts/m1_script.sh
env-var-script: .github/scripts/td_script.sh
1 change: 1 addition & 0 deletions .github/workflows/build-wheels-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ jobs:
package-name: ${{ matrix.package-name }}
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
env-var-script: .github/scripts/td_script.sh
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _main(argv):
if is_nightly:
tensordict_dep = "tensordict-nightly"
else:
tensordict_dep = "tensordict>=0.4.0"
tensordict_dep = "tensordict>=0.5.0"

if is_nightly:
version = get_nightly_version()
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
if ((i - 1) * frames_in_batch * frame_skip) // test_interval < (
i * frames_in_batch * frame_skip
) // test_interval:
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get test rewards
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
prev_test_frame = ((i - 1) * frames_in_batch) // cfg.logger.test_interval
cur_test_frame = (i * frames_in_batch) // cfg.logger.test_interval
final = collected_frames >= collector.total_frames
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# evaluation
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cur_test_frame = (i * frames_per_batch) // evaluation_interval
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def main(cfg: "DictConfig"): # noqa: F821

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
6 changes: 4 additions & 2 deletions sota-implementations/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create test environment
test_env = make_env(cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video)
test_env = make_env(
cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video, device=model_device
)
if cfg.logger.video:
test_env = test_env.append_transform(
VideoRecorder(logger, tag="rendered", in_keys=["pixels"])
Expand Down Expand Up @@ -114,7 +116,7 @@ def main(cfg: "DictConfig"): # noqa: F821
to_log = {"train/loss": loss_vals["loss"]}

# Evaluation
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
max_steps=eval_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def main(cfg: "DictConfig"): # noqa: F821
}

# Evaluation
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
inference_policy.eval()
if i % pretrain_log_interval == 0:
eval_td = test_env.rollout(
Expand Down
26 changes: 17 additions & 9 deletions sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
# -----------------


def make_base_env(env_cfg, from_pixels=False):
def make_base_env(env_cfg, from_pixels=False, device=None):
set_gym_backend(env_cfg.backend).set()

env_library = LIBS[env_cfg.library]
Expand All @@ -73,7 +73,7 @@ def make_base_env(env_cfg, from_pixels=False):
if env_library is DMControlEnv:
env_task = env_cfg.task
env_kwargs.update({"task_name": env_task})
env = env_library(**env_kwargs)
env = env_library(**env_kwargs, device=device)
return env


Expand Down Expand Up @@ -134,18 +134,22 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False):
return transformed_env


def make_parallel_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False):
def make_parallel_env(
env_cfg, obs_loc, obs_std, train=False, from_pixels=False, device=None
):
if train:
num_envs = env_cfg.num_train_envs
else:
num_envs = env_cfg.num_eval_envs

def make_env():
with set_gym_backend(env_cfg.backend):
return make_base_env(env_cfg, from_pixels=from_pixels)
return make_base_env(env_cfg, from_pixels=from_pixels, device="cpu")

env = make_transformed_env(
ParallelEnv(num_envs, EnvCreator(make_env), serial_for_single=True),
ParallelEnv(
num_envs, EnvCreator(make_env), serial_for_single=True, device=device
),
env_cfg,
obs_loc,
obs_std,
Expand All @@ -154,11 +158,15 @@ def make_env():
return env


def make_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False):
env = make_parallel_env(
env_cfg, obs_loc, obs_std, train=train, from_pixels=from_pixels
def make_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False, device=None):
return make_parallel_env(
env_cfg,
obs_loc,
obs_std,
train=train,
from_pixels=from_pixels,
device=device,
)
return env


# ====================================================================
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821
cur_test_frame = (i * frames_per_batch) // eval_iter
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
eval_rollout = eval_env.rollout(
eval_rollout_steps,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get and log evaluation rewards and eval time
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
cur_test_frame = (i * frames_per_batch) // test_interval
final = current_frames >= collector.total_frames
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/dqn/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Get and log evaluation rewards and eval time
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC):
prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
cur_test_frame = (i * frames_per_batch) // test_interval
final = current_frames >= collector.total_frames
Expand Down
8 changes: 1 addition & 7 deletions sota-implementations/dreamer/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,13 @@ env:
image_size : 64
horizon: 500
n_parallel_envs: 8
device:
_target_: dreamer_utils._default_device
device: null
device: cpu

collector:
total_frames: 5_000_000
init_random_frames: 3000
frames_per_batch: 1000
device:
_target_: dreamer_utils._default_device
device: null

optimization:
train_every: 1000
Expand All @@ -41,8 +37,6 @@ optimization:
networks:
exploration_noise: 0.3
device:
_target_: dreamer_utils._default_device
device: null
state_dim: 30
rssm_hidden_dim: 200
hidden_dim: 400
Expand Down
10 changes: 6 additions & 4 deletions sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
import torch.cuda
import tqdm
from dreamer_utils import (
_default_device,
dump_video,
log_metrics,
make_collector,
make_dreamer,
make_environments,
make_replay_buffer,
)
from hydra.utils import instantiate

# mixed precision training
from torch.cuda.amp import GradScaler
Expand All @@ -38,7 +38,7 @@
def main(cfg: "DictConfig"): # noqa: F821
# cfg = correct_for_frame_skip(cfg)

device = torch.device(instantiate(cfg.networks.device))
device = _default_device(cfg.networks.device)

# Create logger
exp_name = generate_exp_name("Dreamer", cfg.logger.exp_name)
Expand Down Expand Up @@ -284,7 +284,7 @@ def compile_rssms(module):
# Evaluation
if (i % eval_iter) == 0:
# Real env
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_rollout = test_env.rollout(
eval_rollout_steps,
policy,
Expand All @@ -298,7 +298,9 @@ def compile_rssms(module):
log_metrics(logger, eval_metrics, collected_frames)
# Simulated env
if model_based_env_eval is not None:
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad():
eval_rollout = model_based_env_eval.rollout(
eval_rollout_steps,
policy,
Expand Down
Loading
Loading