From e7160d569a19aa00b0fd605abd970d0e9ed8b1d0 Mon Sep 17 00:00:00 2001 From: "yin.changsheng" Date: Mon, 5 Dec 2022 14:15:56 +0800 Subject: [PATCH] Add recursive on loop with marked kUnrolled (#13536) Current LoopPartition pass, when the loop is marked kUnrolled, it returns directly This PR enhance LoopPartition pass to continue recursive on loop with marked kUnrolled. --- src/tir/transforms/loop_partition.cc | 3 +- .../test_tir_transform_loop_partition.py | 69 +++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 1d995ef26ed8..0d088526694d 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -597,7 +597,8 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim if (!opt_cond_value.has_value()) { if (has_partition_hint_ && unroll_loop_with_partition_hint_no_interval_ && analyzer_.CanProve(max - min > 0)) { - return For(var, min, max - min + 1, ForKind::kUnrolled, body); + auto new_body = VisitAndMutate(body); + return For(var, min, max - min + 1, ForKind::kUnrolled, new_body); } return Stmt(); } diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index fe48aa7d8fd4..7dd8e794103e 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -677,6 +677,75 @@ def partitioned_main( assert tvm.ir.structural_equal(mod["main"], partitioned_main) +def test_loop_partition_recursive_unroll_hint(): + @T.prim_func + def main(): + placeholder_0_dm = T.decl_buffer([1, 32, 32, 16], dtype="int8") + for i3_0 in T.serial(5, annotations={"pragma_loop_partition_hint": 1}): + for i2_0 in T.serial(2, annotations={"pragma_loop_partition_hint": 1}): + pad_temp = T.decl_buffer([1, 16, 16, 16], dtype="int8") + for ax0, ax1, ax2 in T.grid(16, 16, 16): + if ( + 6 <= i2_0 * 4 + ax0 + and i2_0 * 4 + ax0 < 26 + and 6 <= i3_0 * 4 + ax1 + and i3_0 * 4 + ax1 < 26 + ): + pad_temp[ + 0, + i2_0 * 4 + ax0 - 6 + 6 - i2_0 * 4, + i3_0 * 4 + ax1 - 6 + 6 - i3_0 * 4, + ax2, + ] = placeholder_0_dm[ + 0, + i2_0 * 4 + ax0 - 6 - -6, + i3_0 * 4 + ax1 - 6 - -6, + ax2, + ] + + @T.prim_func + def partitioned_main(): + placeholder_0_dm = T.allocate([16384], "int8", "global") + placeholder_0_dm_1 = T.buffer_decl([16384], dtype="int8", data=placeholder_0_dm) + for i3_0 in T.unroll(2): + for i2_0 in T.unroll(2): + pad_temp = T.allocate([4096], "int8", "global") + pad_temp_1 = T.buffer_decl([4096], dtype="int8", data=pad_temp) + for ax0, ax1, ax2 in T.grid(16, 16, 16): + if 6 <= i2_0 * 4 + ax0 and 6 <= i3_0 * 4 + ax1: + pad_temp_1[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 + ] + for i2_0 in T.unroll(2): + pad_temp_2 = T.allocate([4096], "int8", "global") + pad_temp_3 = T.buffer_decl([4096], dtype="int8", data=pad_temp_2) + for ax0, ax1, ax2 in T.grid(16, 16, 16): + if 6 <= i2_0 * 4 + ax0: + pad_temp_3[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + i2_0 * 2048 + ax0 * 512 + ax1 * 16 + ax2 + 128 + ] + for i3_0 in T.unroll(2): + for i2_0 in T.unroll(2): + pad_temp_4 = T.allocate([4096], "int8", "global") + pad_temp_5 = T.buffer_decl([4096], dtype="int8", data=pad_temp_4) + for ax0, ax1, ax2 in T.grid(16, 16, 16): + if 6 <= i2_0 * 4 + ax0 and i3_0 * 4 + ax1 < 14: + pad_temp_5[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 + 192 + ] + + mod = partition_from_scheduled_tir( + main, + { + "tir.LoopPartition": { + "partition_const_loop": True, + "unroll_loop_with_partition_hint_no_interval": True, + } + }, + ) + assert tvm.ir.structural_equal(mod["main"], partitioned_main) + + def test_loop_partition_keep_loop_annotations(): @T.prim_func def before(A: T.Buffer[160, "int32"], B: T.Buffer[160, "int32"]) -> None: