Skip to content

Commit

Permalink
Merge pull request #6 from HandH1998/tmptmp
Browse files Browse the repository at this point in the history
clean code
  • Loading branch information
HandH1998 authored Jan 22, 2025
2 parents 38bcf52 + e620244 commit 699fe9e
Showing 1 changed file with 62 additions and 67 deletions.
129 changes: 62 additions & 67 deletions sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,69 +240,61 @@ void sm89_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);

if (m <= 1) {
if (np2 <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
} else if (np2 <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 16) {
// M in (1, 16]
if (np2 <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias);
} else if (np2 <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 64) {
// M in (16, 64]
if (np2 <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
} else if (np2 <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 128) {
// M in (64, 128]
if (np2 <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias);
} else if (np2 <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 256) {
// M in (128, 256]
if (np2 <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 64>, cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias);
} else if (np2 <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 512) {
// M in (256, 512)
if (np2 <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
} else if (np2 <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias);
}
} else {
// M in (512, inf)
if (np2 <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias);
} else if (np2 <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
if (mp2 == 1) {
if (np2 <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 16) {
// M in (1, 16]
if (np2 <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias);
} else if (np2 <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 64) {
// M in (16, 64]
if (np2 <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 128) {
// M in (64, 128]
if (np2 <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias);
} else if (np2 <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 256) {
// M in (128, 256]
if (np2 <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 64>, cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias);
} else if (np2 <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias);
}
} else if (mp2 <= 512) {
// M in (256, 512)
if (np2 <= 16384) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias);
}
} else {
// M in (512, inf)
if (np2 <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias);
} else {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias);
}
}
}
}
#endif

Expand Down Expand Up @@ -532,12 +524,15 @@ void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch

if (mp2 <= 64) {
// m in [1, 64]
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _8, _1>>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 128) {
// m in (64, 128]
return sm90_dispatch_bias<OutType, Shape<_64, _128, _128>, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias);
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _4, _1>>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 256) {
// m in (64, 256]
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _1, _1>>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 1024) {
// m in (256, 1024]
return sm90_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_1, _1, _1>>(out, a, b, scales_a, scales_b, bias);
} else {
// m in (128, inf)
// m in (1024, inf)
return sm90_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias);
}
}
Expand Down

0 comments on commit 699fe9e

Please sign in to comment.