From 5aa001bece45c63dd94c3154e931b0cad57b3346 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Thu, 30 Jan 2025 02:25:49 -0800 Subject: [PATCH 1/4] upd --- sgl-kernel/3rdparty/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgl-kernel/3rdparty/cutlass b/sgl-kernel/3rdparty/cutlass index b78588d1630..bdd641790ad 160000 --- a/sgl-kernel/3rdparty/cutlass +++ b/sgl-kernel/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit b78588d1630aa6643bf021613717bafb705df4ef +Subproject commit bdd641790ad49353b40ada41330552a78d2f8b5a From 10d0421cd13bc438a9092488ba3686d63dbebe71 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Thu, 30 Jan 2025 02:45:06 -0800 Subject: [PATCH 2/4] upd --- sgl-kernel/3rdparty/flashinfer | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer index 4f1f08989c7..e5a3befbe3e 160000 --- a/sgl-kernel/3rdparty/flashinfer +++ b/sgl-kernel/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 4f1f08989c71f92df181e346548c2ca48ae6daf5 +Subproject commit e5a3befbe3e63025f0158bc96b218a9c5f402ac7 From f12f46dfb1344e787f78f81d62fe0a55ca5b78e1 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Thu, 30 Jan 2025 03:13:40 -0800 Subject: [PATCH 3/4] upd --- .../csrc/fused_add_rms_norm_kernel.cu | 113 +----------------- 1 file changed, 4 insertions(+), 109 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu index 4c4ecb966ee..f0f3a51744e 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu @@ -1,116 +1,11 @@ -// Adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/include/flashinfer/norm.cuh -// and https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/python/csrc/norm.cu -// TODO(zhyncs): tmp fix, v0.1.6 enables SGLang e2e to pass CIs unlike v0.2.0 - #include -#include -#include -#include -#include +#include #include "utils.h" using namespace flashinfer; -template -__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual, T* __restrict__ weight, - const uint32_t d, float eps) { - const uint32_t bx = blockIdx.x; - const uint32_t tx = threadIdx.x, ty = threadIdx.y; - constexpr uint32_t warp_size = 32; - const uint32_t num_warps = blockDim.y; - const uint32_t thread_id = tx + ty * warp_size; - const uint32_t num_threads = num_warps * warp_size; - const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); - extern __shared__ float smem[]; - - float sum_sq = 0.f; - - for (uint32_t i = 0; i < rounds; i++) { - vec_t input_vec; - input_vec.fill(0.f); - vec_t residual_vec; - residual_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; j++) { - float x = float(input_vec[j]); - x += float(residual_vec[j]); - sum_sq += x * x; - residual_vec[j] = (T)x; - } - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } - } - - // first, warp reduce sum -#pragma unroll - for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { - sum_sq += math::shfl_xor_sync(sum_sq, offset); - } - - smem[ty] = sum_sq; - __syncthreads(); - // then, cross warp reduce sum using only the first warp - if (ty == 0) { - sum_sq = (tx < num_warps) ? smem[tx] : 0.f; -#pragma unroll - for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { - sum_sq += math::shfl_xor_sync(sum_sq, offset); - } - smem[0] = sum_sq; - } - __syncthreads(); - - float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); - - for (uint32_t i = 0; i < rounds; i++) { - vec_t input_vec; - vec_t weight_vec; - vec_t residual_vec; - input_vec.fill(0.f); - weight_vec.fill(0.f); - residual_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; j++) { - input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]); - } - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } - } -} - -template -cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, float eps = 1e-5, - cudaStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t block_size = std::min(1024, d / vec_size); - const uint32_t num_warps = ceil_div(block_size, 32); - dim3 nblks(batch_size); - dim3 nthrs(32, num_warps); - const uint32_t smem_size = num_warps * sizeof(float); - void* args[] = {&input, &residual, &weight, &d, &eps}; - - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = FusedAddRMSNormKernel; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); - - return cudaSuccess; -} - void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) { CHECK_INPUT(input); CHECK_INPUT(residual); @@ -130,9 +25,9 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); // support float16, bfloat16 and float32 DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { - cudaError_t status = - FusedAddRMSNorm(static_cast(input.data_ptr()), static_cast(residual.data_ptr()), - static_cast(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); + cudaError_t status = norm::FusedAddRMSNorm( + static_cast(input.data_ptr()), static_cast(residual.data_ptr()), + static_cast(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); return true; From 7fb53608fa459f985fd0a4b446d90047a0fb499a Mon Sep 17 00:00:00 2001 From: zhyncs Date: Thu, 30 Jan 2025 03:14:27 -0800 Subject: [PATCH 4/4] upd --- sgl-kernel/pyproject.toml | 2 +- sgl-kernel/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index aca6f045054..bb7d6943348 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.3" +version = "0.0.3.post1" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/version.py b/sgl-kernel/version.py index 27fdca497c3..647733203b6 100644 --- a/sgl-kernel/version.py +++ b/sgl-kernel/version.py @@ -1 +1 @@ -__version__ = "0.0.3" +__version__ = "0.0.3.post1"