This repository has been archived by the owner on Oct 16, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #88 from dujiangsu/main
update Readme
- Loading branch information
Showing
10 changed files
with
676 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import threading | ||
|
||
import torch | ||
import torch.fx | ||
import torch.nn as nn | ||
from typing import List, Tuple, Union | ||
|
||
from energonai.communication import send_forward, recv_forward, send_tensor_meta, recv_tensor_meta | ||
from colossalai.context import ParallelMode | ||
from colossalai.core import global_context as gpc | ||
|
||
from .pipeline_meta import PipelineMeta | ||
from .pipeline_msg_dict import PipelineMsgDict, CircleInt # PipelineMsgPriorityQueue | ||
|
||
def filter_inputs(traced: torch.fx.GraphModule): | ||
inputs = {} | ||
for node in traced.graph.nodes: | ||
if node.op == 'placeholder': | ||
inputs[node.name] = None | ||
else: | ||
break | ||
return inputs | ||
|
||
# The Wrapper is only for Transformer Model. | ||
class AutoPipelineCommWrapper: | ||
|
||
def __init__(self, model: nn.Module, max_batch_size: int = 1, dtype=torch.float) -> None: | ||
# TODO (dujiangsu): to make sample capability for different types. Iteration, Tensor, and others. | ||
self.model = model | ||
self.dtype = dtype | ||
|
||
self.tensor_dim = 0 | ||
self.hidden_size = 0 | ||
self.max_batch_size = max_batch_size | ||
|
||
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: | ||
input_ids = torch.randint(1, 10, (max_batch_size, 512), dtype=torch.int64).cuda() | ||
attention_mask = torch.randint(0, 1, (max_batch_size, 1, 512), dtype=torch.int64).cuda() | ||
sample = dict(input_ids=input_ids, attention_mask=attention_mask) | ||
self._init_tensor_meta(sample) | ||
|
||
self.pipe_msg_queue = PipelineMsgDict() | ||
self.lock = threading.Lock() | ||
self.key = CircleInt() | ||
|
||
def _init_tensor_meta(self, sample): | ||
|
||
with torch.inference_mode(): | ||
recv_tensor_shape = None | ||
if gpc.is_first_rank(ParallelMode.PIPELINE): | ||
# inputs = filter_inputs(self.model) | ||
output = self.model(sample['input_ids'], | ||
sample['attention_mask']) # ([32, 512, 1600]) | ||
send_tensor_meta(output) | ||
send_forward(output) | ||
self.tensor_dim = output.dim() | ||
self.hidden_size = output.size()[-1] | ||
elif gpc.is_last_rank(ParallelMode.PIPELINE): | ||
recv_tensor_shape = recv_tensor_meta(recv_tensor_shape) | ||
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now | ||
self.tensor_dim = input_tensor.dim() | ||
self.hidden_size = input_tensor.size()[-1] | ||
else: | ||
recv_tensor_shape = recv_tensor_meta(recv_tensor_shape) | ||
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now | ||
self.tensor_dim = input_tensor.dim() | ||
self.hidden_size = input_tensor.size()[-1] | ||
output = self.model(input_tensor, | ||
sample['attention_mask']) | ||
send_tensor_meta(output) | ||
send_forward(output) | ||
|
||
def run(self, key, inputs): | ||
if gpc.is_initialized(ParallelMode.PIPELINE): | ||
return self.run_with_pp(key, inputs) | ||
# else: | ||
# return self.run_without_pp(key, inputs) | ||
|
||
# def run_without_pp(self, key, inputs): | ||
# pipe_meta = None | ||
# self.pipe_msg_queue.enqueue(key, inputs, pipe_meta) | ||
|
||
# self.lock.acquire() | ||
|
||
# cur_key = self.key.val | ||
# sample, pipe_meta = self.pipe_msg_queue.top(cur_key) | ||
# self.key.addOne() | ||
# with torch.inference_mode(): | ||
# output = self.model(hidden_states=None, | ||
# input_ids=sample['input_ids'], | ||
# attention_mask=sample['attention_mask'], | ||
# seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None) | ||
# self.lock.release() | ||
|
||
# return output, cur_key | ||
|
||
''' | ||
hidden_size : ([32, 512, 1600]) | ||
For different model type, fill_meta_tensor is different | ||
''' | ||
|
||
def fill_meta_tensor(self, inputs, pipe_meta): | ||
if 'seq_lens' in inputs: | ||
pipe_meta.get_meta_tensor()[0] = 1 | ||
pipe_meta.get_meta_tensor()[1] = 1 | ||
pipe_meta.get_meta_tensor()[2] = torch.sum(inputs['seq_lens']) | ||
else: | ||
pipe_meta.get_meta_tensor()[0] = inputs['input_ids'].shape[0] | ||
pipe_meta.get_meta_tensor()[1] = inputs['input_ids'].shape[0] | ||
pipe_meta.get_meta_tensor()[2] = inputs['input_ids'].shape[1] | ||
|
||
pipe_meta.get_meta_tensor()[3] = self.hidden_size | ||
pipe_meta.update_meta() | ||
|
||
def run_with_pp(self, key, inputs): | ||
pipe_meta = PipelineMeta(self.tensor_dim, self.max_batch_size) | ||
self.fill_meta_tensor(inputs, pipe_meta) | ||
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta) | ||
|
||
self.lock.acquire() | ||
cur_key = self.key.val | ||
sample, pipe_meta = self.pipe_msg_queue.top(cur_key) | ||
self.key.addOne() | ||
|
||
with torch.inference_mode(): | ||
|
||
if gpc.is_first_rank(ParallelMode.PIPELINE): | ||
output = self.model(sample['input_ids'], | ||
sample['attention_mask']) | ||
# seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None) | ||
|
||
send_forward(output) | ||
self.lock.release() | ||
return None | ||
|
||
if gpc.is_last_rank(ParallelMode.PIPELINE): | ||
|
||
# print(f'get_tensor_shapes:{pipe_meta.get_tensor_shapes()}') | ||
input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype) | ||
output = self.model(input_tensor, | ||
sample['attention_mask']) | ||
self.lock.release() | ||
return output, cur_key | ||
|
||
else: | ||
|
||
input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype) | ||
output = self.model(input_tensor, | ||
sample['input_ids']) | ||
send_forward(output) | ||
self.lock.release() | ||
return None | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .split_method import split_transformer_into_partitions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import torch.fx | ||
from energonai.context import mcfg | ||
|
||
class EnergonTracer(torch.fx.Tracer): | ||
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: | ||
leaves = mcfg["LeafSet"] # set([BertTransformerLayer]) | ||
return type(m) in leaves |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from torch.fx.passes.split_module import split_module | ||
from .split_policy import module_equal_partition, naive_equal_partition, transformer_partition | ||
from .energon_tracer import EnergonTracer | ||
import torch.fx | ||
|
||
def filter_graph(traced: torch.fx.GraphModule, filter_type: str): | ||
len = 0 | ||
for node in traced.graph.nodes: | ||
if node.op == filter_type: | ||
len = len + 1 | ||
return len | ||
|
||
|
||
def split_transformer_into_partitions(model_class): | ||
model = model_class() | ||
graph = EnergonTracer().trace(model) | ||
traced = torch.fx.GraphModule(model, graph) | ||
depth = filter_graph(traced, "call_module") - 1 | ||
submodules = split_module(traced, model, transformer_partition(depth)) | ||
del model | ||
|
||
return submodules |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import functools | ||
from torch.fx.node import Node | ||
from energonai.context import mcfg | ||
|
||
|
||
partition_counter_0 = 0 | ||
|
||
# partition_nums: nums of each submodule | ||
def _naive_equal_partition(node: Node, partition_nums): | ||
global partition_counter_0 | ||
partition = partition_counter_0 // partition_nums | ||
partition_counter_0 = partition_counter_0 + 1 | ||
return partition | ||
|
||
def naive_equal_partition(partition_nums): | ||
mod_partition = functools.partial(_naive_equal_partition, partition_nums = partition_nums) | ||
return mod_partition | ||
|
||
partition_counter_1 = 0 | ||
|
||
# partition_nums: nums of each submodule | ||
def _module_equal_partition(node: Node, partition_nums): | ||
global partition_counter_1 | ||
partition = partition_counter_1 // partition_nums | ||
if node.op == 'call_module': | ||
partition_counter_1 = partition_counter_1 + 1 | ||
return partition | ||
|
||
def module_equal_partition(partition_nums): | ||
mod_partition = functools.partial(_module_equal_partition, partition_nums = partition_nums) | ||
return mod_partition | ||
|
||
|
||
|
||
from colossalai.core import global_context as gpc | ||
from colossalai.context import ParallelMode | ||
partition_counter_2 = -1 # for embedding layer | ||
# partition_nums: nums of each submodule | ||
def _transformer_partition(node: Node, depth): | ||
global partition_counter_2 | ||
assert gpc.is_initialized(ParallelMode.PIPELINE), "Pipeline communication group should be initialized!" | ||
|
||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) | ||
partition_nums = depth // pipeline_size | ||
partition = abs(partition_counter_2) // partition_nums | ||
if node.op == 'call_module': | ||
partition_counter_2 = partition_counter_2 + 1 | ||
return partition | ||
|
||
def transformer_partition(depth): | ||
mod_partition = functools.partial(_transformer_partition, depth = depth) | ||
return mod_partition |
Oops, something went wrong.