Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] [Bugfix] Pass the correct block_sref_reuse to Replace #14023

Merged
merged 3 commits into from
Feb 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,42 @@ class TransformLayoutPlanner : private StmtExprVisitor {
Buffer old_buffer_;
};

/*!
* \brief Collect blocks that are part of root block to be passed to ScheduleState::Replace for SRef
* reuse
*/
class ReuseBlocksCollector : public tir::StmtVisitor {
public:
static Map<Block, Block> Collect(Block result, Map<Block, Block> new_block_to_old) {
return ReuseBlocksCollector(new_block_to_old).Run(result);
}

private:
/*! \brief Entry point */
Map<Block, Block> Run(const Block result) {
VisitStmt(result);
return block_sref_reuse_;
}
/*! \brief Constructor */
explicit ReuseBlocksCollector(Map<Block, Block> new_block_to_old)
: new_block_to_old_(new_block_to_old) {}

/*! \brief Override the Stmt visiting behaviour */
void VisitStmt_(const tir::BlockNode* block) override {
Block block_ref = GetRef<Block>(block);
auto it = new_block_to_old_.find(block_ref);
if (it != new_block_to_old_.end()) {
block_sref_reuse_.Set((*it).second, (*it).first);
}
StmtVisitor::VisitStmt_(block);
}

/*! \brief New map to be filled with just blocks from scope block */
Map<Block, Block> block_sref_reuse_;
/*! \brief All block replacements collected so far */
Map<Block, Block> new_block_to_old_;
};

class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
public:
/*!
Expand All @@ -730,17 +766,8 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body});
}

Map<Block, Block> block_sref_reuse;
for (auto [after, before] : rewriter.new_block_to_old_) {
while (auto opt = rewriter.new_block_to_old_.Get(before)) {
before = opt.value();
}
while (auto opt = block_sref_reuse.Get(after)) {
after = opt.value();
}

block_sref_reuse.Set(before, after);
}
Map<Block, Block> block_sref_reuse =
ReuseBlocksCollector::Collect(result, rewriter.new_block_to_old_);

return {result, block_sref_reuse};
}
Expand Down
51 changes: 51 additions & 0 deletions tests/python/unittest/test_tir_schedule_transform_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,57 @@ def two_elementwise_unit_dim(A: T.Buffer((1, 128), "float32"), C: T.Buffer((1, 1
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0

class TestTransformLayoutWithCacheWriteAndAxisSeparators(tvm.testing.CompareBeforeAfter):
"""
transform_layout with axis_separator on a buffer from cache_write should work as expected
"""

@pytest.fixture
def transform(self):
def transform(mod):

def transform_fn(x, y):
return [x // 32, y, tvm.te.AXIS_SEPARATOR, x % 32]

sch = tvm.tir.Schedule(mod, debug_mask="all")
block_rv = sch.get_block("T_add")
sch.cache_write(block_rv, 0, "global")
sch.transform_layout(block_rv, ("write", 0), transform_fn, pad_value=0.0)
return sch.mod

return transform

def before(
p0: T.Buffer((T.int64(33), T.int64(128)), "float32"),
p1: T.Buffer((T.int64(33), T.int64(128)), "float32"),
T_add: T.Buffer((T.int64(33), T.int64(128)), "float32"),
):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(33), T.int64(128)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(p0[v_ax0, v_ax1], p1[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = p0[v_ax0, v_ax1] + p1[v_ax0, v_ax1]

def expected(p0: T.Buffer((T.int64(33), T.int64(128)), "float32"), p1: T.Buffer((T.int64(33), T.int64(128)), "float32"), T_add: T.Buffer((T.int64(33), T.int64(128)), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
T_add_global = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(32)), axis_separators=[2])
for axis0, axis1, axis2 in T.grid(T.int64(2), T.int64(128), T.int64(32)):
with T.block("T_add"):
v_axis0, v_axis1, v_axis2 = T.axis.remap("SSS", [axis0, axis1, axis2])
T.reads(p0[v_axis0 * T.int64(32) + v_axis2, v_axis1], p1[v_axis0 * T.int64(32) + v_axis2, v_axis1])
T.writes(T_add_global[v_axis0, v_axis1, v_axis2])
T_add_global[v_axis0, v_axis1, v_axis2] = T.if_then_else(v_axis0 == T.int64(1) and T.int64(1) <= v_axis2, T.float32(0), p0[v_axis0 * T.int64(32) + v_axis2, v_axis1] + p1[v_axis0 * T.int64(32) + v_axis2, v_axis1])
for ax0, ax1 in T.grid(T.int64(33), T.int64(128)):
with T.block("T_add_global"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_add_global[v0 // T.int64(32), v1, v0 % T.int64(32)])
T.writes(T_add[v0, v1])
T_add[v0, v1] = T_add_global[v0 // T.int64(32), v1, v0 % T.int64(32)]

# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
# fmt: on

Expand Down