Skip to content

Commit

Permalink
Revert "feat: support cuda graph for batched multi-query(prefill/appe…
Browse files Browse the repository at this point in the history
…nd) attention" (#276)

Reverts #275
  • Loading branch information
yzh119 authored Jun 2, 2024
1 parent 83ceb67 commit 081a4c5
Show file tree
Hide file tree
Showing 9 changed files with 538 additions and 489 deletions.
1 change: 1 addition & 0 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "../utils.cuh"
#include "../vec_dtypes.cuh"
#include "cascade.cuh"
#include "handler.cuh"
#include "state.cuh"

namespace flashinfer {
Expand Down
209 changes: 102 additions & 107 deletions include/flashinfer/attention/handler.cuh

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
#include "../pos_enc.cuh"
#include "../utils.cuh"
#include "cascade.cuh"
#include "handler.cuh"
#include "mask.cuh"
#include "state.cuh"

namespace flashinfer {

Expand Down
1 change: 1 addition & 0 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);

bool rejected = false;
uint32_t pos = 0;
for (pos = 0; pos < num_speculative_tokens; ++pos) {
IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + pos];
Expand Down
74 changes: 52 additions & 22 deletions python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,32 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] {
return DISPATCH_pos_encoding_mode(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
cudaError_t status =
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices,
KV_LAYOUT, POS_ENCODING_MODE, c_type,
nv_half, int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
if (handler_->IsCUDAGraphMode()) {
// NOTE(Zihao): use runtime dispatch because template function is not virtual
auto cuda_graph_handler_ =
dynamic_cast<CUDAGraphBatchDecodeHandler*>(handler_.get());
cudaError_t status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched<
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
c_type, nv_half, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()),
batch_size, num_qo_heads, page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error ",
cudaGetErrorString(status));
} else {
cudaError_t status = handler_->BeginForwardDispatched<
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
c_type, nv_half, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()),
batch_size, num_qo_heads, page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
}
return true;
});
});
Expand All @@ -165,17 +180,32 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] {
return DISPATCH_pos_encoding_mode(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
cudaError_t status =
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices,
KV_LAYOUT, POS_ENCODING_MODE, c_type, c_type,
int32_t>(
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
if (handler_->IsCUDAGraphMode()) {
// NOTE(Zihao): use runtime dispatch because template function is not virtual
auto cuda_graph_handler_ =
dynamic_cast<CUDAGraphBatchDecodeHandler*>(handler_.get());
auto status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched<
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
c_type, c_type, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()),
batch_size, num_qo_heads, page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error ",
cudaGetErrorString(status));
} else {
cudaError_t status = handler_->BeginForwardDispatched<
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
c_type, c_type, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
workspace_size_in_bytes,
static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()),
batch_size, num_qo_heads, page_size);
TORCH_CHECK(status == cudaSuccess,
"BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
}
return true;
});
});
Expand Down
18 changes: 11 additions & 7 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,34 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
"BatchDecodeWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, unsigned int, bool>())
.def(py::init<unsigned int, unsigned int>())
.def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled", &BatchDecodeWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward);
py::class_<CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper>(
m, "CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, unsigned int>())
.def("begin_forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::EndForward)
.def("update_page_locked_buffer_size",
&CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::Forward);
py::class_<BatchPrefillWithPagedKVCachePyTorchWrapper>(
m, "BatchPrefillWithPagedKVCachePyTorchWrapper")
.def(py::init<unsigned int, unsigned int, bool>())
.def(py::init<unsigned int, unsigned int>())
.def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled", &BatchPrefillWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward)
.def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask);
py::class_<BatchPrefillWithRaggedKVCachePyTorchWrapper>(
m, "BatchPrefillWithRaggedKVCachePyTorchWrapper")
.def(py::init<unsigned int, unsigned int, bool>())
.def(py::init<unsigned int, unsigned int>())
.def("begin_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward)
.def("is_cuda_graph_enabled",
&BatchPrefillWithRaggedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
.def("update_page_locked_buffer_size",
&BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
.def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward)
Expand Down
32 changes: 17 additions & 15 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
unsigned int pos_encoding_mode, torch::Tensor empty_data);
void EndForward();
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor paged_kv_data,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len,
Expand All @@ -93,24 +92,32 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
BatchDecodeWithPagedKVCachePyTorchWrapper(
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_ptr, flashinfer::QKVLayout kv_layout)
: handler_(handler_ptr), kv_layout_(kv_layout) {}
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, unsigned int max_batch_size,
bool enable_cuda_graph)
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout,
unsigned int max_workspace_size_in_bytes)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(
std::make_shared<flashinfer::BatchDecodeHandler>(max_batch_size, enable_cuda_graph)) {}
handler_(std::make_shared<flashinfer::BatchDecodeHandler>(max_workspace_size_in_bytes)) {}

protected:
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_;
flashinfer::QKVLayout kv_layout_;
};

class CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper
: public BatchDecodeWithPagedKVCachePyTorchWrapper {
public:
CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout,
unsigned int max_batch_size)
: BatchDecodeWithPagedKVCachePyTorchWrapper(
std::make_shared<flashinfer::CUDAGraphBatchDecodeHandler>(max_batch_size),
flashinfer::QKVLayout(layout)) {}
};

class BatchPrefillWithPagedKVCachePyTorchWrapper {
public:
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim);
void EndForward();
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr,
torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
Expand All @@ -126,11 +133,9 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper {
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);
BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout,
unsigned int max_workspace_size_in_bytes,
bool enable_cuda_graph)
unsigned int max_workspace_size_in_bytes)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes,
enable_cuda_graph)) {}
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes)) {}

private:
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;
Expand All @@ -143,7 +148,6 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper {
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
unsigned int head_dim);
void EndForward();
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k,
torch::Tensor v, torch::Tensor kv_indptr, bool causal,
Expand All @@ -158,11 +162,9 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper {
bool allow_fp16_qk_reduction, float sm_scale,
float rope_scale, float rope_theta, bool return_lse);
BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout,
unsigned int max_workspace_size_in_bytes,
bool enable_cuda_graph)
unsigned int max_workspace_size_in_bytes)
: kv_layout_(flashinfer::QKVLayout(layout)),
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes,
enable_cuda_graph)) {}
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes)) {}

private:
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;
Expand Down
Loading

0 comments on commit 081a4c5

Please sign in to comment.