Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sliding window attention #406

Merged
merged 7 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ template <LogitsPostHook logits_post_hook, PosEncodingMode pos_encoding_mode, ui
__device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage_idx,
const vec_t<float, vec_size>& q_vec,
const vec_t<float, vec_size>& freq, uint32_t kv_idx_base,
uint32_t iter_base, uint32_t iter_bound,
const int32_t q_offset, float alibi_slope, float* s,
state_t<vec_size>& st, const float logits_soft_cap) {
uint32_t iter_base, uint32_t left_close_bound,
uint32_t iter_bound, const int32_t q_offset,
float alibi_slope, float* s, state_t<vec_size>& st,
const float logits_soft_cap) {
uint32_t tx = threadIdx.x, tz = threadIdx.z;
float m_prev = st.m;
#pragma unroll
Expand All @@ -100,9 +101,10 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage
s[j] += math::shfl_xor_sync(s[j], offset);
}
s[j] = apply_logits_post_hook<logits_post_hook>(s[j], logits_soft_cap);
s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -5e4;
const uint32_t pos = kv_idx_base + tz * tile_size + j;
s[j] = (iter_base + tz * tile_size + j < iter_bound && pos >= left_close_bound) ? s[j] : -5e4;
if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) {
s[j] += alibi_slope * float(int(kv_idx_base + tz * tile_size + j) - q_offset);
s[j] += alibi_slope * float(int(pos) - q_offset);
}
st.m = max(st.m, s[j]);
}
Expand Down Expand Up @@ -212,9 +214,9 @@ template <LogitsPostHook logits_post_hook, bool partition_kv, PosEncodingMode po
__global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k,
DTypeKV* __restrict__ v, DTypeOut* __restrict__ o,
float* __restrict__ lse, tensor_info_t info,
float logits_soft_cap, float sm_scale,
float rope_rcp_scale, float rope_rcp_theta,
uint32_t kv_chunk_size) {
int32_t window_left, float logits_soft_cap,
float sm_scale, float rope_rcp_scale,
float rope_rcp_theta, uint32_t kv_chunk_size) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
sm_scale *=
Expand All @@ -227,6 +229,8 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
uint32_t num_qo_heads = info.num_qo_heads;
const float alibi_slope = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e;
uint32_t seq_len = info.kv_len;
uint32_t left_close_bound =
(window_left >= 0) ? sub_if_greater_or_zero(seq_len, window_left + 1) : 0;

extern __shared__ uint8_t smem[];
DTypeKV* k_smem = (DTypeKV*)smem;
Expand Down Expand Up @@ -303,8 +307,8 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
block.sync();
compute_qk<logits_post_hook, pos_encoding_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size,
seq_len - 1, alibi_slope, s, st_local, logits_soft_cap);
freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, left_close_bound,
kv_chunk_size, seq_len - 1, alibi_slope, s, st_local, logits_soft_cap);
block.sync();
// load k
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
Expand Down Expand Up @@ -389,8 +393,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
float* __restrict__ lse, bool* __restrict__ block_valid_mask, float logits_soft_cap,
float sm_scale, float rope_rcp_scale, float rope_rcp_theta) {
float* __restrict__ lse, bool* __restrict__ block_valid_mask, int32_t window_left,
float logits_soft_cap, float sm_scale, float rope_rcp_scale, float rope_rcp_theta) {
auto block = cg::this_thread_block();
sm_scale *=
(logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap));
Expand All @@ -415,6 +419,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
: 0;
const uint32_t seq_len =
partition_kv ? kv_partition_info.seq_lens_before_partition[batch_idx] : kv_chunk_len;
const uint32_t left_close_bound =
(window_left >= 0) ? sub_if_greater_or_zero(seq_len, window_left + 1) : 0;
const uint32_t mapped_batch_idx =
partition_kv ? kv_partition_info.batch_idx_map[batch_idx] : batch_idx;

Expand Down Expand Up @@ -521,8 +527,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
freq,
(paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) +
cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz,
iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, q_offset_val, alibi_slope, s, st,
logits_soft_cap);
iter * tile_size_per_bdx * bdy * bdz, left_close_bound, kv_chunk_len, q_offset_val,
alibi_slope, s, st, logits_soft_cap);
block.sync();

#pragma unroll
Expand Down Expand Up @@ -627,8 +633,9 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, PosEncodingMode PO
cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t seq_len,
QKVLayout kv_layout, float logits_soft_cap,
float sm_scale, float rope_scale, float rope_theta,
QKVLayout kv_layout, int32_t window_left,
float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta,
cudaStream_t stream) {
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;
Expand Down Expand Up @@ -664,6 +671,7 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
(void*)&o,
(void*)&lse,
(void*)&info,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
Expand Down Expand Up @@ -704,6 +712,7 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
(void*)&tmp,
(void*)&tmp_lse,
(void*)&info,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
Expand All @@ -724,7 +733,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
DTypeQ* q, IdType* q_offset, paged_kv_t<page_storage, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
cudaStream_t stream) {
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;
Expand Down Expand Up @@ -761,6 +770,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
(void*)&o,
(void*)&lse,
(void*)&block_valid_mask,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
Expand All @@ -782,6 +792,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
(void*)&tmp_v,
(void*)&tmp_s,
(void*)&block_valid_mask,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
Expand Down
4 changes: 2 additions & 2 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
float* __restrict__ lse, bool* __restrict__ block_valid_mask, float logits_soft_cap,
float sm_scale, float rope_rcp_scale, float rope_rcp_theta);
float* __restrict__ lse, bool* __restrict__ block_valid_mask, int maybe_window_left,
float logits_soft_cap, float sm_scale, float rope_rcp_scale, float rope_rcp_theta);

/*!
* \brief Compute the maximum number of pages per batch and the new batch size
Expand Down
Loading