From 5201e018b5a63322d6f5cd5c9f79f11684ca3cc4 Mon Sep 17 00:00:00 2001 From: dujiangsu Date: Wed, 11 May 2022 15:18:52 +0800 Subject: [PATCH] combine two pipeline wrapper --- energon/engine/engine.py | 1 - energon/engine/gpt_pipeline_wrapper.py | 23 ++-- energon/engine/pipeline_wrapper.py | 149 +++++++++++++++++++++++++ energon/engine/rpc_worker.py | 12 +- examples/gpt/gpt.py | 25 ++--- 5 files changed, 178 insertions(+), 32 deletions(-) create mode 100644 energon/engine/pipeline_wrapper.py diff --git a/energon/engine/engine.py b/energon/engine/engine.py index fb1685b..790c63a 100644 --- a/energon/engine/engine.py +++ b/energon/engine/engine.py @@ -17,7 +17,6 @@ from energon.utils import ensure_directory_exists from energon.logging import get_dist_logger -from energon.nn import PipelineCommWrapper class InferenceEngine(Module): diff --git a/energon/engine/gpt_pipeline_wrapper.py b/energon/engine/gpt_pipeline_wrapper.py index dfacef4..58f726c 100644 --- a/energon/engine/gpt_pipeline_wrapper.py +++ b/energon/engine/gpt_pipeline_wrapper.py @@ -60,7 +60,7 @@ def _init_tensor_meta(self, sample): input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now self.tensor_dim = input_tensor.dim() self.hidden_size = input_tensor.size()[-1] - output = self.model(hidden_states=None, input_ids=input_tensor, attention_mask=sample['attention_mask']) + output = self.model(hidden_states=input_tensor, input_ids=input_tensor, attention_mask=sample['attention_mask']) send_tensor_meta(output) send_forward(output) @@ -92,9 +92,15 @@ def run_without_pp(self, key, inputs): ''' def fill_meta_tensor(self, inputs, pipe_meta): - pipe_meta.get_meta_tensor()[0] = inputs['input_ids'].shape[0] - pipe_meta.get_meta_tensor()[1] = inputs['input_ids'].shape[0] - pipe_meta.get_meta_tensor()[2] = inputs['input_ids'].shape[1] + if 'seq_lens' in inputs: + pipe_meta.get_meta_tensor()[0] = 1 + pipe_meta.get_meta_tensor()[1] = 1 + pipe_meta.get_meta_tensor()[2] = torch.sum(inputs['seq_lens']) + else: + pipe_meta.get_meta_tensor()[0] = inputs['input_ids'].shape[0] + pipe_meta.get_meta_tensor()[1] = inputs['input_ids'].shape[0] + pipe_meta.get_meta_tensor()[2] = inputs['input_ids'].shape[1] + pipe_meta.get_meta_tensor()[3] = self.hidden_size pipe_meta.update_meta() @@ -113,7 +119,8 @@ def run_with_pp(self, key, inputs): if gpc.is_first_rank(ParallelMode.PIPELINE): output = self.model(hidden_states=None, input_ids=sample['input_ids'], - attention_mask=sample['attention_mask']) + attention_mask=sample['attention_mask'], + seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None) send_forward(output) self.lock.release() @@ -125,7 +132,8 @@ def run_with_pp(self, key, inputs): input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype) output = self.model(hidden_states=input_tensor, input_ids=sample['input_ids'], - attention_mask=sample['attention_mask']) + attention_mask=sample['attention_mask'], + seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None) self.lock.release() return output, cur_key @@ -134,7 +142,8 @@ def run_with_pp(self, key, inputs): input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype) output = self.model(hidden_states=input_tensor, input_ids=sample['input_ids'], - attention_mask=sample['attention_mask']) + attention_mask=sample['attention_mask'], + seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None) send_forward(output) self.lock.release() return None diff --git a/energon/engine/pipeline_wrapper.py b/energon/engine/pipeline_wrapper.py new file mode 100644 index 0000000..b558b00 --- /dev/null +++ b/energon/engine/pipeline_wrapper.py @@ -0,0 +1,149 @@ +import inspect +import threading + +import torch +import torch.nn as nn +import torch.distributed as dist +from typing import List, Tuple, Union + +from energon.communication import send_forward, recv_forward, send_tensor_meta, recv_tensor_meta +from energon.context import ParallelMode +from energon.core import global_context as gpc + +from .pipeline_meta import PipelineMeta +from .pipeline_msg_dict import PipelineMsgDict, CircleInt # PipelineMsgPriorityQueue + + +# The Wrapper is only for Transformer Model. +class PipelineCommWrapper: + def __init__(self, + model: nn.Module, + max_batch_size: int = 1, + dtype=torch.float) -> None: + # TODO (dujiangsu): to make sample capability for different types. Iteration, Tensor, and others. + self.model = model + self.dtype = dtype + + self.tensor_dim = 0 + self.hidden_size = 0 + self.max_batch_size = max_batch_size + + if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: + input_ids = torch.randint(1, 10, (max_batch_size, 512), dtype=torch.int64).cuda() + attention_mask = torch.randint(0, 1, (max_batch_size, 1, 512), dtype=torch.int64).cuda() + hidden_states = None + sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask) + self._init_tensor_meta(sample) + + self.pipe_msg_queue = PipelineMsgDict() + self.lock = threading.Lock() + self.key = CircleInt() + + def _init_tensor_meta(self, sample): + + with torch.inference_mode(): + recv_tensor_shape = None + if gpc.is_first_rank(ParallelMode.PIPELINE): + output = self.model(hidden_states=None, input_ids=sample['input_ids'], + attention_mask=sample['attention_mask']) # ([32, 512, 1600]) + send_tensor_meta(output) + send_forward(output) + self.tensor_dim = output.dim() + self.hidden_size = output.size()[-1] + elif gpc.is_last_rank(ParallelMode.PIPELINE): + recv_tensor_shape = recv_tensor_meta(recv_tensor_shape) + input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now + self.tensor_dim = input_tensor.dim() + self.hidden_size = input_tensor.size()[-1] + else: + recv_tensor_shape = recv_tensor_meta(recv_tensor_shape) + input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now + self.tensor_dim = input_tensor.dim() + self.hidden_size = input_tensor.size()[-1] + output = self.model(hidden_states=input_tensor, input_ids=input_tensor, attention_mask=sample['attention_mask']) + send_tensor_meta(output) + send_forward(output) + + def run(self, key, inputs): + if gpc.is_initialized(ParallelMode.PIPELINE): + return self.run_with_pp(key, inputs) + else: + return self.run_without_pp(key, inputs) + + def run_without_pp(self, key, inputs): + pipe_meta = None + self.pipe_msg_queue.enqueue(key, inputs, pipe_meta) + + self.lock.acquire() + + cur_key = self.key.val + sample, pipe_meta = self.pipe_msg_queue.top(cur_key) + self.key.addOne() + output = self.model(hidden_states=None, + input_ids=sample['input_ids'], + attention_mask=sample['attention_mask']) + self.lock.release() + + return output, cur_key + + ''' + hidden_size : ([32, 512, 1600]) + For different model type, fill_meta_tensor is different + ''' + + def fill_meta_tensor(self, inputs, pipe_meta): + if 'seq_lens' in inputs: + pipe_meta.get_meta_tensor()[0] = 1 + pipe_meta.get_meta_tensor()[1] = 1 + pipe_meta.get_meta_tensor()[2] = torch.sum(inputs['seq_lens']) + else: + pipe_meta.get_meta_tensor()[0] = inputs['input_ids'].shape[0] + pipe_meta.get_meta_tensor()[1] = inputs['input_ids'].shape[0] + pipe_meta.get_meta_tensor()[2] = inputs['input_ids'].shape[1] + + pipe_meta.get_meta_tensor()[3] = self.hidden_size + pipe_meta.update_meta() + + def run_with_pp(self, key, inputs): + pipe_meta = PipelineMeta(self.tensor_dim, self.max_batch_size) + self.fill_meta_tensor(inputs, pipe_meta) + self.pipe_msg_queue.enqueue(key, inputs, pipe_meta) + + self.lock.acquire() + cur_key = self.key.val + sample, pipe_meta = self.pipe_msg_queue.top(cur_key) + self.key.addOne() + + with torch.inference_mode(): + + if gpc.is_first_rank(ParallelMode.PIPELINE): + output = self.model(hidden_states=None, + input_ids=sample['input_ids'], + attention_mask=sample['attention_mask'], + seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None) + + send_forward(output) + self.lock.release() + return None + + if gpc.is_last_rank(ParallelMode.PIPELINE): + + # print(f'get_tensor_shapes:{pipe_meta.get_tensor_shapes()}') + input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype) + output = self.model(hidden_states=input_tensor, + input_ids=sample['input_ids'], + attention_mask=sample['attention_mask'], + seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None) + self.lock.release() + return output, cur_key + + else: + + input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype) + output = self.model(hidden_states=input_tensor, + input_ids=sample['input_ids'], + attention_mask=sample['attention_mask'], + seq_lens=inputs['seq_lens'] if 'seq_lens' in inputs else None) + send_forward(output) + self.lock.release() + return None diff --git a/energon/engine/rpc_worker.py b/energon/engine/rpc_worker.py index 2001bc9..060a6d9 100644 --- a/energon/engine/rpc_worker.py +++ b/energon/engine/rpc_worker.py @@ -7,14 +7,7 @@ from energon.core import global_context as gpc from energon.context import ParallelMode from .rpc_utils import remote_cls_method, sync_cls_method, async_cls_method -from .gpt_pipeline_wrapper import GPTPipelineCommWrapper -from .bert_pipeline_wrapper import BertPipelineCommWrapper - -WRAPPER_TYPES = { - "gpt": GPTPipelineCommWrapper, - "bert": BertPipelineCommWrapper, -} - +from .pipeline_wrapper import PipelineCommWrapper class ReturnDict: def __init__(self): @@ -41,7 +34,6 @@ def __init__(self, self.model_config = model_config self.dtype = dtype self.max_batch_size = max_batch_size - self.pipe_wrapper = WRAPPER_TYPES[model_type] self.WORKER_NAME = "wok{}" self.model = None # call the model @@ -62,7 +54,7 @@ def _init_self(self): # print("Pass") self.model.eval() - self.model = self.pipe_wrapper(model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype) + self.model = PipelineCommWrapper(model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype) def run(self, key, inputs): # print("key: {}".format(key), flush=True) diff --git a/examples/gpt/gpt.py b/examples/gpt/gpt.py index 0fa662b..401d00d 100644 --- a/examples/gpt/gpt.py +++ b/examples/gpt/gpt.py @@ -89,7 +89,7 @@ def __init__(self, self.softmax = nn.Softmax(dim=-1) self.dense = Linear1D_Row(dim, dim, bias=True, dtype=dtype, parallel_input=True) - def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None, seq_lens=None): + def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None, seq_lens=None, valid_word_num=None): qkv = self.query_key_value(x) all_head_size = qkv.shape[-1] // 3 num_attention_heads = divide(all_head_size, self.attention_head_size) # num_heads @@ -115,6 +115,7 @@ def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8, device=get_current_device())).view(1, 1, q_len, k_len).bool() x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device())) + if attention_mask is not None: x = x + attention_mask x = self.softmax(x) @@ -122,7 +123,7 @@ def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None x = torch.matmul(x, v) if seq_lens is not None: - x = transpose_depad(x, batch_size, valid_word_num[0].item(), max_padding_size, seq_lens, num_attention_heads, self.attention_head_size) + x = transpose_depad(x, batch_size, valid_word_num, max_padding_size, seq_lens, num_attention_heads, self.attention_head_size) else: x = x.transpose(1, 2) @@ -142,6 +143,7 @@ def __init__(self, dtype: dtype = None, bias: bool = True): super().__init__() + intermediate_dim = int(dim * mlp_ratio) self.dense_1 = Linear1D_Col(dim, intermediate_dim, bias=bias, dtype=dtype, gather_output=False) self.activation = activation @@ -179,13 +181,13 @@ def __init__(self, self.norm2 = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon) self.mlp = GPTMLP1D(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dtype=dtype, bias=bias) - def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None, seq_lens=None): + def forward(self, x, attention_mask=None, batch_size=None, max_padding_size=None, seq_lens=None, valid_word_num=None): if not self.apply_post_layernorm: residual = x x = self.norm1(x) if self.apply_post_layernorm: residual = x - x = residual + self.attn(x, attention_mask, batch_size, max_padding_size, seq_lens) + x = residual + self.attn(x, attention_mask, batch_size, max_padding_size, seq_lens, valid_word_num) if not self.apply_post_layernorm: residual = x @@ -295,30 +297,25 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_l if seq_lens is not None: hidden_states = ft_remove_padding(hidden_states, self.tmp_mask_offset, self.mask_offset, self.valid_word_num[0].item(), self.dim) - elif seq_lens is not None: - ft_remove_padding(hidden_states, self.tmp_mask_offset, - self.mask_offset, self.valid_word_num[0].item(), self.dim) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # Adapted from huggingface + if attention_mask is not None: - if self.first: - batch_size = input_ids.shape[0] - else: - batch_size = hidden_states.shape[0] attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 + for block in self.blocks: - hidden_states = block(hidden_states, attention_mask, batch_size, max_padding_size, seq_lens) + hidden_states = block(hidden_states, attention_mask, batch_size, max_padding_size, seq_lens, self.valid_word_num[0].item()) if self.last: if seq_lens is not None: - hidden_states = ft_rebuild_padding(hidden_states, self.mask_offset, self.valid_word_num[0].item(), self.dim, batch_size, max_padding_size) + hidden_states = ft_rebuild_padding(hidden_states, self.tmp_mask_offset[0:self.valid_word_num[0].item()], self.valid_word_num[0].item(), self.dim, batch_size, max_padding_size) hidden_states = self.head(self.norm(hidden_states)) # res = [] # for i in range(hidden_states.shape[0]): @@ -420,4 +417,4 @@ def gpt2_8B(**kwargs): def gpt3(**kwargs): model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs) - return _create_gpt_pipeline_model(**model_kwargs) + return _create_gpt_pipeline_model(**model_kwargs) \ No newline at end of file