-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] add deepseek-v3 #35926
base: main
Are you sure you want to change the base?
[WIP] add deepseek-v3 #35926
Conversation
Hi @bzantium, this looks great so far! We'll need added tests for the model + a green CI, and then feel free to ping me to assign a reviewer, or if you have any problems with the port. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ultra kudos! It's super nice
Mostly missing tests, here you can use a similar approach to the gemma2
tests, which use inheritance!
@bzantium Thanks for the amazing work! I was wondering if you were able to train V3 with FSDP? If so how many GPUs did you need? Thanks! |
One big thing would be |
This is great work and I'm looking forward to try it out. For multi-token prediction, is this planned to be implemented in this PR via the |
Thanks for the comments in detail; following your comments, I revised code quite a lot and fixed some mismatch between original code and this PR. I checked the outputs from both are the same. I think now I can add test codes. For
to: @ArthurZucker |
Of course! As you commented, it looks like there's still a lot to work left. (code optimization, training code, testing code and multi-token prediction if possible) |
thanks for refactoring! I more optimize |
Nice, put it a bit on pause talked to @NouamaneTazi about how we would go about "training" this, as TP would be a little bit hard (you said so as well). We need to enable the layer norm to use TP |
Great, I think applying expert parallel instead of tensor parallel would be better in this case if possible. (may need to larger change than TP for |
I found that cos, sin = position_embeddings
q_rot, k_rot = self.reshape_for_rope(q_rot, k_rot)
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
...
def reshape_for_rope(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d)
b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d)
return q, k In addition, following the base_model_tp_plan = {
"layers.*.gate_proj": "colwise",
"layers.*.up_proj": "colwise",
"layers.*.down_proj": "rowwise",
"layers.*.self_attn.q_b_proj": "colwise",
"layers.*.self_attn.kv_b_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
} to: @ArthurZucker |
…or TP; rename q_head_dim with qk_head_dim
Yep sounds good! I'll be a bit slow today! A few things for the todo-list:
I need to run a bit of maintenance today so don't worry, but I'll take a look whenever I can here! |
I fixed it by using both class DeepseekV3Model(LlamaModel):
def __init__(self, config):
super().__init__(config)
self._register_load_state_dict_pre_hook(self.load_pre_hook)
self._register_state_dict_hook(self.load_hook)
self.post_init()
def load_pre_hook(self, state_dict, prefix, *args):
"""
Weights have to be permuted for correct rope formulation. We can't do this in the weights
as every other framework already uses the `Llama` original function (which is copyrighted btw).
And I am not even sure it's better.... anyways end of my rant
"""
def permute_for_rope(input_tensor):
"""
When you go from the complex ROPE formulation to sin and cos one, you need
to permute the query and key weights (to avoid doing it on the fly)
"""
n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2]
input_tensor = input_tensor.reshape(n_heads * dim1, dim2)
input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2)
input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2)
return input_tensor
def permute_layer_for_rope(key, num_heads, head_dim, rope_dim):
weight = state_dict[key]
weight = weight.view(num_heads, head_dim, -1)
weight_rot = weight[:, -rope_dim:]
weight_rot = permute_for_rope(weight_rot)
weight[:, -rope_dim:] = weight_rot
weight = weight.view(-1, weight.shape[-1])
state_dict[key] = weight
for k in state_dict:
if "q_b_proj." in k:
permute_layer_for_rope(
k,
num_heads=self.config.num_attention_heads,
head_dim=self.config.qk_head_dim,
rope_dim=self.config.qk_rope_head_dim,
)
if "kv_a_proj_with_mqa." in k:
permute_layer_for_rope(
k,
num_heads=1,
head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
rope_dim=self.config.qk_rope_head_dim,
)
def load_hook(self, module, state_dict, prefix, *args):
self.load_pre_hook(state_dict, prefix, *args) |
Yep looks alright. |
I am currently using this model (5.35B) for testing, same config but with 6 layers and less experts. I will upload much smaller model later if you need. https://huggingface.co/bzantium/tiny-deepseek-v3 to: @ArthurZucker |
Perfect! |
Sorry to jump in here. But I read the DeepSeek-V2 paper in detail, and noted that the implementation they published, which you seem to be using here, is really quite inefficient, and somewhat at odds to what they write in their paper. This also applies to https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py. |
To explain: To this end, you have to avoid computing the Q, K, V tensors, but instead work in the lower ranks. This is quite possible (I worked it out), but is not properly detailed in their paper. What is real funny: This trick of extending K and Q by an extra part for RoPE encoding: the only reason this is done in the paper is to allow not to have to compute Q, K explicitly. If you compute them explicitly (as is done here), you can just as well apply RoPE encoding to them! There is a way to make this work so that the role of K and V is played by a single tensor of shape When I says "plays the role of K, ...", I mean that you call |
Ahhhhh very nice. |
thanks for the great comment! Could you suggest code block to apply? Following what you mention, attention could be applied to compressed latent instead of QKV, but I cannot directly imagine how it is possible. for Cache, I see what you meant; we need to cache latent KV not original K,V. |
Talked to @zzhhjjj and seems like if we go for the optimized version we won't be able to have TP (Tensor Parallel) |
What is TP? The optimized version I am referring to is just a rewriting of the maths. I am working for a company, and have asked whether I am allowed to contribute this. If they give me green light, I am happy to do it. I think HuggingFace should have the best code possible. |
It does not have to delay this PR here. If this goes in first, I can still open a second PR, which would provide (1) an alternative to the Attention class, and also some code to convert from the current weights to the new ones. The new formulation uses somewhat different weights, because the *_b tensors are combined into one (more or less). BTW: This is all I have, I did not look at the MoE part or anything else. I also did not run any comparisons. |
As for the cache: You only have to store what is called |
I agree with @bzantium that me jumping in should not delay this PR, also because frankly I need to get approval from myu employer. So please do go ahead. I'd then make my contribution afterwards. However, I'd appreciate help then, especially with doing testing. |
If you can implement testing codes, it would be helpful greatly! |
Sure @bzantium , I'd do that. I'd implement tests which ensure that the computations give the same result. I'll also provide code to convert the weights. |
What does this PR do?
This PR adds the codes for the DeepSeekV3.
code relies heavily on original remote code.
resolved: #35425
Before submitting
Pull Request section?
to it if that's the case: DeepSeek V3 Support #35425
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
to: @ArthurZucker