diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 545bcf6a38d05..fcddda8f3eda9 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -380,6 +380,9 @@ if (onnxruntime_ENABLE_TRAINING) file(GLOB onnxruntime_python_ortmodule_graph_optimizers_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/graph_optimizers/*" ) + file(GLOB onnxruntime_python_ortmodule_pipe_srcs CONFIGURE_DEPENDS + "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/experimental/pipe/*" + ) file(GLOB onnxruntime_python_ort_triton_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ort_triton/*.py" ) @@ -756,6 +759,7 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/graph_optimizers + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/experimental/pipe COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ort_triton COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ort_triton/kernel COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/utils @@ -806,6 +810,9 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_graph_optimizers_srcs} $/onnxruntime/training/ortmodule/graph_optimizers/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_ortmodule_pipe_srcs} + $/onnxruntime/training/ortmodule/experimental/pipe/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ort_triton_srcs} $/onnxruntime/training/ort_triton/ diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 54137937ad56d..1609aa7ae02cb 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -495,3 +495,31 @@ for epoch in range(start_epoch, n_epochs): ``` Check [LoadBalancingDistributedBatchSampler implementation](../orttraining/orttraining/python/training/utils/data/sampler.py) for more details. + +## 8 Using ORTPipelineModule for Deepspeed Pipeline Parallel + +You can use `ORTPipelineModule` to support Deepspeed Pipeline Parallelism. Here's how you can integrate it into your pipeline: + +```python +from onnxruntime.training.ortmodule import DebugOptions +from onnxruntime.training.ortmodule.experimental.pipe import ORTPipelineModule + +# Create a debug configuration if needed +# Since we're exporting multiple graphs here, this will generate multiple graphs with their index added as a prefix to differentiate them. + +debug_options = DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="model_name") + +# Keep your deepspeed script the same and use ORTPipelineModule instead of PipelineModule +# Initialize the ORTPipelineModule +pipeline_module = ORTPipelineModule( + layers, + num_stages=2, # Set your number of stages + base_seed=1234, + partition_method="parameters", + debug_options=debug_options # Pass the debug configuration if needed +) + +# Keep the rest of the script as it is. +``` + +Check [ORTPipelineModule implementation](../orttraining/orttraining/python/training/ortmodule/experimental/pipe/_ort_pipeline_module.py) for more details. diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/pipe/__init__.py b/orttraining/orttraining/python/training/ortmodule/experimental/pipe/__init__.py new file mode 100644 index 0000000000000..7d361dc2aa8c9 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/experimental/pipe/__init__.py @@ -0,0 +1,6 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from ._ort_pipeline_module import ORTPipelineModule # noqa: F401 diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/pipe/_ort_pipeline_module.py b/orttraining/orttraining/python/training/ortmodule/experimental/pipe/_ort_pipeline_module.py new file mode 100644 index 0000000000000..f088228ba3495 --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/experimental/pipe/_ort_pipeline_module.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import importlib.metadata +from functools import partial + +import torch.nn as nn +from deepspeed.pipe import LayerSpec, PipelineModule, TiedLayerSpec +from deepspeed.runtime import utils as ds_utils +from deepspeed.runtime.activation_checkpointing import checkpointing +from packaging.version import Version + +from onnxruntime.training.ortmodule import DebugOptions, ORTModule + +# Check if DeepSpeed is installed and meets the minimum version requirement +minimum_version = Version("0.9.0") +installed_version = Version(importlib.metadata.version("deepspeed")) + +if installed_version < minimum_version: + raise ImportError(f"DeepSpeed >= {minimum_version} is required, but {installed_version} is installed.") + + +class ORTPipelineModule(PipelineModule): + """ORTPipelineModule pipeline module. + + A customized version of DeepSpeed's PipelineModule that wraps each neural network layer + with ONNX Runtime's ORTModule. This modification allows leveraging ONNX Runtime optimizations + for the forward and backward passes, potentially enhancing execution performance and efficiency. + + Please locate the "Using ORTPipelineModule for Deepspeed Pipeline Parallel" section in the "docs/ORTModule_Training_Guidelines.md" file of the ORT repository for more information. + + .. note:: + Pipeline parallelism is not compatible with ZeRO-2 and ZeRO-3. + + Args: + layers (Iterable): A sequence of layers defining pipeline structure. Can be a ``torch.nn.Sequential`` module. + num_stages (int, optional): The degree of pipeline parallelism. If not specified, ``topology`` must be provided. + topology (``deepspeed.runtime.pipe.ProcessTopology``, optional): Defines the axes of parallelism axes for training. Must be provided if ``num_stages`` is ``None``. + loss_fn (callable, optional): Loss is computed ``loss = loss_fn(outputs, label)`` + seed_layers(bool, optional): Use a different seed for each layer. Defaults to False. + seed_fn(type, optional): The custom seed generating function. Defaults to random seed generator. + base_seed (int, optional): The starting seed. Defaults to 1234. + partition_method (str, optional): The method upon which the layers are partitioned. Defaults to 'parameters'. + activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing. + activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``. + checkpointable_layers(list, optional): Checkpointable layers may not be checkpointed. Defaults to None which does not additional filtering. + debug_options(onnxruntime.training.ortmodule.DebugOptions): An instance of onnxruntime.training.ortmodule.DebugOptions or None. + If provided, it will be used to configure debugging options for ORTModule, This is done so we can add the name of the layer to avoid overwriting the ONNX files. + """ + + def __init__( + self, + layers, + num_stages=None, + topology=None, + loss_fn=None, + seed_layers=False, + seed_fn=None, + base_seed=1234, + partition_method="parameters", + activation_checkpoint_interval=0, + activation_checkpoint_func=checkpointing.checkpoint, + checkpointable_layers=None, + debug_options=None, + ): + """ + Initialize the ORTPipelineModule with the option to include ONNX Runtime debug options. + """ + + self.ort_kwargs = {"debug_options": debug_options} if debug_options is not None else {} + + super().__init__( + layers, + num_stages, + topology, + loss_fn, + seed_layers, + seed_fn, + base_seed, + partition_method, + activation_checkpoint_interval, + activation_checkpoint_func, + checkpointable_layers, + ) + + def _build(self): + """ + This method does the same thing as PipelineModule._build() method, the only difference is that it wraps each layer with ORTModule. + It also handles saving ONNX models with debug options in case of exporting multiple models. + """ + specs = self._layer_specs + + for local_idx, layer in enumerate(specs[self._local_start : self._local_stop]): + layer_idx = local_idx + self._local_start + if self.seed_layers: + if self.seed_fn: + self.seed_fn(self.base_seed + layer_idx) + else: + ds_utils.set_random_seed(self.base_seed + layer_idx) + + # Recursively build PipelineModule objects + if isinstance(layer, PipelineModule): + raise NotImplementedError("RECURSIVE BUILD NOT YET IMPLEMENTED") + + # TODO: Support wrapping for LayerSpec and TiedLayerSpec in addition to nn.Module in sequential. + # Currently, we only support wrapping nn.Module instances. + + # LayerSpec objects contain an nn.Module that should be allocated now. + elif isinstance(layer, nn.Module): + name = str(layer_idx) + + if "debug_options" in self.ort_kwargs: + new_onnx_prefix = name + "_" + self.ort_kwargs["debug_options"].onnx_prefix + parallel_debug_options = DebugOptions( + self.ort_kwargs["debug_options"].log_level, + self.ort_kwargs["debug_options"].save_onnx, + new_onnx_prefix, + ) + wrapped_layer = ORTModule(layer, parallel_debug_options) + else: + wrapped_layer = ORTModule(layer) + + self.forward_funcs.append(wrapped_layer) + self.fwd_map.update({name: len(self.forward_funcs) - 1}) + self.add_module(name, wrapped_layer) + + # TiedLayerSpec objects contain an nn.Module that should be allocated now. + elif isinstance(layer, TiedLayerSpec): + # Build and register the module if we haven't seen it before. + if layer.key not in self.tied_modules: + self.tied_modules[layer.key] = layer.build() + self.tied_weight_attrs[layer.key] = layer.tied_weight_attr + + if layer.forward_fn is None: + # Just use forward() + self.forward_funcs.append(self.tied_modules[layer.key]) + else: + # User specified fn with args (module, input) + self.forward_funcs.append(partial(layer.forward_fn, self.tied_modules[layer.key])) + + # LayerSpec objects contain an nn.Module that should be allocated now. + elif isinstance(layer, LayerSpec): + module = layer.build() + name = str(layer_idx) + self.forward_funcs.append(module) + self.fwd_map.update({name: len(self.forward_funcs) - 1}) + self.add_module(name, module) + + # Last option: layer may be a functional (e.g., lambda). We do nothing in + # that case and just use it in forward() + else: + self.forward_funcs.append(layer) + + # All pipeline parameters should be considered as model parallel in the context + # of our FP16 optimizer + for p in self.parameters(): + p.ds_pipe_replicated = False diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py index 80db8cd9f17a4..ad1a8842ddecd 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py @@ -56,6 +56,19 @@ def run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log): run_subprocess(command, cwd=cwd, log=log).check_returncode() +def run_ort_pipeline_module_tests(cwd, log): + log.debug("Running: ORTPipelineModule tests") + + command = [ + "deepspeed", + "orttraining_test_ort_pipeline_module.py", + "--deepspeed_config", + "orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json", + ] + + run_subprocess(command, cwd=cwd, log=log).check_returncode() + + def run_ortmodule_fairscale_sharded_optimizer_tests(cwd, log, data_dir): log.debug("Running: ORTModule fairscale sharded optimizer tests") command = [ @@ -94,6 +107,7 @@ def main(): run_ortmodule_deepspeed_zero_stage_1_tests(cwd, log, args.mnist) run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log) + run_ort_pipeline_module_tests(cwd, log) run_ortmodule_fairscale_sharded_optimizer_tests(cwd, log, args.mnist) run_distributed_cache_test(cwd, log) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_pipeline_module.py b/orttraining/orttraining/test/python/orttraining_test_ort_pipeline_module.py new file mode 100644 index 0000000000000..39b7ed6e53201 --- /dev/null +++ b/orttraining/orttraining/test/python/orttraining_test_ort_pipeline_module.py @@ -0,0 +1,90 @@ +import argparse + +import deepspeed +import torch +from torch import nn + +from onnxruntime.training.ortmodule.experimental.pipe import ORTPipelineModule + +# USAGE: +# pip install deepspeed +# deepspeed orttraining_test_ort_pipeline_module.py --deepspeed_config=orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json --pipeline-parallel-size 2 --steps=100 +# expected output : steps: 100 loss: 0.0585 iter time (s): 0.186 samples/sec: 53.694 + + +class SampleData(torch.utils.data.Dataset): + def __init__(self, x, y): + self.x = x + self.y = y + + def __len__(self): + return x.size()[0] + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--local_rank", type=int, default=-1, help="local rank passed from distributed launcher") + parser.add_argument("-s", "--steps", type=int, default=100, help="quit after this many steps") + parser.add_argument("-p", "--pipeline-parallel-size", type=int, default=2, help="pipeline parallelism") + parser.add_argument("--backend", type=str, default="nccl", help="distributed backend") + parser.add_argument("--seed", type=int, default=0, help="PRNG seed") + parser.add_argument("--fp16", type=bool, default=False, help="fp16 run") + + parser = deepspeed.add_config_arguments(parser) + args = parser.parse_args() + return args + + +n = 10 +d_in = 4 +d_hidden = 8 +d_out = 3 +args = get_args() +torch.cuda.set_device(args.local_rank) +device = torch.device("cuda", args.local_rank) + +# dist.init_process_group(backend=args.backend) +deepspeed.init_distributed(dist_backend=args.backend) +torch.manual_seed(args.seed) +# Model. + +model = nn.Sequential( + nn.Linear(d_in, d_hidden), # Stage 1 + nn.ReLU(), # Stage 1 + nn.Linear(d_hidden, d_hidden), # Stage 1 + nn.ReLU(), # Stage 1 + nn.Linear(d_hidden, d_hidden), # Stage 2 + nn.ReLU(), # Stage 2 + nn.Linear(d_hidden, d_out), # Stage 2 +) + +model = ORTPipelineModule( + layers=model, + loss_fn=torch.nn.CrossEntropyLoss(), + num_stages=args.pipeline_parallel_size, + partition_method="uniform", #'parameters', + activation_checkpoint_interval=0, +) + +params = [p for p in model.parameters() if p.requires_grad] + +# Input. +x = torch.rand((n, d_in)) +if args.fp16: + x = x.half() +# Output. +y = torch.randint(0, d_out, (n,)) +ds = SampleData(x, y) + +print("Initialize deepspeed") +model_engine, optimizer, _, _ = deepspeed.initialize( + args=args, model=model, model_parameters=params, training_data=ds # (x,y)# +) + +for step in range(args.steps): + loss = model_engine.train_batch() + if step % 10 == 0: + print("step = ", step, ", loss = ", loss) diff --git a/setup.py b/setup.py index ffe2958b357b8..7f597d15c6f9b 100644 --- a/setup.py +++ b/setup.py @@ -486,6 +486,7 @@ def finalize_options(self): "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator", "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops", "onnxruntime.training.ortmodule.graph_optimizers", + "onnxruntime.training.ortmodule.experimental.pipe", "onnxruntime.training.ort_triton", "onnxruntime.training.ort_triton.kernel", "onnxruntime.training.utils",