Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: reduce the binary size of batch decode kernels #343

Merged
merged 2 commits into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 8 additions & 34 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ foreach(head_dim IN LISTS HEAD_DIMS)
list(APPEND single_decode_kernels_src ${generated_kernel_src})
endforeach(dtype)

# fp8 in, fp16 out
foreach(dtype IN LISTS DECODE_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16.cu)
# fp8 kv-cache
foreach(dtype_kv IN LISTS DECODE_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py ${generated_kernel_src}
Expand All @@ -147,7 +147,7 @@ foreach(head_dim IN LISTS HEAD_DIMS)
VERBATIM
)
list(APPEND single_decode_kernels_src ${generated_kernel_src})
endforeach(dtype)
endforeach(dtype_kv)
endforeach(pos_encoding_mode)
endforeach(kv_layout)
endforeach(logits_post_hook)
Expand All @@ -172,9 +172,9 @@ foreach(head_dim IN LISTS HEAD_DIMS)
list(APPEND batch_decode_kernels_src ${generated_kernel_src})
endforeach(dtype)

# fp8 in, fp16 out
foreach(dtype IN LISTS DECODE_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16_idtype_${idtype}.cu)
# fp8 kv-cache
foreach(dtype_kv IN LISTS DECODE_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py ${generated_kernel_src}
Expand All @@ -183,34 +183,8 @@ foreach(head_dim IN LISTS HEAD_DIMS)
VERBATIM
)
list(APPEND batch_decode_kernels_src ${generated_kernel_src})
endforeach()
endforeach(dtype_kv)
endforeach(idtype)

# padded kv-cache
foreach(dtype IN LISTS DECODE_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_padded_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
list(APPEND batch_decode_kernels_src ${generated_kernel_src})
endforeach(dtype)

# padded kv-cache, fp8 in, fp16 out
foreach(dtype IN LISTS DECODE_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_padded_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
list(APPEND batch_decode_kernels_src ${generated_kernel_src})
endforeach()
endforeach(pos_encoding_mode)
endforeach(kv_layout)
endforeach(logits_post_hook)
Expand Down
200 changes: 0 additions & 200 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -358,141 +358,6 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
}
}

template <LogitsPostHook logits_post_hook, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz,
typename DTypeQ, typename DTypeKV, typename DTypeOut>
__global__ void BatchDecodeWithPaddedKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k,
DTypeKV* __restrict__ v,
DTypeOut* __restrict__ o,
float* __restrict__ lse,
tensor_info_t<kv_layout, bdx * vec_size> info,
float logits_soft_cap, float sm_scale,
float rope_rcp_scale, float rope_rcp_theta) {
auto block = cg::this_thread_block();
sm_scale *=
(logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap));

constexpr uint32_t head_dim = bdx * vec_size;
uint32_t kv_head_idx = blockIdx.y;
uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y;
uint32_t batch_idx = blockIdx.x;
uint32_t num_qo_heads = info.num_qo_heads;
uint32_t num_kv_heads = info.num_kv_heads;
const float alibi_slope = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e;
uint32_t seq_len = info.kv_len;

extern __shared__ uint8_t smem[];
DTypeKV* k_smem = (DTypeKV*)smem;
DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeKV));
float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeKV));

uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z;
vec_t<float, vec_size> q_vec;
vec_t<float, vec_size> freq;
if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) {
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
freq[i] = rope_rcp_scale *
__powf(rope_rcp_theta,
float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim));
}
// apply rotary embedding to q matrix
q_vec = vec_apply_llama_rope<vec_size, bdx>(
q + batch_idx * num_qo_heads * head_dim + info.get_qo_elem_offset(0, qo_head_idx, 0), freq,
seq_len - 1);
} else {
// do not apply rotary embedding to q matrix
q_vec.cast_load(q + batch_idx * num_qo_heads * head_dim +
info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size));
}
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_vec[i] *= sm_scale;
}
block.sync();

// preload k tiles and v tiles
uint32_t producer_kv_idx_base = 0;
constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8;
#pragma unroll
for (uint32_t iter = 0; iter < num_stages_smem; ++iter) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
k_smem + ((iter * bdz + tz) * bdy + ty) * head_dim + tx * vec_size,
k + batch_idx * seq_len * num_kv_heads * head_dim +
info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + tz * bdy + ty < seq_len);
cp_async::commit_group();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + ((iter * bdz + tz) * bdy + ty) * head_dim + tx * vec_size,
v + batch_idx * seq_len * num_kv_heads * head_dim +
info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + tz * bdy + ty < seq_len);
cp_async::commit_group();
producer_kv_idx_base += bdy * bdz;
}

// pipelining k/v tiles loading and state updating
uint32_t consumer_kv_idx_base = 0, stage_idx = 0;
state_t<vec_size> st_local;
float s[bdy];

#pragma unroll 4
for (uint32_t iter = 0; iter < ceil_div(seq_len, bdy * bdz); ++iter) {
// compute qk
cp_async::wait_group<2 * num_stages_smem - 1>();
block.sync();
compute_qk<logits_post_hook, pos_encoding_mode, vec_size, bdx, bdy>(
k_smem + (stage_idx * bdz + tz) * bdy * head_dim, stage_idx, q_vec, freq,
consumer_kv_idx_base, iter * bdy * bdz, seq_len, seq_len - 1, alibi_slope, s, st_local,
logits_soft_cap);
block.sync();
// load k
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
k_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size,
k + batch_idx * seq_len * num_kv_heads * head_dim +
info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + tz * bdy + ty < seq_len);
cp_async::commit_group();

// update m/d/o state
cp_async::wait_group<2 * num_stages_smem - 1>();
block.sync();
update_local_state<vec_size, bdx, bdy>(v_smem + (stage_idx * bdz + tz) * bdy * head_dim, s,
stage_idx, st_local);
block.sync();

// load v
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size,
v + batch_idx * seq_len * num_kv_heads * head_dim +
info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + tz * bdy + ty < seq_len);
cp_async::commit_group();

stage_idx = (stage_idx + 1) % num_stages_smem;
producer_kv_idx_base += bdy * bdz;
consumer_kv_idx_base += bdy * bdz;
}
cp_async::wait_group<0>();
block.sync();

// sync local state of all warps inside a threadblock
sync_state<vec_size, bdx, bdy, bdz>(st_local, reinterpret_cast<float*>(smem), smem_md);

st_local.normalize();
st_local.o.cast_store(o + batch_idx * num_qo_heads * head_dim +
info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size));

// write lse
if (lse != nullptr) {
lse[batch_idx * num_qo_heads + qo_head_idx] = st_local.get_lse();
}
}

/*!
* \brief FlashAttention decoding cuda kernel with paged kv-cache for multiple requests
* \tparam logits_post_hook The logits post hook used in the kernel
Expand Down Expand Up @@ -937,71 +802,6 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
return cudaSuccess;
}

/*!
* \brief FlashAttention decoding cuda kernel with paged kv-cache for batched requests
* \tparam page_storage Whether to store indices or pointers of each active page
* \tparam DTypeQ A template type indicates the query data type
* \tparam DTypeKV A template type indicates the key-value data type
* \tparam DTypeOut A template type indicates the output data type
* \tparam IdType A template type indicates the index data type used in paged kv-cache
* \param q [batch_size, num_qo_heads, head_dim] The query matrix
* \param paged_kv The paged kv cache data structure
* \param o [batch_size, num_qo_heads, head_dim] The output matrix
* \param tmp Used-allocated temporary buffer
* \param lse The logsumexp values.
* \param num_qo_heads A integer indicates the number of heads of query and output
* \param pos_encoding_mode The positional encoding mode
* \param rope_scale The scaling ratio used in RoPE Interpolation.
* \param rope_theta A floating point number indicate the "theta" used in RoPE
* \param stream The cuda stream to launch the kernel
* \return status Indicates whether CUDA calls are successful
*/
template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOUT,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut>
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads,
uint32_t num_kv_heads, float logits_soft_cap,
float sm_scale, float rope_scale,
float rope_theta, cudaStream_t stream) {
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;

constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL);
constexpr uint32_t num_stages_smem = 2U;
constexpr uint32_t bdx = HEAD_DIM / vec_size;
static_assert(bdx <= 32);
DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, {
constexpr uint32_t bdy = GROUP_SIZE;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);

const uint32_t smem_size = 2 * num_stages_smem * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
2 * bdy * bdz * sizeof(float);

dim3 nblks(batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
auto kernel = BatchDecodeWithPaddedKVCacheKernel<LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE,
num_stages_smem, vec_size, bdx, bdy, bdz,
DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
tensor_info_t<KV_LAYOUT, HEAD_DIM> info(1, padded_kv_len, num_qo_heads, num_kv_heads);
void* args[] = {(void*)&q,
(void*)&k,
(void*)&v,
(void*)&o,
(void*)&lse,
(void*)&info,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
}

} // namespace flashinfer

#endif // FLASHINFER_DECODE_CUH_
9 changes: 0 additions & 9 deletions include/flashinfer/decode_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,6 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream);

template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOUT,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut>
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads,
uint32_t num_kv_heads, float logits_soft_cap,
float sm_scale, float rope_scale,
float rope_theta, cudaStream_t stream);

template <PageStorage PAGE_STORAGE, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV,
typename DTypeOut, typename IdType>
Expand Down
3 changes: 3 additions & 0 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@
if (group_size == 1) { \
constexpr size_t GROUP_SIZE = 1; \
__VA_ARGS__ \
} else if (group_size == 2) { \
constexpr size_t GROUP_SIZE = 2; \
__VA_ARGS__ \
} else if (group_size == 4) { \
constexpr size_t GROUP_SIZE = 4; \
__VA_ARGS__ \
Expand Down
1 change: 0 additions & 1 deletion python/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# sdist & wheel
include version.txt
include generate_batch_padded_decode_inst.py
include generate_batch_paged_decode_inst.py
include generate_batch_paged_prefill_inst.py
include generate_batch_ragged_prefill_inst.py
Expand Down
Loading