Skip to content

Commit

Permalink
Add LayerSpec Support to ORTPipelineModule (microsoft#20410)
Browse files Browse the repository at this point in the history
### Description
In Deepspeed's Pipeline Parallel Implementation, there is a class used
to instantiate the object after it's moved to the device and assigned in
a stage.

This approach helps reduce peak memory usage. 

In this PR, we're adding support to ORT for wrapping this LayerSpec.
  • Loading branch information
AdamLouly authored Apr 24, 2024
1 parent 5055dc0 commit 4ce7bbf
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,18 @@ def _build(self):
elif isinstance(layer, LayerSpec):
module = layer.build()
name = str(layer_idx)

if "debug_options" in self.ort_kwargs:
new_onnx_prefix = name + "_" + self.ort_kwargs["debug_options"].onnx_prefix
parallel_debug_options = DebugOptions(
self.ort_kwargs["debug_options"].log_level,
self.ort_kwargs["debug_options"].save_onnx,
new_onnx_prefix,
)
module = ORTModule(module, parallel_debug_options)
else:
module = ORTModule(module)

self.forward_funcs.append(module)
self.fwd_map.update({name: len(self.forward_funcs) - 1})
self.add_module(name, module)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log):
run_subprocess(command, cwd=cwd, log=log).check_returncode()


def run_ort_pipeline_module_tests(cwd, log):
def run_ort_pipeline_module_tests(cwd, log, layer_spec_flag=False):
log.debug("Running: ORTPipelineModule tests")

command = [
Expand All @@ -66,6 +66,9 @@ def run_ort_pipeline_module_tests(cwd, log):
"orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json",
]

if layer_spec_flag:
command.append("--layer_spec=True")

run_subprocess(command, cwd=cwd, log=log).check_returncode()


Expand Down Expand Up @@ -107,7 +110,11 @@ def main():
run_ortmodule_deepspeed_zero_stage_1_tests(cwd, log, args.mnist)

run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log)

# Deepspeed ORTPipelineModule Tests
run_ort_pipeline_module_tests(cwd, log)
run_ort_pipeline_module_tests(cwd, log, layer_spec_flag=True)

run_ortmodule_fairscale_sharded_optimizer_tests(cwd, log, args.mnist)

run_distributed_cache_test(cwd, log)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,90 +1,135 @@
import argparse
from typing import Dict, Tuple

import deepspeed
import torch
from torch import nn
from deepspeed.pipe import LayerSpec
from torch import nn, utils

from onnxruntime.training.ortmodule.experimental.pipe import ORTPipelineModule

# This script demonstrates how to set up a pipeline parallel training session
# using DeepSpeed's ORTPipelineModule for a simple neural network model.


# USAGE:
# pip install deepspeed
# deepspeed orttraining_test_ort_pipeline_module.py --deepspeed_config=orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json --pipeline-parallel-size 2 --steps=100
# expected output : steps: 100 loss: 0.0585 iter time (s): 0.186 samples/sec: 53.694
def get_args() -> argparse.Namespace:
"""Parse and return command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=-1, help="Local rank passed from distributed launcher")
parser.add_argument("--steps", type=int, default=100, help="Number of training steps to run")
parser.add_argument("--pipeline-parallel-size", type=int, default=2, help="Number of pipeline stages")
parser.add_argument("--backend", type=str, default="nccl", help="Distributed backend")
parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility")
parser.add_argument("--layer_spec", type=bool, default=False, help="Use LayerSpec for layer specification")

parser = deepspeed.add_config_arguments(parser)
return parser.parse_args()


class SampleData(utils.data.Dataset):
"""Custom dataset to facilitate loading and batching of data."""

class SampleData(torch.utils.data.Dataset):
def __init__(self, x, y):
def __init__(self, x: torch.Tensor, y: torch.Tensor):
self.x = x
self.y = y

def __len__(self):
return x.size()[0]
def __len__(self) -> int:
return self.x.size(0)

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
return self.x[idx], self.y[idx]


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=-1, help="local rank passed from distributed launcher")
parser.add_argument("-s", "--steps", type=int, default=100, help="quit after this many steps")
parser.add_argument("-p", "--pipeline-parallel-size", type=int, default=2, help="pipeline parallelism")
parser.add_argument("--backend", type=str, default="nccl", help="distributed backend")
parser.add_argument("--seed", type=int, default=0, help="PRNG seed")
parser.add_argument("--fp16", type=bool, default=False, help="fp16 run")
class SimpleNetPipeInput(nn.Module):
"""First stage of the pipeline, responsible for initial processing."""

parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args
def __init__(self, config: Dict[str, int]):
super().__init__()
self.linear = nn.Linear(config["input_size"], config["hidden_size"])
self.activation = nn.ReLU()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear(x)
x = self.activation(x)
return x


class SimpleNetPipeBlock(nn.Module):
"""Intermediate stage of the pipeline, can be duplicated to deepen the network."""

def __init__(self, config: Dict[str, int]):
super().__init__()
self.linear = nn.Linear(config["hidden_size"], config["hidden_size"])
self.activation = nn.ReLU()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear(x)
x = self.activation(x)
return x


class SimpleNetPipeOutput(nn.Module):
"""Final stage of the pipeline, producing the output."""

def __init__(self, config: Dict[str, int]):
super().__init__()
self.linear = nn.Linear(config["hidden_size"], config["output_size"])

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear(x)
return x


def build_model(config: Dict[str, int], n: int, layer_spec: bool) -> nn.Module:
"""Constructs and returns the model either using LayerSpec or nn.Sequential."""
if layer_spec:
print("Wrapping layers with LayerSpec")
model = (
[LayerSpec(SimpleNetPipeInput, config)]
+ [LayerSpec(SimpleNetPipeBlock, config) for _ in range(n)]
+ [LayerSpec(SimpleNetPipeOutput, config)]
)
else:
print("Wrapping layers with nn.Sequential")
model = nn.Sequential(
SimpleNetPipeInput(config),
SimpleNetPipeBlock(config),
SimpleNetPipeBlock(config),
SimpleNetPipeOutput(config),
)
return model


n = 10
d_in = 4
d_hidden = 8
d_out = 3
args = get_args()
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)

# dist.init_process_group(backend=args.backend)
deepspeed.init_distributed(dist_backend=args.backend)
torch.manual_seed(args.seed)
# Model.

model = nn.Sequential(
nn.Linear(d_in, d_hidden), # Stage 1
nn.ReLU(), # Stage 1
nn.Linear(d_hidden, d_hidden), # Stage 1
nn.ReLU(), # Stage 1
nn.Linear(d_hidden, d_hidden), # Stage 2
nn.ReLU(), # Stage 2
nn.Linear(d_hidden, d_out), # Stage 2
)

model = build_model({"input_size": 4, "hidden_size": 8, "output_size": 3}, n=10, layer_spec=args.layer_spec)

model = ORTPipelineModule(
layers=model,
loss_fn=torch.nn.CrossEntropyLoss(),
num_stages=args.pipeline_parallel_size,
partition_method="uniform", #'parameters',
partition_method="uniform",
activation_checkpoint_interval=0,
)

params = [p for p in model.parameters() if p.requires_grad]

# Input.
x = torch.rand((n, d_in))
if args.fp16:
x = x.half()
# Output.
y = torch.randint(0, d_out, (n,))
ds = SampleData(x, y)
# Setup input data
x = torch.rand((10, 4))
y = torch.randint(0, 3, (10,))
dataset = SampleData(x, y)

print("Initialize deepspeed")
model_engine, optimizer, _, _ = deepspeed.initialize(
args=args, model=model, model_parameters=params, training_data=ds # (x,y)#
args=args, model=model, model_parameters=model.parameters(), training_data=dataset
)

for step in range(args.steps):
loss = model_engine.train_batch()
if step % 10 == 0:
print("step = ", step, ", loss = ", loss)
print(f"step = {step}, loss = {loss}")

0 comments on commit 4ce7bbf

Please sign in to comment.