From 4e007852eaaedaff443c802896e5d778eec6c835 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 27 Jul 2024 06:57:59 +0000 Subject: [PATCH 01/13] upd --- include/flashinfer/pos_enc.cuh | 84 ++++++++++++++++++++++++++++++---- 1 file changed, 75 insertions(+), 9 deletions(-) diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 7f6cab39..e3936b4d 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -16,6 +16,7 @@ #ifndef FLASHINFER_POS_ENC_CUH_ #define FLASHINFER_POS_ENC_CUH_ +#include #include #include "layout.cuh" @@ -98,6 +99,8 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(DType* __restrict__ q, DType* __ IdType* __restrict__ indptr, IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + size_t q_stride_n, size_t q_stride_h, + size_t k_stride_n, size_t k_stride_h, float rope_rcp_scale, float rope_rcp_theta) { uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; const uint32_t bdy = blockDim.y; @@ -120,7 +123,7 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(DType* __restrict__ q, DType* __ vec_t q_vec; if (i * bdy + ty < seq_len) { DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0, - num_qo_heads * head_dim, head_dim); + q_stride_n, q_stride_h); q_vec = vec_apply_llama_rope(q_ptr, freq, offset + i * bdy + ty); q_vec.cast_store(q_ptr + tx * vec_size); } @@ -136,7 +139,60 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(DType* __restrict__ q, DType* __ vec_t k_vec; if (i * bdy + ty < seq_len) { DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0, - num_kv_heads * head_dim, head_dim); + k_stride_n, k_stride_h); + k_vec = vec_apply_llama_rope(k_ptr, freq, offset + i * bdy + ty); + k_vec.cast_store(k_ptr + tx * vec_size); + } + } + } +} + +template +__global__ void Llama31BatchQKApplyRotaryInPlaceKernel( + DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, + IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, float smooth_a, + float smooth_b, float rope_rcp_scale, float rope_rcp_theta) { + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; + const uint32_t bdy = blockDim.y; + vec_t freq; +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + freq[i] = + __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); + float smooth = freq[i] * smooth_a + smooth_b; + smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] + freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * (freq[i] * rope_rcp_scale); + } + + if (bx < batch_size * num_qo_heads) { + // apply rotary to q + const uint32_t batch_idx = bx / num_qo_heads; + const uint32_t qo_head_idx = bx % num_qo_heads; + const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx]; + const uint32_t offset = offsets[batch_idx]; +#pragma unroll 2 + for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) { + vec_t q_vec; + if (i * bdy + ty < seq_len) { + DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0, + q_stride_n, q_stride_h); + q_vec = vec_apply_llama_rope(q_ptr, freq, offset + i * bdy + ty); + q_vec.cast_store(q_ptr + tx * vec_size); + } + } + } else { + // apply rotary to k + uint32_t batch_idx = (bx - batch_size * num_qo_heads) / num_kv_heads; + uint32_t kv_head_idx = (bx - batch_size * num_qo_heads) % num_kv_heads; + const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx]; + const uint32_t offset = offsets[batch_idx]; +#pragma unroll 2 + for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) { + vec_t k_vec; + if (i * bdy + ty < seq_len) { + DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0, + k_stride_n, k_stride_h); k_vec = vec_apply_llama_rope(k_ptr, freq, offset + i * bdy + ty); k_vec.cast_store(k_ptr + tx * vec_size); } @@ -145,14 +201,18 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(DType* __restrict__ q, DType* __ } template -cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k, - IdType* __restrict__ indptr, IdType* __restrict__ offsets, - uint32_t batch_size, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t head_dim, - float rope_scale = 1.f, float rope_theta = 1e4, - cudaStream_t stream = nullptr) { +cudaError_t Llama31BatchQKApplyRotaryInPlace( + DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, + IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + float rope_scale = 1.f, float rope_theta = 1e4, float low_freq_factor = 1.f, + float high_freq_factor = 4.f, float old_context_length = 8192, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; + float low_freq_wavelen = old_context_length / low_freq_factor; + float high_freq_wavelen = old_context_length / high_freq_factor; + float smooth_a = 1.0f / (2 * M_PI / high_freq_wavelen - 2 * M_PI / low_freq_wavelen); + float smooth_b = -1.0f / (low_freq_wavelen / high_freq_wavelen - 1.0f); DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); @@ -161,7 +221,7 @@ cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ uint32_t bdy = num_threads / bdx; dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); dim3 nthrs(bdx, bdy); - auto kernel = BatchQKApplyRotaryInPlaceKernel; + auto kernel = Llama31BatchQKApplyRotaryInPlaceKernel; void* args[] = {(void*)&q, (void*)&k, (void*)&indptr, @@ -169,6 +229,12 @@ cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ (void*)&batch_size, (void*)&num_qo_heads, (void*)&num_kv_heads, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&smooth_a, + (void*)&smooth_b, (void*)&rope_rcp_scale, (void*)&rope_rcp_theta}; FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); From 0cc5b57357438f2dc7457507fed4ec01a455e0e3 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 27 Jul 2024 07:03:14 +0000 Subject: [PATCH 02/13] upd --- include/flashinfer/pos_enc.cuh | 107 +++++++++++++++------------------ 1 file changed, 49 insertions(+), 58 deletions(-) diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index e3936b4d..b60fc34c 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -95,60 +95,7 @@ __device__ __forceinline__ vec_t vec_apply_llama_rope( } template -__global__ void BatchQKApplyRotaryInPlaceKernel(DType* __restrict__ q, DType* __restrict__ k, - IdType* __restrict__ indptr, - IdType* __restrict__ offsets, uint32_t batch_size, - uint32_t num_qo_heads, uint32_t num_kv_heads, - size_t q_stride_n, size_t q_stride_h, - size_t k_stride_n, size_t k_stride_h, - float rope_rcp_scale, float rope_rcp_theta) { - uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; - const uint32_t bdy = blockDim.y; - vec_t freq; -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - freq[i] = - rope_rcp_scale * - __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); - } - - if (bx < batch_size * num_qo_heads) { - // apply rotary to q - const uint32_t batch_idx = bx / num_qo_heads; - const uint32_t qo_head_idx = bx % num_qo_heads; - const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx]; - const uint32_t offset = offsets[batch_idx]; -#pragma unroll 2 - for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) { - vec_t q_vec; - if (i * bdy + ty < seq_len) { - DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0, - q_stride_n, q_stride_h); - q_vec = vec_apply_llama_rope(q_ptr, freq, offset + i * bdy + ty); - q_vec.cast_store(q_ptr + tx * vec_size); - } - } - } else { - // apply rotary to k - uint32_t batch_idx = (bx - batch_size * num_qo_heads) / num_kv_heads; - uint32_t kv_head_idx = (bx - batch_size * num_qo_heads) % num_kv_heads; - const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx]; - const uint32_t offset = offsets[batch_idx]; -#pragma unroll 2 - for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) { - vec_t k_vec; - if (i * bdy + ty < seq_len) { - DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0, - k_stride_n, k_stride_h); - k_vec = vec_apply_llama_rope(k_ptr, freq, offset + i * bdy + ty); - k_vec.cast_store(k_ptr + tx * vec_size); - } - } - } -} - -template -__global__ void Llama31BatchQKApplyRotaryInPlaceKernel( +__global__ void BatchQKApplyRotaryInPlaceKernel( DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, float smooth_a, @@ -162,7 +109,7 @@ __global__ void Llama31BatchQKApplyRotaryInPlaceKernel( __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); float smooth = freq[i] * smooth_a + smooth_b; smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] - freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * (freq[i] * rope_rcp_scale); + freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; } if (bx < batch_size * num_qo_heads) { @@ -200,13 +147,57 @@ __global__ void Llama31BatchQKApplyRotaryInPlaceKernel( } } +template +cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k, + IdType* __restrict__ indptr, IdType* __restrict__ offsets, + uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n, + size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + float rope_scale, float rope_theta, + cudaStream_t stream = nullptr) { + float rope_rcp_scale = 1.0f / rope_scale; + float rope_rcp_theta = 1.0f / rope_theta; + float low_freq_wavelen = old_context_length / low_freq_factor; + float high_freq_wavelen = old_context_length / high_freq_factor; + float smooth_a = 0.f; + float smooth_b = 1.f; + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = num_threads / bdx; + dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); + dim3 nthrs(bdx, bdy); + auto kernel = BatchQKApplyRotaryInPlaceKernel; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&indptr, + (void*)&offsets, + (void*)&batch_size, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&smooth_a, + (void*)&smooth_b, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + }); + + return cudaSuccess; +} + template cudaError_t Llama31BatchQKApplyRotaryInPlace( DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, - float rope_scale = 1.f, float rope_theta = 1e4, float low_freq_factor = 1.f, - float high_freq_factor = 4.f, float old_context_length = 8192, cudaStream_t stream = nullptr) { + float rope_scale, float rope_theta, float low_freq_factor, float high_freq_factor, + float old_context_length, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; float low_freq_wavelen = old_context_length / low_freq_factor; @@ -221,7 +212,7 @@ cudaError_t Llama31BatchQKApplyRotaryInPlace( uint32_t bdy = num_threads / bdx; dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); dim3 nthrs(bdx, bdy); - auto kernel = Llama31BatchQKApplyRotaryInPlaceKernel; + auto kernel = BatchQKApplyRotaryInPlaceKernel; void* args[] = {(void*)&q, (void*)&k, (void*)&indptr, From 2309374f2072f6fb0eb180b0a9f5edeae2901075 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 27 Jul 2024 07:10:56 +0000 Subject: [PATCH 03/13] simplify --- include/flashinfer/pos_enc.cuh | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index b60fc34c..298f69bb 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -200,10 +200,8 @@ cudaError_t Llama31BatchQKApplyRotaryInPlace( float old_context_length, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; - float low_freq_wavelen = old_context_length / low_freq_factor; - float high_freq_wavelen = old_context_length / high_freq_factor; - float smooth_a = 1.0f / (2 * M_PI / high_freq_wavelen - 2 * M_PI / low_freq_wavelen); - float smooth_b = -1.0f / (low_freq_wavelen / high_freq_wavelen - 1.0f); + float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); + float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f); DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); From f7d064dd543cfb754129a5f205d209feb84809be Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 27 Jul 2024 07:13:27 +0000 Subject: [PATCH 04/13] bugfix --- include/flashinfer/pos_enc.cuh | 2 -- 1 file changed, 2 deletions(-) diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 298f69bb..b639a03f 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -157,8 +157,6 @@ cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; - float low_freq_wavelen = old_context_length / low_freq_factor; - float high_freq_wavelen = old_context_length / high_freq_factor; float smooth_a = 0.f; float smooth_b = 1.f; From 827cb9977a50603ea3c06f6155d7c848a43596f2 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 27 Jul 2024 07:47:16 +0000 Subject: [PATCH 05/13] upd --- include/flashinfer/pos_enc.cuh | 2 +- python/csrc/flashinfer_ops.cu | 3 + python/csrc/flashinfer_ops.h | 8 +++ python/csrc/rope.cu | 102 +++++++++++++++++++++++++++++++++ src/tvm_wrapper.cu | 7 ++- 5 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 python/csrc/rope.cu diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index b639a03f..1cd5db14 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -190,7 +190,7 @@ cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ } template -cudaError_t Llama31BatchQKApplyRotaryInPlace( +cudaError_t BatchQKApplyLlama31RotaryInPlace( DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 4193f304..79c34b21 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -42,6 +42,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("chain_speculative_sampling", &chain_speculative_sampling, "Speculative sampling from sequence of probabilities"); m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); + m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); + m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, + "Apply Llama 3.1 style RoPE in-place"); m.def("packbits", &packbits, "GPU packbits operator"); m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); py::class_(m, diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index d837528f..cf309aff 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -75,6 +75,14 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps); +void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, float rope_scale, float rope_theta); + +void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length); + torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu new file mode 100644 index 00000000..243ca458 --- /dev/null +++ b/python/csrc/rope.cu @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "flashinfer_ops.h" +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, float rope_scale, float rope_theta) { + CHECK_CUDA(q); // not necessarily contiguous + CHECK_CUDA(k); // not necessarily contiguous + CHECK_INPUT(indptr); + CHECK_INPUT(offsets); + + auto device = q.device(); + CHECK_EQ(k.device(), device); + CHECK_DIM(3, q); // q: (nnz, H_Q, D) + CHECK_DIM(3, k); // k: (nnz, H_K, D) + CHECK_DIM(1, indptr); // indptr: (B + 1) + CHECK_DIM(1, offsets); // offsets: (B) + CHECK_EQ(q.size(0), k.size(0)); + CHECK_EQ(q.size(2), k.size(2)); + unsigned int num_qo_heads = q.size(1); + unsigned int num_kv_heads = k.size(1); + unsigned int head_dim = q.size(2); + unsigned int batch_size = offsets.size(0); + CHECK_EQ(indptr.size(0), batch_size + 1); + size_t q_stride_n = q.stride(0); + size_t q_stride_h = q.stride(1); + size_t k_stride_n = k.stride(0); + size_t k_stride_h = k.stride(1); + indptr = indptr.to(torch::kInt32); + offsets = offsets.to(torch::kInt32); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { + cudaError_t status = BatchQKApplyRotaryInPlace( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_h, q_stride_h, k_stride_n, + k_stride_h, rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryInPlace failed with error code " + + std::string(cudaGetErrorString(status))); + }); +} + +void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length) { + CHECK_CUDA(q); // not necessarily contiguous + CHECK_CUDA(k); // not necessarily contiguous + CHECK_INPUT(indptr); + CHECK_INPUT(offsets); + + auto device = q.device(); + CHECK_EQ(k.device(), device); + CHECK_DIM(3, q); // q: (nnz, H_Q, D) + CHECK_DIM(3, k); // k: (nnz, H_K, D) + CHECK_DIM(1, indptr); // indptr: (B + 1) + CHECK_DIM(1, offsets); // offsets: (B) + CHECK_EQ(q.size(0), k.size(0)); + CHECK_EQ(q.size(2), k.size(2)); + unsigned int num_qo_heads = q.size(1); + unsigned int num_kv_heads = k.size(1); + unsigned int head_dim = q.size(2); + unsigned int batch_size = offsets.size(0); + CHECK_EQ(indptr.size(0), batch_size + 1); + size_t q_stride_n = q.stride(0); + size_t q_stride_h = q.stride(1); + size_t k_stride_n = k.stride(0); + size_t k_stride_h = k.stride(1); + indptr = indptr.to(torch::kInt32); + offsets = offsets.to(torch::kInt32); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { + cudaError_t status = BatchQKApplyLlama31RotaryInPlace( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_h, q_stride_h, k_stride_n, + k_stride_h, rope_scale, rope_theta, low_freq_factor, high_freq_factor, old_context_length, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31RotaryInPlace failed with error code " + + std::string(cudaGetErrorString(status))); + }); +} \ No newline at end of file diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index 73c41dbf..4a5b9767 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -681,12 +681,17 @@ void _FlashInferBatchQKApplyRotaryInPlace(DLTensor* q, DLTensor* k, DLTensor* in DLTensor* offsets, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, double rope_scale, double rope_theta) { + size_t q_stride_n = q->strides[0]; + size_t q_stride_h = q->strides[1]; + size_t k_stride_n = k->strides[0]; + size_t k_stride_h = k->strides[1]; DISPATCH_TVM_CUDA_DTYPE( q->dtype, dtype, {DISPATCH_TVM_CUDA_IDTYPE(indptr->dtype, idtype, { cudaError_t status = BatchQKApplyRotaryInPlace( static_cast(q->data), static_cast(k->data), static_cast(indptr->data), static_cast(offsets->data), batch_size, - num_qo_heads, num_kv_heads, head_dim, rope_scale, rope_theta); + num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, + rope_scale, rope_theta); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); } From bd5c2d0c01ebf789f173fac0fd306c1602a8e0c2 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 27 Jul 2024 07:59:47 +0000 Subject: [PATCH 06/13] python --- python/flashinfer/rope.py | 105 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 python/flashinfer/rope.py diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py new file mode 100644 index 00000000..922e968a --- /dev/null +++ b/python/flashinfer/rope.py @@ -0,0 +1,105 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch + +# mypy: disable-error-code="attr-defined" +try: + from . import _kernels +except ImportError as e: + import os + import logging + + if os.environ.get("BUILD_DOC", "0") == "1": + _kernels = None + logging.warning("Kernels are not loaded in documentation build mode.") + else: + raise e + + +def apply_rope_inplace( + q: torch.Tensor, + k: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + rope_scale: float = 1, + rope_theta: float = 1e4, +) -> None: + r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. + + Parameters + ---------- + q : torch.Tensor + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`. + k : torch.Tensor + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``. + indptr : torch.Tensor + Indptr tensor, shape: ``(batch_size + 1)``. + offsets : torch.Tensor + The relative position offsets of each query in the batch, shape: ``(batch_size)``. + rope_scale : float + The scaling factor used in the rope embedding, default: ``1``. + rope_theta : float + The theta value used in the rope embedding, default: ``1e4``. + """ + return _kernels.apply_rope_inplace(q, k, indptr, offsets, rope_scale, rope_theta) + + +def apply_llama31_rope_inplace( + q: torch.Tensor, + k: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + rope_scale: float = 8, + rope_theta: float = 1e4, + low_freq_factor: float = 1, + high_freq_factor: float = 4, + old_context_len: int = 8192, +) -> None: + r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. + + Parameters + ---------- + q : torch.Tensor + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``. + k : torch.Tensor + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``. + indptr : torch.Tensor + Indptr tensor, shape: ``(batch_size + 1)``. + offsets : torch.Tensor + The relative position offsets of each query in the batch, shape: ``(batch_size)``. + rope_scale : float + The scaling factor used in the rope embedding, default: ``8``. + rope_theta : float + The theta value used in the rope embedding, default: ``1e4``. + low_freq_factor : float + The low frequency factor used in Llama 3.1 RoPE, default: ``1``. + high_freq_factor : float + The high frequency factor used in Llama 3.1 RoPE, default: ``4``. + old_context_len : int + The old context length used in Llama 3.1 RoPE, default. + """ + return _kernels.apply_llama31_rope_inplace( + q, + k, + indptr, + offsets, + rope_scale, + rope_theta, + low_freq_factor, + high_freq_factor, + float(old_context_len), + ) From 606dc6687ce59151bf195ad9a1b0a8c2f2cb8717 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 27 Jul 2024 08:15:09 +0000 Subject: [PATCH 07/13] upd --- python/flashinfer/rope.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index 922e968a..406dd63e 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -40,12 +40,21 @@ def apply_rope_inplace( ) -> None: r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. + We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th + segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the + i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always + 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. + Please see :ref:`Ragged Tensor tutorial ` for more details about the + ragged tensor. + Parameters ---------- q : torch.Tensor - Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`. + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)`, where ``nnz`` is the last + element of ``indptr``. k : torch.Tensor - Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``. + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. indptr : torch.Tensor Indptr tensor, shape: ``(batch_size + 1)``. offsets : torch.Tensor @@ -69,7 +78,15 @@ def apply_llama31_rope_inplace( high_freq_factor: float = 4, old_context_len: int = 8192, ) -> None: - r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. + r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as + RaggedTensor) inplace. + + We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th + segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the + i-th segment is ``k[indptr[i]:indptr[i+1]]``, the first element of :attr:`indptr` is always + 0 and the last element of :attr:`indptr` is the total number of queries/keys in the batch. + Please see :ref:`Ragged Tensor tutorial ` for more details about the + ragged tensor. Parameters ---------- From c37df8284d4f118ba05c23c867ee5991c45de1fb Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 27 Jul 2024 08:35:00 +0000 Subject: [PATCH 08/13] upd --- docs/api/python/rope.rst | 14 ++++++++++++++ docs/index.rst | 1 + python/csrc/rope.cu | 2 ++ python/flashinfer/rope.py | 6 ++++-- python/setup.py | 1 + 5 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 docs/api/python/rope.rst diff --git a/docs/api/python/rope.rst b/docs/api/python/rope.rst new file mode 100644 index 00000000..b27ac7e9 --- /dev/null +++ b/docs/api/python/rope.rst @@ -0,0 +1,14 @@ +.. _apirope: + +flashinfer.rope +=============== + +Kernels for applying rotary embeddings. + +.. currentmodule:: flashinfer.rope + +.. autosummary:: + :toctree: _generate + + apply_rope_inplace + apply_llama31_rope_inplace diff --git a/docs/index.rst b/docs/index.rst index c8ed40e0..ce0129f7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -35,4 +35,5 @@ FlashInfer is a library for Large Language Models that provides high-performance api/python/sampling api/python/group_gemm api/python/norm + api/python/rope api/python/quantization diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu index 243ca458..92cf77ee 100644 --- a/python/csrc/rope.cu +++ b/python/csrc/rope.cu @@ -56,6 +56,7 @@ void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, k_stride_h, rope_scale, rope_theta, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryInPlace failed with error code " + std::string(cudaGetErrorString(status))); + return true; }); } @@ -98,5 +99,6 @@ void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31RotaryInPlace failed with error code " + std::string(cudaGetErrorString(status))); + return true; }); } \ No newline at end of file diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index 406dd63e..da9d30fb 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -91,9 +91,11 @@ def apply_llama31_rope_inplace( Parameters ---------- q : torch.Tensor - Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``. + Query ragged tensor, shape: ``(nnz, num_q_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. k : torch.Tensor - Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``. + Key ragged tensor, shape: ``(nnz, num_k_heads, head_dim)``, where ``nnz`` is the last + element of ``indptr``. indptr : torch.Tensor Indptr tensor, shape: ``(batch_size + 1)``. offsets : torch.Tensor diff --git a/python/setup.py b/python/setup.py index 448e6acd..b6424a21 100644 --- a/python/setup.py +++ b/python/setup.py @@ -318,6 +318,7 @@ def __init__(self, *args, **kwargs) -> None: "csrc/batch_prefill.cu", "csrc/sampling.cu", "csrc/norm.cu", + "csrc/rope.cu", "csrc/group_gemm.cu", "csrc/quantization.cu", ] From a12272e265fae66729dcfba98d736dd5b2d3ec67 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 27 Jul 2024 10:09:03 +0000 Subject: [PATCH 09/13] bugfix --- python/csrc/rope.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu index 92cf77ee..eff0dfd7 100644 --- a/python/csrc/rope.cu +++ b/python/csrc/rope.cu @@ -52,7 +52,7 @@ void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, cudaError_t status = BatchQKApplyRotaryInPlace( static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), - batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_h, q_stride_h, k_stride_n, + batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, rope_scale, rope_theta, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryInPlace failed with error code " + std::string(cudaGetErrorString(status))); From 3cb93398abbe5871577dbf61a9e2b51da3fc43e4 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 27 Jul 2024 10:09:38 +0000 Subject: [PATCH 10/13] upd --- python/flashinfer/__init__.py | 1 + python/flashinfer/rope.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index 184116b2..db818d98 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -44,6 +44,7 @@ chain_speculative_sampling, ) from .norm import rmsnorm +from .rope import apply_rope_inplace, apply_llama31_rope_inplace from .group_gemm import SegmentGEMMWrapper from .quantization import packbits, segment_packbits diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index da9d30fb..49304ffa 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -73,7 +73,7 @@ def apply_llama31_rope_inplace( indptr: torch.Tensor, offsets: torch.Tensor, rope_scale: float = 8, - rope_theta: float = 1e4, + rope_theta: float = 5e5, low_freq_factor: float = 1, high_freq_factor: float = 4, old_context_len: int = 8192, @@ -103,7 +103,7 @@ def apply_llama31_rope_inplace( rope_scale : float The scaling factor used in the rope embedding, default: ``8``. rope_theta : float - The theta value used in the rope embedding, default: ``1e4``. + The theta value used in the rope embedding, default: ``5e5``. low_freq_factor : float The low frequency factor used in Llama 3.1 RoPE, default: ``1``. high_freq_factor : float From 0ea07a1f23384bdc2bf99caf3920a0dd2b4c3bc4 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 27 Jul 2024 10:22:29 +0000 Subject: [PATCH 11/13] upd --- include/flashinfer/pos_enc.cuh | 175 ++++++++++++++++++++++----------- python/csrc/flashinfer_ops.h | 6 +- python/csrc/rope.cu | 13 +-- python/flashinfer/rope.py | 11 ++- python/tests/rope_reference.py | 70 +++++++++++++ python/tests/test_rope.py | 134 +++++++++++++++++++++++++ 6 files changed, 341 insertions(+), 68 deletions(-) create mode 100644 python/tests/rope_reference.py create mode 100644 python/tests/test_rope.py diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 1cd5db14..efa5c8bf 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -94,7 +94,36 @@ __device__ __forceinline__ vec_t vec_apply_llama_rope( return vec; } -template +/*! + * \brief Apply RoPE (Rotary Positional Embeddings) to x[0: head_dim] with interleave, + * return thread-local vector. + * \tparam vec_size A template integer indicates the vector size used + * in the kernel + * \tparam bdx A template integer indicates the blockDim.x + * \tparam T A template type indicates the x data type + * \param x A pointer to the start of x data + * \param freq A vector of float indicates the thread-local rope frequency + * \param offset A integer indicates the offset of the position in RoPE + */ +template +__device__ __forceinline__ vec_t vec_apply_llama_rope_interleave( + const T* x, const vec_t& freq, int32_t offset) { + vec_t vec, vec_before; + vec.cast_load(x + threadIdx.x * vec_size); + vec_before = vec; + +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + float embed = float(offset) * freq[i]; + float cos, sin; + __sincosf(embed, &sin, &cos); + vec[i] = vec[i] * cos + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin; + } + return vec; +} + +template __global__ void BatchQKApplyRotaryInPlaceKernel( DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, @@ -105,8 +134,13 @@ __global__ void BatchQKApplyRotaryInPlaceKernel( vec_t freq; #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { - freq[i] = - __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); + if constexpr (interleave) { + freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(head_dim)); + } else { + freq[i] = __powf(rope_rcp_theta, + float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); + } + float smooth = freq[i] * smooth_a + smooth_b; smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; @@ -124,7 +158,12 @@ __global__ void BatchQKApplyRotaryInPlaceKernel( if (i * bdy + ty < seq_len) { DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0, q_stride_n, q_stride_h); - q_vec = vec_apply_llama_rope(q_ptr, freq, offset + i * bdy + ty); + if constexpr (interleave) { + q_vec = + vec_apply_llama_rope_interleave(q_ptr, freq, offset + i * bdy + ty); + } else { + q_vec = vec_apply_llama_rope(q_ptr, freq, offset + i * bdy + ty); + } q_vec.cast_store(q_ptr + tx * vec_size); } } @@ -140,50 +179,67 @@ __global__ void BatchQKApplyRotaryInPlaceKernel( if (i * bdy + ty < seq_len) { DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0, k_stride_n, k_stride_h); - k_vec = vec_apply_llama_rope(k_ptr, freq, offset + i * bdy + ty); + if constexpr (interleave) { + k_vec = + vec_apply_llama_rope_interleave(k_ptr, freq, offset + i * bdy + ty); + } else { + k_vec = vec_apply_llama_rope(k_ptr, freq, offset + i * bdy + ty); + } k_vec.cast_store(k_ptr + tx * vec_size); } } } } +#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ + if (interleave) { \ + const bool INTERLEAVE = true; \ + __VA_ARGS__ \ + } else { \ + const bool INTERLEAVE = false; \ + __VA_ARGS__ \ + } + template cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, - float rope_scale, float rope_theta, + bool interleave, float rope_scale, float rope_theta, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; float smooth_a = 0.f; - float smooth_b = 1.f; - - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); - constexpr uint32_t bdx = HEAD_DIM / vec_size; - uint32_t num_threads = std::max(128U, bdx); - uint32_t bdy = num_threads / bdx; - dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); - dim3 nthrs(bdx, bdy); - auto kernel = BatchQKApplyRotaryInPlaceKernel; - void* args[] = {(void*)&q, - (void*)&k, - (void*)&indptr, - (void*)&offsets, - (void*)&batch_size, - (void*)&num_qo_heads, - (void*)&num_kv_heads, - (void*)&q_stride_n, - (void*)&q_stride_h, - (void*)&k_stride_n, - (void*)&k_stride_h, - (void*)&smooth_a, - (void*)&smooth_b, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + float smooth_b = 0.f; + + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = num_threads / bdx; + dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); + dim3 nthrs(bdx, bdy); + auto kernel = + BatchQKApplyRotaryInPlaceKernel; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&indptr, + (void*)&offsets, + (void*)&batch_size, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&smooth_a, + (void*)&smooth_b, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + }); }); return cudaSuccess; @@ -194,37 +250,40 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace( DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr, IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, - float rope_scale, float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length, cudaStream_t stream = nullptr) { + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length, cudaStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f); - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); - constexpr uint32_t bdx = HEAD_DIM / vec_size; - uint32_t num_threads = std::max(128U, bdx); - uint32_t bdy = num_threads / bdx; - dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); - dim3 nthrs(bdx, bdy); - auto kernel = BatchQKApplyRotaryInPlaceKernel; - void* args[] = {(void*)&q, - (void*)&k, - (void*)&indptr, - (void*)&offsets, - (void*)&batch_size, - (void*)&num_qo_heads, - (void*)&num_kv_heads, - (void*)&q_stride_n, - (void*)&q_stride_h, - (void*)&k_stride_n, - (void*)&k_stride_h, - (void*)&smooth_a, - (void*)&smooth_b, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = num_threads / bdx; + dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); + dim3 nthrs(bdx, bdy); + auto kernel = + BatchQKApplyRotaryInPlaceKernel; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&indptr, + (void*)&offsets, + (void*)&batch_size, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&smooth_a, + (void*)&smooth_b, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + }); }); return cudaSuccess; diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index cf309aff..32617c69 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -76,11 +76,11 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps); void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, float rope_scale, float rope_theta); + torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, + torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta, float low_freq_factor, float high_freq_factor, float old_context_length); torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu index eff0dfd7..4bed69bf 100644 --- a/python/csrc/rope.cu +++ b/python/csrc/rope.cu @@ -21,7 +21,8 @@ using namespace flashinfer; void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, float rope_scale, float rope_theta) { + torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(indptr); @@ -53,7 +54,7 @@ void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, rope_scale, rope_theta, torch_current_stream); + k_stride_h, interleave, rope_scale, rope_theta, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryInPlace failed with error code " + std::string(cudaGetErrorString(status))); return true; @@ -61,8 +62,8 @@ void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, } void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, + torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta, float low_freq_factor, float high_freq_factor, float old_context_length) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous @@ -95,8 +96,8 @@ void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_h, q_stride_h, k_stride_n, - k_stride_h, rope_scale, rope_theta, low_freq_factor, high_freq_factor, old_context_length, - torch_current_stream); + k_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, + old_context_length, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31RotaryInPlace failed with error code " + std::string(cudaGetErrorString(status))); return true; diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index 49304ffa..791fd023 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -35,6 +35,7 @@ def apply_rope_inplace( k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, + interleave: bool = True, rope_scale: float = 1, rope_theta: float = 1e4, ) -> None: @@ -59,12 +60,16 @@ def apply_rope_inplace( Indptr tensor, shape: ``(batch_size + 1)``. offsets : torch.Tensor The relative position offsets of each query in the batch, shape: ``(batch_size)``. + interleave : bool + Whether to use interleave layout in the last dimension, default: ``True``. rope_scale : float The scaling factor used in the rope embedding, default: ``1``. rope_theta : float The theta value used in the rope embedding, default: ``1e4``. """ - return _kernels.apply_rope_inplace(q, k, indptr, offsets, rope_scale, rope_theta) + return _kernels.apply_rope_inplace( + q, k, indptr, offsets, interleave, rope_scale, rope_theta + ) def apply_llama31_rope_inplace( @@ -72,6 +77,7 @@ def apply_llama31_rope_inplace( k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, + interleave: bool = True, rope_scale: float = 8, rope_theta: float = 5e5, low_freq_factor: float = 1, @@ -100,6 +106,8 @@ def apply_llama31_rope_inplace( Indptr tensor, shape: ``(batch_size + 1)``. offsets : torch.Tensor The relative position offsets of each query in the batch, shape: ``(batch_size)``. + interleave : bool + Whether to use interleave layout in the last dimension, default: ``True``. rope_scale : float The scaling factor used in the rope embedding, default: ``8``. rope_theta : float @@ -116,6 +124,7 @@ def apply_llama31_rope_inplace( k, indptr, offsets, + interleave, rope_scale, rope_theta, low_freq_factor, diff --git a/python/tests/rope_reference.py b/python/tests/rope_reference.py new file mode 100644 index 00000000..4b1daa07 --- /dev/null +++ b/python/tests/rope_reference.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. +import torch +import math +from typing import Tuple + + +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False +): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + if use_scaled: + freqs = apply_scaling(freqs) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) diff --git a/python/tests/test_rope.py b/python/tests/test_rope.py new file mode 100644 index 00000000..8316c499 --- /dev/null +++ b/python/tests/test_rope.py @@ -0,0 +1,134 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import numpy as np +import flashinfer +import pytest +from rope_reference import * + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) +@pytest.mark.parametrize("num_qo_heads", [8, 16]) +@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("offset", [0, 15, 99]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_llama_rope( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + offset, + head_dim, +): + nnz = batch_size * qkv_len + qkv_packed = torch.randn( + nnz, + (num_qo_heads + 2 * num_kv_heads) * head_dim, + dtype=torch.float16, + device="cuda:0", + ) + q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) + k = qkv_packed[ + :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim + ].reshape(nnz, num_kv_heads, head_dim) + indptr = torch.tensor( + [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" + ) + offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") + + # reference implementation + freqs_cis = precompute_freqs_cis( + head_dim, qkv_len + offset, 10000.0, use_scaled=False + ).to("cuda:0") + q_rope, k_rope = apply_rotary_emb( + q.reshape(batch_size, qkv_len, num_qo_heads, head_dim), + k.reshape(batch_size, qkv_len, num_kv_heads, head_dim), + freqs_cis[offset : offset + qkv_len], + ) + q_rope = q_rope.reshape(nnz, num_qo_heads, head_dim) + k_rope = k_rope.reshape(nnz, num_kv_heads, head_dim) + + # flashinfer implementation + flashinfer.apply_rope_inplace(q, k, indptr, offsets, rope_theta=1e4) + + # compare + np.testing.assert_allclose( + q_rope.cpu().numpy(), q.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + np.testing.assert_allclose( + k_rope.cpu().numpy(), k.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) +@pytest.mark.parametrize("num_qo_heads", [8, 16]) +@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("offset", [0, 15, 99]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_llama31_rope( + batch_size, + qkv_len, + num_qo_heads, + num_kv_heads, + offset, + head_dim, +): + nnz = batch_size * qkv_len + qkv_packed = torch.randn( + nnz, + (num_qo_heads + 2 * num_kv_heads) * head_dim, + dtype=torch.float16, + device="cuda:0", + ) + q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) + k = qkv_packed[ + :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim + ].reshape(nnz, num_kv_heads, head_dim) + indptr = torch.tensor( + [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" + ) + offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") + + # reference implementation + freqs_cis = precompute_freqs_cis( + head_dim, qkv_len + offset, 5e5, use_scaled=True + ).to("cuda:0") + q_rope, k_rope = apply_rotary_emb( + q.reshape(batch_size, qkv_len, num_qo_heads, head_dim), + k.reshape(batch_size, qkv_len, num_kv_heads, head_dim), + freqs_cis[offset : offset + qkv_len], + ) + q_rope = q_rope.reshape(nnz, num_qo_heads, head_dim) + k_rope = k_rope.reshape(nnz, num_kv_heads, head_dim) + + # flashinfer implementation + flashinfer.apply_llama31_rope_inplace(q, k, indptr, offsets, rope_theta=5e5) + + # compare + np.testing.assert_allclose( + q_rope.cpu().numpy(), q.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + np.testing.assert_allclose( + k_rope.cpu().numpy(), k.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + + +if __name__ == "__main__": + test_llama_rope(2, 1, 8, 8, 1, 128) + test_llama31_rope(1, 1, 8, 8, 0, 128) From db129ee89ffc1ed8f994dd3a1be1e85ec1b6426a Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 27 Jul 2024 10:33:45 +0000 Subject: [PATCH 12/13] bugfix again --- python/csrc/rope.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu index 4bed69bf..7fb9f483 100644 --- a/python/csrc/rope.cu +++ b/python/csrc/rope.cu @@ -95,7 +95,7 @@ void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor cudaError_t status = BatchQKApplyLlama31RotaryInPlace( static_cast(q.data_ptr()), static_cast(k.data_ptr()), static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), - batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_h, q_stride_h, k_stride_n, + batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, old_context_length, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31RotaryInPlace failed with error code " + From da8e2dc25cfcb84b0d35535473efe75fe09881da Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 27 Jul 2024 10:37:14 +0000 Subject: [PATCH 13/13] upd --- python/flashinfer/rope.py | 24 ++++++++++++++++++++---- python/tests/test_rope.py | 8 ++++++-- src/tvm_wrapper.cu | 2 +- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index 791fd023..6bb67eb8 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -35,7 +35,7 @@ def apply_rope_inplace( k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, - interleave: bool = True, + interleave: bool = False, rope_scale: float = 1, rope_theta: float = 1e4, ) -> None: @@ -61,7 +61,15 @@ def apply_rope_inplace( offsets : torch.Tensor The relative position offsets of each query in the batch, shape: ``(batch_size)``. interleave : bool - Whether to use interleave layout in the last dimension, default: ``True``. + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + rope_scale : float The scaling factor used in the rope embedding, default: ``1``. rope_theta : float @@ -107,7 +115,15 @@ def apply_llama31_rope_inplace( offsets : torch.Tensor The relative position offsets of each query in the batch, shape: ``(batch_size)``. interleave : bool - Whether to use interleave layout in the last dimension, default: ``True``. + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + rope_scale : float The scaling factor used in the rope embedding, default: ``8``. rope_theta : float @@ -117,7 +133,7 @@ def apply_llama31_rope_inplace( high_freq_factor : float The high frequency factor used in Llama 3.1 RoPE, default: ``4``. old_context_len : int - The old context length used in Llama 3.1 RoPE, default. + The old context length used in Llama 3.1 RoPE, default: ``8192``. """ return _kernels.apply_llama31_rope_inplace( q, diff --git a/python/tests/test_rope.py b/python/tests/test_rope.py index 8316c499..e0676126 100644 --- a/python/tests/test_rope.py +++ b/python/tests/test_rope.py @@ -64,7 +64,9 @@ def test_llama_rope( k_rope = k_rope.reshape(nnz, num_kv_heads, head_dim) # flashinfer implementation - flashinfer.apply_rope_inplace(q, k, indptr, offsets, rope_theta=1e4) + flashinfer.apply_rope_inplace( + q, k, indptr, offsets, interleave=True, rope_theta=1e4 + ) # compare np.testing.assert_allclose( @@ -118,7 +120,9 @@ def test_llama31_rope( k_rope = k_rope.reshape(nnz, num_kv_heads, head_dim) # flashinfer implementation - flashinfer.apply_llama31_rope_inplace(q, k, indptr, offsets, rope_theta=5e5) + flashinfer.apply_llama31_rope_inplace( + q, k, indptr, offsets, interleave=True, rope_theta=5e5 + ) # compare np.testing.assert_allclose( diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index 4a5b9767..809fb585 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -691,7 +691,7 @@ void _FlashInferBatchQKApplyRotaryInPlace(DLTensor* q, DLTensor* k, DLTensor* in static_cast(q->data), static_cast(k->data), static_cast(indptr->data), static_cast(offsets->data), batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, - rope_scale, rope_theta); + /*interleave=*/false, rope_scale, rope_theta); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); }