Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc: enhance allocator error info and add shape check for prefill begin forward functions #413

Merged
merged 7 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions include/flashinfer/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#define FLASHINFER_ALLOCATOR_H_

#include <memory>
#include <sstream>
#include <stdexcept>

namespace flashinfer {
Expand All @@ -26,14 +27,17 @@ struct AlignedAllocator {
size_t space;
AlignedAllocator(void* buf, size_t space) : ptr(buf), space(space) {}
template <typename T>
T* aligned_alloc(size_t size, size_t alignment) {
T* aligned_alloc(size_t size, size_t alignment, std::string name) {
if (std::align(alignment, size, ptr, space)) {
T* result = reinterpret_cast<T*>(ptr);
ptr = (char*)ptr + size;
space -= size;
return result;
} else {
throw std::runtime_error("RuntimeError: Out of workspace memory in AlignedAlloactor");
std::ostringstream oss;
oss << "Failed to allocate memory for " << name << " with size " << size << " and alignment "
<< alignment << " in AlignedAllocator";
throw std::runtime_error(oss.str());
}
return nullptr;
}
Expand Down
114 changes: 67 additions & 47 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ inline std::tuple<bool, uint32_t, uint32_t> PrefillBinarySearchKVChunkSize(
high = mid;
}
}

new_batch_size = 0;
for (uint32_t i = 0; i < batch_size; ++i) {
new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) *
Expand Down Expand Up @@ -340,32 +339,37 @@ class BatchDecodeHandler {
padded_batch_size_ = padded_batch_size;
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(
num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16);
tmp_s_ =
allocator.aligned_alloc<float>(num_qo_heads * padded_batch_size * sizeof(float), 16);
new_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);
num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16,
"batch_decode_tmp_v");
tmp_s_ = allocator.aligned_alloc<float>(num_qo_heads * padded_batch_size * sizeof(float),
16, "batch_decode_tmp_s");
new_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16,
"batch_decode_new_indptr");

void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ =
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
new_last_page_len_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16,
"batch_decode_new_last_page_len");
void* new_last_page_len_h_ =
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ =
allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);
chunk_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType),
16, "batch_decode_chunk_indptr");
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
batch_idx_map_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16,
"batch_decode_batch_idx_map");
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
chunk_start_pos_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16,
"batch_decode_chunk_start_pos");
void* chunk_start_pos_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
seq_lengths_before_partition_ = allocator.aligned_alloc<void>(
padded_batch_size * sizeof(IdType), 16, "batch_decode_seq_lengths_before_partition");
void* seq_lengths_before_partition_h_ =
(char*)page_locked_buffer_ +
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
block_valid_mask_ = allocator.aligned_alloc<bool>(padded_batch_size * sizeof(bool), 16);
block_valid_mask_ = allocator.aligned_alloc<bool>(padded_batch_size * sizeof(bool), 16,
"batch_decode_block_valid_mask");
bool* block_valid_mask_h_ =
(bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_);
std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0);
Expand All @@ -390,30 +394,32 @@ class BatchDecodeHandler {
if (split_kv) {
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(
num_qo_heads * new_batch_size * HEAD_DIM * sizeof(DTypeOut), 16);
tmp_s_ =
allocator.aligned_alloc<float>(num_qo_heads * new_batch_size * sizeof(float), 16);
new_indptr_ =
allocator.aligned_alloc<void>((batch_size_after_partition_ + 1) * sizeof(IdType), 16);
num_qo_heads * new_batch_size * HEAD_DIM * sizeof(DTypeOut), 16,
"batch_decode_tmp_v");
tmp_s_ = allocator.aligned_alloc<float>(num_qo_heads * new_batch_size * sizeof(float), 16,
"batch_decode_tmp_s");
new_indptr_ = allocator.aligned_alloc<void>(
(batch_size_after_partition_ + 1) * sizeof(IdType), 16, "batch_decode_new_indptr");
void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
new_last_page_len_ = allocator.aligned_alloc<void>(
batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_new_last_page_len");
void* new_last_page_len_h_ =
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ = allocator.aligned_alloc<void>(
(batch_size_before_partition_ + 1) * sizeof(IdType), 16);
(batch_size_before_partition_ + 1) * sizeof(IdType), 16, "batch_decode_chunk_indptr");
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
batch_idx_map_ = allocator.aligned_alloc<void>(
batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_batch_idx_map");
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
chunk_start_pos_ = allocator.aligned_alloc<void>(
batch_size_after_partition_ * sizeof(IdType), 16, "batch_decode_chunk_start_pos");
void* chunk_start_pos_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16,
"batch_decode_seq_lengths_before_partition");
void* seq_lengths_before_partition_h_ =
(char*)page_locked_buffer_ +
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
Expand Down Expand Up @@ -678,27 +684,34 @@ class BatchPrefillHandler {
if (IsCUDAGraphEnabled()) {
padded_batch_size_ = std::max(split_max_batch_size, total_num_tiles_q);
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
request_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * padded_batch_size_, 16);
request_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * padded_batch_size_, 16,
"batch_prefill_request_indices");
void* request_indices_h_ = page_locked_buffer_;
qo_tile_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * padded_batch_size_, 16);
qo_tile_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * padded_batch_size_, 16,
"batch_prefill_qo_tile_indices");
void* qo_tile_indices_h_ =
(char*)page_locked_buffer_ + ((char*)qo_tile_indices_ - (char*)request_indices_);
kv_tile_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * padded_batch_size_, 16);
kv_tile_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * padded_batch_size_, 16,
"batch_prefill_kv_tile_indices");
void* kv_tile_indices_h_ =
(char*)page_locked_buffer_ + ((char*)kv_tile_indices_ - (char*)request_indices_);
o_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * (batch_size + 1), 16);
o_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * (batch_size + 1), 16,
"batch_prefill_o_indptr");
void* o_indptr_h_ = (char*)page_locked_buffer_ + ((char*)o_indptr_ - (char*)request_indices_);
kv_chunk_size_ptr_ = allocator.aligned_alloc<void>(sizeof(IdType), 1);
kv_chunk_size_ptr_ =
allocator.aligned_alloc<void>(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr");
void* kv_chunk_size_ptr_h_ =
(char*)page_locked_buffer_ + ((char*)kv_chunk_size_ptr_ - (char*)request_indices_);
*(IdType*)kv_chunk_size_ptr_h_ = kv_chunk_size;
if (total_num_tiles_q < split_max_batch_size) {
// need merge_indptr
merge_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * (total_num_rows_ + 1), 16);
merge_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * (total_num_rows_ + 1), 16,
"batch_prefill_merge_indptr");
void* merge_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)merge_indptr_ - (char*)request_indices_);
std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), (IdType*)merge_indptr_h_);
block_valid_mask_ = allocator.aligned_alloc<bool>(sizeof(bool) * padded_batch_size_, 16);
block_valid_mask_ = allocator.aligned_alloc<bool>(sizeof(bool) * padded_batch_size_, 16,
"batch_prefill_block_valid_mask");
bool* block_valid_mask_h_ =
(bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)request_indices_);
for (uint32_t i = 0; i < padded_batch_size_; ++i) {
Expand All @@ -724,37 +737,42 @@ class BatchPrefillHandler {

if (total_num_tiles_q < split_max_batch_size) {
tmp_v_ = allocator.aligned_alloc<void>(
num_qo_heads * split_max_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16);
num_qo_heads * split_max_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16,
"batch_prefill_tmp_v");
tmp_s_ = allocator.aligned_alloc<float>(
num_qo_heads * split_max_batch_size * qo_tile_size * sizeof(float), 16);
num_qo_heads * split_max_batch_size * qo_tile_size * sizeof(float), 16,
"batch_prefill_tmp_s");
} else {
tmp_v_ = nullptr;
tmp_s_ = nullptr;
}
} else {
padded_batch_size_ = new_batch_size;
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
request_indices_ =
allocator.aligned_alloc<void>(sizeof(IdType) * request_indices_vec.size(), 16);
request_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * request_indices_vec.size(),
16, "batch_prefill_request_indices");
void* request_indices_h_ = page_locked_buffer_;
qo_tile_indices_ =
allocator.aligned_alloc<void>(sizeof(IdType) * qo_tile_indices_vec.size(), 16);
qo_tile_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * qo_tile_indices_vec.size(),
16, "batch_prefill_qo_tile_indices");
void* qo_tile_indices_h_ =
(char*)page_locked_buffer_ + ((char*)qo_tile_indices_ - (char*)request_indices_);
kv_tile_indices_ =
allocator.aligned_alloc<void>(sizeof(IdType) * kv_tile_indices_vec.size(), 16);
kv_tile_indices_ = allocator.aligned_alloc<void>(sizeof(IdType) * kv_tile_indices_vec.size(),
16, "batch_prefill_kv_tile_indices");
void* kv_tile_indices_h_ =
(char*)page_locked_buffer_ + ((char*)kv_tile_indices_ - (char*)request_indices_);
if (split_kv) {
// need merge_indptr when split_kv is true
merge_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * merge_indptr_vec.size(), 16);
merge_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * merge_indptr_vec.size(), 16,
"batch_prefill_merge_indptr");
void* merge_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)merge_indptr_ - (char*)request_indices_);
std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), (IdType*)merge_indptr_h_);
}
o_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * o_indptr_vec.size(), 16);
o_indptr_ = allocator.aligned_alloc<void>(sizeof(IdType) * o_indptr_vec.size(), 16,
"batch_prefill_o_indptr");
void* o_indptr_h_ = (char*)page_locked_buffer_ + ((char*)o_indptr_ - (char*)request_indices_);
kv_chunk_size_ptr_ = allocator.aligned_alloc<void>(sizeof(IdType), 1);
kv_chunk_size_ptr_ =
allocator.aligned_alloc<void>(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr");
void* kv_chunk_size_ptr_h_ =
(char*)page_locked_buffer_ + ((char*)kv_chunk_size_ptr_ - (char*)request_indices_);
*(IdType*)kv_chunk_size_ptr_h_ = kv_chunk_size;
Expand All @@ -772,9 +790,11 @@ class BatchPrefillHandler {

if (split_kv) {
tmp_v_ = allocator.aligned_alloc<void>(
num_qo_heads * new_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16);
num_qo_heads * new_batch_size * qo_tile_size * head_dim * sizeof(DTypeOut), 16,
"batch_prefill_tmp_v");
tmp_s_ = allocator.aligned_alloc<float>(
num_qo_heads * new_batch_size * qo_tile_size * sizeof(float), 16);
num_qo_heads * new_batch_size * qo_tile_size * sizeof(float), 16,
"batch_prefill_tmp_s");
} else {
tmp_v_ = nullptr;
tmp_s_ = nullptr;
Expand Down
14 changes: 7 additions & 7 deletions include/flashinfer/group_gemm/wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ cudaError_t CutlassSegmentGEMMWrapper(CutlassSegmentGEMMHandler* handler, DType*
AlignedAllocator allocator(handler->GetWorkspace(), handler->GetWorkspaceSizeInBytes());
cutlass::gemm::GemmCoord* problem_sizes_device =
allocator.aligned_alloc<cutlass::gemm::GemmCoord>(
batch_size * sizeof(cutlass::gemm::GemmCoord), 16);
DType** x_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16);
DType** w_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16);
DType** y_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16);
int64_t* ld_x = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16);
int64_t* ld_w = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16);
int64_t* ld_y = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16);
batch_size * sizeof(cutlass::gemm::GemmCoord), 16, "problem_sizes_device");
DType** x_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16, "x_data");
DType** w_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16, "w_data");
DType** y_data = allocator.aligned_alloc<DType*>(batch_size * sizeof(DType*), 16, "y_data");
int64_t* ld_x = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16, "ld_x");
int64_t* ld_w = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16, "ld_w");
int64_t* ld_y = allocator.aligned_alloc<int64_t>(batch_size * sizeof(int64_t), 16, "ld_y");

// NOTE(Zihao): I didn't successfully launch the kernel with cudaLaunchKernel API,
// so I just use the kernel function directly, need to investigate more.
Expand Down
6 changes: 6 additions & 0 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
CHECK_CONTIGUOUS(paged_kv_indptr);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
CHECK_DIM(1, qo_indptr);
CHECK_DIM(1, paged_kv_indptr);
CHECK_DIM(1, workspace_buffer);
CHECK_EQ(qo_indptr.size(0), batch_size + 1);
CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1);
qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
paged_kv_indptr = paged_kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
auto device = workspace_buffer.device();
Expand Down Expand Up @@ -361,7 +364,10 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
CHECK_CONTIGUOUS(qo_indptr);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
CHECK_DIM(1, qo_indptr);
CHECK_DIM(1, kv_indptr);
CHECK_DIM(1, workspace_buffer);
CHECK_EQ(qo_indptr.size(0), batch_size + 1);
CHECK_EQ(kv_indptr.size(0), batch_size + 1);
qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
kv_indptr = kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
Expand Down