From 4ce7bbf6f17a4c3b4252952f6539b585d8fbfbf7 Mon Sep 17 00:00:00 2001 From: Adam Louly Date: Tue, 23 Apr 2024 17:57:08 -0700 Subject: [PATCH] Add LayerSpec Support to ORTPipelineModule (#20410) ### Description In Deepspeed's Pipeline Parallel Implementation, there is a class used to instantiate the object after it's moved to the device and assigned in a stage. This approach helps reduce peak memory usage. In this PR, we're adding support to ORT for wrapping this LayerSpec. --- .../experimental/pipe/_ort_pipeline_module.py | 12 ++ ...orttraining_ortmodule_distributed_tests.py | 9 +- .../orttraining_test_ort_pipeline_module.py | 139 ++++++++++++------ 3 files changed, 112 insertions(+), 48 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/pipe/_ort_pipeline_module.py b/orttraining/orttraining/python/training/ortmodule/experimental/pipe/_ort_pipeline_module.py index f088228ba3495..aff00f283c0ff 100644 --- a/orttraining/orttraining/python/training/ortmodule/experimental/pipe/_ort_pipeline_module.py +++ b/orttraining/orttraining/python/training/ortmodule/experimental/pipe/_ort_pipeline_module.py @@ -142,6 +142,18 @@ def _build(self): elif isinstance(layer, LayerSpec): module = layer.build() name = str(layer_idx) + + if "debug_options" in self.ort_kwargs: + new_onnx_prefix = name + "_" + self.ort_kwargs["debug_options"].onnx_prefix + parallel_debug_options = DebugOptions( + self.ort_kwargs["debug_options"].log_level, + self.ort_kwargs["debug_options"].save_onnx, + new_onnx_prefix, + ) + module = ORTModule(module, parallel_debug_options) + else: + module = ORTModule(module) + self.forward_funcs.append(module) self.fwd_map.update({name: len(self.forward_funcs) - 1}) self.add_module(name, module) diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py index ad1a8842ddecd..a3010978a0be4 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py @@ -56,7 +56,7 @@ def run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log): run_subprocess(command, cwd=cwd, log=log).check_returncode() -def run_ort_pipeline_module_tests(cwd, log): +def run_ort_pipeline_module_tests(cwd, log, layer_spec_flag=False): log.debug("Running: ORTPipelineModule tests") command = [ @@ -66,6 +66,9 @@ def run_ort_pipeline_module_tests(cwd, log): "orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json", ] + if layer_spec_flag: + command.append("--layer_spec=True") + run_subprocess(command, cwd=cwd, log=log).check_returncode() @@ -107,7 +110,11 @@ def main(): run_ortmodule_deepspeed_zero_stage_1_tests(cwd, log, args.mnist) run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log) + + # Deepspeed ORTPipelineModule Tests run_ort_pipeline_module_tests(cwd, log) + run_ort_pipeline_module_tests(cwd, log, layer_spec_flag=True) + run_ortmodule_fairscale_sharded_optimizer_tests(cwd, log, args.mnist) run_distributed_cache_test(cwd, log) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_pipeline_module.py b/orttraining/orttraining/test/python/orttraining_test_ort_pipeline_module.py index 39b7ed6e53201..d59e32cde33dd 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_pipeline_module.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_pipeline_module.py @@ -1,90 +1,135 @@ import argparse +from typing import Dict, Tuple import deepspeed import torch -from torch import nn +from deepspeed.pipe import LayerSpec +from torch import nn, utils from onnxruntime.training.ortmodule.experimental.pipe import ORTPipelineModule +# This script demonstrates how to set up a pipeline parallel training session +# using DeepSpeed's ORTPipelineModule for a simple neural network model. + + # USAGE: # pip install deepspeed # deepspeed orttraining_test_ort_pipeline_module.py --deepspeed_config=orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json --pipeline-parallel-size 2 --steps=100 -# expected output : steps: 100 loss: 0.0585 iter time (s): 0.186 samples/sec: 53.694 +def get_args() -> argparse.Namespace: + """Parse and return command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--local_rank", type=int, default=-1, help="Local rank passed from distributed launcher") + parser.add_argument("--steps", type=int, default=100, help="Number of training steps to run") + parser.add_argument("--pipeline-parallel-size", type=int, default=2, help="Number of pipeline stages") + parser.add_argument("--backend", type=str, default="nccl", help="Distributed backend") + parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility") + parser.add_argument("--layer_spec", type=bool, default=False, help="Use LayerSpec for layer specification") + + parser = deepspeed.add_config_arguments(parser) + return parser.parse_args() + +class SampleData(utils.data.Dataset): + """Custom dataset to facilitate loading and batching of data.""" -class SampleData(torch.utils.data.Dataset): - def __init__(self, x, y): + def __init__(self, x: torch.Tensor, y: torch.Tensor): self.x = x self.y = y - def __len__(self): - return x.size()[0] + def __len__(self) -> int: + return self.x.size(0) - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: return self.x[idx], self.y[idx] -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--local_rank", type=int, default=-1, help="local rank passed from distributed launcher") - parser.add_argument("-s", "--steps", type=int, default=100, help="quit after this many steps") - parser.add_argument("-p", "--pipeline-parallel-size", type=int, default=2, help="pipeline parallelism") - parser.add_argument("--backend", type=str, default="nccl", help="distributed backend") - parser.add_argument("--seed", type=int, default=0, help="PRNG seed") - parser.add_argument("--fp16", type=bool, default=False, help="fp16 run") +class SimpleNetPipeInput(nn.Module): + """First stage of the pipeline, responsible for initial processing.""" - parser = deepspeed.add_config_arguments(parser) - args = parser.parse_args() - return args + def __init__(self, config: Dict[str, int]): + super().__init__() + self.linear = nn.Linear(config["input_size"], config["hidden_size"]) + self.activation = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear(x) + x = self.activation(x) + return x + + +class SimpleNetPipeBlock(nn.Module): + """Intermediate stage of the pipeline, can be duplicated to deepen the network.""" + + def __init__(self, config: Dict[str, int]): + super().__init__() + self.linear = nn.Linear(config["hidden_size"], config["hidden_size"]) + self.activation = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear(x) + x = self.activation(x) + return x + + +class SimpleNetPipeOutput(nn.Module): + """Final stage of the pipeline, producing the output.""" + + def __init__(self, config: Dict[str, int]): + super().__init__() + self.linear = nn.Linear(config["hidden_size"], config["output_size"]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear(x) + return x + + +def build_model(config: Dict[str, int], n: int, layer_spec: bool) -> nn.Module: + """Constructs and returns the model either using LayerSpec or nn.Sequential.""" + if layer_spec: + print("Wrapping layers with LayerSpec") + model = ( + [LayerSpec(SimpleNetPipeInput, config)] + + [LayerSpec(SimpleNetPipeBlock, config) for _ in range(n)] + + [LayerSpec(SimpleNetPipeOutput, config)] + ) + else: + print("Wrapping layers with nn.Sequential") + model = nn.Sequential( + SimpleNetPipeInput(config), + SimpleNetPipeBlock(config), + SimpleNetPipeBlock(config), + SimpleNetPipeOutput(config), + ) + return model -n = 10 -d_in = 4 -d_hidden = 8 -d_out = 3 args = get_args() torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) - -# dist.init_process_group(backend=args.backend) deepspeed.init_distributed(dist_backend=args.backend) torch.manual_seed(args.seed) -# Model. - -model = nn.Sequential( - nn.Linear(d_in, d_hidden), # Stage 1 - nn.ReLU(), # Stage 1 - nn.Linear(d_hidden, d_hidden), # Stage 1 - nn.ReLU(), # Stage 1 - nn.Linear(d_hidden, d_hidden), # Stage 2 - nn.ReLU(), # Stage 2 - nn.Linear(d_hidden, d_out), # Stage 2 -) + +model = build_model({"input_size": 4, "hidden_size": 8, "output_size": 3}, n=10, layer_spec=args.layer_spec) model = ORTPipelineModule( layers=model, loss_fn=torch.nn.CrossEntropyLoss(), num_stages=args.pipeline_parallel_size, - partition_method="uniform", #'parameters', + partition_method="uniform", activation_checkpoint_interval=0, ) -params = [p for p in model.parameters() if p.requires_grad] - -# Input. -x = torch.rand((n, d_in)) -if args.fp16: - x = x.half() -# Output. -y = torch.randint(0, d_out, (n,)) -ds = SampleData(x, y) +# Setup input data +x = torch.rand((10, 4)) +y = torch.randint(0, 3, (10,)) +dataset = SampleData(x, y) print("Initialize deepspeed") model_engine, optimizer, _, _ = deepspeed.initialize( - args=args, model=model, model_parameters=params, training_data=ds # (x,y)# + args=args, model=model, model_parameters=model.parameters(), training_data=dataset ) for step in range(args.steps): loss = model_engine.train_batch() if step % 10 == 0: - print("step = ", step, ", loss = ", loss) + print(f"step = {step}, loss = {loss}")