-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU] Remove the limitation that requires to memset zero for KVCache of PagedAttention #28681
[CPU] Remove the limitation that requires to memset zero for KVCache of PagedAttention #28681
Conversation
@@ -2356,6 +2356,28 @@ struct AttentionExecutor : public PagedAttentionExecutor { | |||
_slot_mapping.ptr<int32_t>()[idx++] = | |||
block_number * _helper._block_size + block_offset % _helper._block_size; | |||
} | |||
// To simplify tails of the kernels for Q*K and W*V: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a WA for first-token-kernels which couldn't correctly support tails, matmul(attn_socre, value)
to be exact, why this WA not near the code of kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be useful if padding zero logic is merged into exec_loop_mixed/pack kv
, but if merged the workitem will be reduced from header number*zero tokens to header number. So keep it here still reasonable since here is the centralized logic to handle the destination kvcache.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
// W*V aka [m, k1] * [n1, k1]', there is no tails handing for n1, so tails of v_cache need to be set to | ||
// zero. | ||
// for second token, the kernels have tails handling logic | ||
if (q_len != 1 && kv_len % _helper._block_size != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So in serving scenarios (where prompt processing is interleaved with second token generation) and in beam-search or speculative decoding cases (where even seconf token processed with q != 1) we will have memsets on each iteration?
Details:
1(batch number)*32(header number)*31(token number to pad)*128(header size)*2(f16 precision)*2(k+v)*32(layer number)=16.25M, the cost will be 16.25M/(50GB/s)=0.3ms, it should have small impact comparing to the cost of first token which typically is hundreds or thousands of milliseconds.
Tickets: