forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add LayerSpec Support to ORTPipelineModule (microsoft#20410)
### Description In Deepspeed's Pipeline Parallel Implementation, there is a class used to instantiate the object after it's moved to the device and assigned in a stage. This approach helps reduce peak memory usage. In this PR, we're adding support to ORT for wrapping this LayerSpec.
- Loading branch information
Showing
3 changed files
with
112 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
139 changes: 92 additions & 47 deletions
139
orttraining/orttraining/test/python/orttraining_test_ort_pipeline_module.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,90 +1,135 @@ | ||
import argparse | ||
from typing import Dict, Tuple | ||
|
||
import deepspeed | ||
import torch | ||
from torch import nn | ||
from deepspeed.pipe import LayerSpec | ||
from torch import nn, utils | ||
|
||
from onnxruntime.training.ortmodule.experimental.pipe import ORTPipelineModule | ||
|
||
# This script demonstrates how to set up a pipeline parallel training session | ||
# using DeepSpeed's ORTPipelineModule for a simple neural network model. | ||
|
||
|
||
# USAGE: | ||
# pip install deepspeed | ||
# deepspeed orttraining_test_ort_pipeline_module.py --deepspeed_config=orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json --pipeline-parallel-size 2 --steps=100 | ||
# expected output : steps: 100 loss: 0.0585 iter time (s): 0.186 samples/sec: 53.694 | ||
def get_args() -> argparse.Namespace: | ||
"""Parse and return command line arguments.""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--local_rank", type=int, default=-1, help="Local rank passed from distributed launcher") | ||
parser.add_argument("--steps", type=int, default=100, help="Number of training steps to run") | ||
parser.add_argument("--pipeline-parallel-size", type=int, default=2, help="Number of pipeline stages") | ||
parser.add_argument("--backend", type=str, default="nccl", help="Distributed backend") | ||
parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility") | ||
parser.add_argument("--layer_spec", type=bool, default=False, help="Use LayerSpec for layer specification") | ||
|
||
parser = deepspeed.add_config_arguments(parser) | ||
return parser.parse_args() | ||
|
||
|
||
class SampleData(utils.data.Dataset): | ||
"""Custom dataset to facilitate loading and batching of data.""" | ||
|
||
class SampleData(torch.utils.data.Dataset): | ||
def __init__(self, x, y): | ||
def __init__(self, x: torch.Tensor, y: torch.Tensor): | ||
self.x = x | ||
self.y = y | ||
|
||
def __len__(self): | ||
return x.size()[0] | ||
def __len__(self) -> int: | ||
return self.x.size(0) | ||
|
||
def __getitem__(self, idx): | ||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: | ||
return self.x[idx], self.y[idx] | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--local_rank", type=int, default=-1, help="local rank passed from distributed launcher") | ||
parser.add_argument("-s", "--steps", type=int, default=100, help="quit after this many steps") | ||
parser.add_argument("-p", "--pipeline-parallel-size", type=int, default=2, help="pipeline parallelism") | ||
parser.add_argument("--backend", type=str, default="nccl", help="distributed backend") | ||
parser.add_argument("--seed", type=int, default=0, help="PRNG seed") | ||
parser.add_argument("--fp16", type=bool, default=False, help="fp16 run") | ||
class SimpleNetPipeInput(nn.Module): | ||
"""First stage of the pipeline, responsible for initial processing.""" | ||
|
||
parser = deepspeed.add_config_arguments(parser) | ||
args = parser.parse_args() | ||
return args | ||
def __init__(self, config: Dict[str, int]): | ||
super().__init__() | ||
self.linear = nn.Linear(config["input_size"], config["hidden_size"]) | ||
self.activation = nn.ReLU() | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.linear(x) | ||
x = self.activation(x) | ||
return x | ||
|
||
|
||
class SimpleNetPipeBlock(nn.Module): | ||
"""Intermediate stage of the pipeline, can be duplicated to deepen the network.""" | ||
|
||
def __init__(self, config: Dict[str, int]): | ||
super().__init__() | ||
self.linear = nn.Linear(config["hidden_size"], config["hidden_size"]) | ||
self.activation = nn.ReLU() | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.linear(x) | ||
x = self.activation(x) | ||
return x | ||
|
||
|
||
class SimpleNetPipeOutput(nn.Module): | ||
"""Final stage of the pipeline, producing the output.""" | ||
|
||
def __init__(self, config: Dict[str, int]): | ||
super().__init__() | ||
self.linear = nn.Linear(config["hidden_size"], config["output_size"]) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.linear(x) | ||
return x | ||
|
||
|
||
def build_model(config: Dict[str, int], n: int, layer_spec: bool) -> nn.Module: | ||
"""Constructs and returns the model either using LayerSpec or nn.Sequential.""" | ||
if layer_spec: | ||
print("Wrapping layers with LayerSpec") | ||
model = ( | ||
[LayerSpec(SimpleNetPipeInput, config)] | ||
+ [LayerSpec(SimpleNetPipeBlock, config) for _ in range(n)] | ||
+ [LayerSpec(SimpleNetPipeOutput, config)] | ||
) | ||
else: | ||
print("Wrapping layers with nn.Sequential") | ||
model = nn.Sequential( | ||
SimpleNetPipeInput(config), | ||
SimpleNetPipeBlock(config), | ||
SimpleNetPipeBlock(config), | ||
SimpleNetPipeOutput(config), | ||
) | ||
return model | ||
|
||
|
||
n = 10 | ||
d_in = 4 | ||
d_hidden = 8 | ||
d_out = 3 | ||
args = get_args() | ||
torch.cuda.set_device(args.local_rank) | ||
device = torch.device("cuda", args.local_rank) | ||
|
||
# dist.init_process_group(backend=args.backend) | ||
deepspeed.init_distributed(dist_backend=args.backend) | ||
torch.manual_seed(args.seed) | ||
# Model. | ||
|
||
model = nn.Sequential( | ||
nn.Linear(d_in, d_hidden), # Stage 1 | ||
nn.ReLU(), # Stage 1 | ||
nn.Linear(d_hidden, d_hidden), # Stage 1 | ||
nn.ReLU(), # Stage 1 | ||
nn.Linear(d_hidden, d_hidden), # Stage 2 | ||
nn.ReLU(), # Stage 2 | ||
nn.Linear(d_hidden, d_out), # Stage 2 | ||
) | ||
|
||
model = build_model({"input_size": 4, "hidden_size": 8, "output_size": 3}, n=10, layer_spec=args.layer_spec) | ||
|
||
model = ORTPipelineModule( | ||
layers=model, | ||
loss_fn=torch.nn.CrossEntropyLoss(), | ||
num_stages=args.pipeline_parallel_size, | ||
partition_method="uniform", #'parameters', | ||
partition_method="uniform", | ||
activation_checkpoint_interval=0, | ||
) | ||
|
||
params = [p for p in model.parameters() if p.requires_grad] | ||
|
||
# Input. | ||
x = torch.rand((n, d_in)) | ||
if args.fp16: | ||
x = x.half() | ||
# Output. | ||
y = torch.randint(0, d_out, (n,)) | ||
ds = SampleData(x, y) | ||
# Setup input data | ||
x = torch.rand((10, 4)) | ||
y = torch.randint(0, 3, (10,)) | ||
dataset = SampleData(x, y) | ||
|
||
print("Initialize deepspeed") | ||
model_engine, optimizer, _, _ = deepspeed.initialize( | ||
args=args, model=model, model_parameters=params, training_data=ds # (x,y)# | ||
args=args, model=model, model_parameters=model.parameters(), training_data=dataset | ||
) | ||
|
||
for step in range(args.steps): | ||
loss = model_engine.train_batch() | ||
if step % 10 == 0: | ||
print("step = ", step, ", loss = ", loss) | ||
print(f"step = {step}, loss = {loss}") |