diff --git a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu index 42dd059408d..3f26162c6a9 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu @@ -240,69 +240,61 @@ void sm89_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch uint32_t const n = out.size(1); uint32_t const np2 = next_pow_2(n); - if (m <= 1) { - if (np2 <= 8192) { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); - } else if (np2 <= 16384) { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); - } else { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); - } - } else if (mp2 <= 16) { - // M in (1, 16] - if (np2 <= 8192) { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); - } else if (np2 <= 16384) { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); - } else { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); - } - } else if (mp2 <= 64) { - // M in (16, 64] - if (np2 <= 8192) { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); - } else if (np2 <= 16384) { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); - } else { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); - } - } else if (mp2 <= 128) { - // M in (64, 128] - if (np2 <= 8192) { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); - } else if (np2 <= 16384) { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); - } else { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); - } - } else if (mp2 <= 256) { - // M in (128, 256] - if (np2 <= 8192) { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias); - } else if (np2 <= 16384) { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias); - } else { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias); - } - } else if (mp2 <= 512) { - // M in (256, 512) - if (np2 <= 8192) { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); - } else if (np2 <= 16384) { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); - } else { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias); - } - } else { - // M in (512, inf) - if (np2 <= 8192) { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias); - } else if (np2 <= 16384) { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); - } else { - return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + if (mp2 == 1) { + if (np2 <= 8192) { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } else if (mp2 <= 16) { + // M in (1, 16] + if (np2 <= 8192) { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } else if (np2 <= 16384) { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } + } else if (mp2 <= 64) { + // M in (16, 64] + if (np2 <= 16384) { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } + } else if (mp2 <= 128) { + // M in (64, 128] + if (np2 <= 8192) { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } else if (np2 <= 16384) { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } else if (mp2 <= 256) { + // M in (128, 256] + if (np2 <= 8192) { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else if (np2 <= 16384) { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias); + } + } else if (mp2 <= 512) { + // M in (256, 512) + if (np2 <= 16384) { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } + } else { + // M in (512, inf) + if (np2 <= 8192) { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + } } - } } #endif @@ -532,12 +524,15 @@ void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch if (mp2 <= 64) { // m in [1, 64] - return sm90_dispatch_bias, Shape<_1, _8, _1>>(out, a, b, scales_a, scales_b, bias); - } else if (mp2 <= 128) { - // m in (64, 128] - return sm90_dispatch_bias, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias); + return sm90_dispatch_bias, Shape<_1, _4, _1>>(out, a, b, scales_a, scales_b, bias); + } else if (mp2 <= 256) { + // m in (64, 256] + return sm90_dispatch_bias, Shape<_1, _1, _1>>(out, a, b, scales_a, scales_b, bias); + } else if (mp2 <= 1024) { + // m in (256, 1024] + return sm90_dispatch_bias, Shape<_1, _1, _1>>(out, a, b, scales_a, scales_b, bias); } else { - // m in (128, inf) + // m in (1024, inf) return sm90_dispatch_bias, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias); } }