Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

update Readme #88

Merged
merged 2 commits into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Energon-AI provides 3 levels of abstraction for enabling the large-scale model i
For models trained by [Colossal-AI](https://github.com/hpcaitech/ColossalAI), they can be seamlessly transferred to Energon-AI.
For single-device models, they require manual coding works to introduce tensor parallelism and pipeline parallelism.

At present, we pre-build distributed Bert, GPT, and ViT models.
At present, we pre-build distributed Bert, GPT, and ViT models.
For GPT, it extends to at most 175B parameters, which is called [GPT3](https://arxiv.org/abs/2005.14165).
For Bert, Google reports a [super-large Bert with 481B parameters](https://mlcommons.org/en/training-normal-11/) in MLPerf-Training v1.1 open, indicating that Bert can also extend to large-scale.

Expand Down Expand Up @@ -55,22 +55,25 @@ Method 2:
#### Scaling Ability

Here GPT3-12-layers in FP16 is adopted.
Here a node with 8 A100 80 GB GPUs is adopted. GPUs are fully connected with NvLink.
Energon-AI adopts the redundant computation elimination method from [EffectiveTransformer](https://github.com/bytedance/effective_transformer) and the sequence length is set the half of the padding length.
Here a node with 8 A100 80 GB GPUs is adopted. GPUs are fully connected with NvLink.
Energon-AI adopts the redundant computation elimination method. The method is first raised in [EffectiveTransformer](https://github.com/bytedance/effective_transformer), and our implementation refers to [TurboTransformer](https://github.com/Tencent/TurboTransformers/blob/master/turbo_transformers/layers/kernels/gpu_transpose_kernel.cu).
Here the sequence length is set the half of the padding length.
<div align="center">
<img src="https://user-images.githubusercontent.com/12018307/168971637-ffd1d6ba-44bb-4043-a275-3dc2a008c048.png" width = "600" height = "240" alt="Architecture" align=center />
</div>

#### Latency
Here GPT3 in FP16 is adopted.
Here a node with 8 A100 80 GB GPUs is adopted. Every two GPUs are connected with NvLink.
Here the sequence length is set the half of the padding length when using redundant computation elimination method, which is the Energon-AI(RM).
Here the sequence length is set the half of the padding length when using redundant computation elimination method, which is the Energon-AI(RM).
Here FasterTransformer is adopted in comparison and it does not support the redundant computation elimination method in the distributed execution.
<div align="center">
<img src="https://user-images.githubusercontent.com/12018307/169728315-8ac95e4f-3e81-44e5-b82b-5873ffe85351.png" width = "600" height = "300" alt="Architecture" align=center />
</div>

#### Batching
Energon-AI dynamically selects the batch processing with the highest priority regarding the waiting time, batch size, batch expansion possibility (based on the sentence length after padding).
Our dynamic batching method is inspired by the DP algorithm from [TurboTransformer](https://dl.acm.org/doi/10.1145/3437801.3441578).
Here FIFO batching is selected in comparison.
<div align="center">
<img src="https://user-images.githubusercontent.com/12018307/170616782-18fae36f-75cd-4e7b-bc0b-c8998be1e540.png" width = "400" height = "100" alt="Architecture" align=center />
Expand Down
154 changes: 154 additions & 0 deletions energonai/engine/auto_pipeline_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import threading

import torch
import torch.fx
import torch.nn as nn
from typing import List, Tuple, Union

from energonai.communication import send_forward, recv_forward, send_tensor_meta, recv_tensor_meta
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc

from .pipeline_meta import PipelineMeta
from .pipeline_msg_dict import PipelineMsgDict, CircleInt # PipelineMsgPriorityQueue

def filter_inputs(traced: torch.fx.GraphModule):
inputs = {}
for node in traced.graph.nodes:
if node.op == 'placeholder':
inputs[node.name] = None
else:
break
return inputs

# The Wrapper is only for Transformer Model.
class AutoPipelineCommWrapper:

def __init__(self, model: nn.Module, max_batch_size: int = 1, dtype=torch.float) -> None:
# TODO (dujiangsu): to make sample capability for different types. Iteration, Tensor, and others.
self.model = model
self.dtype = dtype

self.tensor_dim = 0
self.hidden_size = 0
self.max_batch_size = max_batch_size

if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
input_ids = torch.randint(1, 10, (max_batch_size, 512), dtype=torch.int64).cuda()
attention_mask = torch.randint(0, 1, (max_batch_size, 1, 512), dtype=torch.int64).cuda()
sample = dict(input_ids=input_ids, attention_mask=attention_mask)
self._init_tensor_meta(sample)

self.pipe_msg_queue = PipelineMsgDict()
self.lock = threading.Lock()
self.key = CircleInt()

def _init_tensor_meta(self, sample):

with torch.inference_mode():
recv_tensor_shape = None
if gpc.is_first_rank(ParallelMode.PIPELINE):
# inputs = filter_inputs(self.model)
output = self.model(sample['input_ids'],
sample['attention_mask']) # ([32, 512, 1600])
send_tensor_meta(output)
send_forward(output)
self.tensor_dim = output.dim()
self.hidden_size = output.size()[-1]
elif gpc.is_last_rank(ParallelMode.PIPELINE):
recv_tensor_shape = recv_tensor_meta(recv_tensor_shape)
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_size = input_tensor.size()[-1]
else:
recv_tensor_shape = recv_tensor_meta(recv_tensor_shape)
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_size = input_tensor.size()[-1]
output = self.model(input_tensor,
sample['attention_mask'])
send_tensor_meta(output)
send_forward(output)

def run(self, key, inputs):
if gpc.is_initialized(ParallelMode.PIPELINE):
return self.run_with_pp(key, inputs)
# else:
# return self.run_without_pp(key, inputs)

# def run_without_pp(self, key, inputs):
# pipe_meta = None
# self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

# self.lock.acquire()

# cur_key = self.key.val
# sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
# self.key.addOne()
# with torch.inference_mode():
# output = self.model(hidden_states=None,
# input_ids=sample['input_ids'],
# attention_mask=sample['attention_mask'],
# seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)
# self.lock.release()

# return output, cur_key

'''
hidden_size : ([32, 512, 1600])
For different model type, fill_meta_tensor is different
'''

def fill_meta_tensor(self, inputs, pipe_meta):
if 'seq_lens' in inputs:
pipe_meta.get_meta_tensor()[0] = 1
pipe_meta.get_meta_tensor()[1] = 1
pipe_meta.get_meta_tensor()[2] = torch.sum(inputs['seq_lens'])
else:
pipe_meta.get_meta_tensor()[0] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[1] = inputs['input_ids'].shape[0]
pipe_meta.get_meta_tensor()[2] = inputs['input_ids'].shape[1]

pipe_meta.get_meta_tensor()[3] = self.hidden_size
pipe_meta.update_meta()

def run_with_pp(self, key, inputs):
pipe_meta = PipelineMeta(self.tensor_dim, self.max_batch_size)
self.fill_meta_tensor(inputs, pipe_meta)
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

self.lock.acquire()
cur_key = self.key.val
sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
self.key.addOne()

with torch.inference_mode():

if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(sample['input_ids'],
sample['attention_mask'])
# seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)

send_forward(output)
self.lock.release()
return None

if gpc.is_last_rank(ParallelMode.PIPELINE):

# print(f'get_tensor_shapes:{pipe_meta.get_tensor_shapes()}')
input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
output = self.model(input_tensor,
sample['attention_mask'])
self.lock.release()
return output, cur_key

else:

input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
output = self.model(input_tensor,
sample['input_ids'])
send_forward(output)
self.lock.release()
return None


11 changes: 7 additions & 4 deletions energonai/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from colossalai.logging import get_dist_logger

from energonai.initialize import launch_from_multiprocess
from energonai.utils import ensure_directory_exists



Expand All @@ -33,9 +32,11 @@ def __init__(self,
max_batch_size: int = 1,
tp_init_size: int = -1,
pp_init_size: int = -1,
auto_pp: bool = False,
host: str = 'localhost',
port: int = 29500,
dtype=None):
dtype=None,
):
"""
Args:
model: torch.nn.Module
Expand All @@ -57,8 +58,10 @@ def __init__(self,
self.tp_size = tp_init_size
self.pp_size = pp_init_size

# for TP
# for TP, PP
self.rrefs = []
self.auto_pp = auto_pp

# for rpc
self.WORKER_NAME = "wok{}"
self._init_dist_rpc()
Expand Down Expand Up @@ -93,7 +96,7 @@ def _init_model(self):
rpc.remote(ob_info,
RPCWorker,
args=(self.model_class, self.model_config, self.model_type, self.dtype,
self.max_batch_size)))
self.max_batch_size, self.auto_pp)))

def run(self, inputs):
res_rref = 0
Expand Down
33 changes: 27 additions & 6 deletions energonai/engine/rpc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from .pipeline_wrapper import PipelineCommWrapper
from .vit_pipeline_wrapper import ViTPipelineCommWrapper
from .auto_pipeline_wrapper import AutoPipelineCommWrapper

from energonai.pipelinable import split_transformer_into_partitions

from energonai.context import mcfg

Expand All @@ -15,9 +18,15 @@
pipe_wrapper = {
'vit': ViTPipelineCommWrapper,
'bert': PipelineCommWrapper,
'gpt': PipelineCommWrapper
'gpt': PipelineCommWrapper,
'auto': AutoPipelineCommWrapper,
}

pipe_split = {
'bert': split_transformer_into_partitions,
'gpt': split_transformer_into_partitions,
}


class ReturnDict:

Expand All @@ -38,22 +47,34 @@ def top(self, key):

class RPCWorker:

def __init__(self, model_class, model_config, model_type, dtype, max_batch_size: int = 1) -> None:
def __init__(self, model_class, model_config, model_type, dtype, max_batch_size: int = 1, auto_pp: bool = False) -> None:

self.model_class = model_class
self.model_config = model_config
self.dtype = dtype
self.max_batch_size = max_batch_size
self.model_type = model_type
# self.auto_pp = auto_pp

self.WORKER_NAME = "wok{}"
self.model = None # call the model
self.rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}')

# self.trt_sample = None
self._init_self()
self.return_dict = ReturnDict()
if auto_pp:
self._auto_pp_init_model()
else:
self._init_self()
self.return_dict = ReturnDict()

def _auto_pp_init_model(self):
logger.info("Init automatic pipeline model in rank {}".format(self.rank))
submodules = pipe_split[self.model_type](self.model_class)
self.model = submodules.get_submodule(f'submod_{gpc.get_local_rank(ParallelMode.PIPELINE)}')
del submodules
self.model = pipe_wrapper['auto'](model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype)


def _init_self(self):
logger.info("Init model in rank {}".format(self.rank))
Expand All @@ -69,7 +90,7 @@ def _init_self(self):
try:
logger.info('Import Torch2Trt')
from torch2trt import torch2trt
from energonai.engine import trt_converter
from energonai.engine import trt_converter
except:
logger.error("Installation Required, \n \
follow https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html \
Expand Down Expand Up @@ -100,4 +121,4 @@ def run(self, key, inputs):
self.return_dict.enqueue(cur_key, output.cpu())
return self.return_dict.top(key)

return None
# return None
1 change: 1 addition & 0 deletions energonai/pipelinable/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .split_method import split_transformer_into_partitions
7 changes: 7 additions & 0 deletions energonai/pipelinable/energon_tracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch.fx
from energonai.context import mcfg

class EnergonTracer(torch.fx.Tracer):
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
leaves = mcfg["LeafSet"] # set([BertTransformerLayer])
return type(m) in leaves
22 changes: 22 additions & 0 deletions energonai/pipelinable/split_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from torch.fx.passes.split_module import split_module
from .split_policy import module_equal_partition, naive_equal_partition, transformer_partition
from .energon_tracer import EnergonTracer
import torch.fx

def filter_graph(traced: torch.fx.GraphModule, filter_type: str):
len = 0
for node in traced.graph.nodes:
if node.op == filter_type:
len = len + 1
return len


def split_transformer_into_partitions(model_class):
model = model_class()
graph = EnergonTracer().trace(model)
traced = torch.fx.GraphModule(model, graph)
depth = filter_graph(traced, "call_module") - 1
submodules = split_module(traced, model, transformer_partition(depth))
del model

return submodules
Loading