Skip to content

Commit

Permalink
fix reivew issues
Browse files Browse the repository at this point in the history
  • Loading branch information
HandH1998 committed Jan 24, 2025
1 parent b9980af commit cd51083
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 82 deletions.
122 changes: 57 additions & 65 deletions sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -229,69 +229,64 @@ void sm89_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
uint32_t const m = a.size(0);
// uint32_t const mp2 =
// std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
uint32_t const mp2 = next_pow_2(m); // next power of 2

uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);

if (mp2 == 1) {
if (np2 <= 8192) {
if (m == 1) {
if (n <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>,
7>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>,
5>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 16) {
} else if (m <= 16) {
// M in (1, 16]
if (np2 <= 8192) {
if (n <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>,
4>(out, a, b, scales_a, scales_b, bias);
} else if (np2 <= 16384) {
} else if (n <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>,
5>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>,
7>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 64) {
} else if (m <= 64) {
// M in (16, 64]
if (np2 <= 16384) {
if (n <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>,
7>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>,
7>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 128) {
} else if (m <= 128) {
// M in (64, 128]
if (np2 <= 8192) {
if (n <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<32, 64, 64>,
4>(out, a, b, scales_a, scales_b, bias);
} else if (np2 <= 16384) {
} else if (n <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<32, 64, 64>,
5>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>,
5>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 256) {
} else if (m <= 256) {
// M in (128, 256]
if (np2 <= 8192) {
if (n <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 64>, cutlass::gemm::GemmShape<64, 32, 64>,
5>(out, a, b, scales_a, scales_b, bias);
} else if (np2 <= 16384) {
} else if (n <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>,
7>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>,
4>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 512) {
} else if (m <= 512) {
// M in (256, 512)
if (np2 <= 16384) {
if (n <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>,
2>(out, a, b, scales_a, scales_b, bias);
} else {
Expand All @@ -300,7 +295,7 @@ void sm89_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch
}
} else {
// M in (512, inf)
if (np2 <= 8192) {
if (n <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>,
3>(out, a, b, scales_a, scales_b, bias);
} else {
Expand Down Expand Up @@ -417,7 +412,7 @@ struct DeviceGemmFp8RowwiseSm90 {
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType>::CollectiveOp;
MainloopScheduleType>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>;
Expand Down Expand Up @@ -451,7 +446,6 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::
ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast<ElementComputeEpilogue const*>(scales_a.data_ptr());
ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast<ElementComputeEpilogue const*>(scales_b.data_ptr());

// TODO: confirm correctess
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1));
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
StrideC stride_c;
Expand Down Expand Up @@ -510,60 +504,58 @@ void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const
TORCH_CHECK(status == cutlass::Status::kSuccess)
}

template <typename OutType, typename CTAShape, typename ClusterShape, typename MainloopScheduleType, typename TileSchedulerType>
void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias,
bool fast_accum = true,
bool use_persistent = false) {
using ElementInput = cutlass::float_e4m3_t;
using ElementOutput = OutType;
using AccumElementType = float;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;

if (bias) {
using Gemm = typename DeviceGemmFp8RowwiseSm90<ElementInput, ElementOutput, AccumElementType,
CTAShape, ClusterShape,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>::Gemm;
return launch_sm90_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
} else {
using Gemm = typename DeviceGemmFp8RowwiseSm90<ElementInput, ElementOutput, AccumElementType,
CTAShape, ClusterShape,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>::Gemm;
return launch_sm90_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
}
template <typename OutType, typename CTAShape, typename ClusterShape, typename MainloopScheduleType,
typename TileSchedulerType>
void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias, bool fast_accum = true, bool use_persistent = false) {
using ElementInput = cutlass::float_e4m3_t;
using ElementOutput = OutType;
using AccumElementType = float;
using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;

if (bias) {
using Gemm =
typename DeviceGemmFp8RowwiseSm90<ElementInput, ElementOutput, AccumElementType, CTAShape, ClusterShape,
MainloopScheduleType, EpilogueScheduleType, TileSchedulerType, true>::Gemm;
return launch_sm90_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
} else {
using Gemm =
typename DeviceGemmFp8RowwiseSm90<ElementInput, ElementOutput, AccumElementType, CTAShape, ClusterShape,
MainloopScheduleType, EpilogueScheduleType, TileSchedulerType, false>::Gemm;
return launch_sm90_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
}
}

template <typename OutType>
void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
uint32_t const m = a.size(0);
uint32_t const mp2 = std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using PersistentTileScheduler = cutlass::gemm::PersistentScheduler;
using BasicTileScheduler = void;
if (mp2 <= 1) {
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _8, _1>, FastBasicScheduler, BasicTileScheduler>(out, a, b, scales_a, scales_b, bias);
} if (mp2 <= 64) {
using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
using PersistentTileScheduler = cutlass::gemm::PersistentScheduler;
using BasicTileScheduler = void;
if (m <= 1) {
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _8, _1>, FastBasicScheduler,
BasicTileScheduler>(out, a, b, scales_a, scales_b, bias);
}
if (m <= 64) {
// m in [1, 64]
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _4, _1>, FastPingpongScheduler, PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 256) {
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _4, _1>, FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else if (m <= 256) {
// m in (64, 256]
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _1, _1>, FastPingpongScheduler, PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 1024) {
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _1, _1>, FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else if (m <= 1024) {
// m in (256, 1024]
return sm90_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_1, _1, _1>, FastPingpongScheduler, PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
return sm90_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_1, _1, _1>, FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
} else {
// m in (1024, inf)
return sm90_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_2, _1, _1>, FastPingpongScheduler, PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
return sm90_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_2, _1, _1>, FastPingpongScheduler,
PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias);
}
}
#endif
Expand Down
5 changes: 0 additions & 5 deletions sgl-kernel/src/sgl-kernel/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,3 @@ inline int getSMVersion() {
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
return sm_major * 10 + sm_minor;
}

inline uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
31 changes: 19 additions & 12 deletions sgl-kernel/tests/test_fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
from sgl_kernel import fp8_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant


def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
Expand All @@ -20,23 +19,31 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):

class TestFp8Gemm(unittest.TestCase):
def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device):
a = torch.randn((M, K), device=device)
b = torch.randn((N, K), device=device)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min

scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) * 0.001
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) * 0.001
a_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

b_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001
if with_bias:
bias = torch.randn((N,), device="cuda", dtype=out_dtype)
bias = torch.randn((N,), device=device, dtype=out_dtype)
else:
bias = None
o1 = torch.empty((a.shape[0], b.shape[1]), device="cuda", dtype=torch.bfloat16)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
b_fp8 = b_fp8.t()
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
o = torch_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, out_dtype, bias)
o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, out_dtype, bias)
o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias)
rtol = 0.02
atol = 2
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")

Expand Down

0 comments on commit cd51083

Please sign in to comment.