Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: non-inplace rope operators #405

Merged
merged 2 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/python/rope.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ Kernels for applying rotary embeddings.

apply_rope_inplace
apply_llama31_rope_inplace
apply_rope
apply_llama31_rope
4 changes: 2 additions & 2 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(
o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len,
/*o_stride_n=*/
partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim,
partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim,
/*o_stride_h=*/head_dim, group_size);

// write lse
Expand Down Expand Up @@ -1732,7 +1732,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(
o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len,
/*o_stride_n=*/
partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim,
partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim,
/*o_stride_h=*/head_dim, group_size);

// write lse
Expand Down
174 changes: 174 additions & 0 deletions include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,86 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(
}
}

template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
typename IdType>
__global__ void BatchQKApplyRotaryKernel(DType* __restrict__ q, DType* __restrict__ k,
DType* __restrict__ q_rope, DType* __restrict__ k_rope,
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, size_t q_stride_n,
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
float smooth_a, float smooth_b, float rope_rcp_scale,
float rope_rcp_theta) {
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
const uint32_t bdy = blockDim.y;
vec_t<float, vec_size> freq;
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
if constexpr (interleave) {
freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(head_dim));
} else {
freq[i] = __powf(rope_rcp_theta,
float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim));
}

float smooth = freq[i] * smooth_a + smooth_b;
smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1]
freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
}

if (bx < batch_size * num_qo_heads) {
// apply rotary to q
const uint32_t batch_idx = bx / num_qo_heads;
const uint32_t qo_head_idx = bx % num_qo_heads;
const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx];
const uint32_t offset = offsets[batch_idx];
#pragma unroll 2
for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) {
vec_t<float, vec_size> q_vec;
if (i * bdy + ty < seq_len) {
DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0,
q_stride_n, q_stride_h);
DType* q_rope_ptr =
q_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0,
/*q_stride_n=*/num_qo_heads * head_dim,
/*q_stride_h=*/head_dim);
if constexpr (interleave) {
q_vec =
vec_apply_llama_rope_interleave<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
} else {
q_vec = vec_apply_llama_rope<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
}
q_vec.cast_store(q_rope_ptr + tx * vec_size);
}
}
} else {
// apply rotary to k
uint32_t batch_idx = (bx - batch_size * num_qo_heads) / num_kv_heads;
uint32_t kv_head_idx = (bx - batch_size * num_qo_heads) % num_kv_heads;
const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx];
const uint32_t offset = offsets[batch_idx];
#pragma unroll 2
for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) {
vec_t<float, vec_size> k_vec;
if (i * bdy + ty < seq_len) {
DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0,
k_stride_n, k_stride_h);
DType* k_rope_ptr =
k_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0,
/*kv_stride_n=*/num_kv_heads * head_dim,
/*kv_stride_h=*/head_dim);
if constexpr (interleave) {
k_vec =
vec_apply_llama_rope_interleave<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
} else {
k_vec = vec_apply_llama_rope<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
}
k_vec.cast_store(k_rope_ptr + +tx * vec_size);
}
}
}
}

#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
Expand Down Expand Up @@ -289,6 +369,100 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace(
return cudaSuccess;
}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyRotary(DType* __restrict__ q, DType* __restrict__ k,
DType* __restrict__ q_rope, DType* __restrict__ k_rope,
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t head_dim, size_t q_stride_n, size_t q_stride_h,
size_t k_stride_n, size_t k_stride_h, bool interleave,
float rope_scale, float rope_theta, cudaStream_t stream = nullptr) {
float rope_rcp_scale = 1.0f / rope_scale;
float rope_rcp_theta = 1.0f / rope_theta;
float smooth_a = 0.f;
float smooth_b = 0.f;

DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
uint32_t num_threads = std::max(128U, bdx);
uint32_t bdy = num_threads / bdx;
dim3 nblks(batch_size * (num_qo_heads + num_kv_heads));
dim3 nthrs(bdx, bdy);
auto kernel = BatchQKApplyRotaryKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&q_rope,
(void*)&k_rope,
(void*)&indptr,
(void*)&offsets,
(void*)&batch_size,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&q_stride_n,
(void*)&q_stride_h,
(void*)&k_stride_n,
(void*)&k_stride_h,
(void*)&smooth_a,
(void*)&smooth_b,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
});

return cudaSuccess;
}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyLlama31Rotary(DType* __restrict__ q, DType* __restrict__ k,
DType* __restrict__ q_rope, DType* __restrict__ k_rope,
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n,
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
bool interleave, float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length, cudaStream_t stream = nullptr) {
float rope_rcp_scale = 1.0f / rope_scale;
float rope_rcp_theta = 1.0f / rope_theta;
float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor);
float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f);

DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
uint32_t num_threads = std::max(128U, bdx);
uint32_t bdy = num_threads / bdx;
dim3 nblks(batch_size * (num_qo_heads + num_kv_heads));
dim3 nthrs(bdx, bdy);
auto kernel = BatchQKApplyRotaryKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&q_rope,
(void*)&k_rope,
(void*)&indptr,
(void*)&offsets,
(void*)&batch_size,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&q_stride_n,
(void*)&q_stride_h,
(void*)&k_stride_n,
(void*)&k_stride_h,
(void*)&smooth_a,
(void*)&smooth_b,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
});

return cudaSuccess;
}

} // namespace flashinfer

#endif // FLASHINFER_POS_ENC_CUH_
2 changes: 2 additions & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
"Apply Llama 3.1 style RoPE in-place");
m.def("apply_rope", &apply_rope, "Apply RoPE");
m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE");
m.def("packbits", &packbits, "GPU packbits operator");
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
Expand Down
10 changes: 10 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor
float rope_theta, float low_freq_factor, float high_freq_factor,
float old_context_length);

std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta);

std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length);

torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);

torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
Expand Down
100 changes: 99 additions & 1 deletion python/csrc/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,102 @@ void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor
std::string(cudaGetErrorString(status)));
return true;
});
}
}

std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(indptr);
CHECK_INPUT(offsets);

auto device = q.device();
CHECK_EQ(k.device(), device);
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
CHECK_DIM(3, k); // k: (nnz, H_K, D)
CHECK_DIM(1, indptr); // indptr: (B + 1)
CHECK_DIM(1, offsets); // offsets: (B)
CHECK_EQ(q.size(0), k.size(0));
CHECK_EQ(q.size(2), k.size(2));
unsigned int num_qo_heads = q.size(1);
unsigned int num_kv_heads = k.size(1);
unsigned int head_dim = q.size(2);
unsigned int batch_size = offsets.size(0);
CHECK_EQ(indptr.size(0), batch_size + 1);
size_t q_stride_n = q.stride(0);
size_t q_stride_h = q.stride(1);
size_t k_stride_n = k.stride(0);
size_t k_stride_h = k.stride(1);
indptr = indptr.to(torch::kInt32);
offsets = offsets.to(torch::kInt32);
// NOTE(Zihao): empty_like do not copy strides so it's okay to use it here.
auto q_rope = torch::empty_like(q);
auto k_rope = torch::empty_like(k);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyRotary(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(q_rope.data_ptr()), static_cast<c_type*>(k_rope.data_ptr()),
static_cast<int32_t*>(indptr.data_ptr()), static_cast<int32_t*>(offsets.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n,
k_stride_h, interleave, rope_scale, rope_theta, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotary failed with error code " +
std::string(cudaGetErrorString(status)));
return true;
});

return {q_rope, k_rope};
}

std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(indptr);
CHECK_INPUT(offsets);

auto device = q.device();
CHECK_EQ(k.device(), device);
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
CHECK_DIM(3, k); // k: (nnz, H_K, D)
CHECK_DIM(1, indptr); // indptr: (B + 1)
CHECK_DIM(1, offsets); // offsets: (B)
CHECK_EQ(q.size(0), k.size(0));
CHECK_EQ(q.size(2), k.size(2));
unsigned int num_qo_heads = q.size(1);
unsigned int num_kv_heads = k.size(1);
unsigned int head_dim = q.size(2);
unsigned int batch_size = offsets.size(0);
CHECK_EQ(indptr.size(0), batch_size + 1);
size_t q_stride_n = q.stride(0);
size_t q_stride_h = q.stride(1);
size_t k_stride_n = k.stride(0);
size_t k_stride_h = k.stride(1);
indptr = indptr.to(torch::kInt32);
offsets = offsets.to(torch::kInt32);

// NOTE(Zihao): empty_like do not copy strides so it's okay to use it here.
auto q_rope = torch::empty_like(q);
auto k_rope = torch::empty_like(k);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyLlama31Rotary(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<c_type*>(q_rope.data_ptr()), static_cast<c_type*>(k_rope.data_ptr()),
static_cast<int32_t*>(indptr.data_ptr()), static_cast<int32_t*>(offsets.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n,
k_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor,
old_context_length, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31Rotary failed with error code " +
std::string(cudaGetErrorString(status)));
return true;
});

return {q_rope, k_rope};
}
2 changes: 1 addition & 1 deletion python/csrc/single_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
const LogitsPostHook logits_post_hook =
logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone;

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] {
return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] {
Expand Down
7 changes: 6 additions & 1 deletion python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@
chain_speculative_sampling,
)
from .norm import rmsnorm
from .rope import apply_rope_inplace, apply_llama31_rope_inplace
from .rope import (
apply_rope_inplace,
apply_llama31_rope_inplace,
apply_rope,
apply_llama31_rope,
)
from .group_gemm import SegmentGEMMWrapper
from .quantization import packbits, segment_packbits

Expand Down
Loading