Skip to content

Commit

Permalink
[Feature] Log pbar rate in SOTA implementations
Browse files Browse the repository at this point in the history
ghstack-source-id: 110e906b617f644465ae4ff1360d8b644bf5be6f
Pull Request resolved: #2662
  • Loading branch information
vmoens committed Dec 17, 2024
1 parent 91064bc commit 0a231b0
Show file tree
Hide file tree
Showing 29 changed files with 156 additions and 141 deletions.
15 changes: 8 additions & 7 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
with timeit("collecting"):
data = next(c_iter)

log_info = {}
metrics_to_log = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch * frame_skip
pbar.update(data.numel())
Expand All @@ -198,7 +198,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
episode_rewards = data["next", "episode_reward"][data["next", "terminated"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "terminated"]]
log_info.update(
metrics_to_log.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
Expand Down Expand Up @@ -242,8 +242,8 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
losses = torch.stack(losses).float().mean()

for key, value in losses.items():
log_info.update({f"train/{key}": value.item()})
log_info.update(
metrics_to_log.update({f"train/{key}": value.item()})
metrics_to_log.update(
{
"train/lr": lr * alpha,
}
Expand All @@ -259,15 +259,16 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
test_rewards = eval_model(
actor_eval, test_env, num_episodes=cfg.logger.num_test_episodes
)
log_info.update(
metrics_to_log.update(
{
"test/reward": test_rewards.mean(),
}
)
log_info.update(timeit.todict(prefix="time"))

if logger:
for key, value in log_info.items():
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
for key, value in metrics_to_log.items():
logger.log_scalar(key, value, collected_frames)

collector.shutdown()
Expand Down
18 changes: 8 additions & 10 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def update(batch):
with timeit("collecting"):
data = next(c_iter)

log_info = {}
metrics_to_log = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())
Expand All @@ -195,7 +195,7 @@ def update(batch):
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]
log_info.update(
metrics_to_log.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
Expand Down Expand Up @@ -236,8 +236,8 @@ def update(batch):
# Get training losses
losses = torch.stack(losses).float().mean()
for key, value in losses.items():
log_info.update({f"train/{key}": value.item()})
log_info.update(
metrics_to_log.update({f"train/{key}": value.item()})
metrics_to_log.update(
{
"train/lr": alpha * cfg.optim.lr,
}
Expand All @@ -253,21 +253,19 @@ def update(batch):
test_rewards = eval_model(
actor, test_env, num_episodes=cfg.logger.num_test_episodes
)
log_info.update(
metrics_to_log.update(
{
"test/reward": test_rewards.mean(),
}
)
actor.train()

log_info.update(timeit.todict(prefix="time"))

if logger:
for key, value in log_info.items():
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
for key, value in metrics_to_log.items():
logger.log_scalar(key, value, collected_frames)

torch.compiler.cudagraph_mark_step_begin()

collector.shutdown()
if not test_env.is_closed:
test_env.close()
Expand Down
9 changes: 5 additions & 4 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def update(data, policy_eval_start, iteration):
)

# log metrics
to_log = {
metrics_to_log = {
"loss": loss.cpu(),
**loss_vals.cpu(),
}
Expand All @@ -188,11 +188,12 @@ def update(data, policy_eval_start, iteration):
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward
metrics_to_log["evaluation_reward"] = eval_reward

with timeit("log"):
to_log.update(timeit.todict(prefix="time"))
log_metrics(logger, to_log, i)
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, i)

pbar.close()
if not eval_env.is_closed:
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def update(sampled_tensordict):
"loss_alpha_prime"
).mean()
metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean()
metrics_to_log.update(timeit.todict(prefix="time"))

# Evaluation
with timeit("eval"):
Expand All @@ -241,6 +240,8 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
metrics_to_log["eval/reward"] = eval_reward

metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,10 @@ def update(sampled_tensordict):
tds = torch.stack(tds, dim=0).mean()
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
metrics_to_log["train/cql_loss"] = tds["loss_cql"]
metrics_to_log.update(timeit.todict(prefix="time"))

if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def make_offline_replay_buffer(rb_cfg):
dataset_id=rb_cfg.dataset,
split_trajs=False,
batch_size=rb_cfg.batch_size,
sampler=SamplerWithoutReplacement(drop_last=False),
sampler=SamplerWithoutReplacement(drop_last=True),
prefetch=4,
direct_download=True,
)
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,14 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
metrics_to_log.update(timeit.todict(prefix="time"))
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
metrics_to_log["train/actor_loss"] = tds["loss_actor"]
metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]

if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,10 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log.update(timeit.todict(prefix="time"))

if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
13 changes: 8 additions & 5 deletions sota-implementations/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def main(cfg: "DictConfig"): # noqa: F821
loss_module = make_dt_loss(cfg.loss, actor, device=model_device)

# Create optimizer
transformer_optim, scheduler = make_dt_optimizer(cfg.optim, loss_module)
transformer_optim, scheduler = make_dt_optimizer(
cfg.optim, loss_module, model_device
)

# Create inference policy
inference_policy = DecisionTransformerInferenceWrapper(
Expand Down Expand Up @@ -136,7 +138,7 @@ def update(data: TensorDict) -> TensorDict:
loss_vals = update(data)
scheduler.step()
# Log metrics
to_log = {"train/loss": loss_vals["loss"]}
metrics_to_log = {"train/loss": loss_vals["loss"]}

# Evaluation
with set_exploration_type(
Expand All @@ -149,13 +151,14 @@ def update(data: TensorDict) -> TensorDict:
auto_cast_to_device=True,
)
test_env.apply(dump_video)
to_log["eval/reward"] = (
metrics_to_log["eval/reward"] = (
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)
to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, to_log, i)
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, i)

pbar.close()
if not test_env.is_closed:
Expand Down
10 changes: 5 additions & 5 deletions sota-implementations/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def update(data):
scheduler.step()

# Log metrics
to_log = {
metrics_to_log = {
"train/loss_log_likelihood": loss_vals["loss_log_likelihood"],
"train/loss_entropy": loss_vals["loss_entropy"],
"train/loss_alpha": loss_vals["loss_alpha"],
Expand All @@ -165,14 +165,14 @@ def update(data):
)
test_env.apply(dump_video)
inference_policy.train()
to_log["eval/reward"] = (
metrics_to_log["eval/reward"] = (
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)

to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, to_log, i)
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, i)

pbar.close()
if not test_env.is_closed:
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,10 +511,10 @@ def make_odt_optimizer(optim_cfg, loss_module):
return dt_optimizer, log_temp_optimizer, scheduler


def make_dt_optimizer(optim_cfg, loss_module):
def make_dt_optimizer(optim_cfg, loss_module, device):
dt_optimizer = torch.optim.Adam(
loss_module.actor_network_params.flatten_keys().values(),
lr=torch.as_tensor(optim_cfg.lr),
lr=torch.tensor(optim_cfg.lr, device=device),
weight_decay=optim_cfg.weight_decay,
eps=1.0e-8,
)
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,9 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log.update(timeit.todict(prefix="time"))
if logger is not None:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log["time/speed"] = pbar.format_dict["rate"]
log_metrics(logger, metrics_to_log, collected_frames)

collector.shutdown()
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/dqn/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ env:
# collector
collector:
total_frames: 40_000_100
frames_per_batch: 16
frames_per_batch: 1600
eps_start: 1.0
eps_end: 0.01
annealing_frames: 4_000_000
Expand Down Expand Up @@ -38,9 +38,9 @@ optim:
loss:
gamma: 0.99
hard_update_freq: 10_000
num_updates: 1
num_updates: 100

compile:
compile: False
compile_mode:
compile_mode: default
cudagraphs: False
4 changes: 2 additions & 2 deletions sota-implementations/dqn/config_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ env:
# collector
collector:
total_frames: 500_100
frames_per_batch: 10
frames_per_batch: 1000
eps_start: 1.0
eps_end: 0.05
annealing_frames: 250_000
Expand Down Expand Up @@ -37,7 +37,7 @@ optim:
loss:
gamma: 0.99
hard_update_freq: 50
num_updates: 1
num_updates: 100

compile:
compile: False
Expand Down
Loading

0 comments on commit 0a231b0

Please sign in to comment.