diff --git a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu index 587293cd197..02ecfa37574 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu @@ -229,69 +229,64 @@ void sm89_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch const torch::Tensor& scales_a, const torch::Tensor& scales_b, const c10::optional& bias) { uint32_t const m = a.size(0); - // uint32_t const mp2 = - // std::max(static_cast(32), next_pow_2(m)); // next power of 2 - uint32_t const mp2 = next_pow_2(m); // next power of 2 - uint32_t const n = out.size(1); - uint32_t const np2 = next_pow_2(n); - if (mp2 == 1) { - if (np2 <= 8192) { + if (m == 1) { + if (n <= 8192) { return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); } else { return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } - } else if (mp2 <= 16) { + } else if (m <= 16) { // M in (1, 16] - if (np2 <= 8192) { + if (n <= 8192) { return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); - } else if (np2 <= 16384) { + } else if (n <= 16384) { return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } else { return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); } - } else if (mp2 <= 64) { + } else if (m <= 64) { // M in (16, 64] - if (np2 <= 16384) { + if (n <= 16384) { return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); } else { return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); } - } else if (mp2 <= 128) { + } else if (m <= 128) { // M in (64, 128] - if (np2 <= 8192) { + if (n <= 8192) { return sm89_dispatch_bias, cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); - } else if (np2 <= 16384) { + } else if (n <= 16384) { return sm89_dispatch_bias, cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } else { return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } - } else if (mp2 <= 256) { + } else if (m <= 256) { // M in (128, 256] - if (np2 <= 8192) { + if (n <= 8192) { return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias); - } else if (np2 <= 16384) { + } else if (n <= 16384) { return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias); } else { return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias); } - } else if (mp2 <= 512) { + } else if (m <= 512) { // M in (256, 512) - if (np2 <= 16384) { + if (n <= 16384) { return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); } else { @@ -300,7 +295,7 @@ void sm89_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch } } else { // M in (512, inf) - if (np2 <= 8192) { + if (n <= 8192) { return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias); } else { @@ -417,7 +412,7 @@ struct DeviceGemmFp8RowwiseSm90 { TileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, - MainloopScheduleType>::CollectiveOp; + MainloopScheduleType>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; @@ -451,7 +446,6 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch:: ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); - // TODO: confirm correctess StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1)); StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1)); StrideC stride_c; @@ -510,34 +504,27 @@ void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const TORCH_CHECK(status == cutlass::Status::kSuccess) } -template -void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias, - bool fast_accum = true, - bool use_persistent = false) { - using ElementInput = cutlass::float_e4m3_t; - using ElementOutput = OutType; - using AccumElementType = float; - using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; - - if (bias) { - using Gemm = typename DeviceGemmFp8RowwiseSm90::Gemm; - return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); - } else { - using Gemm = typename DeviceGemmFp8RowwiseSm90::Gemm; - return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); - } +template +void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias, bool fast_accum = true, bool use_persistent = false) { + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutType; + using AccumElementType = float; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; + + if (bias) { + using Gemm = + typename DeviceGemmFp8RowwiseSm90::Gemm; + return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else { + using Gemm = + typename DeviceGemmFp8RowwiseSm90::Gemm; + return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } } template @@ -545,25 +532,30 @@ void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch const torch::Tensor& scales_a, const torch::Tensor& scales_b, const c10::optional& bias) { uint32_t const m = a.size(0); - uint32_t const mp2 = std::max(static_cast(64), next_pow_2(m)); // next power of 2 -using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; -using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; -using PersistentTileScheduler = cutlass::gemm::PersistentScheduler; -using BasicTileScheduler = void; - if (mp2 <= 1) { - return sm90_dispatch_bias, Shape<_1, _8, _1>, FastBasicScheduler, BasicTileScheduler>(out, a, b, scales_a, scales_b, bias); - } if (mp2 <= 64) { + using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using PersistentTileScheduler = cutlass::gemm::PersistentScheduler; + using BasicTileScheduler = void; + if (m <= 1) { + return sm90_dispatch_bias, Shape<_1, _8, _1>, FastBasicScheduler, + BasicTileScheduler>(out, a, b, scales_a, scales_b, bias); + } + if (m <= 64) { // m in [1, 64] - return sm90_dispatch_bias, Shape<_1, _4, _1>, FastPingpongScheduler, PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); - } else if (mp2 <= 256) { + return sm90_dispatch_bias, Shape<_1, _4, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } else if (m <= 256) { // m in (64, 256] - return sm90_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); - } else if (mp2 <= 1024) { + return sm90_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } else if (m <= 1024) { // m in (256, 1024] - return sm90_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + return sm90_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); } else { // m in (1024, inf) - return sm90_dispatch_bias, Shape<_2, _1, _1>, FastPingpongScheduler, PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + return sm90_dispatch_bias, Shape<_2, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); } } #endif diff --git a/sgl-kernel/src/sgl-kernel/csrc/utils.h b/sgl-kernel/src/sgl-kernel/csrc/utils.h index 5820b1350ab..2fed2d60c03 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/utils.h +++ b/sgl-kernel/src/sgl-kernel/csrc/utils.h @@ -44,8 +44,3 @@ inline int getSMVersion() { CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); return sm_major * 10 + sm_minor; } - -inline uint32_t next_pow_2(uint32_t const num) { - if (num <= 1) return num; - return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} diff --git a/sgl-kernel/tests/test_fp8_gemm.py b/sgl-kernel/tests/test_fp8_gemm.py index b55bd089a4d..1a731865944 100644 --- a/sgl-kernel/tests/test_fp8_gemm.py +++ b/sgl-kernel/tests/test_fp8_gemm.py @@ -2,7 +2,6 @@ import torch from sgl_kernel import fp8_scaled_mm -from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): @@ -20,23 +19,31 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): class TestFp8Gemm(unittest.TestCase): def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): - a = torch.randn((M, K), device=device) - b = torch.randn((N, K), device=device) + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min - scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) * 0.001 - scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) * 0.001 + a_fp32 = ( + (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + ) + a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + b_fp32 = ( + (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + ) + b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001 + scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001 if with_bias: - bias = torch.randn((N,), device="cuda", dtype=out_dtype) + bias = torch.randn((N,), device=device, dtype=out_dtype) else: bias = None - o1 = torch.empty((a.shape[0], b.shape[1]), device="cuda", dtype=torch.bfloat16) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16) b_fp8 = b_fp8.t() - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) - o = torch_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, out_dtype, bias) - o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, out_dtype, bias) + o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) rtol = 0.02 - atol = 2 + atol = 1 torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")