Skip to content

Commit

Permalink
[KVCache] Attention func accepting over-padded qkv and output NDArray (
Browse files Browse the repository at this point in the history
…apache#17401)

This PR enhances the `AttentionWithFusedQKV` function of `PagedKVCache`
so that it can now accept input `qkv_data` and `o_data` that have
padding along the sequence dimension.

We introduce this enhancement to allow more flexibility for the caller
of PagedKVCache to decide whether to pad the input qkv/o NDArrays or
not.
  • Loading branch information
MasterJH5574 authored Sep 22, 2024
1 parent 36ff1f1 commit ce46185
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1755,7 +1755,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) {
total_seq_length += cur_append_lengths_[seq_id];
}
CHECK_EQ(total_seq_length, qkv_data->shape[0]);
CHECK_LE(total_seq_length, qkv_data->shape[0]);
// Sync the copy stream and the compute stream.
ComputeStreamWaitForCopyStream();
// The auxiliary data structure on device must have been synchronized.
Expand All @@ -1767,12 +1767,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
qkv_data->dtype);
NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, num_kv_heads_, head_dim_},
qkv_data->dtype);

NDArray qkv_data_view = qkv_data;
NDArray o_data_view = o_data;
if (total_seq_length != qkv_data->shape[0]) {
qkv_data_view = qkv_data.CreateView(
{total_seq_length, qkv_data->shape[1], qkv_data->shape[2]}, qkv_data->dtype);
o_data_view =
o_data.CreateView({total_seq_length, num_qo_heads_, head_dim_}, qkv_data->dtype);
}
// Part 2. Split fused qkv and apply rotary embedding to q/k data.
if (!rope_ext_factors_.defined()) {
f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data,
f_split_rotary_(qkv_data_view, q_rope_position_map_view_, q_data, k_data, v_data,
static_cast<int>(rope_mode_ == RoPEMode::kNormal));
} else {
f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data,
f_split_rotary_(qkv_data_view, q_rope_position_map_view_, q_data, k_data, v_data,
rope_ext_factors_.value());
}

Expand All @@ -1781,7 +1790,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_);
}
// Part 4: perform attention
AttentionInternal(layer_id, q_data, k_data, v_data, o_data, attn_score_scaling_factor);
AttentionInternal(layer_id, q_data, k_data, v_data, o_data_view, attn_score_scaling_factor);
// Part 5. Append k/v data to kv-cache if flag "append_before_attn" is not set.
if (!append_before_attn_) {
f_transpose_append_(pages_[local_layer_id], k_data, v_data, append_position_map_view_);
Expand Down

0 comments on commit ce46185

Please sign in to comment.