Skip to content

Commit

Permalink
feat: support non-contiguous (packed) input for prefill kernels (#404)
Browse files Browse the repository at this point in the history
This PR implements #311 , after this PR, we support packed qkv input
without explictly convert make the input contiguous:
```python
packed_qkv = W_qkv(x) # (nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim)
q = packed_qkv[..., : num_qo_heads * head_dim].reshape(-1, num_qo_heads, head_dim)
k = packed_qkv[..., num_qo_heads * head_dim: (num_qo_heads + num_kv_heads) * head_dim].reshape(-1, num_kv_heads, head_dim)
v = packed_qkv[..., (num_qo_heads + num_kv_heads) * head_dim:].reshape(-1, num_kv_heads, head_dim)
apply_rope_inplace(q, k, indptr, offsets)
ragged_prefill_wrapper.forward(q, k, v)
```

Before this PR, we need to make `q`/`k`/`v` contiguous before we launch
the attention kernel, which incurs some overhead.

I observe slight (<1%) performance degration after this PR for
non-packed input, which is acceptable IMO.
  • Loading branch information
yzh119 authored Jul 29, 2024
1 parent 4c89dec commit 68c3719
Show file tree
Hide file tree
Showing 11 changed files with 333 additions and 131 deletions.
6 changes: 3 additions & 3 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,11 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim));
}
// apply rotary embedding to q matrix
q_vec = vec_apply_llama_rope<vec_size, bdx>(q + info.get_qo_elem_offset(0, qo_head_idx, 0),
freq, seq_len - 1);
q_vec = vec_apply_llama_rope<vec_size, bdx>(q + info.get_q_elem_offset(0, qo_head_idx, 0), freq,
seq_len - 1);
} else {
// do not apply rotary embedding to q matrix
q_vec.cast_load(q + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size));
q_vec.cast_load(q + info.get_q_elem_offset(0, qo_head_idx, tx * vec_size));
}
// multiple q_vec by sm_scale
#pragma unroll
Expand Down
129 changes: 72 additions & 57 deletions include/flashinfer/attention/prefill.cuh

Large diffs are not rendered by default.

51 changes: 37 additions & 14 deletions include/flashinfer/layout.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#define FLASHINFER_LAYOUT_CUH_

#include <string>
#include <tuple>

namespace flashinfer {

Expand All @@ -36,42 +37,64 @@ __host__ __device__ __forceinline__ size_t get_elem_offset_impl(size_t elem_idx,
return elem_idx * stride_n + head_idx * stride_h + feat_idx;
}

__host__ __forceinline__ auto get_qkv_strides(QKVLayout kv_layout, uint32_t kv_len,
uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t head_dim) {
const uint32_t q_stride_n = num_qo_heads * head_dim, q_stride_h = head_dim,
kv_stride_n = (kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim : head_dim,
kv_stride_h = (kv_layout == QKVLayout::kNHD) ? head_dim : kv_len * head_dim;
return std::make_tuple(q_stride_n, q_stride_h, kv_stride_n, kv_stride_h);
}

struct tensor_info_t {
uint32_t qo_len;
uint32_t kv_len;
uint32_t num_qo_heads;
uint32_t num_kv_heads;
uint32_t qo_stride_n;
uint32_t qo_stride_h;
uint32_t q_stride_n;
uint32_t q_stride_h;
uint32_t kv_stride_n;
uint32_t kv_stride_h;
uint32_t head_dim;
__host__ __device__ __forceinline__ tensor_info_t(uint32_t qo_len, uint32_t kv_len,
uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t qo_stride_n, uint32_t qo_stride_h,
uint32_t kv_stride_n, uint32_t kv_stride_h)
uint32_t q_stride_n, uint32_t q_stride_h,
uint32_t kv_stride_n, uint32_t kv_stride_h,
uint32_t head_dim)
: qo_len(qo_len),
kv_len(kv_len),
num_qo_heads(num_qo_heads),
num_kv_heads(num_kv_heads),
qo_stride_n(qo_stride_n),
qo_stride_h(qo_stride_h),
q_stride_n(q_stride_n),
q_stride_h(q_stride_h),
kv_stride_n(kv_stride_n),
kv_stride_h(kv_stride_h) {}
kv_stride_h(kv_stride_h),
head_dim(head_dim) {}

__host__ __device__ __forceinline__ tensor_info_t(uint32_t qo_len, uint32_t kv_len,
uint32_t num_qo_heads, uint32_t num_kv_heads,
QKVLayout kv_layout, uint32_t head_dim)
: qo_len(qo_len), kv_len(kv_len), num_qo_heads(num_qo_heads), num_kv_heads(num_kv_heads) {
qo_stride_n = num_qo_heads * head_dim;
qo_stride_h = head_dim;
: qo_len(qo_len),
kv_len(kv_len),
num_qo_heads(num_qo_heads),
num_kv_heads(num_kv_heads),
head_dim(head_dim) {
q_stride_n = num_qo_heads * head_dim;
q_stride_h = head_dim;
kv_stride_n = (kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim : head_dim;
kv_stride_h = (kv_layout == QKVLayout::kNHD) ? head_dim : kv_len * head_dim;
}

__host__ __device__ __forceinline__ size_t get_qo_elem_offset(uint32_t qo_idx,
uint32_t qo_head_idx,
uint32_t feat_idx) const {
return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, qo_stride_n, qo_stride_h);
__host__ __device__ __forceinline__ size_t get_q_elem_offset(uint32_t qo_idx,
uint32_t qo_head_idx,
uint32_t feat_idx) const {
return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, q_stride_n, q_stride_h);
}

__host__ __device__ __forceinline__ size_t get_o_elem_offset(uint32_t qo_idx,
uint32_t qo_head_idx,
uint32_t feat_idx) const {
return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, num_qo_heads * head_dim, head_dim);
}

__host__ __device__ __forceinline__ size_t get_kv_elem_offset(uint32_t kv_idx,
Expand Down
18 changes: 10 additions & 8 deletions include/flashinfer/prefill_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, PosEncodingMode PO
cudaError_t SinglePrefillWithKVCacheDispatched(
DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp,
float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len,
QKVLayout kv_layout, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
cudaStream_t stream);
uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, uint32_t kv_stride_h,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream);

template <WarpLayout WARP_LAYOUT, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
Expand All @@ -46,8 +46,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, IdType* o_indptr, DTypeOut* o,
DTypeOut* tmp_v, float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask,
IdType* kv_chunk_size_ptr, uint32_t total_num_rows, uint32_t num_qo_heads,
uint32_t padded_batch_size, uint32_t num_kv_heads, QKVLayout layout, float logits_soft_cap,
float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream = nullptr);
uint32_t padded_batch_size, uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h,
uint32_t kv_stride_n, uint32_t kv_stride_h, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, cudaStream_t stream = nullptr);

template <PageStorage page_storage, WarpLayout WARP_LAYOUT, uint32_t HEAD_DIM,
LogitsPostHook LOGITS_POST_HOOK, PosEncodingMode pos_encoding_mode,
Expand Down Expand Up @@ -117,8 +118,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* q_indptr, DTypeIn* k, DTypeIn* v,
IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset,
IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t num_qo_heads,
uint32_t num_kv_heads, QKVLayout kv_layout, float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta, cudaStream_t stream) {
uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n,
uint32_t kv_stride_h, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
cudaStream_t stream) {
DTypeOut* tmp_v = nullptr;
float* tmp_s = nullptr;
IdType *request_indices = nullptr, *qo_tile_indices = nullptr, *kv_tile_indices = nullptr,
Expand Down Expand Up @@ -154,8 +156,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, k, v, kv_indptr,
custom_mask, qk_indptr, q_offset, k_rope_pos_offset, o_indptr, o, tmp_v, tmp_s, lse,
merge_indptr, block_valid_mask, kv_chunk_size_ptr, total_num_rows, num_qo_heads,
padded_batch_size, num_kv_heads, kv_layout, logits_soft_cap, sm_scale, rope_scale,
rope_theta, stream);
padded_batch_size, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, kv_stride_h,
logits_soft_cap, sm_scale, rope_scale, rope_theta, stream);
});
return cudaSuccess;
}
Expand Down
42 changes: 32 additions & 10 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,10 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
torch::Tensor kv_indptr, bool causal, unsigned int pos_encoding_mode,
bool allow_fp16_qk_reduction, float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta, bool return_lse) {
CHECK_INPUT(q);
CHECK_INPUT(qo_indptr);
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_CUDA(q);
CHECK_CUDA(k);
CHECK_CUDA(v);
CHECK_INPUT(kv_indptr);
auto device = q.device();
CHECK_EQ(device, qo_indptr.device());
Expand All @@ -414,11 +414,22 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
int64_t head_dim = q.size(2);
CHECK_GE(kv_indptr.size(0), batch_size + 1);
int64_t num_kv_heads = (kv_layout_ == QKVLayout::kNHD) ? k.size(1) : k.size(0);
CHECK_EQ(q.stride(2), 1);
CHECK_EQ(k.stride(2), 1);
CHECK_EQ(v.stride(2), 1);
CHECK_EQ(k.size(0), v.size(0));
CHECK_EQ(k.size(1), v.size(1));
CHECK_EQ(k.size(2), v.size(2));
CHECK_EQ(k.size(2), head_dim);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h;
if (kv_layout_ == QKVLayout::kNHD) {
kv_stride_n = k.stride(0);
kv_stride_h = k.stride(1);
} else {
kv_stride_h = k.stride(0);
kv_stride_n = k.stride(1);
}
qo_indptr = qo_indptr.to(torch::kInt32);
kv_indptr = kv_indptr.to(torch::kInt32);

Expand Down Expand Up @@ -453,8 +464,8 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
/*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr,
static_cast<c_type*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
num_qo_heads, num_kv_heads, kv_layout_, logits_soft_cap, sm_scale,
rope_scale, rope_theta,
num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n,
kv_stride_h, logits_soft_cap, sm_scale, rope_scale, rope_theta,
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithRaggedKVCache failed with error ",
Expand All @@ -479,10 +490,10 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr,
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float logits_soft_cap,
float sm_scale, float rope_scale, float rope_theta, bool return_lse) {
CHECK_INPUT(q);
CHECK_INPUT(qo_indptr);
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_CUDA(q);
CHECK_CUDA(k);
CHECK_CUDA(v);
CHECK_INPUT(kv_indptr);
CHECK_INPUT(custom_mask);
CHECK_INPUT(qk_indptr);
Expand All @@ -509,11 +520,22 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
CHECK_GE(kv_indptr.size(0), batch_size + 1);
CHECK_GE(qk_indptr.size(0), batch_size + 1);
int64_t num_kv_heads = (kv_layout_ == QKVLayout::kNHD) ? k.size(1) : k.size(0);
CHECK_EQ(q.stride(2), 1);
CHECK_EQ(k.stride(2), 1);
CHECK_EQ(v.stride(2), 1);
CHECK_EQ(k.size(0), v.size(0));
CHECK_EQ(k.size(1), v.size(1));
CHECK_EQ(k.size(2), v.size(2));
CHECK_EQ(k.size(2), head_dim);
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h;
if (kv_layout_ == QKVLayout::kNHD) {
kv_stride_n = k.stride(0);
kv_stride_h = k.stride(1);
} else {
kv_stride_h = k.stride(0);
kv_stride_n = k.stride(1);
}
qo_indptr = qo_indptr.to(torch::kInt32);
kv_indptr = kv_indptr.to(torch::kInt32);
qk_indptr = qk_indptr.to(torch::kInt32);
Expand Down Expand Up @@ -549,8 +571,8 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
/*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr,
static_cast<c_type*>(o.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
num_qo_heads, num_kv_heads, kv_layout_, logits_soft_cap, sm_scale,
rope_scale, rope_theta,
num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n,
kv_stride_h, logits_soft_cap, sm_scale, rope_scale, rope_theta,
/*stream=*/torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"BatchPrefillWithRaggedKVCache failed with error ",
Expand Down
60 changes: 40 additions & 20 deletions python/csrc/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal,
unsigned int layout, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) {
CHECK_INPUT(q);
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_CUDA(q);
CHECK_CUDA(k);
CHECK_CUDA(v);
CHECK_INPUT(tmp);
auto device = q.device();
CHECK_EQ(k.device(), device);
Expand All @@ -36,6 +36,9 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
CHECK_DIM(3, k);
CHECK_DIM(3, v);
CHECK_SHAPE(k, v);
CHECK_EQ(q.stride(2), 1);
CHECK_EQ(k.stride(2), 1);
CHECK_EQ(v.stride(2), 1);
CHECK_EQ(q.size(2), k.size(2));
CHECK_EQ(q.scalar_type(), k.scalar_type());
CHECK_EQ(q.scalar_type(), v.scalar_type());
Expand All @@ -44,12 +47,17 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
qo_len = q.size(0);
num_qo_heads = q.size(1);
uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h;
if (kv_layout == QKVLayout::kNHD) {
kv_len = k.size(0);
num_kv_heads = k.size(1);
} else {
kv_stride_n = k.stride(0);
kv_stride_h = k.stride(1);
} else { // QKVLayout::kHND
kv_len = k.size(1);
num_kv_heads = k.size(0);
kv_stride_h = k.stride(0);
kv_stride_n = k.stride(1);
}
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
Expand All @@ -63,7 +71,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
const LogitsPostHook logits_post_hook =
logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone;

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] {
Expand All @@ -72,16 +80,19 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] {
return DISPATCH_pos_encoding_mode(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
cudaError_t status = SinglePrefillWithKVCacheDispatched<
HEAD_DIM, LOGITS_POST_HOOK, POS_ENCODING_MODE, ALLOW_FP16_QK_REDUCTION,
MASK_MODE>(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(v.data_ptr()),
/*custom_mask=*/nullptr, static_cast<c_type*>(o.data_ptr()),
static_cast<c_type*>(tmp.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
num_qo_heads, num_kv_heads, qo_len, kv_len, kv_layout, logits_soft_cap,
sm_scale, rope_scale, rope_theta, torch_current_stream);
cudaError_t status =
SinglePrefillWithKVCacheDispatched<HEAD_DIM, LOGITS_POST_HOOK,
POS_ENCODING_MODE,
ALLOW_FP16_QK_REDUCTION, MASK_MODE>(
static_cast<c_type*>(q.data_ptr()),
static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(v.data_ptr()),
/*custom_mask=*/nullptr, static_cast<c_type*>(o.data_ptr()),
static_cast<c_type*>(tmp.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h,
kv_stride_n, kv_stride_h, logits_soft_cap, sm_scale, rope_scale,
rope_theta, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"SinglePrefillWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
Expand All @@ -105,9 +116,9 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache_custom_mask(
torch::Tensor tmp, unsigned int layout, unsigned int pos_encoding_mode,
bool allow_fp16_qk_reduction, float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta, bool return_lse) {
CHECK_INPUT(q);
CHECK_INPUT(k);
CHECK_INPUT(v);
CHECK_CUDA(q);
CHECK_CUDA(k);
CHECK_CUDA(v);
CHECK_INPUT(packed_custom_mask);
auto device = q.device();
CHECK_EQ(k.device(), device);
Expand All @@ -118,6 +129,9 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache_custom_mask(
CHECK_DIM(3, v);
CHECK_DIM(1, packed_custom_mask);
CHECK_SHAPE(k, v);
CHECK_EQ(q.stride(2), 1);
CHECK_EQ(k.stride(2), 1);
CHECK_EQ(v.stride(2), 1);
CHECK_EQ(q.size(2), k.size(2));
// packed_custom_mask must be uint8
TORCH_CHECK(packed_custom_mask.scalar_type() == torch::kUInt8,
Expand All @@ -127,12 +141,17 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache_custom_mask(
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
qo_len = q.size(0);
num_qo_heads = q.size(1);
uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), kv_stride_n, kv_stride_h;
if (kv_layout == QKVLayout::kNHD) {
kv_len = k.size(0);
num_kv_heads = k.size(1);
kv_stride_n = k.stride(0);
kv_stride_h = k.stride(1);
} else {
kv_len = k.size(1);
num_kv_heads = k.size(0);
kv_stride_h = k.stride(0);
kv_stride_n = k.stride(1);
}
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
Expand Down Expand Up @@ -164,8 +183,9 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache_custom_mask(
static_cast<c_type*>(o.data_ptr()),
static_cast<c_type*>(tmp.data_ptr()),
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
num_qo_heads, num_kv_heads, qo_len, kv_len, kv_layout, logits_soft_cap,
sm_scale, rope_scale, rope_theta, torch_current_stream);
num_qo_heads, num_kv_heads, qo_len, kv_len, q_stride_n, q_stride_h,
kv_stride_n, kv_stride_h, logits_soft_cap, sm_scale, rope_scale,
rope_theta, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"SinglePrefillWithKVCache kernel launch failed, error: " +
std::string(cudaGetErrorString(status)));
Expand Down
Loading

0 comments on commit 68c3719

Please sign in to comment.