Skip to content

Commit

Permalink
[Unity] Support TIR kernel for PagedKVCache
Browse files Browse the repository at this point in the history
This PR supports PagedKVCache with leveraging TIR kernels.

Right now we do not have sufficient TIR kernels for multi-level
sequences in PagedKVCache, therefore `Fork` in PagedKVCache
is disabled when such a function does not exist.

This PR adds a "reduced" creator of PagedKVCache, where
some auxiliary functions such as the begin/end forward function
of prefill/decode default to None.

CUDA tests are added to ensure correctness.

Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
  • Loading branch information
3 people committed Jan 11, 2024
1 parent 474c06b commit 1603a90
Show file tree
Hide file tree
Showing 3 changed files with 1,149 additions and 40 deletions.
116 changes: 77 additions & 39 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
PackedFunc f_transpose_append_;
PackedFunc f_attention_prefill_;
PackedFunc f_attention_decode_;
PackedFunc f_attention_prefill_ragged_;
PackedFunc f_attention_prefill_ragged_begin_forward_;
PackedFunc f_attention_prefill_ragged_end_forward_;
PackedFunc f_attention_prefill_begin_forward_;
PackedFunc f_attention_prefill_end_forward_;
PackedFunc f_attention_decode_begin_forward_;
PackedFunc f_attention_decode_end_forward_;
Optional<PackedFunc> f_attention_prefill_ragged_;
Optional<PackedFunc> f_attention_prefill_ragged_begin_forward_;
Optional<PackedFunc> f_attention_prefill_ragged_end_forward_;
Optional<PackedFunc> f_attention_prefill_begin_forward_;
Optional<PackedFunc> f_attention_prefill_end_forward_;
Optional<PackedFunc> f_attention_decode_begin_forward_;
Optional<PackedFunc> f_attention_decode_end_forward_;
PackedFunc f_rotary_;
PackedFunc f_merge_inplace_;
Optional<PackedFunc> f_merge_inplace_;
Optional<PackedFunc> f_debug_get_kv_;

/*! \brief Number of fork depth in the current round of forward. */
Expand All @@ -297,19 +297,23 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {

public:
/*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */
explicit PagedAttentionKVCacheObj(
int64_t page_size, //
int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim, //
int64_t reserved_num_seqs, int64_t num_total_pages, //
double rotary_scale, double rotary_theta, //
DLDataType dtype, DLDevice device, PackedFunc f_transpose_append,
PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_ragged_begin_forward,
PackedFunc f_attention_prefill_ragged_end_forward,
PackedFunc f_attention_prefill_begin_forward, PackedFunc f_attention_prefill_end_forward,
PackedFunc f_attention_decode_begin_forward, PackedFunc f_attention_decode_end_forward,
PackedFunc f_rotary, PackedFunc f_merge_inplace, Optional<PackedFunc> f_debug_get_kv)
explicit PagedAttentionKVCacheObj(int64_t page_size, //
int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim, //
int64_t reserved_num_seqs, int64_t num_total_pages, //
double rotary_scale, double rotary_theta, //
DLDataType dtype, DLDevice device,
PackedFunc f_transpose_append, PackedFunc f_attention_prefill,
PackedFunc f_attention_decode,
Optional<PackedFunc> f_attention_prefill_ragged,
Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
Optional<PackedFunc> f_attention_prefill_begin_forward,
Optional<PackedFunc> f_attention_prefill_end_forward,
Optional<PackedFunc> f_attention_decode_begin_forward,
Optional<PackedFunc> f_attention_decode_end_forward,
PackedFunc f_rotary, Optional<PackedFunc> f_merge_inplace,
Optional<PackedFunc> f_debug_get_kv)
: page_size_(page_size),
num_layers_(num_layers),
num_qo_heads_(num_qo_heads),
Expand Down Expand Up @@ -418,6 +422,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
<< "The parent sequence \"" << parent_seq_id << "\" cannot be found in KV cache.";
CHECK(seq_map_.find(child_seq_id) == seq_map_.end())
<< "The child sequence \"" << child_seq_id << "\" is already in the KV cache.";
CHECK(f_merge_inplace_.defined() && f_attention_prefill_ragged_.defined())
<< "Attention merge-score function not available. ForkSequence is thereby not supported.";

int32_t parent_block_idx = parent_it->second.last_block_idx;
// Create a child block with the parent block pointer.
Expand Down Expand Up @@ -558,13 +564,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
}

void EndForward() final {
if (!f_attention_prefill_end_forward_.defined() || !f_attention_decode_end_forward_.defined() ||
!f_attention_prefill_ragged_end_forward_.defined()) {
return;
}
// Mark the dirty flag as true, so that BeginForward is required
// to be invoked before the next round of model forward.
dirty_aux_data_device_ = true;
f_attention_prefill_ragged_end_forward_();
f_attention_prefill_ragged_end_forward_.value()();
for (int d = 0; d < num_depths_; ++d) {
f_attention_prefill_end_forward_(d);
f_attention_decode_end_forward_(d);
f_attention_prefill_end_forward_.value()(d);
f_attention_decode_end_forward_.value()(d);
}
}

Expand Down Expand Up @@ -845,30 +855,36 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {

/*! \brief Invoke the "begin forward" functions of underlying kernels. */
void KernelBeginForward() {
if (!f_attention_prefill_begin_forward_.defined() ||
!f_attention_decode_begin_forward_.defined() ||
!f_attention_prefill_ragged_begin_forward_.defined()) {
return;
}

if (num_depths_ == 1) {
if (use_decode_kernel_[0]) {
f_attention_decode_begin_forward_(
f_attention_decode_begin_forward_.value()(
/*depth=*/0, page_indptr_on_depths_view_[0], last_page_len_on_depths_view_[0],
num_qo_heads_, num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/true);
} else {
f_attention_prefill_begin_forward_(/*depth=*/0, qo_indptr_on_depths_view_[0],
cur_batch_size_, num_qo_heads_, num_kv_heads_);
f_attention_prefill_begin_forward_.value()(/*depth=*/0, qo_indptr_on_depths_view_[0],
cur_batch_size_, num_qo_heads_, num_kv_heads_);
}
} else {
f_attention_prefill_ragged_begin_forward_(cur_append_length_indptr_view_, cur_batch_size_,
num_qo_heads_, num_kv_heads_);
f_attention_prefill_ragged_begin_forward_.value()(
cur_append_length_indptr_view_, cur_batch_size_, num_qo_heads_, num_kv_heads_);
for (int d = 0; d < num_depths_; ++d) {
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
continue;
}
if (use_decode_kernel_[d]) {
f_attention_decode_begin_forward_(
f_attention_decode_begin_forward_.value()(
d, page_indptr_on_depths_view_[d], last_page_len_on_depths_view_[d], num_qo_heads_,
num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/false);
} else {
f_attention_prefill_begin_forward_(/*depth=*/d, qo_indptr_on_depths_view_[d],
last_page_len_on_depths_view_[d]->shape[0],
num_qo_heads_, num_kv_heads_);
f_attention_prefill_begin_forward_.value()(/*depth=*/d, qo_indptr_on_depths_view_[d],
last_page_len_on_depths_view_[d]->shape[0],
num_qo_heads_, num_kv_heads_);
}
}
}
Expand Down Expand Up @@ -896,10 +912,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
}
} else {
// Compute appended text self-attention
f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, k_data, v_data,
cur_append_length_indptr_view_, output, merged_attn_scores_view_,
/*causal=*/1,
/*rotary_mode=*/0, rotary_scale_, rotary_theta_);
f_attention_prefill_ragged_.value()(q_data, cur_append_length_indptr_view_, k_data, v_data,
cur_append_length_indptr_view_, output,
merged_attn_scores_view_,
/*causal=*/1,
/*rotary_mode=*/0, rotary_scale_, rotary_theta_);

for (int d = 0; d < num_depths_; ++d) {
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
Expand All @@ -920,8 +937,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
/*causal=*/0,
/*rotary_mode=*/0, rotary_scale_, rotary_theta_);
}
f_merge_inplace_(output, merged_attn_scores_view_, temp_attn_output_view_,
temp_attn_scores_view_);
f_merge_inplace_.value()(output, merged_attn_scores_view_, temp_attn_output_view_,
temp_attn_scores_view_);
}
}
}
Expand Down Expand Up @@ -1068,6 +1085,27 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
return PagedAttentionKVCache(std::move(n));
});

TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
.set_body_typed([](ShapeTuple cache_config, int64_t num_layers, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t head_dim, double rotary_scale,
double rotary_theta, NDArray init, PackedFunc f_transpose_append,
PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
PackedFunc f_rotary, Optional<PackedFunc> f_debug_get_kv) {
CHECK_EQ(cache_config.size(), 3);
int64_t reserved_num_seqs = cache_config[0];
int64_t total_token_capacity = cache_config[1];
int64_t page_size = cache_config[2];
int64_t num_total_pages = (total_token_capacity + page_size - 1) / page_size;
ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs,
num_total_pages, rotary_scale, rotary_theta, init->dtype, init->device,
std::move(f_transpose_append), std::move(f_attention_prefill),
std::move(f_attention_decode), //
NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, //
std::move(f_rotary), NullOpt, std::move(f_debug_get_kv));
return PagedAttentionKVCache(std::move(n));
});

TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_clear")
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::Clear);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_add_sequence")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,20 @@ def kv_cache_transpose_append(
for global_pos, h, f in T.grid(ntoken, num_kv_heads, head_dim):
with T.block("k_transpose_append"):
vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
T.writes(
pages[position_map[vgpos] // page_size, 0, vh, position_map[vgpos] % page_size, vf]
)
position: T.int64 = T.Cast("int64", position_map[vgpos])
pages[
T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vf
] = k_data[vgpos, vh, vf]
with T.block("v_transpose_append"):
vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
T.writes(
pages[position_map[vgpos] // page_size, 1, vh, position_map[vgpos] % page_size, vf]
)
position: T.int64 = T.Cast("int64", position_map[vgpos])
pages[
T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vf
Expand Down Expand Up @@ -115,6 +123,11 @@ def copy_cache(
for p, h, d in T.grid(seqlen, num_kv_heads, head_dim):
with T.block("copy0"):
vp, vh, vd = T.axis.remap("SSS", [p, h, d])
T.reads(
position_map[vp],
pages[position_map[vp] // page_size, 0:2, vh, position_map[vp] % page_size, vd],
)
T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd])
position: T.int64 = T.Cast("int64", position_map[vp])
k_data[layer_id, vp, vh, vd] = pages[
T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vd
Expand Down Expand Up @@ -457,7 +470,7 @@ def test_paged_attention_kv_cache_popn(kv_cache):
if pop_length != 0:
cached_k[seq_id] = cached_k[seq_id][:, :-pop_length, ...]
cached_v[seq_id] = cached_v[seq_id][:, :-pop_length, ...]
verify_cached_kv(kv_cache, seq_ids=list(range(4)), expected_k=cached_k, expected_v=cached_v)
verify_cached_kv(kv_cache, seq_ids=list(range(5)), expected_k=cached_k, expected_v=cached_v)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 1603a90

Please sign in to comment.