Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add deepseek-v3 #35926

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
704767e
add deepseekv3 modeling
bzantium Jan 28, 2025
737ee3a
Merge branch 'main' into feature/#35425
bzantium Jan 28, 2025
fc3a4c7
Merge branch 'main' of https://github.com/bzantium/transformers into …
bzantium Jan 28, 2025
244e793
remove redundant code
bzantium Jan 28, 2025
0968df5
Merge branch 'feature/#35425' of https://github.com/bzantium/transfor…
bzantium Jan 28, 2025
4fb2a80
apply make style
bzantium Jan 28, 2025
6b002e5
apply fix-copies
bzantium Jan 28, 2025
4ec1e88
make format
bzantium Jan 28, 2025
114ab84
add init files
bzantium Jan 28, 2025
779f8d2
rename deepseekv3 into deepseek_v3 based on its model_type
bzantium Jan 28, 2025
22623a3
rename deepseekv3 into deepseek_v3 based on its model_type
bzantium Jan 28, 2025
78b19b0
deepseek-v3 not deepseek_v3
bzantium Jan 28, 2025
eb0e3a4
set model_type as deepseek_v3
bzantium Jan 28, 2025
57088cc
use default docs
bzantium Jan 28, 2025
0ef561b
apply make
bzantium Jan 28, 2025
9a75a56
fill type and docstring
bzantium Jan 28, 2025
cdf83e4
add rope_config_validation
bzantium Jan 29, 2025
51990b9
use custom DeepseekV3MLP
bzantium Jan 29, 2025
f4f0ebd
hold code only for checkpoints congifuration; remove redundant
bzantium Jan 30, 2025
4b72b30
revise rope yarn for DeepSeek variation
bzantium Jan 30, 2025
96562c4
Merge branch 'main' into feature/#35425
bzantium Jan 30, 2025
6792cb5
rename DeepSeek-V3
bzantium Jan 30, 2025
3bf3b32
some refactoring
ArthurZucker Jan 31, 2025
24bc8b2
revise load_hook to work properly; make moe func trainable; use llama…
bzantium Jan 31, 2025
5c0cd91
fix attention forward
bzantium Jan 31, 2025
8e994dd
use -1 for not-changing dim when to use exapnd
bzantium Feb 1, 2025
7405a95
refactor DeepseekV3TopkRouter
bzantium Feb 1, 2025
ea3c922
use reshape_for_rope instead of load_hook; revise attention forward f…
bzantium Feb 3, 2025
c813268
register pre_hook and hook both
bzantium Feb 3, 2025
4ab2f9e
make style
bzantium Feb 3, 2025
c5429ec
use n_shared_experts
bzantium Feb 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ Flax), PyTorch, and/or TensorFlow.
| [DeBERTa](model_doc/deberta) | ✅ | ✅ | ❌ |
| [DeBERTa-v2](model_doc/deberta-v2) | ✅ | ✅ | ❌ |
| [Decision Transformer](model_doc/decision_transformer) | ✅ | ❌ | ❌ |
| [DeepSeek-V3](model_doc/deepseek_v3) | ✅ | ❌ | ❌ |
| [Deformable DETR](model_doc/deformable_detr) | ✅ | ❌ | ❌ |
| [DeiT](model_doc/deit) | ✅ | ✅ | ❌ |
| [DePlot](model_doc/deplot) | ✅ | ❌ | ❌ |
Expand Down
46 changes: 46 additions & 0 deletions docs/source/en/model_doc/deepseek_v3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# DeepSeek-V3

## Overview

The DeepSeek-V3 model was proposed in [DeepSeek-V3 Technical Report](https://arxiv.org/abs/2412.19437) by DeepSeek-AI Team.

The abstract from the paper is the following:
We present DeepSeek-V3, a strong Mixture-of-Experts (MoE) language model with 671B total parameters with 37B activated for each token. To achieve efficient inference and cost-effective training, DeepSeek-V3 adopts Multi-head Latent Attention (MLA) and DeepSeekMoE architectures, which were thoroughly validated in DeepSeek-V2. Furthermore, DeepSeek-V3 pioneers an auxiliary-loss-free strategy for load balancing and sets a multi-token prediction training objective for stronger performance. We pre-train DeepSeek-V3 on 14.8 trillion diverse and high-quality tokens, followed by Supervised Fine-Tuning and Reinforcement Learning stages to fully harness its capabilities. Comprehensive evaluations reveal that DeepSeek-V3 outperforms other open-source models and achieves performance comparable to leading closed-source models. Despite its excellent performance, DeepSeek-V3 requires only 2.788M H800 GPU hours for its full training. In addition, its training process is remarkably stable. Throughout the entire training process, we did not experience any irrecoverable loss spikes or perform any rollbacks. The model checkpoints are available at https://github.com/deepseek-ai/DeepSeek-V3.

### Usage tips
The model uses Multi-head Latent Attention (MLA) and DeepSeekMoE architectures for efficient inference and cost-effective training. It employs an auxiliary-loss-free strategy for load balancing and multi-token prediction training objective. The model can be used for various language tasks after being pre-trained on 14.8 trillion tokens and going through Supervised Fine-Tuning and Reinforcement Learning stages.

## DeepseekV3Config

[[autodoc]] DeepseekV3Config

## DeepseekV3Model

[[autodoc]] DeepseekV3Model
- forward

## DeepseekV3ForCausalLM

[[autodoc]] DeepseekV3ForCausalLM
- forward

## DeepseekV3ForSequenceClassification

[[autodoc]] DeepseekV3ForSequenceClassification
- forward
18 changes: 18 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@
],
"models.deberta_v2": ["DebertaV2Config"],
"models.decision_transformer": ["DecisionTransformerConfig"],
"models.deepseek_v3": ["DeepseekV3Config"],
"models.deformable_detr": ["DeformableDetrConfig"],
"models.deit": ["DeiTConfig"],
"models.deprecated": [],
Expand Down Expand Up @@ -1956,6 +1957,14 @@
"DecisionTransformerPreTrainedModel",
]
)
_import_structure["models.deepseek_v3"].extend(
[
"DeepseekV3ForCausalLM",
"DeepseekV3ForSequenceClassification",
"DeepseekV3Model",
"DeepseekV3PreTrainedModel",
]
)
_import_structure["models.deformable_detr"].extend(
[
"DeformableDetrForObjectDetection",
Expand Down Expand Up @@ -5395,6 +5404,9 @@
from .models.decision_transformer import (
DecisionTransformerConfig,
)
from .models.deepseek_v3 import (
DeepseekV3Config,
)
from .models.deformable_detr import (
DeformableDetrConfig,
)
Expand Down Expand Up @@ -6966,6 +6978,12 @@
DecisionTransformerModel,
DecisionTransformerPreTrainedModel,
)
from .models.deepseek_v3 import (
DeepseekV3ForCausalLM,
DeepseekV3ForSequenceClassification,
DeepseekV3Model,
DeepseekV3PreTrainedModel,
)
from .models.deformable_detr import (
DeformableDetrForObjectDetection,
DeformableDetrModel,
Expand Down
36 changes: 30 additions & 6 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,31 @@ def _compute_yarn_parameters(
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
max_position_embeddings = config.max_position_embeddings
factor = config.rope_scaling["factor"]
attention_factor = config.rope_scaling.get("attention_factor")
mscale = config.rope_scaling.get("mscale")
mscale_all_dim = config.rope_scaling.get("mscale_all_dim")

# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if "original_max_position_embeddings" in config.rope_scaling:
original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"]
factor = config.max_position_embeddings / original_max_position_embeddings
else:
original_max_position_embeddings = config.max_position_embeddings

def get_mscale(scale, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0

# Sets the attention factor as suggested in the paper
attention_factor = config.rope_scaling.get("attention_factor")
if attention_factor is None:
attention_factor = 0.1 * math.log(factor) + 1.0
if mscale and mscale_all_dim:
attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
else:
attention_factor = get_mscale(factor)

# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
Expand Down Expand Up @@ -227,15 +245,14 @@ def linear_ramp_factor(min, max, dim):
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)

low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)

# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
)

return inv_freq, attention_factor


Expand Down Expand Up @@ -425,7 +442,14 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
optional_keys = {
"attention_factor",
"beta_fast",
"beta_slow",
"original_max_position_embeddings",
"mscale",
"mscale_all_dim",
}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
deberta,
deberta_v2,
decision_transformer,
deepseek_v3,
deformable_detr,
deit,
deprecated,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
("deberta", "DebertaConfig"),
("deberta-v2", "DebertaV2Config"),
("decision_transformer", "DecisionTransformerConfig"),
("deepseek_v3", "DeepseekV3Config"),
("deformable_detr", "DeformableDetrConfig"),
("deit", "DeiTConfig"),
("depth_anything", "DepthAnythingConfig"),
Expand Down Expand Up @@ -406,6 +407,7 @@
("deberta", "DeBERTa"),
("deberta-v2", "DeBERTa-v2"),
("decision_transformer", "Decision Transformer"),
("deepseek_v3", "DeepSeek-V3"),
("deformable_detr", "Deformable DETR"),
("deit", "DeiT"),
("deplot", "DePlot"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
("deberta", "DebertaModel"),
("deberta-v2", "DebertaV2Model"),
("decision_transformer", "DecisionTransformerModel"),
("deepseek_v3", "DeepseekV3Model"),
("deformable_detr", "DeformableDetrModel"),
("deit", "DeiTModel"),
("deta", "DetaModel"),
Expand Down Expand Up @@ -501,6 +502,7 @@
("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForCausalLM"),
("dbrx", "DbrxForCausalLM"),
("deepseek_v3", "DeepseekV3ForCausalLM"),
("diffllama", "DiffLlamaForCausalLM"),
("electra", "ElectraForCausalLM"),
("emu3", "Emu3ForCausalLM"),
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@
"DebertaV2TokenizerFast" if is_tokenizers_available() else None,
),
),
(
"deepseek_v3",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"diffllama",
(
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_deepseek_v3 import *
from .modeling_deepseek_v3 import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading