-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support variable-length sequences for mamba block #244
base: main
Are you sure you want to change the base?
Conversation
Hello @tridao @albertfgu Thanks for the awesome work on mamba and it is really a strong competitor for transformer! We have noticed some issues (#236, #180) stated that they have a need for training on variable-length sequences. But they can’t find functionalities such as Also, in real world scenarios, length distribution of datasets varies much, simply padding token to maximum length would waste computing resources on the meaningless padded tokens. So we implemented this PR and hope it helps! |
77e58cb
to
a78a9eb
Compare
aea08ca
to
842bef5
Compare
Hello, it's great to see your input on variable length data. How can I use the method you provided? Is there any difference in results between it and padding? |
Thank you for your interest in this PR! Update (2024/03/19): |
Thank you for your reply. Due to performance considerations, I would like to use bidirectional mamba. Should I wait for your updated code? |
Hi @EricPaul03 , @Dmovic has created unit test on the backward pass of mamba block with variable-length sequences, and the test results show numerical equality for both forward and backward pass in the scenarios of varlen inputs. I haven't tried it with bidirectional mamba. But since it is numerical equivalent for the default unidirectional mamba, I think you can just give it a try! |
To give a simple example. What we originally pass into the original mamba block is an input with shape From the above figure, we can clearly see that through this PR, mamba block can focus computing resources on variable-length sequences and avoid the overhead of meaningless padding tokens. Variable-length training is very useful for optimizing the hardware utilization during training, and we know that the well-known flash attention has supported variable-length training via |
Thank you for your answer. This is a great code that I will try to use for my project! |
Sorry to bother you again, I would like to implement the same operation for bidirectional mamba. I would like to know if I also need to reset the value for cu_seqlens when flipping the propagation sequence to cope with the flipped sequence, and can these two share d_conv?
|
I think I should divide conv1d_out, delta, etc. into subsequences and reverse each subsequence separately? (Instead of the entire sequence, use the same cu_seqlens?) |
I copy some method in MixerModel to help use this feature.
|
For bidirectional mamba, you need to pass in the
For example, if you have We can calculate
|
I think you might not need to divide these items into subsequences. All you need is to pass in the For combining the benefits of bidirectional mamba and this PR's variable-length sequences, I drew my graphical understanding here, The mechanism can be simply viewed as that when scanning bidirectionally, hidden_states need to be reset on sequence boundaries of both directions. |
c1b68de
to
5b024c5
Compare
8f472c3
to
b69b957
Compare
Update from 2024/07/22:
python tests/ops/test_mamba_cu_seqlens_equivalence.py
Generate random cu_seqlens = [0, 116, 155, 349, 479, 674, 864, 881, 1024]
max diff for output in varlen_mamba fwd pass: 4.470348358154297e-08
mean diff for output in varlen_mamba fwd pass: 5.5261386577853955e-09
max diff for A_log in varlen_mamba bwd pass: 6.239861249923706e-08
mean diff for A_log in varlen_mamba bwd pass: 5.321690865756068e-10
max diff for D in varlen_mamba bwd pass: 6.318092346191406e-06
mean diff for D in varlen_mamba bwd pass: 6.176169335958548e-07
max diff for in_proj.weight in varlen_mamba bwd pass: 1.9073486328125e-05
mean diff for in_proj.weight in varlen_mamba bwd pass: 1.098805341825937e-06
max diff for conv1d.weight in varlen_mamba bwd pass: 5.662441253662109e-06
mean diff for conv1d.weight in varlen_mamba bwd pass: 8.699786349097849e-07
max diff for conv1d.bias in varlen_mamba bwd pass: 1.0013580322265625e-05
mean diff for conv1d.bias in varlen_mamba bwd pass: 1.4602501323679462e-06
max diff for x_proj.weight in varlen_mamba bwd pass: 3.6954879760742188e-06
mean diff for x_proj.weight in varlen_mamba bwd pass: 2.984411295869904e-08
max diff for dt_proj.weight in varlen_mamba bwd pass: 8.731149137020111e-09
mean diff for dt_proj.weight in varlen_mamba bwd pass: 3.4094516099258954e-10
max diff for dt_proj.bias in varlen_mamba bwd pass: 2.60770320892334e-08
mean diff for dt_proj.bias in varlen_mamba bwd pass: 2.458180992093162e-09
max diff for out_proj.weight in varlen_mamba bwd pass: 5.7220458984375e-06
mean diff for out_proj.weight in varlen_mamba bwd pass: 2.629302855439164e-07 pytest tests/
============================= test session starts ==============================
platform linux -- Python 3.10.14, pytest-8.3.1, pluggy-1.5.0
rootdir: /dev/varlen_mamba
plugins: typeguard-3.0.2
collected 392 items
tests/ops/test_selective_scan.py .................... [ 5%]
tests/ops/test_selective_scan_var_len.py ............ [ 8%]
tests/ops/triton/test_layernorm_gated.py .s.s.s.s.s.s.....s.s.s.s.s.s... [ 16%]
..s.s.s.s.s.s.....s.s.s.s.s.s.... [ 24%]
tests/ops/triton/test_selective_state_update.py ........................ [ 30%]
........................................................................ [ 48%]
........................................................................ [ 67%]
........................................................................ [ 85%]
.............................. [ 93%]
tests/ops/triton/test_ssd.py ........................ [ 99%]
tests/test_generation.py .. [100%]
================= 368 passed, 24 skipped in 256.74s (0:04:16) ==================
|
Dear authors, @tridao @albertfgu Firstly, thanks for the awesome work on theoretical analysis and code development of mamba, mamba2, and other series of state space models! Currently, many users (#356, #236, #180) expect mamba to natively support variable-length training (just like what In this PR: So, could this PR would be reviewed and merged as a feature for mamba if possible? Thanks! |
It's great to see that there already one paper/project (Is Mamba Compatible with Trajectory Optimization in Offline Reinforcement Learning, NeurIPS'24) adopting our code in the area of offline Reinforcement Learning. |
Hi @zigzagcai, thank you for the great work! I tried to install your version but encountered the The full pipeline I did is the following: # (optionally) clone causal-conv1d, also tried pip install causal-conv1d==1.4.0
git clone https://github.com/Dao-AILab/causal-conv1d
cd causal-conv1d
git checkout v1.4.0
pip install -e .
cd ..
# clone and checkout your pr
git clone https://github.com/state-spaces/mamba
cd mamba
git fetch origin pull/244/head:pr-244
git checkout pr-244
pip install -e . Tried installing with pytorch 2.4, 2.1, cuda 12.5, 12.1. All settings have the same problem: > python tests/ops/test_mamba_cu_seqlens_equivalence.py
Traceback (most recent call last):
File "/.../mamba/tests/ops/test_mamba_cu_seqlens_equivalence.py", line 5, in <module>
from mamba_ssm.modules.mamba_simple import Mamba
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/__init__.py", line 3, in <module>
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/selective_scan_interface.py", line 16, in <module>
import selective_scan_cuda
ImportError: /usr/local/lib/python3.10/dist-packages/selective_scan_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops10zeros_like4callERKNS_6TensorESt8optionalIN3c1010ScalarTypeEES5_INS6_6LayoutEES5_INS6_6DeviceEES5_IbES5_INS6_12MemoryFormatEE Additionally, I also found that the installed
Similarly, > pip show mamba-ssm
Name: mamba_ssm
Version: 2.2.2
Summary: Mamba state-space model
Home-page:
Author:
Author-email: Tri Dao <[email protected]>, Albert Gu <[email protected]>
...
Location: /usr/local/lib/python3.10/dist-packages
Requires: einops, ninja, packaging, setuptools, torch, transformers, triton (causal-conv1d is not here)
Required-by: If this issue does't occur to you, could you provide the installing script you are using for the most up-to-date version? Thanks! |
Hi @JindongJiang , I share my minimum reproducing steps here.
|
Hi, @JindongJiang Firstly, Thanks for your interest in this PR!
|
Hi @zigzagcai, thank you very much for the help. Interestingly, deleting the
Beside the pytorch and cuda version, I used the same setup as you suggested:
I will now try using cuda 11.8 as well and will let you know if I get the same problem. |
Hi @zigzagcai, I am back with cuda 11.8 results, problem still exist. This time I am (almost) fully following your setup script:
Only difference is that I have to do
Complete results and env:
It is actually quite surprising that the big discrepancies only happen at the beginning and end: in_proj and out_proj. Could you provide some comments on this? Thanks! |
Hi @JindongJiang , The error below is caused by the
|
I just revert the recent merge commit in 0a15f1d Could you please re-try my branch? I just re-tested the code on my env and it is okay.
The test results
FYI. My local envs (including cuda version and pip packages):
|
BTW. @JindongJiang Which model of GPU are you using, A100, H100 or others? This way I can have better knowledge about your software and hardware environment. |
Hi @zigzagcai , thank you very much for the updates and new commits. I will test the new setup. I got the above results using A100. |
0a15f1d
to
cda4b5a
Compare
Hi @zigzagcai , it seems that the grad discrepancy only exist when I use docker image in slurm. I have two ways to run the experiments:
Thank you for your help again! I think the problem is not in the implementation then. I will use conda without docker for now. |
Very glad to see it is helpful to you! You are right. I guess there might be some conflicts when you try to install packages with |
Hi @zigzagcai Here is how I install dependencies, which might be useful for those working with CUDA 12.5: `conda create -n your_env_name python=3.10.13 pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 pip install -r requirements.txt git clone [email protected]:hustvl/Vim.git pip install -e causal-conv1d>=1.1.0 pip install -e mamba-1p1p1 pip install --upgrade huggingface-hub==0.24.0` I made a slight adjustment to your example, and here is the revised version: `from collections import Counter import torch sentences = [ word_counter = Counter(chain(*[sentence.lower().split() for sentence in sentences])) def variable_length_sequences(new_tensor): def unpack(packed_hidden_states, cu_seqlens): def pack(hidden_states, cu_seqlens): hidden_dim = 256 new_tensor_reeshaped_index = variable_length_sequences(padded_sequences) out_ref = mamba(hidden_states) I noticed that when processing 4 sentences, you receive embeddings for only 3 sentences (torch.Size([3, 6, 256])). It might be helpful to append last_index + 1 to the list in your variable_length_sequences function (i.e., start_indexes.append(last_index + 1)). This adjustment should ensure that the number of output sentences matches the number of input sentences (torch.Size([4, 6, 256])). I am receiving embeddings with a shape of torch.Size([4, 6, 256]). However, one of my sentences contains only three words. Should I apply masking to the returned sequences to remove embeddings that might not be meaningful? Thanks, |
Support variable-length sequences for mamba block via
cu_seqlens
in theforward
pass andbackward
pass, similar to what has been done (such as cumulative sequencescu_seqlens
or lower triangular block diagonal matrixattention mask
) in flash attentionvarlen_fwd/varlen_bwd
API.We have tested that training with variable-length sequences on real world datasets can bring 2~4x speedup.
Why we need?
High speedup and hardware utilization on real world datasets that we tested. Can be used to improve hardware utilization when you have variable-length sequences and you don't want to waste computing resources on meaningless padded tokens. Especially useful when you do mamba training on real world datasets, where length distribution varies much and large proportion of samples are short sequences. Last but not least, we ensure exact fwd/bwd numerical equality with padding approach.
How to use?
Zero learning overhead, packed mamba API is similar to packed flash-attn API or packed mamba2 API. Just need to pack multiple variable-length sequences into one and additionally pass
cu_seqlens
into mambaforward
pass.Note:
We thank @wang-zerui for the python reference implementation and invaluable discussion on how to ensure numerical equality.
This is a joint work with @wang-zerui and @Dmovic and @ptxu78
Some related issues about mamba and flash-attn variable-length training: