Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the Bamba Model #34982

Merged
merged 44 commits into from
Dec 18, 2024
Merged

Add the Bamba Model #34982

merged 44 commits into from
Dec 18, 2024

Conversation

fabianlim
Copy link
Contributor

@fabianlim fabianlim commented Nov 28, 2024

What does this PR do?

This PR merges the BambaModel, which is a hybrid mamba2 architecture with SwiGLU. The checkpoints are jointly trained by IBM, Princeton, and UIUC.

The implementation is based off ai21labs/Jamba-v0.1 and the mamba2 implementation ported over to HF for the codestral model.

cc: @ani300, @raghukiran1224

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@fabianlim fabianlim marked this pull request as draft November 28, 2024 00:35
@fabianlim fabianlim changed the title initial commit for PR Add the Bamba Model Nov 28, 2024
Co-authored-by: Gabe Goodhart <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@Rocketknight1
Copy link
Member

Hi @fabianlim, do you have a paper reference for this model or any details on the trained checkpoints?

@fabianlim
Copy link
Contributor Author

@Rocketknight1 thanks for reaching out. Yes my colleagues are preparing a paper and a GitHub repo with the (training) code. And checkpoints will be 1.8T, 2T, 2.2T, and an sft model. We will update the PR accordingly.

cc: @raghukiran1224

@raghukiran1224
Copy link

The data used is all open, we plan to share any and all details on what the community would want! Open source is the name of the game 😄

@Rocketknight1
Copy link
Member

Cool! @molbap will be the point of contact at Hugging Face for this PR, so feel free to ping me or him if you have any questions as you're working on it.

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim fabianlim mentioned this pull request Dec 5, 2024
6 tasks
@molbap molbap added State space models Issues or PRs related to state space models such as mamba, mamba2 New model labels Dec 9, 2024
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim fabianlim marked this pull request as ready for review December 16, 2024 15:33
Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All tests + integration pass and modular looks better, thanks for working on it and congrats on the model! pinging @ArthurZucker so core review starts ;)


## Overview

TODO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be filled in as well before merging

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in 44788dc

print(i)
```

<!-- update this -->
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
<!-- update this -->

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated in 44788dc

divya-kumari32 and others added 9 commits December 16, 2024 22:51
Added overview, update Model inference card and added config
Minor fixes
Added overview and other additional details for Bamba
Signed-off-by: Antoni Viros i Martin <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@pcuenca pcuenca requested a review from ArthurZucker December 18, 2024 12:06
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, a few nits to adresse and we can merge!

Comment on lines 1 to 57
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)


_import_structure = {
"configuration_bamba": ["BambaConfig"],
}

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_bamba"] = [
"BambaForCausalLM",
"BambaModel",
"BambaPreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_bamba import BambaConfig

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_bamba import (
BambaForCausalLM,
BambaModel,
BambaPreTrainedModel,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_bamba": ["BambaConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_bamba"] = [
"BambaForCausalLM",
"BambaModel",
"BambaPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_bamba import BambaConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_bamba import (
BambaForCausalLM,
BambaModel,
BambaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_bamba import *
from .modeling_bamba import *
from .processing_bamba import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and you need to define __all__ like:

__all__ = [
    "GemmaModel",
    "GemmaForCausalLM",
    "GemmaForSequenceClassification",
    "GemmaForTokenClassification",
    "GemmaPreTrainedModel",
]

at the end of the modular file! see modular gemma2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in latest commit

Comment on lines 169 to 175

assert mamba_intermediate % mamba_n_heads == 0, "mamba_n_heads must divide mamba_expand * hidden_size"

# for the mamba_v2, must satisfy the following
if mamba_d_head == "auto":
mamba_d_head = mamba_intermediate // mamba_n_heads
assert mamba_d_head * mamba_n_heads == mamba_intermediate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove asserts and rather raise errors please

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in latest commit

Comment on lines +34 to +82
def convert_state_dict_from_mamba_ssm(original_sd: Dict) -> Dict[str, torch.Tensor]:
state_dict = {}

for orig_k, param in original_sd.items():
k = orig_k.replace("backbone", "model")

# for embeddings
k = k.replace("embedding", "embed_tokens")

# for mixer
k = k.replace("mixer", "mamba")

# for final layernorm
k = k.replace("norm_f", "final_layernorm")

# for block layernorm
k = re.sub(r"(\d+)\.norm\.", r"\1.input_layernorm.", k)
k = re.sub(r"(\d+)\.norm2\.", r"\1.pre_ff_layernorm.", k)

# for mlp
k = k.replace("mlp.fc2", "feed_forward.down_proj")

if "mlp.fc1" in k:
param, param2 = torch.chunk(param, 2, dim=0)
k2 = k.replace("mlp.fc1", "feed_forward.gate_proj")
state_dict[k2] = param2
k = k.replace("mlp.fc1", "feed_forward.up_proj")

if ("in_proj" in k and orig_k.replace("in_proj", "conv1d") in original_sd) or (
"out_proj" in k and orig_k.replace("out_proj", "conv1d") in original_sd
):
# then this must be a mamba
pass
else:
# for attn
# - because mixer was replaced to mamba above
k = k.replace("mamba.out_proj", "self_attn.o_proj")
if "mamba.in_proj" in k:
m, n = param.shape
d = (m - n) // 2
param, param2, param3 = torch.split(param, [n, d, d], dim=0)
k2 = k.replace("mamba.in_proj", "self_attn.k_proj")
state_dict[k2] = param2
k2 = k.replace("mamba.in_proj", "self_attn.v_proj")
state_dict[k2] = param3
k = k.replace("mamba.in_proj", "self_attn.q_proj")

state_dict[k] = param

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we like to have a more explicit dict like this one

, but it's not blocking merge!

Copy link
Contributor

@ani300 ani300 Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leaving for future/followup PR



# Adapted from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# - handles the case if the rotary embedding is smaller than head_dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really not sure this is worth changing as it's just ads an assert.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise we should use the notation like

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to adapt from GLM

Comment on lines 249 to 254
"""
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
and is why Mamba is called **selective** state spaces)
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a reviewer, but also any dev that is gonna read this code, we need to know what the differences are with Mamba2Mixer.

Could you add comments here explaining why we have to redefine everything?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a list to the comments

Comment on lines 723 to 740
class BambaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
super().__init__()

del self.self_attn

del self.mlp
del self.post_attention_layernorm
self.feed_forward = BambaMLP(config)
self.pre_ff_layernorm = BambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.layer_type = layer_type
if layer_type == "mamba":
self.mamba = BambaMixer(config=config, layer_idx=layer_idx)
elif layer_type == "attention":
self.self_attn = BAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
else:
raise ValueError("Invalid layer_type")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot more similar to :

class JambaAttentionDecoderLayer(nn.Module):
    def __init__(self, config: JambaConfig, layer_idx: int):
        super().__init__()
        num_experts = config.layers_num_experts[layer_idx]
        self.self_attn = JAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

        ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
        self.feed_forward = ffn_layer_class(config)
        self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

which should be a better base 😉

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in latest commit



# Adapted from transformers.models.jamba.modeling_jamba.JambaForCausalLM
class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forward and init should be the same as LlamaForcausalLM which you should be able to inherit from and just change the prepare input for generation as it is the only difference no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in latest commit

@molbap molbap merged commit 9613933 into huggingface:main Dec 18, 2024
23 checks passed
@ArthurZucker
Copy link
Collaborator

Kudos! 🚀

@garrett361 garrett361 mentioned this pull request Jan 24, 2025
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
New model run-slow State space models Issues or PRs related to state space models such as mamba, mamba2
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants