Skip to content

Commit

Permalink
perf: initial cuda graph support (#256)
Browse files Browse the repository at this point in the history
As requested in #187 , this PR adds initial support of `CUDAGraph`
compatibility of flashinfer batch decode attention kernels. This PR is
the first step towards full CUDAGraph support and we will implement
CUDAGraph compatible prefill operators in later PRs.

# Proposed APIs
We add another wrapper `CUDAGraphBatchDecodeWithPagedKVCacheWrapper`,
and user need to pre-allocation page data structure buffers to
initialize this wrapper class. Once initiated, these buffers are pinned
on GPUs in the life cycle of the wrapper class.

The behavior of `CUDAGraphBatchDecodeWithPagedKVCacheWrapper` is a
little bit different from `BatchDecodeWithPagedKVCacheWrapper`'s: we
will only run a fixed set of kernels in CUDAGraph mode, no matter what
the input shape is (the original implementation will dispatch to
different kernels according to different input shapes).

This PR also fix the address of all kernel input pointers to accomodate
the constraint of CUDAGraph capturing.

# Examples
See `test_cuda_graph_batch_decode_with_paged_kv_cache` in unittests.
`begin_forward` functions should not be captured as some of the
operators are not allowed to be captured.

cc @AgrawalAmey  @LiuXiaoxuanPKU  @comaniac
  • Loading branch information
yzh119 authored May 24, 2024
1 parent ed20304 commit 7e9cc7f
Show file tree
Hide file tree
Showing 12 changed files with 710 additions and 96 deletions.
3 changes: 3 additions & 0 deletions docs/api/python/decode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ Batch Decoding

.. autoclass:: BatchDecodeWithPagedKVCacheWrapper
:members:

.. autoclass:: CUDAGraphDecodeWithPagedKVCacheWrapper
:members:
141 changes: 120 additions & 21 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#define FLASHINFER_HANDLER_CUH_

#include <algorithm>
#include <cstddef>
#include <memory>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -101,7 +102,7 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVL
cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch,
uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads,
const uint32_t page_size, cudaStream_t stream) {
const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) {
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL);
constexpr uint32_t num_stages_smem = 2U;
constexpr uint32_t bdx = HEAD_DIM / vec_size;
Expand All @@ -126,8 +127,10 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size));
max_grid_size = num_blocks_per_sm * num_sm;
if (batch_size * num_kv_heads >= max_grid_size) {
if (batch_size * num_kv_heads >= max_grid_size && !enable_cuda_graph) {
// do not use partition-kv kernel
// TODO(Zihao): if enable_cuda_graph, we should always use partition-kv kernel
// so that only one kernel will be captured in the graph.
tmp_size = 0;
new_batch_size = batch_size;
} else {
Expand Down Expand Up @@ -299,39 +302,42 @@ class BatchDecodeHandler {
DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch,
new_batch_size, batch_size, indptr, num_qo_heads,
page_size, stream_));
page_size,
/*enable_cuda_graph=*/false, stream_));
batch_size_after_partition_ = new_batch_size;
if (tmp_size > 0) {
AlignedAlloactor allocator(buffer, workspace_size_in_bytes);
float_buffer_ = allocator.aligned_alloc<void*>(tmp_size, 16);
new_indptr_ =
allocator.aligned_alloc<void*>((batch_size_after_partition_ + 1) * sizeof(IdType), 16);
void* new_indptr_h_ = host_buffer_;
void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ =
allocator.aligned_alloc<void*>(batch_size_after_partition_ * sizeof(IdType), 16);
void* new_last_page_len_h_ =
(char*)host_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ =
allocator.aligned_alloc<void*>((batch_size_before_partition_ + 1) * sizeof(IdType), 16);
void* chunk_indptr_h_ = (char*)host_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ =
allocator.aligned_alloc<void*>(batch_size_after_partition_ * sizeof(IdType), 16);
void* batch_idx_map_h_ = (char*)host_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ =
allocator.aligned_alloc<void*>(batch_size_after_partition_ * sizeof(IdType), 16);
void* chunk_start_pos_h_ =
(char*)host_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void*>(batch_size_after_partition_ * sizeof(IdType), 16);
void* seq_lengths_before_partition_h_ =
(char*)host_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
(char*)page_locked_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_);
size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len,
(IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, (IdType*)chunk_indptr_h_,
(IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
(IdType*)seq_lengths_before_partition_h_, new_indptr_, host_buffer_, num_bytes_to_copy,
stream_));
(IdType*)seq_lengths_before_partition_h_, new_indptr_, page_locked_buffer_,
num_bytes_to_copy, stream_));
}
forward_started_ = true;
return cudaSuccess;
Expand All @@ -353,6 +359,11 @@ class BatchDecodeHandler {

bool IsForwardStarted() const { return forward_started_; }

void UpdatePageLockedBufferSize(size_t max_workspace_size_in_bytes) {
cudaFreeHost(page_locked_buffer_);
cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes);
}

uint32_t GetBatchSizeBeforePartition() const { return batch_size_before_partition_; }

uint32_t GetBatchSizeAfterPartition() const { return batch_size_after_partition_; }
Expand All @@ -372,17 +383,19 @@ class BatchDecodeHandler {
seq_lengths_before_partition_(nullptr),
forward_started_(false),
stream_(nullptr) {
cudaMallocHost(&host_buffer_, max_workspace_size_in_bytes);
cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes);
}
~BatchDecodeHandler() {
EndForward();
cudaFreeHost(host_buffer_);
cudaFreeHost(page_locked_buffer_);
}

private:
virtual bool IsCUDAGraphMode() const { return false; }

protected:
uint32_t batch_size_before_partition_;
uint32_t batch_size_after_partition_;
void* host_buffer_;
void* page_locked_buffer_;
void* float_buffer_;
void* new_indptr_;
void* new_last_page_len_;
Expand All @@ -394,6 +407,86 @@ class BatchDecodeHandler {
cudaStream_t stream_;
};

class CUDAGraphBatchDecodeHandler : public BatchDecodeHandler {
public:
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t CUDAGraphBeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes,
IdType* indptr, IdType* last_page_len,
uint32_t batch_size, uint32_t num_qo_heads,
uint32_t page_size) {
batch_size_before_partition_ = batch_size;
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
auto work_estimation_func =
BatchDecodeWithPagedKVCacheWorkEstimationDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
kv_layout, POS_ENCODING_MODE, DTypeIn,
DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch,
new_batch_size, batch_size, indptr, num_qo_heads,
page_size,
/*enable_cuda_graph=*/true, stream_));
// NOTE(Zihao): max_batch_size_after_partition_ is determined in handler initialization.
// the value should not be changed during the lifetime of the handler.
// So it should be compatible with CUDAGraph which requires fixed pointer.
batch_size_after_partition_ = new_batch_size;
size_t max_tmp_size = num_qo_heads * max_batch_size_after_partition_ *
(HEAD_DIM * sizeof(DTypeOut) + 2 * sizeof(float));
AlignedAlloactor allocator(buffer, workspace_size_in_bytes);
float_buffer_ = allocator.aligned_alloc<void*>(max_tmp_size, 16);
new_indptr_ =
allocator.aligned_alloc<void*>((max_batch_size_after_partition_ + 1) * sizeof(IdType), 16);

void* new_indptr_h_ = page_locked_buffer_;
new_last_page_len_ =
allocator.aligned_alloc<void*>(max_batch_size_after_partition_ * sizeof(IdType), 16);
void* new_last_page_len_h_ =
(char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_);
chunk_indptr_ =
allocator.aligned_alloc<void*>((max_batch_size_after_partition_ + 1) * sizeof(IdType), 16);
void* chunk_indptr_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_);
batch_idx_map_ =
allocator.aligned_alloc<void*>(max_batch_size_after_partition_ * sizeof(IdType), 16);
void* batch_idx_map_h_ =
(char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_);
chunk_start_pos_ =
allocator.aligned_alloc<void*>(max_batch_size_after_partition_ * sizeof(IdType), 16);
void* chunk_start_pos_h_ =
(char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_);
seq_lengths_before_partition_ =
allocator.aligned_alloc<void*>(max_batch_size_after_partition_ * sizeof(IdType), 16);
void* seq_lengths_before_partition_h_ =
(char*)page_locked_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_);

size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_;
FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo(
max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len,
(IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, (IdType*)chunk_indptr_h_,
(IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_,
(IdType*)seq_lengths_before_partition_h_, new_indptr_, page_locked_buffer_,
num_bytes_to_copy, stream_));
forward_started_ = true;
return cudaSuccess;
}
CUDAGraphBatchDecodeHandler(size_t max_batch_size) {
int dev_id = 0, num_sm = 0, max_thread_blocks_per_sm = 0;
cudaGetDevice(&dev_id);
cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id);
cudaDeviceGetAttribute(&max_thread_blocks_per_sm, cudaDevAttrMaxBlocksPerMultiprocessor,
dev_id);
max_batch_size_after_partition_ =
std::max<size_t>(max_thread_blocks_per_sm * num_sm, max_batch_size);
std::cout << max_thread_blocks_per_sm * num_sm << " " << max_batch_size << std::endl;
size_t max_workspace_size_in_bytes =
6 * (sizeof(uint64_t) * (max_batch_size_after_partition_ + 1) + 16);
cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes);
}
bool IsCUDAGraphMode() const override { return true; }

private:
uint32_t max_batch_size_after_partition_;
};

class BatchPrefillHandler {
public:
template <typename IdType>
Expand All @@ -412,6 +505,11 @@ class BatchPrefillHandler {

bool IsForwardStarted() const { return request_indices_ != nullptr; }

void UpdatePageLockedBufferSize(size_t max_workspace_size_in_bytes) {
cudaFreeHost(page_locked_buffer_);
cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes);
}

template <typename IdType>
cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* qo_indptr,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
Expand All @@ -429,14 +527,15 @@ class BatchPrefillHandler {
AlignedAlloactor allocator(buffer, workspace_size_in_bytes);
request_indices_ =
allocator.aligned_alloc<void*>(sizeof(IdType) * request_indices_vec.size(), 16);
void* request_indices_h_ = host_buffer_;
void* request_indices_h_ = page_locked_buffer_;
tile_indices_ = allocator.aligned_alloc<void*>(sizeof(IdType) * tile_indices_vec.size(), 16);
void* tile_indices_h_ = (char*)host_buffer_ + ((char*)tile_indices_ - (char*)request_indices_);
void* tile_indices_h_ =
(char*)page_locked_buffer_ + ((char*)tile_indices_ - (char*)request_indices_);
std::copy(request_indices_vec.begin(), request_indices_vec.end(), (IdType*)request_indices_h_);
std::copy(tile_indices_vec.begin(), tile_indices_vec.end(), (IdType*)tile_indices_h_);
size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)request_indices_;

FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, host_buffer_, num_bytes_to_copy,
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(request_indices_, page_locked_buffer_, num_bytes_to_copy,
cudaMemcpyHostToDevice, stream_));

return cudaSuccess;
Expand All @@ -462,15 +561,15 @@ class BatchPrefillHandler {
num_qo_tiles_(0U),
forward_started_(false),
stream_(nullptr) {
cudaMallocHost(&host_buffer_, max_workspace_size_in_bytes);
cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes);
}
~BatchPrefillHandler() {
EndForward();
cudaFreeHost(host_buffer_);
cudaFreeHost(page_locked_buffer_);
}

private:
void* host_buffer_;
void* page_locked_buffer_;
void* request_indices_;
void* tile_indices_;
uint32_t num_frags_x_;
Expand Down
36 changes: 20 additions & 16 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t q_idx_base,
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
// load q fragment from gmem to smem
q_smem->load_128b_async<SharedMemFillMode::kFillZero>(q_smem_offset_w, q_ptr,
q_idx < qo_upper_bound && group_id < group_size);
q_smem->load_128b_async<SharedMemFillMode::kFillZero>(
q_smem_offset_w, q_ptr, q_idx < qo_upper_bound && group_id < group_size);
q_smem_offset_w = q_smem->advance_offset_by_column<8>(q_smem_offset_w, fyo);
q_ptr += 8 * num_elems_per_128b<DTypeIn>();
}
Expand Down Expand Up @@ -933,7 +933,7 @@ __global__ void SinglePrefillWithKVCacheKernel(
constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b<DTypeOut>();

static_assert(num_frags_z * num_frags_y % num_warps == 0);
static_assert(group_size == 1 || group_size >= 4 && group_size <=8);
static_assert(group_size == 1 || group_size >= 4 && group_size <= 8);

extern __shared__ uint8_t smem[];

Expand Down Expand Up @@ -1341,7 +1341,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
kv_len = (paged_kv.indptr[request_idx + 1] - paged_kv.indptr[request_idx] - 1) *
paged_kv.page_size +
paged_kv.last_page_len[request_idx];
const uint32_t qo_upper_bound = min(qo_len, (tile_idx + 1) * (num_rows_per_cta / aligned_group_size));
const uint32_t qo_upper_bound =
min(qo_len, (tile_idx + 1) * (num_rows_per_cta / aligned_group_size));

constexpr bool partition_kv = false;
constexpr uint32_t head_dim = num_frags_y * 16;
Expand All @@ -1364,7 +1365,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
}
init_states<num_frags_x, num_frags_y>(o_frag, m, d);

const uint32_t qo_idx_base = ((tile_idx * num_warps + ty) * num_frags_x * 16) / aligned_group_size;
const uint32_t qo_idx_base =
((tile_idx * num_warps + ty) * num_frags_x * 16) / aligned_group_size;
const uint32_t qo_n_stride = get_n_stride_impl<QKVLayout::kNHD, head_dim>(num_qo_heads),
qo_h_stride = get_h_stride_impl<QKVLayout::kNHD, head_dim>(qo_len);
smem_t qo_smem(smem);
Expand All @@ -1386,12 +1388,12 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(

if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) {
if (q_offset == nullptr) {
q_smem_inplace_apply_rotary_multiply_sm_scale<aligned_group_size, num_warps, num_frags_x, num_frags_y,
DTypeIn>(qo_idx_base, qo_len, kv_len, &qo_smem,
&q_smem_offset_r, rope_freq, sm_scale);
q_smem_inplace_apply_rotary_multiply_sm_scale<aligned_group_size, num_warps, num_frags_x,
num_frags_y, DTypeIn>(
qo_idx_base, qo_len, kv_len, &qo_smem, &q_smem_offset_r, rope_freq, sm_scale);
} else {
q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale<aligned_group_size, num_warps, num_frags_x,
num_frags_y, DTypeIn>(
q_smem_inplace_apply_rotary_with_pos_multiply_sm_scale<aligned_group_size, num_warps,
num_frags_x, num_frags_y, DTypeIn>(
qo_indptr[request_idx] + qo_idx_base, q_offset, &qo_smem, &q_smem_offset_r, rope_freq,
sm_scale);
}
Expand All @@ -1418,14 +1420,16 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
cp_async::commit_group();

const uint32_t num_iterations = ceil_div(
(causal ? min(kv_len,
kv_len - qo_len + ((tile_idx + 1) * num_frags_x * num_warps * 16) / aligned_group_size)
: kv_len),
(causal
? min(kv_len, kv_len - qo_len +
((tile_idx + 1) * num_frags_x * num_warps * 16) / aligned_group_size)
: kv_len),
16 * num_frags_z);

const uint32_t mask_iteration =
(causal
? min(kv_len + (tile_idx * num_warps * num_frags_x * 16) / aligned_group_size - qo_len, kv_len)
? min(kv_len + (tile_idx * num_warps * num_frags_x * 16) / aligned_group_size - qo_len,
kv_len)
: kv_len) /
(16 * num_frags_z);

Expand Down Expand Up @@ -1453,8 +1457,8 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
}
// apply mask
if (iter >= mask_iteration) {
mask_s<partition_kv, causal, aligned_group_size, num_warps, num_frags_x, num_frags_y, num_frags_z>(
qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, s_frag);
mask_s<partition_kv, causal, aligned_group_size, num_warps, num_frags_x, num_frags_y,
num_frags_z>(qo_idx_base, iter * 16 * num_frags_z, qo_len, kv_len, kv_len, s_frag);
}

// compute m,d states in online softmax
Expand Down
4 changes: 2 additions & 2 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ std::tuple<IdType, IdType, std::vector<IdType>, std::vector<IdType>> split_qo_in
uint32_t num_qo_tiles = 0;

for (uint32_t i = 0; i < batch_size; ++i) {
for (uint32_t j = qo_indptr_h[i] * aligned_gqa_group_size; j < qo_indptr_h[i + 1] * aligned_gqa_group_size;
j += num_rows_per_cta) {
for (uint32_t j = qo_indptr_h[i] * aligned_gqa_group_size;
j < qo_indptr_h[i + 1] * aligned_gqa_group_size; j += num_rows_per_cta) {
request_indices.push_back(i);
tile_indices.push_back((j - qo_indptr_h[i] * aligned_gqa_group_size) / num_rows_per_cta);
++num_qo_tiles;
Expand Down
Loading

0 comments on commit 7e9cc7f

Please sign in to comment.