diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py index cedaafe80a52..8dd482673829 100644 --- a/python/tvm/testing/tir.py +++ b/python/tvm/testing/tir.py @@ -59,3 +59,113 @@ def render(e): assert ( expected_error_text in errors ), f'check_error expects "{expected_error_text} in str(errors): {errors}' + + +def mma_schedule( + workload, + k_inner, + in_dtype, + b_transposed, + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + ldmatrix_a_intrin, + ldmatrix_b_intrin, + mma_intrin, + mma_fill_intrin, + mma_store_intrin, + shared_scope="shared", +): + """Create a tensorized schedule for GEMM with MMA intrinsics.""" + ir_module = tvm.IRModule({"main": workload}) + sch = tvm.tir.Schedule(ir_module) + + block = sch.get_block("C") + i, j, k = sch.get_loops(block) + i, i_tc = sch.split(i, factors=[None, 16]) + j, j_tc = sch.split(j, factors=[None, 16]) + k, k_tc = sch.split(k, factors=[None, k_inner]) + + sch.reorder(i, j, k, i_tc, j_tc, k_tc) + + block_inner = sch.blockize(i_tc) + block_outer, block_inner = block_inner, block + + num_ty = i_factors[2] * j_factors[2] + + i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors) + j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors) + k0, k1, k2 = sch.split(k, k_factors) + + sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3, k2, i4, j4) + + block_idx = sch.fuse(i0, j0) + block_idy = sch.fuse(i1, j1) + thread_idy = sch.fuse(j2, i2) + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + + def fetch_to_shared(block, idx, ndim): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, k0) + vector_size = 16 if in_dtype == "int8" else 8 + warp_size = 32 + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_3) + offset = 8 if in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset) + + return block_read + + fetch_to_shared(block_outer, 0, 2) + fetch_to_shared(block_outer, 1, 2) + + A_warp = sch.cache_read(block_outer, 0, "warp") + B_warp = sch.cache_read(block_outer, 1, "warp") + + sch.compute_at(A_warp, k1) + sch.compute_at(B_warp, k1) + + C_warp = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(C_warp, thread_idy) + + ii, jj = sch.get_loops(C_warp)[-2:] + io, ii = sch.split(ii, factors=[None, 16]) + jo, ji = sch.split(jj, factors=[None, 16]) + sch.reorder(io, jo, ii, ji) + + sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) + block_init_c = sch.get_block("C_init") + + def tile_wmma_fragment(block_read, height, width): + i, j = sch.get_loops(block_read)[-2:] + i0, i1 = sch.split(i, factors=[None, height]) + j0, j1 = sch.split(j, factors=[None, width]) + sch.reorder(i0, j0, i1, j1) + return i1 + + loop_a = tile_wmma_fragment(A_warp, 16, k_inner) + + if b_transposed: + loop_b = tile_wmma_fragment(B_warp, 16, k_inner) + else: + loop_b = tile_wmma_fragment(B_warp, k_inner, 16) + + sch.transform_layout(A_warp, ("write", 0), index_map_A) + sch.transform_layout(B_warp, ("write", 0), index_map_B) + sch.transform_layout(C_warp, ("read", 0), index_map_C) + + sch.tensorize(loop_a, ldmatrix_a_intrin) + sch.tensorize(loop_b, ldmatrix_b_intrin) + sch.tensorize(sch.get_loops(block_inner)[-3], mma_intrin) + sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin) + sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin) + + return sch diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 853a37735486..c5883fd072c5 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -54,7 +54,7 @@ def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind): HALF_WARP_expr = lift(HALF_WARP) -def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed): +def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"): local_size = (M_DIM * k_dim) // WARP_SIZE shared_offset = None index_map = None @@ -115,7 +115,12 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed): @T.prim_func def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None: shared = T.match_buffer( - shared_handle, shmem_shape, dtype, align=128, offset_factor=16, scope="shared" + shared_handle, + shmem_shape, + dtype, + align=128, + offset_factor=16, + scope=shared_scope, ) warp = T.match_buffer( warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp" @@ -144,7 +149,7 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: dtype, align=128, offset_factor=16, - scope="shared", + scope=shared_scope, strides=[s0, s1], ) warp = T.match_buffer( @@ -412,6 +417,16 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: LDMATRIX_16x16_B_INTRIN = "mma.ldmatrix_16x16_b" TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False)) +LDMATRIX_16x16_A_DYN_INTRIN = "mma.ldmatrix_16x16_a_dyn" +TensorIntrin.register( + LDMATRIX_16x16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False, "shared.dyn") +) + +LDMATRIX_16x16_B_DYN_INTRIN = "mma.ldmatrix_16x16_b_dyn" +TensorIntrin.register( + LDMATRIX_16x16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False, "shared.dyn") +) + LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans" TensorIntrin.register( LDMATRIX_16x16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", True, True) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 7402d6426bc2..de9aa79583b4 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -135,7 +135,8 @@ class PipelineOpaqueAccessRewriter { } else { offset = new_buffer->strides[0]; } - PrimExpr new_index = old_index + floormod(pipeline_loop_->loop_var, 2) * offset; + PrimExpr new_index = + old_index + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset; new_args.Set(2, new_index); return Call(call->dtype, call->op, new_args, call->span); } diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py index e9ee990a2415..9feb994e7158 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -42,6 +42,7 @@ ) import tvm.testing import numpy as np +from tvm.testing.tir import mma_schedule M = 4096 @@ -98,94 +99,23 @@ def run_test( mma_fill_intrin, mma_store_intrin, ): - workload = te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, b_transposed)) - ir_module = tvm.IRModule({"main": workload}) - sch = tvm.tir.Schedule(ir_module) - - block = sch.get_block("C") - i, j, k = sch.get_loops(block) - i, i_tc = sch.split(i, factors=[None, 16]) - j, j_tc = sch.split(j, factors=[None, 16]) - k, k_tc = sch.split(k, factors=[None, k_inner]) - - sch.reorder(i, j, k, i_tc, j_tc, k_tc) - - block_inner = sch.blockize(i_tc) - block_outer, block_inner = block_inner, block - - num_ty = i_factors[2] * j_factors[2] - - i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors) - j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors) - k0, k1, k2 = sch.split(k, k_factors) - - sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3, k2, i4, j4) - - block_idx = sch.fuse(i0, j0) - block_idy = sch.fuse(i1, j1) - thread_idy = sch.fuse(j2, i2) - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - - def fetch_to_shared(block, idx, ndim): - block_read = sch.cache_read(block, idx, "shared") - sch.compute_at(block_read, k0) - vector_size = 16 if in_dtype == "int8" else 8 - warp_size = 32 - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) - sch.bind(f_2, "threadIdx.x") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_3) - offset = 8 if in_dtype == "float16" else 16 - sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset) - - return block_read - - fetch_to_shared(block_outer, 0, 2) - fetch_to_shared(block_outer, 1, 2) - - A_warp = sch.cache_read(block_outer, 0, "warp") - B_warp = sch.cache_read(block_outer, 1, "warp") - - sch.compute_at(A_warp, k1) - sch.compute_at(B_warp, k1) - - C_warp = sch.cache_write(block_outer, 0, "warp") - sch.reverse_compute_at(C_warp, thread_idy) - - ii, jj = sch.get_loops(C_warp)[-2:] - io, ii = sch.split(ii, factors=[None, 16]) - jo, ji = sch.split(jj, factors=[None, 16]) - sch.reorder(io, jo, ii, ji) - - sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) - block_init_c = sch.get_block("C_init") - - def tile_wmma_fragment(block_read, height, width): - i, j = sch.get_loops(block_read)[-2:] - i0, i1 = sch.split(i, factors=[None, height]) - j0, j1 = sch.split(j, factors=[None, width]) - sch.reorder(i0, j0, i1, j1) - return i1 - - loop_a = tile_wmma_fragment(A_warp, 16, k_inner) - - if b_transposed: - loop_b = tile_wmma_fragment(B_warp, 16, k_inner) - else: - loop_b = tile_wmma_fragment(B_warp, k_inner, 16) - - sch.transform_layout(A_warp, ("write", 0), index_map_A) - sch.transform_layout(B_warp, ("write", 0), index_map_B) - sch.transform_layout(C_warp, ("read", 0), index_map_C) - - sch.tensorize(loop_a, ldmatrix_a_intrin) - sch.tensorize(loop_b, ldmatrix_b_intrin) - sch.tensorize(sch.get_loops(block_inner)[-3], mma_intrin) - sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin) - sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin) + sch = mma_schedule( + te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, b_transposed)), + k_inner, + in_dtype, + b_transposed, + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + ldmatrix_a_intrin, + ldmatrix_b_intrin, + mma_intrin, + mma_fill_intrin, + mma_store_intrin, + ) if not is_ampere_or_newer(): return None diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 50f96d052b14..fddda05eb5b0 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -16,11 +16,23 @@ # under the License. import pytest import sys +import numpy as np import tvm import tvm.testing +import tvm.tir.tensor_intrin.cuda from tvm import tir, te, TVMError from tvm.script import tir as T +from tvm.meta_schedule.testing import te_workload +from tvm.testing.tir import mma_schedule +from tvm.tir.tensor_intrin.cuda import ( + LDMATRIX_16x16_A_DYN_INTRIN, + LDMATRIX_16x16_B_DYN_INTRIN, + MMA_f16f16f32_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + shared_16x16_to_ldmatrix_32x8_layout, +) def _check(original, transformed): @@ -156,7 +168,7 @@ def three_stage_compute(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), with T.block(): T.reads(B[tx, 0]) T.writes(C[tx, 0]) - C[tx, 0] = A[tx, 0] + T.float32(2) + C[tx, 0] = B[tx, 0] + T.float32(2) with T.block(): T.reads(C[tx, 0]) T.writes(D[tx, i]) @@ -185,7 +197,7 @@ def transformed_three_stage_compute( T.where(1 <= i) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) - C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2) + C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14]) @@ -197,7 +209,7 @@ def transformed_three_stage_compute( with T.block(): T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) - C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2) + C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(C[0:2, tx, 0]) T.writes(D[tx, i]) @@ -210,7 +222,7 @@ def transformed_three_stage_compute( T.where(i < 1) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) - C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2) + C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(C[0:2, tx, 0]) T.writes(D[tx, i + 14]) @@ -1022,5 +1034,62 @@ def test_error_missing_annotation(): _check_error(simple_compute_missing_annotation) +@tvm.testing.requires_cuda +def test_three_stage_gemm(): + N = K = M = 4096 + i_factors, j_factors, k_factors = [4, 8, 2, 4, 1], [1, 64, 2, 1, 2], [128, 2, 1] + + def is_ampere_or_newer(): + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + return major >= 8 + + def index_map(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + workload = te.create_prim_func(te_workload.matmul_fp16(N, M, K)) + + sch = mma_schedule( + workload, + 16, + "float16", + False, + i_factors, + j_factors, + k_factors, + index_map, + index_map, + index_map, + LDMATRIX_16x16_A_DYN_INTRIN, + LDMATRIX_16x16_B_DYN_INTRIN, + MMA_f16f16f32_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + "shared.dyn", + ) + + k0 = sch.get_loops(sch.get_block("C_o_update"))[3] + + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + + if is_ampere_or_newer(): + f = tvm.build(sch.mod["main"], target="cuda") + + dev = tvm.device("cuda", 0) + a_np = np.random.uniform(size=(N, K)).astype("float16") + b_np = np.random.uniform(size=(K, M)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev) + f(a, b, c) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + + if __name__ == "__main__": tvm.testing.main()