-
Notifications
You must be signed in to change notification settings - Fork 479
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add staged train pipeline to torchrec (#1624)
Summary: Pull Request resolved: #1624 Variable stage train pipeline that can support arbitrary number of stages. Note that this is a pre-forward scheduling pipeline, this means the forward is expected executed after the last stage in the pipeline and the last stage in the pipeline has to be the start SDD stage. This is because ShardedModule forward (overwritten by pipeline) has to consume the input from pipeline context post input dist. The design is illustrated in the figure below. A pipeline is composed of K stages, each one depends on its precedence. Different stage may be executed in different or the same streams, with multiple batches concurrently executed in the same iteration. For example, in the image, batch[0] is the oldest batch that has passed H2D, preproc, and SDD stages, and will be running through comp in the current iteration. Batch[1], on the other hand, is the second oldest batch that will execute SDD in the current iteration. Similarly for other batches. {F1150156522} with this, 4 batches will be handled together in the same iteration, while each of them is under a different stage. When an iteration is done, there will be a advance step to copy newer data (slots larger index) to older data (slots with smaller index) so that they will be handled by the next stage in the next iteration. For SDD handling, we currently wrap on top of existing torchrec utilities. This part could potentially be improved in the future. Some additional things on top of this: * Adding fill callback for start sdd * Adding await sdd as a callback * Modifying progress to be walrus'able Differential Revision: D51182804 fbshipit-source-id: 90bcd3e8cb77dcf3d2e50d42ebc49d678451a1c3
- Loading branch information
1 parent
a860519
commit 51ac7d0
Showing
6 changed files
with
591 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from torchrec.distributed.train_pipeline.train_pipeline import ( # noqa | ||
_override_input_dist_forwards, # noqa | ||
_rewrite_model, # noqa | ||
_start_data_dist, # noqa | ||
_to_device, # noqa | ||
_wait_for_batch, # noqa | ||
ArgInfo, # noqa | ||
In, # noqa | ||
Out, # noqa | ||
PrefetchTrainPipelineSparseDist, # noqa | ||
Tracer, # noqa | ||
TrainPipeline, # noqa | ||
TrainPipelineBase, # noqa | ||
TrainPipelineContext, # noqa | ||
TrainPipelineSparseDist, # noqa | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
#!/usr/bin/env python3 | ||
|
||
from typing import Generic, List, Tuple, Union | ||
|
||
import torch | ||
|
||
from torchrec.distributed.model_parallel import ShardedModule | ||
|
||
from torchrec.distributed.train_pipeline.train_pipeline import ( | ||
_override_input_dist_forwards, | ||
_rewrite_model, | ||
_start_data_dist, | ||
In, | ||
TrainPipelineContext, | ||
) | ||
|
||
# TODO - make SparseDataDist stateless to support this being pushed deeper into the pipeline. | ||
class SparseDataDistUtil(Generic[In]): | ||
def __init__( | ||
self, | ||
model: torch.nn.Module, | ||
stream: torch.cuda.streams.Stream, | ||
apply_jit: bool = False, | ||
) -> None: | ||
super().__init__() | ||
self.model = model | ||
self.stream = stream | ||
self.apply_jit = apply_jit | ||
self.context = TrainPipelineContext() | ||
self.initiated = False | ||
self._pipelined_modules: List[ShardedModule] = [] | ||
|
||
# pyre-ignore | ||
self.original_forward = self.model.forward | ||
|
||
def forward_hook( | ||
module: torch.nn.Module, | ||
input: Union[torch.Tensor, Tuple[torch.Tensor]], | ||
output: Union[torch.Tensor, Tuple[torch.Tensor]], | ||
) -> None: | ||
self.wait_sparse_data_dist() | ||
|
||
self.model.register_forward_hook(forward_hook) | ||
|
||
def start_sparse_data_dist(self, batch: In) -> In: | ||
if not self.initiated: | ||
self._pipelined_modules, self.model = _rewrite_model( | ||
model=self.model, | ||
context=self.context, | ||
dist_stream=self.stream, | ||
batch=batch, | ||
apply_jit=self.apply_jit, | ||
) | ||
# initializes input dist, so we can override input dist forwards | ||
_start_data_dist(self._pipelined_modules, batch, self.context) | ||
_override_input_dist_forwards(self._pipelined_modules) | ||
self.initiated = True | ||
|
||
_start_data_dist(self._pipelined_modules, batch, self.context) | ||
|
||
return batch | ||
|
||
def wait_sparse_data_dist(self) -> None: | ||
self.context.module_contexts = self.context.module_contexts_next_batch.copy() | ||
self.context.input_dist_tensors_requests.clear() | ||
for names, awaitable in self.context.fused_splits_awaitables: | ||
for name, request in zip(names, awaitable.wait()): | ||
self.context.input_dist_tensors_requests[name] = request |
273 changes: 273 additions & 0 deletions
273
torchrec/distributed/train_pipeline/staged_train_pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,273 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
#!/usr/bin/env python3 | ||
|
||
import logging | ||
from dataclasses import dataclass | ||
|
||
from typing import Callable, cast, Generic, Iterator, List, Optional, Tuple | ||
|
||
import torch | ||
|
||
from torch.profiler import record_function | ||
from torchrec.distributed.train_pipeline.train_pipeline import In, Out | ||
from torchrec.distributed.utils import none_throws | ||
from torchrec.streamable import Pipelineable | ||
|
||
logger: logging.Logger = logging.getLogger(__name__) | ||
|
||
RunnableType = Callable[..., Out] | ||
StageOutputWithEvent = Tuple[Optional[Out], Optional[torch.cuda.Event]] | ||
|
||
|
||
def get_h2d_func(batch: Pipelineable, device: torch.device) -> Pipelineable: | ||
return batch.to(device, non_blocking=True) | ||
|
||
|
||
@dataclass | ||
class PipelineStage: | ||
""" | ||
A pipeline stage represents a transform to an input that is independent of the | ||
backwards() of the model. Examples include batch H2D transfer, GPU preproc, or | ||
gradient-less model processing. | ||
Args: | ||
name (str): Name of the stage. | ||
runnable (Callable[In, Out]): Function that performs a gradient-less | ||
transform. | ||
stream (torch.cuda.streams.Stream): Stream to run on. Often each stage has a | ||
unique stream, but having different pipelines share a stream provides more | ||
synchronization semantics. | ||
""" | ||
|
||
name: str | ||
runnable: RunnableType | ||
stream: torch.cuda.streams.Stream | ||
fill_callback: Optional[Callable[[], None]] = None | ||
|
||
|
||
class StagedTrainPipeline(Generic[In, Out]): | ||
""" | ||
StagedTrainPipeline orchestrates the pipelined execution of its constitutent stages | ||
from inputs of `data_iter`. Namely scheduling the execution of stages before model | ||
forward. | ||
NOTE: the SDD stage needs to be the final stage of the pipeline so that the | ||
ShardedModule forward can properly consume the SDD output. | ||
Calling progress on a StagedTrainPipeline provides an output that is equivalent to | ||
calling each of the pipeline stages in order. | ||
In the example below a fully synchronous will expose the `data_copy` and | ||
`gpu_preproc` calls. After pipelining, the `data_copy` of batch i+2 can be | ||
overlapped with the `gpu_preproc` of batch i+1 and the main model processing of | ||
batch i. | ||
Args: | ||
data_iter (Optional[Iterator[In]]): An iterator that produces the inputs to the | ||
pipeline. | ||
pipeline (List[PipelineStage]): A list of stages to execute. | ||
debug_mode (bool): Whether to enable debug mode. | ||
Example:: | ||
train_pipeline = StagedTrainPipeline( | ||
data_iter=data_iter, | ||
pipeline=[ | ||
PipelineStage( | ||
name="data_copy", | ||
runnable=get_h2d_func("cuda"), | ||
stream=torch.cuda.Stream(), | ||
), | ||
PipelineStage( | ||
name="gpu_preproc", | ||
runnable=gpu_preproc, | ||
stream=torch.cuda.Stream(), | ||
), | ||
] | ||
) | ||
while batch_for_main_forward := train_pipeline.progress(): | ||
loss = model(batch_for_main_forward) | ||
loss.backward() | ||
optimizer.step() | ||
""" | ||
|
||
def __init__( | ||
self, | ||
data_iter: Optional[Iterator[In]], | ||
pipeline: List[PipelineStage], | ||
debug_mode: bool = False, | ||
) -> None: | ||
self.pipeline = pipeline | ||
self.batch_results: List[Optional[StageOutputWithEvent]] = cast( | ||
List[Optional[StageOutputWithEvent]], [None] * len(self.pipeline) | ||
) | ||
self.initialized = False | ||
self.data_iter = data_iter | ||
|
||
self._num_steps = 0 | ||
self._debug_mode = debug_mode | ||
|
||
@property | ||
def num_stages(self) -> int: | ||
return len(self.pipeline) | ||
|
||
def _next_batch(self) -> Optional[In]: | ||
batch = next(none_throws(self.data_iter, "`data_iter` cannot be none"), None) | ||
return batch | ||
|
||
def _get_indices(self, idx: int) -> Tuple[int, int]: | ||
""" | ||
Returns: | ||
Tuple[int, int]: (stage_idx, batch_result_idx) | ||
""" | ||
return (idx, self.num_stages - idx - 1) | ||
|
||
def _advance(self) -> Optional[Out]: | ||
# left shifts all batch results. | ||
out = self.batch_results[0] | ||
for idx in range(self.num_stages - 1): | ||
self.batch_results[idx] = self.batch_results[idx + 1] | ||
self.batch_results[-1] = None | ||
if out is None: | ||
return out | ||
|
||
return out[0] | ||
|
||
def _run_with_event( | ||
self, | ||
runnable: RunnableType, | ||
event: Optional[torch.cuda.Event], | ||
inputs: Optional[Pipelineable], | ||
stream: torch.cuda.streams.Stream, | ||
) -> StageOutputWithEvent: | ||
if inputs is None: | ||
return (None, None) | ||
with torch.cuda.stream(stream): | ||
# If there is no previous event, data is entering the pipeline | ||
if event is not None: | ||
event.wait(stream) | ||
inputs.record_stream(stream) | ||
|
||
output = runnable(inputs) | ||
new_event = torch.cuda.Event() | ||
new_event.record(stream) | ||
|
||
return (output, new_event) | ||
|
||
def _run_stage( | ||
self, | ||
batch_offset: int, | ||
stage_idx: int, | ||
fill: bool = False, | ||
) -> StageOutputWithEvent: | ||
""" | ||
Each stage of the pipeline MUST have an input and output. | ||
If the input is None, it means there is no more data to process. | ||
It will short circuit and NOT execute the runnable. | ||
""" | ||
stage = self.pipeline[stage_idx] | ||
|
||
with record_function( | ||
f"## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##" | ||
): | ||
batch_to_wait_with_event = None | ||
new_result = (None, None) | ||
if stage_idx == 0: | ||
batch = self._next_batch() | ||
batch_to_wait, event = batch, None | ||
else: | ||
batch_to_wait_with_event = self.batch_results[batch_offset] | ||
assert batch_to_wait_with_event is not None | ||
batch_to_wait, event = batch_to_wait_with_event | ||
|
||
new_result = self._run_with_event( | ||
runnable=stage.runnable, | ||
event=event, | ||
inputs=batch_to_wait, | ||
stream=stage.stream, | ||
) | ||
|
||
self.batch_results[batch_offset] = new_result | ||
if self._debug_mode: | ||
logger.info( | ||
"Running", | ||
f"## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##", | ||
) | ||
|
||
if fill and (fill_callback := stage.fill_callback) is not None: | ||
if self._debug_mode: | ||
logger.info(f"Finished callback for {stage.name}") | ||
fill_callback() | ||
|
||
return new_result | ||
|
||
def _fill_pipeline(self) -> None: | ||
""" | ||
There should always be `self.num_stages` batches in flight. This function | ||
initializes the pipeline by filling it with `self.num_stages - 1` batches. | ||
For a 5 stage pipeline during `_fill_pipeline`: | ||
batch 0: stages 0, 1, 2, 3 will be run | ||
batch 1: stages 0, 1, 2 will be run | ||
batch 2: stages 0, 1 will be run | ||
batch 3: stage 0 will be run | ||
batch 4: starts on `progress()` | ||
""" | ||
|
||
for batch_offset in range(self.num_stages): | ||
stages_to_run = self.num_stages - batch_offset | ||
for stage_idx in range(stages_to_run): | ||
self._run_stage( | ||
batch_offset=batch_offset, stage_idx=stage_idx, fill=True | ||
) | ||
|
||
self.initialized = True | ||
if self._debug_mode: | ||
logger.info("Finished fill pipeline") | ||
|
||
def progress( | ||
self, | ||
data_iter: Optional[Iterator[In]] = None, | ||
run_stage_order: Optional[List[int]] = None, | ||
) -> Optional[Out]: | ||
""" | ||
The stages process data in reverse order, so stage_0 processes the newest data. | ||
Stage order can be modified through the `run_stage_order` arg. This is useful in | ||
achieving better overlap for different stages. | ||
Args: | ||
data_iter (Optional[Iterator[In]]): An iterator that produces the inputs to | ||
the pipeline. | ||
run_stage_order (Optional[List[int]]): Specifies the order of running the | ||
stages. If `None`, the pipeline will run stages in the original order, | ||
i.e. stage_0 -> stage_1 -> ... -> stage_n. | ||
Returns: | ||
Optional[Out]: Output of the final stage. `None` signifies that the | ||
dataloader iterator is depleted | ||
""" | ||
if self.data_iter is None: | ||
self.data_iter = none_throws(data_iter, "`data_iter` cannot be none") | ||
|
||
if not self.initialized: | ||
self._fill_pipeline() | ||
|
||
self._num_steps += 1 | ||
|
||
output = self._advance() | ||
if not run_stage_order: | ||
run_stage_order = list(range(self.num_stages)) | ||
for op_idx in run_stage_order: | ||
stage_idx, batch_result_idx = self._get_indices(op_idx) | ||
self._run_stage( | ||
batch_offset=batch_result_idx, | ||
stage_idx=stage_idx, | ||
) | ||
|
||
return output |
Oops, something went wrong.