Skip to content

Commit

Permalink
add m16n8k32 testcase
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Feb 22, 2022
1 parent c6d029f commit 0cb15a3
Showing 1 changed file with 99 additions and 3 deletions.
102 changes: 99 additions & 3 deletions tests/python/unittest/test_tir_ptx_mma_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,29 @@ def get_meta_m16n8k16_half(mask):

for i in range(8):
base = 1
for k in range(2):
for blk in range(2):
for j in range(8):
ret[i] |= int(mask[k * 8 + i, j]) * base
ret[i] |= int(mask[blk * 8 + i, j]) * base
base = base << 2
return ret


def get_meta_m16n8k32_half(mask):
assert mask.shape == (16, 8, 2)
mask = mask.reshape(16, 2, 8)
ret = np.zeros((8, 2)).astype("uint32")

for i in range(8):
for k in range(2):
base = 1
for blk in range(2):
for j in range(8):
ret[i, k] |= int(mask[blk * 8 + i, k, j]) * base
base = base << 2

return ret.reshape(16)


@T.prim_func
def mma_sp_m16n8k16_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle):
T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
Expand Down Expand Up @@ -109,6 +125,59 @@ def mma_sp_m16n8k16_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.han
C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float16", accum, i)


@T.prim_func
def mma_sp_m16n8k32_fp16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle):
T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
A = T.match_buffer(a, [16, 16], dtype="float16")
B = T.match_buffer(b, [32, 8], dtype="float16")
C = T.match_buffer(c, [16, 8], dtype="float16")
metadata = T.match_buffer(_metadata, [16], dtype="uint32")
brow = T.env_thread("blockIdx.y")
bcol = T.env_thread("blockIdx.x")
tx = T.env_thread("threadIdx.x")
T.launch_thread(brow, 1)
T.launch_thread(bcol, 1)
T.launch_thread(tx, 32)
multi_a = T.allocate([8], "float16", scope="local")
multi_b = T.allocate([8], "float16", scope="local")
accum = T.allocate([4], "float16", scope="local")
meta_local = T.allocate([1], "uint32", scope="local")
for i in range(4):
accum[i] = T.float16(0)

for i in range(8):
multi_a[i] = A[(i % 4) // 2 * 8 + tx // 4, i // 4 * 8 + tx % 4 * 2 + i % 2]

for i in range(8):
multi_b[i] = B[i // 2 * 8 + tx % 4 * 2 + i % 2, tx // 4]

meta_local[0] = metadata[tx // 4 * 2 + tx % 2]

T.evaluate(
T.ptx_mma_sp(
"m16n8k32",
"row",
"col",
"fp16",
"fp16",
"fp16",
multi_a,
0,
multi_b,
0,
accum,
0,
meta_local,
0,
False,
dtype="float16",
)
)

for i in range(4):
C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = T.load("float16", accum, i)


@tvm.testing.requires_cuda
def test_mma_sp_m16n8k16_fp16():
sch = tvm.tir.Schedule(mma_sp_m16n8k16_fp16)
Expand All @@ -118,7 +187,6 @@ def test_mma_sp_m16n8k16_fp16():
# Requires SM80+
return
cuda_mod = tvm.build(sch.mod, target="cuda")
print(cuda_mod.imported_modules[0].get_source())

A_np = np.random.uniform(-1, 1, [16, 8]).astype("float16")
B_np = np.random.uniform(-1, 1, [16, 8]).astype("float16")
Expand All @@ -137,5 +205,33 @@ def test_mma_sp_m16n8k16_fp16():
tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3)


@tvm.testing.requires_cuda
def test_mma_sp_m16n8k32_fp16():
sch = tvm.tir.Schedule(mma_sp_m16n8k32_fp16)
arch = tvm.contrib.nvcc.get_target_compute_version()
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
if major < 8:
# Requires SM80+
return
cuda_mod = tvm.build(sch.mod, target="cuda")

A_np = np.random.uniform(-1, 1, [16, 16]).astype("float16")
B_np = np.random.uniform(-1, 1, [32, 8]).astype("float16")
mask = gen_2in4_mask(16, 32)
A_dense_np = get_dense_mat_by_mask(A_np, mask)
C_np = np.matmul(A_dense_np, B_np)
meta = get_meta_m16n8k32_half(mask)

ctx = tvm.cuda()
A_tvm = tvm.nd.array(A_np, ctx)
B_tvm = tvm.nd.array(B_np, ctx)
C_tvm = tvm.nd.array(np.zeros_like(C_np), ctx)
meta_tvm = tvm.nd.array(meta, ctx)
cuda_mod(A_tvm, B_tvm, C_tvm, meta_tvm)

tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3)


if __name__ == "__main__":
test_mma_sp_m16n8k16_fp16()
test_mma_sp_m16n8k32_fp16()

0 comments on commit 0cb15a3

Please sign in to comment.