Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gemm split-k implementation #696

Draft
wants to merge 4 commits into
base: main_perf
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 109 additions & 28 deletions python/perf-kernels/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,45 @@
@triton.autotune(
configs=[
triton.Config(
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 4, 'waves_per_eu': 0},
num_warps=8, num_stages=2),
{
'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'SPLIT_K': 1, 'GROUP_SIZE_M': 4,
'waves_per_eu': 0
}, num_warps=8, num_stages=2),
triton.Config(
{
'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'waves_per_eu': 2,
'kpack': 2, 'matrix_instr_nonkdim': 16
'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'SPLIT_K': 1, 'GROUP_SIZE_M': 8,
'waves_per_eu': 2, 'kpack': 2, 'matrix_instr_nonkdim': 16
}, num_warps=8, num_stages=2),
triton.Config(
{
'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1, 'waves_per_eu': 0,
'kpack': 1
'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'SPLIT_K': 1, 'GROUP_SIZE_M': 1,
'waves_per_eu': 0, 'kpack': 1
}, num_warps=8, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'waves_per_eu': 0},
num_warps=8, num_stages=2),
{
'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'SPLIT_K': 1, 'GROUP_SIZE_M': 4,
'waves_per_eu': 0
}, num_warps=8, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
num_warps=4, num_stages=2),
{
'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'SPLIT_K': 1, 'GROUP_SIZE_M': 4,
'waves_per_eu': 2
}, num_warps=4, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=8, num_stages=2),
{
'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'SPLIT_K': 1, 'GROUP_SIZE_M': 1,
'waves_per_eu': 2
}, num_warps=8, num_stages=2),
triton.Config(
{
'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'SPLIT_K': 1, 'GROUP_SIZE_M': 32,
'waves_per_eu': 2
}, num_warps=4, num_stages=2),
triton.Config(
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 32, 'waves_per_eu': 2},
num_warps=4, num_stages=2),
{
'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'SPLIT_K': 1, 'GROUP_SIZE_M': 1,
'waves_per_eu': 0, 'kpack': 2, 'matrix_instr_nonkdim': 16
}, num_warps=4, num_stages=2),
],
key=['M', 'N', 'K'],
use_cuda_graph=True,
Expand All @@ -46,6 +61,7 @@ def matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
c_buf_ptr,
M,
N,
K,
Expand All @@ -61,6 +77,7 @@ def matmul_kernel(
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
APPLY_SCALE: tl.constexpr,
Expand All @@ -74,6 +91,7 @@ def matmul_kernel(
# This is done in a grouped ordering to promote L2 data reuse.
# TODO(vgokhale): Add XCD remapping.
pid = tl.program_id(axis=0)
pid_z = tl.program_id(1)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
if GROUP_SIZE_M == 1:
Expand All @@ -88,7 +106,10 @@ def matmul_kernel(
pid_n = (pid % num_pid_in_group) // group_size_m

# Create pointers for first block of A and B input matrices
offs_k = tl.arange(0, BLOCK_SIZE_K)
if SPLIT_K == 1:
offs_k = tl.arange(0, BLOCK_SIZE_K)
else:
offs_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
Expand All @@ -100,20 +121,21 @@ def matmul_kernel(
acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
a = tl.load(a_ptrs, mask=offs_k[None, :] < k_remaining, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
accumulator += tl.dot(a, b)

# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
# Apply scale to recover dynamic range reduced due to lower precision inputs.
if APPLY_SCALE:
accumulator = accumulator * a_scale * b_scale
Expand All @@ -128,7 +150,44 @@ def matmul_kernel(
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)

if SPLIT_K == 1:
tl.store(c_ptrs, c, mask=c_mask)
else:
c_buf_ptrs = c_buf_ptr + pid_z * M * N + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(c_buf_ptrs, accumulator, mask=c_mask)


@triton.jit
def splitK_reduce(c_ptr, c_buf_ptr, M, N, K, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, ACTIVATION: tl.constexpr):

pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for k in range(SPLIT_K):
c_block_ptrs = c_buf_ptr + k * M * N + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
partial_result = tl.load(c_block_ptrs, mask=c_mask)
accumulator += partial_result

if ACTIVATION == "leaky_relu":
accumulator = leaky_relu(accumulator)

c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(c_ptrs, accumulator, mask=c_mask)


# Activation function.
Expand All @@ -145,11 +204,14 @@ def matmul(a, b, c, a_scale, b_scale, scale_a8_b8=False, activation=""):
assert a.dtype == b.dtype, "Mixed dtype GEMMs are not supported!!!"
M, K = a.shape
K, N = b.shape
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
splitk = 1
c_buf = torch.empty((M, N, splitk), device=a.device, dtype=torch.float32)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), META['SPLIT_K'])
matmul_kernel[grid](
a,
b,
c,
c_buf,
M,
N,
K,
Expand All @@ -163,7 +225,10 @@ def matmul(a, b, c, a_scale, b_scale, scale_a8_b8=False, activation=""):
b_scale,
APPLY_SCALE=scale_a8_b8,
ACTIVATION=activation,
# SPLIT_K = splitk,
)
if splitk > 1:
c.copy_(torch.sum(c_buf, dim=2))


name_to_torch_types = {
Expand Down Expand Up @@ -212,9 +277,11 @@ def gen_input(M, N, dtype, needTrans, seed, device='cuda'):


def get_x_vals():
x_vals = [(1024 * v, 1024 * v, 1024 * v) for v in range(1, 9)]
# x_vals = [(1024 * v, 1024 * v, 1024 * v) for v in range(1, 9)]

# x_vals += [(4864, 4096, 8192), (9728, 8192, 65536), (4864, 8192, 4160)]

x_vals += [(4864, 4096, 8192), (9728, 8192, 65536), (4864, 8192, 4160)]
x_vals = [(1, 8192, 28672)]

return x_vals

Expand Down Expand Up @@ -259,17 +326,29 @@ def get_type(provider):
return res[0][1:-1]


def ms_to_gibps(M: int, N: int, K: int, milliseconds: float) -> float:
read_elems: int = M * K + K * N
write_elems: int = M * N
transf_elems: int = read_elems + write_elems
transf_bytes: int = 2 * transf_elems # times 2 due to fp16
transf_gibibytes: float = 2**-30 * transf_bytes
seconds: float = 1e-3 * milliseconds
return round(transf_gibibytes / seconds, 2)


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'],
x_vals=get_x_vals(),
line_arg='provider',
line_vals=[
'rocblas(fp16)', 'rocblas(bf16)', 'triton(fp16)', 'triton(bf16)', 'triton(int8)', 'triton(fp8e4)',
'triton(fp8e5)'
# 'rocblas(fp16)', 'rocblas(bf16)', 'triton(fp16)', 'triton(bf16)', 'triton(int8)', 'triton(fp8e4)',
# 'triton(fp8e5)'
'rocblas(fp16)', 'rocblas(bf16)', 'triton(fp16)'
],
line_names=[
"rocBLAS.Fp16", "rocBLAS.Bf16", "Triton.Fp16", "Triton.Bf16", "Triton.Int8", "Triton.Fp8E4", "Triton.Fp8E5"
# "rocBLAS.Fp16", "rocBLAS.Bf16", "Triton.Fp16", "Triton.Bf16", "Triton.Int8", "Triton.Fp8E4", "Triton.Fp8E5"
"rocBLAS.Fp16", "rocBLAS.Bf16", "Triton.Fp16"
],
ylabel="TFLOPS",
plot_name="matmul-performance",
Expand Down Expand Up @@ -299,7 +378,9 @@ def benchmark(M, N, K, provider):
quantiles=quantiles)
global verbose
if verbose:
print(f'SIZE: {M},{N},{K} Best tuning config: ({matmul_kernel.best_config()})')
gbps = ms_to_gibps(M, N, K, ms)
# print(f'SIZE: {M},{N},{K} Best tuning config: ({matmul_kernel.best_config()})')
print(f'SIZE: {M},{N},{K}, gbps: {gbps}')
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)

Expand Down
Loading