diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 48df322b..907a686e 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -246,11 +246,11 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); } // apply rotary embedding to q matrix - q_vec = vec_apply_llama_rope(q + info.get_qo_elem_offset(0, qo_head_idx, 0), - freq, seq_len - 1); + q_vec = vec_apply_llama_rope(q + info.get_q_elem_offset(0, qo_head_idx, 0), freq, + seq_len - 1); } else { // do not apply rotary embedding to q matrix - q_vec.cast_load(q + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size)); + q_vec.cast_load(q + info.get_q_elem_offset(0, qo_head_idx, tx * vec_size)); } // multiple q_vec by sm_scale #pragma unroll diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 8dc1be19..34ae1180 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -174,7 +174,7 @@ __device__ __forceinline__ void q_frag_apply_llama_rope_with_pos( template __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* smem_offset, T** gptr, - const uint32_t kv_n_stride, const uint32_t kv_idx_base, + const uint32_t kv_stride_n, const uint32_t kv_idx_base, const uint32_t kv_len) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t num_warps = num_warps_x * num_warps_z; @@ -194,7 +194,7 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* smem_offset, T kv_idx += num_warps * 4; *smem_offset = smem.advance_offset_by_row(*smem_offset) - 2 * num_frags_y; - *gptr += num_warps * 4 * kv_n_stride - 2 * num_frags_y * num_elems_per_128b(); + *gptr += num_warps * 4 * kv_stride_n - 2 * num_frags_y * num_elems_per_128b(); } *smem_offset -= num_warps_z * num_frags_z * 16 * channel_size_128b_in; } @@ -276,8 +276,8 @@ template __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, const uint32_t qo_upper_bound, - DTypeIn* q_ptr_base, const uint32_t qo_n_stride, - const uint32_t qo_h_stride, + DTypeIn* q_ptr_base, const uint32_t q_stride_n, + const uint32_t q_stride_h, const uint_fastdiv group_size, smem_t* q_smem) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b(); @@ -294,7 +294,7 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, uint32_t q, r; group_size.divmod(packed_offset + lane_idx / 8 + fx * 16 + j * 4, q, r); const uint32_t q_idx = q; - DTypeIn* q_ptr = q_ptr_base + q * qo_n_stride + r * qo_h_stride; + DTypeIn* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h; #pragma unroll for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { // load q fragment from gmem to smem @@ -879,8 +879,8 @@ template __device__ __forceinline__ void write_o_reg_gmem( float (*o_frag)[num_frags_y][8], smem_t* o_smem, DTypeOut* o_ptr_base, - const uint32_t o_packed_idx_base, const uint32_t qo_upper_bound, const uint32_t qo_n_stride, - const uint32_t qo_h_stride, const uint_fastdiv group_size) { + const uint32_t o_packed_idx_base, const uint32_t qo_upper_bound, const uint32_t o_stride_n, + const uint32_t o_stride_h, const uint_fastdiv group_size) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); const uint32_t warp_idx_x = get_warp_idx_x(); @@ -920,7 +920,7 @@ __device__ __forceinline__ void write_o_reg_gmem( uint32_t q, r; group_size.divmod(o_packed_idx_base + lane_idx / 8 + fx * 16 + j * 4, q, r); const uint32_t o_idx = q; - DTypeOut* o_ptr = o_ptr_base + q * qo_n_stride + r * qo_h_stride; + DTypeOut* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h; #pragma unroll for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { if (o_idx < qo_upper_bound) { @@ -971,7 +971,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, uint8_t* __restrict__ custom_mask, DTypeOut* __restrict__ o, float* __restrict__ lse, const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, - const QKVLayout kv_layout, const float logits_soft_cap, float sm_scale, + const uint32_t q_stride_n, const uint32_t q_stride_h, const uint32_t kv_stride_n, + const uint32_t kv_stride_h, const float logits_soft_cap, float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); @@ -981,8 +982,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC const uint32_t bx = blockIdx.x, chunk_idx = blockIdx.y, kv_head_idx = blockIdx.z; const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; - const tensor_info_t qkv_info(qo_len, kv_len, num_qo_heads, num_kv_heads, kv_layout, - num_frags_y * 16); + const tensor_info_t qkv_info(qo_len, kv_len, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, + kv_stride_n, kv_stride_h, /*head_dim=*/num_frags_y * 16); float alibi_slopes[num_frags_x][2]; const uint32_t num_chunks = gridDim.y; @@ -1010,24 +1011,22 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC // cooperative fetch q fragment from gmem to reg const uint32_t qo_packed_idx_base = (bx * num_warps_x + get_warp_idx_x()) * num_frags_x * 16; - const uint32_t kv_n_stride = qkv_info.kv_stride_n, qo_n_stride = qkv_info.qo_stride_n, - qo_h_stride = qkv_info.qo_stride_h; smem_t qo_smem(smem); DTypeIn* q_ptr_base = - q + qkv_info.get_qo_elem_offset(0, kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()); + q + qkv_info.get_q_elem_offset(0, kv_head_idx * group_size, + (lane_idx % 8) * num_elems_per_128b()); DTypeOut* o_ptr_base = partition_kv ? o + chunk_idx * num_qo_heads * head_dim + - qkv_info.get_qo_elem_offset(0, kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()) - : o + qkv_info.get_qo_elem_offset(0, kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()); + qkv_info.get_o_elem_offset(0, kv_head_idx * group_size, + (lane_idx % 8) * num_elems_per_128b()) + : o + qkv_info.get_o_elem_offset(0, kv_head_idx * group_size, + (lane_idx % 8) * num_elems_per_128b()); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( get_warp_idx_x() * num_frags_x * 16 + lane_idx % 16, lane_idx / 16); load_q_global_smem( - qo_packed_idx_base, qo_len, q_ptr_base, qo_n_stride, qo_h_stride, group_size, &qo_smem); + qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem); cp_async::commit_group(); cp_async::wait_group<0>(); @@ -1092,10 +1091,10 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC kv_smem_offset_w = smem_t::get_permuted_offset( warp_idx * 4 + lane_idx / 8, lane_idx % 8); produce_kv( - k_smem, &kv_smem_offset_w, &k_ptr, kv_n_stride, chunk_start, chunk_end); + k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, chunk_start, chunk_end); cp_async::commit_group(); produce_kv( - v_smem, &kv_smem_offset_w, &v_ptr, kv_n_stride, chunk_start, chunk_end); + v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, chunk_start, chunk_end); cp_async::commit_group(); #pragma unroll 1 @@ -1143,7 +1142,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC block.sync(); produce_kv( - k_smem, &kv_smem_offset_w, &k_ptr, kv_n_stride, + k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, chunk_end); cp_async::commit_group(); cp_async::wait_group<1>(); @@ -1155,7 +1154,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC block.sync(); produce_kv( - v_smem, &kv_smem_offset_w, &v_ptr, kv_n_stride, + v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, chunk_end); cp_async::commit_group(); } @@ -1172,7 +1171,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC // write back write_o_reg_gmem( o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, - partition_kv ? qo_n_stride * num_chunks : qo_n_stride, qo_h_stride, group_size); + /*o_stride_n=*/partition_kv ? num_qo_heads * head_dim * num_chunks : num_qo_heads * head_dim, + /*o_stride_h=*/head_dim, group_size); // write lse if (lse != nullptr || partition_kv) { @@ -1213,7 +1213,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg IdType* __restrict__ k_rope_pos_offset, IdType* __restrict__ o_indptr, DTypeOut* __restrict__ o, float* __restrict__ lse, bool* __restrict__ block_valid_mask, IdType* __restrict__ kv_chunk_size_ptr, const uint_fastdiv group_size, - const QKVLayout kv_layout, const float logits_soft_cap, float sm_scale, + const uint32_t q_stride_n, const uint32_t q_stride_h, const uint32_t kv_stride_n, + const uint32_t kv_stride_h, const float logits_soft_cap, float sm_scale, const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { static_assert(sizeof(DTypeIn) == 2); static_assert(sizeof(DTypeOut) == 2); @@ -1237,8 +1238,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg const uint32_t chunk_size = partition_kv ? kv_chunk_size : kv_len; const uint32_t chunk_start = partition_kv ? kv_tile_idx * chunk_size : 0; const uint32_t chunk_end = partition_kv ? min((kv_tile_idx + 1) * chunk_size, kv_len) : kv_len; - const tensor_info_t qkv_info(qo_len, kv_len, num_qo_heads, num_kv_heads, kv_layout, - num_frags_y * 16); + const tensor_info_t qkv_info(qo_len, kv_len, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, + kv_stride_n, kv_stride_h, /*head_dim=*/num_frags_y * 16); float alibi_slopes[num_frags_x][2]; const uint32_t qo_upper_bound = min(qo_len, ceil_div((qo_tile_idx + 1) * num_rows_per_cta, group_size)); @@ -1261,28 +1262,25 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg const uint32_t qo_packed_idx_base = (qo_tile_idx * num_warps_x + get_warp_idx_x()) * num_frags_x * 16; - const uint32_t kv_n_stride = qkv_info.kv_stride_n, qo_n_stride = qkv_info.qo_stride_n, - qo_h_stride = qkv_info.qo_stride_h; smem_t qo_smem(smem); DTypeIn* q_ptr_base = - q + qkv_info.get_qo_elem_offset(q_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()); + q + qkv_info.get_q_elem_offset(q_indptr[request_idx], kv_head_idx * group_size, + (lane_idx % 8) * num_elems_per_128b()); DTypeIn* o_ptr_base = partition_kv ? o + kv_tile_idx * num_qo_heads * head_dim + - qkv_info.get_qo_elem_offset(o_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()) - : o + qkv_info.get_qo_elem_offset(o_indptr[request_idx], kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b()); + qkv_info.get_o_elem_offset(o_indptr[request_idx], kv_head_idx * group_size, + (lane_idx % 8) * num_elems_per_128b()) + : o + qkv_info.get_o_elem_offset(o_indptr[request_idx], kv_head_idx * group_size, + (lane_idx % 8) * num_elems_per_128b()); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( get_warp_idx_x() * num_frags_x * 16 + lane_idx % 16, lane_idx / 16); load_q_global_smem( - qo_packed_idx_base, qo_upper_bound, q_ptr_base, qo_n_stride, qo_h_stride, group_size, - &qo_smem); + qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem); cp_async::commit_group(); cp_async::wait_group<0>(); @@ -1357,10 +1355,10 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg kv_head_idx, (lane_idx % 8) * num_elems_per_128b()); produce_kv( - k_smem, &kv_smem_offset_w, &k_ptr, kv_n_stride, chunk_start, chunk_end); + k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, chunk_start, chunk_end); cp_async::commit_group(); produce_kv( - v_smem, &kv_smem_offset_w, &v_ptr, kv_n_stride, chunk_start, chunk_end); + v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, chunk_start, chunk_end); cp_async::commit_group(); #pragma unroll 1 @@ -1410,7 +1408,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg block.sync(); produce_kv( - k_smem, &kv_smem_offset_w, &k_ptr, kv_n_stride, + k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, kv_len); cp_async::commit_group(); cp_async::wait_group<1>(); @@ -1422,7 +1420,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg block.sync(); produce_kv( - v_smem, &kv_smem_offset_w, &v_ptr, kv_n_stride, + v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, chunk_start + (iter + 1) * 16 * num_warps_z * num_frags_z, kv_len); cp_async::commit_group(); } @@ -1441,7 +1439,9 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg // write back write_o_reg_gmem( o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, - partition_kv ? qo_n_stride * num_kv_chunks : qo_n_stride, qo_h_stride, group_size); + /*o_stride_n=*/ + partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim, + /*o_stride_h=*/head_dim, group_size); // write lse if (lse != nullptr) { @@ -1534,25 +1534,24 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage const uint32_t qo_packed_idx_base = (qo_tile_idx * num_warps_x + get_warp_idx_x()) * num_frags_x * 16; - const uint32_t qo_n_stride = num_qo_heads * head_dim, qo_h_stride = head_dim; + const uint32_t q_stride_n = num_qo_heads * head_dim, q_stride_h = head_dim; smem_t qo_smem(smem); DTypeIn* q_ptr_base = q + get_elem_offset_impl(q_indptr[request_idx], kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b(), - qo_n_stride, qo_h_stride); + q_stride_n, q_stride_h); DTypeIn* o_ptr_base = partition_kv ? o + kv_tile_idx * num_qo_heads * head_dim + get_elem_offset_impl(o_indptr[request_idx], kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b(), - qo_n_stride, qo_h_stride) + num_qo_heads * head_dim, head_dim) : o + get_elem_offset_impl(o_indptr[request_idx], kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b(), - qo_n_stride, qo_h_stride); + num_qo_heads * head_dim, head_dim); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( get_warp_idx_x() * num_frags_x * 16 + lane_idx % 16, lane_idx / 16); load_q_global_smem( - qo_packed_idx_base, qo_upper_bound, q_ptr_base, qo_n_stride, qo_h_stride, group_size, - &qo_smem); + qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem); cp_async::commit_group(); cp_async::wait_group<0>(); @@ -1732,7 +1731,9 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage // write_back write_o_reg_gmem( o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, - partition_kv ? qo_n_stride * num_kv_chunks : qo_n_stride, qo_h_stride, group_size); + /*o_stride_n=*/ + partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim, + /*o_stride_h=*/head_dim, group_size); // write lse if (lse != nullptr) { @@ -1765,7 +1766,8 @@ template (); @@ -2002,7 +2011,10 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( (void*)&block_valid_mask, (void*)&kv_chunk_size_ptr, (void*)&group_size_fastdiv, - (void*)&kv_layout, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&kv_stride_n, + (void*)&kv_stride_h, (void*)&logits_soft_cap, (void*)&sm_scale, (void*)&log2_rope_rcp_scale, @@ -2035,7 +2047,10 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( (void*)&block_valid_mask, (void*)&kv_chunk_size_ptr, (void*)&group_size_fastdiv, - (void*)&kv_layout, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&kv_stride_n, + (void*)&kv_stride_h, (void*)&logits_soft_cap, (void*)&sm_scale, (void*)&log2_rope_rcp_scale, diff --git a/include/flashinfer/layout.cuh b/include/flashinfer/layout.cuh index b5fd0ef7..caa0df2c 100644 --- a/include/flashinfer/layout.cuh +++ b/include/flashinfer/layout.cuh @@ -17,6 +17,7 @@ #define FLASHINFER_LAYOUT_CUH_ #include +#include namespace flashinfer { @@ -36,42 +37,64 @@ __host__ __device__ __forceinline__ size_t get_elem_offset_impl(size_t elem_idx, return elem_idx * stride_n + head_idx * stride_h + feat_idx; } +__host__ __forceinline__ auto get_qkv_strides(QKVLayout kv_layout, uint32_t kv_len, + uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim) { + const uint32_t q_stride_n = num_qo_heads * head_dim, q_stride_h = head_dim, + kv_stride_n = (kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim : head_dim, + kv_stride_h = (kv_layout == QKVLayout::kNHD) ? head_dim : kv_len * head_dim; + return std::make_tuple(q_stride_n, q_stride_h, kv_stride_n, kv_stride_h); +} + struct tensor_info_t { uint32_t qo_len; uint32_t kv_len; uint32_t num_qo_heads; uint32_t num_kv_heads; - uint32_t qo_stride_n; - uint32_t qo_stride_h; + uint32_t q_stride_n; + uint32_t q_stride_h; uint32_t kv_stride_n; uint32_t kv_stride_h; + uint32_t head_dim; __host__ __device__ __forceinline__ tensor_info_t(uint32_t qo_len, uint32_t kv_len, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t qo_stride_n, uint32_t qo_stride_h, - uint32_t kv_stride_n, uint32_t kv_stride_h) + uint32_t q_stride_n, uint32_t q_stride_h, + uint32_t kv_stride_n, uint32_t kv_stride_h, + uint32_t head_dim) : qo_len(qo_len), kv_len(kv_len), num_qo_heads(num_qo_heads), num_kv_heads(num_kv_heads), - qo_stride_n(qo_stride_n), - qo_stride_h(qo_stride_h), + q_stride_n(q_stride_n), + q_stride_h(q_stride_h), kv_stride_n(kv_stride_n), - kv_stride_h(kv_stride_h) {} + kv_stride_h(kv_stride_h), + head_dim(head_dim) {} __host__ __device__ __forceinline__ tensor_info_t(uint32_t qo_len, uint32_t kv_len, uint32_t num_qo_heads, uint32_t num_kv_heads, QKVLayout kv_layout, uint32_t head_dim) - : qo_len(qo_len), kv_len(kv_len), num_qo_heads(num_qo_heads), num_kv_heads(num_kv_heads) { - qo_stride_n = num_qo_heads * head_dim; - qo_stride_h = head_dim; + : qo_len(qo_len), + kv_len(kv_len), + num_qo_heads(num_qo_heads), + num_kv_heads(num_kv_heads), + head_dim(head_dim) { + q_stride_n = num_qo_heads * head_dim; + q_stride_h = head_dim; kv_stride_n = (kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim : head_dim; kv_stride_h = (kv_layout == QKVLayout::kNHD) ? head_dim : kv_len * head_dim; } - __host__ __device__ __forceinline__ size_t get_qo_elem_offset(uint32_t qo_idx, - uint32_t qo_head_idx, - uint32_t feat_idx) const { - return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, qo_stride_n, qo_stride_h); + __host__ __device__ __forceinline__ size_t get_q_elem_offset(uint32_t qo_idx, + uint32_t qo_head_idx, + uint32_t feat_idx) const { + return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, q_stride_n, q_stride_h); + } + + __host__ __device__ __forceinline__ size_t get_o_elem_offset(uint32_t qo_idx, + uint32_t qo_head_idx, + uint32_t feat_idx) const { + return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, num_qo_heads * head_dim, head_dim); } __host__ __device__ __forceinline__ size_t get_kv_elem_offset(uint32_t kv_idx, diff --git a/include/flashinfer/prefill_attention_decl.cuh b/include/flashinfer/prefill_attention_decl.cuh index 2a09fc7c..4d699caf 100644 --- a/include/flashinfer/prefill_attention_decl.cuh +++ b/include/flashinfer/prefill_attention_decl.cuh @@ -34,8 +34,8 @@ template BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { - CHECK_INPUT(q); CHECK_INPUT(qo_indptr); - CHECK_INPUT(k); - CHECK_INPUT(v); + CHECK_CUDA(q); + CHECK_CUDA(k); + CHECK_CUDA(v); CHECK_INPUT(kv_indptr); auto device = q.device(); CHECK_EQ(device, qo_indptr.device()); @@ -414,11 +414,22 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( int64_t head_dim = q.size(2); CHECK_GE(kv_indptr.size(0), batch_size + 1); int64_t num_kv_heads = (kv_layout_ == QKVLayout::kNHD) ? k.size(1) : k.size(0); + CHECK_EQ(q.stride(2), 1); + CHECK_EQ(k.stride(2), 1); + CHECK_EQ(v.stride(2), 1); CHECK_EQ(k.size(0), v.size(0)); CHECK_EQ(k.size(1), v.size(1)); CHECK_EQ(k.size(2), v.size(2)); CHECK_EQ(k.size(2), head_dim); CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); + uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h; + if (kv_layout_ == QKVLayout::kNHD) { + kv_stride_n = k.stride(0); + kv_stride_h = k.stride(1); + } else { + kv_stride_h = k.stride(0); + kv_stride_n = k.stride(1); + } qo_indptr = qo_indptr.to(torch::kInt32); kv_indptr = kv_indptr.to(torch::kInt32); @@ -453,8 +464,8 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, num_kv_heads, kv_layout_, logits_soft_cap, sm_scale, - rope_scale, rope_theta, + num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, + kv_stride_h, logits_soft_cap, sm_scale, rope_scale, rope_theta, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", @@ -479,10 +490,10 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { - CHECK_INPUT(q); CHECK_INPUT(qo_indptr); - CHECK_INPUT(k); - CHECK_INPUT(v); + CHECK_CUDA(q); + CHECK_CUDA(k); + CHECK_CUDA(v); CHECK_INPUT(kv_indptr); CHECK_INPUT(custom_mask); CHECK_INPUT(qk_indptr); @@ -509,11 +520,22 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC CHECK_GE(kv_indptr.size(0), batch_size + 1); CHECK_GE(qk_indptr.size(0), batch_size + 1); int64_t num_kv_heads = (kv_layout_ == QKVLayout::kNHD) ? k.size(1) : k.size(0); + CHECK_EQ(q.stride(2), 1); + CHECK_EQ(k.stride(2), 1); + CHECK_EQ(v.stride(2), 1); CHECK_EQ(k.size(0), v.size(0)); CHECK_EQ(k.size(1), v.size(1)); CHECK_EQ(k.size(2), v.size(2)); CHECK_EQ(k.size(2), head_dim); CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); + uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h; + if (kv_layout_ == QKVLayout::kNHD) { + kv_stride_n = k.stride(0); + kv_stride_h = k.stride(1); + } else { + kv_stride_h = k.stride(0); + kv_stride_n = k.stride(1); + } qo_indptr = qo_indptr.to(torch::kInt32); kv_indptr = kv_indptr.to(torch::kInt32); qk_indptr = qk_indptr.to(torch::kInt32); @@ -549,8 +571,8 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, num_kv_heads, kv_layout_, logits_soft_cap, sm_scale, - rope_scale, rope_theta, + num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, + kv_stride_h, logits_soft_cap, sm_scale, rope_scale, rope_theta, /*stream=*/torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCache failed with error ", diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index 79d7162f..37d1a838 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -24,9 +24,9 @@ std::vector single_prefill_with_kv_cache( torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { - CHECK_INPUT(q); - CHECK_INPUT(k); - CHECK_INPUT(v); + CHECK_CUDA(q); + CHECK_CUDA(k); + CHECK_CUDA(v); CHECK_INPUT(tmp); auto device = q.device(); CHECK_EQ(k.device(), device); @@ -36,6 +36,9 @@ std::vector single_prefill_with_kv_cache( CHECK_DIM(3, k); CHECK_DIM(3, v); CHECK_SHAPE(k, v); + CHECK_EQ(q.stride(2), 1); + CHECK_EQ(k.stride(2), 1); + CHECK_EQ(v.stride(2), 1); CHECK_EQ(q.size(2), k.size(2)); CHECK_EQ(q.scalar_type(), k.scalar_type()); CHECK_EQ(q.scalar_type(), v.scalar_type()); @@ -44,12 +47,17 @@ std::vector single_prefill_with_kv_cache( QKVLayout kv_layout = static_cast(layout); qo_len = q.size(0); num_qo_heads = q.size(1); + uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h; if (kv_layout == QKVLayout::kNHD) { kv_len = k.size(0); num_kv_heads = k.size(1); - } else { + kv_stride_n = k.stride(0); + kv_stride_h = k.stride(1); + } else { // QKVLayout::kHND kv_len = k.size(1); num_kv_heads = k.size(0); + kv_stride_h = k.stride(0); + kv_stride_n = k.stride(1); } CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); @@ -63,7 +71,7 @@ std::vector single_prefill_with_kv_cache( TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { @@ -72,16 +80,19 @@ std::vector single_prefill_with_kv_cache( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { return DISPATCH_pos_encoding_mode( PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = SinglePrefillWithKVCacheDispatched< - HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, - MASK_MODE>( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), - /*custom_mask=*/nullptr, static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, num_kv_heads, qo_len, kv_len, kv_layout, logits_soft_cap, - sm_scale, rope_scale, rope_theta, torch_current_stream); + cudaError_t status = + SinglePrefillWithKVCacheDispatched( + static_cast(q.data_ptr()), + static_cast(k.data_ptr()), + static_cast(v.data_ptr()), + /*custom_mask=*/nullptr, static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h, + kv_stride_n, kv_stride_h, logits_soft_cap, sm_scale, rope_scale, + rope_theta, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "SinglePrefillWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); @@ -105,9 +116,9 @@ std::vector single_prefill_with_kv_cache_custom_mask( torch::Tensor tmp, unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { - CHECK_INPUT(q); - CHECK_INPUT(k); - CHECK_INPUT(v); + CHECK_CUDA(q); + CHECK_CUDA(k); + CHECK_CUDA(v); CHECK_INPUT(packed_custom_mask); auto device = q.device(); CHECK_EQ(k.device(), device); @@ -118,6 +129,9 @@ std::vector single_prefill_with_kv_cache_custom_mask( CHECK_DIM(3, v); CHECK_DIM(1, packed_custom_mask); CHECK_SHAPE(k, v); + CHECK_EQ(q.stride(2), 1); + CHECK_EQ(k.stride(2), 1); + CHECK_EQ(v.stride(2), 1); CHECK_EQ(q.size(2), k.size(2)); // packed_custom_mask must be uint8 TORCH_CHECK(packed_custom_mask.scalar_type() == torch::kUInt8, @@ -127,12 +141,17 @@ std::vector single_prefill_with_kv_cache_custom_mask( QKVLayout kv_layout = static_cast(layout); qo_len = q.size(0); num_qo_heads = q.size(1); + uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h; if (kv_layout == QKVLayout::kNHD) { kv_len = k.size(0); num_kv_heads = k.size(1); + kv_stride_n = k.stride(0); + kv_stride_h = k.stride(1); } else { kv_len = k.size(1); num_kv_heads = k.size(0); + kv_stride_h = k.stride(0); + kv_stride_n = k.stride(1); } CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); @@ -164,8 +183,9 @@ std::vector single_prefill_with_kv_cache_custom_mask( static_cast(o.data_ptr()), static_cast(tmp.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - num_qo_heads, num_kv_heads, qo_len, kv_len, kv_layout, logits_soft_cap, - sm_scale, rope_scale, rope_theta, torch_current_stream); + num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h, + kv_stride_n, kv_stride_h, logits_soft_cap, sm_scale, rope_scale, + rope_theta, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "SinglePrefillWithKVCache kernel launch failed, error: " + std::string(cudaGetErrorString(status))); diff --git a/python/generate_batch_ragged_prefill_inst.py b/python/generate_batch_ragged_prefill_inst.py index b326f189..cbb3f2ac 100644 --- a/python/generate_batch_ragged_prefill_inst.py +++ b/python/generate_batch_ragged_prefill_inst.py @@ -46,8 +46,8 @@ def get_cu_file_str( uint8_t* custom_mask, {idtype}* qk_indptr, {idtype}* q_offset, {idtype}* k_rope_pos_offset, {idtype}* o_indptr, {dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse, {idtype}* merge_indptr, bool* block_valid_mask, {idtype}* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads, - uint32_t padded_batch_size, uint32_t num_kv_heads, QKVLayout kv_layout, - float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, + uint32_t padded_batch_size, uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h, + uint32_t kv_stride_n, uint32_t kv_stride_h, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); """.format( warp_layout=warp_layout_literal[warp_layout], diff --git a/python/generate_single_prefill_inst.py b/python/generate_single_prefill_inst.py index 1749456e..cf1702ab 100644 --- a/python/generate_single_prefill_inst.py +++ b/python/generate_single_prefill_inst.py @@ -42,8 +42,8 @@ def get_cu_file_str( template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {logits_hook}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}>( {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, uint8_t* custom_mask, {dtype_out}* o, {dtype_out}* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, - QKVLayout kv_layout, float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream); + uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, uint32_t kv_stride_h, + float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); }} """.format( diff --git a/python/tests/test_non_contiguous_prefill.py b/python/tests/test_non_contiguous_prefill.py new file mode 100644 index 00000000..398f6d75 --- /dev/null +++ b/python/tests/test_non_contiguous_prefill.py @@ -0,0 +1,113 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import numpy +import pytest +import torch + +import flashinfer + + +@pytest.mark.parametrize("seq_len", [1, 7, 127, 999, 3579]) +@pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) +@pytest.mark.parametrize("num_qo_heads", [4, 8, 32]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("causal", [True, False]) +def test_single_prefill_packed_input( + seq_len, num_kv_heads, num_qo_heads, head_dim, causal +): + if num_qo_heads % num_kv_heads != 0: + pytest.skip("num_qo_heads must be a multiple of num_kv_heads") + qkv_packed = torch.randn( + seq_len, + (num_qo_heads + 2 * num_kv_heads) * head_dim, + dtype=torch.float16, + device="cuda:0", + ) + q = qkv_packed[:, : num_qo_heads * head_dim].reshape( + seq_len, num_qo_heads, head_dim + ) + k = qkv_packed[ + :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim + ].reshape(seq_len, num_kv_heads, head_dim) + v = qkv_packed[:, (num_qo_heads + num_kv_heads) * head_dim :].reshape( + seq_len, num_kv_heads, head_dim + ) + + o_packed = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=causal) + o_contiguous = flashinfer.single_prefill_with_kv_cache( + q.contiguous(), k.contiguous(), v.contiguous(), causal=causal + ) + + numpy.testing.assert_allclose( + o_packed.cpu(), o_contiguous.cpu(), rtol=1e-3, atol=1e-3 + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99]) +@pytest.mark.parametrize("seq_len", [1, 7, 127, 257]) +@pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) +@pytest.mark.parametrize("num_qo_heads", [4, 8]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("causal", [True, False]) +def test_batch_ragged_prefill_packed_input( + batch_size, seq_len, num_kv_heads, num_qo_heads, head_dim, causal +): + if num_qo_heads % num_kv_heads != 0: + pytest.skip("num_qo_heads must be a multiple of num_kv_heads") + nnz = batch_size * seq_len + qkv_packed = torch.randn( + nnz, + (num_qo_heads + 2 * num_kv_heads) * head_dim, + dtype=torch.float16, + device="cuda:0", + ) + q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) + k = qkv_packed[ + :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim + ].reshape(nnz, num_kv_heads, head_dim) + v = qkv_packed[:, (num_qo_heads + num_kv_heads) * head_dim :].reshape( + nnz, num_kv_heads, head_dim + ) + qo_indptr = torch.tensor( + [i * seq_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" + ) + kv_indptr = qo_indptr + + workspace_buffer = torch.empty( + (256 * 1024 * 1024,), dtype=torch.uint8, device="cuda:0" + ) + wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer) + wrapper.begin_forward( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim, + ) + o_packed = wrapper.forward(q, k, v, causal=causal) + o_contiguous = wrapper.forward( + q.contiguous(), k.contiguous(), v.contiguous(), causal=causal + ) + + numpy.testing.assert_allclose( + o_packed.cpu(), o_contiguous.cpu(), rtol=1e-3, atol=1e-3 + ) + + +if __name__ == "__main__": + test_single_prefill_packed_input(127, 4, 4, 64, True) + test_batch_ragged_prefill_packed_input(37, 127, 4, 4, 64, True) diff --git a/src/cpu_reference.h b/src/cpu_reference.h index 4f0bcd5d..a5c8fb5d 100644 --- a/src/cpu_reference.h +++ b/src/cpu_reference.h @@ -94,7 +94,7 @@ std::vector single_mha(const std::vector& q, const std::vect float max_val = -5e4; if (pos_encoding_mode == PosEncodingMode::kRoPELlama) { q_rotary_local = std::move(cpu_reference::apply_llama_rope( - q.data() + info.get_qo_elem_offset(q_idx, qo_head_idx, 0), head_dim, + q.data() + info.get_q_elem_offset(q_idx, qo_head_idx, 0), head_dim, q_idx + kv_len - qo_len, rope_scale, rope_theta)); } for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { @@ -102,7 +102,7 @@ std::vector single_mha(const std::vector& q, const std::vect switch (pos_encoding_mode) { case PosEncodingMode::kNone: { for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - att[kv_idx] += float(q[info.get_qo_elem_offset(q_idx, qo_head_idx, feat_idx)]) * + att[kv_idx] += float(q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx)]) * float(k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]) * sm_scale; } @@ -147,7 +147,7 @@ std::vector single_mha(const std::vector& q, const std::vect o_float += att[kv_idx] * float(v[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]); } - o[info.get_qo_elem_offset(q_idx, qo_head_idx, feat_idx)] = dtype_out(o_float); + o[info.get_o_elem_offset(q_idx, qo_head_idx, feat_idx)] = dtype_out(o_float); } } } diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 21bd367e..2c6bbc3b 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -32,6 +32,8 @@ cudaError_t SinglePrefillWithKVCacheCustomMask( bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); + auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = + get_qkv_strides(kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {DISPATCH_head_dim( @@ -40,7 +42,7 @@ cudaError_t SinglePrefillWithKVCacheCustomMask( POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION, MaskMode::kCustom>( q, k, v, custom_mask, o, tmp, lse, num_qo_heads, num_kv_heads, qo_len, kv_len, - kv_layout, + qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, /*logits_soft_cap*/ 0.f, sm_scale, rope_scale, rope_theta, stream); })})}); return cudaSuccess; @@ -82,19 +84,22 @@ cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOu cudaStream_t stream = nullptr) { const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = + get_qkv_strides(kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {DISPATCH_mask_mode( mask_mode, MASK_MODE, - {DISPATCH_head_dim(head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - return SinglePrefillWithKVCacheDispatched< - HEAD_DIM, LogitsPostHook::kNone, POS_ENCODING_MODE, - ALLOW_FP16_QK_REDUCTION, MASK_MODE>( - q, k, v, /*custom_mask=*/nullptr, o, tmp, lse, num_qo_heads, - num_kv_heads, qo_len, kv_len, kv_layout, /*logits_soft_cap=*/0.f, - sm_scale, rope_scale, rope_theta, stream); - })})})}); + {DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + return SinglePrefillWithKVCacheDispatched( + q, k, v, /*custom_mask=*/nullptr, o, tmp, lse, num_qo_heads, num_kv_heads, + qo_len, kv_len, qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, + /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); + })})})}); return cudaSuccess; } @@ -109,6 +114,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( const float rope_scale = 1.f, const float rope_theta = 1e4, cudaStream_t stream = nullptr) { const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = + get_qkv_strides(kv_layout, 0, num_qo_heads, num_kv_heads, head_dim); DISPATCH_head_dim( head_dim, HEAD_DIM, {DISPATCH_mask_mode( @@ -121,8 +128,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( MASK_MODE, DTypeIn, DTypeOut, IdType>( handler, q, qo_indptr, k, v, kv_indptr, /*custom_mask=*/nullptr, /*qk_indptr=*/nullptr, q_offset, k_rope_pos_offset, o, lse, num_qo_heads, - num_kv_heads, kv_layout, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, - rope_theta, stream); + num_kv_heads, qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, + /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); })})})}); return cudaSuccess; }