Skip to content

Commit

Permalink
rafactor: move gqa_group_size from template parameter to input argu…
Browse files Browse the repository at this point in the history
…ments (#301)

#262 is out of sync with main, this PR rebased the code on main branch.
  • Loading branch information
yzh119 authored Jun 15, 2024
1 parent bb1783b commit c111ca6
Show file tree
Hide file tree
Showing 27 changed files with 1,303 additions and 1,553 deletions.
279 changes: 133 additions & 146 deletions CMakeLists.txt

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ set(FLASHINFER_FASTDIV_TEST ON)
set(FLASHINFER_DISTRIBUTED ON)
# The following configurations can impact the binary
# size of the generated library
set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
set(FLASHINFER_GEN_PAGE_SIZES 1 16 32)
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
Expand Down
376 changes: 190 additions & 186 deletions include/flashinfer/attention/decode.cuh

Large diffs are not rendered by default.

220 changes: 112 additions & 108 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -297,121 +297,125 @@ class BatchDecodeHandler {

bool* GetBlockValidMask() const { return block_valid_mask_; }

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage,
LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout, PosEncodingMode POS_ENCODING_MODE,
typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
template <uint32_t HEAD_DIM, PageStorage page_storage, LogitsPostHook LOGITS_POST_HOOK,
QKVLayout kv_layout, PosEncodingMode POS_ENCODING_MODE, typename DTypeQ,
typename DTypeKV, typename DTypeOut, typename IdType>
cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr,
IdType* last_page_len, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t page_size) {
uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t page_size) {
batch_size_before_partition_ = batch_size;
uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE;
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE, DTypeQ,
DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch,
new_batch_size, batch_size, indptr, num_qo_heads,
page_size,
/*enable_cuda_graph=*/IsCUDAGraphEnabled(), stream_));
batch_size_after_partition_ = new_batch_size;
if (IsCUDAGraphEnabled()) {
if (batch_size != fixed_batch_size_) {
std::ostringstream err_msg;
err_msg << "The running batch size " << batch_size
<< " is not compatible with the fixed batch size " << fixed_batch_size_
<< " initialized for CUDAGraph";
throw std::runtime_error(err_msg.str());
}
size_t padded_batch_size = max_grid_size / num_kv_heads;
if (tmp_size > 0) {
padded_batch_size_ = padded_batch_size;
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(
num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16);
tmp_s_ =
allocator.aligned_alloc<void>(num_qo_heads * padded_batch_size * 2 * sizeof(float), 16);
new_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);

void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* new_last_page_len_h_ =
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* chunk_start_pos_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* seq_lengths_before_partition_h_ =
(char*)page_locked_buffer_ +
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
block_valid_mask_ = allocator.aligned_alloc<bool>(padded_batch_size * sizeof(bool), 16);
bool* block_valid_mask_h_ =
(bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_);
std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0);

size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr,
last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
(IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_,
/*device_buffer=*/new_indptr_,
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, {
auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched<
GROUP_SIZE, HEAD_DIM, page_storage, LOGITS_POST_HOOK, kv_layout, POS_ENCODING_MODE,
DTypeQ, DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(
work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size,
batch_size, indptr, num_qo_heads, page_size,
/*enable_cuda_graph=*/IsCUDAGraphEnabled(), stream_));
batch_size_after_partition_ = new_batch_size;
if (IsCUDAGraphEnabled()) {
if (batch_size != fixed_batch_size_) {
std::ostringstream err_msg;
err_msg << "The running batch size " << batch_size
<< " is not compatible with the fixed batch size " << fixed_batch_size_
<< " initialized for CUDAGraph";
throw std::runtime_error(err_msg.str());
}
size_t padded_batch_size = max_grid_size / num_kv_heads;
if (tmp_size > 0) {
padded_batch_size_ = padded_batch_size;
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(
num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(DTypeOut), 16);
tmp_s_ = allocator.aligned_alloc<void>(
num_qo_heads * padded_batch_size * 2 * sizeof(float), 16);
new_indptr_ = allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);

void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ =
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* new_last_page_len_h_ =
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ =
allocator.aligned_alloc<void>((padded_batch_size + 1) * sizeof(IdType), 16);
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ = allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* chunk_start_pos_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void>(padded_batch_size * sizeof(IdType), 16);
void* seq_lengths_before_partition_h_ =
(char*)page_locked_buffer_ +
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
block_valid_mask_ = allocator.aligned_alloc<bool>(padded_batch_size * sizeof(bool), 16);
bool* block_valid_mask_h_ =
(bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)new_indptr_);
std::fill(block_valid_mask_h_, block_valid_mask_h_ + padded_batch_size, 0);

size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, padded_batch_size, page_size, indptr,
last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
(IdType*)seq_lengths_before_partition_h_, block_valid_mask_h_,
/*device_buffer=*/new_indptr_,
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
} else {
block_valid_mask_ = nullptr;
padded_batch_size_ = batch_size;
}
} else {
// NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled.
block_valid_mask_ = nullptr;
padded_batch_size_ = batch_size;
}
} else {
// NOTE(Zihao): we don't use block_valid_mask when CUDAGraph is disabled.
block_valid_mask_ = nullptr;
// do not pad the batch size when not using CUDAGraph
padded_batch_size_ = batch_size_after_partition_;
if (tmp_size > 0) {
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(tmp_size, 16);
tmp_s_ = (char*)tmp_v_ +
num_qo_heads * batch_size_after_partition_ * HEAD_DIM * sizeof(DTypeOut);
new_indptr_ =
allocator.aligned_alloc<void>((batch_size_after_partition_ + 1) * sizeof(IdType), 16);
void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* new_last_page_len_h_ =
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ =
allocator.aligned_alloc<void>((batch_size_before_partition_ + 1) * sizeof(IdType), 16);
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* chunk_start_pos_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* seq_lengths_before_partition_h_ =
(char*)page_locked_buffer_ +
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr,
last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
(IdType*)seq_lengths_before_partition_h_,
/*block_valid_mask_h=*/nullptr,
/*device_buffer=*/new_indptr_,
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
// do not pad the batch size when not using CUDAGraph
padded_batch_size_ = batch_size_after_partition_;
if (tmp_size > 0) {
AlignedAllocator allocator(buffer, workspace_size_in_bytes);
tmp_v_ = allocator.aligned_alloc<void>(tmp_size, 16);
tmp_s_ = (char*)tmp_v_ +
num_qo_heads * batch_size_after_partition_ * HEAD_DIM * sizeof(DTypeOut);
new_indptr_ =
allocator.aligned_alloc<void>((batch_size_after_partition_ + 1) * sizeof(IdType), 16);
void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* new_last_page_len_h_ =
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ = allocator.aligned_alloc<void>(
(batch_size_before_partition_ + 1) * sizeof(IdType), 16);
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* chunk_start_pos_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void>(batch_size_after_partition_ * sizeof(IdType), 16);
void* seq_lengths_before_partition_h_ =
(char*)page_locked_buffer_ +
((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, batch_size_after_partition_, page_size, indptr,
last_page_len, (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_,
(IdType*)chunk_indptr_h_, (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
(IdType*)seq_lengths_before_partition_h_,
/*block_valid_mask_h=*/nullptr,
/*device_buffer=*/new_indptr_,
/*host_buffer=*/page_locked_buffer_, num_bytes_to_copy, stream_));
}
}
}
});
forward_started_ = true;
return cudaSuccess;
}
Expand Down
Loading

0 comments on commit c111ca6

Please sign in to comment.