From 88fc29e8e55cc6b000cf9e878fd57c0fcd7b4be8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 15 Nov 2024 16:25:57 +0000 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- sota-implementations/cql/cql_online.py | 1 - sota-implementations/discrete_sac/config.yaml | 3 + .../discrete_sac/discrete_sac.py | 170 +++++++++--------- sota-implementations/discrete_sac/utils.py | 11 +- 4 files changed, 103 insertions(+), 82 deletions(-) diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index 15cf2c68142..76fca97288e 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -170,7 +170,6 @@ def update(sampled_tensordict): c_iter = iter(collector) for i in range(len(collector)): with timeit("collecting"): - torch.compiler.cudagraph_mark_step_begin() tensordict = next(c_iter) pbar.update(tensordict.numel()) # update weights of the inference policy diff --git a/sota-implementations/discrete_sac/config.yaml b/sota-implementations/discrete_sac/config.yaml index aa852ca1fc3..3f679023571 100644 --- a/sota-implementations/discrete_sac/config.yaml +++ b/sota-implementations/discrete_sac/config.yaml @@ -43,6 +43,9 @@ network: hidden_sizes: [256, 256] activation: relu device: null + compile: False + compile_mode: + cudagraphs: False # logging logger: diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index a9a08827f5d..5a0e757b1b6 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -10,17 +10,18 @@ The helper functions are coded in the utils.py associated with this script. """ -import time + +import warnings import hydra import numpy as np import torch import torch.cuda import tqdm -from torchrl._utils import logger as torchrl_logger - +from tensordict.nn import CudaGraphModule +from torchrl._utils import timeit from torchrl.envs.utils import ExplorationType, set_exploration_type - +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( dump_video, @@ -73,9 +74,6 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create TD3 loss loss_module, target_net_updater = make_loss_module(cfg, model) - # Create off-policy collector - collector = make_collector(cfg, train_env, model[0]) - # Create replay buffer replay_buffer = make_replay_buffer( batch_size=cfg.optim.batch_size, @@ -89,9 +87,57 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_actor, optimizer_critic, optimizer_alpha = make_optimizer( cfg, loss_module ) + optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha) + del optimizer_actor, optimizer_critic, optimizer_alpha + + def update(sampled_tensordict): + optimizer.zero_grad(set_to_none=True) + + # Compute loss + loss_out = loss_module(sampled_tensordict) + + actor_loss, q_loss, alpha_loss = ( + loss_out["loss_actor"], + loss_out["loss_qvalue"], + loss_out["loss_alpha"], + ) + + # Update critic + (q_loss + actor_loss + alpha_loss).backward() + optimizer.step() + + # Update target params + target_net_updater.step() + + return loss_out.detach() + + compile_mode = None + if cfg.network.compile: + compile_mode = cfg.network.compile_mode + if compile_mode in ("", None): + if cfg.network.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + update = torch.compile(update, mode=compile_mode) + if cfg.network.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + # Create off-policy collector + collector = make_collector( + cfg, + train_env, + model[0], + compile=compile_mode is not None, + compile_mode=compile_mode, + cudagraphs=cfg.network.cudagraphs, + ) # Main loop - start_time = time.time() collected_frames = 0 pbar = tqdm.tqdm(total=cfg.collector.total_frames) @@ -106,106 +152,72 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch - sampling_start = time.time() - for i, tensordict in enumerate(collector): - sampling_time = time.time() - sampling_start + c_iter = iter(collector) + for i in range(len(collector)): + with timeit("collecting"): + collected_data = next(c_iter) # Update weights of the inference policy collector.update_policy_weights_() + current_frames = collected_data.numel() - pbar.update(tensordict.numel()) + pbar.update(current_frames) - tensordict = tensordict.reshape(-1) - current_frames = tensordict.numel() - # Add to replay buffer - replay_buffer.extend(tensordict.cpu()) + collected_data = collected_data.reshape(-1) + with timeit("rb - extend"): + # Add to replay buffer + replay_buffer.extend(collected_data) collected_frames += current_frames # Optimization steps - training_start = time.time() if collected_frames >= init_random_frames: - ( - actor_losses, - q_losses, - alpha_losses, - ) = ([], [], []) + tds = [] for _ in range(num_updates): - # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict.clone() - - # Compute loss - loss_out = loss_module(sampled_tensordict) - - actor_loss, q_loss, alpha_loss = ( - loss_out["loss_actor"], - loss_out["loss_qvalue"], - loss_out["loss_alpha"], - ) - - # Update critic - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - q_losses.append(q_loss.item()) + with timeit("rb - sample"): + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample() - # Update actor - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + sampled_tensordict = sampled_tensordict.to(device) + loss_out = update(sampled_tensordict) - actor_losses.append(actor_loss.item()) - - # Update alpha - optimizer_alpha.zero_grad() - alpha_loss.backward() - optimizer_alpha.step() - - alpha_losses.append(alpha_loss.item()) - - # Update target params - target_net_updater.step() + tds.append(loss_out) # Update priority if prb: replay_buffer.update_priority(sampled_tensordict) + tds = torch.stack(tds).mean() - training_time = time.time() - training_start + # Logging episode_end = ( - tensordict["next", "done"] - if tensordict["next", "done"].any() - else tensordict["next", "truncated"] + collected_data["next", "done"] + if collected_data["next", "done"].any() + else collected_data["next", "truncated"] ) - episode_rewards = tensordict["next", "episode_reward"][episode_end] + episode_rewards = collected_data["next", "episode_reward"][episode_end] - # Logging metrics_to_log = {} if len(episode_rewards) > 0: - episode_length = tensordict["next", "step_count"][episode_end] + episode_length = collected_data["next", "step_count"][episode_end] metrics_to_log["train/reward"] = episode_rewards.mean().item() metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( episode_length ) if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = np.mean(q_losses) - metrics_to_log["train/a_loss"] = np.mean(actor_losses) - metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses) - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time + metrics_to_log["train/q_loss"] = tds["loss_qvalue"] + metrics_to_log["train/a_loss"] = tds["loss_actor"] + metrics_to_log["train/alpha_loss"] = tds["loss_alpha"] # Evaluation prev_test_frame = ((i - 1) * frames_per_batch) // eval_iter cur_test_frame = (i * frames_per_batch) // eval_iter final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], @@ -213,22 +225,20 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=True, ) eval_env.apply(dump_video) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + if i % 50 == 0: + metrics_to_log.update(timeit.todict(prefix="time")) + timeit.print() + timeit.erase() if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() collector.shutdown() if not eval_env.is_closed: eval_env.close() if not train_env.is_closed: train_env.close() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/discrete_sac/utils.py b/sota-implementations/discrete_sac/utils.py index 8051f07fe95..7055a00674b 100644 --- a/sota-implementations/discrete_sac/utils.py +++ b/sota-implementations/discrete_sac/utils.py @@ -111,7 +111,14 @@ def make_environment(cfg, logger=None): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore): +def make_collector( + cfg, + train_env, + actor_model_explore, + compile=False, + compile_mode=None, + cudagraphs=False, +): """Make collector.""" device = cfg.collector.device if device in ("", None): @@ -129,6 +136,8 @@ def make_collector(cfg, train_env, actor_model_explore): reset_at_each_iter=cfg.collector.reset_at_each_iter, device=device, storing_device="cpu", + compile_policy=False if not compile else {"mode": compile_mode}, + cudagraph_policy=cudagraphs, ) collector.set_seed(cfg.env.seed) return collector From 1e0c02108263607949998853b70db4d00fd3b302 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 13 Dec 2024 13:38:57 -0800 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- sota-implementations/a2c/utils_atari.py | 6 +++--- sota-implementations/a2c/utils_mujoco.py | 4 ++-- sota-implementations/cql/utils.py | 2 +- sota-implementations/dreamer/dreamer_utils.py | 14 +++++++------- sota-implementations/gail/ppo_utils.py | 4 ++-- sota-implementations/impala/utils.py | 2 +- sota-implementations/iql/utils.py | 2 +- sota-implementations/ppo/utils_atari.py | 6 +++--- sota-implementations/ppo/utils_mujoco.py | 4 ++-- 9 files changed, 22 insertions(+), 22 deletions(-) diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 379663ec95c..dee5d5f719b 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -92,12 +92,12 @@ def make_ppo_modules_pixels(proof_environment, device): input_shape = proof_environment.observation_spec["pixels"].shape # Define distribution class and kwargs - if isinstance(proof_environment.single_action_spec.space, CategoricalBox): - num_outputs = proof_environment.single_action_spec.space.n + if isinstance(proof_environment.action_spec_unbatched.space, CategoricalBox): + num_outputs = proof_environment.action_spec_unbatched.space.n distribution_class = OneHotCategorical distribution_kwargs = {} else: # is ContinuousBox - num_outputs = proof_environment.single_action_spec.shape + num_outputs = proof_environment.action_spec_unbatched.shape distribution_class = TanhNormal distribution_kwargs = { "low": proof_environment.action_spec_unbatched.space.low.to(device), diff --git a/sota-implementations/a2c/utils_mujoco.py b/sota-implementations/a2c/utils_mujoco.py index be058a07b22..953bf144080 100644 --- a/sota-implementations/a2c/utils_mujoco.py +++ b/sota-implementations/a2c/utils_mujoco.py @@ -53,7 +53,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False): input_shape = proof_environment.observation_spec["observation"].shape # Define policy output distribution class - num_outputs = proof_environment.single_action_spec.shape[-1] + num_outputs = proof_environment.action_spec_unbatched.shape[-1] distribution_class = TanhNormal distribution_kwargs = { "low": proof_environment.action_spec_unbatched.space.low.to(device), @@ -81,7 +81,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False): policy_mlp = torch.nn.Sequential( policy_mlp, AddStateIndependentNormalScale( - proof_environment.single_action_spec.shape[-1], device=device + proof_environment.action_spec_unbatched.shape[-1], device=device ), ) diff --git a/sota-implementations/cql/utils.py b/sota-implementations/cql/utils.py index 611a31e585d..c262a46e563 100644 --- a/sota-implementations/cql/utils.py +++ b/sota-implementations/cql/utils.py @@ -296,7 +296,7 @@ def make_discretecql_model(cfg, train_env, eval_env, device="cpu"): def make_cql_modules_state(model_cfg, proof_environment): - action_spec = proof_environment.single_action_spec + action_spec = proof_environment.action_spec_unbatched actor_net_kwargs = { "num_cells": model_cfg.hidden_sizes, diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 77c425d8d35..7f3fd663976 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -473,12 +473,12 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): spec=Composite( **{ "loc": Unbounded( - proof_environment.single_action_spec.shape, - device=proof_environment.single_action_spec.device, + proof_environment.action_spec_unbatched.shape, + device=proof_environment.action_spec_unbatched.device, ), "scale": Unbounded( - proof_environment.single_action_spec.shape, - device=proof_environment.single_action_spec.device, + proof_environment.action_spec_unbatched.shape, + device=proof_environment.action_spec_unbatched.device, ), } ), @@ -489,7 +489,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module): default_interaction_type=InteractionType.RANDOM, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, - spec=Composite(**{action_key: proof_environment.single_action_spec}), + spec=Composite(**{action_key: proof_environment.action_spec_unbatched}), ), ) return actor_simulator @@ -530,10 +530,10 @@ def _dreamer_make_actor_real( spec=Composite( **{ "loc": Unbounded( - proof_environment.single_action_spec.shape, + proof_environment.action_spec_unbatched.shape, ), "scale": Unbounded( - proof_environment.single_action_spec.shape, + proof_environment.action_spec_unbatched.shape, ), } ), diff --git a/sota-implementations/gail/ppo_utils.py b/sota-implementations/gail/ppo_utils.py index 2a443c11f1a..f2966ebcc7f 100644 --- a/sota-implementations/gail/ppo_utils.py +++ b/sota-implementations/gail/ppo_utils.py @@ -48,7 +48,7 @@ def make_ppo_models_state(proof_environment): input_shape = proof_environment.observation_spec["observation"].shape # Define policy output distribution class - num_outputs = proof_environment.single_action_spec.shape[-1] + num_outputs = proof_environment.action_spec_unbatched.shape[-1] distribution_class = TanhNormal distribution_kwargs = { "low": proof_environment.action_spec_unbatched.space.low, @@ -74,7 +74,7 @@ def make_ppo_models_state(proof_environment): policy_mlp = torch.nn.Sequential( policy_mlp, AddStateIndependentNormalScale( - proof_environment.single_action_spec.shape[-1], scale_lb=1e-8 + proof_environment.action_spec_unbatched.shape[-1], scale_lb=1e-8 ), ) diff --git a/sota-implementations/impala/utils.py b/sota-implementations/impala/utils.py index 7ed16313176..d5f6157d237 100644 --- a/sota-implementations/impala/utils.py +++ b/sota-implementations/impala/utils.py @@ -68,7 +68,7 @@ def make_ppo_modules_pixels(proof_environment): input_shape = proof_environment.observation_spec["pixels"].shape # Define distribution class and kwargs - num_outputs = proof_environment.single_action_spec.space.n + num_outputs = proof_environment.action_spec_unbatched.space.n distribution_class = OneHotCategorical distribution_kwargs = {} diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 617974752b0..a1b9ed54db6 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -247,7 +247,7 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"): def make_iql_modules_state(model_cfg, proof_environment): - action_spec = proof_environment.single_action_spec + action_spec = proof_environment.action_spec_unbatched actor_net_kwargs = { "num_cells": model_cfg.hidden_sizes, diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index 75e13fd4d3a..ec9cc3cd1d5 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -91,12 +91,12 @@ def make_ppo_modules_pixels(proof_environment): input_shape = proof_environment.observation_spec["pixels"].shape # Define distribution class and kwargs - if isinstance(proof_environment.single_action_spec.space, CategoricalBox): - num_outputs = proof_environment.single_action_spec.space.n + if isinstance(proof_environment.action_spec_unbatched.space, CategoricalBox): + num_outputs = proof_environment.action_spec_unbatched.space.n distribution_class = OneHotCategorical distribution_kwargs = {} else: # is ContinuousBox - num_outputs = proof_environment.single_action_spec.shape + num_outputs = proof_environment.action_spec_unbatched.shape distribution_class = TanhNormal distribution_kwargs = { "low": proof_environment.action_spec_unbatched.space.low, diff --git a/sota-implementations/ppo/utils_mujoco.py b/sota-implementations/ppo/utils_mujoco.py index 2a443c11f1a..f2966ebcc7f 100644 --- a/sota-implementations/ppo/utils_mujoco.py +++ b/sota-implementations/ppo/utils_mujoco.py @@ -48,7 +48,7 @@ def make_ppo_models_state(proof_environment): input_shape = proof_environment.observation_spec["observation"].shape # Define policy output distribution class - num_outputs = proof_environment.single_action_spec.shape[-1] + num_outputs = proof_environment.action_spec_unbatched.shape[-1] distribution_class = TanhNormal distribution_kwargs = { "low": proof_environment.action_spec_unbatched.space.low, @@ -74,7 +74,7 @@ def make_ppo_models_state(proof_environment): policy_mlp = torch.nn.Sequential( policy_mlp, AddStateIndependentNormalScale( - proof_environment.single_action_spec.shape[-1], scale_lb=1e-8 + proof_environment.action_spec_unbatched.shape[-1], scale_lb=1e-8 ), ) From 95ffff36596095d091e79b0eb43fb67eb721b8ea Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 13 Dec 2024 14:13:58 -0800 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- sota-implementations/discrete_sac/discrete_sac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index 1fa3271da82..8b3efe15102 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -180,7 +180,7 @@ def update(sampled_tensordict): with timeit("update"): torch.compiler.cudagraph_mark_step_begin() sampled_tensordict = sampled_tensordict.to(device) - loss_out = update(sampled_tensordict) + loss_out = update(sampled_tensordict).clone() tds.append(loss_out)