Skip to content

Commit

Permalink
[SLM] GPT2 multi-gpu support (mlc-ai#1647)
Browse files Browse the repository at this point in the history
  • Loading branch information
LeshengJin authored Jan 22, 2024
1 parent 8ce2358 commit a7c73fe
Showing 1 changed file with 49 additions and 9 deletions.
58 changes: 49 additions & 9 deletions python/mlc_chat/model/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from mlc_chat import op as op_ext
from mlc_chat.support import logging
from mlc_chat.support import tensor_parallel as tp
from mlc_chat.support.config import ConfigBase
from mlc_chat.support.style import bold

Expand All @@ -31,6 +32,7 @@ class GPT2Config(ConfigBase): # pylint: disable=too-many-instance-attributes
prefill_chunk_size: int = 0
scale_attn_by_inverse_layer_idx: bool = False
tensor_parallel_shards: int = 1
head_dim: int = 0
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand All @@ -53,7 +55,9 @@ def __post_init__(self):
"`context_window_size`, `n_positions` or `max_sequence_length` is "
"provided in `config.json`."
)
assert self.tensor_parallel_shards == 1, "GPT2 currently does not support sharding."
if self.head_dim == 0:
self.head_dim = self.n_embd // self.n_head
assert self.head_dim * self.n_head == self.n_embd
if self.prefill_chunk_size == 0:
logger.info(
"%s defaults to %s (%d)",
Expand All @@ -79,8 +83,8 @@ def __post_init__(self):
class GPT2Attention(nn.Module): # pylint: disable=too-many-instance-attributes
def __init__(self, config: GPT2Config, layer_idx: int = None):
self.embed_dim = config.n_embd
self.num_heads = config.n_head
self.head_dim = self.embed_dim // self.num_heads
self.num_heads = config.n_head // config.tensor_parallel_shards
self.head_dim = config.head_dim
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
self.layer_idx = layer_idx

Expand All @@ -89,7 +93,7 @@ def __init__(self, config: GPT2Config, layer_idx: int = None):
out_features=3 * self.num_heads * self.head_dim,
bias=True,
)
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.c_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=True)

self.k_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim])
self.v_cache = nn.KVCache(config.context_window_size, [self.num_heads, self.head_dim])
Expand Down Expand Up @@ -124,7 +128,7 @@ def forward(
class GPT2MLP(nn.Module):
def __init__(self, config: GPT2Config):
embed_dim = config.n_embd
intermediate_size = config.n_inner
intermediate_size = config.n_inner // config.tensor_parallel_shards
self.c_fc = nn.Linear(embed_dim, intermediate_size)
self.c_proj = nn.Linear(intermediate_size, embed_dim)

Expand All @@ -143,12 +147,45 @@ def __init__(self, config: GPT2Config, layer_idx: int = None):
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(config)

def _set_tp():
def _set(param, hint):
param.attrs["shard_strategy"] = hint

hd = config.head_dim
q = k = v = self.attn.num_heads * hd
_set(
self.attn.c_attn.weight,
tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]),
)
_set(
self.attn.c_attn.bias,
tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]),
)
_set(self.attn.c_proj.weight, tp.ShardSingleDim("_shard_attn_c_proj", dim=1))
_set(
self.mlp.c_fc.weight,
tp.ShardSingleDim("_shard_c_fc_weight", dim=0),
)
_set(self.mlp.c_fc.bias, tp.ShardSingleDim("_shard_c_fc_bias", dim=0))
_set(self.mlp.c_proj.weight, tp.ShardSingleDim("_shard_mlp_c_proj", dim=1))

self.tensor_parallel_shards = config.tensor_parallel_shards
_set_tp()

def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var):
hidden_states = (
self.attn(self.ln_1(hidden_states), attention_mask, total_seq_len) + hidden_states
)
def _apply_residual(out, residual):
if self.tensor_parallel_shards > 1:
return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, "sum")
return out + residual

with tp.shard_bias(self.attn.c_proj, self.tensor_parallel_shards), tp.shard_bias(
self.mlp.c_proj, self.tensor_parallel_shards
):
hidden_states = _apply_residual(
self.attn(self.ln_1(hidden_states), attention_mask, total_seq_len), hidden_states
)
hidden_states = _apply_residual(self.mlp(self.ln_2(hidden_states)), hidden_states)

hidden_states = self.mlp(self.ln_2(hidden_states)) + hidden_states
return hidden_states


Expand All @@ -159,8 +196,11 @@ def __init__(self, config: GPT2Config):
self.wpe = nn.Embedding(config.context_window_size, config.n_embd)
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.tensor_parallel_shards = config.tensor_parallel_shards

def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor):
if self.tensor_parallel_shards > 1:
inputs = op.ccl_broadcast_from_worker0(inputs)
# Token Embeddings
t_embd = self.wte(inputs)

Expand Down

0 comments on commit a7c73fe

Please sign in to comment.