Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply less heavy profiling #270

Merged
merged 2 commits into from
Apr 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions torchtitan/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@

import contextlib
import os
import time

import torch
from torchtitan.config_manager import JobConfig
from torchtitan.logging_utils import logger

# the number of warmup steps before the active step in each profiling cycle
WARMUP = 3


@contextlib.contextmanager
def maybe_enable_profiling(config: JobConfig, *pos_args, **kwargs):
Expand All @@ -21,42 +25,44 @@ def maybe_enable_profiling(config: JobConfig, *pos_args, **kwargs):
dump_dir = config.job.dump_folder
save_trace_dir = config.profiling.save_traces_folder
trace_dir = os.path.join(dump_dir, save_trace_dir)
iter_frequency = config.profiling.profile_freq
profile_freq = config.profiling.profile_freq

_global_iter_count = 0

rank = torch.distributed.get_rank()

def trace_handler(prof):
nonlocal _global_iter_count
_global_iter_count += iter_frequency
_global_iter_count += profile_freq
curr_trace_dir_name = "iteration_" + str(_global_iter_count)
curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
if not os.path.exists(curr_trace_dir):
os.makedirs(curr_trace_dir, exist_ok=True)

logger.info(f"Dumping traces at step {_global_iter_count}")
begin = time.monotonic()
prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")
logger.info(
f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds"
)

logger.info(f"Profiling active. Traces will be saved at {trace_dir}")

if not os.path.exists(trace_dir):
os.makedirs(trace_dir, exist_ok=True)

warmup, active = WARMUP, 1
wait = profile_freq - (active + warmup)
assert (
wait >= 0
), "profile_freq must be greater than or equal to warmup + active"
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=iter_frequency - 2,
warmup=1,
active=1,
repeat=0,
),
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
on_trace_ready=trace_handler,
profile_memory=True,
with_stack=False,
record_shapes=True,
) as torch_profiler:
yield torch_profiler
else:
Expand Down
Loading