Skip to content

Commit

Permalink
[TIR] [Bugfix] Pass the correct block_sref_reuse to Replace (#14023)
Browse files Browse the repository at this point in the history
* [TIR] [Bugfix] Pass the correct block_sref_reuse to Replace

A mismatch between the blocks present in the `result` vs the blocks
passed in `block_sref_to_reuse` caused the bug mentioned in #13974.

This patch tries to fix that bug by collecting only the blocks that are
part of result and also present in the block replacement map
`new_block_to_old_`. Since the scope block is `result`, only that block
and its child blocks would be replaced, and any replaced block would be
present in `rewriter.new_block_to_old_`. Thus, collecting the replaced
blocks from among child blocks of `result` guarantees that the
`block_sref_reuse` would contain all the replaced blocks and that
they'll point to the correct block in `result` thus avoiding the missing
SRef error.
  • Loading branch information
quic-sanirudh authored Feb 18, 2023
1 parent 14bc5e4 commit 6f232f9
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 11 deletions.
49 changes: 38 additions & 11 deletions src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,42 @@ class TransformLayoutPlanner : private StmtExprVisitor {
Buffer old_buffer_;
};

/*!
* \brief Collect blocks that are part of root block to be passed to ScheduleState::Replace for SRef
* reuse
*/
class ReuseBlocksCollector : public tir::StmtVisitor {
public:
static Map<Block, Block> Collect(Block result, Map<Block, Block> new_block_to_old) {
return ReuseBlocksCollector(new_block_to_old).Run(result);
}

private:
/*! \brief Entry point */
Map<Block, Block> Run(const Block result) {
VisitStmt(result);
return block_sref_reuse_;
}
/*! \brief Constructor */
explicit ReuseBlocksCollector(Map<Block, Block> new_block_to_old)
: new_block_to_old_(new_block_to_old) {}

/*! \brief Override the Stmt visiting behaviour */
void VisitStmt_(const tir::BlockNode* block) override {
Block block_ref = GetRef<Block>(block);
auto it = new_block_to_old_.find(block_ref);
if (it != new_block_to_old_.end()) {
block_sref_reuse_.Set((*it).second, (*it).first);
}
StmtVisitor::VisitStmt_(block);
}

/*! \brief New map to be filled with just blocks from scope block */
Map<Block, Block> block_sref_reuse_;
/*! \brief All block replacements collected so far */
Map<Block, Block> new_block_to_old_;
};

class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
public:
/*!
Expand All @@ -730,17 +766,8 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body});
}

Map<Block, Block> block_sref_reuse;
for (auto [after, before] : rewriter.new_block_to_old_) {
while (auto opt = rewriter.new_block_to_old_.Get(before)) {
before = opt.value();
}
while (auto opt = block_sref_reuse.Get(after)) {
after = opt.value();
}

block_sref_reuse.Set(before, after);
}
Map<Block, Block> block_sref_reuse =
ReuseBlocksCollector::Collect(result, rewriter.new_block_to_old_);

return {result, block_sref_reuse};
}
Expand Down
51 changes: 51 additions & 0 deletions tests/python/unittest/test_tir_schedule_transform_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,57 @@ def two_elementwise_unit_dim(A: T.Buffer((1, 128), "float32"), C: T.Buffer((1, 1
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0

class TestTransformLayoutWithCacheWriteAndAxisSeparators(tvm.testing.CompareBeforeAfter):
"""
transform_layout with axis_separator on a buffer from cache_write should work as expected
"""

@pytest.fixture
def transform(self):
def transform(mod):

def transform_fn(x, y):
return [x // 32, y, tvm.te.AXIS_SEPARATOR, x % 32]

sch = tvm.tir.Schedule(mod, debug_mask="all")
block_rv = sch.get_block("T_add")
sch.cache_write(block_rv, 0, "global")
sch.transform_layout(block_rv, ("write", 0), transform_fn, pad_value=0.0)
return sch.mod

return transform

def before(
p0: T.Buffer((T.int64(33), T.int64(128)), "float32"),
p1: T.Buffer((T.int64(33), T.int64(128)), "float32"),
T_add: T.Buffer((T.int64(33), T.int64(128)), "float32"),
):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(33), T.int64(128)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(p0[v_ax0, v_ax1], p1[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = p0[v_ax0, v_ax1] + p1[v_ax0, v_ax1]

def expected(p0: T.Buffer((T.int64(33), T.int64(128)), "float32"), p1: T.Buffer((T.int64(33), T.int64(128)), "float32"), T_add: T.Buffer((T.int64(33), T.int64(128)), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
T_add_global = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(32)), axis_separators=[2])
for axis0, axis1, axis2 in T.grid(T.int64(2), T.int64(128), T.int64(32)):
with T.block("T_add"):
v_axis0, v_axis1, v_axis2 = T.axis.remap("SSS", [axis0, axis1, axis2])
T.reads(p0[v_axis0 * T.int64(32) + v_axis2, v_axis1], p1[v_axis0 * T.int64(32) + v_axis2, v_axis1])
T.writes(T_add_global[v_axis0, v_axis1, v_axis2])
T_add_global[v_axis0, v_axis1, v_axis2] = T.if_then_else(v_axis0 == T.int64(1) and T.int64(1) <= v_axis2, T.float32(0), p0[v_axis0 * T.int64(32) + v_axis2, v_axis1] + p1[v_axis0 * T.int64(32) + v_axis2, v_axis1])
for ax0, ax1 in T.grid(T.int64(33), T.int64(128)):
with T.block("T_add_global"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_add_global[v0 // T.int64(32), v1, v0 % T.int64(32)])
T.writes(T_add[v0, v1])
T_add[v0, v1] = T_add_global[v0 // T.int64(32), v1, v0 % T.int64(32)]

# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
# fmt: on

Expand Down

0 comments on commit 6f232f9

Please sign in to comment.