From d4769f79c5c37933cd9f20798221bf7e0a897450 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 27 May 2022 05:43:55 +0900 Subject: [PATCH 1/8] fixed hard-coded index in software pipeling --- src/tir/transforms/inject_software_pipeline.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 7402d6426bc2..de9aa79583b4 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -135,7 +135,8 @@ class PipelineOpaqueAccessRewriter { } else { offset = new_buffer->strides[0]; } - PrimExpr new_index = old_index + floormod(pipeline_loop_->loop_var, 2) * offset; + PrimExpr new_index = + old_index + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset; new_args.Set(2, new_index); return Call(call->dtype, call->op, new_args, call->span); } From cf4dc363301c932faacc61d8811a285e3b9bcb49 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 27 May 2022 06:21:11 +0900 Subject: [PATCH 2/8] fixed three-stage pipeline test --- .../test_tir_transform_inject_software_pipeline.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 50f96d052b14..f9d304f9915a 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -156,7 +156,7 @@ def three_stage_compute(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), with T.block(): T.reads(B[tx, 0]) T.writes(C[tx, 0]) - C[tx, 0] = A[tx, 0] + T.float32(2) + C[tx, 0] = B[tx, 0] + T.float32(2) with T.block(): T.reads(C[tx, 0]) T.writes(D[tx, i]) @@ -185,7 +185,7 @@ def transformed_three_stage_compute( T.where(1 <= i) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) - C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2) + C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14]) @@ -197,7 +197,7 @@ def transformed_three_stage_compute( with T.block(): T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) - C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2) + C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(C[0:2, tx, 0]) T.writes(D[tx, i]) @@ -210,7 +210,7 @@ def transformed_three_stage_compute( T.where(i < 1) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) - C[(i + 1) % 2, tx, 0] = A[tx, 0] + T.float32(2) + C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(C[0:2, tx, 0]) T.writes(D[tx, i + 14]) @@ -1023,4 +1023,5 @@ def test_error_missing_annotation(): if __name__ == "__main__": - tvm.testing.main() + # tvm.testing.main() + test_three_stage_compute() From 869fb972d59cd5fa72132964303db441a3ac3e59 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 27 May 2022 09:02:13 +0900 Subject: [PATCH 3/8] add three stage pipelined gemm test --- ..._tir_transform_inject_software_pipeline.py | 409 +++++++++++++++++- 1 file changed, 408 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index f9d304f9915a..77e98a4c02ba 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -16,9 +16,11 @@ # under the License. import pytest import sys +import numpy as np import tvm import tvm.testing +import tvm.tir.tensor_intrin.cuda from tvm import tir, te, TVMError from tvm.script import tir as T @@ -1022,6 +1024,411 @@ def test_error_missing_annotation(): _check_error(simple_compute_missing_annotation) +def test_three_stage_gemm(): + @tvm.script.ir_module + class Module_pipelined: + @T.prim_func + def main( + A: T.Buffer[(4096, 4096), "float16"], + B: T.Buffer[(4096, 4096), "float16"], + C: T.Buffer[(4096, 4096), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # var definition + tx = T.env_thread("threadIdx.x") + s0 = T.var("int32") + s0_1 = T.var("int32") + s0_2 = T.var("int32") + s1 = T.var("int32") + s1_1 = T.var("int32") + s1_2 = T.var("int32") + # body + # with T.block("root") + A_shared = T.alloc_buffer([4096, 4096], dtype="float16", scope="shared.dyn") + B_shared = T.alloc_buffer([4096, 4096], dtype="float16", scope="shared.dyn") + A_shared_warp = T.alloc_buffer([256, 256, 32, 8], dtype="float16", scope="warp") + B_shared_warp = T.alloc_buffer([256, 256, 32, 8], dtype="float16", scope="warp") + C_warp = T.alloc_buffer([256, 256, 32, 8], dtype="float32", scope="warp") + for i0_0_0_i1_0_0_fused in T.thread_binding(4, thread="blockIdx.x"): + for i0_0_1_i1_0_1_fused in T.thread_binding(512, thread="blockIdx.y"): + for i1_0_2_i0_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): + for i0_0_3_init, i1_0_4_init in T.grid(4, 2): + with T.block("C_o_init"): + i_o = T.axis.spatial( + 256, + i0_0_0_i1_0_0_fused * 64 + + i0_0_1_i1_0_1_fused // 64 * 8 + + i1_0_2_i0_0_2_fused % 2 * 4 + + i0_0_3_init, + ) + j_o = T.axis.spatial( + 256, + i0_0_1_i1_0_1_fused % 64 * 4 + + i1_0_2_i0_0_2_fused // 2 * 2 + + i1_0_4_init, + ) + T.reads() + T.writes(C_warp[i_o, j_o, 0:32, 0:8]) + with T.block("C_init_o"): + i_init_o = T.axis.spatial(1, 0) + j_init_o = T.axis.spatial(1, 0) + T.reads() + T.writes(C_warp[i_o, j_o, 0:32, 0:8]) + C_warp_1 = T.match_buffer( + C_warp[i_o, j_o, 0:32, 0:8], + [32, 8], + dtype="float32", + scope="warp", + offset_factor=1, + ) + T.launch_thread(tx, 32) + T.evaluate( + T.mma_fill( + 8, C_warp_1.data, C_warp_1.elem_offset, dtype="float32" + ) + ) + for i2_0_0 in T.serial( + 128, + annotations={ + "software_pipeline_order": [0, 1, 2], + "software_pipeline_stage": [0, 0, 3], + }, + ): + for ax0_ax1_fused_0 in T.serial(4): + for ax0_ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding( + 32, thread="threadIdx.x" + ): + for ax0_ax1_fused_3 in T.vectorized(8): + with T.block("A_shared"): + v0 = T.axis.spatial( + 4096, + i0_0_0_i1_0_0_fused * 1024 + + i0_0_1_i1_0_1_fused // 64 * 128 + + ( + ax0_ax1_fused_0 * 1024 + + ax0_ax1_fused_1 * 256 + + ax0_ax1_fused_2 * 8 + + ax0_ax1_fused_3 + ) + // 32, + ) + v1 = T.axis.spatial( + 4096, + i2_0_0 * 32 + + ( + ax0_ax1_fused_0 * 1024 + + ax0_ax1_fused_1 * 256 + + ax0_ax1_fused_2 * 8 + + ax0_ax1_fused_3 + ) + % 32, + ) + T.reads(A[v0, v1]) + T.writes(A_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(2): + for ax0_ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding( + 32, thread="threadIdx.x" + ): + for ax0_ax1_fused_3 in T.vectorized(8): + with T.block("B_shared"): + v0 = T.axis.spatial( + 4096, + i2_0_0 * 32 + + ( + ax0_ax1_fused_0 * 1024 + + ax0_ax1_fused_1 * 256 + + ax0_ax1_fused_2 * 8 + + ax0_ax1_fused_3 + ) + // 64, + ) + v1 = T.axis.spatial( + 4096, + i0_0_1_i1_0_1_fused % 64 * 64 + + ( + ax0_ax1_fused_0 * 1024 + + ax0_ax1_fused_1 * 256 + + ax0_ax1_fused_2 * 8 + + ax0_ax1_fused_3 + ) + % 64, + ) + T.reads(B[v0, v1]) + T.writes(B_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) + B_shared[v0, v1] = B[v0, v1] + for i2_0_1 in T.serial(2): + for ax0_0, ax1_0 in T.grid(4, 1): + with T.block("A_shared_warp_o"): + v0_o = T.axis.spatial( + 256, + i0_0_0_i1_0_0_fused * 64 + + i0_0_1_i1_0_1_fused // 64 * 8 + + i1_0_2_i0_0_2_fused % 2 * 4 + + ax0_0, + ) + v1_o = T.axis.spatial(256, i2_0_0 * 2 + i2_0_1) + T.reads( + A_shared[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ] + ) + T.writes(A_shared_warp[v0_o, v1_o, 0:32, 0:8]) + warp = T.match_buffer( + A_shared_warp[v0_o, v1_o, 0:32, 0:8], + [32, 8], + dtype="float16", + scope="warp", + offset_factor=16, + ) + shared = T.match_buffer( + A_shared[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ], + [16, 16], + dtype="float16", + strides=[s0, s1], + scope="shared.dyn", + offset_factor=16, + ) + T.launch_thread(tx, 32) + T.evaluate( + T.ptx_ldmatrix( + False, + 4, + ".b16", + warp.data, + warp.elem_offset + 8 * tx, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + shared.data, + shared.elem_offset, + s0 * 16, + 1, + dtype="handle", + ), + s0 * (tx % 16) + 8 * (tx // 16), + dtype="float16", + ) + ) + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("B_shared_warp_o"): + v0_o = T.axis.spatial(256, i2_0_0 * 2 + i2_0_1) + v1_o = T.axis.spatial( + 256, + i0_0_1_i1_0_1_fused % 64 * 4 + + i1_0_2_i0_0_2_fused // 2 * 2 + + ax1_0, + ) + T.reads( + B_shared[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ] + ) + T.writes(B_shared_warp[v0_o, v1_o, 0:32, 0:8]) + warp_1 = T.match_buffer( + B_shared_warp[v0_o, v1_o, 0:32, 0:8], + [32, 8], + dtype="float16", + scope="warp", + offset_factor=16, + ) + shared_1 = T.match_buffer( + B_shared[ + v0_o * 16 : v0_o * 16 + 16, + v1_o * 16 : v1_o * 16 + 16, + ], + [16, 16], + dtype="float16", + strides=[s0_1, s1_1], + scope="shared.dyn", + offset_factor=16, + ) + T.launch_thread(tx, 32) + T.evaluate( + T.ptx_ldmatrix( + True, + 4, + ".b16", + warp_1.data, + warp_1.elem_offset + 8 * tx, + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), + shared_1.data, + shared_1.elem_offset, + s0_1 * 16, + 1, + dtype="handle", + ), + s0_1 * (tx % 16) + 8 * (tx // 16), + dtype="float16", + ) + ) + for i0_0_3, i1_0_3, i2_0_2, i0_0_4, i1_0_4 in T.grid(4, 1, 1, 1, 2): + with T.block("C_o_update"): + i_o = T.axis.spatial( + 256, + i0_0_0_i1_0_0_fused * 64 + + i0_0_1_i1_0_1_fused // 64 * 8 + + i1_0_2_i0_0_2_fused % 2 * 4 + + i0_0_3, + ) + j_o = T.axis.spatial( + 256, + i0_0_1_i1_0_1_fused % 64 * 4 + + i1_0_2_i0_0_2_fused // 2 * 2 + + i1_0_4, + ) + k_o = T.axis.reduce(256, i2_0_0 * 2 + i2_0_1) + T.reads( + C_warp[i_o, j_o, 0:32, 0:8], + A_shared_warp[i_o, k_o, 0:32, 0:8], + B_shared_warp[k_o, j_o, 0:32, 0:8], + ) + T.writes(C_warp[i_o, j_o, 0:32, 0:8]) + with T.block("C_o"): + i_o_1 = T.axis.spatial(1, 0) + j_o_1 = T.axis.spatial(1, 0) + k_o_1 = T.axis.reduce(1, 0) + T.reads( + C_warp[i_o, j_o, 0:32, 0:8], + A_shared_warp[i_o, k_o, 0:32, 0:8], + B_shared_warp[k_o, j_o, 0:32, 0:8], + ) + T.writes(C_warp[i_o, j_o, 0:32, 0:8]) + A_1 = T.match_buffer( + A_shared_warp[i_o, k_o, 0:32, 0:8], + [32, 8], + dtype="float16", + scope="warp", + offset_factor=16, + ) + B_1 = T.match_buffer( + B_shared_warp[k_o, j_o, 0:32, 0:8], + [32, 8], + dtype="float16", + scope="warp", + offset_factor=16, + ) + C_1 = T.match_buffer( + C_warp[i_o, j_o, 0:32, 0:8], + [32, 8], + dtype="float32", + scope="warp", + offset_factor=16, + ) + T.launch_thread(tx, 32) + T.evaluate( + T.ptx_mma( + "m16n8k16", + "row", + "col", + "fp16", + "fp16", + "fp32", + A_1.data, + A_1.elem_offset + tx * 8, + B_1.data, + B_1.elem_offset + tx * 8, + C_1.data, + C_1.elem_offset + tx * 8, + False, + dtype="float32", + ) + ) + T.evaluate( + T.ptx_mma( + "m16n8k16", + "row", + "col", + "fp16", + "fp16", + "fp32", + A_1.data, + A_1.elem_offset + tx * 8, + B_1.data, + B_1.elem_offset + tx * 8 + 8 // 2, + C_1.data, + C_1.elem_offset + tx * 8 + 8 // 2, + False, + dtype="float32", + ) + ) + for ax0_0, ax1_0 in T.grid(4, 2): + with T.block("C_warp_o"): + v0_o = T.axis.spatial( + 256, + i0_0_0_i1_0_0_fused * 64 + + i0_0_1_i1_0_1_fused // 64 * 8 + + i1_0_2_i0_0_2_fused % 2 * 4 + + ax0_0, + ) + v1_o = T.axis.spatial( + 256, + i0_0_1_i1_0_1_fused % 64 * 4 + + i1_0_2_i0_0_2_fused // 2 * 2 + + ax1_0, + ) + T.reads(C_warp[v0_o, v1_o, 0:32, 0:8]) + T.writes(C[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + C_warp_2 = T.match_buffer( + C_warp[v0_o, v1_o, 0:32, 0:8], + [32, 8], + dtype="float32", + scope="warp", + offset_factor=1, + ) + C_2 = T.match_buffer( + C[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], + [16, 16], + dtype="float32", + strides=[s0_2, s1_2], + offset_factor=1, + ) + T.launch_thread(tx, 32) + T.evaluate( + T.mma_store( + 16, + 16, + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), + C_2.data, + C_2.elem_offset, + s0_2 * 16, + 2, + dtype="handle", + ), + C_warp_2.data, + C_warp_2.elem_offset, + s0_2, + dtype="float32", + ) + ) + + f = tvm.build(Module_pipelined, target="cuda") + + N = K = M = 4096 + dev = tvm.device("cuda", 0) + a_np = np.random.uniform(size=(N, K)).astype("float16") + b_np = np.random.uniform(size=(K, M)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev) + f(a, b, c) + # print(f.imported_modules[0].get_source()) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + print("ok") + + if __name__ == "__main__": # tvm.testing.main() - test_three_stage_compute() + test_three_stage_gemm() From 966d0c07db8d0c7992a2bc00501a37cc8dc0393d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 27 May 2022 16:02:43 +0900 Subject: [PATCH 4/8] refactor mma test --- python/tvm/testing/tir.py | 110 ++++++++++++++++++ ...est_tir_schedule_tensorize_ldmatrix_mma.py | 106 +++-------------- 2 files changed, 129 insertions(+), 87 deletions(-) diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py index cedaafe80a52..950abe8e5e87 100644 --- a/python/tvm/testing/tir.py +++ b/python/tvm/testing/tir.py @@ -59,3 +59,113 @@ def render(e): assert ( expected_error_text in errors ), f'check_error expects "{expected_error_text} in str(errors): {errors}' + + +def mma_4k_schedule( + workload, + k_inner, + in_dtype, + b_transposed, + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + ldmatrix_a_intrin, + ldmatrix_b_intrin, + mma_intrin, + mma_fill_intrin, + mma_store_intrin, + shared_scope="shared", +): + """Create a tensorized schedule for 4k GEMM with MMA intrinsics.""" + ir_module = tvm.IRModule({"main": workload}) + sch = tvm.tir.Schedule(ir_module) + + block = sch.get_block("C") + i, j, k = sch.get_loops(block) + i, i_tc = sch.split(i, factors=[None, 16]) + j, j_tc = sch.split(j, factors=[None, 16]) + k, k_tc = sch.split(k, factors=[None, k_inner]) + + sch.reorder(i, j, k, i_tc, j_tc, k_tc) + + block_inner = sch.blockize(i_tc) + block_outer, block_inner = block_inner, block + + num_ty = i_factors[2] * j_factors[2] + + i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors) + j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors) + k0, k1, k2 = sch.split(k, k_factors) + + sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3, k2, i4, j4) + + block_idx = sch.fuse(i0, j0) + block_idy = sch.fuse(i1, j1) + thread_idy = sch.fuse(j2, i2) + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + + def fetch_to_shared(block, idx, ndim): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, k0) + vector_size = 16 if in_dtype == "int8" else 8 + warp_size = 32 + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_3) + offset = 8 if in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset) + + return block_read + + fetch_to_shared(block_outer, 0, 2) + fetch_to_shared(block_outer, 1, 2) + + A_warp = sch.cache_read(block_outer, 0, "warp") + B_warp = sch.cache_read(block_outer, 1, "warp") + + sch.compute_at(A_warp, k1) + sch.compute_at(B_warp, k1) + + C_warp = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(C_warp, thread_idy) + + ii, jj = sch.get_loops(C_warp)[-2:] + io, ii = sch.split(ii, factors=[None, 16]) + jo, ji = sch.split(jj, factors=[None, 16]) + sch.reorder(io, jo, ii, ji) + + sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) + block_init_c = sch.get_block("C_init") + + def tile_wmma_fragment(block_read, height, width): + i, j = sch.get_loops(block_read)[-2:] + i0, i1 = sch.split(i, factors=[None, height]) + j0, j1 = sch.split(j, factors=[None, width]) + sch.reorder(i0, j0, i1, j1) + return i1 + + loop_a = tile_wmma_fragment(A_warp, 16, k_inner) + + if b_transposed: + loop_b = tile_wmma_fragment(B_warp, 16, k_inner) + else: + loop_b = tile_wmma_fragment(B_warp, k_inner, 16) + + sch.transform_layout(A_warp, ("write", 0), index_map_A) + sch.transform_layout(B_warp, ("write", 0), index_map_B) + sch.transform_layout(C_warp, ("read", 0), index_map_C) + + sch.tensorize(loop_a, ldmatrix_a_intrin) + sch.tensorize(loop_b, ldmatrix_b_intrin) + sch.tensorize(sch.get_loops(block_inner)[-3], mma_intrin) + sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin) + sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin) + + return sch diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py index e9ee990a2415..eae513d71a8a 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -42,6 +42,7 @@ ) import tvm.testing import numpy as np +from tvm.testing.tir import mma_4k_schedule M = 4096 @@ -98,94 +99,25 @@ def run_test( mma_fill_intrin, mma_store_intrin, ): - workload = te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, b_transposed)) - ir_module = tvm.IRModule({"main": workload}) - sch = tvm.tir.Schedule(ir_module) - - block = sch.get_block("C") - i, j, k = sch.get_loops(block) - i, i_tc = sch.split(i, factors=[None, 16]) - j, j_tc = sch.split(j, factors=[None, 16]) - k, k_tc = sch.split(k, factors=[None, k_inner]) - - sch.reorder(i, j, k, i_tc, j_tc, k_tc) - - block_inner = sch.blockize(i_tc) - block_outer, block_inner = block_inner, block - - num_ty = i_factors[2] * j_factors[2] - - i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors) - j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors) - k0, k1, k2 = sch.split(k, k_factors) - - sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3, k2, i4, j4) - - block_idx = sch.fuse(i0, j0) - block_idy = sch.fuse(i1, j1) - thread_idy = sch.fuse(j2, i2) - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - - def fetch_to_shared(block, idx, ndim): - block_read = sch.cache_read(block, idx, "shared") - sch.compute_at(block_read, k0) - vector_size = 16 if in_dtype == "int8" else 8 - warp_size = 32 - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) - sch.bind(f_2, "threadIdx.x") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_3) - offset = 8 if in_dtype == "float16" else 16 - sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset) - - return block_read - - fetch_to_shared(block_outer, 0, 2) - fetch_to_shared(block_outer, 1, 2) - - A_warp = sch.cache_read(block_outer, 0, "warp") - B_warp = sch.cache_read(block_outer, 1, "warp") - - sch.compute_at(A_warp, k1) - sch.compute_at(B_warp, k1) - - C_warp = sch.cache_write(block_outer, 0, "warp") - sch.reverse_compute_at(C_warp, thread_idy) - - ii, jj = sch.get_loops(C_warp)[-2:] - io, ii = sch.split(ii, factors=[None, 16]) - jo, ji = sch.split(jj, factors=[None, 16]) - sch.reorder(io, jo, ii, ji) - - sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) - block_init_c = sch.get_block("C_init") - - def tile_wmma_fragment(block_read, height, width): - i, j = sch.get_loops(block_read)[-2:] - i0, i1 = sch.split(i, factors=[None, height]) - j0, j1 = sch.split(j, factors=[None, width]) - sch.reorder(i0, j0, i1, j1) - return i1 - - loop_a = tile_wmma_fragment(A_warp, 16, k_inner) - - if b_transposed: - loop_b = tile_wmma_fragment(B_warp, 16, k_inner) - else: - loop_b = tile_wmma_fragment(B_warp, k_inner, 16) - - sch.transform_layout(A_warp, ("write", 0), index_map_A) - sch.transform_layout(B_warp, ("write", 0), index_map_B) - sch.transform_layout(C_warp, ("read", 0), index_map_C) + sch = mma_4k_schedule( + te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, b_transposed)), + k_inner, + in_dtype, + b_transposed, + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + ldmatrix_a_intrin, + ldmatrix_b_intrin, + mma_intrin, + mma_fill_intrin, + mma_store_intrin, + ) - sch.tensorize(loop_a, ldmatrix_a_intrin) - sch.tensorize(loop_b, ldmatrix_b_intrin) - sch.tensorize(sch.get_loops(block_inner)[-3], mma_intrin) - sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin) - sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin) + print(sch.mod.script()) if not is_ampere_or_newer(): return None From 24af027d1824a17d8d6efd7a0e5bfd060f8368b5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 27 May 2022 16:14:16 +0900 Subject: [PATCH 5/8] use mma_4k schedule utility in test --- python/tvm/tir/tensor_intrin/cuda.py | 12 +- ...est_tir_schedule_tensorize_ldmatrix_mma.py | 2 - ..._tir_transform_inject_software_pipeline.py | 456 +++--------------- 3 files changed, 64 insertions(+), 406 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 853a37735486..cbf3ba0c691e 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -54,7 +54,7 @@ def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind): HALF_WARP_expr = lift(HALF_WARP) -def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed): +def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"): local_size = (M_DIM * k_dim) // WARP_SIZE shared_offset = None index_map = None @@ -115,7 +115,7 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed): @T.prim_func def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None: shared = T.match_buffer( - shared_handle, shmem_shape, dtype, align=128, offset_factor=16, scope="shared" + shared_handle, shmem_shape, dtype, align=128, offset_factor=16, scope=shared_scope, ) warp = T.match_buffer( warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp" @@ -144,7 +144,7 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: dtype, align=128, offset_factor=16, - scope="shared", + scope=shared_scope, strides=[s0, s1], ) warp = T.match_buffer( @@ -412,6 +412,12 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: LDMATRIX_16x16_B_INTRIN = "mma.ldmatrix_16x16_b" TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False)) +LDMATRIX_16x16_A_DYN_INTRIN = "mma.ldmatrix_16x16_a_dyn" +TensorIntrin.register(LDMATRIX_16x16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False, "shared.dyn")) + +LDMATRIX_16x16_B_DYN_INTRIN = "mma.ldmatrix_16x16_b_dyn" +TensorIntrin.register(LDMATRIX_16x16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False, "shared.dyn")) + LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans" TensorIntrin.register( LDMATRIX_16x16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", True, True) diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py index eae513d71a8a..929766c8680c 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -117,8 +117,6 @@ def run_test( mma_store_intrin, ) - print(sch.mod.script()) - if not is_ampere_or_newer(): return None diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 77e98a4c02ba..76fb2f9042c5 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -23,6 +23,16 @@ import tvm.tir.tensor_intrin.cuda from tvm import tir, te, TVMError from tvm.script import tir as T +from tvm.meta_schedule.testing import te_workload +from tvm.testing.tir import mma_4k_schedule +from tvm.tir.tensor_intrin.cuda import ( + LDMATRIX_16x16_A_DYN_INTRIN, + LDMATRIX_16x16_B_DYN_INTRIN, + MMA_f16f16f32_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + shared_16x16_to_ldmatrix_32x8_layout, +) def _check(original, transformed): @@ -1025,408 +1035,52 @@ def test_error_missing_annotation(): def test_three_stage_gemm(): - @tvm.script.ir_module - class Module_pipelined: - @T.prim_func - def main( - A: T.Buffer[(4096, 4096), "float16"], - B: T.Buffer[(4096, 4096), "float16"], - C: T.Buffer[(4096, 4096), "float32"], - ) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # var definition - tx = T.env_thread("threadIdx.x") - s0 = T.var("int32") - s0_1 = T.var("int32") - s0_2 = T.var("int32") - s1 = T.var("int32") - s1_1 = T.var("int32") - s1_2 = T.var("int32") - # body - # with T.block("root") - A_shared = T.alloc_buffer([4096, 4096], dtype="float16", scope="shared.dyn") - B_shared = T.alloc_buffer([4096, 4096], dtype="float16", scope="shared.dyn") - A_shared_warp = T.alloc_buffer([256, 256, 32, 8], dtype="float16", scope="warp") - B_shared_warp = T.alloc_buffer([256, 256, 32, 8], dtype="float16", scope="warp") - C_warp = T.alloc_buffer([256, 256, 32, 8], dtype="float32", scope="warp") - for i0_0_0_i1_0_0_fused in T.thread_binding(4, thread="blockIdx.x"): - for i0_0_1_i1_0_1_fused in T.thread_binding(512, thread="blockIdx.y"): - for i1_0_2_i0_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): - for i0_0_3_init, i1_0_4_init in T.grid(4, 2): - with T.block("C_o_init"): - i_o = T.axis.spatial( - 256, - i0_0_0_i1_0_0_fused * 64 - + i0_0_1_i1_0_1_fused // 64 * 8 - + i1_0_2_i0_0_2_fused % 2 * 4 - + i0_0_3_init, - ) - j_o = T.axis.spatial( - 256, - i0_0_1_i1_0_1_fused % 64 * 4 - + i1_0_2_i0_0_2_fused // 2 * 2 - + i1_0_4_init, - ) - T.reads() - T.writes(C_warp[i_o, j_o, 0:32, 0:8]) - with T.block("C_init_o"): - i_init_o = T.axis.spatial(1, 0) - j_init_o = T.axis.spatial(1, 0) - T.reads() - T.writes(C_warp[i_o, j_o, 0:32, 0:8]) - C_warp_1 = T.match_buffer( - C_warp[i_o, j_o, 0:32, 0:8], - [32, 8], - dtype="float32", - scope="warp", - offset_factor=1, - ) - T.launch_thread(tx, 32) - T.evaluate( - T.mma_fill( - 8, C_warp_1.data, C_warp_1.elem_offset, dtype="float32" - ) - ) - for i2_0_0 in T.serial( - 128, - annotations={ - "software_pipeline_order": [0, 1, 2], - "software_pipeline_stage": [0, 0, 3], - }, - ): - for ax0_ax1_fused_0 in T.serial(4): - for ax0_ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"): - for ax0_ax1_fused_2 in T.thread_binding( - 32, thread="threadIdx.x" - ): - for ax0_ax1_fused_3 in T.vectorized(8): - with T.block("A_shared"): - v0 = T.axis.spatial( - 4096, - i0_0_0_i1_0_0_fused * 1024 - + i0_0_1_i1_0_1_fused // 64 * 128 - + ( - ax0_ax1_fused_0 * 1024 - + ax0_ax1_fused_1 * 256 - + ax0_ax1_fused_2 * 8 - + ax0_ax1_fused_3 - ) - // 32, - ) - v1 = T.axis.spatial( - 4096, - i2_0_0 * 32 - + ( - ax0_ax1_fused_0 * 1024 - + ax0_ax1_fused_1 * 256 - + ax0_ax1_fused_2 * 8 - + ax0_ax1_fused_3 - ) - % 32, - ) - T.reads(A[v0, v1]) - T.writes(A_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) - A_shared[v0, v1] = A[v0, v1] - for ax0_ax1_fused_0 in T.serial(2): - for ax0_ax1_fused_1 in T.thread_binding(4, thread="threadIdx.y"): - for ax0_ax1_fused_2 in T.thread_binding( - 32, thread="threadIdx.x" - ): - for ax0_ax1_fused_3 in T.vectorized(8): - with T.block("B_shared"): - v0 = T.axis.spatial( - 4096, - i2_0_0 * 32 - + ( - ax0_ax1_fused_0 * 1024 - + ax0_ax1_fused_1 * 256 - + ax0_ax1_fused_2 * 8 - + ax0_ax1_fused_3 - ) - // 64, - ) - v1 = T.axis.spatial( - 4096, - i0_0_1_i1_0_1_fused % 64 * 64 - + ( - ax0_ax1_fused_0 * 1024 - + ax0_ax1_fused_1 * 256 - + ax0_ax1_fused_2 * 8 - + ax0_ax1_fused_3 - ) - % 64, - ) - T.reads(B[v0, v1]) - T.writes(B_shared[v0, v1]) - T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]]}) - B_shared[v0, v1] = B[v0, v1] - for i2_0_1 in T.serial(2): - for ax0_0, ax1_0 in T.grid(4, 1): - with T.block("A_shared_warp_o"): - v0_o = T.axis.spatial( - 256, - i0_0_0_i1_0_0_fused * 64 - + i0_0_1_i1_0_1_fused // 64 * 8 - + i1_0_2_i0_0_2_fused % 2 * 4 - + ax0_0, - ) - v1_o = T.axis.spatial(256, i2_0_0 * 2 + i2_0_1) - T.reads( - A_shared[ - v0_o * 16 : v0_o * 16 + 16, - v1_o * 16 : v1_o * 16 + 16, - ] - ) - T.writes(A_shared_warp[v0_o, v1_o, 0:32, 0:8]) - warp = T.match_buffer( - A_shared_warp[v0_o, v1_o, 0:32, 0:8], - [32, 8], - dtype="float16", - scope="warp", - offset_factor=16, - ) - shared = T.match_buffer( - A_shared[ - v0_o * 16 : v0_o * 16 + 16, - v1_o * 16 : v1_o * 16 + 16, - ], - [16, 16], - dtype="float16", - strides=[s0, s1], - scope="shared.dyn", - offset_factor=16, - ) - T.launch_thread(tx, 32) - T.evaluate( - T.ptx_ldmatrix( - False, - 4, - ".b16", - warp.data, - warp.elem_offset + 8 * tx, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - shared.data, - shared.elem_offset, - s0 * 16, - 1, - dtype="handle", - ), - s0 * (tx % 16) + 8 * (tx // 16), - dtype="float16", - ) - ) - for ax0_0, ax1_0 in T.grid(1, 2): - with T.block("B_shared_warp_o"): - v0_o = T.axis.spatial(256, i2_0_0 * 2 + i2_0_1) - v1_o = T.axis.spatial( - 256, - i0_0_1_i1_0_1_fused % 64 * 4 - + i1_0_2_i0_0_2_fused // 2 * 2 - + ax1_0, - ) - T.reads( - B_shared[ - v0_o * 16 : v0_o * 16 + 16, - v1_o * 16 : v1_o * 16 + 16, - ] - ) - T.writes(B_shared_warp[v0_o, v1_o, 0:32, 0:8]) - warp_1 = T.match_buffer( - B_shared_warp[v0_o, v1_o, 0:32, 0:8], - [32, 8], - dtype="float16", - scope="warp", - offset_factor=16, - ) - shared_1 = T.match_buffer( - B_shared[ - v0_o * 16 : v0_o * 16 + 16, - v1_o * 16 : v1_o * 16 + 16, - ], - [16, 16], - dtype="float16", - strides=[s0_1, s1_1], - scope="shared.dyn", - offset_factor=16, - ) - T.launch_thread(tx, 32) - T.evaluate( - T.ptx_ldmatrix( - True, - 4, - ".b16", - warp_1.data, - warp_1.elem_offset + 8 * tx, - T.tvm_access_ptr( - T.type_annotation(dtype="float16"), - shared_1.data, - shared_1.elem_offset, - s0_1 * 16, - 1, - dtype="handle", - ), - s0_1 * (tx % 16) + 8 * (tx // 16), - dtype="float16", - ) - ) - for i0_0_3, i1_0_3, i2_0_2, i0_0_4, i1_0_4 in T.grid(4, 1, 1, 1, 2): - with T.block("C_o_update"): - i_o = T.axis.spatial( - 256, - i0_0_0_i1_0_0_fused * 64 - + i0_0_1_i1_0_1_fused // 64 * 8 - + i1_0_2_i0_0_2_fused % 2 * 4 - + i0_0_3, - ) - j_o = T.axis.spatial( - 256, - i0_0_1_i1_0_1_fused % 64 * 4 - + i1_0_2_i0_0_2_fused // 2 * 2 - + i1_0_4, - ) - k_o = T.axis.reduce(256, i2_0_0 * 2 + i2_0_1) - T.reads( - C_warp[i_o, j_o, 0:32, 0:8], - A_shared_warp[i_o, k_o, 0:32, 0:8], - B_shared_warp[k_o, j_o, 0:32, 0:8], - ) - T.writes(C_warp[i_o, j_o, 0:32, 0:8]) - with T.block("C_o"): - i_o_1 = T.axis.spatial(1, 0) - j_o_1 = T.axis.spatial(1, 0) - k_o_1 = T.axis.reduce(1, 0) - T.reads( - C_warp[i_o, j_o, 0:32, 0:8], - A_shared_warp[i_o, k_o, 0:32, 0:8], - B_shared_warp[k_o, j_o, 0:32, 0:8], - ) - T.writes(C_warp[i_o, j_o, 0:32, 0:8]) - A_1 = T.match_buffer( - A_shared_warp[i_o, k_o, 0:32, 0:8], - [32, 8], - dtype="float16", - scope="warp", - offset_factor=16, - ) - B_1 = T.match_buffer( - B_shared_warp[k_o, j_o, 0:32, 0:8], - [32, 8], - dtype="float16", - scope="warp", - offset_factor=16, - ) - C_1 = T.match_buffer( - C_warp[i_o, j_o, 0:32, 0:8], - [32, 8], - dtype="float32", - scope="warp", - offset_factor=16, - ) - T.launch_thread(tx, 32) - T.evaluate( - T.ptx_mma( - "m16n8k16", - "row", - "col", - "fp16", - "fp16", - "fp32", - A_1.data, - A_1.elem_offset + tx * 8, - B_1.data, - B_1.elem_offset + tx * 8, - C_1.data, - C_1.elem_offset + tx * 8, - False, - dtype="float32", - ) - ) - T.evaluate( - T.ptx_mma( - "m16n8k16", - "row", - "col", - "fp16", - "fp16", - "fp32", - A_1.data, - A_1.elem_offset + tx * 8, - B_1.data, - B_1.elem_offset + tx * 8 + 8 // 2, - C_1.data, - C_1.elem_offset + tx * 8 + 8 // 2, - False, - dtype="float32", - ) - ) - for ax0_0, ax1_0 in T.grid(4, 2): - with T.block("C_warp_o"): - v0_o = T.axis.spatial( - 256, - i0_0_0_i1_0_0_fused * 64 - + i0_0_1_i1_0_1_fused // 64 * 8 - + i1_0_2_i0_0_2_fused % 2 * 4 - + ax0_0, - ) - v1_o = T.axis.spatial( - 256, - i0_0_1_i1_0_1_fused % 64 * 4 - + i1_0_2_i0_0_2_fused // 2 * 2 - + ax1_0, - ) - T.reads(C_warp[v0_o, v1_o, 0:32, 0:8]) - T.writes(C[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - C_warp_2 = T.match_buffer( - C_warp[v0_o, v1_o, 0:32, 0:8], - [32, 8], - dtype="float32", - scope="warp", - offset_factor=1, - ) - C_2 = T.match_buffer( - C[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], - [16, 16], - dtype="float32", - strides=[s0_2, s1_2], - offset_factor=1, - ) - T.launch_thread(tx, 32) - T.evaluate( - T.mma_store( - 16, - 16, - T.tvm_access_ptr( - T.type_annotation(dtype="float32"), - C_2.data, - C_2.elem_offset, - s0_2 * 16, - 2, - dtype="handle", - ), - C_warp_2.data, - C_warp_2.elem_offset, - s0_2, - dtype="float32", - ) - ) - - f = tvm.build(Module_pipelined, target="cuda") - N = K = M = 4096 - dev = tvm.device("cuda", 0) - a_np = np.random.uniform(size=(N, K)).astype("float16") - b_np = np.random.uniform(size=(K, M)).astype("float16") - c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev) - f(a, b, c) - # print(f.imported_modules[0].get_source()) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) - print("ok") + i_factors, j_factors, k_factors = [4, 8, 2, 4, 1], [1, 64, 2, 1, 2], [128, 2, 1] + + def index_map(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + workload = te.create_prim_func(te_workload.matmul_fp16(N, M, K)) + + sch = mma_4k_schedule( + workload, + 16, + "float16", + False, + i_factors, + j_factors, + k_factors, + index_map, + index_map, + index_map, + LDMATRIX_16x16_A_DYN_INTRIN, + LDMATRIX_16x16_B_DYN_INTRIN, + MMA_f16f16f32_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + "shared.dyn", + ) + + print(sch.mod.script()) + + f = tvm.build(sch.mod["main"], target="cuda") + + # dev = tvm.device("cuda", 0) + # a_np = np.random.uniform(size=(N, K)).astype("float16") + # b_np = np.random.uniform(size=(K, M)).astype("float16") + # c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) + # a = tvm.nd.array(a_np, dev) + # b = tvm.nd.array(b_np, dev) + # c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev) + # f(a, b, c) + # # print(f.imported_modules[0].get_source()) + # tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + # print("ok") if __name__ == "__main__": From 3d8b3ccd928cc1061b327ac4899590f04063b98f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 27 May 2022 16:24:54 +0900 Subject: [PATCH 6/8] apply pipeling annotation --- python/tvm/testing/tir.py | 2 +- ...est_tir_schedule_tensorize_ldmatrix_mma.py | 4 +-- ..._tir_transform_inject_software_pipeline.py | 29 ++++++++++--------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py index 950abe8e5e87..53e082c9ec8e 100644 --- a/python/tvm/testing/tir.py +++ b/python/tvm/testing/tir.py @@ -61,7 +61,7 @@ def render(e): ), f'check_error expects "{expected_error_text} in str(errors): {errors}' -def mma_4k_schedule( +def mma_schedule( workload, k_inner, in_dtype, diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py index 929766c8680c..9feb994e7158 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -42,7 +42,7 @@ ) import tvm.testing import numpy as np -from tvm.testing.tir import mma_4k_schedule +from tvm.testing.tir import mma_schedule M = 4096 @@ -99,7 +99,7 @@ def run_test( mma_fill_intrin, mma_store_intrin, ): - sch = mma_4k_schedule( + sch = mma_schedule( te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, b_transposed)), k_inner, in_dtype, diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 76fb2f9042c5..cb7efd943808 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -24,7 +24,7 @@ from tvm import tir, te, TVMError from tvm.script import tir as T from tvm.meta_schedule.testing import te_workload -from tvm.testing.tir import mma_4k_schedule +from tvm.testing.tir import mma_schedule from tvm.tir.tensor_intrin.cuda import ( LDMATRIX_16x16_A_DYN_INTRIN, LDMATRIX_16x16_B_DYN_INTRIN, @@ -1047,7 +1047,7 @@ def index_map(i, j): workload = te.create_prim_func(te_workload.matmul_fp16(N, M, K)) - sch = mma_4k_schedule( + sch = mma_schedule( workload, 16, "float16", @@ -1066,21 +1066,22 @@ def index_map(i, j): "shared.dyn", ) - print(sch.mod.script()) + k0 = sch.get_loops(sch.get_block("C_o_update"))[3] + + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) f = tvm.build(sch.mod["main"], target="cuda") - # dev = tvm.device("cuda", 0) - # a_np = np.random.uniform(size=(N, K)).astype("float16") - # b_np = np.random.uniform(size=(K, M)).astype("float16") - # c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) - # a = tvm.nd.array(a_np, dev) - # b = tvm.nd.array(b_np, dev) - # c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev) - # f(a, b, c) - # # print(f.imported_modules[0].get_source()) - # tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) - # print("ok") + dev = tvm.device("cuda", 0) + a_np = np.random.uniform(size=(N, K)).astype("float16") + b_np = np.random.uniform(size=(K, M)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev) + f(a, b, c) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) if __name__ == "__main__": From c403252d64fdb536d77b0aa54331aa5962cfb32b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 27 May 2022 17:35:11 +0900 Subject: [PATCH 7/8] black --- python/tvm/tir/tensor_intrin/cuda.py | 15 ++++++++++++--- ...test_tir_transform_inject_software_pipeline.py | 3 +-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index cbf3ba0c691e..c5883fd072c5 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -115,7 +115,12 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"): @T.prim_func def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None: shared = T.match_buffer( - shared_handle, shmem_shape, dtype, align=128, offset_factor=16, scope=shared_scope, + shared_handle, + shmem_shape, + dtype, + align=128, + offset_factor=16, + scope=shared_scope, ) warp = T.match_buffer( warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp" @@ -413,10 +418,14 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False)) LDMATRIX_16x16_A_DYN_INTRIN = "mma.ldmatrix_16x16_a_dyn" -TensorIntrin.register(LDMATRIX_16x16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False, "shared.dyn")) +TensorIntrin.register( + LDMATRIX_16x16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False, "shared.dyn") +) LDMATRIX_16x16_B_DYN_INTRIN = "mma.ldmatrix_16x16_b_dyn" -TensorIntrin.register(LDMATRIX_16x16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False, "shared.dyn")) +TensorIntrin.register( + LDMATRIX_16x16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False, "shared.dyn") +) LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans" TensorIntrin.register( diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index cb7efd943808..f89dbacee33d 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -1085,5 +1085,4 @@ def index_map(i, j): if __name__ == "__main__": - # tvm.testing.main() - test_three_stage_gemm() + tvm.testing.main() From 853a128894847719d4decbc9d111dd760bc130f3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 28 May 2022 04:59:36 +0900 Subject: [PATCH 8/8] require ampere in test --- python/tvm/testing/tir.py | 2 +- ..._tir_transform_inject_software_pipeline.py | 29 ++++++++++++------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py index 53e082c9ec8e..8dd482673829 100644 --- a/python/tvm/testing/tir.py +++ b/python/tvm/testing/tir.py @@ -79,7 +79,7 @@ def mma_schedule( mma_store_intrin, shared_scope="shared", ): - """Create a tensorized schedule for 4k GEMM with MMA intrinsics.""" + """Create a tensorized schedule for GEMM with MMA intrinsics.""" ir_module = tvm.IRModule({"main": workload}) sch = tvm.tir.Schedule(ir_module) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index f89dbacee33d..fddda05eb5b0 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -1034,10 +1034,16 @@ def test_error_missing_annotation(): _check_error(simple_compute_missing_annotation) +@tvm.testing.requires_cuda def test_three_stage_gemm(): N = K = M = 4096 i_factors, j_factors, k_factors = [4, 8, 2, 4, 1], [1, 64, 2, 1, 2], [128, 2, 1] + def is_ampere_or_newer(): + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + return major >= 8 + def index_map(i, j): return ( i // 16, @@ -1071,17 +1077,18 @@ def index_map(i, j): sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) - f = tvm.build(sch.mod["main"], target="cuda") - - dev = tvm.device("cuda", 0) - a_np = np.random.uniform(size=(N, K)).astype("float16") - b_np = np.random.uniform(size=(K, M)).astype("float16") - c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(b_np, dev) - c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev) - f(a, b, c) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + if is_ampere_or_newer(): + f = tvm.build(sch.mod["main"], target="cuda") + + dev = tvm.device("cuda", 0) + a_np = np.random.uniform(size=(N, K)).astype("float16") + b_np = np.random.uniform(size=(K, M)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((N, M), dtype="float32"), dev) + f(a, b, c) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) if __name__ == "__main__":