Skip to content

Commit

Permalink
[TIR] Encode conditional accesses info into block read/write regions (a…
Browse files Browse the repository at this point in the history
…pache#9880)

* encode conditional accesses info into block read/write regions

* compare ir after simplify
  • Loading branch information
wrongtest-intellif authored and crazydemo committed Jan 27, 2022
1 parent d42277d commit 6cce656
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 30 deletions.
29 changes: 23 additions & 6 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class BlockReadWriteDetector : public StmtExprVisitor {
private:
/*! \brief Iteration range for loop_vars */
std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
/*! \brief Extra iteration range hint for free vars */
std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
/*! \brief The buffers that the current block reads */
std::vector<Buffer> read_buffers_;
/*! \brief The buffers that the current block writes */
Expand Down Expand Up @@ -96,6 +98,9 @@ class BlockReadWriteDetector : public StmtExprVisitor {
/*! \brief Helper function to update a opaque access. */
void UpdateOpaque(const Var& buffer_var);

/*! \brief Helper function to relax the buffer indices */
arith::IntSet RelaxAccessIndex(const PrimExpr& index);

void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const IfThenElseNode* op) override;
void VisitStmt_(const BlockRealizeNode* op) override;
Expand Down Expand Up @@ -140,10 +145,22 @@ void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) {
ExprVisitor::VisitExpr_(op);
}

arith::IntSet BlockReadWriteDetector::RelaxAccessIndex(const PrimExpr& index) {
arith::IntSet relaxed = arith::EvalSet(index, dom_map_);
if (!hint_map_.empty()) {
// take non-relaxed var bound hints into considerations
// eg, if i * 4 + j with i >= 10 and j in [0, 4), only j in domain scope
// then the index region can be relaxed to [i*4, i*4+4) ^ [40, inf)
arith::IntSet hint_bound = arith::EvalSet(relaxed, hint_map_);
relaxed = arith::Intersect({relaxed, hint_bound});
}
return relaxed;
}

void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
std::vector<arith::IntSet> relaxed_region;
for (const PrimExpr& index : op->indices) {
relaxed_region.push_back(arith::EvalSet(index, dom_map_));
relaxed_region.push_back(RelaxAccessIndex(index));
}
Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
ExprVisitor::VisitExpr_(op);
Expand All @@ -160,12 +177,12 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
VisitExpr(op->condition);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, true);
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
// Visit else branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, false);
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
StmtExprVisitor::VisitStmt(op->else_case);
}
}
Expand All @@ -175,12 +192,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
VisitExpr(op->args[0]);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, true);
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true);
StmtExprVisitor::VisitExpr(op->args[1]);
}
{
// Visit else branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, false);
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false);
StmtExprVisitor::VisitExpr(op->args[2]);
}
return;
Expand All @@ -196,7 +213,7 @@ void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
std::vector<arith::IntSet> relaxed_region;
for (const PrimExpr& index : op->indices) {
relaxed_region.push_back(arith::EvalSet(index, dom_map_));
relaxed_region.push_back(RelaxAccessIndex(index));
}
Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
StmtVisitor::VisitStmt_(op);
Expand Down
10 changes: 6 additions & 4 deletions src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
StmtExprVisitor::VisitExpr(op->condition);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, true);
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
// Visit else branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, false);
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
StmtExprVisitor::VisitStmt(op->else_case);
}
}
Expand All @@ -139,12 +139,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
StmtExprVisitor::VisitExpr(op->args[0]);
{
// Visit then branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, true);
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true);
StmtExprVisitor::VisitExpr(op->args[1]);
}
{
// Visit else branch
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, false);
With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false);
StmtExprVisitor::VisitExpr(op->args[2]);
}
return;
Expand Down Expand Up @@ -282,6 +282,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor {

/*! \brief The map from loop vars to their iter range. */
std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
/*! \brief Extra map from free vars to their iter range hints. */
std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
/*! \brief The analyzer aware of loop domains. */
arith::Analyzer dom_analyzer_;
/*! \brief The map from Buffer to it's relaxed access set. */
Expand Down
62 changes: 49 additions & 13 deletions src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,18 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
Array<Var> vars = Array<Var>(var_set.begin(), var_set.end());
Map<Var, Range> ranges;
for (const Var& v : vars) {
auto it = dom_map_->find(v.get());
if (it != dom_map_->end()) {
const auto& int_set = it->second;
ranges.Set(v, Range::FromMinExtent(int_set.min(),
analyzer.Simplify(int_set.max() - int_set.min() + 1)));
arith::IntSet dom;
auto relax_it = relax_map_->find(v.get());
if (relax_it != relax_map_->end()) {
dom = relax_it->second;
} else {
auto hint_it = hint_map_->find(v.get());
if (hint_it != hint_map_->end()) {
dom = hint_it->second;
}
}
if (dom.defined()) {
ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer.Simplify(dom.max() - dom.min() + 1)));
}
}
// solve constraints
Expand All @@ -314,24 +321,53 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
}

ConditionalBoundsContext::ConditionalBoundsContext(
const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* dom_map,
bool is_true_branch)
: condition_(condition), dom_map_(dom_map), is_true_branch_(is_true_branch) {}
const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* relax_map,
std::unordered_map<const VarNode*, arith::IntSet>* hint_map, bool is_true_branch)
: condition_(condition),
relax_map_(relax_map),
hint_map_(hint_map),
is_true_branch_(is_true_branch) {}

void ConditionalBoundsContext::EnterWithScope() {
for (const auto& p : GetVarBoundsFromCondition()) {
const auto* var = p.first.get();
auto it = dom_map_->find(var);
if (it != dom_map_->end()) {
origin_map_.emplace(var, it->second);
it->second = arith::Intersect({it->second, arith::IntSet::FromRange(p.second)});
arith::IntSet new_dom = arith::IntSet::FromRange(p.second);
auto relax_it = relax_map_->find(var);
if (relax_it != relax_map_->end()) {
// this is a bound for relaxed var
origin_map_.emplace(var, relax_it->second);
relax_it->second = arith::Intersect({relax_it->second, new_dom});
} else {
// this is a bound for free var
auto hint_it = hint_map_->find(var);
if (hint_it != hint_map_->end()) {
origin_map_.emplace(var, hint_it->second);
hint_it->second = arith::Intersect({hint_it->second, new_dom});
} else {
origin_map_.emplace(var, arith::IntSet::Nothing());
hint_map_->insert(hint_it, {var, new_dom});
}
}
}
}

void ConditionalBoundsContext::ExitWithScope() {
for (const auto& p : origin_map_) {
(*dom_map_)[p.first] = p.second;
const auto* var = p.first;
auto relax_it = relax_map_->find(var);
if (relax_it != relax_map_->end()) {
// recover bound for relaxed var
relax_it->second = p.second;
} else {
// recover bound for free var
auto hint_it = hint_map_->find(var);
ICHECK(hint_it != hint_map_->end());
if (p.second.IsNothing()) {
hint_map_->erase(hint_it);
} else {
hint_it->second = p.second;
}
}
}
}

Expand Down
18 changes: 11 additions & 7 deletions src/tir/transforms/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,21 +231,23 @@ Bool IsFromLegacyTESchedule(PrimFunc f);
*\brief Context helper to update domain map within conditional scope.
*
* Assume the condition is `0 <= i && i < 9` and global domain of i is [0, 20], thus `bounds[i]` is
*[0, 8]. Then `With<ConditionalBoundsContext> ctx(&dom_map, bounds, true)` step into scope where
*dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> ctx(&dom_map, bounds, false)` step into
*scope where dom_map[i] is [9, 20]
* [0, 8]. Then `With<ConditionalBoundsContext> ctx(condition, &relax_map, &hint_map, true)` step
*into scope where dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> ctx(condition,
*&relax_map, &hint_map, false)` step into scope where dom_map[i] is [9, 20]
*/
class ConditionalBoundsContext {
private:
friend class With<ConditionalBoundsContext>;
/*!
* \brief Construct a condition bounds context.
* \param condition The condition holds on true branch.
* \param dom_map The global domain map to be updated.
* \param relax_map The domain map for relaxed vars to update.
* \param hint_map The domain map for free vars to update.
* \param is_true_branch Whether step into the branch where condition bounds holds.
*/
ConditionalBoundsContext(const PrimExpr& condition,
std::unordered_map<const VarNode*, arith::IntSet>* dom_map,
std::unordered_map<const VarNode*, arith::IntSet>* relax_map,
std::unordered_map<const VarNode*, arith::IntSet>* hint_map,
bool is_true_branch);
void EnterWithScope();
void ExitWithScope();
Expand All @@ -255,8 +257,10 @@ class ConditionalBoundsContext {

/*! \brief the condition holds on true branch. */
const PrimExpr& condition_;
/*! \brief global domain map to updated */
std::unordered_map<const VarNode*, arith::IntSet>* dom_map_;
/*! \brief domain map for relaxed vars to update */
std::unordered_map<const VarNode*, arith::IntSet>* relax_map_;
/*! \brief domain map for free vars to update */
std::unordered_map<const VarNode*, arith::IntSet>* hint_map_;
/*! \brief whether is on true branch */
bool is_true_branch_;
/*! \brief used to record and restore original var bounds */
Expand Down
66 changes: 66 additions & 0 deletions tests/python/unittest/test_tir_analysis_get_block_access_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,41 @@ def access_in_branch_func() -> None:
B[i] = A[i - 1]


@T.prim_func
def access_of_padding_pattern() -> None:
X = T.alloc_buffer([28, 28])
X_pad = T.alloc_buffer([32, 32])
Y = T.alloc_buffer([28, 28])
for i, j in T.grid(32, 32):
with T.block("padding"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(
[
X[
T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1,
T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1,
]
]
)
T.writes([X_pad[vi, vj]])
X_pad[vi, vj] = T.if_then_else(
2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32"
)
with T.block("padding_reverse"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads([X_pad[T.max(vi, 2) : T.min(vi, 29) + 1, T.max(vj, 2) : T.min(vj, 29) + 1]])
T.writes(
[
Y[
T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1,
T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1,
]
]
)
if 2 <= vi and vi < 30 and 2 <= vj and vj < 30:
Y[vi - 2, vj - 2] = X_pad[vi, vj]


def test_block_access_region_detector():
block = func.body.block.body.block
alloc_buffers = func.body.block.alloc_buffers
Expand Down Expand Up @@ -220,10 +255,41 @@ def test_access_in_branch_func():
tvm.ir.assert_structural_equal(ret0[1], ret1[1])


def test_access_of_padding_pattern():
s = tvm.tir.schedule.Schedule(access_of_padding_pattern)
alloc_buffers = s.get_sref(s.get_block("root")).stmt.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}

def do_compare_buffer_region(region, expect):
assert region.buffer == expect.buffer
analyzer = tvm.arith.Analyzer()
for k, rng in enumerate(region.region):
tvm.ir.assert_structural_equal(
analyzer.simplify(rng.min), analyzer.simplify(expect.region[k].min)
)
tvm.ir.assert_structural_equal(
analyzer.simplify(rng.extent), analyzer.simplify(expect.region[k].extent)
)

def do_check_block(block_name):
block = s.get_sref(s.get_block(block_name)).stmt
expect_reads = block.reads
expect_writes = block.writes
ret = tir.analysis.get_block_access_region(block, buffer_var_map)
for i, read in enumerate(ret[0]):
do_compare_buffer_region(read, expect_reads[i])
for i, write in enumerate(ret[1]):
do_compare_buffer_region(write, expect_writes[i])

do_check_block("padding")
do_check_block("padding_reverse")


if __name__ == "__main__":
test_block_access_region_detector()
test_opaque_block()
test_opaque_access()
test_match_buffer()
test_access_in_if_then_else_func()
test_access_in_branch_func()
test_access_of_padding_pattern()
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def _check(original, transformed):
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.CompactBufferAllocation()(mod)
mod = tvm.tir.transform.Simplify()(mod)
transformed = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(transformed))["main"]
tvm.ir.assert_structural_equal(mod["main"], transformed)


Expand Down

0 comments on commit 6cce656

Please sign in to comment.