From 435ff8d8583bfe44188ea849ee65b62b2285fe11 Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Thu, 8 Feb 2024 21:07:12 -0800 Subject: [PATCH 1/3] fix skip_first for resumption --- composer/profiler/profiler_schedule.py | 9 +++--- composer/profiler/torch_profiler.py | 10 ++++++- tests/profiler/test_profiler.py | 40 +++++++++++++++++++++++++- 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/composer/profiler/profiler_schedule.py b/composer/profiler/profiler_schedule.py index 02b72b8a50..cc275982d5 100644 --- a/composer/profiler/profiler_schedule.py +++ b/composer/profiler/profiler_schedule.py @@ -42,16 +42,17 @@ def cyclic_schedule( (State -> ProfilerAction): A ``prof_schedule`` for the :class:`.Trainer`. """ - def schedule(state: State): + def schedule(state: State, resumption_batch_idx: int = 0): # do wait, then warump, then active, up to repeat times per cycle cycle_len = wait + warmup + active batch_idx = int(state.timestamp.batch_in_epoch) - if batch_idx < skip_first: + skip_first_after_resumption = skip_first + resumption_batch_idx + if batch_idx < skip_first_after_resumption: return ProfilerAction.SKIP - if repeat != 0 and batch_idx >= cycle_len * repeat + skip_first: + if repeat != 0 and batch_idx >= cycle_len * repeat + skip_first_after_resumption: # exhausted the repeat return ProfilerAction.SKIP - position_in_cycle = (batch_idx - skip_first) % cycle_len + position_in_cycle = (batch_idx - skip_first_after_resumption) % cycle_len if position_in_cycle < wait: return ProfilerAction.SKIP if position_in_cycle < wait + warmup: diff --git a/composer/profiler/torch_profiler.py b/composer/profiler/torch_profiler.py index 0f8f1f4fb0..0b74038af7 100644 --- a/composer/profiler/torch_profiler.py +++ b/composer/profiler/torch_profiler.py @@ -225,6 +225,9 @@ def __init__( self.num_traces_to_keep = num_traces_to_keep self.saved_traces = OrderedDict() self.profiler: Optional[torch.profiler.profile] = None + # This is used in schedule function to update the skip_first, so + # that we count the skip_first starting from the 1st batch after resumption + self.resumption_batch_idx: int = 0 def init(self, state: State, logger: Logger) -> None: if state.profiler is None: @@ -242,7 +245,7 @@ def scheduler_fn(torch_profiler_step: int) -> TorchProfilerAction: del torch_profiler_step # the torch profiler step is unused. Using the composer timestamp instead. assert state.profiler is not None - composer_profiler_action = state.profiler.schedule(state) + composer_profiler_action = state.profiler.schedule(state, self.resumption_batch_idx) if composer_profiler_action == ProfilerAction.ACTIVE_AND_SAVE: return TorchProfilerAction.RECORD_AND_SAVE if composer_profiler_action == ProfilerAction.ACTIVE: @@ -328,6 +331,11 @@ def handler_fn(prof: torch.profiler.profiler.profile): ) self.profiler.__enter__() + def after_load(self, state: State, logger: Logger) -> None: + del logger + assert self.profiler is not None + self.resumption_batch_idx = state.timestamp.batch_in_epoch + def batch_end(self, state: State, logger: Logger) -> None: del state, logger # unused assert self.profiler is not None diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 2ae9383d79..f13be17486 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -9,8 +9,10 @@ import pytest import torch from packaging import version +from torch.profiler.profiler import ProfilerAction as TorchProfilerAction -from composer.core import State +from composer.core import Engine, Event, State, Timestamp +from composer.loggers import Logger from composer.profiler import Profiler, ProfilerAction, SystemProfiler, TorchProfiler, cyclic_schedule from composer.profiler.utils import export_memory_timeline_html @@ -170,3 +172,39 @@ def test_memory_timeline(tmp_path: pathlib.Path) -> None: assert fig is not None, 'export_memory_timeline_html should return a figure when return_fig=True' _, end = fig.gca().get_ylim() assert round(end, 2) == 0.06 + + +def test_skip_first_after_resumption(minimal_state: State) -> None: + skip_first = 1 + wait = 2 + warmup = 3 + active = 4 + repeat = 1 + schedule = cyclic_schedule(skip_first=skip_first, wait=wait, warmup=warmup, active=active, repeat=repeat) + mock_trace_handler = MagicMock() + profiler = Profiler( + trace_handlers=[mock_trace_handler], + schedule=schedule, + ) + profiler.bind_to_state(minimal_state) + minimal_state.profiler = profiler + + assert len(profiler._callbacks) >= 1 + assert isinstance(profiler._callbacks[-1], TorchProfiler) + torch_profiler = profiler._callbacks[-1] + + # Create torch.profiler.profile + logger = Logger(minimal_state) + engine = Engine(state=minimal_state, logger=logger) + engine.run_event(Event.INIT) + assert torch_profiler.profiler is not None + + minimal_state.timestamp = Timestamp(batch_in_epoch=7) + assert torch_profiler.profiler.schedule(0) == TorchProfilerAction.RECORD + + # Load checkpoint at batch 4 + minimal_state.timestamp = Timestamp(batch_in_epoch=4) + engine.run_event(Event.BEFORE_LOAD) + engine.run_event(Event.AFTER_LOAD) + minimal_state.timestamp = Timestamp(batch_in_epoch=7) + assert torch_profiler.profiler.schedule(0) == TorchProfilerAction.WARMUP From 7df9a879ff8bedcebbc71fc6add8606d1bdc41b4 Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Thu, 8 Feb 2024 21:22:18 -0800 Subject: [PATCH 2/3] update doc --- composer/profiler/profiler_schedule.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/composer/profiler/profiler_schedule.py b/composer/profiler/profiler_schedule.py index cc275982d5..9e9409e088 100644 --- a/composer/profiler/profiler_schedule.py +++ b/composer/profiler/profiler_schedule.py @@ -23,10 +23,11 @@ def cyclic_schedule( This function returns a schedule function that uses a cyclic profiling window. The resulting function can be passed as the ``prof_schedule`` argument to the :class:`.Trainer`. - The cyclic window skips the first ``skip_first`` batches in every epoch. Then, it performs a cycle of - skipping ``wait`` batches, warming up for ``warmup`` batches, and recording ``active`` batches. - It repeats this cycle up to ``repeat`` times per epoch (or for the entire epoch, if ``repeat`` is 0). - This logic repeats every epoch. + The cyclic window skips the first ``skip_first`` + ``resumption_batch_idx`` batches in every epoch. + ``resumption_batch_idx`` is passed to the the schedule fuction. It is the ``state.timestamp.batch_in_epoch`` + when resuming training. Then, it performs a cycle of skipping ``wait`` batches, warming up for ``warmup`` + batches, and recording ``active`` batches. It repeats this cycle up to ``repeat`` times per epoch (or + for the entire epoch, if ``repeat`` is 0). This logic repeats every epoch. Args: skip_first (int, optional): Number of batches to skip profiling at epoch start. Defaults to ``0``. @@ -39,7 +40,7 @@ def cyclic_schedule( Defaults to ``1``. Returns: - (State -> ProfilerAction): A ``prof_schedule`` for the :class:`.Trainer`. + ((State, int) -> ProfilerAction): A ``prof_schedule`` for the :class:`.Trainer`. """ def schedule(state: State, resumption_batch_idx: int = 0): From cc6b00a28b000700d669cc84d18a48a408049525 Mon Sep 17 00:00:00 2001 From: Ning Wang Date: Fri, 9 Feb 2024 10:32:28 -0800 Subject: [PATCH 3/3] v2 --- composer/profiler/profiler.py | 2 ++ composer/profiler/profiler_schedule.py | 11 +++++++---- composer/profiler/torch_profiler.py | 9 +++------ 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/composer/profiler/profiler.py b/composer/profiler/profiler.py index c88c1f0912..824858b50a 100644 --- a/composer/profiler/profiler.py +++ b/composer/profiler/profiler.py @@ -118,6 +118,8 @@ def __init__( self.schedule = schedule self.state = None self._callbacks: List[Callback] = [] + # Used to count skip_first starting from resumption timestamp + self.resumption_batch_idx: int = 0 self.remote_filenames: List[str] = [] # First, add each remote file name to self.remote_filenames to create RemoteUploaderDownloader logger in trainer. [s3://bucket/path/to/file] # Then modify remote file name to be a local path to pass into torch_profiler and system_profiler. e.g: path/to/file diff --git a/composer/profiler/profiler_schedule.py b/composer/profiler/profiler_schedule.py index 9e9409e088..08d2549c2b 100644 --- a/composer/profiler/profiler_schedule.py +++ b/composer/profiler/profiler_schedule.py @@ -24,7 +24,7 @@ def cyclic_schedule( passed as the ``prof_schedule`` argument to the :class:`.Trainer`. The cyclic window skips the first ``skip_first`` + ``resumption_batch_idx`` batches in every epoch. - ``resumption_batch_idx`` is passed to the the schedule fuction. It is the ``state.timestamp.batch_in_epoch`` + ``resumption_batch_idx`` is accessed from state.profiler. It is the ``state.timestamp.batch_in_epoch`` when resuming training. Then, it performs a cycle of skipping ``wait`` batches, warming up for ``warmup`` batches, and recording ``active`` batches. It repeats this cycle up to ``repeat`` times per epoch (or for the entire epoch, if ``repeat`` is 0). This logic repeats every epoch. @@ -40,14 +40,17 @@ def cyclic_schedule( Defaults to ``1``. Returns: - ((State, int) -> ProfilerAction): A ``prof_schedule`` for the :class:`.Trainer`. + (State -> ProfilerAction): A ``prof_schedule`` for the :class:`.Trainer`. """ - def schedule(state: State, resumption_batch_idx: int = 0): + def schedule(state: State): # do wait, then warump, then active, up to repeat times per cycle cycle_len = wait + warmup + active batch_idx = int(state.timestamp.batch_in_epoch) - skip_first_after_resumption = skip_first + resumption_batch_idx + if state.profiler is not None: + skip_first_after_resumption = skip_first + state.profiler.resumption_batch_idx + else: + skip_first_after_resumption = skip_first if batch_idx < skip_first_after_resumption: return ProfilerAction.SKIP if repeat != 0 and batch_idx >= cycle_len * repeat + skip_first_after_resumption: diff --git a/composer/profiler/torch_profiler.py b/composer/profiler/torch_profiler.py index 0b74038af7..4112ec880c 100644 --- a/composer/profiler/torch_profiler.py +++ b/composer/profiler/torch_profiler.py @@ -225,9 +225,6 @@ def __init__( self.num_traces_to_keep = num_traces_to_keep self.saved_traces = OrderedDict() self.profiler: Optional[torch.profiler.profile] = None - # This is used in schedule function to update the skip_first, so - # that we count the skip_first starting from the 1st batch after resumption - self.resumption_batch_idx: int = 0 def init(self, state: State, logger: Logger) -> None: if state.profiler is None: @@ -245,7 +242,7 @@ def scheduler_fn(torch_profiler_step: int) -> TorchProfilerAction: del torch_profiler_step # the torch profiler step is unused. Using the composer timestamp instead. assert state.profiler is not None - composer_profiler_action = state.profiler.schedule(state, self.resumption_batch_idx) + composer_profiler_action = state.profiler.schedule(state) if composer_profiler_action == ProfilerAction.ACTIVE_AND_SAVE: return TorchProfilerAction.RECORD_AND_SAVE if composer_profiler_action == ProfilerAction.ACTIVE: @@ -333,8 +330,8 @@ def handler_fn(prof: torch.profiler.profiler.profile): def after_load(self, state: State, logger: Logger) -> None: del logger - assert self.profiler is not None - self.resumption_batch_idx = state.timestamp.batch_in_epoch + assert state.profiler is not None + state.profiler.resumption_batch_idx = int(state.timestamp.batch_in_epoch) def batch_end(self, state: State, logger: Logger) -> None: del state, logger # unused