Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[triton] Support head_dim not 2^n in triton extend and decode attention #1281

Merged
merged 13 commits into from
Sep 9, 2024

Conversation

ByronHsu
Copy link
Collaborator

@ByronHsu ByronHsu commented Sep 1, 2024

Motivation

In #1159, users found that head_dim = 96 is not supported in native triton kernel.

Modifications

Originally, we enforce head_dim to be 2^n (there is one special case 576, where it is split into 512 and 64, both are still 2^n). This is because we don't handle padding for non 2^n case on head_dim's BLOCK, to be specific tl.arange(0, BLOCK_DMODEL) errors out because it only accepts 2^n.

The solution is to pad the block size on head_dim's BLOCK to the next power of 2, and intersect the original mask with head_dim mask.

By doing so, i was able to

  1. Pass extend attention test with head_dim = 96 (I added a new test)
python extend_attention.py 
Mean:  tensor(8.3447e-07, device='cuda:0', dtype=torch.float16)
Max:  tensor(0.0001, device='cuda:0', dtype=torch.float16)
Mean:  tensor(8.3447e-07, device='cuda:0', dtype=torch.float16)
Max:  tensor(0.0001, device='cuda:0', dtype=torch.float16)
  1. Host the model successfully and run example code

image
image
(The answer was correct, but it contained some redundant output, not sure if it is because of the model or the kernel)

  1. Confirmed the original 2^n case still works fine. The answer of gemma-2b-it is very legit.

image

Discussion

  1. Any specific reason we don't handle non 2^n case?
  2. I only see in-file test for extend_attention, and cannot find the test of decode attention. Appreciate for any pointer!
  3. Should we remove the assertion since we have non 2^n handling now? I have little context for the head_dim = 576 case.

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

cc @merrymercy @Ying1123 @zhyncs to take a look. Thanks!

@ByronHsu ByronHsu marked this pull request as ready for review September 1, 2024 05:35
@ByronHsu ByronHsu changed the title Support head_dim not 2^n in triton extend and decode attention [bug fix] Support head_dim not 2^n in triton extend and decode attention Sep 1, 2024
@zhyncs
Copy link
Member

zhyncs commented Sep 1, 2024

Nice work!

@ByronHsu
Copy link
Collaborator Author

ByronHsu commented Sep 1, 2024

Unit test is failing. Let me take a look

@ByronHsu
Copy link
Collaborator Author

ByronHsu commented Sep 1, 2024

I have followed the hardcode way by adding 96 to the allowed set. However, i believe with this change, we can support all non 2^n cases too, but we can follow up in the later PR.

@ByronHsu
Copy link
Collaborator Author

ByronHsu commented Sep 1, 2024

also related to #1109

@ByronHsu
Copy link
Collaborator Author

ByronHsu commented Sep 1, 2024

Looks like a flakiness in 2GPU test @zhyncs

@zhyncs
Copy link
Member

zhyncs commented Sep 1, 2024

Looks like a flakiness in 2GPU test @zhyncs

Latest main is ok, main has been merged now.

@ByronHsu ByronHsu changed the title [bug fix] Support head_dim not 2^n in triton extend and decode attention [triton] Support head_dim not 2^n in triton extend and decode attention Sep 1, 2024
@ispobock
Copy link
Collaborator

ispobock commented Sep 1, 2024

I tested DeepSeek-V2-Lite and Meta-Llama-3-8B on A100, and see slightly performance degeneragation, @ByronHsu could you help verify:

DeepSeek-V2-Lite
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V2-Lite --trust-remote-code --disable-radix-cache --tp=1 --enable-mla
python3 -m sglang.bench_serving --backend sglang --num-prompts 3000

# main branch
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     3000
Benchmark duration (s):                  155.49
Total input tokens:                      714456
Total generated tokens:                  656556
Total generated tokens (retokenized):    655790
Request throughput (req/s):              19.29
Input token throughput (tok/s):          4594.85
Output token throughput (tok/s):         4222.48
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   66964.88
Median E2E Latency (ms):                 62149.36
---------------Time to First Token----------------
Mean TTFT (ms):                          21425.47
Median TTFT (ms):                        21079.74
P99 TTFT (ms):                           41214.53
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          761.30
Median TPOT (ms):                        265.87
P99 TPOT (ms):                           6491.08
---------------Inter-token Latency----------------
Mean ITL (ms):                           251.98
Median ITL (ms):                         114.43
P99 ITL (ms):                            758.35
==================================================

# this PR
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     3000
Benchmark duration (s):                  159.45
Total input tokens:                      714456
Total generated tokens:                  656556
Total generated tokens (retokenized):    655794
Request throughput (req/s):              18.81
Input token throughput (tok/s):          4480.81
Output token throughput (tok/s):         4117.68
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   70939.11
Median E2E Latency (ms):                 66236.11
---------------Time to First Token----------------
Mean TTFT (ms):                          24700.33
Median TTFT (ms):                        24615.69
P99 TTFT (ms):                           44581.79
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          793.96
Median TPOT (ms):                        274.87
P99 TPOT (ms):                           7126.74
---------------Inter-token Latency----------------
Mean ITL (ms):                           255.57
Median ITL (ms):                         112.59
P99 ITL (ms):                            781.51
==================================================
Meta-Llama-3-8B
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B --trust-remote-code --disable-radix-cache --tp=1 --disable-flashinfer
python3 -m sglang.bench_serving --backend sglang --num-prompts 3000

# main branch
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     3000
Benchmark duration (s):                  150.40
Total input tokens:                      658644
Total generated tokens:                  607312
Total generated tokens (retokenized):    601739
Request throughput (req/s):              19.95
Input token throughput (tok/s):          4379.19
Output token throughput (tok/s):         4037.89
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   76772.80
Median E2E Latency (ms):                 77663.50
---------------Time to First Token----------------
Mean TTFT (ms):                          39820.40
Median TTFT (ms):                        28166.97
P99 TTFT (ms):                           99590.99
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          247.05
Median TPOT (ms):                        210.01
P99 TPOT (ms):                           1289.23
---------------Inter-token Latency----------------
Mean ITL (ms):                           188.23
Median ITL (ms):                         142.75
P99 ITL (ms):                            694.74
==================================================

# this PR
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     3000
Benchmark duration (s):                  154.41
Total input tokens:                      658644
Total generated tokens:                  607312
Total generated tokens (retokenized):    601909
Request throughput (req/s):              19.43
Input token throughput (tok/s):          4265.45
Output token throughput (tok/s):         3933.02
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   77446.45
Median E2E Latency (ms):                 78719.17
---------------Time to First Token----------------
Mean TTFT (ms):                          40720.12
Median TTFT (ms):                        29141.32
P99 TTFT (ms):                           100828.28
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          244.30
Median TPOT (ms):                        207.76
P99 TPOT (ms):                           1168.11
---------------Inter-token Latency----------------
Mean ITL (ms):                           189.12
Median ITL (ms):                         136.21
P99 ITL (ms):                            653.50
==================================================

@ByronHsu
Copy link
Collaborator Author

ByronHsu commented Sep 1, 2024

I tested on L4 for llama 3, and it seems the PR is slightly faster hmm. deepseek OOM'ed on L4.
(left: main, right: PR)
image

@zhyncs
Copy link
Member

zhyncs commented Sep 1, 2024

@ByronHsu We will soon release v0.3, this PR will be reviewed and merged after v0.3, thank you for your understanding, and our main optimization and daily CI run environment are H100, can you reproduce it with H100? Thanks.

@ByronHsu
Copy link
Collaborator Author

ByronHsu commented Sep 1, 2024

Sounds good! I will try to see if i can get some on-demand H100 to test.

@ByronHsu
Copy link
Collaborator Author

ByronHsu commented Sep 1, 2024

Benchmark

left: main, right: this PR

tldr: I tested on both A100 and H100 using lambda cloud, and look like it does not have notable slowdown (do observe some randomness in the bench tho, maybe it is why causing the above difference)

cc @merrymercy @zhyncs

A100-SXM4-40GB

python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B --trust-remote-code --disable-radix-cache --tp=1 --disable-flashinfer
python3 -m sglang.bench_serving --backend sglang --num-prompts 1000

llama3-8B does not have apparent regression

image
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V2-Lite --trust-remote-code --disable-radix-cache --tp=1 --enable-mla
python3 -m sglang.bench_serving --backend sglang --num-prompts 1000

DeepSeek-V2-Lite does not have apparent regression too
image

image

H100-PCIE-80GB

llama
image

deepseek
image

@ByronHsu
Copy link
Collaborator Author

ByronHsu commented Sep 1, 2024

Do we want to incorporate this into v0.3 given the on-par performance on H100? I can later work on supporting dim=80 #1109. The code is iterated so fast and i am afraid there might be much conflict soon.

@zhyncs
Copy link
Member

zhyncs commented Sep 2, 2024

I’ll take a look.

@merrymercy
Copy link
Contributor

@zhyncs Can we merge this to unblock phi3 support?

@zhyncs
Copy link
Member

zhyncs commented Sep 3, 2024

Wait for me to check the DeepSeek V2 perf.

@ByronHsu
Copy link
Collaborator Author

ByronHsu commented Sep 7, 2024

how is the testing?

@zhyncs
Copy link
Member

zhyncs commented Sep 7, 2024

My computer is currently being repaired. I expect to pick it up tomorrow afternoon. 😂

@ByronHsu
Copy link
Collaborator Author

ByronHsu commented Sep 7, 2024

Wish all the best for your computer XD

@zhyncs
Copy link
Member

zhyncs commented Sep 8, 2024

I've got my new computer and I'll review it asap.

@zhyncs
Copy link
Member

zhyncs commented Sep 8, 2024

I verified in this branch, 1000 num prompts are basically consistent on H100, and when it reaches 5000, it decreases by about 2%. @ispobock @ByronHsu

@zhyncs
Copy link
Member

zhyncs commented Sep 8, 2024

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V2-Lite --trust-remote-code --disable-radix-cache --tp=1 --enable-mla
python3 -m sglang.bench_serving --backend sglang --num-prompts 5000
# main
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     5000
Benchmark duration (s):                  140.72
Total input tokens:                      1224620
Total generated tokens:                  1061203
Total generated tokens (retokenized):    1059740
Request throughput (req/s):              35.53
Input token throughput (tok/s):          8702.71
Output token throughput (tok/s):         7541.40
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   66816.46
Median E2E Latency (ms):                 63138.32
---------------Time to First Token----------------
Mean TTFT (ms):                          24583.26
Median TTFT (ms):                        22696.98
P99 TTFT (ms):                           49841.68
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          589.90
Median TPOT (ms):                        278.05
P99 TPOT (ms):                           4141.21
---------------Inter-token Latency----------------
Mean ITL (ms):                           261.19
Median ITL (ms):                         112.11
P99 ITL (ms):                            1234.07
==================================================

# pr round 1
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     5000
Benchmark duration (s):                  145.43
Total input tokens:                      1224620
Total generated tokens:                  1061203
Total generated tokens (retokenized):    1059822
Request throughput (req/s):              34.38
Input token throughput (tok/s):          8420.75
Output token throughput (tok/s):         7297.06
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   68165.95
Median E2E Latency (ms):                 64177.48
---------------Time to First Token----------------
Mean TTFT (ms):                          24173.78
Median TTFT (ms):                        23413.88
P99 TTFT (ms):                           48337.06
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          593.05
Median TPOT (ms):                        287.36
P99 TPOT (ms):                           3859.84
---------------Inter-token Latency----------------
Mean ITL (ms):                           270.71
Median ITL (ms):                         116.18
P99 ITL (ms):                            1228.87
==================================================

# pr round 2
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     5000
Benchmark duration (s):                  142.15
Total input tokens:                      1224620
Total generated tokens:                  1061203
Total generated tokens (retokenized):    1059756
Request throughput (req/s):              35.17
Input token throughput (tok/s):          8615.12
Output token throughput (tok/s):         7465.49
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   67235.38
Median E2E Latency (ms):                 63448.25
---------------Time to First Token----------------
Mean TTFT (ms):                          24270.60
Median TTFT (ms):                        23483.24
P99 TTFT (ms):                           48095.19
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          594.59
Median TPOT (ms):                        282.76
P99 TPOT (ms):                           3877.19
---------------Inter-token Latency----------------
Mean ITL (ms):                           270.61
Median ITL (ms):                         110.07
P99 ITL (ms):                            1260.49
==================================================

@ByronHsu
Copy link
Collaborator Author

ByronHsu commented Sep 8, 2024

@zhyncs if you run "main" multiple times, it might also have variance. Maybe we can kick off 5 for both, and calculate the mean?

@zhyncs
Copy link
Member

zhyncs commented Sep 8, 2024

From my point of view, I am also willing to interpret it as a fluctuation, and a 2% fluctuation is acceptable in my opinion, currently only affecting the Triton runtime. With that being said, we can go ahead and support Phi, I will assist with merging that PR as soon as possible, you can take a look at the reasons for the current failed CIs.

@zhyncs zhyncs enabled auto-merge (squash) September 8, 2024 21:19
@zhyncs zhyncs merged commit 8e6bdf8 into sgl-project:main Sep 9, 2024
9 checks passed
@merrymercy
Copy link
Contributor

merrymercy commented Sep 9, 2024

@ByronHsu Thanks for the contribution. It is merged. Can you follow up with these items?

  1. Remove the assertion
  2. Set up some unit test cases following the guide here. https://github.com/sgl-project/sglang/tree/main/test#test-backend-runtime. It can be a file called sglang/test/srt/test_triton_attention_kernels.py

@ByronHsu
Copy link
Collaborator Author

ByronHsu commented Sep 9, 2024

ok. created a issue to track for myself: #1359. Please assign to me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants