Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize MiniMax-Text-01 lightning_attn_decode triton #2966

Merged
merged 1 commit into from
Jan 18, 2025

Conversation

BBuf
Copy link
Collaborator

@BBuf BBuf commented Jan 18, 2025

main branch:

lightning-attention-decode-performance:
    batch_size  seq_len  Original PyTorch Implementation  Triton Implementation
0          1.0      1.0                       384.128004             312.096000
1          2.0      1.0                       373.088002             311.919987
2          3.0      1.0                       374.751985             310.431987
3          4.0      1.0                       375.216007             314.303994
4          5.0      1.0                       372.031987             314.911991
5          6.0      1.0                       372.287989             312.703997
6          7.0      1.0                       373.600006             315.104008
7          8.0      1.0                       372.191995             311.648011
8          9.0      1.0                       373.504013             312.848002
9         10.0      1.0                       374.944001             312.000006
10        11.0      1.0                       373.120010             312.096000
11        12.0      1.0                       371.616006             315.775990
12        13.0      1.0                       371.807992             314.112008
13        14.0      1.0                       371.071994             314.080000
14        15.0      1.0                       371.583998             324.992001
15        16.0      1.0                       372.447997             338.496000
16        17.0      1.0                       368.319988             346.047997
17        18.0      1.0                       365.520000             352.800012
18        19.0      1.0                       365.935981             362.496018
19        20.0      1.0                       367.152005             370.368004
20        21.0      1.0                       369.120002             377.088010
21        22.0      1.0                       378.975987             387.071997
22        23.0      1.0                       388.736010             395.359993
23        24.0      1.0                       396.256000             402.079999
24        25.0      1.0                       404.736012             410.896003
25        26.0      1.0                       411.872000             420.320004
26        27.0      1.0                       420.704007             428.880006
27        28.0      1.0                       428.063989             437.103987
28        29.0      1.0                       436.576009             441.280007
29        30.0      1.0                       446.720004             450.031996
30        31.0      1.0                       453.967988             458.559990
31        32.0      1.0                       460.927993             466.399997

pr:

lightning-attention-decode-performance:
    batch_size  seq_len  Original PyTorch Implementation  Triton Implementation
0          1.0      1.0                       379.135996             297.376007
1          2.0      1.0                       370.447993             297.248006
2          3.0      1.0                       376.271993             297.280014
3          4.0      1.0                       377.391994             297.695994
4          5.0      1.0                       379.007995             297.695994
5          6.0      1.0                       371.087998             298.720002
6          7.0      1.0                       371.183991             298.272014
7          8.0      1.0                       372.864008             300.576001
8          9.0      1.0                       373.695999             300.448000
9         10.0      1.0                       373.439997             298.287988
10        11.0      1.0                       373.663992             298.848003
11        12.0      1.0                       372.031987             298.224002
12        13.0      1.0                       372.943997             300.096005
13        14.0      1.0                       373.344004             299.216002
14        15.0      1.0                       370.528013             301.472008
15        16.0      1.0                       372.736007             300.927997
16        17.0      1.0                       364.672005             295.264006
17        18.0      1.0                       367.967993             296.095997
18        19.0      1.0                       366.560012             295.583993
19        20.0      1.0                       365.920007             290.399998
20        21.0      1.0                       370.496005             291.440010
21        22.0      1.0                       373.279989             295.759976
22        23.0      1.0                       380.735993             294.784009
23        24.0      1.0                       392.623991             295.967996
24        25.0      1.0                       402.000010             293.152004
25        26.0      1.0                       408.960015             294.559985
26        27.0      1.0                       419.072002             293.056011
27        28.0      1.0                       425.024003             293.504000
28        29.0      1.0                       434.816003             294.719994
29        30.0      1.0                       444.768012             297.280014
30        31.0      1.0                       453.664005             305.375993
31        32.0      1.0                       460.272014             307.711989

By removing explicit padding in lightning_attn_decode triton kernel and directly computing through a mask in the triton kernel, the overhead is significantly reduced. This approach results in a 30%-50% end2end speedup compared to the origin PyTorch MiniMax-Text-01 decoding version.

@zhyncs

@BBuf BBuf requested a review from zhyncs January 18, 2025 15:22
@zhyncs zhyncs requested a review from ispobock January 18, 2025 15:24
@zhyncs zhyncs merged commit c2f212d into main Jan 18, 2025
3 checks passed
@zhyncs zhyncs deleted the optimize_lighting_attention_decode_triton branch January 18, 2025 15:41
@yzhangcs
Copy link

yzhangcs commented Jan 18, 2025

@BBuf Hi, it looks like BLOCK_SIZE does not actually work in this kernel?

@BBuf
Copy link
Collaborator Author

BBuf commented Jan 19, 2025

@BBuf Hi, it looks BLOCK_SIZE does not actually work in this kernel?

Yeah, I'll fix it, thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants