Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix skip_first for resumption #2986

Merged
merged 6 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions composer/profiler/profiler_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ def cyclic_schedule(
This function returns a schedule function that uses a cyclic profiling window. The resulting function can be
passed as the ``prof_schedule`` argument to the :class:`.Trainer`.

The cyclic window skips the first ``skip_first`` batches in every epoch. Then, it performs a cycle of
skipping ``wait`` batches, warming up for ``warmup`` batches, and recording ``active`` batches.
It repeats this cycle up to ``repeat`` times per epoch (or for the entire epoch, if ``repeat`` is 0).
This logic repeats every epoch.
The cyclic window skips the first ``skip_first`` + ``resumption_batch_idx`` batches in every epoch.
``resumption_batch_idx`` is passed to the the schedule fuction. It is the ``state.timestamp.batch_in_epoch``
when resuming training. Then, it performs a cycle of skipping ``wait`` batches, warming up for ``warmup``
batches, and recording ``active`` batches. It repeats this cycle up to ``repeat`` times per epoch (or
for the entire epoch, if ``repeat`` is 0). This logic repeats every epoch.

Args:
skip_first (int, optional): Number of batches to skip profiling at epoch start. Defaults to ``0``.
Expand All @@ -39,19 +40,20 @@ def cyclic_schedule(
Defaults to ``1``.

Returns:
(State -> ProfilerAction): A ``prof_schedule`` for the :class:`.Trainer`.
((State, int) -> ProfilerAction): A ``prof_schedule`` for the :class:`.Trainer`.
"""

def schedule(state: State):
def schedule(state: State, resumption_batch_idx: int = 0):
# do wait, then warump, then active, up to repeat times per cycle
cycle_len = wait + warmup + active
batch_idx = int(state.timestamp.batch_in_epoch)
if batch_idx < skip_first:
skip_first_after_resumption = skip_first + resumption_batch_idx
if batch_idx < skip_first_after_resumption:
return ProfilerAction.SKIP
if repeat != 0 and batch_idx >= cycle_len * repeat + skip_first:
if repeat != 0 and batch_idx >= cycle_len * repeat + skip_first_after_resumption:
# exhausted the repeat
return ProfilerAction.SKIP
position_in_cycle = (batch_idx - skip_first) % cycle_len
position_in_cycle = (batch_idx - skip_first_after_resumption) % cycle_len
if position_in_cycle < wait:
return ProfilerAction.SKIP
if position_in_cycle < wait + warmup:
Expand Down
10 changes: 9 additions & 1 deletion composer/profiler/torch_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ def __init__(
self.num_traces_to_keep = num_traces_to_keep
self.saved_traces = OrderedDict()
self.profiler: Optional[torch.profiler.profile] = None
# This is used in schedule function to update the skip_first, so
# that we count the skip_first starting from the 1st batch after resumption
bigning marked this conversation as resolved.
Show resolved Hide resolved
self.resumption_batch_idx: int = 0

def init(self, state: State, logger: Logger) -> None:
if state.profiler is None:
Expand All @@ -242,7 +245,7 @@ def scheduler_fn(torch_profiler_step: int) -> TorchProfilerAction:
del torch_profiler_step # the torch profiler step is unused. Using the composer timestamp instead.

assert state.profiler is not None
composer_profiler_action = state.profiler.schedule(state)
composer_profiler_action = state.profiler.schedule(state, self.resumption_batch_idx)
if composer_profiler_action == ProfilerAction.ACTIVE_AND_SAVE:
return TorchProfilerAction.RECORD_AND_SAVE
if composer_profiler_action == ProfilerAction.ACTIVE:
Expand Down Expand Up @@ -328,6 +331,11 @@ def handler_fn(prof: torch.profiler.profiler.profile):
)
self.profiler.__enter__()

def after_load(self, state: State, logger: Logger) -> None:
del logger
assert self.profiler is not None
bigning marked this conversation as resolved.
Show resolved Hide resolved
self.resumption_batch_idx = state.timestamp.batch_in_epoch

def batch_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
assert self.profiler is not None
Expand Down
40 changes: 39 additions & 1 deletion tests/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import pytest
import torch
from packaging import version
from torch.profiler.profiler import ProfilerAction as TorchProfilerAction

from composer.core import State
from composer.core import Engine, Event, State, Timestamp
from composer.loggers import Logger
from composer.profiler import Profiler, ProfilerAction, SystemProfiler, TorchProfiler, cyclic_schedule
from composer.profiler.utils import export_memory_timeline_html

Expand Down Expand Up @@ -170,3 +172,39 @@ def test_memory_timeline(tmp_path: pathlib.Path) -> None:
assert fig is not None, 'export_memory_timeline_html should return a figure when return_fig=True'
_, end = fig.gca().get_ylim()
assert round(end, 2) == 0.06


def test_skip_first_after_resumption(minimal_state: State) -> None:
skip_first = 1
wait = 2
warmup = 3
active = 4
repeat = 1
schedule = cyclic_schedule(skip_first=skip_first, wait=wait, warmup=warmup, active=active, repeat=repeat)
mock_trace_handler = MagicMock()
profiler = Profiler(
trace_handlers=[mock_trace_handler],
schedule=schedule,
)
profiler.bind_to_state(minimal_state)
minimal_state.profiler = profiler

assert len(profiler._callbacks) >= 1
assert isinstance(profiler._callbacks[-1], TorchProfiler)
torch_profiler = profiler._callbacks[-1]

# Create torch.profiler.profile
logger = Logger(minimal_state)
engine = Engine(state=minimal_state, logger=logger)
engine.run_event(Event.INIT)
assert torch_profiler.profiler is not None

minimal_state.timestamp = Timestamp(batch_in_epoch=7)
assert torch_profiler.profiler.schedule(0) == TorchProfilerAction.RECORD

# Load checkpoint at batch 4
minimal_state.timestamp = Timestamp(batch_in_epoch=4)
engine.run_event(Event.BEFORE_LOAD)
engine.run_event(Event.AFTER_LOAD)
minimal_state.timestamp = Timestamp(batch_in_epoch=7)
assert torch_profiler.profiler.schedule(0) == TorchProfilerAction.WARMUP
Loading