Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster overlap mode scheduler #1738

Merged
merged 4 commits into from
Oct 21, 2024
Merged

Faster overlap mode scheduler #1738

merged 4 commits into from
Oct 21, 2024

Conversation

merrymercy
Copy link
Contributor

@merrymercy merrymercy commented Oct 21, 2024

This PR improves the order of kernel launch and result fetching. Now the overlap scheduler can bring 10% throughput improvement even when radix cache is turned off. When the radix cache is turned on, we can expect more speedup.

Benchmark results

Overlap mode: 51.03 req/s

python -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --disable-radix --enable-overlap
python -m sglang.bench_serving --model meta-llama/Llama-3.1-8B-Instruct --num-prompt 3000
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     3000
Benchmark duration (s):                  58.79
Total input tokens:                      673672
Total generated tokens:                  581627
Total generated tokens (retokenized):    581405
Request throughput (req/s):              51.03
Input token throughput (tok/s):          11459.26
Output token throughput (tok/s):         9893.56
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   28986.97
Median E2E Latency (ms):                 29088.28
---------------Time to First Token----------------
Mean TTFT (ms):                          14495.13
Median TTFT (ms):                        11312.61
P99 TTFT (ms):                           36408.59
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          144.25
Median TPOT (ms):                        86.74
P99 TPOT (ms):                           1081.64
---------------Inter-token Latency----------------
Mean ITL (ms):                           78.78
Median ITL (ms):                         32.48
P99 ITL (ms):                            529.30
==================================================

Normal mode: 46.06 req/s

python -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --disable-radix 
python -m sglang.bench_serving --model meta-llama/Llama-3.1-8B-Instruct --num-prompt 3000
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Successful requests:                     3000
Benchmark duration (s):                  65.14
Total input tokens:                      673672
Total generated tokens:                  581627
Total generated tokens (retokenized):    581402
Request throughput (req/s):              46.06
Input token throughput (tok/s):          10342.28
Output token throughput (tok/s):         8929.19
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   31574.46
Median E2E Latency (ms):                 31581.12
---------------Time to First Token----------------
Mean TTFT (ms):                          15352.12
Median TTFT (ms):                        11615.68
P99 TTFT (ms):                           39444.51
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          157.51
Median TPOT (ms):                        96.38
P99 TPOT (ms):                           1131.20
---------------Inter-token Latency----------------
Mean ITL (ms):                           87.11
Median ITL (ms):                         37.10
P99 ITL (ms):                            554.28
==================================================

Notes

  1. We still only use multi-threading under the limitation of GIL. We can expect a larger improvement if we move to multi-processing or we can turn off GIL.
  2. The overlap scheduler is an experimental feature. I verified its accuracy on GSM-8k, and it matches that of the normal scheduler. It works for standard decoding, but it does not support sampling penalizers (e.g., frequency and repetition penalties) or constrained decoding (e.g., regex, JSON).

@merrymercy merrymercy changed the title Launch a copy thread for overlapped scheduler Faster overlap mode scheduler Oct 21, 2024
@merrymercy merrymercy merged commit 7ce3606 into main Oct 21, 2024
9 of 10 checks passed
@merrymercy merrymercy deleted the multi-stream branch October 21, 2024 11:30
@merrymercy merrymercy mentioned this pull request Oct 23, 2024
37 tasks
@fengyang95
Copy link

@merrymercy Has this been tested on larger models? I tried the deepseek-v2.5 fp8 version, but it doesn't seem to show much improvement.

@ykcombat
Copy link

ykcombat commented Dec 1, 2024

@merrymercy Have you ever tested overlap mode scheduler when receiving requests at a certain request rate rather than sending all the requests at the beginning?
When I test it without specifying request rate, everthing goes all right.
python -m sglang.bench_serving --backend sglang --num-prompt 10
But when i tried specifying request, letting requests sent in Possion distribution:
python -m sglang.bench_serving --backend sglang --num-prompt 10 --request-rate 2
I have encountered a mysterious bugs:
CUDA Error: device-side assert triggered (710) /tmp/build-via-sdist-d34cpfe8/flashinfer-0.1.6+cu121torch2.4/include/flashinfer/attention/decode.cuh: line 749 at function cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) Exception in thread Thread-3 (forward_thread_func): Traceback (most recent call last): File "/state/partition/ykchen/conda/envs/sglang/lib/python3.10/threading.py", line 1016, in _bootstrap_inner self.run() File "/state/partition/ykchen/conda/envs/sglang/lib/python3.10/threading.py", line 953, in run self._target(*self._args, **self._kwargs) File "/home/ykchen/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 99, in forward_thread_func self.forward_thread_func_() File "/state/partition/ykchen/conda/envs/sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/home/ykchen/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 116, in forward_thread_func_ logits_output, next_token_ids = self.worker.forward_batch_generation( File "/home/ykchen/sglang/python/sglang/srt/managers/tp_worker.py", line 139, in forward_batch_generation logits_output = self.model_runner.forward(forward_batch) File "/home/ykchen/sglang/python/sglang/srt/model_executor/model_runner.py", line 594, in forward return self.forward_decode(forward_batch) File "/home/ykchen/sglang/python/sglang/srt/model_executor/model_runner.py", line 565, in forward_decode return self.model.forward( File "/state/partition/ykchen/conda/envs/sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/home/ykchen/sglang/python/sglang/srt/models/llama.py", line 371, in forward hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) File "/state/partition/ykchen/conda/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/state/partition/ykchen/conda/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/home/ykchen/sglang/python/sglang/srt/models/llama.py", line 284, in forward hidden_states, residual = layer( File "/state/partition/ykchen/conda/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/state/partition/ykchen/conda/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/home/ykchen/sglang/python/sglang/srt/models/llama.py", line 234, in forward hidden_states = self.self_attn( File "/state/partition/ykchen/conda/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/state/partition/ykchen/conda/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/home/ykchen/sglang/python/sglang/srt/models/llama.py", line 171, in forward attn_output = self.attn(q, k, v, forward_batch) File "/state/partition/ykchen/conda/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/state/partition/ykchen/conda/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/home/ykchen/sglang/python/sglang/srt/layers/radix_attention.py", line 60, in forward return forward_batch.attn_backend.forward(q, k, v, self, forward_batch) File "/home/ykchen/sglang/python/sglang/srt/layers/attention/__init__.py", line 58, in forward return self.forward_decode(q, k, v, layer, forward_batch) File "/home/ykchen/sglang/python/sglang/srt/layers/attention/flashinfer_backend.py", line 284, in forward_decode o = decode_wrapper.forward( File "/state/partition/ykchen/conda/envs/sglang/lib/python3.10/site-packages/flashinfer/decode.py", line 589, in forward return self.run( File "/state/partition/ykchen/conda/envs/sglang/lib/python3.10/site-packages/flashinfer/decode.py", line 673, in run out = self._wrapper.run( RuntimeError: BatchDecodeWithPagedKVCache failed with error device-side assert triggered
This script works with normal scheduler.
Did I make any mistakes or it's a bug for overlap mode scheduler?

@merrymercy
Copy link
Contributor Author

@ykcombat Did you try it with the latest main branch? If the error is still there, please open a new issue with reproducible instructions. We will fix it very soon if we can reproduce that.

@ykcombat
Copy link

ykcombat commented Dec 2, 2024

@ykcombat Did you try it with the latest main branch? If the error is still there, please open a new issue with reproducible instructions. We will fix it very soon if we can reproduce that.

@merrymercy Thanks for your quick reply! I tried it with the latest main branch but it seems that the error is still there. I have opened a new issue at #2312.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants