Skip to content

Commit

Permalink
[Software pipeline] Fix hardcoded index in access_ptr rewriting, ad…
Browse files Browse the repository at this point in the history
…d a GPU test with depth 4 (#11495)

* fixed hard-coded index in software pipeling

* fixed three-stage pipeline test

* add three stage pipelined gemm test

* refactor mma test

* use mma_4k schedule utility in test

* apply pipeling annotation

* black

* require ampere in test
  • Loading branch information
masahi authored May 28, 2022
1 parent afb67e6 commit 2389f1f
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 96 deletions.
110 changes: 110 additions & 0 deletions python/tvm/testing/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,113 @@ def render(e):
assert (
expected_error_text in errors
), f'check_error expects "{expected_error_text} in str(errors): {errors}'


def mma_schedule(
workload,
k_inner,
in_dtype,
b_transposed,
i_factors,
j_factors,
k_factors,
index_map_A,
index_map_B,
index_map_C,
ldmatrix_a_intrin,
ldmatrix_b_intrin,
mma_intrin,
mma_fill_intrin,
mma_store_intrin,
shared_scope="shared",
):
"""Create a tensorized schedule for GEMM with MMA intrinsics."""
ir_module = tvm.IRModule({"main": workload})
sch = tvm.tir.Schedule(ir_module)

block = sch.get_block("C")
i, j, k = sch.get_loops(block)
i, i_tc = sch.split(i, factors=[None, 16])
j, j_tc = sch.split(j, factors=[None, 16])
k, k_tc = sch.split(k, factors=[None, k_inner])

sch.reorder(i, j, k, i_tc, j_tc, k_tc)

block_inner = sch.blockize(i_tc)
block_outer, block_inner = block_inner, block

num_ty = i_factors[2] * j_factors[2]

i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors)
k0, k1, k2 = sch.split(k, k_factors)

sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3, k2, i4, j4)

block_idx = sch.fuse(i0, j0)
block_idy = sch.fuse(i1, j1)
thread_idy = sch.fuse(j2, i2)
sch.bind(block_idx, "blockIdx.x")
sch.bind(block_idy, "blockIdx.y")
sch.bind(thread_idy, "threadIdx.y")

def fetch_to_shared(block, idx, ndim):
block_read = sch.cache_read(block, idx, shared_scope)
sch.compute_at(block_read, k0)
vector_size = 16 if in_dtype == "int8" else 8
warp_size = 32
fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
_, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size])
sch.bind(f_2, "threadIdx.x")
sch.bind(f_1, "threadIdx.y")
sch.vectorize(f_3)
offset = 8 if in_dtype == "float16" else 16
sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset)

return block_read

fetch_to_shared(block_outer, 0, 2)
fetch_to_shared(block_outer, 1, 2)

A_warp = sch.cache_read(block_outer, 0, "warp")
B_warp = sch.cache_read(block_outer, 1, "warp")

sch.compute_at(A_warp, k1)
sch.compute_at(B_warp, k1)

C_warp = sch.cache_write(block_outer, 0, "warp")
sch.reverse_compute_at(C_warp, thread_idy)

ii, jj = sch.get_loops(C_warp)[-2:]
io, ii = sch.split(ii, factors=[None, 16])
jo, ji = sch.split(jj, factors=[None, 16])
sch.reorder(io, jo, ii, ji)

sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
block_init_c = sch.get_block("C_init")

def tile_wmma_fragment(block_read, height, width):
i, j = sch.get_loops(block_read)[-2:]
i0, i1 = sch.split(i, factors=[None, height])
j0, j1 = sch.split(j, factors=[None, width])
sch.reorder(i0, j0, i1, j1)
return i1

loop_a = tile_wmma_fragment(A_warp, 16, k_inner)

if b_transposed:
loop_b = tile_wmma_fragment(B_warp, 16, k_inner)
else:
loop_b = tile_wmma_fragment(B_warp, k_inner, 16)

sch.transform_layout(A_warp, ("write", 0), index_map_A)
sch.transform_layout(B_warp, ("write", 0), index_map_B)
sch.transform_layout(C_warp, ("read", 0), index_map_C)

sch.tensorize(loop_a, ldmatrix_a_intrin)
sch.tensorize(loop_b, ldmatrix_b_intrin)
sch.tensorize(sch.get_loops(block_inner)[-3], mma_intrin)
sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin)
sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin)

return sch
21 changes: 18 additions & 3 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind):
HALF_WARP_expr = lift(HALF_WARP)


def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"):
local_size = (M_DIM * k_dim) // WARP_SIZE
shared_offset = None
index_map = None
Expand Down Expand Up @@ -115,7 +115,12 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
@T.prim_func
def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
shared = T.match_buffer(
shared_handle, shmem_shape, dtype, align=128, offset_factor=16, scope="shared"
shared_handle,
shmem_shape,
dtype,
align=128,
offset_factor=16,
scope=shared_scope,
)
warp = T.match_buffer(
warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp"
Expand Down Expand Up @@ -144,7 +149,7 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
dtype,
align=128,
offset_factor=16,
scope="shared",
scope=shared_scope,
strides=[s0, s1],
)
warp = T.match_buffer(
Expand Down Expand Up @@ -412,6 +417,16 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None:
LDMATRIX_16x16_B_INTRIN = "mma.ldmatrix_16x16_b"
TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False))

LDMATRIX_16x16_A_DYN_INTRIN = "mma.ldmatrix_16x16_a_dyn"
TensorIntrin.register(
LDMATRIX_16x16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False, "shared.dyn")
)

LDMATRIX_16x16_B_DYN_INTRIN = "mma.ldmatrix_16x16_b_dyn"
TensorIntrin.register(
LDMATRIX_16x16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False, "shared.dyn")
)

LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans"
TensorIntrin.register(
LDMATRIX_16x16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", True, True)
Expand Down
3 changes: 2 additions & 1 deletion src/tir/transforms/inject_software_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ class PipelineOpaqueAccessRewriter {
} else {
offset = new_buffer->strides[0];
}
PrimExpr new_index = old_index + floormod(pipeline_loop_->loop_var, 2) * offset;
PrimExpr new_index =
old_index + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
new_args.Set(2, new_index);
return Call(call->dtype, call->op, new_args, call->span);
}
Expand Down
106 changes: 18 additions & 88 deletions tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
import tvm.testing
import numpy as np
from tvm.testing.tir import mma_schedule


M = 4096
Expand Down Expand Up @@ -98,94 +99,23 @@ def run_test(
mma_fill_intrin,
mma_store_intrin,
):
workload = te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, b_transposed))
ir_module = tvm.IRModule({"main": workload})
sch = tvm.tir.Schedule(ir_module)

block = sch.get_block("C")
i, j, k = sch.get_loops(block)
i, i_tc = sch.split(i, factors=[None, 16])
j, j_tc = sch.split(j, factors=[None, 16])
k, k_tc = sch.split(k, factors=[None, k_inner])

sch.reorder(i, j, k, i_tc, j_tc, k_tc)

block_inner = sch.blockize(i_tc)
block_outer, block_inner = block_inner, block

num_ty = i_factors[2] * j_factors[2]

i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors)
k0, k1, k2 = sch.split(k, k_factors)

sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3, k2, i4, j4)

block_idx = sch.fuse(i0, j0)
block_idy = sch.fuse(i1, j1)
thread_idy = sch.fuse(j2, i2)
sch.bind(block_idx, "blockIdx.x")
sch.bind(block_idy, "blockIdx.y")
sch.bind(thread_idy, "threadIdx.y")

def fetch_to_shared(block, idx, ndim):
block_read = sch.cache_read(block, idx, "shared")
sch.compute_at(block_read, k0)
vector_size = 16 if in_dtype == "int8" else 8
warp_size = 32
fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
_, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size])
sch.bind(f_2, "threadIdx.x")
sch.bind(f_1, "threadIdx.y")
sch.vectorize(f_3)
offset = 8 if in_dtype == "float16" else 16
sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset)

return block_read

fetch_to_shared(block_outer, 0, 2)
fetch_to_shared(block_outer, 1, 2)

A_warp = sch.cache_read(block_outer, 0, "warp")
B_warp = sch.cache_read(block_outer, 1, "warp")

sch.compute_at(A_warp, k1)
sch.compute_at(B_warp, k1)

C_warp = sch.cache_write(block_outer, 0, "warp")
sch.reverse_compute_at(C_warp, thread_idy)

ii, jj = sch.get_loops(C_warp)[-2:]
io, ii = sch.split(ii, factors=[None, 16])
jo, ji = sch.split(jj, factors=[None, 16])
sch.reorder(io, jo, ii, ji)

sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
block_init_c = sch.get_block("C_init")

def tile_wmma_fragment(block_read, height, width):
i, j = sch.get_loops(block_read)[-2:]
i0, i1 = sch.split(i, factors=[None, height])
j0, j1 = sch.split(j, factors=[None, width])
sch.reorder(i0, j0, i1, j1)
return i1

loop_a = tile_wmma_fragment(A_warp, 16, k_inner)

if b_transposed:
loop_b = tile_wmma_fragment(B_warp, 16, k_inner)
else:
loop_b = tile_wmma_fragment(B_warp, k_inner, 16)

sch.transform_layout(A_warp, ("write", 0), index_map_A)
sch.transform_layout(B_warp, ("write", 0), index_map_B)
sch.transform_layout(C_warp, ("read", 0), index_map_C)

sch.tensorize(loop_a, ldmatrix_a_intrin)
sch.tensorize(loop_b, ldmatrix_b_intrin)
sch.tensorize(sch.get_loops(block_inner)[-3], mma_intrin)
sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin)
sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin)
sch = mma_schedule(
te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, b_transposed)),
k_inner,
in_dtype,
b_transposed,
i_factors,
j_factors,
k_factors,
index_map_A,
index_map_B,
index_map_C,
ldmatrix_a_intrin,
ldmatrix_b_intrin,
mma_intrin,
mma_fill_intrin,
mma_store_intrin,
)

if not is_ampere_or_newer():
return None
Expand Down
Loading

0 comments on commit 2389f1f

Please sign in to comment.