Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup sgl-kernel kernels #3175

Merged
merged 7 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false
runs-on: 1-gpu-runner
strategy:
fail-fast: false
matrix:
range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-100]
steps:
Expand Down
2 changes: 1 addition & 1 deletion sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _get_version():
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu",
"3rdparty/flashinfer/csrc/activation.cu",
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
"3rdparty/flashinfer/csrc/norm.cu",
Expand Down
2 changes: 0 additions & 2 deletions sgl-kernel/src/sgl-kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
moe_align_block_size,
register_graph_buffers,
rmsnorm,
rotary_embedding,
sampling_scaling_penalties,
silu_and_mul,
top_k_renorm_prob,
Expand All @@ -44,7 +43,6 @@
"moe_align_block_size",
"register_graph_buffers",
"rmsnorm",
"rotary_embedding",
"sampling_scaling_penalties",
"silu_and_mul",
"top_k_renorm_prob",
Expand Down
140 changes: 140 additions & 0 deletions sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/include/flashinfer/norm.cuh
// and https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/python/csrc/norm.cu
// TODO(zhyncs): tmp fix, v0.1.6 enables SGLang e2e to pass CIs unlike v0.2.0

#include <ATen/cuda/CUDAContext.h>

#include <flashinfer/math.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include <numeric>

#include "utils.h"

using namespace flashinfer;

template <uint32_t VEC_SIZE, typename T>
__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual, T* __restrict__ weight,
const uint32_t d, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
const uint32_t num_warps = blockDim.y;
const uint32_t thread_id = tx + ty * warp_size;
const uint32_t num_threads = num_warps * warp_size;
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
extern __shared__ float smem[];

float sum_sq = 0.f;

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0.f);
vec_t<T, VEC_SIZE> residual_vec;
residual_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
float x = float(input_vec[j]);
x += float(residual_vec[j]);
sum_sq += x * x;
residual_vec[j] = (T)x;
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}

// first, warp reduce sum
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}

smem[ty] = sum_sq;
__syncthreads();
// then, cross warp reduce sum using only the first warp
if (ty == 0) {
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}
smem[0] = sum_sq;
}
__syncthreads();

float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> residual_vec;
input_vec.fill(0.f);
weight_vec.fill(0.f);
residual_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]);
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}
}

template <typename T>
cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, float eps = 1e-5,
cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
void* args[] = {&input, &residual, &weight, &d, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});

return cudaSuccess;
}

void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) {
CHECK_INPUT(input);
CHECK_INPUT(residual);
CHECK_INPUT(weight);
auto device = input.device();
CHECK_EQ(residual.device(), device);
CHECK_EQ(weight.device(), device);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
CHECK_EQ(input.size(0), residual.size(0));
CHECK_EQ(input.size(1), residual.size(1));
CHECK_EQ(input.size(1), weight.size(0));
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);

cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
// support float16, bfloat16 and float32
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status =
FusedAddRMSNorm(static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
return true;
});
}
6 changes: 1 addition & 5 deletions sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,11 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k,
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
torch::Tensor new_kv);

// rotary embedding
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);

// rms norm
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);

// fused rms norm
void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);

// gemma rms norm
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
Expand Down
10 changes: 1 addition & 9 deletions sgl-kernel/src/sgl-kernel/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,6 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
)


def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
return torch.ops.sgl_kernels.rotary_embedding(
positions, query, key, head_size, cos_sin_cache, is_neox
)


# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
# Kudos to @yzh119
def rmsnorm(
Expand All @@ -167,9 +161,7 @@ def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
with input.device as device:
torch.ops.sgl_kernels.fused_add_rmsnorm(
input, residual, weight, eps, _get_cuda_stream(device)
)
torch.ops.sgl_kernels.fused_add_rmsnorm(input, residual, weight, eps)


def gemma_rmsnorm(
Expand Down
10 changes: 2 additions & 8 deletions sgl-kernel/src/sgl-kernel/torch_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,13 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"new_kv) -> ()");
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);

// rotary embedding
m.def(
"rotary_embedding(Tensor positions, Tensor! query, Tensor! key, int head_size, Tensor cos_sin_cache, bool "
"is_neox) -> ()");
m.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);

// rms norm
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);

// fused rms norm
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
m.impl("fused_add_rmsnorm", torch::kCUDA, &fused_add_rmsnorm);
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);

// gemma rms norm
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
Expand Down
2 changes: 1 addition & 1 deletion sgl-kernel/tests/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):

@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
eps = 1e-6

Expand Down
Loading