From 832ffdcd32fc7f67ca2c2526256be5f27819111e Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Date: Fri, 17 Feb 2023 17:14:24 +0530 Subject: [PATCH 1/3] [TIR] [Bugfix] Pass the correct block_sref_reuse to Replace A mismatch between the blocks present in the `result` vs the blocks passed in `block_sref_to_reuse` caused the bug mentioned in #13974. This patch tries to fix that bug by collecting only the blocks that are part of result and also present in the block replacement map `new_block_to_old_`. Since the scope block is `result`, only that block and its child blocks would be replaced, and any replaced block would be present in `rewriter.new_block_to_old_`. Thus, collecting the replaced blocks from among child blocks of `result` guarantees that the `block_sref_reuse` would contain all the replaced blocks and that they'll point to the correct block in `result` thus avoiding the missing SRef error. --- .../primitive/layout_transformation.cc | 49 ++++++++++++++----- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 742384fc798f..cffa563af9cb 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -704,6 +704,42 @@ class TransformLayoutPlanner : private StmtExprVisitor { Buffer old_buffer_; }; +/** + * @brief Collect blocks that are part of root block to be passed to ScheduleState::Replace for SRef + * reuse + */ +class ReuseBlocksCollector : public tir::StmtVisitor { + public: + static Map Collect(Block result, Map new_block_to_old) { + return ReuseBlocksCollector(new_block_to_old).Run(result); + } + + private: + /*! \brief Entry point */ + Map Run(const Block result) { + VisitStmt(result); + return block_sref_reuse_; + } + /*! \brief Constructor */ + explicit ReuseBlocksCollector(Map new_block_to_old) + : new_block_to_old_(new_block_to_old) {} + + /*! \brief Override the Stmt visiting behaviour */ + void VisitStmt_(const tir::BlockNode* block) override { + Block block_ref = GetRef(block); + auto it = new_block_to_old_.find(block_ref); + if (it != new_block_to_old_.end()) { + block_sref_reuse_.Set((*it).second, (*it).first); + } + StmtVisitor::VisitStmt_(block); + } + + /*! \brief New map to be filled with just blocks from scope block */ + Map block_sref_reuse_; + /*! \brief All block replacements collected so far */ + Map new_block_to_old_; +}; + class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { public: /*! @@ -730,17 +766,8 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body}); } - Map block_sref_reuse; - for (auto [after, before] : rewriter.new_block_to_old_) { - while (auto opt = rewriter.new_block_to_old_.Get(before)) { - before = opt.value(); - } - while (auto opt = block_sref_reuse.Get(after)) { - after = opt.value(); - } - - block_sref_reuse.Set(before, after); - } + Map block_sref_reuse = + ReuseBlocksCollector::Collect(result, rewriter.new_block_to_old_); return {result, block_sref_reuse}; } From 47f9a933ee260f9006ea91ba19175a76cf9eab6d Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Date: Fri, 17 Feb 2023 20:47:19 +0530 Subject: [PATCH 2/3] Add regression test --- .../test_tir_schedule_transform_layout.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index ace2b58acb0b..d866de33f100 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -173,6 +173,57 @@ def two_elementwise_unit_dim(A: T.Buffer((1, 128), "float32"), C: T.Buffer((1, 1 vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 +class TestTransformLayoutWithCacheWriteAndAxisSeparators(tvm.testing.CompareBeforeAfter): + """ + transform_layout with axis_separator on a buffer from cache_write should work as expected + """ + + @pytest.fixture + def transform(self): + def transform(mod): + + def transform_fn(x, y): + return [x // 32, y, tvm.te.AXIS_SEPARATOR, x % 32] + + sch = tvm.tir.Schedule(mod, debug_mask="all") + block_rv = sch.get_block("T_add") + sch.cache_write(block_rv, 0, "global") + sch.transform_layout(block_rv, ("write", 0), transform_fn, pad_value=0.0) + return sch.mod + + return transform + + def before( + p0: T.Buffer((T.int64(33), T.int64(128)), "float32"), + p1: T.Buffer((T.int64(33), T.int64(128)), "float32"), + T_add: T.Buffer((T.int64(33), T.int64(128)), "float32"), + ): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(33), T.int64(128)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(p0[v_ax0, v_ax1], p1[v_ax0, v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = p0[v_ax0, v_ax1] + p1[v_ax0, v_ax1] + + def expected(p0: T.Buffer((T.int64(33), T.int64(128)), "float32"), p1: T.Buffer((T.int64(33), T.int64(128)), "float32"), T_add: T.Buffer((T.int64(33), T.int64(128)), "float32")): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # with T.block("root"): + T_add_global = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(32)), axis_separators=[2]) + for axis0, axis1, axis2 in T.grid(T.int64(2), T.int64(128), T.int64(32)): + with T.block("T_add"): + v_axis0, v_axis1, v_axis2 = T.axis.remap("SSS", [axis0, axis1, axis2]) + T.reads(p0[v_axis0 * T.int64(32) + v_axis2, v_axis1], p1[v_axis0 * T.int64(32) + v_axis2, v_axis1]) + T.writes(T_add_global[v_axis0, v_axis1, v_axis2]) + T_add_global[v_axis0, v_axis1, v_axis2] = T.if_then_else(v_axis0 == T.int64(1) and T.int64(1) <= v_axis2, T.float32(0), p0[v_axis0 * T.int64(32) + v_axis2, v_axis1] + p1[v_axis0 * T.int64(32) + v_axis2, v_axis1]) + for ax0, ax1 in T.grid(T.int64(33), T.int64(128)): + with T.block("T_add_global"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_add_global[v0 // T.int64(32), v1, v0 % T.int64(32)]) + T.writes(T_add[v0, v1]) + T_add[v0, v1] = T_add_global[v0 // T.int64(32), v1, v0 % T.int64(32)] + # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks # fmt: on From 5c09f8cbbc92f396f6460f5f70d75efe234e1ece Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Date: Fri, 17 Feb 2023 20:57:00 +0530 Subject: [PATCH 3/3] Fix doxygen comment format --- src/tir/schedule/primitive/layout_transformation.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index cffa563af9cb..0e993d06dcf1 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -704,8 +704,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { Buffer old_buffer_; }; -/** - * @brief Collect blocks that are part of root block to be passed to ScheduleState::Replace for SRef +/*! + * \brief Collect blocks that are part of root block to be passed to ScheduleState::Replace for SRef * reuse */ class ReuseBlocksCollector : public tir::StmtVisitor {