Skip to content

Commit

Permalink
[MetaSchedule]Add a testcase for padded conv2d in meta_schedule (#17171)
Browse files Browse the repository at this point in the history
### Bug Fix

In the `TileWithTensorIntrin` function, when the `allow_padding` parameter is enabled, the original implementation inlines all consumer blocks. This behavior can lead to incorrect inlining of output blocks, causing issues with block shapes and dependencies. To ensure correct inlining operations, only non-output consumer blocks should be inlined.
---------
Co-authored-by: yuxiyue <[email protected]>
  • Loading branch information
YXY-0922 authored Jul 22, 2024
1 parent e5bf56d commit 18ff9ff
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::Block
}
auto consumers = sch->GetConsumers(block_rv);
for (const auto& consumer : consumers) {
sch->ComputeInline(consumer);
auto sref = sch->GetSRef(consumer);
if (!tir::IsOutputBlock(sch->state(), sref, tir::GetScopeRoot(sch->state(), sref, true)))
sch->ComputeInline(consumer);
}
}
// Construct a mapping from tir loops back to LoopRVs
Expand Down
152 changes: 152 additions & 0 deletions tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,5 +1055,157 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer(
)


def test_padded_conv():
# fmt: off
@T.prim_func
def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight: T.Buffer((7, 7, 3, 64), "float16"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
conv2d_nhwc_reindex_shared = T.alloc_buffer((56, 2, 14, 2, 16, 16), scope="shared")
conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((56, 2, 14, 2, 16, 16), scope="wmma.accumulator")
PadInput_reindex_pad_shared = T.alloc_buffer((12544, 160), "float16", scope="shared")
weight_reindex_pad_shared = T.alloc_buffer((160, 64), "float16", scope="shared")
PadInput_reindex_pad_shared_wmma_matrix_a = T.alloc_buffer((12544, 160), "float16", scope="wmma.matrix_a")
weight_reindex_pad_shared_wmma_matrix_b = T.alloc_buffer((160, 64), "float16", scope="wmma.matrix_b")
for ax0_0_0_ax1_0_0_fused in T.thread_binding(14, thread="blockIdx.y"):
for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, thread="blockIdx.x"):
for ax0_0_2_ax1_0_2_fused in T.thread_binding(8, thread="threadIdx.y"):
for ax2_0_0 in range(10):
for ax0_ax1_fused in range(28672):
with T.block("PadInput_reindex_pad_shared"):
v0 = T.axis.spatial(12544, ax0_0_0_ax1_0_0_fused // 2 * 1792 + ax0_ax1_fused // 16)
v1 = T.axis.spatial(160, ax2_0_0 * 16 + ax0_ax1_fused % 16)
T.reads(inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3])
T.writes(PadInput_reindex_pad_shared[v0, v1])
T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 4})
PadInput_reindex_pad_shared[v0, v1] = T.if_then_else(v1 < 147, T.if_then_else(3 <= v0 // 112 * 2 + v1 // 21 and v0 // 112 * 2 + v1 // 21 < 227 and 3 <= v0 % 112 * 2 + v1 % 21 // 3 and v0 % 112 * 2 + v1 % 21 // 3 < 227, inputs[0, v0 // 112 * 2 + v1 // 21 - 3, v0 % 112 * 2 + v1 % 21 // 3 - 3, v1 % 3], T.float16(0)), T.float16(0))
for ax0_ax1_fused in range(512):
with T.block("weight_reindex_pad_shared"):
v0 = T.axis.spatial(160, ax2_0_0 * 16 + ax0_ax1_fused // 32)
v1 = T.axis.spatial(64, ax0_0_0_ax1_0_0_fused % 2 * 32 + ax0_ax1_fused % 32)
T.reads(weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1])
T.writes(weight_reindex_pad_shared[v0, v1])
T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2})
weight_reindex_pad_shared[v0, v1] = T.if_then_else(v0 < 147, weight[v0 // 21, v0 % 21 // 3, v0 % 3, v1], T.float16(0))
for ax2_0_1 in range(1):
for ax0_0, ax1_0 in T.grid(14, 1):
with T.block("PadInput_reindex_pad_shared_wmma.matrix_a_o"):
v0_o = T.axis.spatial(784, ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0)
v1_o = T.axis.spatial(10, ax2_0_0 + ax1_0)
T.reads(PadInput_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"})
for ax0_1, ax1_1 in T.grid(16, 16):
with T.block("PadInput_reindex_pad_shared_wmma.matrix_a"):
v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
T.reads(PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0_0, ax1_0 in T.grid(1, 2):
with T.block("weight_reindex_pad_shared_wmma.matrix_b_o"):
v0_o = T.axis.spatial(10, ax2_0_0 + ax0_0)
v1_o = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0)
T.reads(weight_reindex_pad_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"})
for ax0_1, ax1_1 in T.grid(16, 16):
with T.block("weight_reindex_pad_shared_wmma.matrix_b"):
v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1])
T.reads(weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
weight_reindex_pad_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_pad_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(7, 2, 1, 2, 1):
with T.block("conv2d_nhwc_o"):
v0_o = T.axis.spatial(784, ax0_0_0_ax1_0_0_fused // 2 * 112 + ax0_0_2_ax1_0_2_fused * 14 + ax0_0_3 * 2 + ax0_0_4)
v1_o = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused % 2 * 2 + ax1_0_3 + ax1_0_4)
v2_o = T.axis.reduce(10, ax2_0_0 + ax2_0_1 + ax2_0_2)
T.reads(PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16])
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, 0:16, 0:16])
T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1})
with T.init():
for ax0_1, ax1_1 in T.grid(16, 16):
with T.block("conv2d_nhwc_init"):
v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1])
T.reads()
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i_init, v1_i_init])
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0)
for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16):
with T.block("conv2d_nhwc"):
v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1])
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i], PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i])
T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"})
conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o // 14, v1_o // 2, v0_o % 14, v1_o % 2, v0_i, v1_i] + T.Cast("float32", PadInput_reindex_pad_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", weight_reindex_pad_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i])
for ax2 in range(14):
for ax0_ax1_fused in T.thread_binding(8, thread="threadIdx.y"):
for ax2_1, ax3 in T.grid(1, 2):
with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"):
v0_o = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_fused)
v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2)
v2_o = T.axis.spatial(14, ax2 + ax2_1)
v3_o = T.axis.spatial(2, ax3)
v4_o = T.axis.spatial(1, 0)
v5_o = T.axis.spatial(1, 0)
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16])
T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16])
T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"})
for ax4, ax5 in T.grid(16, 16):
with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
v4_i, v5_i = T.axis.remap("SS", [ax4, ax5])
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
T.writes(conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i])
conv2d_nhwc_reindex_shared[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, v4_i, v5_i]
for ax0_ax1_ax3_ax4_ax5_fused in range(4096):
with T.block("conv2d_nhwc_reindex_shared"):
v0 = T.axis.spatial(56, ax0_0_0_ax1_0_0_fused // 2 * 8 + ax0_ax1_ax3_ax4_ax5_fused // 512)
v1 = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused % 2)
v2 = T.axis.spatial(14, ax2)
v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused % 512 // 256)
v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16)
v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16)
T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5])
T.writes(conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32])
T.block_attr({"meta_schedule.cooperative_fetch": 3})
conv2d_nhwc[0, (v4 + v2 * 16 + v0 * 224) // 112, (v4 + v2 * 16 + v0 * 224) % 112, v5 + v3 * 16 + v1 * 32] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]
# fmt: on

decision_0 = [
("SamplePerfectTile", [7, 1, 8, 7, 2]),
("SamplePerfectTile", [2, 1, 1, 2, 1]),
("SamplePerfectTile", [10, 1, 1]),
("SampleCategorical", 2),
("SampleCategorical", 2),
("SampleCategorical", 1),
]
mod = te.create_prim_func(
te_workload.conv2d_nhwc(
1,
224,
224,
3,
64,
7,
2,
3,
in_dtype="float16",
out_dtype="float32",
)
)
actual = generate_design_space(
kind="cuda",
mod=mod,
target=tvm.target.Target("cuda --arch=sm_70"),
types=None,
sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")]
+ get_rules("cuda", ms.schedule_rule.AutoInline),
)
check_sketches(
mod,
sketches=actual,
expected_mods=[padded_conv2d_0],
expected_decisions=[decision_0],
)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 18ff9ff

Please sign in to comment.