-
Notifications
You must be signed in to change notification settings - Fork 811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[triton] Support head_dim not 2^n in triton extend and decode attention #1281
Conversation
Nice work! |
Unit test is failing. Let me take a look |
I have followed the hardcode way by adding |
also related to #1109 |
Looks like a flakiness in 2GPU test @zhyncs |
Latest main is ok, main has been merged now. |
I tested DeepSeek-V2-Lite and Meta-Llama-3-8B on A100, and see slightly performance degeneragation, @ByronHsu could you help verify: DeepSeek-V2-Lite
Meta-Llama-3-8B
|
@ByronHsu We will soon release v0.3, this PR will be reviewed and merged after v0.3, thank you for your understanding, and our main optimization and daily CI run environment are H100, can you reproduce it with H100? Thanks. |
Sounds good! I will try to see if i can get some on-demand H100 to test. |
Do we want to incorporate this into v0.3 given the on-par performance on H100? I can later work on supporting dim=80 #1109. The code is iterated so fast and i am afraid there might be much conflict soon. |
I’ll take a look. |
@zhyncs Can we merge this to unblock phi3 support? |
Wait for me to check the DeepSeek V2 perf. |
how is the testing? |
My computer is currently being repaired. I expect to pick it up tomorrow afternoon. 😂 |
Wish all the best for your computer XD |
I've got my new computer and I'll review it asap. |
|
@zhyncs if you run "main" multiple times, it might also have variance. Maybe we can kick off 5 for both, and calculate the mean? |
From my point of view, I am also willing to interpret it as a fluctuation, and a 2% fluctuation is acceptable in my opinion, currently only affecting the Triton runtime. With that being said, we can go ahead and support Phi, I will assist with merging that PR as soon as possible, you can take a look at the reasons for the current failed CIs. |
@ByronHsu Thanks for the contribution. It is merged. Can you follow up with these items?
|
ok. created a issue to track for myself: #1359. Please assign to me |
Motivation
In #1159, users found that head_dim = 96 is not supported in native triton kernel.
Modifications
Originally, we enforce head_dim to be
2^n
(there is one special case576
, where it is split into512
and64
, both are still2^n
). This is because we don't handle padding for non2^n
case on head_dim's BLOCK, to be specifictl.arange(0, BLOCK_DMODEL)
errors out because it only accepts2^n
.The solution is to pad the block size on head_dim's BLOCK to the next power of 2, and intersect the original mask with head_dim mask.
By doing so, i was able to
(The answer was correct, but it contained some redundant output, not sure if it is because of the model or the kernel)
Discussion
head_dim = 576
case.Checklist
cc @merrymercy @Ying1123 @zhyncs to take a look. Thanks!