Skip to content

Commit

Permalink
refactor: replace begin_forward/forward/end_forward with plan
Browse files Browse the repository at this point in the history
…/`run` (#466)

This PR changes the use of `begin_forward`/`forward`/`end_forward` API
with the new `plan`/`run` API.
- `forward` is consistent with pytorch but confusing because flashinfer
focus on inference and do not have a corresponding `backward` phase,
this PR changes it to `run`, which is more precise and consistent with
the naming convention of cutlass's python API.
- `begin_forward` is renamed to `plan`, which is consistent with the
naming convention of nvmath API.
- `end_forward` is deprecated and has no effect after this PR.

There is some slight difference between the old `forward` and the new
`run` API:
- All problem specifications will be provided in `plan` (previously
`begin_forward`) API, and cached until next `plan` call, and we only
need to provide query and KV-Cache tensors in `run` API.

This is not a breaking change, and we keep backward compatibility of the
old `begin_forward`/`forward`/`end_forward` APIs, they will be gradually
deprecated in future releases.
  • Loading branch information
yzh119 authored Aug 25, 2024
1 parent 957572e commit d940d2e
Show file tree
Hide file tree
Showing 27 changed files with 673 additions and 586 deletions.
47 changes: 14 additions & 33 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,11 @@ class BatchDecodeHandler {
template <uint32_t HEAD_DIM, PageStorage page_storage, LogitsPostHook LOGITS_POST_HOOK,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut,
typename IdType>
cudaError_t BeginForwardDispatched(void* float_buffer, size_t float_workspace_size_in_bytes,
void* int_buffer, size_t int_workspace_size_in_bytes,
IdType* indptr_h, IdType* last_page_len_h, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t page_size) {
cudaError_t PlanDispatched(void* float_buffer, size_t float_workspace_size_in_bytes,
void* int_buffer, size_t int_workspace_size_in_bytes, IdType* indptr_h,
IdType* last_page_len_h, uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t page_size) {
Clear();
batch_size_before_partition_ = batch_size;
bool split_kv;
uint32_t max_grid_size, max_num_pages_per_batch, new_batch_size;
Expand Down Expand Up @@ -438,12 +438,10 @@ class BatchDecodeHandler {
}
}
});
forward_started_ = true;
return cudaSuccess;
}

cudaError_t EndForward() {
forward_started_ = false;
void Clear() {
padded_batch_size_ = 0;
batch_size_before_partition_ = 0;
batch_size_after_partition_ = 0;
Expand All @@ -456,11 +454,8 @@ class BatchDecodeHandler {
batch_idx_map_ = nullptr;
chunk_start_pos_ = nullptr;
seq_lengths_before_partition_ = nullptr;
return cudaSuccess;
}

bool IsForwardStarted() const { return forward_started_; }

void UpdatePageLockedBufferSize(size_t int_workspace_size_in_bytes) {
cudaFreeHost(page_locked_buffer_);
cudaMallocHost(&page_locked_buffer_, int_workspace_size_in_bytes);
Expand Down Expand Up @@ -490,16 +485,12 @@ class BatchDecodeHandler {
batch_idx_map_(nullptr),
chunk_start_pos_(nullptr),
seq_lengths_before_partition_(nullptr),
forward_started_(false),
cuda_graph_enabled_(enable_cuda_graph),
fixed_batch_size_(batch_size),
stream_(nullptr) {
cudaMallocHost(&page_locked_buffer_, 8 * 1024 * 1024);
}
~BatchDecodeHandler() {
EndForward();
cudaFreeHost(page_locked_buffer_);
}
~BatchDecodeHandler() { cudaFreeHost(page_locked_buffer_); }

bool IsCUDAGraphEnabled() const { return cuda_graph_enabled_; }

Expand All @@ -516,7 +507,6 @@ class BatchDecodeHandler {
void* batch_idx_map_;
void* chunk_start_pos_;
void* seq_lengths_before_partition_;
bool forward_started_;
bool cuda_graph_enabled_;
uint32_t padded_batch_size_;
uint32_t fixed_batch_size_;
Expand Down Expand Up @@ -656,19 +646,17 @@ class BatchPrefillHandler {

uint32_t GetTotalNumRows() const { return total_num_rows_; }

bool IsForwardStarted() const { return request_indices_ != nullptr; }

void UpdatePageLockedBufferSize(size_t int_workspace_size_in_bytes) {
cudaFreeHost(page_locked_buffer_);
cudaMallocHost(&page_locked_buffer_, int_workspace_size_in_bytes);
}

template <typename DTypeOut, typename IdType>
cudaError_t BeginForward(void* float_buffer, size_t float_workspace_size_in_bytes,
void* int_buffer, size_t int_workspace_size_in_bytes,
IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
uint32_t page_size) {
cudaError_t Plan(void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer,
size_t int_workspace_size_in_bytes, IdType* qo_indptr_h, IdType* kv_indptr_h,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t head_dim, uint32_t page_size) {
Clear();
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
Expand Down Expand Up @@ -812,8 +800,7 @@ class BatchPrefillHandler {
return cudaSuccess;
}

cudaError_t EndForward() {
forward_started_ = false;
void Clear() {
request_indices_ = nullptr;
qo_tile_indices_ = nullptr;
kv_tile_indices_ = nullptr;
Expand All @@ -826,7 +813,6 @@ class BatchPrefillHandler {
total_num_rows_ = 0U;
padded_batch_size_ = 0U;
warp_layout_ = WarpLayout::k4x1x2;
return cudaSuccess;
}

cudaStream_t GetCUDAStream() const { return stream_; }
Expand All @@ -848,15 +834,11 @@ class BatchPrefillHandler {
total_num_rows_(0U),
padded_batch_size_(0U),
warp_layout_(WarpLayout::k4x1x2),
forward_started_(false),
enable_cuda_graph_(enable_cuda_graph),
stream_(nullptr) {
cudaMallocHost(&page_locked_buffer_, 8 * 1024 * 1024);
}
~BatchPrefillHandler() {
EndForward();
cudaFreeHost(page_locked_buffer_);
}
~BatchPrefillHandler() { cudaFreeHost(page_locked_buffer_); }

protected:
void* page_locked_buffer_;
Expand All @@ -872,7 +854,6 @@ class BatchPrefillHandler {
uint32_t total_num_rows_;
uint32_t padded_batch_size_;
WarpLayout warp_layout_;
bool forward_started_;
bool enable_cuda_graph_;
cudaStream_t stream_;
};
Expand Down
27 changes: 10 additions & 17 deletions include/flashinfer/decode_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,16 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched(
DTypeOut* tmp_v = handler->GetTempV<DTypeOut>();
float* tmp_s = handler->GetTempS();

if (handler->IsForwardStarted()) {
if (tmp_v != nullptr) {
// create auxiliary information for cooperative kernels
new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition();
new_paged_kv.indptr = handler->GetNewIndPtr<IdType>();
new_paged_kv.last_page_len = handler->GetNewLastPageLen<IdType>();
kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition();
kv_partition_info.chunk_indptr = handler->GetChunkIndPtr<IdType>();
kv_partition_info.batch_idx_map = handler->GetBatchIdxMap<IdType>();
kv_partition_info.chunk_start_pos = handler->GetChunkStartPos<IdType>();
kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition<IdType>();
}
} else {
std::ostringstream err_msg;
err_msg << "Please call BatchDecodeHandler's BeginForward() before calling "
"BatchDecodeWithPagedKVCacheWrapper()";
throw std::runtime_error(err_msg.str());
if (tmp_v != nullptr) {
// create auxiliary information for cooperative kernels
new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition();
new_paged_kv.indptr = handler->GetNewIndPtr<IdType>();
new_paged_kv.last_page_len = handler->GetNewLastPageLen<IdType>();
kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition();
kv_partition_info.chunk_indptr = handler->GetChunkIndPtr<IdType>();
kv_partition_info.batch_idx_map = handler->GetBatchIdxMap<IdType>();
kv_partition_info.chunk_start_pos = handler->GetChunkStartPos<IdType>();
kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition<IdType>();
}

return BatchDecodeWithPagedKVCacheDispatched<HEAD_DIM, PAGE_STORAGE, LOGITS_POST_HOOK,
Expand Down
62 changes: 24 additions & 38 deletions include/flashinfer/prefill_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,18 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
WarpLayout warp_layout;
uint32_t padded_batch_size = 0U;
uint32_t total_num_rows = 0U;
if (handler->IsForwardStarted()) {
tmp_v = handler->GetTempV<DTypeOut>();
tmp_s = handler->GetTempS();
request_indices = handler->GetRequestIndices<IdType>();
qo_tile_indices = handler->GetQOTileIndices<IdType>();
kv_tile_indices = handler->GetKVTileIndices<IdType>();
block_valid_mask = handler->GetBlockValidMask();
o_indptr = handler->GetOIndptr<IdType>();
merge_indptr = handler->GetMergeIndptr<IdType>();
kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
warp_layout = handler->GetWarpLayout();
padded_batch_size = handler->GetPaddedBatchSize();
total_num_rows = handler->GetTotalNumRows();
} else {
std::ostringstream err_msg;
err_msg << "Please call BatchPrefillHandler's BeginForward() before calling "
"BatchPrefillWithPagedKVCacheWrapper()";
throw std::runtime_error(err_msg.str());
}
tmp_v = handler->GetTempV<DTypeOut>();
tmp_s = handler->GetTempS();
request_indices = handler->GetRequestIndices<IdType>();
qo_tile_indices = handler->GetQOTileIndices<IdType>();
kv_tile_indices = handler->GetKVTileIndices<IdType>();
block_valid_mask = handler->GetBlockValidMask();
o_indptr = handler->GetOIndptr<IdType>();
merge_indptr = handler->GetMergeIndptr<IdType>();
kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
warp_layout = handler->GetWarpLayout();
padded_batch_size = handler->GetPaddedBatchSize();
total_num_rows = handler->GetTotalNumRows();

DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, {
return BatchPrefillWithPagedKVCacheDispatched<
Expand Down Expand Up @@ -131,25 +124,18 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
WarpLayout warp_layout;
uint32_t padded_batch_size = 0U;
uint32_t total_num_rows = 0U;
if (handler->IsForwardStarted()) {
tmp_v = handler->GetTempV<DTypeOut>();
tmp_s = handler->GetTempS();
request_indices = handler->GetRequestIndices<IdType>();
qo_tile_indices = handler->GetQOTileIndices<IdType>();
kv_tile_indices = handler->GetKVTileIndices<IdType>();
block_valid_mask = handler->GetBlockValidMask();
o_indptr = handler->GetOIndptr<IdType>();
merge_indptr = handler->GetMergeIndptr<IdType>();
kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
warp_layout = handler->GetWarpLayout();
padded_batch_size = handler->GetPaddedBatchSize();
total_num_rows = handler->GetTotalNumRows();
} else {
std::ostringstream err_msg;
err_msg << "Please call BatchPrefillHandler's BeginForward() before calling "
"BatchPrefillWithRaggedKVWrapperCache()";
throw std::runtime_error(err_msg.str());
}
tmp_v = handler->GetTempV<DTypeOut>();
tmp_s = handler->GetTempS();
request_indices = handler->GetRequestIndices<IdType>();
qo_tile_indices = handler->GetQOTileIndices<IdType>();
kv_tile_indices = handler->GetKVTileIndices<IdType>();
block_valid_mask = handler->GetBlockValidMask();
o_indptr = handler->GetOIndptr<IdType>();
merge_indptr = handler->GetMergeIndptr<IdType>();
kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
warp_layout = handler->GetWarpLayout();
padded_batch_size = handler->GetPaddedBatchSize();
total_num_rows = handler->GetTotalNumRows();

DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, {
return BatchPrefillWithRaggedKVCacheDispatched<WARP_LAYOUT, HEAD_DIM, LOGITS_POST_HOOK,
Expand Down
29 changes: 13 additions & 16 deletions python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

using namespace flashinfer;

void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
void BatchDecodeWithPagedKVCachePyTorchWrapper::Plan(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor indptr,
torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
Expand Down Expand Up @@ -62,15 +62,15 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
return DISPATCH_pos_encoding_mode(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
cudaError_t status =
handler_->BeginForwardDispatched<HEAD_DIM, PageStorage::kIndices,
LOGITS_POST_HOOK, POS_ENCODING_MODE, qkv_type,
qkv_type, qkv_type, int32_t>(
static_cast<void*>(float_workspace_buffer.data_ptr()),
float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()),
int_workspace_size_in_bytes, static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
num_kv_heads, page_size);
handler_
->PlanDispatched<HEAD_DIM, PageStorage::kIndices, LOGITS_POST_HOOK,
POS_ENCODING_MODE, qkv_type, qkv_type, qkv_type, int32_t>(
static_cast<void*>(float_workspace_buffer.data_ptr()),
float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()),
int_workspace_size_in_bytes, static_cast<int32_t*>(indptr.data_ptr()),
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size,
num_qo_heads, num_kv_heads, page_size);
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ",
cudaGetErrorString(status));
return true;
Expand All @@ -86,9 +86,8 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
return DISPATCH_pos_encoding_mode(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
cudaError_t status =
handler_->BeginForwardDispatched<HEAD_DIM, PageStorage::kIndices,
LOGITS_POST_HOOK, POS_ENCODING_MODE, q_type,
kv_type, q_type, int32_t>(
handler_->PlanDispatched<HEAD_DIM, PageStorage::kIndices, LOGITS_POST_HOOK,
POS_ENCODING_MODE, q_type, kv_type, q_type, int32_t>(
static_cast<void*>(float_workspace_buffer.data_ptr()),
float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()),
Expand All @@ -107,14 +106,12 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
}
}

void BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }

void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
unsigned int int_workspace_size_in_bytes) {
handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes);
}

std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Run(
torch::Tensor q, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache, std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
Expand Down
20 changes: 8 additions & 12 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

using namespace flashinfer;

void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
void BatchPrefillWithPagedKVCachePyTorchWrapper::Plan(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
Expand Down Expand Up @@ -48,7 +48,7 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
handler_->SetCUDAStream(torch_current_stream);

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] {
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
cudaError_t status = handler_->Plan<q_type, int32_t>(
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()),
Expand All @@ -60,14 +60,12 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
});
}

void BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }

void BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
unsigned int int_workspace_size_in_bytes) {
handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes);
}

std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Run(
torch::Tensor q, torch::Tensor qo_indptr, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache, std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
Expand Down Expand Up @@ -257,7 +255,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
}
}

std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask(
std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::RunCustomMask(
torch::Tensor q, torch::Tensor qo_indptr, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache, std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
Expand Down Expand Up @@ -452,7 +450,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
}
}

void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
void BatchPrefillWithRaggedKVCachePyTorchWrapper::Plan(
torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer,
torch::Tensor qo_indptr, torch::Tensor kv_indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
Expand All @@ -479,7 +477,7 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
handler_->SetCUDAStream(torch_current_stream);

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] {
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
cudaError_t status = handler_->Plan<q_type, int32_t>(
static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes,
static_cast<void*>(int_workspace_buffer.data_ptr()), int_workspace_size_in_bytes,
static_cast<int32_t*>(qo_indptr.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
Expand All @@ -491,14 +489,12 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
});
}

void BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }

void BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
unsigned int int_workspace_size_in_bytes) {
handler_->UpdatePageLockedBufferSize(int_workspace_size_in_bytes);
}

std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Run(
torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v,
torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode,
bool allow_fp16_qk_reduction, int window_left, float logits_soft_cap, float sm_scale,
Expand Down Expand Up @@ -605,7 +601,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
}
}

std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask(
std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::RunCustomMask(
torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v,
torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr,
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, int window_left,
Expand Down
Loading

0 comments on commit d940d2e

Please sign in to comment.