diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index fec214fa1fc7..c644fbecdf5c 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -326,23 +326,62 @@ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block if (!opt_tensorize_info) return NullOpt; const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); if (info->block_iter_paddings.defined()) { + // We have to track whether each producer or consumer is padded. + // To do so, we first record all the Block's. + std::unordered_set original_producers, original_consumers; + { + for (const auto& p : GetProducers(sch->state(), sch->GetSRef(block_rv))) + original_producers.insert(p.get()); + for (const auto& c : GetConsumers(sch->state(), sch->GetSRef(block_rv))) + original_consumers.insert(c.get()); + } + + // Pad. Maybe we can make PadEinsum return the changes it made, to avoid bookkeeping? sch->PadEinsum(block_rv, info->block_iter_paddings.value()); + + // Now we need to find out all the padded Block's. + Array inlined_producers, inlined_consumers; + for (const auto& producer : sch->GetProducers(block_rv)) { + // PadEinsum will not modify the producer if it does not need padding. + if (original_producers.count(sch->GetSRef(producer).get())) { + // Producer not padded. No inlining. + continue; + } + auto the_original_producers = sch->GetProducers(producer); + if (the_original_producers.empty()) { + // The original producer is input. + continue; + } + ICHECK_EQ(the_original_producers.size(), 1u); + auto the_original_producer = the_original_producers[0]; + ICHECK(original_producers.count(sch->GetSRef(the_original_producer).get())); + inlined_producers.push_back(the_original_producer); + } + for (const auto& consumer : sch->GetConsumers(block_rv)) { + // PadEinsum will not modify the consumer if it does not need padding. + if (original_consumers.count(sch->GetSRef(consumer).get())) { + // Consumer not padded. No inlining. + continue; + } + auto the_original_consumers = sch->GetConsumers(consumer); + if (the_original_consumers.empty()) { + // The original consumer is output. + continue; + } + ICHECK_EQ(the_original_consumers.size(), 1u); + auto the_original_consumer = the_original_consumers[0]; + ICHECK(original_consumers.count(sch->GetSRef(the_original_consumer).get())); + inlined_consumers.push_back(consumer); + } + // Inline the producer and consumer padding blocks - auto producers = sch->GetProducers(block_rv); - for (const auto& producer : producers) { - auto original_producers = sch->GetProducers(producer); - // NOTICE: there may not all producers padded. + for (const auto& the_original_producer : inlined_producers) { // Inline the original producer into the padding block. This ensures that the new producer // has the padded shape. - if (original_producers.size() == 1u) { - sch->ComputeInline(original_producers[0]); - } + sch->ComputeInline(the_original_producer); } - auto consumers = sch->GetConsumers(block_rv); - for (const auto& consumer : consumers) { - auto sref = sch->GetSRef(consumer); - if (!tir::IsOutputBlock(sch->state(), sref, tir::GetScopeRoot(sch->state(), sref, true))) - sch->ComputeInline(consumer); + for (const auto& consumer : inlined_consumers) { + sch->ComputeInline(consumer); } } // Construct a mapping from tir loops back to LoopRVs diff --git a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py index 1fd2ab84749e..be936e6e84fb 100644 --- a/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py @@ -1207,5 +1207,300 @@ def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight: T.Buf ) +def test_padded_matmul_single_padded_input(): + # fmt: off + @T.prim_func + def padded_matmul_single_padded_input_0(A: T.Buffer((1023, 4096), "float16"), B: T.Buffer((4096, 1024), "float16"), C: T.Buffer((1023, 1024), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + C_reindex_pad_shared = T.alloc_buffer((8, 32, 8, 2, 16, 16), scope="shared") + C_reindex_pad_shared_wmma_accumulator = T.alloc_buffer((8, 32, 8, 2, 16, 16), scope="wmma.accumulator") + A_reindex_pad_shared = T.alloc_buffer((1024, 4096), "float16", scope="shared") + B_reindex_shared = T.alloc_buffer((4096, 1024), "float16", scope="shared") + A_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((1024, 4096), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((4096, 1024), "float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(32, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(8, thread="threadIdx.y"): + for ax2_0_0 in range(32): + for ax0_ax1_fused in range(65536): + with T.block("A_reindex_pad_shared"): + v0 = T.axis.spatial(1024, ax0_0_1_ax1_0_1_fused // 16 * 512 + ax0_ax1_fused // 128) + v1 = T.axis.spatial(4096, ax2_0_0 * 128 + ax0_ax1_fused % 128) + T.reads(A[v0, v1]) + T.writes(A_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) + A_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 1023, A[v0, v1], T.float16(0.0)) + for ax0_ax1_fused in range(8192): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(4096, ax2_0_0 * 128 + ax0_ax1_fused // 64) + v1 = T.axis.spatial(1024, ax0_0_1_ax1_0_1_fused % 16 * 64 + ax0_ax1_fused % 64) + T.reads(B[v0, v1]) + T.writes(B_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + B_reindex_shared[v0, v1] = B[v0, v1] + for ax2_0_1 in range(8): + for ax0_0, ax1_0 in T.grid(8, 1): + with T.block("A_reindex_pad_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused // 16 * 32 + ax0_0_2_ax1_0_2_fused // 2 * 8 + ax0_0) + v1_o = T.axis.spatial(256, ax2_0_0 * 8 + ax2_0_1 + ax1_0) + T.reads(A_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_pad_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("B_reindex_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(256, ax2_0_0 * 8 + ax2_0_1 + ax0_0) + v1_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused % 16 * 4 + ax0_0_2_ax1_0_2_fused % 2 * 2 + ax1_0) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(2, 1, 1, 4, 2): + with T.block("C_o"): + v0_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused // 16 * 32 + ax0_0_2_ax1_0_2_fused // 2 * 8 + ax0_0_3 * 4 + ax0_0_4) + v1_o = T.axis.spatial(64, ax0_0_1_ax1_0_1_fused % 16 * 4 + ax0_0_2_ax1_0_2_fused % 2 * 2 + ax1_0_3 * 2 + ax1_0_4) + v2_o = T.axis.reduce(256, ax2_0_0 * 8 + ax2_0_1 + ax2_0_2) + T.reads(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i_init, v1_i_init]) + C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0.0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i], A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i] = C_reindex_pad_shared_wmma_accumulator[v0_o // 8, v1_o // 2, v0_o % 8, v1_o % 2, v0_i, v1_i] + T.Cast("float32", A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(8): + for ax0_ax1_fused in T.thread_binding(8, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 2): + with T.block("C_reindex_pad_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 16 * 4 + ax0_ax1_fused // 2) + v1_o = T.axis.spatial(32, ax0_0_1_ax1_0_1_fused % 16 * 2 + ax0_ax1_fused % 2) + v2_o = T.axis.spatial(8, ax2 + ax2_1) + v3_o = T.axis.spatial(2, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.writes(C_reindex_pad_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_pad_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + T.writes(C_reindex_pad_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + C_reindex_pad_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = C_reindex_pad_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(4096): + with T.block("C_reindex_pad_shared"): + v0 = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 16 * 4 + ax0_ax1_ax3_ax4_ax5_fused // 1024) + v1 = T.axis.spatial(32, ax0_0_1_ax1_0_1_fused % 16 * 2 + ax0_ax1_ax3_ax4_ax5_fused % 1024 // 512) + v2 = T.axis.spatial(8, ax2) + v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused % 512 // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.where(ax0_0_1_ax1_0_1_fused // 16 * 512 + ax0_ax1_ax3_ax4_ax5_fused // 1024 * 128 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 1023) + T.reads(C_reindex_pad_shared[v0, v1, v2, v3, v4, v5]) + T.writes(C[v4 + v2 * 16 + v0 * 128, v5 + v3 * 16 + v1 * 32]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + C[v4 + v2 * 16 + v0 * 128, v5 + v3 * 16 + v1 * 32] = C_reindex_pad_shared[v0, v1, v2, v3, v4, v5] + # fmt: on + + decision_0 = [ + ("SamplePerfectTile", [1, 2, 4, 2, 4]), + ("SamplePerfectTile", [1, 16, 2, 1, 2]), + ("SamplePerfectTile", [32, 8, 1]), + ("SampleCategorical", 3), + ("SampleCategorical", 1), + ("SampleCategorical", 0), + ] + mod = te.create_prim_func( + te_workload.matmul( + n=1023, + m=1024, + k=4096, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_70"), + types=None, + sch_rules=[multi_level_tiling_tensor_core()] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[padded_matmul_single_padded_input_0], + expected_decisions=[decision_0], + ) + + +def test_padded_matmul_no_padded_output(): + # fmt: off + @T.prim_func + def padded_matmul_no_padded_output_0(A: T.Buffer((1024, 4095), "float16"), B: T.Buffer((4095, 1024), "float16"), C: T.Buffer((1024, 1024), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + C_reindex_shared = T.alloc_buffer((32, 16, 2, 4, 16, 16), scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer((32, 16, 2, 4, 16, 16), scope="wmma.accumulator") + A_reindex_pad_shared = T.alloc_buffer((1024, 4096), "float16", scope="shared") + B_reindex_pad_shared = T.alloc_buffer((4096, 1024), "float16", scope="shared") + A_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((1024, 4096), "float16", scope="wmma.matrix_a") + B_reindex_pad_shared_wmma_matrix_b = T.alloc_buffer((4096, 1024), "float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(64, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(4, thread="threadIdx.y"): + for ax2_0_0 in range(128): + for ax0_ax1_fused in range(4096): + with T.block("A_reindex_pad_shared"): + v0 = T.axis.spatial(1024, ax0_0_0_ax1_0_0_fused // 16 * 256 + ax0_0_1_ax1_0_1_fused * 128 + ax0_ax1_fused // 32) + v1 = T.axis.spatial(4096, ax2_0_0 * 32 + ax0_ax1_fused % 32) + T.reads(A[v0, v1]) + T.writes(A_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + A_reindex_pad_shared[v0, v1] = T.if_then_else(v1 < 4095, A[v0, v1], T.float16(0.0)) + for ax0_ax1_fused in range(2048): + with T.block("B_reindex_pad_shared"): + v0 = T.axis.spatial(4096, ax2_0_0 * 32 + ax0_ax1_fused // 64) + v1 = T.axis.spatial(1024, ax0_0_0_ax1_0_0_fused % 16 * 64 + ax0_ax1_fused % 64) + T.reads(B[v0, v1]) + T.writes(B_reindex_pad_shared[v0, v1]) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + B_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 4095, B[v0, v1], T.float16(0.0)) + for ax2_0_1 in range(2): + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("A_reindex_pad_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused // 16 * 16 + ax0_0_1_ax1_0_1_fused * 8 + ax0_0_2_ax1_0_2_fused * 2 + ax0_0) + v1_o = T.axis.spatial(256, ax2_0_0 * 2 + ax2_0_1 + ax1_0) + T.reads(A_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_pad_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(1, 4): + with T.block("B_reindex_pad_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(256, ax2_0_0 * 2 + ax2_0_1 + ax0_0) + v1_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 16 * 4 + ax1_0) + T.reads(B_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_pad_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(2, 1, 1, 1, 4): + with T.block("C_o"): + v0_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused // 16 * 16 + ax0_0_1_ax1_0_1_fused * 8 + ax0_0_2_ax1_0_2_fused * 2 + ax0_0_3 + ax0_0_4) + v1_o = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 16 * 4 + ax1_0_3 * 4 + ax1_0_4) + v2_o = T.axis.reduce(256, ax2_0_0 * 2 + ax2_0_1 + ax2_0_2) + T.reads(A_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_pad_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init] = T.float32(0.0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i], A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] + T.Cast("float32", A_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(4, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 4): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused // 16 * 8 + ax0_0_1_ax1_0_1_fused * 4 + ax0_ax1_fused) + v1_o = T.axis.spatial(16, ax0_0_0_ax1_0_0_fused % 16) + v2_o = T.axis.spatial(2, ax2 + ax2_1) + v3_o = T.axis.spatial(4, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.writes(C_reindex_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + T.writes(C_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]) + C_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(4096): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused // 16 * 8 + ax0_0_1_ax1_0_1_fused * 4 + ax0_ax1_ax3_ax4_ax5_fused // 1024) + v1 = T.axis.spatial(16, ax0_0_0_ax1_0_0_fused % 16) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused % 1024 // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64] = C_reindex_shared[v0, v1, v2, v3, v4, v5] + # fmt: on + + decision_0 = [ + ("SamplePerfectTile", [4, 2, 4, 2, 1]), + ("SamplePerfectTile", [16, 1, 1, 1, 4]), + ("SamplePerfectTile", [128, 2, 1]), + ("SampleCategorical", 2), + ("SampleCategorical", 3), + ("SampleCategorical", 0), + ] + mod = te.create_prim_func( + te_workload.matmul( + n=1024, + m=1024, + k=4095, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = generate_design_space( + kind="cuda", + mod=mod, + target=tvm.target.Target("cuda --arch=sm_70"), + types=None, + sch_rules=[multi_level_tiling_tensor_core()] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ) + check_sketches( + mod, + sketches=actual, + expected_mods=[padded_matmul_no_padded_output_0], + expected_decisions=[decision_0], + ) + + if __name__ == "__main__": tvm.testing.main()