diff --git a/README.md b/README.md
index 308ca0f..eac5c10 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,7 @@ Energon-AI provides 3 levels of abstraction for enabling the large-scale model i
For models trained by [Colossal-AI](https://github.com/hpcaitech/ColossalAI), they can be seamlessly transferred to Energon-AI.
For single-device models, they require manual coding works to introduce tensor parallelism and pipeline parallelism.
-At present, we pre-build distributed Bert, GPT, and ViT models.
+At present, we pre-build distributed Bert, GPT, and ViT models.
For GPT, it extends to at most 175B parameters, which is called [GPT3](https://arxiv.org/abs/2005.14165).
For Bert, Google reports a [super-large Bert with 481B parameters](https://mlcommons.org/en/training-normal-11/) in MLPerf-Training v1.1 open, indicating that Bert can also extend to large-scale.
@@ -55,8 +55,9 @@ Method 2:
#### Scaling Ability
Here GPT3-12-layers in FP16 is adopted.
-Here a node with 8 A100 80 GB GPUs is adopted. GPUs are fully connected with NvLink.
-Energon-AI adopts the redundant computation elimination method from [EffectiveTransformer](https://github.com/bytedance/effective_transformer) and the sequence length is set the half of the padding length.
+Here a node with 8 A100 80 GB GPUs is adopted. GPUs are fully connected with NvLink.
+Energon-AI adopts the redundant computation elimination method. The method is first raised in [EffectiveTransformer](https://github.com/bytedance/effective_transformer), and our implementation refers to [TurboTransformer](https://github.com/Tencent/TurboTransformers/blob/master/turbo_transformers/layers/kernels/gpu_transpose_kernel.cu).
+Here the sequence length is set the half of the padding length.
diff --git a/energonai/engine/auto_pipeline_wrapper.py b/energonai/engine/auto_pipeline_wrapper.py
new file mode 100644
index 0000000..cabb869
--- /dev/null
+++ b/energonai/engine/auto_pipeline_wrapper.py
@@ -0,0 +1,154 @@
+import threading
+
+import torch
+import torch.fx
+import torch.nn as nn
+from typing import List, Tuple, Union
+
+from energonai.communication import send_forward, recv_forward, send_tensor_meta, recv_tensor_meta
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+
+from .pipeline_meta import PipelineMeta
+from .pipeline_msg_dict import PipelineMsgDict, CircleInt # PipelineMsgPriorityQueue
+
+def filter_inputs(traced: torch.fx.GraphModule):
+ inputs = {}
+ for node in traced.graph.nodes:
+ if node.op == 'placeholder':
+ inputs[node.name] = None
+ else:
+ break
+ return inputs
+
+# The Wrapper is only for Transformer Model.
+class AutoPipelineCommWrapper:
+
+ def __init__(self, model: nn.Module, max_batch_size: int = 1, dtype=torch.float) -> None:
+ # TODO (dujiangsu): to make sample capability for different types. Iteration, Tensor, and others.
+ self.model = model
+ self.dtype = dtype
+
+ self.tensor_dim = 0
+ self.hidden_size = 0
+ self.max_batch_size = max_batch_size
+
+ if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
+ input_ids = torch.randint(1, 10, (max_batch_size, 512), dtype=torch.int64).cuda()
+ attention_mask = torch.randint(0, 1, (max_batch_size, 1, 512), dtype=torch.int64).cuda()
+ sample = dict(input_ids=input_ids, attention_mask=attention_mask)
+ self._init_tensor_meta(sample)
+
+ self.pipe_msg_queue = PipelineMsgDict()
+ self.lock = threading.Lock()
+ self.key = CircleInt()
+
+ def _init_tensor_meta(self, sample):
+
+ with torch.inference_mode():
+ recv_tensor_shape = None
+ if gpc.is_first_rank(ParallelMode.PIPELINE):
+ # inputs = filter_inputs(self.model)
+ output = self.model(sample['input_ids'],
+ sample['attention_mask']) # ([32, 512, 1600])
+ send_tensor_meta(output)
+ send_forward(output)
+ self.tensor_dim = output.dim()
+ self.hidden_size = output.size()[-1]
+ elif gpc.is_last_rank(ParallelMode.PIPELINE):
+ recv_tensor_shape = recv_tensor_meta(recv_tensor_shape)
+ input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
+ self.tensor_dim = input_tensor.dim()
+ self.hidden_size = input_tensor.size()[-1]
+ else:
+ recv_tensor_shape = recv_tensor_meta(recv_tensor_shape)
+ input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
+ self.tensor_dim = input_tensor.dim()
+ self.hidden_size = input_tensor.size()[-1]
+ output = self.model(input_tensor,
+ sample['attention_mask'])
+ send_tensor_meta(output)
+ send_forward(output)
+
+ def run(self, key, inputs):
+ if gpc.is_initialized(ParallelMode.PIPELINE):
+ return self.run_with_pp(key, inputs)
+ # else:
+ # return self.run_without_pp(key, inputs)
+
+ # def run_without_pp(self, key, inputs):
+ # pipe_meta = None
+ # self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)
+
+ # self.lock.acquire()
+
+ # cur_key = self.key.val
+ # sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
+ # self.key.addOne()
+ # with torch.inference_mode():
+ # output = self.model(hidden_states=None,
+ # input_ids=sample['input_ids'],
+ # attention_mask=sample['attention_mask'],
+ # seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)
+ # self.lock.release()
+
+ # return output, cur_key
+
+ '''
+ hidden_size : ([32, 512, 1600])
+ For different model type, fill_meta_tensor is different
+ '''
+
+ def fill_meta_tensor(self, inputs, pipe_meta):
+ if 'seq_lens' in inputs:
+ pipe_meta.get_meta_tensor()[0] = 1
+ pipe_meta.get_meta_tensor()[1] = 1
+ pipe_meta.get_meta_tensor()[2] = torch.sum(inputs['seq_lens'])
+ else:
+ pipe_meta.get_meta_tensor()[0] = inputs['input_ids'].shape[0]
+ pipe_meta.get_meta_tensor()[1] = inputs['input_ids'].shape[0]
+ pipe_meta.get_meta_tensor()[2] = inputs['input_ids'].shape[1]
+
+ pipe_meta.get_meta_tensor()[3] = self.hidden_size
+ pipe_meta.update_meta()
+
+ def run_with_pp(self, key, inputs):
+ pipe_meta = PipelineMeta(self.tensor_dim, self.max_batch_size)
+ self.fill_meta_tensor(inputs, pipe_meta)
+ self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)
+
+ self.lock.acquire()
+ cur_key = self.key.val
+ sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
+ self.key.addOne()
+
+ with torch.inference_mode():
+
+ if gpc.is_first_rank(ParallelMode.PIPELINE):
+ output = self.model(sample['input_ids'],
+ sample['attention_mask'])
+ # seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None)
+
+ send_forward(output)
+ self.lock.release()
+ return None
+
+ if gpc.is_last_rank(ParallelMode.PIPELINE):
+
+ # print(f'get_tensor_shapes:{pipe_meta.get_tensor_shapes()}')
+ input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
+ output = self.model(input_tensor,
+ sample['attention_mask'])
+ self.lock.release()
+ return output, cur_key
+
+ else:
+
+ input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
+ output = self.model(input_tensor,
+ sample['input_ids'])
+ send_forward(output)
+ self.lock.release()
+ return None
+
+
diff --git a/energonai/engine/engine.py b/energonai/engine/engine.py
index 325f6ef..84e1cc5 100644
--- a/energonai/engine/engine.py
+++ b/energonai/engine/engine.py
@@ -17,7 +17,6 @@
from colossalai.logging import get_dist_logger
from energonai.initialize import launch_from_multiprocess
-from energonai.utils import ensure_directory_exists
@@ -33,9 +32,11 @@ def __init__(self,
max_batch_size: int = 1,
tp_init_size: int = -1,
pp_init_size: int = -1,
+ auto_pp: bool = False,
host: str = 'localhost',
port: int = 29500,
- dtype=None):
+ dtype=None,
+ ):
"""
Args:
model: torch.nn.Module
@@ -57,8 +58,10 @@ def __init__(self,
self.tp_size = tp_init_size
self.pp_size = pp_init_size
- # for TP
+ # for TP, PP
self.rrefs = []
+ self.auto_pp = auto_pp
+
# for rpc
self.WORKER_NAME = "wok{}"
self._init_dist_rpc()
@@ -93,7 +96,7 @@ def _init_model(self):
rpc.remote(ob_info,
RPCWorker,
args=(self.model_class, self.model_config, self.model_type, self.dtype,
- self.max_batch_size)))
+ self.max_batch_size, self.auto_pp)))
def run(self, inputs):
res_rref = 0
diff --git a/energonai/engine/rpc_worker.py b/energonai/engine/rpc_worker.py
index 5894ef9..97f25ec 100644
--- a/energonai/engine/rpc_worker.py
+++ b/energonai/engine/rpc_worker.py
@@ -7,6 +7,9 @@
from .pipeline_wrapper import PipelineCommWrapper
from .vit_pipeline_wrapper import ViTPipelineCommWrapper
+from .auto_pipeline_wrapper import AutoPipelineCommWrapper
+
+from energonai.pipelinable import split_transformer_into_partitions
from energonai.context import mcfg
@@ -15,9 +18,15 @@
pipe_wrapper = {
'vit': ViTPipelineCommWrapper,
'bert': PipelineCommWrapper,
- 'gpt': PipelineCommWrapper
+ 'gpt': PipelineCommWrapper,
+ 'auto': AutoPipelineCommWrapper,
}
+pipe_split = {
+ 'bert': split_transformer_into_partitions,
+ 'gpt': split_transformer_into_partitions,
+ }
+
class ReturnDict:
@@ -38,13 +47,14 @@ def top(self, key):
class RPCWorker:
- def __init__(self, model_class, model_config, model_type, dtype, max_batch_size: int = 1) -> None:
+ def __init__(self, model_class, model_config, model_type, dtype, max_batch_size: int = 1, auto_pp: bool = False) -> None:
self.model_class = model_class
self.model_config = model_config
self.dtype = dtype
self.max_batch_size = max_batch_size
self.model_type = model_type
+ # self.auto_pp = auto_pp
self.WORKER_NAME = "wok{}"
self.model = None # call the model
@@ -52,8 +62,19 @@ def __init__(self, model_class, model_config, model_type, dtype, max_batch_size:
torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}')
# self.trt_sample = None
- self._init_self()
- self.return_dict = ReturnDict()
+ if auto_pp:
+ self._auto_pp_init_model()
+ else:
+ self._init_self()
+ self.return_dict = ReturnDict()
+
+ def _auto_pp_init_model(self):
+ logger.info("Init automatic pipeline model in rank {}".format(self.rank))
+ submodules = pipe_split[self.model_type](self.model_class)
+ self.model = submodules.get_submodule(f'submod_{gpc.get_local_rank(ParallelMode.PIPELINE)}')
+ del submodules
+ self.model = pipe_wrapper['auto'](model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype)
+
def _init_self(self):
logger.info("Init model in rank {}".format(self.rank))
@@ -69,7 +90,7 @@ def _init_self(self):
try:
logger.info('Import Torch2Trt')
from torch2trt import torch2trt
- from energonai.engine import trt_converter
+ from energonai.engine import trt_converter
except:
logger.error("Installation Required, \n \
follow https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html \
@@ -100,4 +121,4 @@ def run(self, key, inputs):
self.return_dict.enqueue(cur_key, output.cpu())
return self.return_dict.top(key)
- return None
+ # return None
diff --git a/energonai/pipelinable/__init__.py b/energonai/pipelinable/__init__.py
new file mode 100644
index 0000000..6cde0b3
--- /dev/null
+++ b/energonai/pipelinable/__init__.py
@@ -0,0 +1 @@
+from .split_method import split_transformer_into_partitions
diff --git a/energonai/pipelinable/energon_tracer.py b/energonai/pipelinable/energon_tracer.py
new file mode 100644
index 0000000..f0a8cd1
--- /dev/null
+++ b/energonai/pipelinable/energon_tracer.py
@@ -0,0 +1,7 @@
+import torch.fx
+from energonai.context import mcfg
+
+class EnergonTracer(torch.fx.Tracer):
+ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
+ leaves = mcfg["LeafSet"] # set([BertTransformerLayer])
+ return type(m) in leaves
\ No newline at end of file
diff --git a/energonai/pipelinable/split_method.py b/energonai/pipelinable/split_method.py
new file mode 100644
index 0000000..7f24621
--- /dev/null
+++ b/energonai/pipelinable/split_method.py
@@ -0,0 +1,22 @@
+from torch.fx.passes.split_module import split_module
+from .split_policy import module_equal_partition, naive_equal_partition, transformer_partition
+from .energon_tracer import EnergonTracer
+import torch.fx
+
+def filter_graph(traced: torch.fx.GraphModule, filter_type: str):
+ len = 0
+ for node in traced.graph.nodes:
+ if node.op == filter_type:
+ len = len + 1
+ return len
+
+
+def split_transformer_into_partitions(model_class):
+ model = model_class()
+ graph = EnergonTracer().trace(model)
+ traced = torch.fx.GraphModule(model, graph)
+ depth = filter_graph(traced, "call_module") - 1
+ submodules = split_module(traced, model, transformer_partition(depth))
+ del model
+
+ return submodules
\ No newline at end of file
diff --git a/energonai/pipelinable/split_policy.py b/energonai/pipelinable/split_policy.py
new file mode 100644
index 0000000..5e62fb4
--- /dev/null
+++ b/energonai/pipelinable/split_policy.py
@@ -0,0 +1,52 @@
+import functools
+from torch.fx.node import Node
+from energonai.context import mcfg
+
+
+partition_counter_0 = 0
+
+# partition_nums: nums of each submodule
+def _naive_equal_partition(node: Node, partition_nums):
+ global partition_counter_0
+ partition = partition_counter_0 // partition_nums
+ partition_counter_0 = partition_counter_0 + 1
+ return partition
+
+def naive_equal_partition(partition_nums):
+ mod_partition = functools.partial(_naive_equal_partition, partition_nums = partition_nums)
+ return mod_partition
+
+partition_counter_1 = 0
+
+# partition_nums: nums of each submodule
+def _module_equal_partition(node: Node, partition_nums):
+ global partition_counter_1
+ partition = partition_counter_1 // partition_nums
+ if node.op == 'call_module':
+ partition_counter_1 = partition_counter_1 + 1
+ return partition
+
+def module_equal_partition(partition_nums):
+ mod_partition = functools.partial(_module_equal_partition, partition_nums = partition_nums)
+ return mod_partition
+
+
+
+from colossalai.core import global_context as gpc
+from colossalai.context import ParallelMode
+partition_counter_2 = -1 # for embedding layer
+# partition_nums: nums of each submodule
+def _transformer_partition(node: Node, depth):
+ global partition_counter_2
+ assert gpc.is_initialized(ParallelMode.PIPELINE), "Pipeline communication group should be initialized!"
+
+ pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
+ partition_nums = depth // pipeline_size
+ partition = abs(partition_counter_2) // partition_nums
+ if node.op == 'call_module':
+ partition_counter_2 = partition_counter_2 + 1
+ return partition
+
+def transformer_partition(depth):
+ mod_partition = functools.partial(_transformer_partition, depth = depth)
+ return mod_partition
\ No newline at end of file
diff --git a/examples/auto_pipeline/bert.py b/examples/auto_pipeline/bert.py
new file mode 100644
index 0000000..ea42017
--- /dev/null
+++ b/examples/auto_pipeline/bert.py
@@ -0,0 +1,290 @@
+import math
+from typing import Callable
+
+import os
+import torch
+from torch import nn as nn, Tensor, dtype
+
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from energonai.logging import get_dist_logger
+from colossalai.nn.layer.utils import divide, ACT2FN
+from colossalai.nn import Linear1D_Col, Linear1D_Row, Classifier1D
+from colossalai.nn import LayerNorm1D
+from energonai.kernel import transpose_pad, transpose_depad, depad
+from energonai.nn import VocabParallelEmbedding1D
+from energonai.utils import get_current_device, is_using_pp
+
+__all__ = [
+ 'BertEmbedding1D'
+ 'BertMLP1D',
+ 'BertSelfAttention1D',
+ 'BertTransformerLayer1D'
+]
+
+from energonai.utils.checkpointing import load_checkpoint
+
+
+class BertEmbedding1D(nn.Module):
+ def __init__(self,
+ embedding_dim: int, # hidden_size
+ vocab_size: int,
+ max_position_embeddings: int,
+ num_tokentypes: int = 0,
+ padding_idx: int = 0,
+ layernorm_epsilon: float = 1e-5,
+ dtype: dtype = None) -> None:
+ super().__init__()
+ self.word_embeddings = VocabParallelEmbedding1D(vocab_size, embedding_dim, padding_idx=padding_idx, dtype=dtype,
+ skip_tp=True)
+ self.position_embeddings = VocabParallelEmbedding1D(max_position_embeddings, embedding_dim, dtype=dtype,
+ skip_tp=True)
+ if num_tokentypes > 0:
+ self.tokentype_embeddings = VocabParallelEmbedding1D(num_tokentypes, embedding_dim, dtype=dtype)
+ else:
+ self.tokentype_embeddings = None
+
+ self.LayerNorm = LayerNorm1D(embedding_dim, eps=layernorm_epsilon)
+
+ def forward(self, input_ids, position_ids=None, tokentype_ids=None):
+ # max_padding_size = input_ids.shape[1]
+
+ # TODO: register_buffer in advance for position_ids to speedup
+
+ # if position_ids is None:
+ # position_ids = torch.arange(max_padding_size, dtype=torch.long, device=get_current_device()).unsqueeze(0)
+
+ x = self.word_embeddings(input_ids) # + self.position_embeddings(position_ids)
+
+ if self.tokentype_embeddings is not None and tokentype_ids is not None:
+ x = x + self.tokentype_embeddings(tokentype_ids)
+
+ x = self.LayerNorm(x)
+
+ # if seq_lens is not None:
+ # x = depad(x, batch_size, seq_lens)
+
+ return x
+
+
+class BertSelfAttention1D(nn.Module):
+ def __init__(self,
+ hidden_size: int,
+ num_heads: int,
+ bias: bool = True,
+ fuse_scale_mask_softmax: bool = False,
+ layernorm_epsilon: float = 1e-5,
+ dtype: dtype = None) -> None:
+ super().__init__()
+ if hidden_size % num_heads != 0:
+ raise ValueError(
+ f"The hidden size ({hidden_size}) is not a multiple of the number of attention ")
+ self.hidden_size = hidden_size
+ self.attention_head_size = divide(hidden_size, num_heads)
+ self.fuse_scale_mask_softmax = fuse_scale_mask_softmax
+
+ self.query_key_value = Linear1D_Col(hidden_size, 3 * hidden_size, bias=bias, dtype=dtype)
+
+ if fuse_scale_mask_softmax:
+ raise NotImplementedError
+
+ self.dense = Linear1D_Row(hidden_size, hidden_size, bias=True, dtype=dtype, parallel_input=True)
+ self.LayerNorm = LayerNorm1D(hidden_size, eps=layernorm_epsilon)
+
+ def forward(self, hidden_states, attention_mask=None):
+
+ attention_output = self.query_key_value(hidden_states)
+ all_head_size = attention_output.shape[-1] // 3
+ num_attention_heads = divide(all_head_size, self.attention_head_size) # num_heads
+
+ new_qkv_shape = attention_output.shape[:-1] + (num_attention_heads, 3 * self.attention_head_size)
+ attention_output = attention_output.view(new_qkv_shape)
+
+ # if seq_lens is not None:
+ # # TODO: use FasterTransformer's implementation.
+ # attention_output = transpose_pad(attention_output, batch_size, max_padding_size, seq_lens,
+ # num_attention_heads, self.attention_head_size * 3)
+ # else:
+ attention_output = attention_output.permute(0, 2, 1, 3)
+ # TODO: make sure self.attention_head_size*3 is correct
+
+ q, k, v = torch.chunk(attention_output, 3, dim=-1)
+
+ attention_output = torch.matmul(q, k.transpose(-1, -2))
+ if self.fuse_scale_mask_softmax:
+ raise NotImplementedError
+ else:
+ attention_output = attention_output / math.sqrt(self.attention_head_size)
+ # if attention_mask is not None:
+ # attention_output = attention_output + attention_mask
+ attention_output = nn.functional.softmax(attention_output, dim=-1)
+
+ attention_output = torch.matmul(attention_output, v)
+
+ # if seq_lens is not None:
+ # sum_seq = torch.sum(seq_lens)
+ # attention_output = transpose_depad(attention_output, batch_size, sum_seq, max_padding_size, seq_lens,
+ # num_attention_heads, self.attention_head_size)
+ # else:
+ attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
+
+ new_context_layer_shape = attention_output.size()[:-2] + (all_head_size,)
+ attention_output = attention_output.reshape(new_context_layer_shape)
+ attention_output = self.dense(attention_output)
+
+ hidden_states = self.LayerNorm(attention_output + hidden_states)
+
+ return hidden_states
+
+
+def gelu_impl(x):
+ """OpenAI's gelu implementation."""
+ return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
+ (1.0 + 0.044715 * x * x)))
+
+
+class BertMLP1D(nn.Module):
+ def __init__(self,
+ hidden_size: int,
+ mlp_ratio: float,
+ activation: Callable = gelu_impl,
+ layernorm_epsilon: float = 1e-5,
+ dtype: dtype = None,
+ bias: bool = True):
+ super().__init__()
+ intermediate_dim = int(hidden_size * mlp_ratio)
+ self.layer_0 = Linear1D_Col(hidden_size, intermediate_dim, bias=bias, dtype=dtype, gather_output=False)
+ self.activation = activation
+ self.layer_1 = Linear1D_Row(intermediate_dim, hidden_size, bias=bias, dtype=dtype, parallel_input=True)
+ self.LayerNorm = LayerNorm1D(hidden_size, eps=layernorm_epsilon)
+
+ def forward(self, input_tensor):
+ hidden_states = self.layer_0(input_tensor)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.layer_1(hidden_states)
+
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertTransformerLayer1D(nn.Module):
+ def __init__(self,
+ hidden_size: int,
+ num_heads: int,
+ mlp_ratio: float,
+ activation: Callable = gelu_impl,
+ layernorm_epsilon: float = 1e-5,
+ dtype: dtype = None,
+ bias: bool = True,
+ fuse_scale_mask_softmax: bool = False):
+ super().__init__()
+
+ self.attention = BertSelfAttention1D(hidden_size,
+ num_heads,
+ bias,
+ fuse_scale_mask_softmax,
+ layernorm_epsilon,
+ dtype)
+ self.mlp = BertMLP1D(hidden_size,
+ mlp_ratio,
+ activation,
+ layernorm_epsilon,
+ dtype,
+ bias)
+
+ def forward(self, hidden_states, attention_mask):
+
+ batch_size = hidden_states.shape[0]
+
+ if attention_mask is not None:
+ attention_mask = attention_mask.view(batch_size, -1)
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+ attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * - 10000.0
+
+ hidden_states = self.attention(hidden_states, attention_mask)
+ hidden_states = self.mlp(hidden_states)
+
+ return hidden_states
+
+class Bert1D(nn.Module):
+
+ def __init__(self,
+ vocab_size: int = 50304,
+ max_position_embeddings: int = 1024,
+ hidden_size: int = 768,
+ num_heads: int = 12,
+ depth: int = 12,
+ mlp_ratio: float = 4.0,
+ layernorm_epsilon: float = 1e-5,
+ activation: Callable = nn.functional.gelu,
+ padding_idx: int = 0,
+ dtype: dtype = None,
+ bias: bool = True,
+ fuse_scale_mask_softmax: bool = False,
+ ):
+ super().__init__()
+ self.embed = BertEmbedding1D(embedding_dim=hidden_size,
+ vocab_size=vocab_size,
+ max_position_embeddings=max_position_embeddings,
+ padding_idx=padding_idx,
+ layernorm_epsilon=layernorm_epsilon,
+ dtype=dtype)
+ self.blocks = nn.ModuleList()
+
+ for i in range(depth):
+ self.blocks.append(BertTransformerLayer1D(
+ hidden_size=hidden_size,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ activation=activation,
+ layernorm_epsilon=layernorm_epsilon,
+ dtype=dtype,
+ bias=bias,
+ fuse_scale_mask_softmax=fuse_scale_mask_softmax,)
+ )
+
+ def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_lens=None):
+
+ # batch_size = input_ids.shape[0]
+ # max_padding_size = input_ids.shape[1]
+
+ hidden_states = self.embed(input_ids=input_ids, position_ids=None, tokentype_ids=None) # , seq_lens
+
+ for block in self.blocks:
+ hidden_states = block(hidden_states=hidden_states, attention_mask=attention_mask)
+
+ hidden_states = hidden_states[:, 1, :]
+
+ return hidden_states
+
+
+
+def _create_bert_model(model_kwargs):
+ model = Bert1D(**model_kwargs)
+ return model
+
+
+def bert_small(**kwargs):
+ model_kwargs = dict(hidden_size=768, depth=12, num_heads=12, **kwargs)
+ return _create_bert_model(model_kwargs)
+
+
+def bert_large(**kwargs):
+ model_kwargs = dict(hidden_size=1024, depth=24, num_heads=16, **kwargs)
+ return _create_bert_model(model_kwargs)
+
+
+def bert_xl(**kwargs):
+ model_kwargs = dict(hidden_size=1600, depth=48, num_heads=16, **kwargs)
+ return _create_bert_model(model_kwargs)
+
+
+def bert_8B(**kwargs):
+ model_kwargs = dict(hidden_size=3072, depth=72, num_heads=24, **kwargs)
+ return _create_bert_model(model_kwargs)
+
+
+def bert_175B(**kwargs):
+ model_kwargs = dict(hidden_size=12288, depth=96, num_heads=96, **kwargs)
+ return _create_bert_model(model_kwargs)
diff --git a/examples/auto_pipeline/bert_config.py b/examples/auto_pipeline/bert_config.py
new file mode 100644
index 0000000..b373ffd
--- /dev/null
+++ b/examples/auto_pipeline/bert_config.py
@@ -0,0 +1,24 @@
+from bert import bert_small, bert_large, bert_xl, bert_8B, bert_175B
+from bert_server import launch_engine
+from bert import BertEmbedding1D, BertTransformerLayer1D
+
+model_class = bert_8B
+model_type = "bert"
+engine_server = launch_engine
+
+
+# parallel
+tp_init_size = 2
+pp_init_size = 2
+auto_pp = True
+LeafSet = set([BertTransformerLayer1D, BertEmbedding1D])
+
+
+
+host = "127.0.0.1"
+port = 29400
+half = False
+server_host = "127.0.0.1"
+server_port = 8010
+log_level = "info"
+backend = "nccl"
\ No newline at end of file
diff --git a/examples/auto_pipeline/bert_server.py b/examples/auto_pipeline/bert_server.py
new file mode 100644
index 0000000..946278b
--- /dev/null
+++ b/examples/auto_pipeline/bert_server.py
@@ -0,0 +1,92 @@
+import os
+import torch
+import uvicorn
+from fastapi import FastAPI
+from energonai.engine import InferenceEngine
+from energonai.context import mcfg
+
+app = FastAPI() # 创建 api 对象
+
+@app.get("/") # 根路由
+def root():
+ return {"200"}
+
+@app.get("/model_with_padding")
+def run():
+ # for the performance only
+ seq_len = 512
+ batch_size = 32
+
+ input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64)
+ attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64)
+ # seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int64) # generate seq_lens randomly
+ hidden_states = None
+ sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)
+
+ output = engine.run(sample)
+ output = output.to_here()
+ print(output)
+ return {"To return the string result."}
+
+# @app.get("/model_rm_padding")
+# def run():
+# # for the performance only
+# seq_len = 512
+# batch_size = 32
+
+# input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64)
+# attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64)
+# seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int) # generate seq_lens randomly
+# hidden_states = None
+# sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask, seq_lens=seq_lens)
+
+# output = engine.run(sample)
+# output = output.to_here()
+# print(output)
+# return {"To return the string result."}
+
+
+@app.get("/shutdown")
+async def shutdown():
+ engine.clear()
+ server.should_exit = True
+ server.force_exit = True
+ await server.shutdown()
+
+
+def launch_engine(model_class,
+ model_type,
+ max_batch_size: int = 1,
+ tp_init_size: int = -1,
+ pp_init_size: int = -1,
+ host: str = "localhost",
+ port: int = 29500,
+ dtype = torch.float,
+ checkpoint: str = None,
+ tokenizer_path: str = None,
+ server_host = "localhost",
+ server_port = 8005,
+ log_level = "info"
+ ):
+
+ if checkpoint:
+ model_config = {'dtype': dtype, 'checkpoint': True, 'checkpoint_path': checkpoint}
+ else:
+ model_config = {'dtype': dtype}
+
+ global engine
+ engine = InferenceEngine(model_class,
+ model_config,
+ model_type,
+ max_batch_size = max_batch_size,
+ tp_init_size = tp_init_size,
+ pp_init_size = pp_init_size,
+ auto_pp = mcfg['auto_pp'],
+ host = host,
+ port = port,
+ dtype = dtype)
+
+ global server
+ config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level)
+ server = uvicorn.Server(config=config)
+ server.run()
\ No newline at end of file