Skip to content

Commit

Permalink
[3rdparty] Bump FlashInfer for tmp workspace reduction (apache#17400)
Browse files Browse the repository at this point in the history
This PR bumps FlashInfer to reduce the size of required temporary
workspace.
  • Loading branch information
MasterJH5574 authored Sep 22, 2024
1 parent 72d542e commit 36ff1f1
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 14 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/flashinfer
Submodule flashinfer updated 72 files
+8 −1 .github/workflows/release_wheel.yml
+1 −1 .release-please-manifest.json
+50 −1 CHANGELOG.md
+61 −11 CMakeLists.txt
+12 −2 README.md
+2 −0 cmake/config.cmake
+4 −0 docs/api/python/cascade.rst
+3 −0 docs/api/python/sampling.rst
+2 −2 docs/conf.py
+1 −1 docs/installation.rst
+18 −0 docs/tutorials/kv_layout.rst
+74 −0 include/flashinfer/activation.cuh
+85 −66 include/flashinfer/attention/cascade.cuh
+22 −37 include/flashinfer/attention/decode.cuh
+86 −79 include/flashinfer/attention/handler.cuh
+507 −368 include/flashinfer/attention/prefill.cuh
+9 −8 include/flashinfer/frag_layout_swizzle.cuh
+78 −2 include/flashinfer/mma.cuh
+115 −17 include/flashinfer/norm.cuh
+62 −13 include/flashinfer/permuted_smem.cuh
+19 −18 include/flashinfer/prefill_attention_decl.cuh
+504 −165 include/flashinfer/sampling.cuh
+6 −0 include/flashinfer/utils.cuh
+313 −235 include/flashinfer/vec_dtypes.cuh
+60 −0 python/csrc/activation.cu
+25 −15 python/csrc/batch_decode.cu
+222 −105 python/csrc/batch_prefill.cu
+6 −37 python/csrc/flashinfer_ops.cu
+33 −126 python/csrc/flashinfer_ops.h
+32 −0 python/csrc/flashinfer_ops_decode.cu
+59 −0 python/csrc/flashinfer_ops_decode.h
+47 −0 python/csrc/flashinfer_ops_prefill.cu
+96 −0 python/csrc/flashinfer_ops_prefill.h
+46 −16 python/csrc/norm.cu
+4 −0 python/csrc/pytorch_extension_utils.h
+176 −35 python/csrc/sampling.cu
+1 −1 python/csrc/single_decode.cu
+15 −3 python/csrc/single_prefill.cu
+28 −23 python/flashinfer/__init__.py
+102 −0 python/flashinfer/activation.py
+280 −21 python/flashinfer/cascade.py
+41 −24 python/flashinfer/decode.py
+50 −1 python/flashinfer/group_gemm.py
+27 −6 python/flashinfer/norm.py
+3 −3 python/flashinfer/page.py
+97 −46 python/flashinfer/prefill.py
+22 −0 python/flashinfer/quantization.py
+489 −41 python/flashinfer/sampling.py
+37 −10 python/flashinfer/sparse.py
+8 −6 python/generate_batch_paged_prefill_inst.py
+8 −6 python/generate_batch_ragged_prefill_inst.py
+7 −5 python/generate_single_prefill_inst.py
+75 −46 python/setup.py
+45 −0 python/tests/test_activation.py
+208 −0 python/tests/test_fp8_prefill.py
+34 −3 python/tests/test_norm.py
+179 −26 python/tests/test_sampling.py
+40 −51 python/tests/test_shared_prefix_kernels.py
+24 −14 src/bench_batch_decode.cu
+11 −6 src/bench_batch_prefill.cu
+36 −20 src/bench_cascade.cu
+4 −4 src/bench_sampling.cu
+77 −1 src/bench_single_prefill.cu
+15 −13 src/flashinfer_ops.cuh
+6 −3 src/test_batch_decode.cu
+220 −122 src/test_batch_prefill.cu
+28 −22 src/test_cascade.cu
+71 −0 src/test_fast_dequant.cu
+2 −2 src/test_sampling.cu
+108 −15 src/test_single_prefill.cu
+47 −32 src/tvm_wrapper.cu
+1 −1 version.txt
29 changes: 18 additions & 11 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ namespace relax_vm {
constexpr const int kPagedKVCacheMaxBlockDepth = 2;
/*! \brief The maximum tree size of a single sequence in tree attention. */
constexpr const int kTreeAttnMaxTreeSize = 256;
/*! \brief The 8MB workspace size for attention auxiliary data. */
constexpr const int kAttnWorkspaceByte = 128 * 1024 * 1024;
/*! \brief The 1MB workspace size for integer attention auxiliary data. */
constexpr const int kIntAttnWorkspaceByte = 1 * 1024 * 1024;
/*! \brief The 128MB workspace size for floating-point attention auxiliary data. */
constexpr const int kFloatAttnWorkspaceByte = 768 * 1024 * 1024;
/*! \brief The id of the temporary logical page, which is useful for sliding window. */
constexpr const int kPagedKVCacheTempPageId = -1;

Expand Down Expand Up @@ -915,7 +917,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
NDArray temp_attn_output_device_;
NDArray temp_attn_scores_device_;
NDArray merged_attn_scores_device_;
std::vector<NDArray> temp_attn_workspace_;
std::vector<NDArray> temp_int_attn_workspace_;
NDArray temp_float_attn_workspace_;

//-------------------------------------------
// Below are the auxiliary data structure on CPU.
Expand Down Expand Up @@ -1089,8 +1092,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {

for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
if (NeedKernelBeginForward()) {
temp_attn_workspace_.push_back(
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
temp_int_attn_workspace_.push_back(
NDArray::Empty({kIntAttnWorkspaceByte / 4}, DataType::Float(32), device));
}
qo_indptr_on_depths_view_.push_back(NDArray());
page_indptr_on_depths_view_.push_back(NDArray());
Expand All @@ -1103,8 +1106,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
// Additional workspace for the "prefill with ragged kv" kernel.
if (NeedKernelBeginForward()) {
temp_attn_workspace_.push_back(
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
temp_int_attn_workspace_.push_back(
NDArray::Empty({kIntAttnWorkspaceByte / 4}, DataType::Float(32), device));
temp_float_attn_workspace_ =
NDArray::Empty({kFloatAttnWorkspaceByte / 4}, DataType::Float(32), device);
}

temp_attn_q_device_ =
Expand Down Expand Up @@ -2324,7 +2329,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
if (!append_before_attn_) {
if (is_chain_on_depths_[0]) {
f_attention_prefill_ragged_begin_forward_.value()(
temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(),
temp_float_attn_workspace_, temp_int_attn_workspace_[0],
cur_append_lengths_indptr_host_.as_ndarray(),
cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_,
num_kv_heads_, head_dim_, copy_stream_);
}
Expand All @@ -2336,14 +2342,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support sliding window.";
if (use_decode_kernel_[d]) {
f_attention_decode_begin_forward_.value()(
d, temp_attn_workspace_[d + 1], page_indptr_on_depths_host_[d].as_ndarray(),
d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1],
page_indptr_on_depths_host_[d].as_ndarray(),
last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, head_dim_,
page_size_,
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
} else {
f_attention_prefill_begin_forward_.value()(
/*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(),
page_indptr_on_depths_host_[d].as_ndarray(),
/*depth=*/d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1],
qo_indptr_on_depths_host_[d].as_ndarray(), page_indptr_on_depths_host_[d].as_ndarray(),
static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_,
num_kv_heads_, head_dim_, page_size_, copy_stream_);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def set_global_func():
)
fattention_merge_state = tvm.get_global_func("flashinfer.merge_state_in_place")

target = tvm.target.Target("nvidia/geforce-rtx-3090-ti")
target = tvm.target.Target.from_device(device)
builts = []
for tir_func in [
kv_cache_transpose_append,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def set_global_func(head_dim, dtype):
fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty")
fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv")

target = tvm.target.Target("cuda")
target = tvm.target.Target.from_device(device)
builts = []
for tir_func in [
_kv_cache_transpose_append(num_kv_heads, head_dim, dtype),
Expand Down

0 comments on commit 36ff1f1

Please sign in to comment.