From 60c58dece72994b2738228ae60368ab4a1c413b1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 18 Nov 2024 07:55:05 +0000 Subject: [PATCH 01/19] Update [ghstack-poisoned] --- sota-implementations/gail/config.yaml | 5 + sota-implementations/gail/gail.py | 152 ++++++++++++++----------- sota-implementations/gail/ppo_utils.py | 7 +- 3 files changed, 97 insertions(+), 67 deletions(-) diff --git a/sota-implementations/gail/config.yaml b/sota-implementations/gail/config.yaml index cf6c8053037..2e057b08220 100644 --- a/sota-implementations/gail/config.yaml +++ b/sota-implementations/gail/config.yaml @@ -41,6 +41,11 @@ gail: gp_lambda: 10.0 device: null +compile: + compile: False + compile_mode: + cudagraphs: False + replay_buffer: dataset: halfcheetah-expert-v2 batch_size: 256 diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index a3c64693fb3..ae70c92a20c 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -18,11 +18,12 @@ from ppo_utils import eval_model, make_env, make_ppo_models from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from tensordict.nn import CudaGraphModule from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type -from torchrl.objectives import ClipPPOLoss, GAILLoss +from torchrl.objectives import ClipPPOLoss, GAILLoss, group_optimizers from torchrl.objectives.value.advantages import GAE from torchrl.record import VideoRecorder from torchrl.record.loggers import generate_exp_name, get_logger @@ -69,20 +70,9 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create models (check utils_mujoco.py) - actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = make_ppo_models(cfg.env.env_name, compile=cfg.compile.compile) actor, critic = actor.to(device), critic.to(device) - # Create collector - collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, device), - policy=actor, - frames_per_batch=cfg.ppo.collector.frames_per_batch, - total_frames=cfg.ppo.collector.total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, - ) - # Create data buffer data_buffer = TensorDictReplayBuffer( storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch), @@ -111,6 +101,30 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizers actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) + optim = group_optimizers(actor_optim, critic_optim) + del actor_optim, critic_optim + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=cfg.ppo.collector.frames_per_batch, + total_frames=cfg.ppo.collector.total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + compile_policy={"mode": compile_mode} if compile_mode is not None else False, + cudagraph_policy=cfg.compile.cudagraphs, + ) # Create replay buffer replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) @@ -138,32 +152,9 @@ def main(cfg: "DictConfig"): # noqa: F821 VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) ) test_env.eval() + num_network_updates = torch.zeros((), dtype=torch.int64, device=device) - # Training loop - collected_frames = 0 - num_network_updates = 0 - pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames) - - # extract cfg variables - cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs - cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr - cfg_optim_lr = cfg.ppo.optim.lr - cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon - cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon - cfg_logger_test_interval = cfg.logger.test_interval - cfg_logger_num_test_episodes = cfg.logger.num_test_episodes - - for i, data in enumerate(collector): - - log_info = {} - frames_in_batch = data.numel() - collected_frames += frames_in_batch - pbar.update(data.numel()) - - # Update discriminator - # Get expert data - expert_data = replay_buffer.sample() - expert_data = expert_data.to(device) + def update(data, expert_data, num_network_updates=num_network_updates): # Add collector data to expert data expert_data.set( discriminator_loss.tensor_keys.collector_action, @@ -176,9 +167,9 @@ def main(cfg: "DictConfig"): # noqa: F821 d_loss = discriminator_loss(expert_data) # Backward pass - discriminator_optim.zero_grad() d_loss.get("loss").backward() discriminator_optim.step() + discriminator_optim.zero_grad(set_to_none=True) # Compute discriminator reward with torch.no_grad(): @@ -188,32 +179,19 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set discriminator rewards to tensordict data.set(("next", "reward"), d_rewards) - # Get training rewards and episode lengths - episode_rewards = data["next", "episode_reward"][data["next", "done"]] - if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "done"]] - log_info.update( - { - "train/reward": episode_rewards.mean().item(), - "train/episode_length": episode_length.sum().item() - / len(episode_length), - } - ) # Update PPO for _ in range(cfg_loss_ppo_epochs): - # Compute GAE with torch.no_grad(): data = adv_module(data) data_reshape = data.reshape(-1) # Update the data buffer + data_buffer.empty() data_buffer.extend(data_reshape) - for _, batch in enumerate(data_buffer): - - # Get a data batch - batch = batch.to(device) + for batch in data_buffer: + optim.zero_grad(set_to_none=True) # Linearly decrease the learning rate and clip epsilon alpha = 1.0 @@ -233,20 +211,66 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_loss = loss["loss_objective"] + loss["loss_entropy"] # Backward pass - actor_loss.backward() - critic_loss.backward() + (actor_loss + critic_loss).backward() # Update the networks - actor_optim.step() - critic_optim.step() - actor_optim.zero_grad() - critic_optim.zero_grad() + optim.step() + return d_loss.detach() + + if cfg.compile.compile: + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + # Training loop + collected_frames = 0 + pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames) + + # extract cfg variables + cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs + cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr + cfg_optim_lr = cfg.ppo.optim.lr + cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon + cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon + cfg_logger_test_interval = cfg.logger.test_interval + cfg_logger_num_test_episodes = cfg.logger.num_test_episodes + + for i, data in enumerate(collector): + + log_info = {} + frames_in_batch = data.numel() + collected_frames += frames_in_batch + pbar.update(data.numel()) + + # Update discriminator + # Get expert data + expert_data = replay_buffer.sample() + expert_data = expert_data.to(device) + + d_loss = update(data, expert_data) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) log_info.update( { - "train/actor_loss": actor_loss.item(), - "train/critic_loss": critic_loss.item(), - "train/discriminator_loss": d_loss["loss"].item(), + # "train/actor_loss": actor_loss.item(), + # "train/critic_loss": critic_loss.item(), + "train/discriminator_loss": d_loss["loss"], "train/lr": alpha * cfg_optim_lr, "train/clip_epsilon": ( alpha * cfg_loss_clip_epsilon diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py index 635c24517e6..fe5868a08ba 100644 --- a/sota-implementations/gail/ppo_utils.py +++ b/sota-implementations/gail/ppo_utils.py @@ -42,7 +42,7 @@ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False) # -------------------------------------------------------------------- -def make_ppo_models_state(proof_environment): +def make_ppo_models_state(proof_environment, compile): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -54,6 +54,7 @@ def make_ppo_models_state(proof_environment): "low": proof_environment.single_action_spec.space.low, "high": proof_environment.single_action_spec.space.high, "tanh_loc": False, + "safe_tanh": not compile, } # Define policy architecture @@ -116,9 +117,9 @@ def make_ppo_models_state(proof_environment): return policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, compile): proof_environment = make_env(env_name, device="cpu") - actor, critic = make_ppo_models_state(proof_environment) + actor, critic = make_ppo_models_state(proof_environment, compile=compile) return actor, critic From 5f7da2bedc512dc63daa539cfad1d4b47608b4fe Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 13 Dec 2024 12:49:57 -0800 Subject: [PATCH 02/19] Update [ghstack-poisoned] --- sota-implementations/cql/discrete_cql_config.yaml | 2 +- sota-implementations/cql/online_config.yaml | 2 +- sota-implementations/cql/utils.py | 8 +++++++- sota-implementations/decision_transformer/utils.py | 2 +- sota-implementations/gail/gail.py | 4 ++-- sota-implementations/iql/discrete_iql.yaml | 2 +- sota-implementations/iql/online_config.yaml | 2 +- sota-implementations/iql/utils.py | 8 +++++++- sota-implementations/sac/config.yaml | 2 +- sota-implementations/sac/utils.py | 8 +++++++- sota-implementations/td3/config.yaml | 2 +- sota-implementations/td3/utils.py | 8 +++++++- 12 files changed, 37 insertions(+), 13 deletions(-) diff --git a/sota-implementations/cql/discrete_cql_config.yaml b/sota-implementations/cql/discrete_cql_config.yaml index 6db31a9aa81..a9fb9bfed0c 100644 --- a/sota-implementations/cql/discrete_cql_config.yaml +++ b/sota-implementations/cql/discrete_cql_config.yaml @@ -14,7 +14,7 @@ collector: multi_step: 0 init_random_frames: 1000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 200 annealing_frames: 10000 eps_start: 1.0 diff --git a/sota-implementations/cql/online_config.yaml b/sota-implementations/cql/online_config.yaml index 5a8be9616a0..5c9e649f17f 100644 --- a/sota-implementations/cql/online_config.yaml +++ b/sota-implementations/cql/online_config.yaml @@ -15,7 +15,7 @@ collector: multi_step: 0 init_random_frames: 5_000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 1000 diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index 611a31e585d..bcb5c0e96ac 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -122,6 +122,12 @@ def make_collector( cudagraph=False, ): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, @@ -129,7 +135,7 @@ def make_collector( frames_per_batch=cfg.collector.frames_per_batch, max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + device=device, compile_policy={"mode": compile_mode} if compile else False, cudagraph_policy=cudagraph, ) diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 20957ec7bee..a045857d2f4 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -393,7 +393,7 @@ def make_odt_model(cfg, device: torch.device | None = None) -> TensorDictModule: with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): td = proof_environment.rollout(max_steps=100) td["action"] = td["next", "action"] - actor(td) + actor(td.to(device)) return actor diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index ae70c92a20c..d8187e68635 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -16,9 +16,9 @@ from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer from ppo_utils import eval_model, make_env, make_ppo_models +from tensordict.nn import CudaGraphModule from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer -from tensordict.nn import CudaGraphModule from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import set_gym_backend @@ -262,7 +262,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() - / len(episode_length), + / len(episode_length), } ) diff --git a/sota-implementations/iql/discrete_iql.yaml b/sota-implementations/iql/discrete_iql.yaml index 9245d4c4832..d28c02cf499 100644 --- a/sota-implementations/iql/discrete_iql.yaml +++ b/sota-implementations/iql/discrete_iql.yaml @@ -15,7 +15,7 @@ collector: total_frames: 20000 init_random_frames: 1000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 200 # logger diff --git a/sota-implementations/iql/online_config.yaml b/sota-implementations/iql/online_config.yaml index 1f7bb361e6c..64ad7466192 100644 --- a/sota-implementations/iql/online_config.yaml +++ b/sota-implementations/iql/online_config.yaml @@ -15,7 +15,7 @@ collector: multi_step: 0 init_random_frames: 5000 env_per_collector: 1 - device: cpu + device: max_frames_per_traj: 200 # logger diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 617974752b0..9d1b8c83a75 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -118,6 +118,12 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None): def make_collector(cfg, train_env, actor_model_explore): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, @@ -125,7 +131,7 @@ def make_collector(cfg, train_env, actor_model_explore): init_random_frames=cfg.collector.init_random_frames, max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + device=device, ) collector.set_seed(cfg.env.seed) return collector diff --git a/sota-implementations/sac/config.yaml b/sota-implementations/sac/config.yaml index 29586f2e9a7..5cf531a3be2 100644 --- a/sota-implementations/sac/config.yaml +++ b/sota-implementations/sac/config.yaml @@ -12,7 +12,7 @@ collector: init_random_frames: 25000 frames_per_batch: 1000 init_env_steps: 1000 - device: cpu + device: env_per_collector: 1 reset_at_each_iter: False diff --git a/sota-implementations/sac/utils.py b/sota-implementations/sac/utils.py index d1dbb2db791..e827e4e6416 100644 --- a/sota-implementations/sac/utils.py +++ b/sota-implementations/sac/utils.py @@ -105,13 +105,19 @@ def make_environment(cfg, logger=None): def make_collector(cfg, train_env, actor_model_explore): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, - device=cfg.collector.device, + device=device, ) collector.set_seed(cfg.env.seed) return collector diff --git a/sota-implementations/td3/config.yaml b/sota-implementations/td3/config.yaml index 7f7854b68b3..5bdf22ea6fa 100644 --- a/sota-implementations/td3/config.yaml +++ b/sota-implementations/td3/config.yaml @@ -13,7 +13,7 @@ collector: init_env_steps: 1000 frames_per_batch: 1000 reset_at_each_iter: False - device: cpu + device: env_per_collector: 1 num_workers: 1 diff --git a/sota-implementations/td3/utils.py b/sota-implementations/td3/utils.py index 665c2e0c674..d32375a7649 100644 --- a/sota-implementations/td3/utils.py +++ b/sota-implementations/td3/utils.py @@ -116,6 +116,12 @@ def make_environment(cfg, logger=None): def make_collector(cfg, train_env, actor_model_explore): """Make collector.""" + device = cfg.collector.device + if device in ("", None): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") collector = SyncDataCollector( train_env, actor_model_explore, @@ -123,7 +129,7 @@ def make_collector(cfg, train_env, actor_model_explore): frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, reset_at_each_iter=cfg.collector.reset_at_each_iter, - device=cfg.collector.device, + device=device, ) collector.set_seed(cfg.env.seed) return collector From 7a4cc1e7ad544f782e1182125809566ca3e9cffc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 09:47:08 -0800 Subject: [PATCH 03/19] Update [ghstack-poisoned] --- sota-implementations/gail/gail.py | 28 +++++++++++++++++--------- sota-implementations/gail/ppo_utils.py | 17 +++++++++------- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index d7b3dd49fcb..39eaf3a929d 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -11,6 +11,8 @@ """ from __future__ import annotations +import warnings + import hydra import numpy as np import torch @@ -18,6 +20,7 @@ from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer from ppo_utils import eval_model, make_env, make_ppo_models +from tensordict import TensorDict from tensordict.nn import CudaGraphModule from torchrl.collectors import SyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer @@ -72,8 +75,9 @@ def main(cfg: "DictConfig"): # noqa: F821 np.random.seed(cfg.env.seed) # Create models (check utils_mujoco.py) - actor, critic = make_ppo_models(cfg.env.env_name, compile=cfg.compile.compile) - actor, critic = actor.to(device), critic.to(device) + actor, critic = make_ppo_models( + cfg.env.env_name, compile=cfg.compile.compile, device=device + ) # Create data buffer data_buffer = TensorDictReplayBuffer( @@ -101,8 +105,12 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizers - actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) - critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5) + actor_optim = torch.optim.Adam( + actor.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5 + ) + critic_optim = torch.optim.Adam( + critic.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5 + ) optim = group_optimizers(actor_optim, critic_optim) del actor_optim, critic_optim @@ -196,12 +204,10 @@ def update(data, expert_data, num_network_updates=num_network_updates): optim.zero_grad(set_to_none=True) # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 + alpha = torch.ones((), device=device) if cfg_optim_anneal_lr: alpha = 1 - (num_network_updates / total_network_updates) - for group in actor_optim.param_groups: - group["lr"] = cfg_optim_lr * alpha - for group in critic_optim.param_groups: + for group in optim.param_groups: group["lr"] = cfg_optim_lr * alpha if cfg_loss_anneal_clip_eps: loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha) @@ -217,7 +223,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): # Update the networks optim.step() - return d_loss.detach() + return TensorDict(dloss=d_loss, alpha=alpha).detach() if cfg.compile.compile: update = torch.compile(update, mode=compile_mode) @@ -253,7 +259,9 @@ def update(data, expert_data, num_network_updates=num_network_updates): expert_data = replay_buffer.sample() expert_data = expert_data.to(device) - d_loss = update(data, expert_data) + metadata = update(data, expert_data) + d_loss = metadata["d_loss"] + alpha = metadata["alpha"] # Get training rewards and episode lengths episode_rewards = data["next", "episode_reward"][data["next", "done"]] diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py index b5f43a10899..6ba12acdf9c 100644 --- a/sota-implementations/gail/ppo_utils.py +++ b/sota-implementations/gail/ppo_utils.py @@ -43,7 +43,7 @@ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False) # -------------------------------------------------------------------- -def make_ppo_models_state(proof_environment, compile): +def make_ppo_models_state(proof_environment, compile, device): # Define input shape input_shape = proof_environment.observation_spec["observation"].shape @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment, compile): num_outputs = proof_environment.action_spec_unbatched.shape[-1] distribution_class = TanhNormal distribution_kwargs = { - "low": proof_environment.action_spec_unbatched.space.low, - "high": proof_environment.action_spec_unbatched.space.high, + "low": proof_environment.action_spec_unbatched.space.low.to(device), + "high": proof_environment.action_spec_unbatched.space.high.to(device), "tanh_loc": False, "safe_tanh": not compile, } @@ -64,6 +64,7 @@ def make_ppo_models_state(proof_environment, compile): activation_class=torch.nn.Tanh, out_features=num_outputs, # predict only loc num_cells=[64, 64], + device=device, ) # Initialize policy weights @@ -88,7 +89,7 @@ def make_ppo_models_state(proof_environment, compile): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=proof_environment.single_full_action_spec, + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, @@ -118,9 +119,11 @@ def make_ppo_models_state(proof_environment, compile): return policy_module, value_module -def make_ppo_models(env_name, compile): - proof_environment = make_env(env_name, device="cpu") - actor, critic = make_ppo_models_state(proof_environment, compile=compile) +def make_ppo_models(env_name, compile, device): + proof_environment = make_env(env_name, device=device) + actor, critic = make_ppo_models_state( + proof_environment, compile=compile, device=device + ) return actor, critic From a4abc7bad0f1d00a477b5d9fc04f0fa557606ad0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 18:37:32 -0800 Subject: [PATCH 04/19] Update [ghstack-poisoned] --- sota-implementations/gail/gail.py | 5 ++++- torchrl/data/tensor_specs.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index 39eaf3a929d..e6c729eabc3 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -34,6 +34,9 @@ from torchrl.record.loggers import generate_exp_name, get_logger +torch.set_float32_matmul_precision("high") + + @hydra.main(config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 set_gym_backend(cfg.env.backend).set() @@ -260,7 +263,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): expert_data = expert_data.to(device) metadata = update(data, expert_data) - d_loss = metadata["d_loss"] + d_loss = metadata["dloss"] alpha = metadata["alpha"] # Get training rewards and episode lengths diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 1898e679717..6f214fba6de 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -3319,9 +3319,9 @@ def __init__( self.update_mask(mask) self._provisional_n = None - @torch.compiler.assume_constant_result + @property def _undefined_n(self): - return self.space.n == -1 + return self.space.n < 0 def enumerate(self) -> torch.Tensor: dtype = self.dtype From 23cab4156ebc208e2d3f924edd4a97638951646c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 18:45:09 -0800 Subject: [PATCH 05/19] Update [ghstack-poisoned] --- sota-implementations/gail/gail.py | 1 + sota-implementations/gail/ppo_utils.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index e6c729eabc3..0f72128318d 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -95,6 +95,7 @@ def main(cfg: "DictConfig"): # noqa: F821 lmbda=cfg.ppo.loss.gae_lambda, value_network=critic, average_gae=False, + device=device, ) loss_module = ClipPPOLoss( diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py index 6ba12acdf9c..7dcc2db6b74 100644 --- a/sota-implementations/gail/ppo_utils.py +++ b/sota-implementations/gail/ppo_utils.py @@ -55,7 +55,7 @@ def make_ppo_models_state(proof_environment, compile, device): "low": proof_environment.action_spec_unbatched.space.low.to(device), "high": proof_environment.action_spec_unbatched.space.high.to(device), "tanh_loc": False, - "safe_tanh": not compile, + # "safe_tanh": not compile, } # Define policy architecture @@ -77,7 +77,9 @@ def make_ppo_models_state(proof_environment, compile, device): policy_mlp = torch.nn.Sequential( policy_mlp, AddStateIndependentNormalScale( - proof_environment.action_spec_unbatched.shape[-1], scale_lb=1e-8 + proof_environment.action_spec_unbatched.shape[-1], + scale_lb=1e-8, + device=device, ), ) @@ -102,6 +104,7 @@ def make_ppo_models_state(proof_environment, compile, device): activation_class=torch.nn.Tanh, out_features=1, num_cells=[64, 64], + device=device, ) # Initialize value weights From ad1f0a493bf2e9c8593ae20e1f45e9ff3dcb2716 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 18:53:44 -0800 Subject: [PATCH 06/19] Update [ghstack-poisoned] --- sota-implementations/gail/gail.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index 0f72128318d..969c7fc083e 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -23,7 +23,7 @@ from tensordict import TensorDict from tensordict.nn import CudaGraphModule from torchrl.collectors import SyncDataCollector -from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import set_gym_backend @@ -84,7 +84,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create data buffer data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch), + storage=LazyTensorStorage(cfg.ppo.collector.frames_per_batch, device=device), sampler=SamplerWithoutReplacement(), batch_size=cfg.ppo.loss.mini_batch_size, ) @@ -134,7 +134,6 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch=cfg.ppo.collector.frames_per_batch, total_frames=cfg.ppo.collector.total_frames, device=device, - storing_device=device, max_frames_per_traj=-1, compile_policy={"mode": compile_mode} if compile_mode is not None else False, cudagraph_policy=cfg.compile.cudagraphs, From 5259e6eb934849be93a85e98f1a4eda5c8dd86eb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 19:12:38 -0800 Subject: [PATCH 07/19] Update [ghstack-poisoned] --- sota-implementations/gail/gail.py | 10 +++++-- torchrl/_utils.py | 50 +++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index 969c7fc083e..48e3b18c2c3 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -22,6 +22,8 @@ from ppo_utils import eval_model, make_env, make_ppo_models from tensordict import TensorDict from tensordict.nn import CudaGraphModule + +from torchrl._utils import compile_with_warmup from torchrl.collectors import SyncDataCollector from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement @@ -84,7 +86,11 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create data buffer data_buffer = TensorDictReplayBuffer( - storage=LazyTensorStorage(cfg.ppo.collector.frames_per_batch, device=device), + storage=LazyTensorStorage( + cfg.ppo.collector.frames_per_batch, + device=device, + compilable=cfg.compile.compile, + ), sampler=SamplerWithoutReplacement(), batch_size=cfg.ppo.loss.mini_batch_size, ) @@ -229,7 +235,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): return TensorDict(dloss=d_loss, alpha=alpha).detach() if cfg.compile.compile: - update = torch.compile(update, mode=compile_mode) + update = compile_with_warmup(update, warmup=2, mode=compile_mode) if cfg.compile.cudagraphs: warnings.warn( "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", diff --git a/torchrl/_utils.py b/torchrl/_utils.py index c81ffcc962b..73f31c8ccf5 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -851,3 +851,53 @@ def set_mode(self, type: Any | None) -> None: cm = self._lock if not is_compiling() else nullcontext() with cm: self._mode = type + + +@wraps(torch.compile) +def compile_with_warmup(*args, warmup: int, **kwargs): + """Compile a model with warm-up. + + This function wraps :func:`~torch.compile` to add a warm-up phase. During the warm-up phase, + the original model is used. After the warm-up phase, the model is compiled using + `torch.compile`. + + Args: + *args: Arguments to be passed to `torch.compile`. + warmup (int): Number of calls to the model before compiling it. + **kwargs: Keyword arguments to be passed to `torch.compile`. + + Returns: + A callable that wraps the original model. If no model is provided, returns a + lambda function that takes a model as input and returns the wrapped model. + + Notes: + If no model is provided, this function returns a lambda function that can be + used to wrap a model later. This allows for delayed compilation of the model. + + Example: + >>> model = torch.nn.Linear(5, 3) + >>> compiled_model = compile_with_warmup(model, warmup=10) + >>> # First 10 calls use the original model + >>> # After 10 calls, the model is compiled and used + """ + + if len(args): + model = args[0] + else: + model = kwargs.get("model") + if model is None: + return lambda model: compile_with_warmup(model, warmup=warmup, **kwargs) + else: + count = 0 + compiled_model = model + + @wraps(model) + def count_and_compile(*model_args, **model_kwargs): + nonlocal count + nonlocal compiled_model + count += 1 + if count == warmup: + compiled_model = torch.compile(model, *args, **kwargs) + return compiled_model(*model_args, **model_kwargs) + + return count_and_compile From b26b9d3f4919a57f18d53b273520a51235e3565c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 19:16:39 -0800 Subject: [PATCH 08/19] Update [ghstack-poisoned] --- torchrl/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 73f31c8ccf5..45f8c433725 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -880,11 +880,11 @@ def compile_with_warmup(*args, warmup: int, **kwargs): >>> # First 10 calls use the original model >>> # After 10 calls, the model is compiled and used """ - if len(args): model = args[0] + args = () else: - model = kwargs.get("model") + model = kwargs.pop("model", None) if model is None: return lambda model: compile_with_warmup(model, warmup=warmup, **kwargs) else: From 7cc8f3607711ac0f2cd34531499b6344fb4d1635 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 19:19:11 -0800 Subject: [PATCH 09/19] Update [ghstack-poisoned] --- sota-implementations/gail/gail.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index 48e3b18c2c3..072063bb689 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -93,6 +93,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ), sampler=SamplerWithoutReplacement(), batch_size=cfg.ppo.loss.mini_batch_size, + compilable=cfg.compile.compile, ) # Create loss and adv modules From 918d1fc5cac69ffb56d1a5f9ac098b864121d5bd Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 19:22:04 -0800 Subject: [PATCH 10/19] Update [ghstack-poisoned] --- torchrl/data/replay_buffers/replay_buffers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 67113095af0..f9e76a2a282 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -1159,7 +1159,7 @@ class TensorDictReplayBuffer(ReplayBuffer): def __init__(self, *, priority_key: str = "td_error", **kwargs) -> None: writer = kwargs.get("writer", None) if writer is None: - kwargs["writer"] = TensorDictRoundRobinWriter() + kwargs["writer"] = TensorDictRoundRobinWriter(compilable=kwargs.get("compilable")) super().__init__(**kwargs) self.priority_key = priority_key From d9e9921ab405438959fe32fd2302485e3e37031c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 19:29:16 -0800 Subject: [PATCH 11/19] Update [ghstack-poisoned] --- torchrl/data/replay_buffers/replay_buffers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index f9e76a2a282..c9751faf01e 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -20,9 +20,9 @@ import torch try: - from torch.compiler import is_dynamo_compiling + from torch.compiler import is_compiling except ImportError: - from torch._dynamo import is_compiling as is_dynamo_compiling + from torch._dynamo import is_compiling from tensordict import ( is_tensor_collection, @@ -617,7 +617,7 @@ def _add(self, data): return index def _extend(self, data: Sequence) -> torch.Tensor: - is_compiling = is_dynamo_compiling() + is_compiling = is_compiling() nc = contextlib.nullcontext() with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc: if self.dim_extend > 0: @@ -672,7 +672,7 @@ def update_priority( @pin_memory_output def _sample(self, batch_size: int) -> Tuple[Any, dict]: - with self._replay_lock if not is_dynamo_compiling() else contextlib.nullcontext(): + with self._replay_lock if not is_compiling() else contextlib.nullcontext(): index, info = self._sampler.sample(self._storage, batch_size) info["index"] = index data = self._storage.get(index) @@ -1343,7 +1343,7 @@ def sample( @pin_memory_output def _sample(self, batch_size: int) -> Tuple[Any, dict]: - with self._replay_lock: + with self._replay_lock if not is_compiling() else contextlib.nullcontext(): index, info = self._sampler.sample(self._storage, batch_size) info["index"] = index data = self._storage.get(index) From 8d7d2c0ff2ffda84459921a00c6a245d602e2f17 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 19:30:46 -0800 Subject: [PATCH 12/19] Update [ghstack-poisoned] --- torchrl/data/replay_buffers/replay_buffers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index c9751faf01e..c2b4e69389c 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -617,9 +617,9 @@ def _add(self, data): return index def _extend(self, data: Sequence) -> torch.Tensor: - is_compiling = is_compiling() + is_comp = is_compiling() nc = contextlib.nullcontext() - with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc: + with self._replay_lock if not is_comp else nc, self._write_lock if not is_comp else nc: if self.dim_extend > 0: data = self._transpose(data) index = self._writer.extend(data) From f66ad56e35fa16fbc7f0b7d9e0e96e26e709ace7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 19:33:55 -0800 Subject: [PATCH 13/19] Update [ghstack-poisoned] --- torchrl/data/replay_buffers/writers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 7fb865453d6..70484c0fefa 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -332,7 +332,7 @@ def extend(self, data: Sequence) -> torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities: + for ent in list(self._storage._attached_entities): ent.mark_update(index) return index From e6817057d6267ecc59776d26105659896d9ca239 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 19:35:17 -0800 Subject: [PATCH 14/19] Update [ghstack-poisoned] --- sota-implementations/gail/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/gail/config.yaml b/sota-implementations/gail/config.yaml index 2e057b08220..089de2c59e4 100644 --- a/sota-implementations/gail/config.yaml +++ b/sota-implementations/gail/config.yaml @@ -43,7 +43,7 @@ gail: compile: compile: False - compile_mode: + compile_mode: default cudagraphs: False replay_buffer: From f78dde4d81bc87331303e16c3ba493f39f3f8156 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 20:01:12 -0800 Subject: [PATCH 15/19] Update [ghstack-poisoned] --- torchrl/data/replay_buffers/replay_buffers.py | 4 +++- torchrl/data/replay_buffers/storages.py | 9 +++++---- torchrl/data/replay_buffers/writers.py | 10 +++++----- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index c2b4e69389c..fbb76b5a681 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -1159,7 +1159,9 @@ class TensorDictReplayBuffer(ReplayBuffer): def __init__(self, *, priority_key: str = "td_error", **kwargs) -> None: writer = kwargs.get("writer", None) if writer is None: - kwargs["writer"] = TensorDictRoundRobinWriter(compilable=kwargs.get("compilable")) + kwargs["writer"] = TensorDictRoundRobinWriter( + compilable=kwargs.get("compilable") + ) super().__init__(**kwargs) self.priority_key = priority_key diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index ae0d97b7bab..d0ec3572784 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -86,17 +86,17 @@ def _is_full(self): return len(self) == self.max_size @property - def _attached_entities(self): + def _attached_entities(self) -> List: # RBs that use a given instance of Storage should add # themselves to this set. _attached_entities_set = getattr(self, "_attached_entities_set", None) if _attached_entities_set is None: - self._attached_entities_set = _attached_entities_set = set() + self._attached_entities_set = _attached_entities_set = [] return _attached_entities_set @torch._dynamo.assume_constant_result def _attached_entities_iter(self): - return list(self._attached_entities) + return self._attached_entities @abc.abstractmethod def set(self, cursor: int, data: Any, *, set_cursor: bool = True): @@ -123,7 +123,8 @@ def attach(self, buffer: Any) -> None: Args: buffer: the object that reads from this storage. """ - self._attached_entities.add(buffer) + if buffer not in self._attached_entities: + self._attached_entities.append(buffer) def __getitem__(self, item): return self.get(item) diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 70484c0fefa..ff2a00ab242 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -176,7 +176,7 @@ def add(self, data: Any) -> int | torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(_cursor, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities: + for ent in self._storage._attached_entities_iter: ent.mark_update(index) return index @@ -302,7 +302,7 @@ def add(self, data: Any) -> int | torch.Tensor: ) self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities: + for ent in self._storage._attached_entities_iter: ent.mark_update(index) return index @@ -332,7 +332,7 @@ def extend(self, data: Sequence) -> torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) index = self._replicate_index(index) - for ent in list(self._storage._attached_entities): + for ent in self._storage._attached_entities_iter: ent.mark_update(index) return index @@ -533,7 +533,7 @@ def add(self, data: Any) -> int | torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities: + for ent in self._storage._attached_entities_iter: ent.mark_update(index) return index @@ -567,7 +567,7 @@ def extend(self, data: TensorDictBase) -> None: device = getattr(self._storage, "device", None) out_index = torch.full(data.shape, -1, dtype=torch.long, device=device) index = self._replicate_index(out_index) - for ent in self._storage._attached_entities: + for ent in self._storage._attached_entities_iter: ent.mark_update(index) return index From 1d5484aabaacbf844013210ebfe2ef2e030ced0e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 20:03:17 -0800 Subject: [PATCH 16/19] Update [ghstack-poisoned] --- torchrl/data/replay_buffers/storages.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index d0ec3572784..52d137208ad 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -69,7 +69,7 @@ def __init__( self.max_size = int(max_size) self.checkpointer = checkpointer self._compilable = compilable - self._attached_entities_set = set() + self._attached_entities_list = [] @property def checkpointer(self): @@ -89,10 +89,10 @@ def _is_full(self): def _attached_entities(self) -> List: # RBs that use a given instance of Storage should add # themselves to this set. - _attached_entities_set = getattr(self, "_attached_entities_set", None) - if _attached_entities_set is None: - self._attached_entities_set = _attached_entities_set = [] - return _attached_entities_set + _attached_entities_list = getattr(self, "_attached_entities_list", None) + if _attached_entities_list is None: + self._attached_entities_list = _attached_entities_list = [] + return _attached_entities_list @torch._dynamo.assume_constant_result def _attached_entities_iter(self): From 2813d4b5e18401da976e233bf89dc327a380f50e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 20:04:27 -0800 Subject: [PATCH 17/19] Update [ghstack-poisoned] --- torchrl/data/replay_buffers/writers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index ff2a00ab242..e7f4da9c4bb 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -176,7 +176,7 @@ def add(self, data: Any) -> int | torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(_cursor, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities_iter: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -302,7 +302,7 @@ def add(self, data: Any) -> int | torch.Tensor: ) self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities_iter: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -332,7 +332,7 @@ def extend(self, data: Sequence) -> torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities_iter: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -533,7 +533,7 @@ def add(self, data: Any) -> int | torch.Tensor: # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) index = self._replicate_index(index) - for ent in self._storage._attached_entities_iter: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index @@ -567,7 +567,7 @@ def extend(self, data: TensorDictBase) -> None: device = getattr(self._storage, "device", None) out_index = torch.full(data.shape, -1, dtype=torch.long, device=device) index = self._replicate_index(out_index) - for ent in self._storage._attached_entities_iter: + for ent in self._storage._attached_entities_iter(): ent.mark_update(index) return index From 8476c5ec870a5961dcccb3922c3beef7c19ff8a6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 20:24:02 -0800 Subject: [PATCH 18/19] Update [ghstack-poisoned] --- sota-implementations/gail/gail.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index 072063bb689..1075a78eba6 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -233,7 +233,7 @@ def update(data, expert_data, num_network_updates=num_network_updates): # Update the networks optim.step() - return TensorDict(dloss=d_loss, alpha=alpha).detach() + return {"dloss": d_loss, "alpha": alpha} if cfg.compile.compile: update = compile_with_warmup(update, warmup=2, mode=compile_mode) From 73b5916974a8d3e3dc97ae800e0cd54c9da086cf Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 09:32:48 -0800 Subject: [PATCH 19/19] Update [ghstack-poisoned] --- sota-implementations/a2c/utils_atari.py | 2 +- sota-implementations/a2c/utils_mujoco.py | 2 +- sota-implementations/dreamer/dreamer_utils.py | 2 +- sota-implementations/gail/gail.py | 1 - sota-implementations/impala/utils.py | 2 +- sota-implementations/ppo/utils_atari.py | 2 +- sota-implementations/ppo/utils_mujoco.py | 2 +- torchrl/data/tensor_specs.py | 2 +- 8 files changed, 7 insertions(+), 8 deletions(-) diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 0397f7dc5f3..6ff62bbe520 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -152,7 +152,7 @@ def make_ppo_modules_pixels(proof_environment, device): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=proof_environment.single_full_action_spec.to(device), + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index b78f52b7eb4..5ce5ed1902d 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -94,7 +94,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=proof_environment.single_full_action_spec.to(device), + spec=proof_environment.full_action_spec_unbatched.to(device), distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 532fe4e1fe9..7d8b9d6d618 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -546,7 +546,7 @@ def _dreamer_make_actor_real( default_interaction_type=InteractionType.DETERMINISTIC, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=proof_environment.single_full_action_spec.to("cpu"), + spec=proof_environment.full_action_spec_unbatched.to("cpu"), ), ), SafeModule( diff --git a/sota-implementations/gail/gail.py b/sota-implementations/gail/gail.py index 1075a78eba6..a02845cfe4d 100644 --- a/sota-implementations/gail/gail.py +++ b/sota-implementations/gail/gail.py @@ -20,7 +20,6 @@ from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer from ppo_utils import eval_model, make_env, make_ppo_models -from tensordict import TensorDict from tensordict.nn import CudaGraphModule from torchrl._utils import compile_with_warmup diff --git a/sota-implementations/impala/utils.py b/sota-implementations/impala/utils.py index 738bb83bf55..e174bc2e71c 100644 --- a/sota-implementations/impala/utils.py +++ b/sota-implementations/impala/utils.py @@ -117,7 +117,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=proof_environment.single_full_action_spec, + spec=proof_environment.full_action_spec_unbatched, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index 755c6311729..040259377ad 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -148,7 +148,7 @@ def make_ppo_modules_pixels(proof_environment): policy_module = ProbabilisticActor( policy_module, in_keys=["logits"], - spec=proof_environment.single_full_action_spec, + spec=proof_environment.full_action_spec_unbatched, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index e7eb4534c45..f2e08ffb129 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -87,7 +87,7 @@ def make_ppo_models_state(proof_environment): out_keys=["loc", "scale"], ), in_keys=["loc", "scale"], - spec=proof_environment.single_full_action_spec, + spec=proof_environment.full_action_spec_unbatched, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=True, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 6f214fba6de..ad29b63db04 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -3388,7 +3388,7 @@ def set_provisional_n(self, n: int): self._provisional_n = n def rand(self, shape: torch.Size = None) -> torch.Tensor: - if self._undefined_n(): + if self._undefined_n: if self._provisional_n is None: raise RuntimeError( "Cannot generate random categorical samples for undefined cardinality (n=-1). "