Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Software pipeline] Fix hardcoded index in access_ptr rewriting, add a GPU test with depth 4 #11495

Merged
merged 8 commits into from
May 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions python/tvm/testing/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,113 @@ def render(e):
assert (
expected_error_text in errors
), f'check_error expects "{expected_error_text} in str(errors): {errors}'


def mma_schedule(
workload,
k_inner,
in_dtype,
b_transposed,
i_factors,
j_factors,
k_factors,
index_map_A,
index_map_B,
index_map_C,
ldmatrix_a_intrin,
ldmatrix_b_intrin,
mma_intrin,
mma_fill_intrin,
mma_store_intrin,
shared_scope="shared",
):
"""Create a tensorized schedule for GEMM with MMA intrinsics."""
ir_module = tvm.IRModule({"main": workload})
sch = tvm.tir.Schedule(ir_module)

block = sch.get_block("C")
i, j, k = sch.get_loops(block)
i, i_tc = sch.split(i, factors=[None, 16])
j, j_tc = sch.split(j, factors=[None, 16])
k, k_tc = sch.split(k, factors=[None, k_inner])

sch.reorder(i, j, k, i_tc, j_tc, k_tc)

block_inner = sch.blockize(i_tc)
block_outer, block_inner = block_inner, block

num_ty = i_factors[2] * j_factors[2]

i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors)
k0, k1, k2 = sch.split(k, k_factors)

sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3, k2, i4, j4)

block_idx = sch.fuse(i0, j0)
block_idy = sch.fuse(i1, j1)
thread_idy = sch.fuse(j2, i2)
sch.bind(block_idx, "blockIdx.x")
sch.bind(block_idy, "blockIdx.y")
sch.bind(thread_idy, "threadIdx.y")

def fetch_to_shared(block, idx, ndim):
block_read = sch.cache_read(block, idx, shared_scope)
sch.compute_at(block_read, k0)
vector_size = 16 if in_dtype == "int8" else 8
warp_size = 32
fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
_, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size])
sch.bind(f_2, "threadIdx.x")
sch.bind(f_1, "threadIdx.y")
sch.vectorize(f_3)
offset = 8 if in_dtype == "float16" else 16
sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset)

return block_read

fetch_to_shared(block_outer, 0, 2)
fetch_to_shared(block_outer, 1, 2)

A_warp = sch.cache_read(block_outer, 0, "warp")
B_warp = sch.cache_read(block_outer, 1, "warp")

sch.compute_at(A_warp, k1)
sch.compute_at(B_warp, k1)

C_warp = sch.cache_write(block_outer, 0, "warp")
sch.reverse_compute_at(C_warp, thread_idy)

ii, jj = sch.get_loops(C_warp)[-2:]
io, ii = sch.split(ii, factors=[None, 16])
jo, ji = sch.split(jj, factors=[None, 16])
sch.reorder(io, jo, ii, ji)

sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
block_init_c = sch.get_block("C_init")

def tile_wmma_fragment(block_read, height, width):
i, j = sch.get_loops(block_read)[-2:]
i0, i1 = sch.split(i, factors=[None, height])
j0, j1 = sch.split(j, factors=[None, width])
sch.reorder(i0, j0, i1, j1)
return i1

loop_a = tile_wmma_fragment(A_warp, 16, k_inner)

if b_transposed:
loop_b = tile_wmma_fragment(B_warp, 16, k_inner)
else:
loop_b = tile_wmma_fragment(B_warp, k_inner, 16)

sch.transform_layout(A_warp, ("write", 0), index_map_A)
sch.transform_layout(B_warp, ("write", 0), index_map_B)
sch.transform_layout(C_warp, ("read", 0), index_map_C)

sch.tensorize(loop_a, ldmatrix_a_intrin)
sch.tensorize(loop_b, ldmatrix_b_intrin)
sch.tensorize(sch.get_loops(block_inner)[-3], mma_intrin)
sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin)
sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin)

return sch
21 changes: 18 additions & 3 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind):
HALF_WARP_expr = lift(HALF_WARP)


def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"):
local_size = (M_DIM * k_dim) // WARP_SIZE
shared_offset = None
index_map = None
Expand Down Expand Up @@ -115,7 +115,12 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
@T.prim_func
def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
shared = T.match_buffer(
shared_handle, shmem_shape, dtype, align=128, offset_factor=16, scope="shared"
shared_handle,
shmem_shape,
dtype,
align=128,
offset_factor=16,
scope=shared_scope,
)
warp = T.match_buffer(
warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp"
Expand Down Expand Up @@ -144,7 +149,7 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
dtype,
align=128,
offset_factor=16,
scope="shared",
scope=shared_scope,
strides=[s0, s1],
)
warp = T.match_buffer(
Expand Down Expand Up @@ -412,6 +417,16 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None:
LDMATRIX_16x16_B_INTRIN = "mma.ldmatrix_16x16_b"
TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False))

LDMATRIX_16x16_A_DYN_INTRIN = "mma.ldmatrix_16x16_a_dyn"
TensorIntrin.register(
LDMATRIX_16x16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False, "shared.dyn")
)

LDMATRIX_16x16_B_DYN_INTRIN = "mma.ldmatrix_16x16_b_dyn"
TensorIntrin.register(
LDMATRIX_16x16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False, "shared.dyn")
)

LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans"
TensorIntrin.register(
LDMATRIX_16x16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", True, True)
Expand Down
3 changes: 2 additions & 1 deletion src/tir/transforms/inject_software_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ class PipelineOpaqueAccessRewriter {
} else {
offset = new_buffer->strides[0];
}
PrimExpr new_index = old_index + floormod(pipeline_loop_->loop_var, 2) * offset;
PrimExpr new_index =
old_index + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
new_args.Set(2, new_index);
return Call(call->dtype, call->op, new_args, call->span);
}
Expand Down
106 changes: 18 additions & 88 deletions tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
import tvm.testing
import numpy as np
from tvm.testing.tir import mma_schedule


M = 4096
Expand Down Expand Up @@ -98,94 +99,23 @@ def run_test(
mma_fill_intrin,
mma_store_intrin,
):
workload = te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, b_transposed))
ir_module = tvm.IRModule({"main": workload})
sch = tvm.tir.Schedule(ir_module)

block = sch.get_block("C")
i, j, k = sch.get_loops(block)
i, i_tc = sch.split(i, factors=[None, 16])
j, j_tc = sch.split(j, factors=[None, 16])
k, k_tc = sch.split(k, factors=[None, k_inner])

sch.reorder(i, j, k, i_tc, j_tc, k_tc)

block_inner = sch.blockize(i_tc)
block_outer, block_inner = block_inner, block

num_ty = i_factors[2] * j_factors[2]

i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors)
k0, k1, k2 = sch.split(k, k_factors)

sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3, k2, i4, j4)

block_idx = sch.fuse(i0, j0)
block_idy = sch.fuse(i1, j1)
thread_idy = sch.fuse(j2, i2)
sch.bind(block_idx, "blockIdx.x")
sch.bind(block_idy, "blockIdx.y")
sch.bind(thread_idy, "threadIdx.y")

def fetch_to_shared(block, idx, ndim):
block_read = sch.cache_read(block, idx, "shared")
sch.compute_at(block_read, k0)
vector_size = 16 if in_dtype == "int8" else 8
warp_size = 32
fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
_, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size])
sch.bind(f_2, "threadIdx.x")
sch.bind(f_1, "threadIdx.y")
sch.vectorize(f_3)
offset = 8 if in_dtype == "float16" else 16
sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset)

return block_read

fetch_to_shared(block_outer, 0, 2)
fetch_to_shared(block_outer, 1, 2)

A_warp = sch.cache_read(block_outer, 0, "warp")
B_warp = sch.cache_read(block_outer, 1, "warp")

sch.compute_at(A_warp, k1)
sch.compute_at(B_warp, k1)

C_warp = sch.cache_write(block_outer, 0, "warp")
sch.reverse_compute_at(C_warp, thread_idy)

ii, jj = sch.get_loops(C_warp)[-2:]
io, ii = sch.split(ii, factors=[None, 16])
jo, ji = sch.split(jj, factors=[None, 16])
sch.reorder(io, jo, ii, ji)

sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
block_init_c = sch.get_block("C_init")

def tile_wmma_fragment(block_read, height, width):
i, j = sch.get_loops(block_read)[-2:]
i0, i1 = sch.split(i, factors=[None, height])
j0, j1 = sch.split(j, factors=[None, width])
sch.reorder(i0, j0, i1, j1)
return i1

loop_a = tile_wmma_fragment(A_warp, 16, k_inner)

if b_transposed:
loop_b = tile_wmma_fragment(B_warp, 16, k_inner)
else:
loop_b = tile_wmma_fragment(B_warp, k_inner, 16)

sch.transform_layout(A_warp, ("write", 0), index_map_A)
sch.transform_layout(B_warp, ("write", 0), index_map_B)
sch.transform_layout(C_warp, ("read", 0), index_map_C)

sch.tensorize(loop_a, ldmatrix_a_intrin)
sch.tensorize(loop_b, ldmatrix_b_intrin)
sch.tensorize(sch.get_loops(block_inner)[-3], mma_intrin)
sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin)
sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin)
sch = mma_schedule(
te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, b_transposed)),
k_inner,
in_dtype,
b_transposed,
i_factors,
j_factors,
k_factors,
index_map_A,
index_map_B,
index_map_C,
ldmatrix_a_intrin,
ldmatrix_b_intrin,
mma_intrin,
mma_fill_intrin,
mma_store_intrin,
)

if not is_ampere_or_newer():
return None
Expand Down
Loading