Skip to content

Commit

Permalink
Revert "fix skip_first for resumption (#2986)" (#2991)
Browse files Browse the repository at this point in the history
This reverts commit b310a9b.
  • Loading branch information
bigning authored Feb 9, 2024
1 parent b310a9b commit c03da27
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 58 deletions.
2 changes: 0 additions & 2 deletions composer/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ def __init__(
self.schedule = schedule
self.state = None
self._callbacks: List[Callback] = []
# Used to count skip_first starting from resumption timestamp
self.resumption_batch_idx: int = 0
self.remote_filenames: List[str] = []
# First, add each remote file name to self.remote_filenames to create RemoteUploaderDownloader logger in trainer. [s3://bucket/path/to/file]
# Then modify remote file name to be a local path to pass into torch_profiler and system_profiler. e.g: path/to/file
Expand Down
19 changes: 7 additions & 12 deletions composer/profiler/profiler_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ def cyclic_schedule(
This function returns a schedule function that uses a cyclic profiling window. The resulting function can be
passed as the ``prof_schedule`` argument to the :class:`.Trainer`.
The cyclic window skips the first ``skip_first`` + ``resumption_batch_idx`` batches in every epoch.
``resumption_batch_idx`` is accessed from state.profiler. It is the ``state.timestamp.batch_in_epoch``
when resuming training. Then, it performs a cycle of skipping ``wait`` batches, warming up for ``warmup``
batches, and recording ``active`` batches. It repeats this cycle up to ``repeat`` times per epoch (or
for the entire epoch, if ``repeat`` is 0). This logic repeats every epoch.
The cyclic window skips the first ``skip_first`` batches in every epoch. Then, it performs a cycle of
skipping ``wait`` batches, warming up for ``warmup`` batches, and recording ``active`` batches.
It repeats this cycle up to ``repeat`` times per epoch (or for the entire epoch, if ``repeat`` is 0).
This logic repeats every epoch.
Args:
skip_first (int, optional): Number of batches to skip profiling at epoch start. Defaults to ``0``.
Expand All @@ -47,16 +46,12 @@ def schedule(state: State):
# do wait, then warump, then active, up to repeat times per cycle
cycle_len = wait + warmup + active
batch_idx = int(state.timestamp.batch_in_epoch)
if state.profiler is not None:
skip_first_after_resumption = skip_first + state.profiler.resumption_batch_idx
else:
skip_first_after_resumption = skip_first
if batch_idx < skip_first_after_resumption:
if batch_idx < skip_first:
return ProfilerAction.SKIP
if repeat != 0 and batch_idx >= cycle_len * repeat + skip_first_after_resumption:
if repeat != 0 and batch_idx >= cycle_len * repeat + skip_first:
# exhausted the repeat
return ProfilerAction.SKIP
position_in_cycle = (batch_idx - skip_first_after_resumption) % cycle_len
position_in_cycle = (batch_idx - skip_first) % cycle_len
if position_in_cycle < wait:
return ProfilerAction.SKIP
if position_in_cycle < wait + warmup:
Expand Down
5 changes: 0 additions & 5 deletions composer/profiler/torch_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,6 @@ def handler_fn(prof: torch.profiler.profiler.profile):
)
self.profiler.__enter__()

def after_load(self, state: State, logger: Logger) -> None:
del logger
assert state.profiler is not None
state.profiler.resumption_batch_idx = int(state.timestamp.batch_in_epoch)

def batch_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
assert self.profiler is not None
Expand Down
40 changes: 1 addition & 39 deletions tests/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import pytest
import torch
from packaging import version
from torch.profiler.profiler import ProfilerAction as TorchProfilerAction

from composer.core import Engine, Event, State, Timestamp
from composer.loggers import Logger
from composer.core import State
from composer.profiler import Profiler, ProfilerAction, SystemProfiler, TorchProfiler, cyclic_schedule
from composer.profiler.utils import export_memory_timeline_html

Expand Down Expand Up @@ -172,39 +170,3 @@ def test_memory_timeline(tmp_path: pathlib.Path) -> None:
assert fig is not None, 'export_memory_timeline_html should return a figure when return_fig=True'
_, end = fig.gca().get_ylim()
assert round(end, 2) == 0.06


def test_skip_first_after_resumption(minimal_state: State) -> None:
skip_first = 1
wait = 2
warmup = 3
active = 4
repeat = 1
schedule = cyclic_schedule(skip_first=skip_first, wait=wait, warmup=warmup, active=active, repeat=repeat)
mock_trace_handler = MagicMock()
profiler = Profiler(
trace_handlers=[mock_trace_handler],
schedule=schedule,
)
profiler.bind_to_state(minimal_state)
minimal_state.profiler = profiler

assert len(profiler._callbacks) >= 1
assert isinstance(profiler._callbacks[-1], TorchProfiler)
torch_profiler = profiler._callbacks[-1]

# Create torch.profiler.profile
logger = Logger(minimal_state)
engine = Engine(state=minimal_state, logger=logger)
engine.run_event(Event.INIT)
assert torch_profiler.profiler is not None

minimal_state.timestamp = Timestamp(batch_in_epoch=7)
assert torch_profiler.profiler.schedule(0) == TorchProfilerAction.RECORD

# Load checkpoint at batch 4
minimal_state.timestamp = Timestamp(batch_in_epoch=4)
engine.run_event(Event.BEFORE_LOAD)
engine.run_event(Event.AFTER_LOAD)
minimal_state.timestamp = Timestamp(batch_in_epoch=7)
assert torch_profiler.profiler.schedule(0) == TorchProfilerAction.WARMUP

0 comments on commit c03da27

Please sign in to comment.