From 234b5b886ad2a019a7921d865eff9342eba585fb Mon Sep 17 00:00:00 2001 From: dujiangsu Date: Mon, 30 May 2022 16:52:28 +0800 Subject: [PATCH 1/2] update Readme auto_pipeline auto pipeline. at present only the model split part is automatic and the wrapper is still model specific --- README.md | 11 +- energonai/engine/auto_pipeline_wrapper.py | 154 ++++++++++++ energonai/engine/engine.py | 11 +- energonai/engine/rpc_worker.py | 33 ++- energonai/pipelinable/__init__.py | 1 + energonai/pipelinable/energon_tracer.py | 7 + energonai/pipelinable/split_method.py | 22 ++ energonai/pipelinable/split_policy.py | 52 ++++ examples/auto_pipeline/bert.py | 290 ++++++++++++++++++++++ examples/auto_pipeline/bert_config.py | 24 ++ examples/auto_pipeline/bert_server.py | 92 +++++++ examples/auto_pipeline/run.py | 70 ++++++ 12 files changed, 753 insertions(+), 14 deletions(-) create mode 100644 energonai/engine/auto_pipeline_wrapper.py create mode 100644 energonai/pipelinable/__init__.py create mode 100644 energonai/pipelinable/energon_tracer.py create mode 100644 energonai/pipelinable/split_method.py create mode 100644 energonai/pipelinable/split_policy.py create mode 100644 examples/auto_pipeline/bert.py create mode 100644 examples/auto_pipeline/bert_config.py create mode 100644 examples/auto_pipeline/bert_server.py create mode 100644 examples/auto_pipeline/run.py diff --git a/README.md b/README.md index 308ca0f..eac5c10 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Energon-AI provides 3 levels of abstraction for enabling the large-scale model i For models trained by [Colossal-AI](https://github.com/hpcaitech/ColossalAI), they can be seamlessly transferred to Energon-AI. For single-device models, they require manual coding works to introduce tensor parallelism and pipeline parallelism. -At present, we pre-build distributed Bert, GPT, and ViT models. +At present, we pre-build distributed Bert, GPT, and ViT models. For GPT, it extends to at most 175B parameters, which is called [GPT3](https://arxiv.org/abs/2005.14165). For Bert, Google reports a [super-large Bert with 481B parameters](https://mlcommons.org/en/training-normal-11/) in MLPerf-Training v1.1 open, indicating that Bert can also extend to large-scale. @@ -55,8 +55,9 @@ Method 2: #### Scaling Ability Here GPT3-12-layers in FP16 is adopted. -Here a node with 8 A100 80 GB GPUs is adopted. GPUs are fully connected with NvLink. -Energon-AI adopts the redundant computation elimination method from [EffectiveTransformer](https://github.com/bytedance/effective_transformer) and the sequence length is set the half of the padding length. +Here a node with 8 A100 80 GB GPUs is adopted. GPUs are fully connected with NvLink. +Energon-AI adopts the redundant computation elimination method. The method is first raised in [EffectiveTransformer](https://github.com/bytedance/effective_transformer), and our implementation refers to [TurboTransformer](https://github.com/Tencent/TurboTransformers/blob/master/turbo_transformers/layers/kernels/gpu_transpose_kernel.cu). +Here the sequence length is set the half of the padding length.
Architecture
@@ -64,13 +65,15 @@ Energon-AI adopts the redundant computation elimination method from [EffectiveTr #### Latency Here GPT3 in FP16 is adopted. Here a node with 8 A100 80 GB GPUs is adopted. Every two GPUs are connected with NvLink. -Here the sequence length is set the half of the padding length when using redundant computation elimination method, which is the Energon-AI(RM). +Here the sequence length is set the half of the padding length when using redundant computation elimination method, which is the Energon-AI(RM). Here FasterTransformer is adopted in comparison and it does not support the redundant computation elimination method in the distributed execution.
Architecture
#### Batching +Energon-AI dynamically selects the batch processing with the highest priority regarding the waiting time, batch size, batch expansion possibility (based on the sentence length after padding). +Our dynamic batching method is inspired by the DP algorithm from [TurboTransformer](https://dl.acm.org/doi/10.1145/3437801.3441578). Here FIFO batching is selected in comparison.
Architecture diff --git a/energonai/engine/auto_pipeline_wrapper.py b/energonai/engine/auto_pipeline_wrapper.py new file mode 100644 index 0000000..cabb869 --- /dev/null +++ b/energonai/engine/auto_pipeline_wrapper.py @@ -0,0 +1,154 @@ +import threading + +import torch +import torch.fx +import torch.nn as nn +from typing import List, Tuple, Union + +from energonai.communication import send_forward, recv_forward, send_tensor_meta, recv_tensor_meta +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + +from .pipeline_meta import PipelineMeta +from .pipeline_msg_dict import PipelineMsgDict, CircleInt # PipelineMsgPriorityQueue + +def filter_inputs(traced: torch.fx.GraphModule): + inputs = {} + for node in traced.graph.nodes: + if node.op == 'placeholder': + inputs[node.name] = None + else: + break + return inputs + +# The Wrapper is only for Transformer Model. +class AutoPipelineCommWrapper: + + def __init__(self, model: nn.Module, max_batch_size: int = 1, dtype=torch.float) -> None: + # TODO (dujiangsu): to make sample capability for different types. Iteration, Tensor, and others. + self.model = model + self.dtype = dtype + + self.tensor_dim = 0 + self.hidden_size = 0 + self.max_batch_size = max_batch_size + + if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: + input_ids = torch.randint(1, 10, (max_batch_size, 512), dtype=torch.int64).cuda() + attention_mask = torch.randint(0, 1, (max_batch_size, 1, 512), dtype=torch.int64).cuda() + sample = dict(input_ids=input_ids, attention_mask=attention_mask) + self._init_tensor_meta(sample) + + self.pipe_msg_queue = PipelineMsgDict() + self.lock = threading.Lock() + self.key = CircleInt() + + def _init_tensor_meta(self, sample): + + with torch.inference_mode(): + recv_tensor_shape = None + if gpc.is_first_rank(ParallelMode.PIPELINE): + # inputs = filter_inputs(self.model) + output = self.model(sample['input_ids'], + sample['attention_mask']) # ([32, 512, 1600]) + send_tensor_meta(output) + send_forward(output) + self.tensor_dim = output.dim() + self.hidden_size = output.size()[-1] + elif gpc.is_last_rank(ParallelMode.PIPELINE): + recv_tensor_shape = recv_tensor_meta(recv_tensor_shape) + input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now + self.tensor_dim = input_tensor.dim() + self.hidden_size = input_tensor.size()[-1] + else: + recv_tensor_shape = recv_tensor_meta(recv_tensor_shape) + input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now + self.tensor_dim = input_tensor.dim() + self.hidden_size = input_tensor.size()[-1] + output = self.model(input_tensor, + sample['attention_mask']) + send_tensor_meta(output) + send_forward(output) + + def run(self, key, inputs): + if gpc.is_initialized(ParallelMode.PIPELINE): + return self.run_with_pp(key, inputs) + # else: + # return self.run_without_pp(key, inputs) + + # def run_without_pp(self, key, inputs): + # pipe_meta = None + # self.pipe_msg_queue.enqueue(key, inputs, pipe_meta) + + # self.lock.acquire() + + # cur_key = self.key.val + # sample, pipe_meta = self.pipe_msg_queue.top(cur_key) + # self.key.addOne() + # with torch.inference_mode(): + # output = self.model(hidden_states=None, + # input_ids=sample['input_ids'], + # attention_mask=sample['attention_mask'], + # seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None) + # self.lock.release() + + # return output, cur_key + + ''' + hidden_size : ([32, 512, 1600]) + For different model type, fill_meta_tensor is different + ''' + + def fill_meta_tensor(self, inputs, pipe_meta): + if 'seq_lens' in inputs: + pipe_meta.get_meta_tensor()[0] = 1 + pipe_meta.get_meta_tensor()[1] = 1 + pipe_meta.get_meta_tensor()[2] = torch.sum(inputs['seq_lens']) + else: + pipe_meta.get_meta_tensor()[0] = inputs['input_ids'].shape[0] + pipe_meta.get_meta_tensor()[1] = inputs['input_ids'].shape[0] + pipe_meta.get_meta_tensor()[2] = inputs['input_ids'].shape[1] + + pipe_meta.get_meta_tensor()[3] = self.hidden_size + pipe_meta.update_meta() + + def run_with_pp(self, key, inputs): + pipe_meta = PipelineMeta(self.tensor_dim, self.max_batch_size) + self.fill_meta_tensor(inputs, pipe_meta) + self.pipe_msg_queue.enqueue(key, inputs, pipe_meta) + + self.lock.acquire() + cur_key = self.key.val + sample, pipe_meta = self.pipe_msg_queue.top(cur_key) + self.key.addOne() + + with torch.inference_mode(): + + if gpc.is_first_rank(ParallelMode.PIPELINE): + output = self.model(sample['input_ids'], + sample['attention_mask']) + # seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None) + + send_forward(output) + self.lock.release() + return None + + if gpc.is_last_rank(ParallelMode.PIPELINE): + + # print(f'get_tensor_shapes:{pipe_meta.get_tensor_shapes()}') + input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype) + output = self.model(input_tensor, + sample['attention_mask']) + self.lock.release() + return output, cur_key + + else: + + input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype) + output = self.model(input_tensor, + sample['input_ids']) + send_forward(output) + self.lock.release() + return None + + diff --git a/energonai/engine/engine.py b/energonai/engine/engine.py index 325f6ef..84e1cc5 100644 --- a/energonai/engine/engine.py +++ b/energonai/engine/engine.py @@ -17,7 +17,6 @@ from colossalai.logging import get_dist_logger from energonai.initialize import launch_from_multiprocess -from energonai.utils import ensure_directory_exists @@ -33,9 +32,11 @@ def __init__(self, max_batch_size: int = 1, tp_init_size: int = -1, pp_init_size: int = -1, + auto_pp: bool = False, host: str = 'localhost', port: int = 29500, - dtype=None): + dtype=None, + ): """ Args: model: torch.nn.Module @@ -57,8 +58,10 @@ def __init__(self, self.tp_size = tp_init_size self.pp_size = pp_init_size - # for TP + # for TP, PP self.rrefs = [] + self.auto_pp = auto_pp + # for rpc self.WORKER_NAME = "wok{}" self._init_dist_rpc() @@ -93,7 +96,7 @@ def _init_model(self): rpc.remote(ob_info, RPCWorker, args=(self.model_class, self.model_config, self.model_type, self.dtype, - self.max_batch_size))) + self.max_batch_size, self.auto_pp))) def run(self, inputs): res_rref = 0 diff --git a/energonai/engine/rpc_worker.py b/energonai/engine/rpc_worker.py index 5894ef9..97f25ec 100644 --- a/energonai/engine/rpc_worker.py +++ b/energonai/engine/rpc_worker.py @@ -7,6 +7,9 @@ from .pipeline_wrapper import PipelineCommWrapper from .vit_pipeline_wrapper import ViTPipelineCommWrapper +from .auto_pipeline_wrapper import AutoPipelineCommWrapper + +from energonai.pipelinable import split_transformer_into_partitions from energonai.context import mcfg @@ -15,9 +18,15 @@ pipe_wrapper = { 'vit': ViTPipelineCommWrapper, 'bert': PipelineCommWrapper, - 'gpt': PipelineCommWrapper + 'gpt': PipelineCommWrapper, + 'auto': AutoPipelineCommWrapper, } +pipe_split = { + 'bert': split_transformer_into_partitions, + 'gpt': split_transformer_into_partitions, + } + class ReturnDict: @@ -38,13 +47,14 @@ def top(self, key): class RPCWorker: - def __init__(self, model_class, model_config, model_type, dtype, max_batch_size: int = 1) -> None: + def __init__(self, model_class, model_config, model_type, dtype, max_batch_size: int = 1, auto_pp: bool = False) -> None: self.model_class = model_class self.model_config = model_config self.dtype = dtype self.max_batch_size = max_batch_size self.model_type = model_type + # self.auto_pp = auto_pp self.WORKER_NAME = "wok{}" self.model = None # call the model @@ -52,8 +62,19 @@ def __init__(self, model_class, model_config, model_type, dtype, max_batch_size: torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}') # self.trt_sample = None - self._init_self() - self.return_dict = ReturnDict() + if auto_pp: + self._auto_pp_init_model() + else: + self._init_self() + self.return_dict = ReturnDict() + + def _auto_pp_init_model(self): + logger.info("Init automatic pipeline model in rank {}".format(self.rank)) + submodules = pipe_split[self.model_type](self.model_class) + self.model = submodules.get_submodule(f'submod_{gpc.get_local_rank(ParallelMode.PIPELINE)}') + del submodules + self.model = pipe_wrapper['auto'](model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype) + def _init_self(self): logger.info("Init model in rank {}".format(self.rank)) @@ -69,7 +90,7 @@ def _init_self(self): try: logger.info('Import Torch2Trt') from torch2trt import torch2trt - from energonai.engine import trt_converter + from energonai.engine import trt_converter except: logger.error("Installation Required, \n \ follow https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html \ @@ -100,4 +121,4 @@ def run(self, key, inputs): self.return_dict.enqueue(cur_key, output.cpu()) return self.return_dict.top(key) - return None + # return None diff --git a/energonai/pipelinable/__init__.py b/energonai/pipelinable/__init__.py new file mode 100644 index 0000000..6cde0b3 --- /dev/null +++ b/energonai/pipelinable/__init__.py @@ -0,0 +1 @@ +from .split_method import split_transformer_into_partitions diff --git a/energonai/pipelinable/energon_tracer.py b/energonai/pipelinable/energon_tracer.py new file mode 100644 index 0000000..f0a8cd1 --- /dev/null +++ b/energonai/pipelinable/energon_tracer.py @@ -0,0 +1,7 @@ +import torch.fx +from energonai.context import mcfg + +class EnergonTracer(torch.fx.Tracer): + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + leaves = mcfg["LeafSet"] # set([BertTransformerLayer]) + return type(m) in leaves \ No newline at end of file diff --git a/energonai/pipelinable/split_method.py b/energonai/pipelinable/split_method.py new file mode 100644 index 0000000..7f24621 --- /dev/null +++ b/energonai/pipelinable/split_method.py @@ -0,0 +1,22 @@ +from torch.fx.passes.split_module import split_module +from .split_policy import module_equal_partition, naive_equal_partition, transformer_partition +from .energon_tracer import EnergonTracer +import torch.fx + +def filter_graph(traced: torch.fx.GraphModule, filter_type: str): + len = 0 + for node in traced.graph.nodes: + if node.op == filter_type: + len = len + 1 + return len + + +def split_transformer_into_partitions(model_class): + model = model_class() + graph = EnergonTracer().trace(model) + traced = torch.fx.GraphModule(model, graph) + depth = filter_graph(traced, "call_module") - 1 + submodules = split_module(traced, model, transformer_partition(depth)) + del model + + return submodules \ No newline at end of file diff --git a/energonai/pipelinable/split_policy.py b/energonai/pipelinable/split_policy.py new file mode 100644 index 0000000..5e62fb4 --- /dev/null +++ b/energonai/pipelinable/split_policy.py @@ -0,0 +1,52 @@ +import functools +from torch.fx.node import Node +from energonai.context import mcfg + + +partition_counter_0 = 0 + +# partition_nums: nums of each submodule +def _naive_equal_partition(node: Node, partition_nums): + global partition_counter_0 + partition = partition_counter_0 // partition_nums + partition_counter_0 = partition_counter_0 + 1 + return partition + +def naive_equal_partition(partition_nums): + mod_partition = functools.partial(_naive_equal_partition, partition_nums = partition_nums) + return mod_partition + +partition_counter_1 = 0 + +# partition_nums: nums of each submodule +def _module_equal_partition(node: Node, partition_nums): + global partition_counter_1 + partition = partition_counter_1 // partition_nums + if node.op == 'call_module': + partition_counter_1 = partition_counter_1 + 1 + return partition + +def module_equal_partition(partition_nums): + mod_partition = functools.partial(_module_equal_partition, partition_nums = partition_nums) + return mod_partition + + + +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode +partition_counter_2 = -1 # for embedding layer +# partition_nums: nums of each submodule +def _transformer_partition(node: Node, depth): + global partition_counter_2 + assert gpc.is_initialized(ParallelMode.PIPELINE), "Pipeline communication group should be initialized!" + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + partition_nums = depth // pipeline_size + partition = abs(partition_counter_2) // partition_nums + if node.op == 'call_module': + partition_counter_2 = partition_counter_2 + 1 + return partition + +def transformer_partition(depth): + mod_partition = functools.partial(_transformer_partition, depth = depth) + return mod_partition \ No newline at end of file diff --git a/examples/auto_pipeline/bert.py b/examples/auto_pipeline/bert.py new file mode 100644 index 0000000..ea42017 --- /dev/null +++ b/examples/auto_pipeline/bert.py @@ -0,0 +1,290 @@ +import math +from typing import Callable + +import os +import torch +from torch import nn as nn, Tensor, dtype + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from energonai.logging import get_dist_logger +from colossalai.nn.layer.utils import divide, ACT2FN +from colossalai.nn import Linear1D_Col, Linear1D_Row, Classifier1D +from colossalai.nn import LayerNorm1D +from energonai.kernel import transpose_pad, transpose_depad, depad +from energonai.nn import VocabParallelEmbedding1D +from energonai.utils import get_current_device, is_using_pp + +__all__ = [ + 'BertEmbedding1D' + 'BertMLP1D', + 'BertSelfAttention1D', + 'BertTransformerLayer1D' +] + +from energonai.utils.checkpointing import load_checkpoint + + +class BertEmbedding1D(nn.Module): + def __init__(self, + embedding_dim: int, # hidden_size + vocab_size: int, + max_position_embeddings: int, + num_tokentypes: int = 0, + padding_idx: int = 0, + layernorm_epsilon: float = 1e-5, + dtype: dtype = None) -> None: + super().__init__() + self.word_embeddings = VocabParallelEmbedding1D(vocab_size, embedding_dim, padding_idx=padding_idx, dtype=dtype, + skip_tp=True) + self.position_embeddings = VocabParallelEmbedding1D(max_position_embeddings, embedding_dim, dtype=dtype, + skip_tp=True) + if num_tokentypes > 0: + self.tokentype_embeddings = VocabParallelEmbedding1D(num_tokentypes, embedding_dim, dtype=dtype) + else: + self.tokentype_embeddings = None + + self.LayerNorm = LayerNorm1D(embedding_dim, eps=layernorm_epsilon) + + def forward(self, input_ids, position_ids=None, tokentype_ids=None): + # max_padding_size = input_ids.shape[1] + + # TODO: register_buffer in advance for position_ids to speedup + + # if position_ids is None: + # position_ids = torch.arange(max_padding_size, dtype=torch.long, device=get_current_device()).unsqueeze(0) + + x = self.word_embeddings(input_ids) # + self.position_embeddings(position_ids) + + if self.tokentype_embeddings is not None and tokentype_ids is not None: + x = x + self.tokentype_embeddings(tokentype_ids) + + x = self.LayerNorm(x) + + # if seq_lens is not None: + # x = depad(x, batch_size, seq_lens) + + return x + + +class BertSelfAttention1D(nn.Module): + def __init__(self, + hidden_size: int, + num_heads: int, + bias: bool = True, + fuse_scale_mask_softmax: bool = False, + layernorm_epsilon: float = 1e-5, + dtype: dtype = None) -> None: + super().__init__() + if hidden_size % num_heads != 0: + raise ValueError( + f"The hidden size ({hidden_size}) is not a multiple of the number of attention ") + self.hidden_size = hidden_size + self.attention_head_size = divide(hidden_size, num_heads) + self.fuse_scale_mask_softmax = fuse_scale_mask_softmax + + self.query_key_value = Linear1D_Col(hidden_size, 3 * hidden_size, bias=bias, dtype=dtype) + + if fuse_scale_mask_softmax: + raise NotImplementedError + + self.dense = Linear1D_Row(hidden_size, hidden_size, bias=True, dtype=dtype, parallel_input=True) + self.LayerNorm = LayerNorm1D(hidden_size, eps=layernorm_epsilon) + + def forward(self, hidden_states, attention_mask=None): + + attention_output = self.query_key_value(hidden_states) + all_head_size = attention_output.shape[-1] // 3 + num_attention_heads = divide(all_head_size, self.attention_head_size) # num_heads + + new_qkv_shape = attention_output.shape[:-1] + (num_attention_heads, 3 * self.attention_head_size) + attention_output = attention_output.view(new_qkv_shape) + + # if seq_lens is not None: + # # TODO: use FasterTransformer's implementation. + # attention_output = transpose_pad(attention_output, batch_size, max_padding_size, seq_lens, + # num_attention_heads, self.attention_head_size * 3) + # else: + attention_output = attention_output.permute(0, 2, 1, 3) + # TODO: make sure self.attention_head_size*3 is correct + + q, k, v = torch.chunk(attention_output, 3, dim=-1) + + attention_output = torch.matmul(q, k.transpose(-1, -2)) + if self.fuse_scale_mask_softmax: + raise NotImplementedError + else: + attention_output = attention_output / math.sqrt(self.attention_head_size) + # if attention_mask is not None: + # attention_output = attention_output + attention_mask + attention_output = nn.functional.softmax(attention_output, dim=-1) + + attention_output = torch.matmul(attention_output, v) + + # if seq_lens is not None: + # sum_seq = torch.sum(seq_lens) + # attention_output = transpose_depad(attention_output, batch_size, sum_seq, max_padding_size, seq_lens, + # num_attention_heads, self.attention_head_size) + # else: + attention_output = attention_output.permute(0, 2, 1, 3).contiguous() + + new_context_layer_shape = attention_output.size()[:-2] + (all_head_size,) + attention_output = attention_output.reshape(new_context_layer_shape) + attention_output = self.dense(attention_output) + + hidden_states = self.LayerNorm(attention_output + hidden_states) + + return hidden_states + + +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * + (1.0 + 0.044715 * x * x))) + + +class BertMLP1D(nn.Module): + def __init__(self, + hidden_size: int, + mlp_ratio: float, + activation: Callable = gelu_impl, + layernorm_epsilon: float = 1e-5, + dtype: dtype = None, + bias: bool = True): + super().__init__() + intermediate_dim = int(hidden_size * mlp_ratio) + self.layer_0 = Linear1D_Col(hidden_size, intermediate_dim, bias=bias, dtype=dtype, gather_output=False) + self.activation = activation + self.layer_1 = Linear1D_Row(intermediate_dim, hidden_size, bias=bias, dtype=dtype, parallel_input=True) + self.LayerNorm = LayerNorm1D(hidden_size, eps=layernorm_epsilon) + + def forward(self, input_tensor): + hidden_states = self.layer_0(input_tensor) + hidden_states = self.activation(hidden_states) + hidden_states = self.layer_1(hidden_states) + + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertTransformerLayer1D(nn.Module): + def __init__(self, + hidden_size: int, + num_heads: int, + mlp_ratio: float, + activation: Callable = gelu_impl, + layernorm_epsilon: float = 1e-5, + dtype: dtype = None, + bias: bool = True, + fuse_scale_mask_softmax: bool = False): + super().__init__() + + self.attention = BertSelfAttention1D(hidden_size, + num_heads, + bias, + fuse_scale_mask_softmax, + layernorm_epsilon, + dtype) + self.mlp = BertMLP1D(hidden_size, + mlp_ratio, + activation, + layernorm_epsilon, + dtype, + bias) + + def forward(self, hidden_states, attention_mask): + + batch_size = hidden_states.shape[0] + + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * - 10000.0 + + hidden_states = self.attention(hidden_states, attention_mask) + hidden_states = self.mlp(hidden_states) + + return hidden_states + +class Bert1D(nn.Module): + + def __init__(self, + vocab_size: int = 50304, + max_position_embeddings: int = 1024, + hidden_size: int = 768, + num_heads: int = 12, + depth: int = 12, + mlp_ratio: float = 4.0, + layernorm_epsilon: float = 1e-5, + activation: Callable = nn.functional.gelu, + padding_idx: int = 0, + dtype: dtype = None, + bias: bool = True, + fuse_scale_mask_softmax: bool = False, + ): + super().__init__() + self.embed = BertEmbedding1D(embedding_dim=hidden_size, + vocab_size=vocab_size, + max_position_embeddings=max_position_embeddings, + padding_idx=padding_idx, + layernorm_epsilon=layernorm_epsilon, + dtype=dtype) + self.blocks = nn.ModuleList() + + for i in range(depth): + self.blocks.append(BertTransformerLayer1D( + hidden_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + activation=activation, + layernorm_epsilon=layernorm_epsilon, + dtype=dtype, + bias=bias, + fuse_scale_mask_softmax=fuse_scale_mask_softmax,) + ) + + def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_lens=None): + + # batch_size = input_ids.shape[0] + # max_padding_size = input_ids.shape[1] + + hidden_states = self.embed(input_ids=input_ids, position_ids=None, tokentype_ids=None) # , seq_lens + + for block in self.blocks: + hidden_states = block(hidden_states=hidden_states, attention_mask=attention_mask) + + hidden_states = hidden_states[:, 1, :] + + return hidden_states + + + +def _create_bert_model(model_kwargs): + model = Bert1D(**model_kwargs) + return model + + +def bert_small(**kwargs): + model_kwargs = dict(hidden_size=768, depth=12, num_heads=12, **kwargs) + return _create_bert_model(model_kwargs) + + +def bert_large(**kwargs): + model_kwargs = dict(hidden_size=1024, depth=24, num_heads=16, **kwargs) + return _create_bert_model(model_kwargs) + + +def bert_xl(**kwargs): + model_kwargs = dict(hidden_size=1600, depth=48, num_heads=16, **kwargs) + return _create_bert_model(model_kwargs) + + +def bert_8B(**kwargs): + model_kwargs = dict(hidden_size=3072, depth=72, num_heads=24, **kwargs) + return _create_bert_model(model_kwargs) + + +def bert_175B(**kwargs): + model_kwargs = dict(hidden_size=12288, depth=96, num_heads=96, **kwargs) + return _create_bert_model(model_kwargs) diff --git a/examples/auto_pipeline/bert_config.py b/examples/auto_pipeline/bert_config.py new file mode 100644 index 0000000..b373ffd --- /dev/null +++ b/examples/auto_pipeline/bert_config.py @@ -0,0 +1,24 @@ +from bert import bert_small, bert_large, bert_xl, bert_8B, bert_175B +from bert_server import launch_engine +from bert import BertEmbedding1D, BertTransformerLayer1D + +model_class = bert_8B +model_type = "bert" +engine_server = launch_engine + + +# parallel +tp_init_size = 2 +pp_init_size = 2 +auto_pp = True +LeafSet = set([BertTransformerLayer1D, BertEmbedding1D]) + + + +host = "127.0.0.1" +port = 29400 +half = False +server_host = "127.0.0.1" +server_port = 8010 +log_level = "info" +backend = "nccl" \ No newline at end of file diff --git a/examples/auto_pipeline/bert_server.py b/examples/auto_pipeline/bert_server.py new file mode 100644 index 0000000..946278b --- /dev/null +++ b/examples/auto_pipeline/bert_server.py @@ -0,0 +1,92 @@ +import os +import torch +import uvicorn +from fastapi import FastAPI +from energonai.engine import InferenceEngine +from energonai.context import mcfg + +app = FastAPI() # 创建 api 对象 + +@app.get("/") # 根路由 +def root(): + return {"200"} + +@app.get("/model_with_padding") +def run(): + # for the performance only + seq_len = 512 + batch_size = 32 + + input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64) + # seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int64) # generate seq_lens randomly + hidden_states = None + sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask) + + output = engine.run(sample) + output = output.to_here() + print(output) + return {"To return the string result."} + +# @app.get("/model_rm_padding") +# def run(): +# # for the performance only +# seq_len = 512 +# batch_size = 32 + +# input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64) +# attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64) +# seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int) # generate seq_lens randomly +# hidden_states = None +# sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask, seq_lens=seq_lens) + +# output = engine.run(sample) +# output = output.to_here() +# print(output) +# return {"To return the string result."} + + +@app.get("/shutdown") +async def shutdown(): + engine.clear() + server.should_exit = True + server.force_exit = True + await server.shutdown() + + +def launch_engine(model_class, + model_type, + max_batch_size: int = 1, + tp_init_size: int = -1, + pp_init_size: int = -1, + host: str = "localhost", + port: int = 29500, + dtype = torch.float, + checkpoint: str = None, + tokenizer_path: str = None, + server_host = "localhost", + server_port = 8005, + log_level = "info" + ): + + if checkpoint: + model_config = {'dtype': dtype, 'checkpoint': True, 'checkpoint_path': checkpoint} + else: + model_config = {'dtype': dtype} + + global engine + engine = InferenceEngine(model_class, + model_config, + model_type, + max_batch_size = max_batch_size, + tp_init_size = tp_init_size, + pp_init_size = pp_init_size, + auto_pp = mcfg['auto_pp'], + host = host, + port = port, + dtype = dtype) + + global server + config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level) + server = uvicorn.Server(config=config) + server.run() \ No newline at end of file diff --git a/examples/auto_pipeline/run.py b/examples/auto_pipeline/run.py new file mode 100644 index 0000000..2ebd8f5 --- /dev/null +++ b/examples/auto_pipeline/run.py @@ -0,0 +1,70 @@ +from bert import bert_small, bert_large, bert_xl, bert_8B, bert_175B +# from bert import BertEmbedding1D, BertTransformerLayer1D +from colossalai import launch_from_torch +from energonai.pipelinable import split_transformer_into_partitions + +config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=1, mode='1d'))) +launch_from_torch(config) + + +from energonai.context import mcfg +mcfg.load_config("/home/lcdjs/EnergonAI/examples/auto_pipeline/bert_config.py") +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode + +import torch.fx + +batch_size = 4 +seq_len = 128 + +def filter_inputs(traced: torch.fx.GraphModule): + inputs = {} + for node in traced.graph.nodes: + if node.op == 'placeholder': + inputs[node.name] = None + else: + break + + return inputs + +if gpc.get_local_rank(ParallelMode.GLOBAL) == 0: + submodules = split_transformer_into_partitions(bert_large) + + # print(submodules.code) + + # for i in enumerate(submodules.children()): + # print(i) + model0 = submodules.get_submodule('submod_0')#.graph #.print_tabular() + + model1 = submodules.get_submodule('submod_1')#.graph #.print_tabular() + + print(filter_inputs(model0)) + print(filter_inputs(model1)) + + + + # print(len(submodules.children())) + + # input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64).cpu() + # attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64).cpu() + + # # model = submodules.submod_0 + # model0 = model0.cpu() + # model1 = model1.cpu() + # # # print(model.parameters().device) + # output = model0(input_ids = input_ids, attention_mask=attention_mask) + + # output = model1(blocks_11 = output, attention_mask=attention_mask) + + # submodules.to_folder("/home/lcdjs/EnergonAI/examples/auto_pipeline") + # print(submodules.submod_1.code) + +# model2 = submodules.submod_0() + +# model_config = {'dtype': torch.half} + +# model = bert_large(**model_config) + +# graph = EnergonTracer().trace(model) +# traced = torch.fx.GraphModule(model, graph) +# print(traced.code) From 58935372568e8bc7fbf5e362451a9eefdc8176c6 Mon Sep 17 00:00:00 2001 From: dujiangsu Date: Wed, 8 Jun 2022 17:11:33 +0800 Subject: [PATCH 2/2] del demo run --- examples/auto_pipeline/run.py | 70 ----------------------------------- 1 file changed, 70 deletions(-) delete mode 100644 examples/auto_pipeline/run.py diff --git a/examples/auto_pipeline/run.py b/examples/auto_pipeline/run.py deleted file mode 100644 index 2ebd8f5..0000000 --- a/examples/auto_pipeline/run.py +++ /dev/null @@ -1,70 +0,0 @@ -from bert import bert_small, bert_large, bert_xl, bert_8B, bert_175B -# from bert import BertEmbedding1D, BertTransformerLayer1D -from colossalai import launch_from_torch -from energonai.pipelinable import split_transformer_into_partitions - -config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=1, mode='1d'))) -launch_from_torch(config) - - -from energonai.context import mcfg -mcfg.load_config("/home/lcdjs/EnergonAI/examples/auto_pipeline/bert_config.py") -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode - -import torch.fx - -batch_size = 4 -seq_len = 128 - -def filter_inputs(traced: torch.fx.GraphModule): - inputs = {} - for node in traced.graph.nodes: - if node.op == 'placeholder': - inputs[node.name] = None - else: - break - - return inputs - -if gpc.get_local_rank(ParallelMode.GLOBAL) == 0: - submodules = split_transformer_into_partitions(bert_large) - - # print(submodules.code) - - # for i in enumerate(submodules.children()): - # print(i) - model0 = submodules.get_submodule('submod_0')#.graph #.print_tabular() - - model1 = submodules.get_submodule('submod_1')#.graph #.print_tabular() - - print(filter_inputs(model0)) - print(filter_inputs(model1)) - - - - # print(len(submodules.children())) - - # input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64).cpu() - # attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64).cpu() - - # # model = submodules.submod_0 - # model0 = model0.cpu() - # model1 = model1.cpu() - # # # print(model.parameters().device) - # output = model0(input_ids = input_ids, attention_mask=attention_mask) - - # output = model1(blocks_11 = output, attention_mask=attention_mask) - - # submodules.to_folder("/home/lcdjs/EnergonAI/examples/auto_pipeline") - # print(submodules.submod_1.code) - -# model2 = submodules.submod_0() - -# model_config = {'dtype': torch.half} - -# model = bert_large(**model_config) - -# graph = EnergonTracer().trace(model) -# traced = torch.fx.GraphModule(model, graph) -# print(traced.code)