-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add feature / fix bug: I fixed the kNRows
feature in forward
#161
Conversation
@MzeroMiko could you please share you benchmark numbers and platform? I see that computation slows down quite a bit when using nrows>1. Am I missing something? These are the times I get on A100 80GB GPU My benchmark code can be found here Thanks! |
Thank you very much, @apoorv2904.
|
Thank you for sharing this splendid work!
I found that
kNRows
is always1
inoriginal selective_scan
, and I observed that if I use greaterkNRows
inselective scan
, the faster the code would run. The phenomenon is consistent with mamba.py, when addingd_state
, the time consumption keeps. Though it is not strictly right, but adding the burden of one thread and reducing the number of blocks (as SM is limited) really works in most of cases.So I reopen that feature which may be deprecated in
original selective_scan
, and fixed some bugs related to it.I have tested with
pytest tests/ops/test_selective_scan_.py
(which you may delete later), and all tests pass.Note that I have only fixed the forward procedure, so in backward,
nrows
is still1
.Before Merging: I found that, when I uncomment all alternative parameters, the test is not all pass. However,
mamba_ssm-1.1.3.post1+cu122torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
acts the same.