Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] GAIL compatibility with compile #2573

Merged
merged 65 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
60c58de
Update
vmoens Nov 18, 2024
36a47f0
Update
vmoens Nov 18, 2024
ba460bc
Update
vmoens Nov 18, 2024
7bad4b9
Update
vmoens Nov 18, 2024
c51c444
Update
vmoens Nov 18, 2024
4643a60
Update
vmoens Nov 18, 2024
05e1440
Update
vmoens Nov 18, 2024
e7de6b8
Update
vmoens Nov 18, 2024
25c3833
Update
vmoens Nov 18, 2024
fadc8e7
Update
vmoens Nov 18, 2024
e9355ce
Update
vmoens Nov 18, 2024
2c5b7ca
Update
vmoens Nov 18, 2024
64d1d92
Update
vmoens Nov 18, 2024
61734fc
Update
vmoens Nov 18, 2024
4453d26
Update
vmoens Nov 18, 2024
5021e7f
Update
vmoens Nov 18, 2024
a7bf0ca
Update
vmoens Nov 18, 2024
f1b0a27
Update
vmoens Nov 18, 2024
2898b35
Update
vmoens Nov 18, 2024
de60416
Update
vmoens Nov 18, 2024
304028c
Update
vmoens Nov 18, 2024
a75d6d0
Update
vmoens Nov 18, 2024
992e37b
Update
vmoens Nov 19, 2024
c4d8b45
Update
vmoens Nov 19, 2024
4d5dfe1
Update
vmoens Nov 19, 2024
3179bc5
Update
vmoens Nov 20, 2024
af1f5ae
Update
vmoens Nov 20, 2024
386018f
Update
vmoens Nov 21, 2024
9c72d75
Update
vmoens Nov 25, 2024
43e09da
Update
vmoens Dec 13, 2024
5f7da2b
Update
vmoens Dec 13, 2024
cabc794
Update
vmoens Dec 13, 2024
e795639
Update
vmoens Dec 13, 2024
af99f94
Update
vmoens Dec 13, 2024
1d25106
Update
vmoens Dec 13, 2024
42b8aad
Update
vmoens Dec 13, 2024
3c3a7ca
Update
vmoens Dec 13, 2024
80b388c
Update
vmoens Dec 13, 2024
4a2d3e5
Update
vmoens Dec 13, 2024
89db585
Update
vmoens Dec 14, 2024
3f10630
Update
vmoens Dec 14, 2024
89b1cb3
Update
vmoens Dec 14, 2024
7a4cc1e
Update
vmoens Dec 14, 2024
c375abf
Update
vmoens Dec 14, 2024
f987f59
Update
vmoens Dec 14, 2024
8250968
Update
vmoens Dec 14, 2024
9733bc2
Update
vmoens Dec 14, 2024
b087d00
Update
vmoens Dec 14, 2024
ad97990
Update
vmoens Dec 15, 2024
a4abc7b
Update
vmoens Dec 15, 2024
23cab41
Update
vmoens Dec 15, 2024
ad1f0a4
Update
vmoens Dec 15, 2024
5259e6e
Update
vmoens Dec 15, 2024
b26b9d3
Update
vmoens Dec 15, 2024
7cc8f36
Update
vmoens Dec 15, 2024
918d1fc
Update
vmoens Dec 15, 2024
d9e9921
Update
vmoens Dec 15, 2024
8d7d2c0
Update
vmoens Dec 15, 2024
f66ad56
Update
vmoens Dec 15, 2024
e681705
Update
vmoens Dec 15, 2024
f78dde4
Update
vmoens Dec 15, 2024
1d5484a
Update
vmoens Dec 15, 2024
2813d4b
Update
vmoens Dec 15, 2024
8476c5e
Update
vmoens Dec 15, 2024
73b5916
Update
vmoens Dec 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def make_ppo_modules_pixels(proof_environment, device):
policy_module = ProbabilisticActor(
policy_module,
in_keys=["logits"],
spec=proof_environment.single_full_action_spec.to(device),
spec=proof_environment.full_action_spec_unbatched.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def make_ppo_models_state(proof_environment, device, *, compile: bool = False):
out_keys=["loc", "scale"],
),
in_keys=["loc", "scale"],
spec=proof_environment.single_full_action_spec.to(device),
spec=proof_environment.full_action_spec_unbatched.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/discrete_cql_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ collector:
multi_step: 0
init_random_frames: 1000
env_per_collector: 1
device: cpu
device:
max_frames_per_traj: 200
annealing_frames: 10000
eps_start: 1.0
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ collector:
multi_step: 0
init_random_frames: 5_000
env_per_collector: 1
device: cpu
device:
max_frames_per_traj: 1000


Expand Down
8 changes: 7 additions & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,20 @@ def make_collector(
cudagraph=False,
):
"""Make collector."""
device = cfg.collector.device
if device in ("", None):
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
collector = SyncDataCollector(
train_env,
actor_model_explore,
init_random_frames=cfg.collector.init_random_frames,
frames_per_batch=cfg.collector.frames_per_batch,
max_frames_per_traj=cfg.collector.max_frames_per_traj,
total_frames=cfg.collector.total_frames,
device=cfg.collector.device,
device=device,
compile_policy={"mode": compile_mode} if compile else False,
cudagraph_policy=cudagraph,
)
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def _dreamer_make_actor_real(
default_interaction_type=InteractionType.DETERMINISTIC,
distribution_class=TanhNormal,
distribution_kwargs={"tanh_loc": True},
spec=proof_environment.single_full_action_spec.to("cpu"),
spec=proof_environment.full_action_spec_unbatched.to("cpu"),
),
),
SafeModule(
Expand Down
5 changes: 5 additions & 0 deletions sota-implementations/gail/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ gail:
gp_lambda: 10.0
device: null

compile:
compile: False
compile_mode: default
cudagraphs: False

replay_buffer:
dataset: halfcheetah-expert-v2
batch_size: 256
185 changes: 113 additions & 72 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,33 @@
"""
from __future__ import annotations

import warnings

import hydra
import numpy as np
import torch
import tqdm

from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer
from ppo_utils import eval_model, make_env, make_ppo_models
from tensordict.nn import CudaGraphModule

from torchrl._utils import compile_with_warmup
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import ClipPPOLoss, GAILLoss
from torchrl.objectives import ClipPPOLoss, GAILLoss, group_optimizers
from torchrl.objectives.value.advantages import GAE
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger


torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
set_gym_backend(cfg.env.backend).set()
Expand Down Expand Up @@ -71,25 +79,20 @@ def main(cfg: "DictConfig"): # noqa: F821
np.random.seed(cfg.env.seed)

# Create models (check utils_mujoco.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = actor.to(device), critic.to(device)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.ppo.collector.frames_per_batch,
total_frames=cfg.ppo.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
actor, critic = make_ppo_models(
cfg.env.env_name, compile=cfg.compile.compile, device=device
)

# Create data buffer
data_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(cfg.ppo.collector.frames_per_batch),
storage=LazyTensorStorage(
cfg.ppo.collector.frames_per_batch,
device=device,
compilable=cfg.compile.compile,
),
sampler=SamplerWithoutReplacement(),
batch_size=cfg.ppo.loss.mini_batch_size,
compilable=cfg.compile.compile,
)

# Create loss and adv modules
Expand All @@ -98,6 +101,7 @@ def main(cfg: "DictConfig"): # noqa: F821
lmbda=cfg.ppo.loss.gae_lambda,
value_network=critic,
average_gae=False,
device=device,
)

loss_module = ClipPPOLoss(
Expand All @@ -111,8 +115,35 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create optimizers
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
actor_optim = torch.optim.Adam(
actor.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5
)
critic_optim = torch.optim.Adam(
critic.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5
)
optim = group_optimizers(actor_optim, critic_optim)
del actor_optim, critic_optim

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.ppo.collector.frames_per_batch,
total_frames=cfg.ppo.collector.total_frames,
device=device,
max_frames_per_traj=-1,
compile_policy={"mode": compile_mode} if compile_mode is not None else False,
cudagraph_policy=cfg.compile.cudagraphs,
)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
Expand Down Expand Up @@ -140,32 +171,9 @@ def main(cfg: "DictConfig"): # noqa: F821
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
)
test_env.eval()
num_network_updates = torch.zeros((), dtype=torch.int64, device=device)

# Training loop
collected_frames = 0
num_network_updates = 0
pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)

# extract cfg variables
cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
cfg_optim_lr = cfg.ppo.optim.lr
cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes

for i, data in enumerate(collector):

log_info = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Update discriminator
# Get expert data
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)
def update(data, expert_data, num_network_updates=num_network_updates):
# Add collector data to expert data
expert_data.set(
discriminator_loss.tensor_keys.collector_action,
Expand All @@ -178,9 +186,9 @@ def main(cfg: "DictConfig"): # noqa: F821
d_loss = discriminator_loss(expert_data)

# Backward pass
discriminator_optim.zero_grad()
d_loss.get("loss").backward()
discriminator_optim.step()
discriminator_optim.zero_grad(set_to_none=True)

# Compute discriminator reward
with torch.no_grad():
Expand All @@ -190,40 +198,25 @@ def main(cfg: "DictConfig"): # noqa: F821
# Set discriminator rewards to tensordict
data.set(("next", "reward"), d_rewards)

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]
log_info.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
}
)
# Update PPO
for _ in range(cfg_loss_ppo_epochs):

# Compute GAE
with torch.no_grad():
data = adv_module(data)
data_reshape = data.reshape(-1)

# Update the data buffer
data_buffer.empty()
data_buffer.extend(data_reshape)

for _, batch in enumerate(data_buffer):

# Get a data batch
batch = batch.to(device)
for batch in data_buffer:
optim.zero_grad(set_to_none=True)

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
alpha = torch.ones((), device=device)
if cfg_optim_anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in actor_optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
for group in critic_optim.param_groups:
for group in optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
Expand All @@ -235,20 +228,68 @@ def main(cfg: "DictConfig"): # noqa: F821
actor_loss = loss["loss_objective"] + loss["loss_entropy"]

# Backward pass
actor_loss.backward()
critic_loss.backward()
(actor_loss + critic_loss).backward()

# Update the networks
actor_optim.step()
critic_optim.step()
actor_optim.zero_grad()
critic_optim.zero_grad()
optim.step()
return {"dloss": d_loss, "alpha": alpha}

if cfg.compile.compile:
update = compile_with_warmup(update, warmup=2, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Training loop
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)

# extract cfg variables
cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
cfg_optim_lr = cfg.ppo.optim.lr
cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes

for i, data in enumerate(collector):

log_info = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Update discriminator
# Get expert data
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)

metadata = update(data, expert_data)
d_loss = metadata["dloss"]
alpha = metadata["alpha"]

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]

log_info.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
}
)

log_info.update(
{
"train/actor_loss": actor_loss.item(),
"train/critic_loss": critic_loss.item(),
"train/discriminator_loss": d_loss["loss"].item(),
# "train/actor_loss": actor_loss.item(),
# "train/critic_loss": critic_loss.item(),
"train/discriminator_loss": d_loss["loss"],
"train/lr": alpha * cfg_optim_lr,
"train/clip_epsilon": (
alpha * cfg_loss_clip_epsilon
Expand Down
Loading
Loading