Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] timeit.printevery #2653

Merged
merged 25 commits into from
Dec 16, 2024
10 changes: 5 additions & 5 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
lr = cfg.optim.lr

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)

with timeit("collecting"):
data = next(c_iter)

Expand Down Expand Up @@ -261,10 +264,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
"test/reward": test_rewards.mean(),
}
)
if i % 200 == 0:
log_info.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
log_info.update(timeit.todict(prefix="time"))

if logger:
for key, value in log_info.items():
Expand Down
10 changes: 5 additions & 5 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def update(batch):
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)

with timeit("collecting"):
data = next(c_iter)

Expand Down Expand Up @@ -257,10 +260,7 @@ def update(batch):
)
actor.train()

if i % 200 == 0:
log_info.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
log_info.update(timeit.todict(prefix="time"))

if logger:
for key, value in log_info.items():
Expand Down
12 changes: 3 additions & 9 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"""
from __future__ import annotations

import time
import warnings

import hydra
Expand All @@ -21,7 +20,7 @@
import tqdm
from tensordict.nn import CudaGraphModule

from torchrl._utils import logger as torchrl_logger, timeit
from torchrl._utils import timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -156,9 +155,9 @@ def update(data, policy_eval_start, iteration):
eval_steps = cfg.logger.eval_steps

# Training loop
start_time = time.time()
policy_eval_start = torch.tensor(policy_eval_start, device=device)
for i in range(gradient_steps):
timeit.printevery(1000, gradient_steps, erase=True)
pbar.update(1)
# sample data
with timeit("sample"):
Expand Down Expand Up @@ -192,15 +191,10 @@ def update(data, policy_eval_start, iteration):
to_log["evaluation_reward"] = eval_reward

with timeit("log"):
if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
to_log.update(timeit.todict(prefix="time"))
log_metrics(logger, to_log, i)
if i % 200 == 0:
timeit.print()
timeit.erase()

pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")
if not eval_env.is_closed:
eval_env.close()

Expand Down
10 changes: 4 additions & 6 deletions sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def update(sampled_tensordict):
eval_rollout_steps = cfg.logger.eval_steps

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
tensordict = next(c_iter)
pbar.update(tensordict.numel())
Expand Down Expand Up @@ -222,8 +224,7 @@ def update(sampled_tensordict):
"loss_alpha_prime"
).mean()
metrics_to_log["train/entropy"] = log_loss_td.get("entropy").mean()
if i % 10 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log.update(timeit.todict(prefix="time"))

# Evaluation
with timeit("eval"):
Expand All @@ -245,9 +246,6 @@ def update(sampled_tensordict):
metrics_to_log["eval/reward"] = eval_reward

log_metrics(logger, metrics_to_log, collected_frames)
if i % 10 == 0:
timeit.print()
timeit.erase()

collector.shutdown()
if not eval_env.is_closed:
Expand Down
11 changes: 4 additions & 7 deletions sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def update(sampled_tensordict):
frames_per_batch = cfg.collector.frames_per_batch

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for _ in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
torch.compiler.cudagraph_mark_step_begin()
tensordict = next(c_iter)
Expand Down Expand Up @@ -224,12 +226,7 @@ def update(sampled_tensordict):
tds = torch.stack(tds, dim=0).mean()
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
metrics_to_log["train/cql_loss"] = tds["loss_cql"]
if i % 100 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))

if i % 100 == 0:
timeit.print()
timeit.erase()
metrics_to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
Expand Down
10 changes: 4 additions & 6 deletions sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
update_counter = 0
delayed_updates = cfg.optim.policy_update_delay
c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for _ in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
torch.compiler.cudagraph_mark_step_begin()
tensordict = next(c_iter)
Expand Down Expand Up @@ -258,18 +260,14 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
metrics_to_log["train/episode_length"] = episode_length.sum().item() / len(
episode_length
)
if i % 20 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
metrics_to_log.update(timeit.todict(prefix="time"))
if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = tds["loss_qvalue"]
metrics_to_log["train/actor_loss"] = tds["loss_actor"]
metrics_to_log["train/alpha_loss"] = tds["loss_alpha"]

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
if i % 20 == 0:
timeit.print()
timeit.erase()

collector.shutdown()
if not eval_env.is_closed:
Expand Down
9 changes: 4 additions & 5 deletions sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def update(sampled_tensordict):
eval_rollout_steps = cfg.env.max_episode_steps

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for _ in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
tensordict = next(c_iter)
# Update exploration policy
Expand Down Expand Up @@ -226,10 +228,7 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
if i % 20 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
metrics_to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
Expand Down
6 changes: 2 additions & 4 deletions sota-implementations/decision_transformer/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def update(data: TensorDict) -> TensorDict:
# Pretraining
pbar = tqdm.tqdm(range(pretrain_gradient_steps))
for i in pbar:
timeit.printevery(1000, pretrain_gradient_steps, erase=True)
# Sample data
with timeit("rb - sample"):
data = offline_buffer.sample().to(model_device)
Expand All @@ -151,10 +152,7 @@ def update(data: TensorDict) -> TensorDict:
to_log["eval/reward"] = (
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)
if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, to_log, i)
Expand Down
9 changes: 2 additions & 7 deletions sota-implementations/decision_transformer/online_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""
from __future__ import annotations

import time
import warnings

import hydra
Expand Down Expand Up @@ -130,8 +129,8 @@ def update(data):

torchrl_logger.info(" ***Pretraining*** ")
# Pretraining
start_time = time.time()
for i in range(pretrain_gradient_steps):
timeit.printevery(1000, pretrain_gradient_steps, erase=True)
pbar.update(1)
with timeit("sample"):
# Sample data
Expand Down Expand Up @@ -170,18 +169,14 @@ def update(data):
eval_td["next", "reward"].sum(1).mean().item() / reward_scaling
)

if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
to_log.update(timeit.todict(prefix="time"))

if logger is not None:
log_metrics(logger, to_log, i)

pbar.close()
if not test_env.is_closed:
test_env.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
9 changes: 4 additions & 5 deletions sota-implementations/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def update(sampled_tensordict):
frames_per_batch = cfg.collector.frames_per_batch

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
collected_data = next(c_iter)

Expand Down Expand Up @@ -229,10 +231,7 @@ def update(sampled_tensordict):
eval_env.apply(dump_video)
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
if i % 50 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()
metrics_to_log.update(timeit.todict(prefix="time"))
if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)

Expand Down
9 changes: 4 additions & 5 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def update(sampled_tensordict):
pbar = tqdm.tqdm(total=total_frames)

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
data = next(c_iter)
log_info = {}
Expand Down Expand Up @@ -241,10 +243,7 @@ def update(sampled_tensordict):
)
model.train()

if i % 200 == 0:
timeit.print()
log_info.update(timeit.todict(prefix="time"))
timeit.erase()
log_info.update(timeit.todict(prefix="time"))

# Log all the information
if logger:
Expand Down
9 changes: 4 additions & 5 deletions sota-implementations/dqn/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def update(sampled_tensordict):
q_losses = torch.zeros(num_updates, device=device)

c_iter = iter(collector)
for i in range(len(collector)):
total_iter = len(collector)
for i in range(total_iter):
timeit.printevery(1000, total_iter, erase=True)
with timeit("collecting"):
data = next(c_iter)

Expand Down Expand Up @@ -226,10 +228,7 @@ def update(sampled_tensordict):
}
)

if i % 200 == 0:
timeit.print()
log_info.update(timeit.todict(prefix="time"))
timeit.erase()
log_info.update(timeit.todict(prefix="time"))

# Log all the information
if logger:
Expand Down
3 changes: 1 addition & 2 deletions sota-implementations/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,8 @@ def compile_rssms(module):
"t_sample": t_sample,
"t_preproc": t_preproc,
"t_collect": t_collect,
**timeit.todict(percall=False),
**timeit.todict(prefix="time"),
}
timeit.erase()
metrics_to_log.update(loss_metrics)

if logger is not None:
Expand Down
Loading
Loading