Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Discrete SAC compatibility with compile #2569

Merged
merged 42 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def make_ppo_modules_pixels(proof_environment, device):
input_shape = proof_environment.observation_spec["pixels"].shape

# Define distribution class and kwargs
if isinstance(proof_environment.single_action_spec.space, CategoricalBox):
num_outputs = proof_environment.single_action_spec.space.n
if isinstance(proof_environment.action_spec_unbatched.space, CategoricalBox):
num_outputs = proof_environment.action_spec_unbatched.space.n
distribution_class = OneHotCategorical
distribution_kwargs = {}
else: # is ContinuousBox
num_outputs = proof_environment.single_action_spec.shape
num_outputs = proof_environment.action_spec_unbatched.shape
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec_unbatched.space.low.to(device),
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
input_shape = proof_environment.observation_spec["observation"].shape

# Define policy output distribution class
num_outputs = proof_environment.single_action_spec.shape[-1]
num_outputs = proof_environment.action_spec_unbatched.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec_unbatched.space.low.to(device),
Expand Down Expand Up @@ -82,7 +82,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
policy_mlp = torch.nn.Sequential(
policy_mlp,
AddStateIndependentNormalScale(
proof_environment.single_action_spec.shape[-1], device=device
proof_environment.action_spec_unbatched.shape[-1], device=device
),
)

Expand Down
1 change: 0 additions & 1 deletion sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def update(sampled_tensordict):
c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
torch.compiler.cudagraph_mark_step_begin()
tensordict = next(c_iter)
pbar.update(tensordict.numel())
# update weights of the inference policy
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def make_discretecql_model(cfg, train_env, eval_env, device="cpu"):


def make_cql_modules_state(model_cfg, proof_environment):
action_spec = proof_environment.single_action_spec
action_spec = proof_environment.action_spec_unbatched

actor_net_kwargs = {
"num_cells": model_cfg.hidden_sizes,
Expand Down
5 changes: 5 additions & 0 deletions sota-implementations/discrete_sac/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ network:
activation: relu
device: null

compile:
compile: False
compile_mode:
cudagraphs: False

# logging
logger:
backend: wandb
Expand Down
170 changes: 90 additions & 80 deletions sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@

The helper functions are coded in the utils.py associated with this script.
"""

from __future__ import annotations

import time
import warnings

import hydra
import numpy as np
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger

from tensordict.nn import CudaGraphModule
from torchrl._utils import timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type

from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
dump_video,
Expand Down Expand Up @@ -75,9 +76,6 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create TD3 loss
loss_module, target_net_updater = make_loss_module(cfg, model)

# Create off-policy collector
collector = make_collector(cfg, train_env, model[0])

# Create replay buffer
replay_buffer = make_replay_buffer(
batch_size=cfg.optim.batch_size,
Expand All @@ -91,9 +89,57 @@ def main(cfg: "DictConfig"): # noqa: F821
optimizer_actor, optimizer_critic, optimizer_alpha = make_optimizer(
cfg, loss_module
)
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
del optimizer_actor, optimizer_critic, optimizer_alpha

def update(sampled_tensordict):
optimizer.zero_grad(set_to_none=True)

# Compute loss
loss_out = loss_module(sampled_tensordict)

actor_loss, q_loss, alpha_loss = (
loss_out["loss_actor"],
loss_out["loss_qvalue"],
loss_out["loss_alpha"],
)

# Update critic
(q_loss + actor_loss + alpha_loss).backward()
optimizer.step()

# Update target params
target_net_updater.step()

return loss_out.detach()

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Create off-policy collector
collector = make_collector(
cfg,
train_env,
model[0],
compile=compile_mode is not None,
compile_mode=compile_mode,
cudagraphs=cfg.compile.cudagraphs,
)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

Expand All @@ -108,129 +154,93 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_iter = cfg.logger.eval_iter
frames_per_batch = cfg.collector.frames_per_batch

sampling_start = time.time()
for i, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
collected_data = next(c_iter)

# Update weights of the inference policy
collector.update_policy_weights_()
current_frames = collected_data.numel()

pbar.update(tensordict.numel())
pbar.update(current_frames)

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
collected_data = collected_data.reshape(-1)
with timeit("rb - extend"):
# Add to replay buffer
replay_buffer.extend(collected_data)
collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
alpha_losses,
) = ([], [], [])
tds = []
for _ in range(num_updates):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()

# Compute loss
loss_out = loss_module(sampled_tensordict)

actor_loss, q_loss, alpha_loss = (
loss_out["loss_actor"],
loss_out["loss_qvalue"],
loss_out["loss_alpha"],
)

# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
q_losses.append(q_loss.item())
with timeit("rb - sample"):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()

# Update actor
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
sampled_tensordict = sampled_tensordict.to(device)
loss_out = update(sampled_tensordict).clone()

actor_losses.append(actor_loss.item())

# Update alpha
optimizer_alpha.zero_grad()
alpha_loss.backward()
optimizer_alpha.step()

alpha_losses.append(alpha_loss.item())

# Update target params
target_net_updater.step()
tds.append(loss_out)

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)
tds = torch.stack(tds).mean()

training_time = time.time() - training_start
# Logging
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
else tensordict["next", "truncated"]
collected_data["next", "done"]
if collected_data["next", "done"].any()
else collected_data["next", "truncated"]
)
episode_rewards = tensordict["next", "episode_reward"][episode_end]
episode_rewards = collected_data["next", "episode_reward"][episode_end]

# Logging
metrics_to_log = {}
if len(episode_rewards) > 0:
episode_length = tensordict["next", "step_count"][episode_end]
episode_length = collected_data["next", "step_count"][episode_end]
metrics_to_log["train/reward"] = episode_rewards.mean().item()
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)

if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses)
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
metrics_to_log["train/alpha_loss"] = np.mean(alpha_losses)
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
metrics_to_log["train/a_loss"] = tds["loss_actor"]
metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]

# Evaluation
prev_test_frame = ((i - 1) * frames_per_batch) // eval_iter
cur_test_frame = (i * frames_per_batch) // eval_iter
final = current_frames >= collector.total_frames
if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
model[0],
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if i % 50 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
end_time = time.time()
execution_time = end_time - start_time
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
11 changes: 10 additions & 1 deletion sota-implementations/discrete_sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,14 @@ def make_environment(cfg, logger=None):
# ---------------------------


def make_collector(cfg, train_env, actor_model_explore):
def make_collector(
cfg,
train_env,
actor_model_explore,
compile=False,
compile_mode=None,
cudagraphs=False,
):
"""Make collector."""
device = cfg.collector.device
if device in ("", None):
Expand All @@ -131,6 +138,8 @@ def make_collector(cfg, train_env, actor_model_explore):
reset_at_each_iter=cfg.collector.reset_at_each_iter,
device=device,
storing_device="cpu",
compile_policy=False if not compile else {"mode": compile_mode},
cudagraph_policy=cudagraphs,
)
collector.set_seed(cfg.env.seed)
return collector
Expand Down
14 changes: 7 additions & 7 deletions sota-implementations/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,12 +475,12 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
spec=Composite(
**{
"loc": Unbounded(
proof_environment.single_action_spec.shape,
device=proof_environment.single_action_spec.device,
proof_environment.action_spec_unbatched.shape,
device=proof_environment.action_spec_unbatched.device,
),
"scale": Unbounded(
proof_environment.single_action_spec.shape,
device=proof_environment.single_action_spec.device,
proof_environment.action_spec_unbatched.shape,
device=proof_environment.action_spec_unbatched.device,
),
}
),
Expand All @@ -491,7 +491,7 @@ def _dreamer_make_actor_sim(action_key, proof_environment, actor_module):
default_interaction_type=InteractionType.RANDOM,
distribution_class=TanhNormal,
distribution_kwargs={"tanh_loc": True},
spec=Composite(**{action_key: proof_environment.single_action_spec}),
spec=Composite(**{action_key: proof_environment.action_spec_unbatched}),
),
)
return actor_simulator
Expand Down Expand Up @@ -532,10 +532,10 @@ def _dreamer_make_actor_real(
spec=Composite(
**{
"loc": Unbounded(
proof_environment.single_action_spec.shape,
proof_environment.action_spec_unbatched.shape,
),
"scale": Unbounded(
proof_environment.single_action_spec.shape,
proof_environment.action_spec_unbatched.shape,
),
}
),
Expand Down
Loading
Loading