diff --git a/docs/api/python/decode.rst b/docs/api/python/decode.rst index ca972664..5c937a41 100644 --- a/docs/api/python/decode.rst +++ b/docs/api/python/decode.rst @@ -24,3 +24,6 @@ Batch Decoding .. autoclass:: BatchDecodeWithPagedKVCacheWrapper :members: + +.. autoclass:: CUDAGraphDecodeWithPagedKVCacheWrapper + :members: diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 1b3e0881..b982c048 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -17,6 +17,7 @@ #define FLASHINFER_HANDLER_CUH_ #include +#include #include #include #include @@ -101,7 +102,7 @@ template = max_grid_size) { + if (batch_size * num_kv_heads >= max_grid_size && !enable_cuda_graph) { // do not use partition-kv kernel + // TODO(Zihao): if enable_cuda_graph, we should always use partition-kv kernel + // so that only one kernel will be captured in the graph. tmp_size = 0; new_batch_size = batch_size; } else { @@ -299,39 +302,42 @@ class BatchDecodeHandler { DTypeOut, IdType>; FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr, num_qo_heads, - page_size, stream_)); + page_size, + /*enable_cuda_graph=*/false, stream_)); batch_size_after_partition_ = new_batch_size; if (tmp_size > 0) { AlignedAlloactor allocator(buffer, workspace_size_in_bytes); float_buffer_ = allocator.aligned_alloc(tmp_size, 16); new_indptr_ = allocator.aligned_alloc((batch_size_after_partition_ + 1) * sizeof(IdType), 16); - void* new_indptr_h_ = host_buffer_; + void* new_indptr_h_ = page_locked_buffer_; new_last_page_len_ = allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); void* new_last_page_len_h_ = - (char*)host_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); + (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); chunk_indptr_ = allocator.aligned_alloc((batch_size_before_partition_ + 1) * sizeof(IdType), 16); - void* chunk_indptr_h_ = (char*)host_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); + void* chunk_indptr_h_ = + (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); batch_idx_map_ = allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); - void* batch_idx_map_h_ = (char*)host_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); + void* batch_idx_map_h_ = + (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); chunk_start_pos_ = allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); void* chunk_start_pos_h_ = - (char*)host_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); + (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); seq_lengths_before_partition_ = allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); void* seq_lengths_before_partition_h_ = - (char*)host_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); + (char*)page_locked_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, (IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, - (IdType*)seq_lengths_before_partition_h_, new_indptr_, host_buffer_, num_bytes_to_copy, - stream_)); + (IdType*)seq_lengths_before_partition_h_, new_indptr_, page_locked_buffer_, + num_bytes_to_copy, stream_)); } forward_started_ = true; return cudaSuccess; @@ -353,6 +359,11 @@ class BatchDecodeHandler { bool IsForwardStarted() const { return forward_started_; } + void UpdatePageLockedBufferSize(size_t max_workspace_size_in_bytes) { + cudaFreeHost(page_locked_buffer_); + cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes); + } + uint32_t GetBatchSizeBeforePartition() const { return batch_size_before_partition_; } uint32_t GetBatchSizeAfterPartition() const { return batch_size_after_partition_; } @@ -372,17 +383,19 @@ class BatchDecodeHandler { seq_lengths_before_partition_(nullptr), forward_started_(false), stream_(nullptr) { - cudaMallocHost(&host_buffer_, max_workspace_size_in_bytes); + cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes); } ~BatchDecodeHandler() { EndForward(); - cudaFreeHost(host_buffer_); + cudaFreeHost(page_locked_buffer_); } - private: + virtual bool IsCUDAGraphMode() const { return false; } + + protected: uint32_t batch_size_before_partition_; uint32_t batch_size_after_partition_; - void* host_buffer_; + void* page_locked_buffer_; void* float_buffer_; void* new_indptr_; void* new_last_page_len_; @@ -394,6 +407,86 @@ class BatchDecodeHandler { cudaStream_t stream_; }; +class CUDAGraphBatchDecodeHandler : public BatchDecodeHandler { + public: + template + cudaError_t CUDAGraphBeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, + IdType* indptr, IdType* last_page_len, + uint32_t batch_size, uint32_t num_qo_heads, + uint32_t page_size) { + batch_size_before_partition_ = batch_size; + uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size; + auto work_estimation_func = + BatchDecodeWithPagedKVCacheWorkEstimationDispatched; + FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, + new_batch_size, batch_size, indptr, num_qo_heads, + page_size, + /*enable_cuda_graph=*/true, stream_)); + // NOTE(Zihao): max_batch_size_after_partition_ is determined in handler initialization. + // the value should not be changed during the lifetime of the handler. + // So it should be compatible with CUDAGraph which requires fixed pointer. + batch_size_after_partition_ = new_batch_size; + size_t max_tmp_size = num_qo_heads * max_batch_size_after_partition_ * + (HEAD_DIM * sizeof(DTypeOut) + 2 * sizeof(float)); + AlignedAlloactor allocator(buffer, workspace_size_in_bytes); + float_buffer_ = allocator.aligned_alloc(max_tmp_size, 16); + new_indptr_ = + allocator.aligned_alloc((max_batch_size_after_partition_ + 1) * sizeof(IdType), 16); + + void* new_indptr_h_ = page_locked_buffer_; + new_last_page_len_ = + allocator.aligned_alloc(max_batch_size_after_partition_ * sizeof(IdType), 16); + void* new_last_page_len_h_ = + (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); + chunk_indptr_ = + allocator.aligned_alloc((max_batch_size_after_partition_ + 1) * sizeof(IdType), 16); + void* chunk_indptr_h_ = + (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); + batch_idx_map_ = + allocator.aligned_alloc(max_batch_size_after_partition_ * sizeof(IdType), 16); + void* batch_idx_map_h_ = + (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); + chunk_start_pos_ = + allocator.aligned_alloc(max_batch_size_after_partition_ * sizeof(IdType), 16); + void* chunk_start_pos_h_ = + (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); + seq_lengths_before_partition_ = + allocator.aligned_alloc(max_batch_size_after_partition_ * sizeof(IdType), 16); + void* seq_lengths_before_partition_h_ = + (char*)page_locked_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); + + size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; + FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( + max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len, + (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, (IdType*)chunk_indptr_h_, + (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, + (IdType*)seq_lengths_before_partition_h_, new_indptr_, page_locked_buffer_, + num_bytes_to_copy, stream_)); + forward_started_ = true; + return cudaSuccess; + } + CUDAGraphBatchDecodeHandler(size_t max_batch_size) { + int dev_id = 0, num_sm = 0, max_thread_blocks_per_sm = 0; + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id); + cudaDeviceGetAttribute(&max_thread_blocks_per_sm, cudaDevAttrMaxBlocksPerMultiprocessor, + dev_id); + max_batch_size_after_partition_ = + std::max(max_thread_blocks_per_sm * num_sm, max_batch_size); + std::cout << max_thread_blocks_per_sm * num_sm << " " << max_batch_size << std::endl; + size_t max_workspace_size_in_bytes = + 6 * (sizeof(uint64_t) * (max_batch_size_after_partition_ + 1) + 16); + cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes); + } + bool IsCUDAGraphMode() const override { return true; } + + private: + uint32_t max_batch_size_after_partition_; +}; + class BatchPrefillHandler { public: template @@ -412,6 +505,11 @@ class BatchPrefillHandler { bool IsForwardStarted() const { return request_indices_ != nullptr; } + void UpdatePageLockedBufferSize(size_t max_workspace_size_in_bytes) { + cudaFreeHost(page_locked_buffer_); + cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes); + } + template cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* qo_indptr, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, @@ -429,14 +527,15 @@ class BatchPrefillHandler { AlignedAlloactor allocator(buffer, workspace_size_in_bytes); request_indices_ = allocator.aligned_alloc(sizeof(IdType) * request_indices_vec.size(), 16); - void* request_indices_h_ = host_buffer_; + void* request_indices_h_ = page_locked_buffer_; tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * tile_indices_vec.size(), 16); - void* tile_indices_h_ = (char*)host_buffer_ + ((char*)tile_indices_ - (char*)request_indices_); + void* tile_indices_h_ = + (char*)page_locked_buffer_ + ((char*)tile_indices_ - (char*)request_indices_); std::copy(request_indices_vec.begin(), request_indices_vec.end(), (IdType*)request_indices_h_); std::copy(tile_indices_vec.begin(), tile_indices_vec.end(), (IdType*)tile_indices_h_); size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)request_indices_; - FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, host_buffer_, num_bytes_to_copy, + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, page_locked_buffer_, num_bytes_to_copy, cudaMemcpyHostToDevice, stream_)); return cudaSuccess; @@ -462,15 +561,15 @@ class BatchPrefillHandler { num_qo_tiles_(0U), forward_started_(false), stream_(nullptr) { - cudaMallocHost(&host_buffer_, max_workspace_size_in_bytes); + cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes); } ~BatchPrefillHandler() { EndForward(); - cudaFreeHost(host_buffer_); + cudaFreeHost(page_locked_buffer_); } private: - void* host_buffer_; + void* page_locked_buffer_; void* request_indices_; void* tile_indices_; uint32_t num_frags_x_; diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 877ba363..90edae5e 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -309,8 +309,8 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t q_idx_base, #pragma unroll for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { // load q fragment from gmem to smem - q_smem->load_128b_async(q_smem_offset_w, q_ptr, - q_idx < qo_upper_bound && group_id < group_size); + q_smem->load_128b_async( + q_smem_offset_w, q_ptr, q_idx < qo_upper_bound && group_id < group_size); q_smem_offset_w = q_smem->advance_offset_by_column<8>(q_smem_offset_w, fyo); q_ptr += 8 * num_elems_per_128b(); } @@ -933,7 +933,7 @@ __global__ void SinglePrefillWithKVCacheKernel( constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); static_assert(num_frags_z * num_frags_y % num_warps == 0); - static_assert(group_size == 1 || group_size >= 4 && group_size <=8); + static_assert(group_size == 1 || group_size >= 4 && group_size <= 8); extern __shared__ uint8_t smem[]; @@ -1341,7 +1341,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( kv_len = (paged_kv.indptr[request_idx + 1] - paged_kv.indptr[request_idx] - 1) * paged_kv.page_size + paged_kv.last_page_len[request_idx]; - const uint32_t qo_upper_bound = min(qo_len, (tile_idx + 1) * (num_rows_per_cta / aligned_group_size)); + const uint32_t qo_upper_bound = + min(qo_len, (tile_idx + 1) * (num_rows_per_cta / aligned_group_size)); constexpr bool partition_kv = false; constexpr uint32_t head_dim = num_frags_y * 16; @@ -1364,7 +1365,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( } init_states(o_frag, m, d); - const uint32_t qo_idx_base = ((tile_idx * num_warps + ty) * num_frags_x * 16) / aligned_group_size; + const uint32_t qo_idx_base = + ((tile_idx * num_warps + ty) * num_frags_x * 16) / aligned_group_size; const uint32_t qo_n_stride = get_n_stride_impl(num_qo_heads), qo_h_stride = get_h_stride_impl(qo_len); smem_t qo_smem(smem); @@ -1386,12 +1388,12 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { if (q_offset == nullptr) { - q_smem_inplace_apply_rotary_multiply_sm_scale(qo_idx_base, qo_len, kv_len, &qo_smem, - &q_smem_offset_r, rope_freq, sm_scale); + q_smem_inplace_apply_rotary_multiply_sm_scale( + qo_idx_base, qo_len, kv_len, &qo_smem, &q_smem_offset_r, rope_freq, sm_scale); } else { - q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( + q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale( qo_indptr[request_idx] + qo_idx_base, q_offset, &qo_smem, &q_smem_offset_r, rope_freq, sm_scale); } @@ -1418,14 +1420,16 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( cp_async::commit_group(); const uint32_t num_iterations = ceil_div( - (causal ? min(kv_len, - kv_len - qo_len + ((tile_idx + 1) * num_frags_x * num_warps * 16) / aligned_group_size) - : kv_len), + (causal + ? min(kv_len, kv_len - qo_len + + ((tile_idx + 1) * num_frags_x * num_warps * 16) / aligned_group_size) + : kv_len), 16 * num_frags_z); const uint32_t mask_iteration = (causal - ? min(kv_len + (tile_idx * num_warps * num_frags_x * 16) / aligned_group_size - qo_len, kv_len) + ? min(kv_len + (tile_idx * num_warps * num_frags_x * 16) / aligned_group_size - qo_len, + kv_len) : kv_len) / (16 * num_frags_z); @@ -1453,8 +1457,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel( } // apply mask if (iter >= mask_iteration) { - mask_s( - qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, s_frag); + mask_s(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, s_frag); } // compute m,d states in online softmax diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 58831e26..f4511e56 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -282,8 +282,8 @@ std::tuple, std::vector> split_qo_in uint32_t num_qo_tiles = 0; for (uint32_t i = 0; i < batch_size; ++i) { - for (uint32_t j = qo_indptr_h[i] * aligned_gqa_group_size; j < qo_indptr_h[i + 1] * aligned_gqa_group_size; - j += num_rows_per_cta) { + for (uint32_t j = qo_indptr_h[i] * aligned_gqa_group_size; + j < qo_indptr_h[i + 1] * aligned_gqa_group_size; j += num_rows_per_cta) { request_indices.push_back(i); tile_indices.push_back((j - qo_indptr_h[i] * aligned_gqa_group_size) / num_rows_per_cta); ++num_qo_tiles; diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index e90d9969..fa9275a9 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -132,7 +132,7 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); - handler_.SetCUDAStream(torch_current_stream); + handler_->SetCUDAStream(torch_current_stream); if (is_float8_tensor(empty_data)) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_data.scalar_type(), c_type, [&] { @@ -141,17 +141,32 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { return DISPATCH_pos_encoding_mode( PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - handler_.BeginForwardDispatched( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, - page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); + if (handler_->IsCUDAGraphMode()) { + // NOTE(Zihao): use runtime dispatch because template function is not virtual + auto cuda_graph_handler_ = + dynamic_cast(handler_.get()); + cudaError_t status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched< + GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE, + c_type, nv_half, int32_t>(static_cast(workspace_buffer.data_ptr()), + workspace_size_in_bytes, + static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), + batch_size, num_qo_heads, page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error ", + cudaGetErrorString(status)); + } else { + cudaError_t status = handler_->BeginForwardDispatched< + GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE, + c_type, nv_half, int32_t>(static_cast(workspace_buffer.data_ptr()), + workspace_size_in_bytes, + static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), + batch_size, num_qo_heads, page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + } return true; }); }); @@ -165,17 +180,32 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { return DISPATCH_pos_encoding_mode( PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - handler_.BeginForwardDispatched( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, - page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); + if (handler_->IsCUDAGraphMode()) { + // NOTE(Zihao): use runtime dispatch because template function is not virtual + auto cuda_graph_handler_ = + dynamic_cast(handler_.get()); + auto status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched< + GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE, + c_type, c_type, int32_t>(static_cast(workspace_buffer.data_ptr()), + workspace_size_in_bytes, + static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), + batch_size, num_qo_heads, page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error ", + cudaGetErrorString(status)); + } else { + cudaError_t status = handler_->BeginForwardDispatched< + GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE, + c_type, c_type, int32_t>(static_cast(workspace_buffer.data_ptr()), + workspace_size_in_bytes, + static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), + batch_size, num_qo_heads, page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + } return true; }); }); @@ -185,7 +215,12 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( } } -void BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward() { handler_.EndForward(); } +void BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); } + +void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( + unsigned int max_workspace_size_in_bytes) { + handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes); +} std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, @@ -249,8 +284,8 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, c_type, nv_half, int32_t>( - &handler_, static_cast(q.data_ptr()), /*q_offset=*/nullptr, paged_kv, - static_cast(o.data_ptr()), + handler_.get(), static_cast(q.data_ptr()), /*q_offset=*/nullptr, + paged_kv, static_cast(o.data_ptr()), /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), sm_scale, rope_scale, rope_theta, /*stream=*/torch_current_stream); @@ -279,8 +314,8 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, c_type, c_type, int32_t>( - &handler_, static_cast(q.data_ptr()), /*q_offset=*/nullptr, paged_kv, - static_cast(o.data_ptr()), + handler_.get(), static_cast(q.data_ptr()), /*q_offset=*/nullptr, + paged_kv, static_cast(o.data_ptr()), /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), sm_scale, rope_scale, rope_theta, /*stream=*/torch_current_stream); diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index ef00acb5..54d5f55d 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -34,17 +34,22 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32); size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); - handler_.SetCUDAStream(torch_current_stream); + handler_->SetCUDAStream(torch_current_stream); cudaError_t status = - handler_.BeginForward(static_cast(workspace_buffer.data_ptr()), + handler_->BeginForward(static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads, head_dim); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", cudaGetErrorString(status)); } -void BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward() { handler_.EndForward(); } +void BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); } + +void BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( + unsigned int max_workspace_size_in_bytes) { + handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes); +} std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, @@ -117,7 +122,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( PageStorage::kIndices, KV_LAYOUT, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, c_type, c_type, int32_t>( - &handler_, static_cast(q.data_ptr()), + handler_.get(), static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, @@ -157,17 +162,22 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( CHECK_EQ(qo_indptr.scalar_type(), torch::kInt32); size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size(); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); - handler_.SetCUDAStream(torch_current_stream); + handler_->SetCUDAStream(torch_current_stream); cudaError_t status = - handler_.BeginForward(static_cast(workspace_buffer.data_ptr()), + handler_->BeginForward(static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads, head_dim); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ", cudaGetErrorString(status)); } -void BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward() { handler_.EndForward(); } +void BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); } + +void BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( + unsigned int max_workspace_size_in_bytes) { + handler_->UpdatePageLockedBufferSize(max_workspace_size_in_bytes); +} std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, @@ -218,7 +228,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( cudaError_t status = BatchPrefillWithRaggedKVCacheWrapperDispatched< GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, c_type, c_type, int32_t>( - &handler_, static_cast(q.data_ptr()), + handler_.get(), static_cast(q.data_ptr()), static_cast(qo_indptr.data_ptr()), static_cast(k.data_ptr()), static_cast(v.data_ptr()), static_cast(kv_indptr.data_ptr()), diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 376fbd50..67daa87d 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -37,20 +37,34 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); py::class_(m, "BatchDecodeWithPagedKVCachePyTorchWrapper") - .def(py::init()) + .def(py::init()) .def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward) .def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward) + .def("update_page_locked_buffer_size", + &BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) .def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward); + py::class_( + m, "CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper") + .def(py::init()) + .def("begin_forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward) + .def("end_forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::EndForward) + .def("update_page_locked_buffer_size", + &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) + .def("forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::Forward); py::class_( m, "BatchPrefillWithPagedKVCachePyTorchWrapper") - .def(py::init()) + .def(py::init()) .def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward) .def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward) + .def("update_page_locked_buffer_size", + &BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) .def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward); py::class_( m, "BatchPrefillWithRaggedKVCachePyTorchWrapper") - .def(py::init()) + .def(py::init()) .def("begin_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward) .def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward) + .def("update_page_locked_buffer_size", + &BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) .def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward); } diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 285485e5..acfa39bb 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -18,6 +18,7 @@ #include #include +#include // namespace flashinfer { // class BatchPrefillHandler; @@ -70,25 +71,43 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper { unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode, torch::Tensor empty_data); void EndForward(); + void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); std::vector Forward(torch::Tensor q, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, unsigned int pos_encoding_mode, float sm_scale, float rope_scale, float rope_theta, bool return_lse); - BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout) - : kv_layout_(flashinfer::QKVLayout(layout)) {} - - private: - flashinfer::BatchDecodeHandler handler_; + BatchDecodeWithPagedKVCachePyTorchWrapper( + std::shared_ptr handler_ptr, flashinfer::QKVLayout kv_layout) + : handler_(handler_ptr), kv_layout_(kv_layout) {} + BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, + unsigned int max_workspace_size_in_bytes) + : kv_layout_(flashinfer::QKVLayout(layout)), + handler_( + std::make_shared(max_workspace_size_in_bytes)) {} + + protected: + std::shared_ptr handler_; flashinfer::QKVLayout kv_layout_; }; +class CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper + : public BatchDecodeWithPagedKVCachePyTorchWrapper { + public: + CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, + unsigned int max_batch_size) + : BatchDecodeWithPagedKVCachePyTorchWrapper( + std::make_shared(max_batch_size), + flashinfer::QKVLayout(layout)) {} +}; + class BatchPrefillWithPagedKVCachePyTorchWrapper { public: void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim); void EndForward(); + void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, @@ -96,11 +115,13 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); - BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout) - : kv_layout_(flashinfer::QKVLayout(layout)) {} + BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, + unsigned int max_workspace_size_in_bytes) + : kv_layout_(flashinfer::QKVLayout(layout)), + handler_(std::make_shared(max_workspace_size_in_bytes)) {} private: - flashinfer::BatchPrefillHandler handler_; + std::shared_ptr handler_; flashinfer::QKVLayout kv_layout_; }; @@ -110,15 +131,18 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim); void EndForward(); + void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); - BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout) - : kv_layout_(flashinfer::QKVLayout(layout)) {} + BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout, + unsigned int max_workspace_size_in_bytes) + : kv_layout_(flashinfer::QKVLayout(layout)), + handler_(std::make_shared(max_workspace_size_in_bytes)) {} private: - flashinfer::BatchPrefillHandler handler_; + std::shared_ptr handler_; flashinfer::QKVLayout kv_layout_; }; diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index 1980c0b1..f7353d48 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -19,6 +19,7 @@ batch_decode_with_padded_kv_cache, batch_decode_with_padded_kv_cache_return_lse, BatchDecodeWithPagedKVCacheWrapper, + CUDAGraphBatchDecodeWithPagedKVCacheWrapper, ) from .prefill import ( single_prefill_with_kv_cache, diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index a59079c4..cedbefaa 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -405,7 +405,8 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer self._wrapper = _kernels.BatchDecodeWithPagedKVCachePyTorchWrapper( - TensorLayout[kv_layout].value + TensorLayout[kv_layout].value, + workspace_buffer.numel() * workspace_buffer.element_size(), ) self._paged_kv_indptr = None self._paged_kv_indices = None @@ -628,3 +629,285 @@ def forward_return_lse( rope_theta, True, ) + + +class CUDAGraphBatchDecodeWithPagedKVCacheWrapper: + r"""CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first + proposed in `vLLM `_) for batch of requests. + + Note that this wrapper may not be as efficient as :class:`BatchDecodeWithPagedKVCacheWrapper` + because we won't dispatch to different kernels for different batch sizes/sequence lengths/etc + to accomodate the CUDAGraph requirement. + + Check :ref:`our tutorial` for page table layout. + # TODO(Zihao): update documentation + + Note + ---- + The :meth:`begin_forward` method could not be captured by CUDAGraph. + + See Also + -------- + :class:`BatchDecodeWithPagedKVCacheWrapper` + """ + + def __init__( + self, + workspace_buffer: torch.Tensor, + indptr_buffer: torch.Tensor, + indices_buffer: torch.Tensor, + last_page_len_buffer: torch.Tensor, + kv_layout: str = "NHD", + ): + r"""Constructor of :class:`BatchDecodeWithPagedKVCacheWrapper`. + + Parameters + ---------- + workspace_buffer : torch.Tensor + The user reserved workspace buffer on GPU used to store auxiliary data structures, + recommended size is 128MB, the device of the workspace buffer should be the + same as the device of the input tensors. + indptr_buffer : torch.Tensor + The user reserved buffer on GPU to store the indptr of the paged kv cache, should + be large enough to store the indptr of maximum batch size (``[max_batch_size + 1]``) + during the lifecycle of this wrapper. + indices_buffer : torch.Tensor + The user reserved buffer on GPU to store the page indices of the paged kv cache, + should be large enough to store the maximum number of page indices + (``max_num_pages``) during the lifecycle of this wrapper. + last_page_len_buffer : torch.Tensor + The user reserved buffer on GPU to store the number of entries in the last page, + should be large enough to store the maximum batch size (``[max_batch_size]``) + during the lifecycle of this wrapper. + kv_layout : str + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + """ + check_kv_layout(kv_layout) + self._kv_layout = kv_layout + self._workspace_buffer = workspace_buffer + max_batch_size = len(last_page_len_buffer) + self._wrapper = _kernels.CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper( + TensorLayout[kv_layout].value, + max_batch_size, + ) + self._paged_kv_indptr_buf = indptr_buffer + self._paged_kv_indices_buf = indices_buffer + self._paged_kv_last_page_len_buf = last_page_len_buffer + + def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): + r"""Reset the workspace buffer. + + Parameters + ---------- + new_workspace_buffer : torch.Tensor + The new workspace buffer, the device of the new workspace buffer should + be the same as the device of the input tensors. + """ + self._workspace_buffer = new_workspace_buffer + + def begin_forward( + self, + indptr: torch.Tensor, + indices: torch.Tensor, + last_page_len: torch.Tensor, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + pos_encoding_mode: str = "NONE", + data_type: Union[str, torch.dtype] = "float16", + ): + r"""Create auxiliary data structures for batch decode for multiple forward calls + within the same decode step. + + Parameters + ---------- + indptr : torch.Tensor + The indptr of the paged kv cache, shape: ``[batch_size + 1]`` + indices_host : torch.Tensor + The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]`` + last_page_len : torch.Tensor + The number of entries in the last page of each request in the paged kv + cache, shape: ``[batch_size]`` + num_qo_heads : int + The number of query/output heads + num_kv_heads : int + The number of key/value heads + head_dim : int + The dimension of the heads + page_size : int + The page size of the paged kv cache + pos_encoding_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + data_type : Union[str, torch.dtype] + The data type of the paged kv cache + + Note + ---- + The :meth:`begin_forward` method should be called before any :meth:`forward` or + :meth:`forward_return_lse` calls, auxiliary data structures will be created + during this call and cached for multiple forward calls. + + The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` + is not equal to ``num_kv_heads``, the function will use + `grouped query attention `_. + """ + + self._paged_kv_indptr_buf[: len(indptr)] = indptr + self._paged_kv_indices_buf[: len(indices)] = indices + self._paged_kv_last_page_len_buf[: len(last_page_len)] = last_page_len + + batch_size = len(indptr) - 1 + # NOTE(Zihao): the following tensor acts as placeholder to pass dtype info + empty_data = torch.empty( + 0, + dtype=( + getattr(torch, data_type) if isinstance(data_type, str) else data_type + ), + ) + self._wrapper.begin_forward( + self._workspace_buffer, + indptr, + last_page_len, + batch_size, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + PosEncodingMode[pos_encoding_mode].value, + empty_data, + ) + + def end_forward(self): + r"""Clear auxiliary data structures created by :meth:`begin_forward`.""" + self._wrapper.end_forward() + + def forward( + self, + q: torch.Tensor, + paged_kv_data: torch.Tensor, + pos_encoding_mode: str = "NONE", + sm_scale: Optional[float] = None, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, + ): + r"""Compute batch decode attention between query and paged kv cache. + + Parameters + ---------- + q : torch.Tensor + The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]`` + paged_kv_data : torch.Tensor + A 5-D tensor of the reserved paged kv-cache data, shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, or + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if + :attr:`kv_layout` is ``HND``. + pos_encoding_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + sm_scale : Optional[float] + The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``. + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to + ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. + + Returns + ------- + torch.Tensor + The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. + """ + check_pos_encoding_mode(pos_encoding_mode) + if sm_scale is None: + head_dim = q.shape[-1] + sm_scale = 1.0 / math.sqrt(head_dim) + if rope_scale is None: + rope_scale = 1.0 + if rope_theta is None: + rope_theta = 1e4 + + paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) + return self._wrapper.forward( + q, + paged_kv_data, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len_buf, + PosEncodingMode[pos_encoding_mode].value, + sm_scale, + rope_scale, + rope_theta, + False, + )[0] + + def forward_return_lse( + self, + q: torch.Tensor, + paged_kv_data: torch.Tensor, + pos_encoding_mode: str = "NONE", + sm_scale: Optional[float] = None, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, + ): + r"""Compute batch decode attention with paged kv cache, return attention output + and logsumexp of attention scores. + + Parameters + ---------- + q : torch.Tensor + The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]`` + paged_kv_data : torch.Tensor + A 5-D tensor of the reserved paged kv-cache data, shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, or + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if + :attr:`kv_layout` is ``HND``. + pos_encoding_mode : str + Whether to apply RoPE on-the-fly inside attention kernels, could be + ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. + sm_scale : Optional[float] + The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``. + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to + ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. + + Returns + ------- + V : torch.Tensor + The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. + S : torch.Tensor + The logsumexp of attention scores, Shape: ``[batch_size, num_qo_heads]``. + + Notes + ----- + Please refer to the :ref:`tutorial ` for a detailed + explanation of the log-sum-exp function and attention states. + """ + check_pos_encoding_mode(pos_encoding_mode) + if sm_scale is None: + head_dim = q.shape[-1] + sm_scale = 1.0 / math.sqrt(head_dim) + if rope_scale is None: + rope_scale = 1.0 + if rope_theta is None: + rope_theta = 1e4 + paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) + return self._wrapper.forward( + q, + paged_kv_data, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len_buf, + self._batch_size, + self._nnz_pages, + PosEncodingMode[pos_encoding_mode].value, + sm_scale, + rope_scale, + rope_theta, + True, + ) diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 70064e11..f024a1ac 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -353,7 +353,8 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer self._wrapper = _kernels.BatchPrefillWithPagedKVCachePyTorchWrapper( - TensorLayout[kv_layout].value + TensorLayout[kv_layout].value, + workspace_buffer.numel() * workspace_buffer.element_size(), ) self._qo_indptr = None self._paged_kv_indptr = None @@ -666,7 +667,8 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer self._wrapper = _kernels.BatchPrefillWithRaggedKVCachePyTorchWrapper( - TensorLayout[kv_layout].value + TensorLayout[kv_layout].value, + workspace_buffer.numel() * workspace_buffer.element_size(), ) self._qo_indptr = None self._kv_indptr = None diff --git a/python/tests/test_batch_decode_kernels.py b/python/tests/test_batch_decode_kernels.py index 739dd823..13bba6f1 100644 --- a/python/tests/test_batch_decode_kernels.py +++ b/python/tests/test_batch_decode_kernels.py @@ -23,7 +23,6 @@ @pytest.mark.parametrize("batch_size", [12, 17]) @pytest.mark.parametrize("kv_len", [54, 97]) -@pytest.mark.parametrize("qo_len", [37, 17]) @pytest.mark.parametrize("page_size", [1, 8, 16]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) @@ -36,7 +35,6 @@ def test_batch_decode_with_paged_kv_cache( batch_size, kv_len, - qo_len, page_size, num_kv_heads, num_qo_heads, @@ -116,10 +114,151 @@ def test_batch_decode_with_paged_kv_cache( numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize("batch_size", [12, 17]) +@pytest.mark.parametrize("kv_len", [54, 97]) +@pytest.mark.parametrize("page_size", [1, 8, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) +@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) +@pytest.mark.parametrize( + "dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] +) +def test_cuda_graph_batch_decode_with_paged_kv_cache( + batch_size, + kv_len, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + kv_layout, + pos_encoding_mode, + dtype, +): + q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(dtype) + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_data = ( + torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0) + if kv_layout == "HND" + else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim).to(0) + ) + kv_indptr_host_warmup = torch.arange(0, batch_size + 1).int() + kv_indices_host_warmup = torch.arange(0, batch_size).int() + kv_last_page_len_host_warmup = torch.full( + (batch_size,), page_size, dtype=torch.int32 + ) + + kv_indptr_host = torch.arange(0, batch_size + 1).int() * num_pages_per_seq + kv_indices_host = torch.arange(0, total_num_pages).int() + kv_last_page_len_host = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + ) + + kv_indptr_device_buffer = torch.empty(batch_size + 1).int().to(0) + kv_indices_device_buffer = torch.empty(total_num_pages).int().to(0) + kv_last_page_device_buffer = torch.empty(batch_size).int().to(0) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_indptr_device_buffer, + kv_indices_device_buffer, + kv_last_page_device_buffer, + kv_layout, + ) + wrapper.begin_forward( + kv_indptr_host_warmup, + kv_indices_host_warmup, + kv_last_page_len_host_warmup, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + "NONE", + dtype, + ) + # warmup + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + o = wrapper.forward( + q, kv_data.to(dtype), pos_encoding_mode=pos_encoding_mode + ) + torch.cuda.current_stream().wait_stream(s) + # capture + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + o = wrapper.forward(q, kv_data.to(dtype), pos_encoding_mode=pos_encoding_mode) + + wrapper.begin_forward( + kv_indptr_host, + kv_indices_host, + kv_last_page_len_host, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + "NONE", + dtype, + ) + g.replay() + + # compute ground truth and compare + kv_indptr = kv_indptr_host.to(0) + kv_indices = kv_indices_host.to(0) + kv_last_page_len = kv_last_page_len_host.to(0) + for i in range(batch_size): + perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] + perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] + qi = q[i] + ki = torch.cat( + [ + kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 0] + .permute(*perm_dims) + .reshape(-1, num_kv_heads, head_dim), + ( + kv_data[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]] + if kv_layout == "HND" + else kv_data[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :] + ) + .permute(*perm_dims_last) + .reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ).to(dtype) + vi = torch.cat( + [ + kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] + .permute(*perm_dims) + .reshape(-1, num_kv_heads, head_dim), + ( + kv_data[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]] + if kv_layout == "HND" + else kv_data[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :] + ) + .permute(*perm_dims_last) + .reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ).to(dtype) + o_ref_i = flashinfer.single_decode_with_kv_cache( + qi, ki, vi, pos_encoding_mode=pos_encoding_mode + ) + o_i_np = o[i].cpu().numpy() + o_ref_i_np = o_ref_i.cpu().numpy() + numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": test_batch_decode_with_paged_kv_cache( - 12, 54, 37, 8, 8, 8, 128, "HND", "NONE", torch.float16 + 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float16 ) test_batch_decode_with_paged_kv_cache( - 12, 54, 37, 1, 8, 8, 128, "HND", "NONE", torch.float8_e5m2 + 12, 54, 1, 8, 8, 128, "HND", "NONE", torch.float8_e5m2 + ) + test_cuda_graph_batch_decode_with_paged_kv_cache( + 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float16 )