Skip to content

Commit

Permalink
[TIR] [Bugfix] Pass the correct block_sref_reuse to Replace
Browse files Browse the repository at this point in the history
A mismatch between the blocks present in the `result` vs the blocks
passed in block_sref_to_reuse caused the bug mentioned in #13974. This
patch tries to fix that bug by collecting only the blocks that are part
of result and also present in the block replacement map
new_block_to_old_
  • Loading branch information
quic-sanirudh committed Feb 17, 2023
1 parent d7253fb commit fbb30f0
Showing 1 changed file with 38 additions and 11 deletions.
49 changes: 38 additions & 11 deletions src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,42 @@ class TransformLayoutPlanner : private StmtExprVisitor {
Buffer old_buffer_;
};

/**
* @brief Collect blocks that are part of root block to be passed to ScheduleState::Replace for SRef
* reuse
*/
class ReuseBlocksCollector : public tir::StmtVisitor {
public:
static Map<Block, Block> Collect(Block result, Map<Block, Block> new_block_to_old) {
return ReuseBlocksCollector(new_block_to_old).Run(result);
}

private:
/*! \brief Entry point */
Map<Block, Block> Run(const Block result) {
VisitStmt(result);
return block_sref_reuse_;
}
/*! \brief Constructor */
explicit ReuseBlocksCollector(Map<Block, Block> new_block_to_old)
: new_block_to_old_(new_block_to_old) {}

/*! \brief Override the Stmt visiting behaviour */
void VisitStmt_(const tir::BlockNode* block) override {
Block block_ref = GetRef<Block>(block);
auto it = new_block_to_old_.find(block_ref);
if (it != new_block_to_old_.end()) {
block_sref_reuse_.Set((*it).second, (*it).first);
}
StmtVisitor::VisitStmt_(block);
}

/*! \brief New map to be filled with just blocks from scope block */
Map<Block, Block> block_sref_reuse_;
/*! \brief All block replacements collected so far */
Map<Block, Block> new_block_to_old_;
};

class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
public:
/*!
Expand All @@ -730,17 +766,8 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body});
}

Map<Block, Block> block_sref_reuse;
for (auto [after, before] : rewriter.new_block_to_old_) {
while (auto opt = rewriter.new_block_to_old_.Get(before)) {
before = opt.value();
}
while (auto opt = block_sref_reuse.Get(after)) {
after = opt.value();
}

block_sref_reuse.Set(before, after);
}
Map<Block, Block> block_sref_reuse =
ReuseBlocksCollector::Collect(result, rewriter.new_block_to_old_);

return {result, block_sref_reuse};
}
Expand Down

0 comments on commit fbb30f0

Please sign in to comment.