Skip to content

Commit

Permalink
Introducing ORTPipelineModule - DeepSpeed Parallel Pipeline Support. (m…
Browse files Browse the repository at this point in the history
…icrosoft#20287)

### Description
Introducing a new class ORTPipelineModule to handle wrapping layers in
DeepSpeed pipeline parallel.


### Motivation and Context
To support pipeline parallelism on ORTModule.

This PR will include an initial support of deepspeed Pipeline
parallelism.

- [x] Support Pipeline parallel where layers are nn Modules in
Sequential.
- [ ] Support LayerSpec and TiedLayerSpec
- [ ] Enable partitioning to accept List
- [ ] Full-GPU Graph Consolidation
- [ ] Subgraph Merging for Inference
  • Loading branch information
AdamLouly authored Apr 18, 2024
1 parent f664f91 commit ee74fb6
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 0 deletions.
7 changes: 7 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,9 @@ if (onnxruntime_ENABLE_TRAINING)
file(GLOB onnxruntime_python_ortmodule_graph_optimizers_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/graph_optimizers/*"
)
file(GLOB onnxruntime_python_ortmodule_pipe_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/experimental/pipe/*"
)
file(GLOB onnxruntime_python_ort_triton_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ort_triton/*.py"
)
Expand Down Expand Up @@ -756,6 +759,7 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/experimental/pipe
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/kernel
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils
Expand Down Expand Up @@ -806,6 +810,9 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ortmodule_graph_optimizers_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ortmodule_pipe_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/experimental/pipe/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ort_triton_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/
Expand Down
28 changes: 28 additions & 0 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,31 @@ for epoch in range(start_epoch, n_epochs):
```
Check [LoadBalancingDistributedBatchSampler implementation](../orttraining/orttraining/python/training/utils/data/sampler.py) for more details.
## 8 Using ORTPipelineModule for Deepspeed Pipeline Parallel
You can use `ORTPipelineModule` to support Deepspeed Pipeline Parallelism. Here's how you can integrate it into your pipeline:

```python
from onnxruntime.training.ortmodule import DebugOptions
from onnxruntime.training.ortmodule.experimental.pipe import ORTPipelineModule
# Create a debug configuration if needed
# Since we're exporting multiple graphs here, this will generate multiple graphs with their index added as a prefix to differentiate them.
debug_options = DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="model_name")
# Keep your deepspeed script the same and use ORTPipelineModule instead of PipelineModule
# Initialize the ORTPipelineModule
pipeline_module = ORTPipelineModule(
layers,
num_stages=2, # Set your number of stages
base_seed=1234,
partition_method="parameters",
debug_options=debug_options # Pass the debug configuration if needed
)
# Keep the rest of the script as it is.
```

Check [ORTPipelineModule implementation](../orttraining/orttraining/python/training/ortmodule/experimental/pipe/_ort_pipeline_module.py) for more details.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from ._ort_pipeline_module import ORTPipelineModule # noqa: F401
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import importlib.metadata
from functools import partial

import torch.nn as nn
from deepspeed.pipe import LayerSpec, PipelineModule, TiedLayerSpec
from deepspeed.runtime import utils as ds_utils
from deepspeed.runtime.activation_checkpointing import checkpointing
from packaging.version import Version

from onnxruntime.training.ortmodule import DebugOptions, ORTModule

# Check if DeepSpeed is installed and meets the minimum version requirement
minimum_version = Version("0.9.0")
installed_version = Version(importlib.metadata.version("deepspeed"))

if installed_version < minimum_version:
raise ImportError(f"DeepSpeed >= {minimum_version} is required, but {installed_version} is installed.")


class ORTPipelineModule(PipelineModule):
"""ORTPipelineModule pipeline module.
A customized version of DeepSpeed's PipelineModule that wraps each neural network layer
with ONNX Runtime's ORTModule. This modification allows leveraging ONNX Runtime optimizations
for the forward and backward passes, potentially enhancing execution performance and efficiency.
Please locate the "Using ORTPipelineModule for Deepspeed Pipeline Parallel" section in the "docs/ORTModule_Training_Guidelines.md" file of the ORT repository for more information.
.. note::
Pipeline parallelism is not compatible with ZeRO-2 and ZeRO-3.
Args:
layers (Iterable): A sequence of layers defining pipeline structure. Can be a ``torch.nn.Sequential`` module.
num_stages (int, optional): The degree of pipeline parallelism. If not specified, ``topology`` must be provided.
topology (``deepspeed.runtime.pipe.ProcessTopology``, optional): Defines the axes of parallelism axes for training. Must be provided if ``num_stages`` is ``None``.
loss_fn (callable, optional): Loss is computed ``loss = loss_fn(outputs, label)``
seed_layers(bool, optional): Use a different seed for each layer. Defaults to False.
seed_fn(type, optional): The custom seed generating function. Defaults to random seed generator.
base_seed (int, optional): The starting seed. Defaults to 1234.
partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'.
activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.
activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.
checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering.
debug_options(onnxruntime.training.ortmodule.DebugOptions): An instance of onnxruntime.training.ortmodule.DebugOptions or None.
If provided, it will be used to configure debugging options for ORTModule, This is done so we can add the name of the layer to avoid overwriting the ONNX files.
"""

def __init__(
self,
layers,
num_stages=None,
topology=None,
loss_fn=None,
seed_layers=False,
seed_fn=None,
base_seed=1234,
partition_method="parameters",
activation_checkpoint_interval=0,
activation_checkpoint_func=checkpointing.checkpoint,
checkpointable_layers=None,
debug_options=None,
):
"""
Initialize the ORTPipelineModule with the option to include ONNX Runtime debug options.
"""

self.ort_kwargs = {"debug_options": debug_options} if debug_options is not None else {}

super().__init__(
layers,
num_stages,
topology,
loss_fn,
seed_layers,
seed_fn,
base_seed,
partition_method,
activation_checkpoint_interval,
activation_checkpoint_func,
checkpointable_layers,
)

def _build(self):
"""
This method does the same thing as PipelineModule._build() method, the only difference is that it wraps each layer with ORTModule.
It also handles saving ONNX models with debug options in case of exporting multiple models.
"""
specs = self._layer_specs

for local_idx, layer in enumerate(specs[self._local_start : self._local_stop]):
layer_idx = local_idx + self._local_start
if self.seed_layers:
if self.seed_fn:
self.seed_fn(self.base_seed + layer_idx)
else:
ds_utils.set_random_seed(self.base_seed + layer_idx)

# Recursively build PipelineModule objects
if isinstance(layer, PipelineModule):
raise NotImplementedError("RECURSIVE BUILD NOT YET IMPLEMENTED")

# TODO: Support wrapping for LayerSpec and TiedLayerSpec in addition to nn.Module in sequential.
# Currently, we only support wrapping nn.Module instances.

# LayerSpec objects contain an nn.Module that should be allocated now.
elif isinstance(layer, nn.Module):
name = str(layer_idx)

if "debug_options" in self.ort_kwargs:
new_onnx_prefix = name + "_" + self.ort_kwargs["debug_options"].onnx_prefix
parallel_debug_options = DebugOptions(
self.ort_kwargs["debug_options"].log_level,
self.ort_kwargs["debug_options"].save_onnx,
new_onnx_prefix,
)
wrapped_layer = ORTModule(layer, parallel_debug_options)
else:
wrapped_layer = ORTModule(layer)

self.forward_funcs.append(wrapped_layer)
self.fwd_map.update({name: len(self.forward_funcs) - 1})
self.add_module(name, wrapped_layer)

# TiedLayerSpec objects contain an nn.Module that should be allocated now.
elif isinstance(layer, TiedLayerSpec):
# Build and register the module if we haven't seen it before.
if layer.key not in self.tied_modules:
self.tied_modules[layer.key] = layer.build()
self.tied_weight_attrs[layer.key] = layer.tied_weight_attr

if layer.forward_fn is None:
# Just use forward()
self.forward_funcs.append(self.tied_modules[layer.key])
else:
# User specified fn with args (module, input)
self.forward_funcs.append(partial(layer.forward_fn, self.tied_modules[layer.key]))

# LayerSpec objects contain an nn.Module that should be allocated now.
elif isinstance(layer, LayerSpec):
module = layer.build()
name = str(layer_idx)
self.forward_funcs.append(module)
self.fwd_map.update({name: len(self.forward_funcs) - 1})
self.add_module(name, module)

# Last option: layer may be a functional (e.g., lambda). We do nothing in
# that case and just use it in forward()
else:
self.forward_funcs.append(layer)

# All pipeline parameters should be considered as model parallel in the context
# of our FP16 optimizer
for p in self.parameters():
p.ds_pipe_replicated = False
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ def run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log):
run_subprocess(command, cwd=cwd, log=log).check_returncode()


def run_ort_pipeline_module_tests(cwd, log):
log.debug("Running: ORTPipelineModule tests")

command = [
"deepspeed",
"orttraining_test_ort_pipeline_module.py",
"--deepspeed_config",
"orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json",
]

run_subprocess(command, cwd=cwd, log=log).check_returncode()


def run_ortmodule_fairscale_sharded_optimizer_tests(cwd, log, data_dir):
log.debug("Running: ORTModule fairscale sharded optimizer tests")
command = [
Expand Down Expand Up @@ -94,6 +107,7 @@ def main():
run_ortmodule_deepspeed_zero_stage_1_tests(cwd, log, args.mnist)

run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log)
run_ort_pipeline_module_tests(cwd, log)
run_ortmodule_fairscale_sharded_optimizer_tests(cwd, log, args.mnist)

run_distributed_cache_test(cwd, log)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import argparse

import deepspeed
import torch
from torch import nn

from onnxruntime.training.ortmodule.experimental.pipe import ORTPipelineModule

# USAGE:
# pip install deepspeed
# deepspeed orttraining_test_ort_pipeline_module.py --deepspeed_config=orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json --pipeline-parallel-size 2 --steps=100
# expected output : steps: 100 loss: 0.0585 iter time (s): 0.186 samples/sec: 53.694


class SampleData(torch.utils.data.Dataset):
def __init__(self, x, y):
self.x = x
self.y = y

def __len__(self):
return x.size()[0]

def __getitem__(self, idx):
return self.x[idx], self.y[idx]


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=-1, help="local rank passed from distributed launcher")
parser.add_argument("-s", "--steps", type=int, default=100, help="quit after this many steps")
parser.add_argument("-p", "--pipeline-parallel-size", type=int, default=2, help="pipeline parallelism")
parser.add_argument("--backend", type=str, default="nccl", help="distributed backend")
parser.add_argument("--seed", type=int, default=0, help="PRNG seed")
parser.add_argument("--fp16", type=bool, default=False, help="fp16 run")

parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args


n = 10
d_in = 4
d_hidden = 8
d_out = 3
args = get_args()
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)

# dist.init_process_group(backend=args.backend)
deepspeed.init_distributed(dist_backend=args.backend)
torch.manual_seed(args.seed)
# Model.

model = nn.Sequential(
nn.Linear(d_in, d_hidden), # Stage 1
nn.ReLU(), # Stage 1
nn.Linear(d_hidden, d_hidden), # Stage 1
nn.ReLU(), # Stage 1
nn.Linear(d_hidden, d_hidden), # Stage 2
nn.ReLU(), # Stage 2
nn.Linear(d_hidden, d_out), # Stage 2
)

model = ORTPipelineModule(
layers=model,
loss_fn=torch.nn.CrossEntropyLoss(),
num_stages=args.pipeline_parallel_size,
partition_method="uniform", #'parameters',
activation_checkpoint_interval=0,
)

params = [p for p in model.parameters() if p.requires_grad]

# Input.
x = torch.rand((n, d_in))
if args.fp16:
x = x.half()
# Output.
y = torch.randint(0, d_out, (n,))
ds = SampleData(x, y)

print("Initialize deepspeed")
model_engine, optimizer, _, _ = deepspeed.initialize(
args=args, model=model, model_parameters=params, training_data=ds # (x,y)#
)

for step in range(args.steps):
loss = model_engine.train_batch()
if step % 10 == 0:
print("step = ", step, ", loss = ", loss)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def finalize_options(self):
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops",
"onnxruntime.training.ortmodule.graph_optimizers",
"onnxruntime.training.ortmodule.experimental.pipe",
"onnxruntime.training.ort_triton",
"onnxruntime.training.ort_triton.kernel",
"onnxruntime.training.utils",
Expand Down

0 comments on commit ee74fb6

Please sign in to comment.