Skip to content

Commit

Permalink
Add deepseek autotp (#6937)
Browse files Browse the repository at this point in the history
Deepseek including Multi-Head Latent Attention(MLA) and MoE.

For MLA TP, we need to skip two low-rank layers("q_a_proj" and
"kv_a_proj_with_mqa)
For Deepseek MoE, tp_parse gets this moe layer name is
layer_idx.down_proj, it is hard to add the policy, so we set the
down_proj layer to all_reduce_linears default.
  • Loading branch information
Yejing-Lai authored Jan 9, 2025
1 parent 53fb579 commit 45fce45
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
15 changes: 10 additions & 5 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def is_load_module(module):
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
"Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm"
"Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding", "Phi3LongRoPEScaledRotaryEmbedding", "Qwen2RMSNorm",
"DeepseekV2RMSNorm", "DeepseekV2YarnRotaryEmbedding", "MoEGate"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

Expand Down Expand Up @@ -332,9 +333,9 @@ def _replace(self, child, name, conv_linear_layer):
return
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
# For mixtral-7x8b, need to skip MoE gate linear replace.
if name == "block_sparse_moe.gate" or (('mlp.shared_expert_gate' == name or 'mlp.gate' == name)
and 'qwen2_moe' in str(type(self.module))):
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
if "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or (
('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))):
return child
# For Yuan model
if 'Yuan' in str(self.module):
Expand All @@ -350,11 +351,15 @@ def _replace(self, child, name, conv_linear_layer):
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
# For MoE MLP model, e.g., deepseek and jamba
down_proj = False
if 'down_proj' in name:
down_proj = True
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear:
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]

Expand Down
7 changes: 6 additions & 1 deletion deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@ def get_num_attention_heads():
def get_shard_size(total_size, mp_size, name=None, rank=None):
global num_kv_heads
last_linear = ["lm_head", "embed_out"]
# MoE MLP layer use near even division will get better perf.
moe_mlp_layer = ["gate_proj", "up_proj", "down_proj", "w1", "w2", "w3"]
not_moe_mlp_layer = True
if name != None and any(s in str(name) for s in moe_mlp_layer):
not_moe_mlp_layer = False
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division
if rank == None:
rank = dist.get_rank()
if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str(
name) not in last_linear:
name) not in last_linear and not_moe_mlp_layer:
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
else:
Expand Down

0 comments on commit 45fce45

Please sign in to comment.