Skip to content

Commit

Permalink
feat: Separate Q and KV dtypes for decode (#286)
Browse files Browse the repository at this point in the history
Closes #285

Modified unit tests pass. May need some extra validation.
  • Loading branch information
Yard1 authored Jun 13, 2024
1 parent 1250b68 commit 5602659
Show file tree
Hide file tree
Showing 19 changed files with 473 additions and 371 deletions.
113 changes: 59 additions & 54 deletions include/flashinfer/attention/decode.cuh

Large diffs are not rendered by default.

30 changes: 16 additions & 14 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ namespace flashinfer {

template <bool partition_kv, PosEncodingMode pos_encoding_mode, uint32_t num_stages_smem,
uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz,
PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
PageStorage page_storage, QKVLayout kv_layout, typename DTypeQ, typename DTypeKV,
typename DTypeOut, typename IdType>
__global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeIn* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
DTypeOut* __restrict__ tmp_v, float* __restrict__ tmp_s, float* __restrict__ lse,
bool* __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale,
Expand Down Expand Up @@ -86,7 +86,7 @@ std::pair<uint32_t, uint32_t> PartitionPagedKVCacheBinarySearchMinNumPagePerBatc
* \brief Estimate the temporary buffer size and the maximum grid size for the
* partition-kv BatchDecodeWithPagedKVCache kernel
* \tparam page_storage Whether to store indices or pointers of each active page
* \tparam DTypeIn A template type indicates the input data type
* \tparam DTypeKV A template type indicates the key-value data type
* \tparam DTypeOut A template type indicates the output data type
* \tparam IdType A template type indicates the index data type
* \param tmp_size The estimated temporary buffer size, return 0 if not use partition-kv kernel
Expand All @@ -100,27 +100,29 @@ std::pair<uint32_t, uint32_t> PartitionPagedKVCacheBinarySearchMinNumPagePerBatc
* \return status Indicates whether CUDA calls are successful
*/
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch,
uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads,
const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) {
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL);
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL);
constexpr uint32_t num_stages_smem = 2U;
constexpr uint32_t bdx = HEAD_DIM / vec_size;
static_assert(bdx <= 32);
constexpr uint32_t bdy = GROUP_SIZE;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeIn) == 1 ? 2U : 4U) : 1U;
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U;
const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE;
const uint32_t smem_size =
2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeIn) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeIn*), 2 * bdy * bdz * sizeof(float));
2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float));

// Note that the dtype of Q should not impact the cudaOccupancyMaxActiveBlocksPerMultiprocessor
// return, which is why we just use DTypeKV as it simplifies the API.
auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel<
/*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx, vec_size, bdx,
bdy, bdz, page_storage, kv_layout, DTypeIn, DTypeOut, IdType>;
bdy, bdz, page_storage, kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>;
int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
Expand Down Expand Up @@ -294,7 +296,7 @@ class BatchDecodeHandler {
bool* GetBlockValidMask() const { return block_valid_mask_; }

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr,
IdType* last_page_len, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t page_size) {
Expand All @@ -303,8 +305,8 @@ class BatchDecodeHandler {
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
auto work_estimation_func =
BatchDecodeWithPagedKVCacheWorkEstimationDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
kv_layout, POS_ENCODING_MODE, DTypeIn,
DTypeOut, IdType>;
kv_layout, POS_ENCODING_MODE,
DTypeQ, DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch,
new_batch_size, batch_size, indptr, num_qo_heads,
page_size,
Expand Down
22 changes: 11 additions & 11 deletions include/flashinfer/decode_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,35 @@
namespace flashinfer {

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
PosEncodingMode pos_encoding_mode, typename DTypeIn, typename DTypeOut>
cudaError_t SingleDecodeWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o,
PosEncodingMode pos_encoding_mode, typename DTypeQ, typename DTypeKV, typename DTypeOut>
cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, uint32_t num_kv_heads,
uint32_t seq_len, float sm_scale, float rope_scale,
float rope_theta, cudaStream_t stream);

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheDispatched(
DTypeIn* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
DTypeQ* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale,
float rope_scale, float rope_theta, cudaStream_t stream);

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut>
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut>
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads,
float sm_scale, float rope_scale,
float rope_theta, cudaStream_t stream);

template <PageStorage page_storage, QKVLayout KV_LAYOUT, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched(
BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset,
paged_kv_t<page_storage, KV_LAYOUT, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset,
paged_kv_t<page_storage, KV_LAYOUT, DTypeKV, IdType> paged_kv, DTypeOut* o, float* lse,
float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) {
paged_kv_t<page_storage, KV_LAYOUT, DTypeIn, IdType> new_paged_kv = paged_kv;
paged_kv_t<page_storage, KV_LAYOUT, DTypeKV, IdType> new_paged_kv = paged_kv;
kv_partition_info_t<IdType> kv_partition_info;
DTypeOut* tmp_v = handler->GetTempV<DTypeOut>();
float* tmp_s = handler->GetTempS<float>();
Expand All @@ -82,7 +82,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched(
}

return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage, KV_LAYOUT,
POS_ENCODING_MODE, DTypeIn, DTypeOut, IdType>(
POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>(
q, q_offset, new_paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse,
handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), sm_scale, rope_scale, rope_theta,
stream);
Expand Down
Loading

0 comments on commit 5602659

Please sign in to comment.