Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add deepseek-v3 #35926

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open

Conversation

bzantium
Copy link
Contributor

@bzantium bzantium commented Jan 28, 2025

What does this PR do?

This PR adds the codes for the DeepSeekV3.
code relies heavily on original remote code.

resolved: #35425

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

to: @ArthurZucker

@Rocketknight1
Copy link
Member

Hi @bzantium, this looks great so far! We'll need added tests for the model + a green CI, and then feel free to ping me to assign a reviewer, or if you have any problems with the port.

@bzantium bzantium changed the title [WIP] add deepseekv3 [WIP] add deepseek-v3 Jan 29, 2025
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ultra kudos! It's super nice
Mostly missing tests, here you can use a similar approach to the gemma2 tests, which use inheritance!

src/transformers/models/deepseek_v3/modular_deepseek_v3.py Outdated Show resolved Hide resolved
src/transformers/models/deepseek_v3/modular_deepseek_v3.py Outdated Show resolved Hide resolved
src/transformers/models/deepseek_v3/modular_deepseek_v3.py Outdated Show resolved Hide resolved
src/transformers/models/deepseek_v3/modular_deepseek_v3.py Outdated Show resolved Hide resolved
src/transformers/models/deepseek_v3/modular_deepseek_v3.py Outdated Show resolved Hide resolved
src/transformers/models/deepseek_v3/modular_deepseek_v3.py Outdated Show resolved Hide resolved
src/transformers/models/deepseek_v3/modular_deepseek_v3.py Outdated Show resolved Hide resolved
src/transformers/models/deepseek_v3/modular_deepseek_v3.py Outdated Show resolved Hide resolved
src/transformers/models/deepseek_v3/modular_deepseek_v3.py Outdated Show resolved Hide resolved
@cuichenx
Copy link

@bzantium Thanks for the amazing work! I was wondering if you were able to train V3 with FSDP? If so how many GPUs did you need? Thanks!

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jan 29, 2025

One big thing would be TP support, the base_tp_plan would probably need to be updated to make sure each mlp's gat up down have the correct order, unless the direct usage of dist remove this need

@casper-hansen
Copy link

This is great work and I'm looking forward to try it out. For multi-token prediction, is this planned to be implemented in this PR via the num_nextn_predict_layers attribute in the config?

@bzantium
Copy link
Contributor Author

bzantium commented Jan 30, 2025

Thanks for the comments in detail; following your comments, I revised code quite a lot and fixed some mismatch between original code and this PR. I checked the outputs from both are the same. I think now I can add test codes. For TP support, I think they can be applied only for mlp layer but not for self_attn because they have functions like split on the hidden_dim. I added as following:

    base_model_tp_plan = {
        "layers.*.gate_proj": "colwise",
        "layers.*.up_proj": "colwise",
        "layers.*.down_proj": "rowwise",
    }

to: @ArthurZucker

@bzantium
Copy link
Contributor Author

bzantium commented Jan 30, 2025

Do you want me to jump on the PR and help you merge this faster @bzantium ? 🤗

Of course! As you commented, it looks like there's still a lot to work left. (code optimization, training code, testing code and multi-token prediction if possible)
to: @ArthurZucker

@bzantium
Copy link
Contributor Author

thanks for refactoring! I more optimize moe_infer and make it trainable so rename it to moe.
also, I revised load_hook to properly work as intended. At last, I rolled back to use Llama instead of Mixtral since there exists not to match each other.
to: @ArthurZucker

@ArthurZucker
Copy link
Collaborator

Nice, put it a bit on pause talked to @NouamaneTazi about how we would go about "training" this, as TP would be a little bit hard (you said so as well). We need to enable the layer norm to use TP

@bzantium
Copy link
Contributor Author

bzantium commented Feb 1, 2025

Nice, put it a bit on pause talked to @NouamaneTazi about how we would go about "training" this, as TP would be a little bit hard (you said so as well). We need to enable the layer norm to use TP

Great, I think applying expert parallel instead of tensor parallel would be better in this case if possible. (may need to larger change than TP for transformers to apply generally). Original code provided direct expert parallelism implementation for moe.
ref: https://pytorch.org/blog/training-moes/

@bzantium
Copy link
Contributor Author

bzantium commented Feb 3, 2025

I found that _register_load_state_dict_pre_hook is not executed when using tp_plan=auto. so I use reshape_for_rope instead of load_hook like following:

cos, sin = position_embeddings
q_rot, k_rot = self.reshape_for_rope(q_rot, k_rot)
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
...

def reshape_for_rope(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    b, h, s, d = q.shape
    q = q.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d)
    b, h, s, d = k.shape
    k = k.view(b, h, s, d // 2, 2).transpose(-1, -2).reshape(b, h, s, d)
    return q, k

In addition, following the TP vllm reference, I revised Attention forward and add more modules to base_model_tp_plan.

base_model_tp_plan = {
    "layers.*.gate_proj": "colwise",
    "layers.*.up_proj": "colwise",
    "layers.*.down_proj": "rowwise",
    "layers.*.self_attn.q_b_proj": "colwise",
    "layers.*.self_attn.kv_b_proj": "colwise",
    "layers.*.self_attn.o_proj": "rowwise",
}

to: @ArthurZucker

@ArthurZucker
Copy link
Collaborator

Yep sounds good! I'll be a bit slow today!
We should enforce the use of the load hook, otherwise rotating for rope is gonna be a real issue in terms of performances!

A few things for the todo-list:

  • Layer-norm probably needs a special layer replacement, if I am not mistaken
  • The load balancing function for now is the one we usually use, I need to check on the paper if it's the same or not
  • add megablock as a soft dependency to try and use kernels ?
  • make sure we support fp8 (cc @SunMarc) !

I need to run a bit of maintenance today so don't worry, but I'll take a look whenever I can here!

@bzantium
Copy link
Contributor Author

bzantium commented Feb 3, 2025

I fixed it by using both _register_load_state_dict_pre_hook and _register_state_dict_hook. I got the same results with/without TP.

class DeepseekV3Model(LlamaModel):
    def __init__(self, config):
        super().__init__(config)
        self._register_load_state_dict_pre_hook(self.load_pre_hook)
        self._register_state_dict_hook(self.load_hook)
        self.post_init()

    def load_pre_hook(self, state_dict, prefix, *args):
        """
        Weights have to be permuted for correct rope formulation. We can't do this in the weights
        as every other framework already uses the `Llama` original function (which is copyrighted btw).
        And I am not even sure it's better.... anyways end of my rant
        """

        def permute_for_rope(input_tensor):
            """
            When you go from the complex ROPE formulation to sin and cos one, you need
            to permute the query and key weights (to avoid doing it on the fly)
            """
            n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2]
            input_tensor = input_tensor.reshape(n_heads * dim1, dim2)
            input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2)
            input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2)
            return input_tensor

        def permute_layer_for_rope(key, num_heads, head_dim, rope_dim):
            weight = state_dict[key]
            weight = weight.view(num_heads, head_dim, -1)
            weight_rot = weight[:, -rope_dim:]
            weight_rot = permute_for_rope(weight_rot)
            weight[:, -rope_dim:] = weight_rot
            weight = weight.view(-1, weight.shape[-1])
            state_dict[key] = weight

        for k in state_dict:
            if "q_b_proj." in k:
                permute_layer_for_rope(
                    k,
                    num_heads=self.config.num_attention_heads,
                    head_dim=self.config.qk_head_dim,
                    rope_dim=self.config.qk_rope_head_dim,
                )
            if "kv_a_proj_with_mqa." in k:
                permute_layer_for_rope(
                    k,
                    num_heads=1,
                    head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
                    rope_dim=self.config.qk_rope_head_dim,
                )

    def load_hook(self, module, state_dict, prefix, *args):
        self.load_pre_hook(state_dict, prefix, *args)

@ArthurZucker

@ArthurZucker
Copy link
Collaborator

Yep looks alright.
For testing quickly we can create a tiny model (this will help in our CI) and I'll be able to download it as well !

@bzantium
Copy link
Contributor Author

bzantium commented Feb 3, 2025

I am currently using this model (5.35B) for testing, same config but with 6 layers and less experts. I will upload much smaller model later if you need.

https://huggingface.co/bzantium/tiny-deepseek-v3

to: @ArthurZucker

@ArthurZucker
Copy link
Collaborator

Perfect!

@mseeger
Copy link

mseeger commented Feb 3, 2025

Sorry to jump in here. But I read the DeepSeek-V2 paper in detail, and noted that the implementation they published, which you seem to be using here, is really quite inefficient, and somewhat at odds to what they write in their paper. This also applies to https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py.

@mseeger
Copy link

mseeger commented Feb 3, 2025

To explain:
The key idea in the paper is to use low rank approximations to Q, K, V tensors. This reduces KV cache size. But it also leads to more efficient computations, even in training.

To this end, you have to avoid computing the Q, K, V tensors, but instead work in the lower ranks. This is quite possible (I worked it out), but is not properly detailed in their paper.

What is real funny: This trick of extending K and Q by an extra part for RoPE encoding: the only reason this is done in the paper is to allow not to have to compute Q, K explicitly. If you compute them explicitly (as is done here), you can just as well apply RoPE encoding to them!

There is a way to make this work so that the role of K and V is played by a single tensor of shape (batch_size, 1, q_len, kv_lora_rank + qk_rope_head_dim). There is no num_attention_heads here. And single tensor, because "V" is really just a part of "K". The role of Q is played by a tensor of shape (batch_size, num_attention_heads, q_len, kv_lora_rank + qk_rope_head_dim), and dot products are computed between these.

When I says "plays the role of K, ...", I mean that you call torch.nn.functional.scaled_dot_product_attention with these, instead of the original Q, K, V.

@ArthurZucker
Copy link
Collaborator

Ahhhhh very nice.
Indeed I usually base stuff off code rather than papers, would be nice to get an optimized version

@bzantium
Copy link
Contributor Author

bzantium commented Feb 4, 2025

thanks for the great comment! Could you suggest code block to apply? Following what you mention, attention could be applied to compressed latent instead of QKV, but I cannot directly imagine how it is possible. for Cache, I see what you meant; we need to cache latent KV not original K,V.
to: @mseeger

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Feb 4, 2025

Talked to @zzhhjjj and seems like if we go for the optimized version we won't be able to have TP (Tensor Parallel)

@mseeger
Copy link

mseeger commented Feb 4, 2025

What is TP? The optimized version I am referring to is just a rewriting of the maths.

I am working for a company, and have asked whether I am allowed to contribute this. If they give me green light, I am happy to do it. I think HuggingFace should have the best code possible.

@mseeger
Copy link

mseeger commented Feb 4, 2025

It does not have to delay this PR here. If this goes in first, I can still open a second PR, which would provide (1) an alternative to the Attention class, and also some code to convert from the current weights to the new ones. The new formulation uses somewhat different weights, because the *_b tensors are combined into one (more or less).

BTW: This is all I have, I did not look at the MoE part or anything else. I also did not run any comparisons.

@mseeger
Copy link

mseeger commented Feb 4, 2025

As for the cache: You only have to store what is called C_KV and K_R in the DeepSeek-V2 paper. You return the concat of these two as K tensor, and C_KV alone as V tensor.

@mseeger
Copy link

mseeger commented Feb 5, 2025

I agree with @bzantium that me jumping in should not delay this PR, also because frankly I need to get approval from myu employer. So please do go ahead. I'd then make my contribution afterwards. However, I'd appreciate help then, especially with doing testing.

@bzantium
Copy link
Contributor Author

bzantium commented Feb 5, 2025

I agree with @bzantium that me jumping in should not delay this PR, also because frankly I need to get approval from myu employer. So please do go ahead. I'd then make my contribution afterwards. However, I'd appreciate help then, especially with doing testing.

If you can implement testing codes, it would be helpful greatly!

@mseeger
Copy link

mseeger commented Feb 5, 2025

Sure @bzantium , I'd do that. I'd implement tests which ensure that the computations give the same result. I'll also provide code to convert the weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DeepSeek V3 Support
7 participants