Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] fix crash on flashinfer backend with cudagraph disabled, when attention group_size not in [1,2,4,8] #7509

Merged
merged 3 commits into from
Aug 21, 2024

Conversation

learninmou
Copy link
Contributor

@learninmou learninmou commented Aug 14, 2024

when I use flashinfer backend and disable cuda graph, load a model with attention group_size=6, vllm crashs and shows the following log:
screenshot-20240814-161315

This error consistently occurs under the following conditions:

  1. user use flashinfer attention backend explicitly, (set env VLLM_ATTENTION_BACKEND="FLASHINFER")
  2. the model use GQA, and group size not in [1,2,4,8]. (I've reviewed flashinfer's source, it only support group_size [1,2,4,8] otherwise user need to use_tensor_cores=True to do decode)
  3. cuda graph is disabled, it can happen when user set enforce_eager=True, or the decoding batch_size > _BATCH_SIZES_TO_CAPTURE[-1] (variable in model_runner.py, currently 256)

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@learninmou learninmou changed the title fix crash on flashinfer backend with cudagraph disabled, when attention group_size not in [1,2,4,8] [BUG] fix crash on flashinfer backend with cudagraph disabled, when attention group_size not in [1,2,4,8] Aug 14, 2024
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Can you add a unit test?

vllm/worker/model_runner.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
@learninmou
Copy link
Contributor Author

LGTM. Can you add a unit test?

add utest case in the last commit which can cover this bug if user do not set BatchDecodeWithPagedKVCacheWrapper properly.

@learninmou
Copy link
Contributor Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 16, 2024
@comaniac comaniac enabled auto-merge (squash) August 16, 2024 02:27
@JaheimLee
Copy link

When will this pr be merged?

@comaniac
Copy link
Collaborator

@learninmou could you rebase to see if CI failure is fixed in main branch already?

auto-merge was automatically disabled August 21, 2024 03:44

Head branch was pushed to by a user without write access

@learninmou
Copy link
Contributor Author

@learninmou could you rebase to see if CI failure is fixed in main branch already?

rebase finished, all checks have passed

@comaniac comaniac merged commit 53328d7 into vllm-project:main Aug 21, 2024
49 checks passed
patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Aug 21, 2024
@elfiegg
Copy link
Contributor

elfiegg commented Aug 26, 2024

Hi, we recently discover that this PR caused 11.4% perf regression:
Before: 5233.12 tokens/s
After: 4697.55 tokens/s

Tested on H200 machines with below command:
python benchmarks/benchmark_throughput.py --model=neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV \ --num-prompts=1024 --output-len=256 --input-len=256 --quantization=fp8 \ --max-num-batched-tokens=8192 --max-num-seqs=1024

Have we done perf test on this PR?
If this is the case - Can we revert it until further investigations are completed?

@comaniac
Copy link
Collaborator

Hmm this PR shouldn't introduce performance regression to existing workloads as it attempts to fix uncovered cases. Is your vLLM benchmark before and after exactly this PR with the same environment? If so is it possible for you to identify the root cause? It could be one of the following I could think of:

  1. The use of use_tensor_cores in FlashInfer kernel.
  2. The overhead of the logic checking the group size.

Also which model and GPU you're benchmarking? Thanks.

omrishiv pushed a commit to omrishiv/vllm that referenced this pull request Aug 26, 2024
@elfiegg
Copy link
Contributor

elfiegg commented Aug 26, 2024

I'm not exactly sure why this would cause perf regression yet. But I reran the benchmark, it shows about 18% slowdown:

Before: Throughput: 10.56 requests/s, 5409.15 tokens/s
After: Throughput: 8.91 requests/s, 4563.67 tokens/s

Benchmark is performed on Llama3-70B on H200.
Command:
python benchmarks/benchmark_throughput.py --model=neuralmagic/Meta-Llama-3-70B-Instruct-FP8 --num-prompts=768 --output-len=256 --input-len=256 --quantization=fp8 --kv-cache-dtype=fp8 --gpu-memory-utilization=0.98 --max-num-batched-tokens=2048 --max-num-seqs=768 --max_model_len=512

Is it possible that we revert it for now until further investigations?

@comaniac
Copy link
Collaborator

Ok I probably know the reason. Before this PR:

use_tensor_cores = num_qo_heads // num_kv_heads >= 4

After this PR:

use_tensor_cores = (num_qo_heads // num_kv_heads) not in (1, 2, 4, 8)
  • Llama-3-8B: qo_heads (32) // kv_heads (8) = 4. Before this PR we set use_tensor_cores=True but after this PR it sets to False.
  • Llama-3-70B: qo_heads (64) // kv_heads (8) = 8. Before this PR we set use_tensor_cores=True but after this PR it sets to False.

I'll file a PR to fix this and you could test the PR to see if that helps.

@yzh119
Copy link

yzh119 commented Aug 26, 2024

Sorry I just noticed this PR, this fix will degrade performance. The reason flashinfer compiles kernel with use_tensor_cores=False for group_size in (1,2,4,8) is not related to performance, it's because they are most popular group sizes and I keep the group_size 8 for use_tensor_cores=False mainly for correctness test. We should always enable tensor cores for large group size (more specifically, when group_size > 4), I should make it more clear in flashinfer documentation.

@comaniac comaniac mentioned this pull request Aug 26, 2024
@yzh119
Copy link

yzh119 commented Aug 26, 2024

It will be great if you can set me as a reviewer for pull requests related to flashinfer, I always miss mentions because of huge amount of notifications..

comaniac added a commit that referenced this pull request Aug 27, 2024
triple-Mu pushed a commit to triple-Mu/vllm_official that referenced this pull request Sep 4, 2024
Jeffwan pushed a commit to aibrix/vllm that referenced this pull request Sep 19, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants