Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing ORTPipelineModule - DeepSpeed Parallel Pipeline Support. #20287

Merged
merged 9 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,9 @@ if (onnxruntime_ENABLE_TRAINING)
file(GLOB onnxruntime_python_ortmodule_graph_optimizers_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/graph_optimizers/*"
)
file(GLOB onnxruntime_python_ortmodule_pipe_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/pipe/*"
)
file(GLOB onnxruntime_python_ort_triton_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/ort_triton/*.py"
)
Expand Down Expand Up @@ -756,6 +759,7 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/pipe
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/kernel
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils
Expand Down Expand Up @@ -806,6 +810,9 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ortmodule_graph_optimizers_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/graph_optimizers/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ortmodule_pipe_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/pipe/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_ort_triton_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ort_triton/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from ._ort_pipeline_module import ORTPipelineModule
Fixed Show fixed Hide fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
# Licensed under the MIT license.

import importlib.metadata
from packaging.version import Version

# Check if DeepSpeed is installed and meets the minimum version requirement
minimum_version = Version("0.12.6")
installed_version = Version(importlib.metadata.version("deepspeed"))

if installed_version < minimum_version:
raise ImportError(f"DeepSpeed >= {minimum_version} is required, but {installed_version} is installed.")

from typing import List, Union
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

import torch.nn as nn
Fixed Show fixed Hide fixed
from deepspeed.pipe import LayerSpec, TiedLayerSpec, PipelineModule
Fixed Show fixed Hide fixed
from onnxruntime.training.ortmodule import DebugOptions, ORTModule
Fixed Show fixed Hide fixed
from deepspeed.runtime.activation_checkpointing import checkpointing
Fixed Show fixed Hide fixed


class ORTPipelineModule(PipelineModule):
"""ORTPipelineModule pipeline module.

A customized version of DeepSpeed's PipelineModule that wraps each neural network layer
with ONNX Runtime's ORTModule. This modification allows leveraging ONNX Runtime optimizations
for the forward and backward passes, potentially enhancing execution performance and efficiency.

"""

def __init__(
self,
layers,
num_stages=None,
topology=None,
loss_fn=None,
seed_layers=False,
seed_fn=None,
base_seed=1234,
partition_method="parameters",
activation_checkpoint_interval=0,
activation_checkpoint_func=checkpointing.checkpoint,
checkpointable_layers=None,
debug_options=None,
):
"""
Initialize the ORTPipelineModule with the option to include ONNX Runtime debug options.

:param debug_options: An instance of onnxruntime.training.ortmodule.DebugOptions or None.
If provided, it will be used to configure debugging options for ORTModules.
This is done so we can add the name of the layer to avoid overwriting the ONNX files.
Default is None, indicating that no special debug configuration is applied.
"""

self.ort_kwargs = {"debug_options": debug_options} if debug_options is not None else {}

super().__init__(
layers,
num_stages,
topology,
loss_fn,
seed_layers,
seed_fn,
base_seed,
partition_method,
activation_checkpoint_interval,
activation_checkpoint_func,
checkpointable_layers,
)

def _build(self):
specs = self._layer_specs

for local_idx, layer in enumerate(specs[self._local_start : self._local_stop]):
layer_idx = local_idx + self._local_start
if self.seed_layers:
if self.seed_fn:
self.seed_fn(self.base_seed + layer_idx)
else:
ds_utils.set_random_seed(self.base_seed + layer_idx)
Fixed Show fixed Hide fixed

# Recursively build PipelineModule objects
if isinstance(layer, PipelineModule):
raise NotImplementedError("RECURSIVE BUILD NOT YET IMPLEMENTED")

# TODO: Support wrapping for LayerSpec and TiedLayerSpec in addition to nn.Module in sequential.
# Currently, we only support wrapping nn.Module instances.

# LayerSpec objects contain an nn.Module that should be allocated now.
elif isinstance(layer, nn.Module):
name = str(layer_idx)

if "debug_options" in self.ort_kwargs:
new_onnx_prefix = name + "_" + self.ort_kwargs["debug_options"].onnx_prefix
parallel_debug_options = DebugOptions(
self.ort_kwargs["debug_options"].log_level,
self.ort_kwargs["debug_options"].save_onnx,
new_onnx_prefix,
)
wrapped_layer = ORTModule(layer, parallel_debug_options)
else:
wrapped_layer = ORTModule(layer)

self.forward_funcs.append(wrapped_layer)
self.fwd_map.update({name: len(self.forward_funcs) - 1})
self.add_module(name, layer)

# TiedLayerSpec objects contain an nn.Module that should be allocated now.
elif isinstance(layer, TiedLayerSpec):
# Build and register the module if we haven't seen it before.
if layer.key not in self.tied_modules:
self.tied_modules[layer.key] = layer.build()
self.tied_weight_attrs[layer.key] = layer.tied_weight_attr

if layer.forward_fn is None:
# Just use forward()
self.forward_funcs.append(self.tied_modules[layer.key])
else:
# User specified fn with args (module, input)
self.forward_funcs.append(partial(layer.forward_fn, self.tied_modules[layer.key]))
Fixed Show fixed Hide fixed

# LayerSpec objects contain an nn.Module that should be allocated now.
elif isinstance(layer, LayerSpec):
module = layer.build()
name = str(layer_idx)
self.forward_funcs.append(module)
self.fwd_map.update({name: len(self.forward_funcs) - 1})
self.add_module(name, module)

# Last option: layer may be a functional (e.g., lambda). We do nothing in
# that case and just use it in forward()
else:
self.forward_funcs.append(layer)

# All pipeline parameters should be considered as model parallel in the context
# of our FP16 optimizer
for p in self.parameters():
p.ds_pipe_replicated = False
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ def run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log):
run_subprocess(command, cwd=cwd, log=log).check_returncode()


def run_ort_pipeline_module_tests(cwd, log):
log.debug("Running: ORTPipelineModule tests")

command = [
"deepspeed",
"orttraining_test_ort_pipeline_module.py",
"--deepspeed_config",
"orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json",
]

run_subprocess(command, cwd=cwd, log=log).check_returncode()


def run_ortmodule_fairscale_sharded_optimizer_tests(cwd, log, data_dir):
log.debug("Running: ORTModule fairscale sharded optimizer tests")
command = [
Expand Down Expand Up @@ -94,6 +107,7 @@ def main():
run_ortmodule_deepspeed_zero_stage_1_tests(cwd, log, args.mnist)

run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log)
run_ort_pipeline_module_tests(cwd, log)
run_ortmodule_fairscale_sharded_optimizer_tests(cwd, log, args.mnist)

run_distributed_cache_test(cwd, log)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import argparse

import deepspeed
import torch
from torch import nn

from onnxruntime.training.ortmodule.pipe import ORTPipelineModule

# USAGE:
# pip install deepspeed
# deepspeed orttraining_test_ort_pipeline_module.py --deepspeed_config=orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json --pipeline-parallel-size 2 --steps=100
# expected output : steps: 100 loss: 0.0585 iter time (s): 0.186 samples/sec: 53.694


class SampleData(torch.utils.data.Dataset):
def __init__(self, x, y):
self.x = x
self.y = y

def __len__(self):
return x.size()[0]

def __getitem__(self, idx):
return self.x[idx], self.y[idx]


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=-1, help="local rank passed from distributed launcher")
parser.add_argument("-s", "--steps", type=int, default=100, help="quit after this many steps")
parser.add_argument("-p", "--pipeline-parallel-size", type=int, default=2, help="pipeline parallelism")
parser.add_argument("--backend", type=str, default="nccl", help="distributed backend")
parser.add_argument("--seed", type=int, default=0, help="PRNG seed")
parser.add_argument("--fp16", type=bool, default=False, help="fp16 run")

parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args


n = 10
d_in = 4
d_hidden = 8
d_out = 3
args = get_args()
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)

# dist.init_process_group(backend=args.backend)
deepspeed.init_distributed(dist_backend=args.backend)
torch.manual_seed(args.seed)
# Model.

model = nn.Sequential(
nn.Linear(d_in, d_hidden), # Stage 1
nn.ReLU(), # Stage 1
nn.Linear(d_hidden, d_hidden), # Stage 1
nn.ReLU(), # Stage 1
nn.Linear(d_hidden, d_hidden), # Stage 2
nn.ReLU(), # Stage 2
nn.Linear(d_hidden, d_out), # Stage 2
)

model = ORTPipelineModule(
layers=model,
loss_fn=torch.nn.CrossEntropyLoss(),
num_stages=args.pipeline_parallel_size,
partition_method="uniform", #'parameters',
activation_checkpoint_interval=0,
)

params = [p for p in model.parameters() if p.requires_grad]

# Input.
x = torch.rand((n, d_in))
if args.fp16:
x = x.half()
# Output.
y = torch.randint(0, d_out, (n,))
ds = SampleData(x, y)

print("Initialize deepspeed")
model_engine, optimizer, _, _ = deepspeed.initialize(
args=args, model=model, model_parameters=params, training_data=ds # (x,y)#
)

for step in range(args.steps):
loss = model_engine.train_batch()
if step % 10 == 0:
print("step = ", step, ", loss = ", loss)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def finalize_options(self):
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator",
"onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops",
"onnxruntime.training.ortmodule.graph_optimizers",
"onnxruntime.training.ortmodule.pipe",
"onnxruntime.training.ort_triton",
"onnxruntime.training.ort_triton.kernel",
"onnxruntime.training.utils",
Expand Down
Loading