From c95c3268736b80c343911f2386c123259afab8b9 Mon Sep 17 00:00:00 2001 From: dujiangsu Date: Wed, 27 Apr 2022 17:00:59 +0800 Subject: [PATCH 1/2] add gpt example --- examples/gpt/gpt_config.py | 15 ++++++++ examples/gpt/gpt_server.py | 75 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 examples/gpt/gpt_config.py create mode 100644 examples/gpt/gpt_server.py diff --git a/examples/gpt/gpt_config.py b/examples/gpt/gpt_config.py new file mode 100644 index 0000000..c94792e --- /dev/null +++ b/examples/gpt/gpt_config.py @@ -0,0 +1,15 @@ +from gpt import gpt2_small, gpt2_medium, gpt2_large, gpt2_xl, gpt2_8B, gpt3 +from gpt_server import launch_engine + +model_class = gpt2_large +model_type = "gpt" +engine_server = launch_engine +tp_init_size = 2 +pp_init_size = 2 +host = "127.0.0.1" +port = 29400 +half = True +server_host = "127.0.0.1" +server_port = 8010 +log_level = "info" +backend = "nccl" \ No newline at end of file diff --git a/examples/gpt/gpt_server.py b/examples/gpt/gpt_server.py new file mode 100644 index 0000000..b4afbf4 --- /dev/null +++ b/examples/gpt/gpt_server.py @@ -0,0 +1,75 @@ +import os +import torch +import uvicorn +from fastapi import FastAPI +from fastapi import Response +import torch.distributed.rpc as rpc +from energon.engine import InferenceEngine + +app = FastAPI() # 创建 api 对象 + +@app.get("/") # 根路由 +def root(): + return {"200"} + +@app.get("/model_with_padding") +def run(): + # for the performance only + seq_len = 512 + batch_size = 32 + + input_ids = torch.randint(1, 10, (batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.randint(0, 1, (batch_size, 1, seq_len), dtype=torch.int64) + # seq_lens = torch.randint(1, 128, (batch_size, ), dtype=torch.int64) # generate seq_lens randomly + hidden_states = None + sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask) + + output = engine.run(sample) + output = output.to_here() + print(output) + return {"To return the string result."} + + +@app.get("/shutdown") +async def shutdown(): + engine.clear() + server.should_exit = True + server.force_exit = True + await server.shutdown() + + +def launch_engine(model_name, + model_type, + max_batch_size: int = 1, + tp_init_size: int = -1, + pp_init_size: int = -1, + host: str = "localhost", + port: int = 29500, + dtype = torch.float, + checkpoint: str = None, + tokenizer_path: str = None, + server_host = "localhost", + server_port = 8005, + log_level = "info" + ): + + if checkpoint: + model_config = {'dtype': dtype, 'checkpoint': True, 'checkpoint_path': checkpoint} + else: + model_config = {'dtype': dtype} + + global engine + engine = InferenceEngine(model_name, + model_config, + model_type, + max_batch_size = max_batch_size, + tp_init_size = tp_init_size, + pp_init_size = pp_init_size, + host = host, + port = port, + dtype = dtype) + + global server + config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level) + server = uvicorn.Server(config=config) + server.run() \ No newline at end of file From 5d2e6de0e187983565740e0e916ebb3b08e1bf42 Mon Sep 17 00:00:00 2001 From: dujiangsu Date: Wed, 27 Apr 2022 17:02:35 +0800 Subject: [PATCH 2/2] del model dir in energon --- energon/model/__init__.py | 2 - energon/model/bert/__init__.py | 4 - energon/model/bert/bert.py | 362 ------------------------- energon/model/gpt/__init__.py | 5 - energon/model/gpt/gpt.py | 470 -------------------------------- energon/model/gpt/hf_gpt2.py | 480 --------------------------------- 6 files changed, 1323 deletions(-) delete mode 100644 energon/model/__init__.py delete mode 100644 energon/model/bert/__init__.py delete mode 100644 energon/model/bert/bert.py delete mode 100644 energon/model/gpt/__init__.py delete mode 100644 energon/model/gpt/gpt.py delete mode 100644 energon/model/gpt/hf_gpt2.py diff --git a/energon/model/__init__.py b/energon/model/__init__.py deleted file mode 100644 index 360ac1f..0000000 --- a/energon/model/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .bert import * -from .gpt import * \ No newline at end of file diff --git a/energon/model/bert/__init__.py b/energon/model/bert/__init__.py deleted file mode 100644 index a6ceb46..0000000 --- a/energon/model/bert/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .bert import bert_small, bert_large, bert_xl, bert_8B, bert_175B - - -__all__ = ['bert_small', 'bert_large', 'bert_xl', 'bert_8B', 'bert_175B'] \ No newline at end of file diff --git a/energon/model/bert/bert.py b/energon/model/bert/bert.py deleted file mode 100644 index 5cb5e89..0000000 --- a/energon/model/bert/bert.py +++ /dev/null @@ -1,362 +0,0 @@ -import math -from typing import Callable - -import os -import torch -from torch import nn as nn, Tensor, dtype - -from energon.context import ParallelMode -from energon.core import global_context as gpc -from energon.logging import get_dist_logger -from energon.nn.layer.utils import divide, ACT2FN -from energon.nn import Linear1D_Col, Linear1D_Row, Classifier1D -from energon.nn import LayerNorm1D -from energon.kernel import transpose_pad, transpose_depad, depad -from energon.nn import VocabParallelEmbedding1D -from energon.utils import get_current_device, is_using_pp - -__all__ = [ - 'BertEmbedding1D' - 'BertMLP1D', - 'BertSelfAttention1D', - 'BertTransformerLayer1D' -] - -from energon.utils.checkpointing import load_checkpoint - - -class BertEmbedding1D(nn.Module): - def __init__(self, - embedding_dim: int, # hidden_size - vocab_size: int, - max_position_embeddings: int, - num_tokentypes: int = 0, - padding_idx: int = 0, - layernorm_epsilon: float = 1e-5, - dtype: dtype = None) -> None: - super().__init__() - self.word_embeddings = VocabParallelEmbedding1D(vocab_size, embedding_dim, padding_idx=padding_idx, dtype=dtype, skip_tp=True) - self.position_embeddings = VocabParallelEmbedding1D(max_position_embeddings, embedding_dim, dtype=dtype, skip_tp=True) - if num_tokentypes > 0: - self.tokentype_embeddings = VocabParallelEmbedding1D(num_tokentypes, embedding_dim, dtype=dtype) - else: - self.tokentype_embeddings = None - - # self.LayerNorm = nn.LayerNorm(embedding_dim, eps=layernorm_epsilon, dtype=dtype) - self.LayerNorm = LayerNorm1D(embedding_dim, eps=layernorm_epsilon) - - def forward(self, input_ids, position_ids=None, tokentype_ids=None, seq_lens=None, batch_size=None, max_padding_size=None): - max_padding_size = input_ids.shape[1] - - # TODO: register_buffer in advance for position_ids to speedup - - if position_ids is None: - position_ids = torch.arange(max_padding_size, dtype=torch.long, device=get_current_device()).unsqueeze(0) - - x = self.word_embeddings(input_ids) + self.position_embeddings(position_ids) - - if self.tokentype_embeddings is not None and tokentype_ids is not None: - x = x + self.tokentype_embeddings(tokentype_ids) - - x = self.LayerNorm(x) - - if seq_lens is not None: - x = depad(x, batch_size, seq_lens) - - return x - - -class BertSelfAttention1D(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - bias: bool = True, - fuse_scale_mask_softmax: bool = False, - layernorm_epsilon: float = 1e-5, - dtype: dtype = None) -> None: - super().__init__() - if hidden_size % num_heads != 0: - raise ValueError( - f"The hidden size ({hidden_size}) is not a multiple of the number of attention ") - self.hidden_size = hidden_size - self.attention_head_size = divide(hidden_size, num_heads) - self.fuse_scale_mask_softmax = fuse_scale_mask_softmax - - self.query_key_value = Linear1D_Col(hidden_size, 3 * hidden_size, bias=bias, dtype=dtype) - - if fuse_scale_mask_softmax: - raise NotImplementedError - - self.dense = Linear1D_Row(hidden_size, hidden_size, bias=True, dtype=dtype, parallel_input=True) - # self.LayerNorm = nn.LayerNorm(hidden_size, eps=layernorm_epsilon) - self.LayerNorm = LayerNorm1D(hidden_size, eps=layernorm_epsilon) - - - def forward(self, hidden_states, attention_mask=None, seq_lens=None, batch_size=None, max_padding_size=None): - - attention_output = self.query_key_value(hidden_states) - all_head_size = attention_output.shape[-1] // 3 - num_attention_heads = divide(all_head_size, self.attention_head_size) # num_heads - - new_qkv_shape = attention_output.shape[:-1] + (num_attention_heads, 3*self.attention_head_size) - attention_output = attention_output.view(new_qkv_shape) - - print(f'1: {attention_output.size()}') - if seq_lens is not None: - attention_output = transpose_pad(attention_output, batch_size, max_padding_size, seq_lens, num_attention_heads, self.attention_head_size*3) - else: - attention_output = attention_output.permute(0, 2, 1, 3) - # TODO: make sure self.attention_head_size*3 is correct - - print(f'2: {attention_output.size()}') - - q, k, v = torch.chunk(attention_output, 3, dim = -1) - - attention_output = torch.matmul(q, k.transpose(-1, -2)) - print(f'3: {attention_output.size()}') - if self.fuse_scale_mask_softmax: - raise NotImplementedError - else: - attention_output = attention_output / math.sqrt(self.attention_head_size) - # if attention_mask is not None: - # attention_output = attention_output + attention_mask - attention_output = nn.functional.softmax(attention_output, dim=-1) - - attention_output = torch.matmul(attention_output, v) - - print(f'4: {attention_output.size()}') - - if seq_lens is not None: - sum_seq = torch.sum(seq_lens) - attention_output = transpose_depad(attention_output, batch_size, sum_seq, max_padding_size, seq_lens, num_attention_heads, self.attention_head_size) - else: - attention_output = attention_output.permute(0, 2, 1, 3).contiguous() - - print(f'5: {attention_output.size()}') - - - new_context_layer_shape = attention_output.size()[:-2] + (all_head_size,) - attention_output = attention_output.reshape(new_context_layer_shape) - attention_output = self.dense(attention_output) - - hidden_states = self.LayerNorm(attention_output + hidden_states) - - return hidden_states - -def gelu_impl(x): - """OpenAI's gelu implementation.""" - return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * - (1.0 + 0.044715 * x * x))) - -class BertMLP1D(nn.Module): - def __init__(self, - hidden_size: int, - mlp_ratio: float, - activation: Callable = gelu_impl, - layernorm_epsilon: float = 1e-5, - dtype: dtype = None, - bias: bool = True): - super().__init__() - intermediate_dim = int(hidden_size * mlp_ratio) - self.layer_0 = Linear1D_Col(hidden_size, intermediate_dim, bias=bias, dtype=dtype, gather_output=False) - self.activation = activation - self.layer_1 = Linear1D_Row(intermediate_dim, hidden_size, bias=bias,dtype=dtype, parallel_input=True) - # self.LayerNorm = nn.LayerNorm(hidden_size, eps=layernorm_epsilon) - self.LayerNorm = LayerNorm1D(hidden_size, eps=layernorm_epsilon) - - def forward(self, input_tensor): - hidden_states = self.layer_0(input_tensor) - hidden_states = self.activation(hidden_states) - hidden_states = self.layer_1(hidden_states) - - hidden_states = self.LayerNorm(hidden_states+input_tensor) - return hidden_states - -class BertTransformerLayer1D(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - mlp_ratio: float, - activation: Callable = gelu_impl, - layernorm_epsilon: float = 1e-5, - dtype: dtype = None, - bias: bool = True, - fuse_scale_mask_softmax: bool = False): - - super().__init__() - - self.attention = BertSelfAttention1D(hidden_size, - num_heads, - bias, - fuse_scale_mask_softmax, - layernorm_epsilon, - dtype) - self.mlp = BertMLP1D(hidden_size, - mlp_ratio, - activation, - layernorm_epsilon, - dtype, - bias) - - def forward(self, hidden_states, attention_mask, seq_lens=None, batch_size=None, max_padding_size=None): - hidden_states = self.attention(hidden_states, attention_mask, seq_lens, batch_size, max_padding_size) - hidden_states = self.mlp(hidden_states) - - return hidden_states - - -class PipelineBert1D(nn.Module): - - def __init__(self, - vocab_size: int = 50304, - max_position_embeddings: int = 1024, - hidden_size: int = 768, - num_heads: int = 12, - depth: int = 12, - mlp_ratio: float = 4.0, - layernorm_epsilon: float = 1e-5, - activation: Callable = nn.functional.gelu, - padding_idx: int = 0, - dtype: dtype = None, - bias: bool = True, - fuse_scale_mask_softmax: bool = False, - first: bool = False, - last: bool = False): - super().__init__() - self.first = first - self.last = last - - if first: - self.embed = BertEmbedding1D(embedding_dim=hidden_size, - vocab_size=vocab_size, - max_position_embeddings=max_position_embeddings, - padding_idx=padding_idx, - layernorm_epsilon=layernorm_epsilon, - dtype=dtype) - self.blocks = nn.ModuleList() - self.pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) if is_using_pp() else 0 - for id_ in range(depth): - self.blocks.register_module("blk_{}".format(id_ + self.pp_rank * depth), - BertTransformerLayer1D( - hidden_size=hidden_size, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - activation=activation, - layernorm_epsilon=layernorm_epsilon, - dtype=dtype, - bias=bias, - fuse_scale_mask_softmax=fuse_scale_mask_softmax, - ) - ) - - - def forward(self, hidden_states=None, input_ids=None, attention_mask=None, seq_lens=None): - - batch_size = None - max_padding_size = None - if seq_lens is not None: - batch_size = input_ids.shape[0] - max_padding_size = input_ids.shape[1] - - print(self.first) - print(self.last) - - if self.first: - hidden_states = self.embed(input_ids=input_ids, position_ids=None, tokentype_ids=None, seq_lens=seq_lens, batch_size=batch_size, max_padding_size=max_padding_size) #, seq_lens - - for block in self.blocks: - hidden_states = block(hidden_states=hidden_states, attention_mask=attention_mask, seq_lens=seq_lens, batch_size=batch_size, max_padding_size=max_padding_size) - - if self.last: - hidden_states = hidden_states[:, 1, :] - print(f'Hidden States: {hidden_states.size()}') - - return hidden_states - - - -def partition_uniform(num_items, pipeline_parallel_size, num_chunks): - assert num_items % num_chunks == 0, \ - "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" - - logger = get_dist_logger() - parts = [[] for _ in range(pipeline_parallel_size)] # 4 - partition_items = num_items // num_chunks # 96 // 2 - for idx in range(num_chunks): - base_idx = idx * partition_items - chunk_size = partition_items // pipeline_parallel_size - left = pipeline_parallel_size - partition_items % pipeline_parallel_size - if chunk_size == 0: - logger.warning("Some nodes in Pipeline have no requests") - - for p in range(pipeline_parallel_size): - st = base_idx - base_idx += chunk_size + (p >= left) - parts[p].append((st, base_idx)) - - return parts - -def _create_bert_pipeline_model(depth=48, num_chunks=1, layer_partitions=None, **model_kwargs): - logger = get_dist_logger() - pipeline_size = 0 - pipeline_rank = 0 - if gpc.is_initialized(ParallelMode.PIPELINE): - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - else: - pipeline_size = 1 - pipeline_rank = 0 - - rank = gpc.get_global_rank() - - parts = partition_uniform(depth, pipeline_size, - num_chunks)[pipeline_rank] if layer_partitions is None else layer_partitions - models = [] - for start, end in parts: - model_kwargs['first'] = start == 0 - model_kwargs['last'] = end == depth - model_kwargs['depth'] = end - start - chunk = PipelineBert1D(**model_kwargs).to(get_current_device()) - models.append(chunk) - logger.info(f'==> Rank {rank} built layer {start}-{end} / total {depth}') - - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - - numel = 0 - for _, param in model.named_parameters(recurse=True): - numel += param.numel() - - if "checkpoint" in model_kwargs.keys(): - if model_kwargs["checkpoint"] is True: - if gpc.get_global_rank() == 0: - assert "checkpoint_path" in model_kwargs.keys(), "You have to specify a file path to use checkpoint loading" - assert os.path.exists(model_kwargs["checkpoint_path"]), "Checkpoint file not found" - load_checkpoint(model_kwargs["checkpoint_path"], model, **model_kwargs) - - logger.info(f'Rank{rank}/{pipeline_rank} model size in FP16 = {numel * 2 / 1e9} GB') - return model - -def bert_small(**kwargs): - model_kwargs = dict(hidden_size=768, depth=12, num_heads=12, **kwargs) - return _create_bert_pipeline_model(**model_kwargs) - -def bert_large(**kwargs): - model_kwargs = dict(hidden_size=1024, depth=24, num_heads=16, **kwargs) - return _create_bert_pipeline_model(**model_kwargs) - -def bert_xl(**kwargs): - model_kwargs = dict(hidden_size=1600, depth=48, num_heads=16, **kwargs) - return _create_bert_pipeline_model(**model_kwargs) - - -def bert_8B(**kwargs): - model_kwargs = dict(hidden_size=3072, depth=72, num_heads=24, **kwargs) - return _create_bert_pipeline_model(**model_kwargs) - - -def bert_175B(**kwargs): - model_kwargs = dict(hidden_size=12288, depth=96, num_heads=96, **kwargs) - return _create_bert_pipeline_model(**model_kwargs) \ No newline at end of file diff --git a/energon/model/gpt/__init__.py b/energon/model/gpt/__init__.py deleted file mode 100644 index d102207..0000000 --- a/energon/model/gpt/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .gpt import gpt2_small, gpt2_medium, gpt2_large, gpt2_xl, gpt2_8B, gpt3 -from .hf_gpt2 import hf_gpt2 - - -__all__ = ['gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt2_8B', 'gpt3', 'hf_gpt2'] diff --git a/energon/model/gpt/gpt.py b/energon/model/gpt/gpt.py deleted file mode 100644 index 69ae90c..0000000 --- a/energon/model/gpt/gpt.py +++ /dev/null @@ -1,470 +0,0 @@ -import math -from typing import Callable -import os - -import torch -from torch import nn as nn, Tensor, dtype - -from energon.context import ParallelMode -from energon.core import global_context as gpc -from energon.logging import get_dist_logger -from energon.nn.layer.utils import divide, ACT2FN -from energon.nn import Linear1D_Col, Linear1D_Row, Classifier1D -from energon.nn import LayerNorm1D -from energon.nn import VocabParallelEmbedding1D -from energon.utils import get_current_device, is_using_pp -from energon.utils.checkpointing import load_checkpoint - -__all__ = [ - 'GPTEmbedding1D' - 'GPTMLP1D', - 'GPTSelfAttention1D', - 'GPTTransformerLayer1D' -] - - -class GPTEmbedding1D(nn.Module): - - def __init__(self, - embedding_dim: int, - vocab_size: int, - max_position_embeddings: int, - num_tokentypes: int = 0, - padding_idx: int = 0, - dtype: dtype = None) -> None: - super().__init__() - self.word_embeddings = VocabParallelEmbedding1D(vocab_size, embedding_dim, padding_idx=padding_idx, dtype=dtype, skip_tp=True) - self.position_embeddings = VocabParallelEmbedding1D(max_position_embeddings, embedding_dim, dtype=dtype, skip_tp=True) - if num_tokentypes > 0: - self.tokentype_embeddings = VocabParallelEmbedding1D(num_tokentypes, embedding_dim, dtype=dtype, skip_tp=True) - else: - self.tokentype_embeddings = None - - @property - def word_embedding_weight(self): - return self.word_embeddings.weight - - def forward(self, input_ids, position_ids=None, tokentype_ids=None): - # padding condition, not for variable length - seq_length = input_ids.size(1) - if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0) - x = self.word_embeddings(input_ids) + self.position_embeddings(position_ids) - if self.tokentype_embeddings is not None and tokentype_ids is not None: - x = x + self.tokentype_embeddings(tokentype_ids) - - return x - - -class GPTSelfAttention1D(nn.Module): - - def __init__(self, - dim: int, - num_heads: int, - bias: bool = True, - fuse_scale_mask_softmax: bool = False, - dtype: dtype = None) -> None: - super().__init__() - self.fuse_scale_mask_softmax = fuse_scale_mask_softmax # TODO - self.attention_head_size = divide(dim, num_heads) - self.query_key_value = Linear1D_Col(dim, 3 * dim, bias=bias, dtype=dtype) - - if fuse_scale_mask_softmax: - from colossalai.kernel import FusedScaleMaskSoftmax - from colossalai.kernel.cuda_native.scaled_softmax import \ - AttnMaskType - self.softmax = FusedScaleMaskSoftmax(input_in_fp16=True, - input_in_bf16=False, - attn_mask_type=AttnMaskType.causal, - scaled_masked_softmax_fusion=True, - mask_func=None, - softmax_in_fp32=True, - scale=math.sqrt(self.attention_head_size)) - else: - self.softmax = nn.Softmax(dim=-1) - self.dense = Linear1D_Row(dim, dim, bias=True, dtype=dtype, parallel_input=True) - - def forward(self, x, attention_mask=None): - qkv = self.query_key_value(x) - - # print(f'qkv {qkv.shape}') - - all_head_size = qkv.shape[-1] // 3 - num_attention_heads = divide(all_head_size, self.attention_head_size) # num_heads - - new_qkv_shape = qkv.shape[:-1] + \ - (num_attention_heads, 3 * self.attention_head_size) - qkv = qkv.view(new_qkv_shape) - qkv = qkv.permute((0, 2, 1, 3)) - q, k, v = torch.chunk(qkv, 3, dim=-1) - # print(f'qkv {qkv.shape}') # 6 40 128 - - x = torch.matmul(q, k.transpose(-1, -2)) - - if self.fuse_scale_mask_softmax: - x = self.softmax(x, attention_mask) - else: - x = x / math.sqrt(self.attention_head_size) - # causal mask - q_len, k_len = q.size(-2), k.size(-2) - causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8, - device=get_current_device())).view(1, 1, q_len, k_len).bool() - x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device())) - if attention_mask is not None: - x = x + attention_mask - x = self.softmax(x) - - x = torch.matmul(x, v) - x = x.transpose(1, 2) - new_context_layer_shape = x.size()[:-2] + (all_head_size,) - x = x.reshape(new_context_layer_shape) - - x = self.dense(x) - - return x - - -class GPTMLP1D(nn.Module): - - def __init__(self, - dim: int, - mlp_ratio: float, - activation: Callable, - dtype: dtype = None, - bias: bool = True): - super().__init__() - intermediate_dim = int(dim * mlp_ratio) - self.dense_1 = Linear1D_Col(dim, intermediate_dim, bias=bias, dtype=dtype, gather_output=False) - self.activation = activation - self.dense_2 = Linear1D_Row(intermediate_dim, dim, bias=bias, dtype=dtype, parallel_input=True) - - def forward(self, x): - x = self.dense_1(x) - x = self.activation(x) - x = self.dense_2(x) - return x - - -class GPTBlock1D(nn.Module): - - def __init__(self, - dim: int, - num_heads: int, - mlp_ratio: float, - activation: Callable, - layernorm_epsilon: float = 1e-5, - dtype: dtype = None, - bias: bool = True, - apply_post_layernorm: bool = False, - fuse_scale_mask_softmax: bool = False): - super().__init__() - - self.apply_post_layernorm = apply_post_layernorm - # self.norm1 = nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) - self.norm1 = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon) - self.attn = GPTSelfAttention1D(dim=dim, - num_heads=num_heads, - bias=bias, - fuse_scale_mask_softmax=fuse_scale_mask_softmax, - dtype=dtype) - - # self.norm2 = nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) - self.norm2 = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon) - self.mlp = GPTMLP1D(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dtype=dtype, bias=bias) - - def forward(self, x, attention_mask=None): - if not self.apply_post_layernorm: - residual = x - x = self.norm1(x) - if self.apply_post_layernorm: - residual = x - x = residual + self.attn(x, attention_mask) - - if not self.apply_post_layernorm: - residual = x - x = self.norm2(x) - if self.apply_post_layernorm: - residual = x - x = residual + self.mlp(x) - - return x, attention_mask - - -class GPTLMHead1D(nn.Module): - - def __init__(self, - dim: int, - vocab_size: int, - word_embeding_weight: nn.Parameter = None, - bias: bool = False, - dtype: dtype = None) -> None: - super().__init__() - self.dense = Classifier1D(dim, vocab_size, word_embeding_weight, bias=bias, dtype=dtype) - - @property - def weight(self): - return self.dense.weight - - def forward(self, x): - x = self.dense(x) - return x - - -class GPT1D(nn.Module): - - def __init__(self, - vocab_size: int = 50304, - max_position_embeddings: int = 1024, - dim: int = 768, - num_heads: int = 12, - depth: int = 12, - mlp_ratio: float = 4.0, - layernorm_epsilon: float = 1e-5, - activation: Callable = nn.functional.gelu, - padding_idx: int = 0, - dtype: dtype = None, - bias: bool = True, - apply_post_layernorm: bool = False, - fuse_scale_mask_softmax: bool = False) -> None: - super().__init__() - self.embed = GPTEmbedding1D(embedding_dim=dim, - vocab_size=vocab_size, - max_position_embeddings=max_position_embeddings, - padding_idx=padding_idx, - dtype=dtype) - self.blocks = nn.ModuleList() - self.pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - for id_ in range(depth): - self.blocks.register_module("blk_{}".format(id_ + self.pp_rank * depth), - GPTBlock1D( - dim=dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - activation=activation, - layernorm_epsilon=layernorm_epsilon, - dtype=dtype, - bias=bias, - apply_post_layernorm=apply_post_layernorm, - fuse_scale_mask_softmax=fuse_scale_mask_softmax, - ) - ) - # self.blocks = nn.ModuleList([ - # GPTBlock1D( - # dim=dim, - # num_heads=num_heads, - # mlp_ratio=mlp_ratio, - # activation=activation, - # layernorm_epsilon=layernorm_epsilon, - # dtype=dtype, - # bias=bias, - # apply_post_layernorm=apply_post_layernorm, - # fuse_scale_mask_softmax=fuse_scale_mask_softmax, - # ) for _ in range(depth) - # ]) - # self.norm = nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) - self.norm = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon) - self.head = GPTLMHead1D(dim=dim, - vocab_size=vocab_size, - word_embeding_weight=self.embed.word_embedding_weight, - dtype=dtype) - - def forward(self, input_ids, attention_mask=None): - x = self.embed(input_ids) - - if attention_mask is not None: - batch_size = input_ids.shape[0] - attention_mask = attention_mask.view(batch_size, -1) - attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 - - for block in self.blocks: - x, attention_mask = block(x, attention_mask) - - x = self.head(self.norm(x)) - - return x - - -class PipelineGPT1D(nn.Module): - - def __init__(self, - vocab_size: int = 50257, - max_position_embeddings: int = 1024, - dim: int = 768, - num_heads: int = 12, - depth: int = 12, - mlp_ratio: float = 4.0, - layernorm_epsilon: float = 1e-5, - activation: Callable = nn.functional.gelu, - padding_idx: int = 0, - dtype: dtype = None, - bias: bool = True, - apply_post_layernorm: bool = False, - fuse_scale_mask_softmax: bool = False, - first: bool = False, - last: bool = False, **kwargs): - super().__init__() - self.first = first - self.last = last - if first: - self.embed = GPTEmbedding1D(embedding_dim=dim, - vocab_size=vocab_size, - max_position_embeddings=max_position_embeddings, - padding_idx=padding_idx, - dtype=dtype) - self.blocks = nn.ModuleList() - self.pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) if is_using_pp() else 0 - for id_ in range(depth): - self.blocks.register_module("blk_{}".format(id_ + self.pp_rank * depth), - GPTBlock1D( - dim=dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - activation=activation, - layernorm_epsilon=layernorm_epsilon, - dtype=dtype, - bias=bias, - apply_post_layernorm=apply_post_layernorm, - fuse_scale_mask_softmax=fuse_scale_mask_softmax, - ) - ) - # self.blocks = nn.ModuleList([ - # GPTBlock1D( - # dim=dim, - # num_heads=num_heads, - # mlp_ratio=mlp_ratio, - # activation=activation, - # layernorm_epsilon=layernorm_epsilon, - # dtype=dtype, - # bias=bias, - # apply_post_layernorm=apply_post_layernorm, - # fuse_scale_mask_softmax=fuse_scale_mask_softmax, - # ) for _ in range(depth) - # ]) - if self.last: - # self.norm = nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) - self.norm = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon) - self.head = GPTLMHead1D(dim=dim, vocab_size=vocab_size, - dtype=dtype) # word_embeeding_weight=self.embed.word_embedding_weight not in the same process - - def forward(self, hidden_states=None, input_ids=None, attention_mask=None): - if self.first: - hidden_states = self.embed(input_ids) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # Adapted from huggingface - if attention_mask is not None: - if self.first: - batch_size = input_ids.shape[0] - else: - batch_size = hidden_states.shape[0] - attention_mask = attention_mask.view(batch_size, -1) - attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 - - for block in self.blocks: - hidden_states, attention_mask = block(hidden_states, attention_mask) - - if self.last: - hidden_states = self.head(self.norm(hidden_states)) - - return hidden_states - - -def partition_uniform(num_items, pipeline_parallel_size, num_chunks): - assert num_items % num_chunks == 0, \ - "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" - - logger = get_dist_logger() - parts = [[] for _ in range(pipeline_parallel_size)] # 4 - partition_items = num_items // num_chunks # 96 // 2 - for idx in range(num_chunks): - base_idx = idx * partition_items - chunk_size = partition_items // pipeline_parallel_size - left = pipeline_parallel_size - partition_items % pipeline_parallel_size - if chunk_size == 0: - logger.warning("Some nodes in Pipeline have no requests") - - for p in range(pipeline_parallel_size): - st = base_idx - base_idx += chunk_size + (p >= left) - parts[p].append((st, base_idx)) - - return parts - - -def _create_gpt_pipeline_model(depth=48, num_chunks=1, layer_partitions=None, **model_kwargs): - logger = get_dist_logger() - pipeline_size = 0 - pipeline_rank = 0 - if gpc.is_initialized(ParallelMode.PIPELINE): - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - else: - pipeline_size = 1 - pipeline_rank = 0 - - rank = gpc.get_global_rank() - - parts = partition_uniform(depth, pipeline_size, - num_chunks)[pipeline_rank] if layer_partitions is None else layer_partitions - models = [] - for start, end in parts: - model_kwargs['first'] = start == 0 - model_kwargs['last'] = end == depth - model_kwargs['depth'] = end - start - chunk = PipelineGPT1D(**model_kwargs).to(get_current_device()) - models.append(chunk) - logger.info(f'==> Rank {rank} built layer {start}-{end} / total {depth}') - - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - - numel = 0 - for _, param in model.named_parameters(recurse=True): - numel += param.numel() - if "checkpoint" in model_kwargs.keys(): - if model_kwargs["checkpoint"] is True: - if gpc.get_global_rank() == 0: - assert "checkpoint_path" in model_kwargs.keys(), "You have to specify a file path to use checkpoint loading" - print(model_kwargs["checkpoint_path"]) - assert os.path.exists(model_kwargs["checkpoint_path"]), "Checkpoint file not found" - load_checkpoint(model_kwargs["checkpoint_path"], model, **model_kwargs) - logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB') - return model - - -def gpt2_small(**kwargs): - model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs) - return _create_gpt_pipeline_model(**model_kwargs) - - -def gpt2_medium(**kwargs): - model_kwargs = dict(dim=1024, depth=24, num_heads=8, **kwargs) - return _create_gpt_pipeline_model(**model_kwargs) - - -def gpt2_large(**kwargs): - model_kwargs = dict(dim=1536, depth=36, num_heads=12, **kwargs) - return _create_gpt_pipeline_model(**model_kwargs) - - -def gpt2_xl(**kwargs): - model_kwargs = dict(dim=1600, depth=48, num_heads=16, **kwargs) - return _create_gpt_pipeline_model(**model_kwargs) - - -def gpt2_8B(**kwargs): - model_kwargs = dict(dim=3072, depth=72, num_heads=24, **kwargs) - return _create_gpt_pipeline_model(**model_kwargs) - - -def gpt3(**kwargs): - model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs) - return _create_gpt_pipeline_model(**model_kwargs) diff --git a/energon/model/gpt/hf_gpt2.py b/energon/model/gpt/hf_gpt2.py deleted file mode 100644 index 44a31bf..0000000 --- a/energon/model/gpt/hf_gpt2.py +++ /dev/null @@ -1,480 +0,0 @@ -import os -import math -import torch -import random -from torch import nn as nn, Tensor, dtype -from typing import Callable - -from energon.context import ParallelMode -from energon.core import global_context as gpc -from energon.logging import get_dist_logger -from energon.nn.layer.utils import divide, ACT2FN -from energon.nn import Linear1D_Col, Linear1D_Row, Classifier1D -from energon.nn import LayerNorm1D -from energon.nn import VocabParallelEmbedding1D -from energon.utils import get_current_device, is_using_pp -from energon.utils.checkpointing_hf_gpt2 import load_checkpoint - -__all__ = [ - 'GPTEmbedding1D' - 'GPTMLP1D', - 'GPTSelfAttention1D', - 'GPTTransformerLayer1D' -] - - -class GPTEmbedding1D(nn.Module): - - def __init__(self, - embedding_dim: int, - vocab_size: int, - max_position_embeddings: int, - num_tokentypes: int = 0, - padding_idx: int = 0, - dtype: dtype = None) -> None: - super().__init__() - self.word_embeddings = VocabParallelEmbedding1D(vocab_size, embedding_dim, padding_idx=padding_idx, dtype=dtype, skip_tp=True) - self.position_embeddings = VocabParallelEmbedding1D(max_position_embeddings, embedding_dim, dtype=dtype, skip_tp=True) - if num_tokentypes > 0: - self.tokentype_embeddings = VocabParallelEmbedding1D(num_tokentypes, embedding_dim, dtype=dtype, skip_tp=True) - else: - self.tokentype_embeddings = None - - @property - def word_embedding_weight(self): - return self.word_embeddings.weight - - def forward(self, input_ids, position_ids=None, tokentype_ids=None): - # padding condition, not for variable length - seq_length = input_ids.size(1) - if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0) - x = self.word_embeddings(input_ids) + self.position_embeddings(position_ids) - if self.tokentype_embeddings is not None and tokentype_ids is not None: - x = x + self.tokentype_embeddings(tokentype_ids) - # print("wte: {}".format(self.word_embeddings(input_ids))) - # print("wpe: {}".format(self.position_embeddings(position_ids))) - # print("hidden_states: {}".format(x)) - return x - - -class GPTSelfAttention1D(nn.Module): - - def __init__(self, - dim: int, - num_heads: int, - bias: bool = True, - fuse_scale_mask_softmax: bool = False, - dtype: dtype = None) -> None: - super().__init__() - self.fuse_scale_mask_softmax = fuse_scale_mask_softmax # TODO - self.attention_head_size = divide(dim, num_heads) - # self.query_key_value = Linear1D_Col(dim, 3 * dim, bias=bias, dtype=dtype) - self.query_ = Linear1D_Col(dim, dim, bias=bias, dtype=dtype) - self.key_ = Linear1D_Col(dim, dim, bias=bias, dtype=dtype) - self.value_ = Linear1D_Col(dim, dim, bias=bias, dtype=dtype) - if fuse_scale_mask_softmax: - from colossalai.kernel import FusedScaleMaskSoftmax - from colossalai.kernel.cuda_native.scaled_softmax import \ - AttnMaskType - self.softmax = FusedScaleMaskSoftmax(input_in_fp16=True, - input_in_bf16=False, - attn_mask_type=AttnMaskType.causal, - scaled_masked_softmax_fusion=True, - mask_func=None, - softmax_in_fp32=True, - scale=math.sqrt(self.attention_head_size)) - else: - self.softmax = nn.Softmax(dim=-1) - self.dense = Linear1D_Row(dim, dim, bias=True, dtype=dtype, parallel_input=True) - - - def _split_heads(self, tensor, num_heads, attn_head_size): - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) - - - def forward(self, x, attention_mask=None): - # print("x: {}".format(x.shape)) - # qkv = self.query_key_value(x) - - # print(f'qkv {qkv.shape}') - q = self.query_(x) - k = self.key_(x) - v = self.value_(x) - all_head_size = q.shape[-1] - num_attention_heads = divide(all_head_size, self.attention_head_size) # num_heads - # print(self.attention_head_size) - # new_qkv_shape = qkv.shape[:-1] + \ - # (num_attention_heads, 3 * self.attention_head_size) - # qkv = qkv.view(new_qkv_shape) - # qkv = qkv.permute((0, 2, 1, 3)) - # print("{} qkv: {} {}".format(gpc.get_global_rank(), qkv.shape, qkv)) - # # q, k, v = torch.chunk(qkv, 3, dim=-1) - # q, k, v = qkv.split(all_head_size, dim=2) - # print("{} q: {} {}".format(gpc.get_global_rank(), q.shape, q)) - # print("{} k: {} {}".format(gpc.get_global_rank(), k.shape, k)) - # print("{} v: {} {}".format(gpc.get_global_rank(), v.shape, v)) - q = self._split_heads(q, num_attention_heads, self.attention_head_size) - k = self._split_heads(k, num_attention_heads, self.attention_head_size) - v = self._split_heads(v, num_attention_heads, self.attention_head_size) - # print(f'qkv {qkv.shape}') # 6 40 128 - x = torch.matmul(q, k.transpose(-1, -2)) - - if self.fuse_scale_mask_softmax: - x = self.softmax(x, attention_mask) - else: - x = x / math.sqrt(self.attention_head_size) - # causal mask - q_len, k_len = q.size(-2), k.size(-2) - causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8, - device=get_current_device())).view(1, 1, q_len, k_len).bool() - x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device())) - if attention_mask is not None: - x = x + attention_mask - x = self.softmax(x) - - x = torch.matmul(x, v) - x = x.transpose(1, 2) - new_context_layer_shape = x.size()[:-2] + (all_head_size,) - x = x.reshape(new_context_layer_shape) - # print("{} before dense: {} {}".format(gpc.get_global_rank(), x.shape, x)) - x = self.dense(x) - # print("after mlp: {}".format(x)) - - return x - - -class GPTMLP1D(nn.Module): - - def __init__(self, - dim: int, - mlp_ratio: float, - activation: Callable, - dtype: dtype = None, - bias: bool = True): - super().__init__() - intermediate_dim = int(dim * mlp_ratio) - self.dense_1 = Linear1D_Col(dim, intermediate_dim, bias=bias, dtype=dtype, gather_output=False) - self.activation = activation - self.dense_2 = Linear1D_Row(intermediate_dim, dim, bias=bias, dtype=dtype, parallel_input=True) - - def forward(self, x): - x = self.dense_1(x) - x = self.activation(x) - x = self.dense_2(x) - return x - - -class GPTBlock1D(nn.Module): - - def __init__(self, - dim: int, - num_heads: int, - mlp_ratio: float, - activation: Callable, - layernorm_epsilon: float = 1e-5, - dtype: dtype = None, - bias: bool = True, - apply_post_layernorm: bool = False, - fuse_scale_mask_softmax: bool = False): - super().__init__() - - self.apply_post_layernorm = apply_post_layernorm - # self.norm1 = nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) - self.norm1 = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon) - self.attn = GPTSelfAttention1D(dim=dim, - num_heads=num_heads, - bias=bias, - fuse_scale_mask_softmax=fuse_scale_mask_softmax, - dtype=dtype) - - # self.norm2 = nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) - self.norm2 = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon) - self.mlp = GPTMLP1D(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dtype=dtype, bias=bias) - - def forward(self, x, attention_mask=None): - if not self.apply_post_layernorm: - residual = x - x = self.norm1(x) - # print("{} after norm1: {}".format(gpc.get_global_rank(), x)) - if self.apply_post_layernorm: - residual = x - x = residual + self.attn(x, attention_mask) - - if not self.apply_post_layernorm: - residual = x - # print("{} after attn: {}".format(gpc.get_global_rank(), x)) - x = self.norm2(x) - # print("{} after norm2: {}".format(gpc.get_global_rank(), x)) - if self.apply_post_layernorm: - residual = x - x = residual + self.mlp(x) - # print("{} after mlp: {}".format(gpc.get_global_rank(), x)) - return x, attention_mask - - -class GPTLMHead1D(nn.Module): - - def __init__(self, - dim: int, - vocab_size: int, - word_embeding_weight: nn.Parameter = None, - bias: bool = False, - dtype: dtype = None) -> None: - super().__init__() - self.dense = Classifier1D(dim, vocab_size, word_embeding_weight, bias=bias, dtype=dtype) - - @property - def weight(self): - return self.dense.weight - - def forward(self, x): - x = self.dense(x) - return x - - -class GPT1D(nn.Module): - - def __init__(self, - vocab_size: int = 50304, - max_position_embeddings: int = 1024, - dim: int = 768, - num_heads: int = 12, - depth: int = 12, - mlp_ratio: float = 4.0, - layernorm_epsilon: float = 1e-5, - activation: Callable = nn.functional.gelu, - padding_idx: int = 0, - dtype: dtype = None, - bias: bool = True, - apply_post_layernorm: bool = False, - fuse_scale_mask_softmax: bool = False) -> None: - super().__init__() - self.embed = GPTEmbedding1D(embedding_dim=dim, - vocab_size=vocab_size, - max_position_embeddings=max_position_embeddings, - padding_idx=padding_idx, - dtype=dtype) - self.blocks = nn.ModuleList() - self.pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - for id_ in range(depth): - self.blocks.register_module("blk_{}".format(id_ + self.pp_rank * depth), - GPTBlock1D( - dim=dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - activation=activation, - layernorm_epsilon=layernorm_epsilon, - dtype=dtype, - bias=bias, - apply_post_layernorm=apply_post_layernorm, - fuse_scale_mask_softmax=fuse_scale_mask_softmax, - ) - ) - # self.blocks = nn.ModuleList([ - # GPTBlock1D( - # dim=dim, - # num_heads=num_heads, - # mlp_ratio=mlp_ratio, - # activation=activation, - # layernorm_epsilon=layernorm_epsilon, - # dtype=dtype, - # bias=bias, - # apply_post_layernorm=apply_post_layernorm, - # fuse_scale_mask_softmax=fuse_scale_mask_softmax, - # ) for _ in range(depth) - # ]) - # self.norm = nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) - self.norm = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon) - self.head = GPTLMHead1D(dim=dim, - vocab_size=vocab_size, - word_embeding_weight=self.embed.word_embedding_weight, - dtype=dtype) - - def forward(self, input_ids, attention_mask=None): - x = self.embed(input_ids) - - if attention_mask is not None: - batch_size = input_ids.shape[0] - attention_mask = attention_mask.view(batch_size, -1) - attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - attention_mask = attention_mask.to(dtype=x.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 - - for block in self.blocks: - x, attention_mask = block(x, attention_mask) - - x = self.head(self.norm(x)) - - return x - -def select_top_k(predictions, k=5): - predicted_index = random.choice(predictions[0, -1, :].sort(descending=True)[1][:10]) #.item() - return predicted_index - -class PipelineGPT1D(nn.Module): - - def __init__(self, - vocab_size: int = 50257, - max_position_embeddings: int = 1024, - dim: int = 768, - num_heads: int = 12, - depth: int = 12, - mlp_ratio: float = 4.0, - layernorm_epsilon: float = 1e-5, - activation: Callable = nn.functional.gelu, - padding_idx: int = 0, - dtype: dtype = None, - bias: bool = True, - apply_post_layernorm: bool = False, - fuse_scale_mask_softmax: bool = False, - first: bool = False, - last: bool = False, - **kwargs): - super().__init__() - self.first = first - self.last = last - if first: - self.embed = GPTEmbedding1D(embedding_dim=dim, - vocab_size=vocab_size, - max_position_embeddings=max_position_embeddings, - padding_idx=padding_idx, - dtype=dtype) - self.blocks = nn.ModuleList() - self.pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) if is_using_pp() else 0 - for id_ in range(depth): - self.blocks.register_module("blk_{}".format(id_ + self.pp_rank * depth), - GPTBlock1D( - dim=dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - activation=activation, - layernorm_epsilon=layernorm_epsilon, - dtype=dtype, - bias=bias, - apply_post_layernorm=apply_post_layernorm, - fuse_scale_mask_softmax=fuse_scale_mask_softmax, - ) - ) - # self.blocks = nn.ModuleList([ - # GPTBlock1D( - # dim=dim, - # num_heads=num_heads, - # mlp_ratio=mlp_ratio, - # activation=activation, - # layernorm_epsilon=layernorm_epsilon, - # dtype=dtype, - # bias=bias, - # apply_post_layernorm=apply_post_layernorm, - # fuse_scale_mask_softmax=fuse_scale_mask_softmax, - # ) for _ in range(depth) - # ]) - if self.last: - # self.norm = nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype) - self.norm = LayerNorm1D(normalized_shape=dim, eps=layernorm_epsilon) - self.head = GPTLMHead1D(dim=dim, vocab_size=vocab_size, - dtype=dtype) # word_embeeding_weight=self.embed.word_embedding_weight not in the same process - - def forward(self, hidden_states=None, input_ids=None, attention_mask=None): - topk = 5 # TODO: add as a parameter - if self.first: - hidden_states = self.embed(input_ids) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # Adapted from huggingface - if attention_mask is not None: - if self.first: - batch_size = input_ids.shape[0] - else: - batch_size = hidden_states.shape[0] - attention_mask = attention_mask.view(batch_size, -1) - attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 - # print("processed attention mask: {}".format(attention_mask)) - clk_cnt = 0 - for block in self.blocks: - # print("="*30) - # print("processing blk {}".format(clk_cnt)) - clk_cnt += 1 - hidden_states, attention_mask = block(hidden_states, attention_mask) - - if self.last: - hidden_states = self.head(self.norm(hidden_states)) - hidden_states = select_top_k(hidden_states, k=topk) - - return hidden_states - - -def partition_uniform(num_items, pipeline_parallel_size, num_chunks): - assert num_items % num_chunks == 0, \ - "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" - - logger = get_dist_logger() - parts = [[] for _ in range(pipeline_parallel_size)] # 4 - partition_items = num_items // num_chunks # 96 // 2 - for idx in range(num_chunks): - base_idx = idx * partition_items - chunk_size = partition_items // pipeline_parallel_size - left = pipeline_parallel_size - partition_items % pipeline_parallel_size - if chunk_size == 0: - logger.warning("Some nodes in Pipeline have no requests") - - for p in range(pipeline_parallel_size): - st = base_idx - base_idx += chunk_size + (p >= left) - parts[p].append((st, base_idx)) - - return parts - - -def _create_gpt_pipeline_model(depth=48, num_chunks=1, layer_partitions=None, **model_kwargs): - logger = get_dist_logger() - pipeline_size = 0 - pipeline_rank = 0 - if gpc.is_initialized(ParallelMode.PIPELINE): - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - else: - pipeline_size = 1 - pipeline_rank = 0 - - rank = gpc.get_global_rank() - - parts = partition_uniform(depth, pipeline_size, - num_chunks)[pipeline_rank] if layer_partitions is None else layer_partitions - models = [] - for start, end in parts: - model_kwargs['first'] = start == 0 - model_kwargs['last'] = end == depth - model_kwargs['depth'] = end - start - chunk = PipelineGPT1D(**model_kwargs).to(get_current_device()) - models.append(chunk) - logger.info(f'==> Rank {rank} built layer {start}-{end} / total {depth}') - - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - - numel = 0 - for _, param in model.named_parameters(recurse=True): - numel += param.numel() - if "checkpoint" in model_kwargs.keys(): - if model_kwargs["checkpoint"] is True: - if gpc.get_global_rank() == 0: - assert "checkpoint_path" in model_kwargs.keys(), "You have to specify a file path to use checkpoint loading" - print(model_kwargs["checkpoint_path"]) - assert os.path.exists(model_kwargs["checkpoint_path"]), "Checkpoint file not found" - load_checkpoint(model_kwargs["checkpoint_path"], model, **model_kwargs) - logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB') - return model - - -def hf_gpt2(**kwargs): - model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs) - return _create_gpt_pipeline_model(**model_kwargs)