Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #88 from dujiangsu/main
Browse files Browse the repository at this point in the history
update Readme
  • Loading branch information
MaruyamaAya authored Jun 8, 2022
2 parents 1a92c00 + 5893537 commit 89ce337
Show file tree
Hide file tree
Showing 10 changed files with 676 additions and 10 deletions.
154 changes: 154 additions & 0 deletions energonai/engine/auto_pipeline_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import threading

import torch
import torch.fx
import torch.nn as nn
from typing import List, Tuple, Union

from energonai.communication import send_forward, recv_forward, send_tensor_meta, recv_tensor_meta
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc

from .pipeline_meta import PipelineMeta
from .pipeline_msg_dict import PipelineMsgDict, CircleInt # PipelineMsgPriorityQueue

def filter_inputs(traced: torch.fx.GraphModule):
inputs = {}
for node in traced.graph.nodes:
if node.op == 'placeholder':
inputs[node.name] = None
else:
break
return inputs

# The Wrapper is only for Transformer Model.
class AutoPipelineCommWrapper:

def __init__(self, model: nn.Module, max_batch_size: int = 1, dtype=torch.float) -> None:
# TODO (dujiangsu): to make sample capability for different types. Iteration, Tensor, and others.
self.model = model
self.dtype = dtype

self.tensor_dim = 0
self.hidden_size = 0
self.max_batch_size = max_batch_size

if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
input_ids = torch.randint(1, 10, (max_batch_size, 512), dtype=torch.int64).cuda()
attention_mask = torch.randint(0, 1, (max_batch_size, 1, 512), dtype=torch.int64).cuda()
sample = dict(input_ids=input_ids, attention_mask=attention_mask)
self._init_tensor_meta(sample)

self.pipe_msg_queue = PipelineMsgDict()
self.lock = threading.Lock()
self.key = CircleInt()

def _init_tensor_meta(self, sample):

with torch.inference_mode():
recv_tensor_shape = None
if gpc.is_first_rank(ParallelMode.PIPELINE):
# inputs = filter_inputs(self.model)
output = self.model(sample['input_ids'],
sample['attention_mask']) # ([32, 512, 1600])
send_tensor_meta(output)
send_forward(output)
self.tensor_dim = output.dim()
self.hidden_size = output.size()[-1]
elif gpc.is_last_rank(ParallelMode.PIPELINE):
recv_tensor_shape = recv_tensor_meta(recv_tensor_shape)
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_size = input_tensor.size()[-1]
else:
recv_tensor_shape = recv_tensor_meta(recv_tensor_shape)
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_size = input_tensor.size()[-1]
output = self.model(input_tensor,
sample['attention_mask'])
send_tensor_meta(output)
send_forward(output)

def run(self, key, inputs):
if gpc.is_initialized(ParallelMode.PIPELINE):
return self.run_with_pp(key, inputs)
# else:
# return self.run_without_pp(key, inputs)

# def run_without_pp(self, key, inputs):
# pipe_meta = None
# self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

# self.lock.acquire()

# cur_key = self.key.val
# sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
# self.key.addOne()
# with torch.inference_mode():
# output = self.model(hidden_states=None,
# input_ids=sample['input_ids'],
# attention_mask=sample['attention_mask'],
# seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)
# self.lock.release()

# return output, cur_key

'''
hidden_size : ([32, 512, 1600])
For different model type, fill_meta_tensor is different
'''

def fill_meta_tensor(self, inputs, pipe_meta):
if 'seq_lens' in inputs:
pipe_meta.get_meta_tensor()[0] = 1
pipe_meta.get_meta_tensor()[1] = 1
pipe_meta.get_meta_tensor()[2] = torch.sum(inputs['seq_lens'])
else:
pipe_meta.get_meta_tensor()[0] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[1] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[2] = inputs['input_ids'].shape[1]

pipe_meta.get_meta_tensor()[3] = self.hidden_size
pipe_meta.update_meta()

def run_with_pp(self, key, inputs):
pipe_meta = PipelineMeta(self.tensor_dim, self.max_batch_size)
self.fill_meta_tensor(inputs, pipe_meta)
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

self.lock.acquire()
cur_key = self.key.val
sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
self.key.addOne()

with torch.inference_mode():

if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(sample['input_ids'],
sample['attention_mask'])
# seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)

send_forward(output)
self.lock.release()
return None

if gpc.is_last_rank(ParallelMode.PIPELINE):

# print(f'get_tensor_shapes:{pipe_meta.get_tensor_shapes()}')
input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
output = self.model(input_tensor,
sample['attention_mask'])
self.lock.release()
return output, cur_key

else:

input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
output = self.model(input_tensor,
sample['input_ids'])
send_forward(output)
self.lock.release()
return None


11 changes: 7 additions & 4 deletions energonai/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from colossalai.logging import get_dist_logger

from energonai.initialize import launch_from_multiprocess
from energonai.utils import ensure_directory_exists



Expand All @@ -33,9 +32,11 @@ def __init__(self,
max_batch_size: int = 1,
tp_init_size: int = -1,
pp_init_size: int = -1,
auto_pp: bool = False,
host: str = 'localhost',
port: int = 29500,
dtype=None):
dtype=None,
):
"""
Args:
model: torch.nn.Module
Expand All @@ -57,8 +58,10 @@ def __init__(self,
self.tp_size = tp_init_size
self.pp_size = pp_init_size

# for TP
# for TP, PP
self.rrefs = []
self.auto_pp = auto_pp

# for rpc
self.WORKER_NAME = "wok{}"
self._init_dist_rpc()
Expand Down Expand Up @@ -93,7 +96,7 @@ def _init_model(self):
rpc.remote(ob_info,
RPCWorker,
args=(self.model_class, self.model_config, self.model_type, self.dtype,
self.max_batch_size)))
self.max_batch_size, self.auto_pp)))

def run(self, inputs):
res_rref = 0
Expand Down
33 changes: 27 additions & 6 deletions energonai/engine/rpc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from .pipeline_wrapper import PipelineCommWrapper
from .vit_pipeline_wrapper import ViTPipelineCommWrapper
from .auto_pipeline_wrapper import AutoPipelineCommWrapper

from energonai.pipelinable import split_transformer_into_partitions

from energonai.context import mcfg

Expand All @@ -15,9 +18,15 @@
pipe_wrapper = {
'vit': ViTPipelineCommWrapper,
'bert': PipelineCommWrapper,
'gpt': PipelineCommWrapper
'gpt': PipelineCommWrapper,
'auto': AutoPipelineCommWrapper,
}

pipe_split = {
'bert': split_transformer_into_partitions,
'gpt': split_transformer_into_partitions,
}


class ReturnDict:

Expand All @@ -38,22 +47,34 @@ def top(self, key):

class RPCWorker:

def __init__(self, model_class, model_config, model_type, dtype, max_batch_size: int = 1) -> None:
def __init__(self, model_class, model_config, model_type, dtype, max_batch_size: int = 1, auto_pp: bool = False) -> None:

self.model_class = model_class
self.model_config = model_config
self.dtype = dtype
self.max_batch_size = max_batch_size
self.model_type = model_type
# self.auto_pp = auto_pp

self.WORKER_NAME = "wok{}"
self.model = None # call the model
self.rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}')

# self.trt_sample = None
self._init_self()
self.return_dict = ReturnDict()
if auto_pp:
self._auto_pp_init_model()
else:
self._init_self()
self.return_dict = ReturnDict()

def _auto_pp_init_model(self):
logger.info("Init automatic pipeline model in rank {}".format(self.rank))
submodules = pipe_split[self.model_type](self.model_class)
self.model = submodules.get_submodule(f'submod_{gpc.get_local_rank(ParallelMode.PIPELINE)}')
del submodules
self.model = pipe_wrapper['auto'](model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype)


def _init_self(self):
logger.info("Init model in rank {}".format(self.rank))
Expand All @@ -69,7 +90,7 @@ def _init_self(self):
try:
logger.info('Import Torch2Trt')
from torch2trt import torch2trt
from energonai.engine import trt_converter
from energonai.engine import trt_converter
except:
logger.error("Installation Required, \n \
follow https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html \
Expand Down Expand Up @@ -100,4 +121,4 @@ def run(self, key, inputs):
self.return_dict.enqueue(cur_key, output.cpu())
return self.return_dict.top(key)

return None
# return None
1 change: 1 addition & 0 deletions energonai/pipelinable/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .split_method import split_transformer_into_partitions
7 changes: 7 additions & 0 deletions energonai/pipelinable/energon_tracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch.fx
from energonai.context import mcfg

class EnergonTracer(torch.fx.Tracer):
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
leaves = mcfg["LeafSet"] # set([BertTransformerLayer])
return type(m) in leaves
22 changes: 22 additions & 0 deletions energonai/pipelinable/split_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from torch.fx.passes.split_module import split_module
from .split_policy import module_equal_partition, naive_equal_partition, transformer_partition
from .energon_tracer import EnergonTracer
import torch.fx

def filter_graph(traced: torch.fx.GraphModule, filter_type: str):
len = 0
for node in traced.graph.nodes:
if node.op == filter_type:
len = len + 1
return len


def split_transformer_into_partitions(model_class):
model = model_class()
graph = EnergonTracer().trace(model)
traced = torch.fx.GraphModule(model, graph)
depth = filter_graph(traced, "call_module") - 1
submodules = split_module(traced, model, transformer_partition(depth))
del model

return submodules
52 changes: 52 additions & 0 deletions energonai/pipelinable/split_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import functools
from torch.fx.node import Node
from energonai.context import mcfg


partition_counter_0 = 0

# partition_nums: nums of each submodule
def _naive_equal_partition(node: Node, partition_nums):
global partition_counter_0
partition = partition_counter_0 // partition_nums
partition_counter_0 = partition_counter_0 + 1
return partition

def naive_equal_partition(partition_nums):
mod_partition = functools.partial(_naive_equal_partition, partition_nums = partition_nums)
return mod_partition

partition_counter_1 = 0

# partition_nums: nums of each submodule
def _module_equal_partition(node: Node, partition_nums):
global partition_counter_1
partition = partition_counter_1 // partition_nums
if node.op == 'call_module':
partition_counter_1 = partition_counter_1 + 1
return partition

def module_equal_partition(partition_nums):
mod_partition = functools.partial(_module_equal_partition, partition_nums = partition_nums)
return mod_partition



from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
partition_counter_2 = -1 # for embedding layer
# partition_nums: nums of each submodule
def _transformer_partition(node: Node, depth):
global partition_counter_2
assert gpc.is_initialized(ParallelMode.PIPELINE), "Pipeline communication group should be initialized!"

pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
partition_nums = depth // pipeline_size
partition = abs(partition_counter_2) // partition_nums
if node.op == 'call_module':
partition_counter_2 = partition_counter_2 + 1
return partition

def transformer_partition(depth):
mod_partition = functools.partial(_transformer_partition, depth = depth)
return mod_partition
Loading

0 comments on commit 89ce337

Please sign in to comment.