Skip to content

Commit

Permalink
Add staged train pipeline to torchrec (#1624)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1624

Variable stage train pipeline that can support arbitrary number of stages. Note that this is a pre-forward scheduling pipeline, this means the forward is expected executed after the last stage in the pipeline and the last stage in the pipeline has to be the start SDD stage. This is because ShardedModule forward (overwritten by pipeline) has to consume the input from pipeline context post input dist.

The design is illustrated in the figure below. A pipeline is composed of K stages, each one depends on its precedence. Different stage may be executed in different or the same streams, with multiple batches concurrently executed in the same iteration. For example, in the image, batch[0] is the oldest batch that has passed H2D, preproc, and SDD stages, and will be running through comp in the current iteration. Batch[1], on the other hand, is the second oldest batch that will execute SDD in the current iteration. Similarly for other batches.

 {F1150156522}

with this, 4 batches will be handled together in the same iteration, while each of them is under a different stage. When an iteration is done, there will be a advance step to copy newer data (slots larger index) to older data (slots with smaller index) so that they will be handled by the next stage in the next iteration.

For SDD handling, we currently wrap on top of existing torchrec utilities. This part could potentially be improved in the future.

Some additional things on top of this:
* Adding fill callback for start sdd
* Adding await sdd as a callback
* Modifying progress to be walrus'able

Differential Revision: D51182804

fbshipit-source-id: 90bcd3e8cb77dcf3d2e50d42ebc49d678451a1c3
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Feb 13, 2024
1 parent a860519 commit 51ac7d0
Show file tree
Hide file tree
Showing 6 changed files with 591 additions and 2 deletions.
24 changes: 24 additions & 0 deletions torchrec/distributed/train_pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from torchrec.distributed.train_pipeline.train_pipeline import ( # noqa
_override_input_dist_forwards, # noqa
_rewrite_model, # noqa
_start_data_dist, # noqa
_to_device, # noqa
_wait_for_batch, # noqa
ArgInfo, # noqa
In, # noqa
Out, # noqa
PrefetchTrainPipelineSparseDist, # noqa
Tracer, # noqa
TrainPipeline, # noqa
TrainPipelineBase, # noqa
TrainPipelineContext, # noqa
TrainPipelineSparseDist, # noqa
)
75 changes: 75 additions & 0 deletions torchrec/distributed/train_pipeline/sdd_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#!/usr/bin/env python3

from typing import Generic, List, Tuple, Union

import torch

from torchrec.distributed.model_parallel import ShardedModule

from torchrec.distributed.train_pipeline.train_pipeline import (
_override_input_dist_forwards,
_rewrite_model,
_start_data_dist,
In,
TrainPipelineContext,
)

# TODO - make SparseDataDist stateless to support this being pushed deeper into the pipeline.
class SparseDataDistUtil(Generic[In]):
def __init__(
self,
model: torch.nn.Module,
stream: torch.cuda.streams.Stream,
apply_jit: bool = False,
) -> None:
super().__init__()
self.model = model
self.stream = stream
self.apply_jit = apply_jit
self.context = TrainPipelineContext()
self.initiated = False
self._pipelined_modules: List[ShardedModule] = []

# pyre-ignore
self.original_forward = self.model.forward

def forward_hook(
module: torch.nn.Module,
input: Union[torch.Tensor, Tuple[torch.Tensor]],
output: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> None:
self.wait_sparse_data_dist()

self.model.register_forward_hook(forward_hook)

def start_sparse_data_dist(self, batch: In) -> In:
if not self.initiated:
self._pipelined_modules, self.model = _rewrite_model(
model=self.model,
context=self.context,
dist_stream=self.stream,
batch=batch,
apply_jit=self.apply_jit,
)
# initializes input dist, so we can override input dist forwards
_start_data_dist(self._pipelined_modules, batch, self.context)
_override_input_dist_forwards(self._pipelined_modules)
self.initiated = True

_start_data_dist(self._pipelined_modules, batch, self.context)

return batch

def wait_sparse_data_dist(self) -> None:
self.context.module_contexts = self.context.module_contexts_next_batch.copy()
self.context.input_dist_tensors_requests.clear()
for names, awaitable in self.context.fused_splits_awaitables:
for name, request in zip(names, awaitable.wait()):
self.context.input_dist_tensors_requests[name] = request
273 changes: 273 additions & 0 deletions torchrec/distributed/train_pipeline/staged_train_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#!/usr/bin/env python3

import logging
from dataclasses import dataclass

from typing import Callable, cast, Generic, Iterator, List, Optional, Tuple

import torch

from torch.profiler import record_function
from torchrec.distributed.train_pipeline.train_pipeline import In, Out
from torchrec.distributed.utils import none_throws
from torchrec.streamable import Pipelineable

logger: logging.Logger = logging.getLogger(__name__)

RunnableType = Callable[..., Out]
StageOutputWithEvent = Tuple[Optional[Out], Optional[torch.cuda.Event]]


def get_h2d_func(batch: Pipelineable, device: torch.device) -> Pipelineable:
return batch.to(device, non_blocking=True)


@dataclass
class PipelineStage:
"""
A pipeline stage represents a transform to an input that is independent of the
backwards() of the model. Examples include batch H2D transfer, GPU preproc, or
gradient-less model processing.
Args:
name (str): Name of the stage.
runnable (Callable[In, Out]): Function that performs a gradient-less
transform.
stream (torch.cuda.streams.Stream): Stream to run on. Often each stage has a
unique stream, but having different pipelines share a stream provides more
synchronization semantics.
"""

name: str
runnable: RunnableType
stream: torch.cuda.streams.Stream
fill_callback: Optional[Callable[[], None]] = None


class StagedTrainPipeline(Generic[In, Out]):
"""
StagedTrainPipeline orchestrates the pipelined execution of its constitutent stages
from inputs of `data_iter`. Namely scheduling the execution of stages before model
forward.
NOTE: the SDD stage needs to be the final stage of the pipeline so that the
ShardedModule forward can properly consume the SDD output.
Calling progress on a StagedTrainPipeline provides an output that is equivalent to
calling each of the pipeline stages in order.
In the example below a fully synchronous will expose the `data_copy` and
`gpu_preproc` calls. After pipelining, the `data_copy` of batch i+2 can be
overlapped with the `gpu_preproc` of batch i+1 and the main model processing of
batch i.
Args:
data_iter (Optional[Iterator[In]]): An iterator that produces the inputs to the
pipeline.
pipeline (List[PipelineStage]): A list of stages to execute.
debug_mode (bool): Whether to enable debug mode.
Example::
train_pipeline = StagedTrainPipeline(
data_iter=data_iter,
pipeline=[
PipelineStage(
name="data_copy",
runnable=get_h2d_func("cuda"),
stream=torch.cuda.Stream(),
),
PipelineStage(
name="gpu_preproc",
runnable=gpu_preproc,
stream=torch.cuda.Stream(),
),
]
)
while batch_for_main_forward := train_pipeline.progress():
loss = model(batch_for_main_forward)
loss.backward()
optimizer.step()
"""

def __init__(
self,
data_iter: Optional[Iterator[In]],
pipeline: List[PipelineStage],
debug_mode: bool = False,
) -> None:
self.pipeline = pipeline
self.batch_results: List[Optional[StageOutputWithEvent]] = cast(
List[Optional[StageOutputWithEvent]], [None] * len(self.pipeline)
)
self.initialized = False
self.data_iter = data_iter

self._num_steps = 0
self._debug_mode = debug_mode

@property
def num_stages(self) -> int:
return len(self.pipeline)

def _next_batch(self) -> Optional[In]:
batch = next(none_throws(self.data_iter, "`data_iter` cannot be none"), None)
return batch

def _get_indices(self, idx: int) -> Tuple[int, int]:
"""
Returns:
Tuple[int, int]: (stage_idx, batch_result_idx)
"""
return (idx, self.num_stages - idx - 1)

def _advance(self) -> Optional[Out]:
# left shifts all batch results.
out = self.batch_results[0]
for idx in range(self.num_stages - 1):
self.batch_results[idx] = self.batch_results[idx + 1]
self.batch_results[-1] = None
if out is None:
return out

return out[0]

def _run_with_event(
self,
runnable: RunnableType,
event: Optional[torch.cuda.Event],
inputs: Optional[Pipelineable],
stream: torch.cuda.streams.Stream,
) -> StageOutputWithEvent:
if inputs is None:
return (None, None)
with torch.cuda.stream(stream):
# If there is no previous event, data is entering the pipeline
if event is not None:
event.wait(stream)
inputs.record_stream(stream)

output = runnable(inputs)
new_event = torch.cuda.Event()
new_event.record(stream)

return (output, new_event)

def _run_stage(
self,
batch_offset: int,
stage_idx: int,
fill: bool = False,
) -> StageOutputWithEvent:
"""
Each stage of the pipeline MUST have an input and output.
If the input is None, it means there is no more data to process.
It will short circuit and NOT execute the runnable.
"""
stage = self.pipeline[stage_idx]

with record_function(
f"## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##"
):
batch_to_wait_with_event = None
new_result = (None, None)
if stage_idx == 0:
batch = self._next_batch()
batch_to_wait, event = batch, None
else:
batch_to_wait_with_event = self.batch_results[batch_offset]
assert batch_to_wait_with_event is not None
batch_to_wait, event = batch_to_wait_with_event

new_result = self._run_with_event(
runnable=stage.runnable,
event=event,
inputs=batch_to_wait,
stream=stage.stream,
)

self.batch_results[batch_offset] = new_result
if self._debug_mode:
logger.info(
"Running",
f"## Pipeline Stage {stage_idx} : {stage.name} for batch {batch_offset + self._num_steps} ##",
)

if fill and (fill_callback := stage.fill_callback) is not None:
if self._debug_mode:
logger.info(f"Finished callback for {stage.name}")
fill_callback()

return new_result

def _fill_pipeline(self) -> None:
"""
There should always be `self.num_stages` batches in flight. This function
initializes the pipeline by filling it with `self.num_stages - 1` batches.
For a 5 stage pipeline during `_fill_pipeline`:
batch 0: stages 0, 1, 2, 3 will be run
batch 1: stages 0, 1, 2 will be run
batch 2: stages 0, 1 will be run
batch 3: stage 0 will be run
batch 4: starts on `progress()`
"""

for batch_offset in range(self.num_stages):
stages_to_run = self.num_stages - batch_offset
for stage_idx in range(stages_to_run):
self._run_stage(
batch_offset=batch_offset, stage_idx=stage_idx, fill=True
)

self.initialized = True
if self._debug_mode:
logger.info("Finished fill pipeline")

def progress(
self,
data_iter: Optional[Iterator[In]] = None,
run_stage_order: Optional[List[int]] = None,
) -> Optional[Out]:
"""
The stages process data in reverse order, so stage_0 processes the newest data.
Stage order can be modified through the `run_stage_order` arg. This is useful in
achieving better overlap for different stages.
Args:
data_iter (Optional[Iterator[In]]): An iterator that produces the inputs to
the pipeline.
run_stage_order (Optional[List[int]]): Specifies the order of running the
stages. If `None`, the pipeline will run stages in the original order,
i.e. stage_0 -> stage_1 -> ... -> stage_n.
Returns:
Optional[Out]: Output of the final stage. `None` signifies that the
dataloader iterator is depleted
"""
if self.data_iter is None:
self.data_iter = none_throws(data_iter, "`data_iter` cannot be none")

if not self.initialized:
self._fill_pipeline()

self._num_steps += 1

output = self._advance()
if not run_stage_order:
run_stage_order = list(range(self.num_stages))
for op_idx in run_stage_order:
stage_idx, batch_result_idx = self._get_indices(op_idx)
self._run_stage(
batch_offset=batch_result_idx,
stage_idx=stage_idx,
)

return output
Loading

0 comments on commit 51ac7d0

Please sign in to comment.