-
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support #11844
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
82b5a4c
to
4c4a33e
Compare
I see that you have |
4c4a33e
to
6b7c49e
Compare
This pull request has merge conflicts that must be resolved before it can be |
6b7c49e
to
35aac26
Compare
35aac26
to
91d5476
Compare
All conflicts fixed, could you please take another look? thanks! |
st] = decode_metadata.block_tables[i, st:ed] | ||
decode_metadata.block_tables_intra = block_tables_intra | ||
|
||
seq_lens_succ = (chunk_num_curr - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I try the Needle in a haystack test with qwen-7b and llama-8b(Modified code to support llama), there is a bug that produces a negative number when it is over 13k~15k.
I modified the code as below and confirmed that it works.
seq_lens_succ = ((chunk_num_curr - (chunk_num_curr - 1).clip(min=0)) * chunk_len)
This pull request has merge conflicts that must be resolved before it can be |
I tested it because I thought it was fixed, but I still have the same problem as below.
|
91d5476
to
c8781cd
Compare
The dual chunk attention doesn't support cuda graph and I have added an assertion in
It is indeed a bug introduced during preparing this PR, fixed. Thanks! |
c8781cd
to
8648b1e
Compare
Rebase against main. Hi @youkaichao @simon-mo @WoosukKwon Do you folks think if there are still things that need to be improved in this pull request? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spotted a few bits ofcommented out code that look like debug cruft or are otherwise mysterious. Could you clean those up and any other similar spots?
This pull request has merge conflicts that must be resolved before it can be |
qc_freqs = torch.einsum("i,j -> ij", qc_t, inv_freq) | ||
k_freqs = torch.einsum("i,j -> ij", k_t, inv_freq) | ||
qc_no_clamp_freqs = torch.einsum("i,j -> ij", qc_no_clamp_t, inv_freq) | ||
q_inter_freqs = torch.einsum("i,j -> ij", q_inter_t, inv_freq) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think these einsum's are still slow on cuda than (a * b).sum(-1)
, not on the hot path though so not critical
ran bench_einsum.py
from that issue on an H100 and got:
python einsum_bench.py
[------------------------------------- -------------------------------------]
| mul/sum | torch.einsum | numpy.einsum
1 threads: -------------------------------------------------------------------
Nc,Nc->N cpu (1048576, 2) | 5000 | 3100 | 4000
Nc,Nc->N cuda (1048576, 2) | 20 | 747 | 3300
Times are in microseconds (us).
logits_soft_cap, attn_type, **{ | ||
"dual_chunk_attention_config": dual_chunk_attention_config, | ||
"prefix": prefix, | ||
} if dual_chunk_attention_config is not None else {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this messy, I think we should maybe do something like:
def __init__(..., **extra_attn_kwargs):
self.impl = impl_cls(..., **extra_attn_kwargs)
the challenge here is prefix
would not be captured by extra_attn_kwargs
but is only (currently) used by DualChunkFlashAttentionImpl
. I do think it would be less messy though to do this any make prefix
a standard arg for attention impls, given that it is pretty generic. Thoughts @WoosukKwon
vllm/attention/layer.py
Outdated
if self.dual_chunk_attention_config: | ||
assert query_succ_and_inter is not None | ||
dca_kwargs = { | ||
"query_succ": query_succ_and_inter[0], | ||
"query_inter": query_succ_and_inter[1], | ||
"query_succ_critical": query_succ_and_inter[2], | ||
"query_inter_critical": query_succ_and_inter[3], | ||
} if query_succ_and_inter else {} | ||
else: | ||
dca_kwargs = {} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should try hard to see if there is cleaner way of passing these, maybe they can be bundled into a single q
tensor that get reinterpreted as components via a combination of slicing and .view
calls in the attn impl?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would take a try to see if it can be simplified.
8648b1e
to
c7a11ee
Compare
Hi @LucasWilkinson most of the comments has been addressed, could you please take another look? Thanks! The lint error comes from the prompt text, do you have any suggestion about how could I skip/resolve it? |
c230fb7
to
e1b1d0f
Compare
…h sparse attention support. Signed-off-by: Tao He <[email protected]>
e1b1d0f
to
81d8004
Compare
@sighingnow Sorry for the delayed response! I've merged main into your branch so the pre-commit error should be cleared. I'll enable ready status for this PR so at least we can get the CI going before @tlrmchlsmth or @LucasWilkinson want to give their final greenlight! |
Signed-off-by: Roger Wang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should revisit whether it makes sense to have this 1m context example here. This file would be by far the largest in the vLLM git repo at 4MB, and I think it might be better for the example to pull the data from an external source.
Side note, I thought having the example prompts start by discussing the differences between various narcotics was an interesting choice!
vertical_indices_count: torch. | ||
Tensor, # [N_HEADS] : different head use different number of indices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: some weird whitespace here - I wouldn't want us to spend too much effort fighting clang-format but maybe the comment can be moved to its own line
for i in range(batch_size): | ||
st = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len // | ||
self.block_size) | ||
ed = min( | ||
st + (max_seq_len_succ - 1) // self.block_size + 1, | ||
(cache_seq_lens[i] - 1) // self.block_size + 1, | ||
) | ||
block_tables_succ[i, :ed - | ||
st] = decode_metadata.block_tables[i, st:ed] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improved readability:
for i in range(batch_size): | |
st = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len // | |
self.block_size) | |
ed = min( | |
st + (max_seq_len_succ - 1) // self.block_size + 1, | |
(cache_seq_lens[i] - 1) // self.block_size + 1, | |
) | |
block_tables_succ[i, :ed - | |
st] = decode_metadata.block_tables[i, st:ed] | |
for i in range(batch_size): | |
start = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len // | |
self.block_size) | |
end = min( | |
start + (max_seq_len_succ - 1) // self.block_size + 1, | |
(cache_seq_lens[i] - 1) // self.block_size + 1, | |
) | |
block_tables_succ[i, :end - | |
start] = decode_metadata.block_tables[i, start:end] |
if sparse_attn_enabled: | ||
flash_result = self._do_flash_attn( | ||
q_states_intra, | ||
k_states_intra, | ||
v_states_intra, | ||
softmax_scale=softmax_scale, | ||
causal=True, | ||
stage="intra", | ||
vertical_indices=vertical_buffer, | ||
slash_indices=slash_buffer, | ||
vertical_indices_count=vertical_size_buffer, | ||
slash_indices_count=slash_sizes_buffer, | ||
mergehead_softmax_scale=softmax_scale, | ||
sparse_attn_enabled=sparse_attn_enabled) | ||
else: | ||
flash_result = self._do_flash_attn( | ||
q_states_intra, | ||
k_states_intra, | ||
v_states_intra, | ||
softmax_scale=softmax_scale, | ||
causal=True, | ||
stage="intra", | ||
vertical_indices=intra_vertical_indices, | ||
slash_indices=intra_slash_indices, | ||
sparse_attn_enabled=sparse_attn_enabled) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: is the if/else necessary since sparse_attn_enabled
is passed into the function in both cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add SPDX headers to new files (see #12628 for reference)
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more review comments, mostly minor stuff. Looks pretty good, although I do suggest getting rid of the qwen_1m
example
__device__ void save_blocks(int* block_offset, int64_t range_start, | ||
int64_t range_end, int64_t block_size, | ||
int64_t& block_count, int64_t kv_seqlen) { | ||
if (range_start >= kv_seqlen) { | ||
return; | ||
} | ||
if (range_end > kv_seqlen) { | ||
range_end = kv_seqlen; | ||
} | ||
for (int idx = range_start; idx < range_end; idx += block_size) { | ||
block_offset[block_count++] = idx; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this function would be clearer and more explicit in its behavior if it returned the current block count instead of modifying its input argument:
__device__ void save_blocks(int* block_offset, int64_t range_start,
int64_t range_end, int64_t block_size,
int64_t input_block_count, int64_t kv_seqlen) {
if (range_start >= kv_seqlen) {
return input_block_count;
}
if (range_end > kv_seqlen) {
range_end = kv_seqlen;
}
int64_t current_block_count = input_block_count;
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[block_count++] = idx;
}
return block_count;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some comments describing what the functions in this file are doing? Comments describing what blocks of code within convert_vertical_slash_indexes_kernel
would be helpful as well
int32_max = 2147483647 # avoid sort | ||
int32_min = -2147483648 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: best to avoid hardcoding constants
int32_max = 2147483647 # avoid sort | |
int32_min = -2147483648 | |
int32_max = torch.iinfo(torch.int32).max | |
int32_min = torch.iinfo(torch.int32).min |
(Also could you elaborate on the "avoid sort" comment?)
q=query_states.bfloat16(), | ||
k=key_states.bfloat16(), | ||
v=value_states.bfloat16(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why convert these to bfloat16? I don't think we should be doing this e.g. if the model's dtype is float16
softmax_scale, | ||
chunk_size, | ||
local_size, | ||
scaling_factor.item(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scaling_factor.item(), | |
scaling_factor[i].item(), |
Otherwise this will error since scaling_factor contains more than 1 element here
This PR implements the dual-chunk flash attention, a training-free method to extend model context length (see also #6139), with sparse attention (https://github.com/microsoft/MInference) support.
This PR requires the sparse attention kernel from vllm-flash-attention. Qwen models with 1m context length support will be open-sourced in the next one or two weeks, and unit tests will be added later.
FIX #12452