-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add the Bamba Model #34982
Add the Bamba Model #34982
Conversation
Co-authored-by: Gabe Goodhart <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Hi @fabianlim, do you have a paper reference for this model or any details on the trained checkpoints? |
@Rocketknight1 thanks for reaching out. Yes my colleagues are preparing a paper and a GitHub repo with the (training) code. And checkpoints will be 1.8T, 2T, 2.2T, and an sft model. We will update the PR accordingly. cc: @raghukiran1224 |
The data used is all open, we plan to share any and all details on what the community would want! Open source is the name of the game 😄 |
Cool! @molbap will be the point of contact at Hugging Face for this PR, so feel free to ping me or him if you have any questions as you're working on it. |
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All tests + integration pass and modular looks better, thanks for working on it and congrats on the model! pinging @ArthurZucker so core review starts ;)
docs/source/en/model_doc/bamba.md
Outdated
|
||
## Overview | ||
|
||
TODO |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be filled in as well before merging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated in 44788dc
docs/source/en/model_doc/bamba.md
Outdated
print(i) | ||
``` | ||
|
||
<!-- update this --> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
<!-- update this --> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated in 44788dc
Added overview, update Model inference card and added config
Minor fixes
Added overview and other additional details for Bamba
Signed-off-by: Antoni Viros i Martin <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, a few nits to adresse and we can merge!
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import TYPE_CHECKING | ||
|
||
from ...utils import ( | ||
OptionalDependencyNotAvailable, | ||
_LazyModule, | ||
is_torch_available, | ||
) | ||
|
||
|
||
_import_structure = { | ||
"configuration_bamba": ["BambaConfig"], | ||
} | ||
|
||
try: | ||
if not is_torch_available(): | ||
raise OptionalDependencyNotAvailable() | ||
except OptionalDependencyNotAvailable: | ||
pass | ||
else: | ||
_import_structure["modeling_bamba"] = [ | ||
"BambaForCausalLM", | ||
"BambaModel", | ||
"BambaPreTrainedModel", | ||
] | ||
|
||
if TYPE_CHECKING: | ||
from .configuration_bamba import BambaConfig | ||
|
||
try: | ||
if not is_torch_available(): | ||
raise OptionalDependencyNotAvailable() | ||
except OptionalDependencyNotAvailable: | ||
pass | ||
else: | ||
from .modeling_bamba import ( | ||
BambaForCausalLM, | ||
BambaModel, | ||
BambaPreTrainedModel, | ||
) | ||
|
||
else: | ||
import sys | ||
|
||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import TYPE_CHECKING | |
from ...utils import ( | |
OptionalDependencyNotAvailable, | |
_LazyModule, | |
is_torch_available, | |
) | |
_import_structure = { | |
"configuration_bamba": ["BambaConfig"], | |
} | |
try: | |
if not is_torch_available(): | |
raise OptionalDependencyNotAvailable() | |
except OptionalDependencyNotAvailable: | |
pass | |
else: | |
_import_structure["modeling_bamba"] = [ | |
"BambaForCausalLM", | |
"BambaModel", | |
"BambaPreTrainedModel", | |
] | |
if TYPE_CHECKING: | |
from .configuration_bamba import BambaConfig | |
try: | |
if not is_torch_available(): | |
raise OptionalDependencyNotAvailable() | |
except OptionalDependencyNotAvailable: | |
pass | |
else: | |
from .modeling_bamba import ( | |
BambaForCausalLM, | |
BambaModel, | |
BambaPreTrainedModel, | |
) | |
else: | |
import sys | |
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) | |
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import TYPE_CHECKING | |
from ...utils import _LazyModule | |
from ...utils.import_utils import define_import_structure | |
if TYPE_CHECKING: | |
from .configuration_bamba import * | |
from .modeling_bamba import * | |
from .processing_bamba import * | |
else: | |
import sys | |
_file = globals()["__file__"] | |
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and you need to define __all__
like:
__all__ = [
"GemmaModel",
"GemmaForCausalLM",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
"GemmaPreTrainedModel",
]
at the end of the modular file! see modular gemma2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed in latest commit
|
||
assert mamba_intermediate % mamba_n_heads == 0, "mamba_n_heads must divide mamba_expand * hidden_size" | ||
|
||
# for the mamba_v2, must satisfy the following | ||
if mamba_d_head == "auto": | ||
mamba_d_head = mamba_intermediate // mamba_n_heads | ||
assert mamba_d_head * mamba_n_heads == mamba_intermediate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's remove asserts and rather raise errors please
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in latest commit
def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]: | ||
state_dict = {} | ||
|
||
for orig_k, param in original_sd.items(): | ||
k = orig_k.replace("backbone", "model") | ||
|
||
# for embeddings | ||
k = k.replace("embedding", "embed_tokens") | ||
|
||
# for mixer | ||
k = k.replace("mixer", "mamba") | ||
|
||
# for final layernorm | ||
k = k.replace("norm_f", "final_layernorm") | ||
|
||
# for block layernorm | ||
k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k) | ||
k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k) | ||
|
||
# for mlp | ||
k = k.replace("mlp.fc2", "feed_forward.down_proj") | ||
|
||
if "mlp.fc1" in k: | ||
param, param2 = torch.chunk(param, 2, dim=0) | ||
k2 = k.replace("mlp.fc1", "feed_forward.gate_proj") | ||
state_dict[k2] = param2 | ||
k = k.replace("mlp.fc1", "feed_forward.up_proj") | ||
|
||
if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or ( | ||
"out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd | ||
): | ||
# then this must be a mamba | ||
pass | ||
else: | ||
# for attn | ||
# - because mixer was replaced to mamba above | ||
k = k.replace("mamba.out_proj", "self_attn.o_proj") | ||
if "mamba.in_proj" in k: | ||
m, n = param.shape | ||
d = (m - n) // 2 | ||
param, param2, param3 = torch.split(param, [n, d, d], dim=0) | ||
k2 = k.replace("mamba.in_proj", "self_attn.k_proj") | ||
state_dict[k2] = param2 | ||
k2 = k.replace("mamba.in_proj", "self_attn.v_proj") | ||
state_dict[k2] = param3 | ||
k = k.replace("mamba.in_proj", "self_attn.q_proj") | ||
|
||
state_dict[k] = param | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we like to have a more explicit dict like this one
ORIGINAL_TO_CONVERTED_KEY_MAPPING = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leaving for future/followup PR
|
||
|
||
# Adapted from transformers.models.llama.modeling_llama.apply_rotary_pos_emb | ||
# - handles the case if the rotary embedding is smaller than head_dim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
really not sure this is worth changing as it's just ads an assert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise we should use the notation like
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to adapt from GLM
""" | ||
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. | ||
A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) | ||
∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, | ||
and is why Mamba is called **selective** state spaces) | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as a reviewer, but also any dev that is gonna read this code, we need to know what the differences are with Mamba2Mixer.
Could you add comments here explaining why we have to redefine everything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a list to the comments
class BambaDecoderLayer(LlamaDecoderLayer): | ||
def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"): | ||
super().__init__() | ||
|
||
del self.self_attn | ||
|
||
del self.mlp | ||
del self.post_attention_layernorm | ||
self.feed_forward = BambaMLP(config) | ||
self.pre_ff_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
|
||
self.layer_type = layer_type | ||
if layer_type == "mamba": | ||
self.mamba = BambaMixer(config=config, layer_idx=layer_idx) | ||
elif layer_type == "attention": | ||
self.self_attn = BAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) | ||
else: | ||
raise ValueError("Invalid layer_type") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a lot more similar to :
class JambaAttentionDecoderLayer(nn.Module):
def __init__(self, config: JambaConfig, layer_idx: int):
super().__init__()
num_experts = config.layers_num_experts[layer_idx]
self.self_attn = JAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
self.feed_forward = ffn_layer_class(config)
self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
which should be a better base 😉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in latest commit
|
||
|
||
# Adapted from transformers.models.jamba.modeling_jamba.JambaForCausalLM | ||
class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forward and init should be the same as LlamaForcausalLM which you should be able to inherit from and just change the prepare input for generation as it is the only difference no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in latest commit
Signed-off-by: Antoni Viros i Martin <[email protected]>
Signed-off-by: Antoni Viros i Martin <[email protected]>
Signed-off-by: Antoni Viros i Martin <[email protected]>
Signed-off-by: Antoni Viros i Martin <[email protected]>
Signed-off-by: Antoni Viros i Martin <[email protected]>
Kudos! 🚀 |
What does this PR do?
This PR merges the
BambaModel
, which is a hybrid mamba2 architecture with SwiGLU. The checkpoints are jointly trained by IBM, Princeton, and UIUC.The implementation is based off ai21labs/Jamba-v0.1 and the mamba2 implementation ported over to HF for the codestral model.
cc: @ani300, @raghukiran1224
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.