Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Perf] Reduce scheduling overhead in model runner after cuda sync #12094

Merged
merged 6 commits into from
Jan 26, 2025

Conversation

youngkent
Copy link
Contributor

@youngkent youngkent commented Jan 15, 2025

We do some runner bookkeeping CPU operations after decoding iteration. We could parallelize some bookkeeping work while waiting on cuda sync. After the cuda sync, we only need to do simple and fast updates.

The change should reduce scheduling overhead between decode iterations by ~20%. (See attached gpu trace)

Before the optimization,
Screenshot 2025-01-15 at 10 39 51 AM

After the optimization,
Screenshot 2025-01-15 at 10 40 42 AM

E2E latency benchmark, ran

VLLM_USE_V1=1 python3 benchmarks/benchmark_latency.py --model "/data/users/ktong/llama/llm_8b_oss" --tensor-parallel-size 1 --input_len 1000 --batch_size 32

Output (1-2% e2e latency reduction):
Avg latency: 2.338167402730323 seconds
10% percentile latency: 2.3207896508742123 seconds
25% percentile latency: 2.3264574960339814 seconds
50% percentile latency: 2.3333765944698825 seconds
75% percentile latency: 2.343035737867467 seconds
90% percentile latency: 2.3567665563430635 seconds
99% percentile latency: 2.3934816433605737 seconds

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@@ -8,7 +8,7 @@
class SamplerOutput:

# [num_reqs]
sampled_token_ids: List[int]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessary?
iirc, @tlrmchlsmth use List[int] because they are cheaper to serialize, and would benefit tensor parallel case, where we need to pass them across processes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is true — I didn’t look at how it impacts the non-TP case though

Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ModelRunnerOutput is what we serialize for TP, we don't serialize the SamplerOutput directly, so this is not a concern

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yep that's right -- I did change this line in #9856, but that was just downstream of changing sampled_token_ids to a List in the ModelRunnerOutput. This looks good to me since that's left as-is!

@robertgshaw2-redhat
Copy link
Collaborator

Wow, great idea. Im going to run some perfomance analysis on this tomorrow.

vllm/v1/sample/sampler.py Outdated Show resolved Hide resolved
Signed-off-by: Keyun Tong <[email protected]>
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for discovering and fixing this!

self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])

# NOTE: GPU -> CPU Sync happens here.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for a record: If top-p or top-k sampling is used (with the FlashInfer kernel), CPU-GPU synchronization happens inside the sampler at

# NOTE: CPU-GPU synchronization happens here.
if not success.all():

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we can avoid this in a follow up PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can. This is a fundamental limitation of the kernel (or the algorithm itself). The rejection sampling method cannot 100% guarantee the success.

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 25, 2025

# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
sampled_token_ids = sampler_output.sampled_token_ids.tolist()
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be faster to do sampler_output.sampled_token_ids.cpu()and thensampler_output.sampled_token_ids[i].item()` in the inner loop.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my experience, item() took considerable time so should be avoided.

@WoosukKwon WoosukKwon merged commit fa63e71 into vllm-project:main Jan 26, 2025
42 of 44 checks passed
@WoosukKwon
Copy link
Collaborator

@youngkent Thanks for the PR! This change helps vLLM's performance noticeably.

tjtanaa added a commit to EmbeddedLLM/vllm that referenced this pull request Jan 27, 2025
* [Misc] Use VisionArena Dataset for VLM Benchmarking (vllm-project#12389)

Signed-off-by: Roger Wang <[email protected]>

* [ci/build] fix wheel size check (vllm-project#12396)

Signed-off-by: youkaichao <[email protected]>

* [Hardware][Gaudi][Doc] Add missing step in setup instructions (vllm-project#12382)

* [ci/build] sync default value for wheel size (vllm-project#12398)

Signed-off-by: youkaichao <[email protected]>

* [Misc] Enable proxy support in benchmark script (vllm-project#12356)

Signed-off-by: Junichi Sato <[email protected]>

* [Bugfix][Kernel] Fix CUDA 11.8 being broken by FA3 build (vllm-project#12375)

Signed-off-by: Lucas Wilkinson <[email protected]>

* [Misc] Remove deprecated code (vllm-project#12383)

Signed-off-by: DarkLight1337 <[email protected]>

* [Bugfix][Kernel] FA3 Fix - RuntimeError: This flash attention build only supports pack_gqa (for build size reasons). (vllm-project#12405)

Signed-off-by: Lucas Wilkinson <[email protected]>

* [Bugfix][Kernel] Fix moe align block issue for mixtral (vllm-project#12413)

* [Bugfix] Fix BLIP-2 processing (vllm-project#12412)

Signed-off-by: DarkLight1337 <[email protected]>

* [ROCm][MoE] MI300 tuned configs Mixtral-8x(7B,22B) | fp16, fp8 (vllm-project#12408)

Signed-off-by: Divakar Verma <[email protected]>

* [Misc] Add FA2 support to ViT MHA layer (vllm-project#12355)

Signed-off-by: Isotr0py <[email protected]>

* [TPU][CI] Update torchxla version in requirement-tpu.txt (vllm-project#12422)

Signed-off-by: Siyuan Liu <[email protected]>

* [Misc][Bugfix] FA3 support to ViT MHA layer (vllm-project#12435)

Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Isotr0py <[email protected]>

* [V1][Perf] Reduce scheduling overhead in model runner after cuda sync (vllm-project#12094)

Signed-off-by: Keyun Tong <[email protected]>

* [V1][Bugfix] Fix assertion when mm hashing is turned off (vllm-project#12439)

Signed-off-by: Roger Wang <[email protected]>

* [Misc] Revert FA on ViT vllm-project#12355 and vllm-project#12435 (vllm-project#12445)

* [Frontend] generation_config.json for  maximum tokens(vllm-project#12242)

Signed-off-by: Matthew Hendrey <[email protected]>
Signed-off-by: Shangming Cai <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Yuan Tang <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: wangxiyuan <[email protected]>
Co-authored-by: shangmingc <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Yuan Tang <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Chen Zhang <[email protected]>
Co-authored-by: wangxiyuan <[email protected]>

* [Bugfix] Disable w16a16 2of4 sparse CompressedTensors24 (vllm-project#12417)

Signed-off-by: Tyler Michael Smith <[email protected]>
Co-authored-by: mgoin <[email protected]>

* [Bugfix/CI] Fix broken kernels/test_mha.py (vllm-project#12450)

* [Bugfix][Kernel] Fix perf regression caused by PR vllm-project#12405 (vllm-project#12434)

Signed-off-by: Lucas Wilkinson <[email protected]>

* [Build/CI] Fix libcuda.so linkage (vllm-project#12424)

Signed-off-by: Tyler Michael Smith <[email protected]>

* [Frontend] Rerank API (Jina- and Cohere-compatible API)  (vllm-project#12376)

Signed-off-by: Kyle Mistele <[email protected]>

* [DOC] Add link to vLLM blog (vllm-project#12460)

Signed-off-by: Yuan Tang <[email protected]>

* [V1] Avoid list creation in input preparation (vllm-project#12457)

Signed-off-by: Woosuk Kwon <[email protected]>

* [Frontend] Support scores endpoint in run_batch (vllm-project#12430)

Signed-off-by: Pooya Davoodi <[email protected]>

* [Bugfix] Fix Granite 3.0 MoE model loading (vllm-project#12446)

Signed-off-by: DarkLight1337 <[email protected]>

* [Bugfix] Fix missing seq_start_loc in xformers prefill metadata (vllm-project#12464)

Signed-off-by: Isotr0py <[email protected]>

* [V1][Minor] Minor optimizations for update_from_output (vllm-project#12454)

Signed-off-by: Woosuk Kwon <[email protected]>

* [Bugfix] Fix gpt2 GGUF inference (vllm-project#12467)

Signed-off-by: Isotr0py <[email protected]>

---------

Signed-off-by: Roger Wang <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Junichi Sato <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Siyuan Liu <[email protected]>
Signed-off-by: Keyun Tong <[email protected]>
Signed-off-by: Matthew Hendrey <[email protected]>
Signed-off-by: Shangming Cai <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Yuan Tang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: wangxiyuan <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Kyle Mistele <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Pooya Davoodi <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: Mohit Deopujari <[email protected]>
Co-authored-by: Junichi Sato <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: ElizaWszola <[email protected]>
Co-authored-by: Divakar Verma <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: Siyuan Liu <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: Keyun Tong <[email protected]>
Co-authored-by: Matthew Hendrey <[email protected]>
Co-authored-by: shangmingc <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Yuan Tang <[email protected]>
Co-authored-by: Chen Zhang <[email protected]>
Co-authored-by: wangxiyuan <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: mgoin <[email protected]>
Co-authored-by: Kyle Mistele <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: Pooya Davoodi <[email protected]>
tjtanaa pushed a commit to EmbeddedLLM/vllm that referenced this pull request Jan 28, 2025
rasmith pushed a commit to rasmith/vllm that referenced this pull request Jan 30, 2025
Isotr0py pushed a commit to Isotr0py/vllm that referenced this pull request Feb 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants