diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 776538adbc0fa..07dcace0b381a 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -56,6 +56,8 @@ class BlockReadWriteDetector : public StmtExprVisitor { private: /*! \brief Iteration range for loop_vars */ std::unordered_map dom_map_; + /*! \brief Extra iteration range hint for free vars */ + std::unordered_map hint_map_; /*! \brief The buffers that the current block reads */ std::vector read_buffers_; /*! \brief The buffers that the current block writes */ @@ -96,6 +98,9 @@ class BlockReadWriteDetector : public StmtExprVisitor { /*! \brief Helper function to update a opaque access. */ void UpdateOpaque(const Var& buffer_var); + /*! \brief Helper function to relax the buffer indices */ + arith::IntSet RelaxAccessIndex(const PrimExpr& index); + void VisitStmt_(const ForNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const BlockRealizeNode* op) override; @@ -140,10 +145,22 @@ void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) { ExprVisitor::VisitExpr_(op); } +arith::IntSet BlockReadWriteDetector::RelaxAccessIndex(const PrimExpr& index) { + arith::IntSet relaxed = arith::EvalSet(index, dom_map_); + if (!hint_map_.empty()) { + // take non-relaxed var bound hints into considerations + // eg, if i * 4 + j with i >= 10 and j in [0, 4), only j in domain scope + // then the index region can be relaxed to [i*4, i*4+4) ^ [40, inf) + arith::IntSet hint_bound = arith::EvalSet(relaxed, hint_map_); + relaxed = arith::Intersect({relaxed, hint_bound}); + } + return relaxed; +} + void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) { std::vector relaxed_region; for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(arith::EvalSet(index, dom_map_)); + relaxed_region.push_back(RelaxAccessIndex(index)); } Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region); ExprVisitor::VisitExpr_(op); @@ -160,12 +177,12 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { VisitExpr(op->condition); { // Visit then branch - With ctx(op->condition, &dom_map_, true); + With ctx(op->condition, &dom_map_, &hint_map_, true); StmtExprVisitor::VisitStmt(op->then_case); } if (op->else_case.defined()) { // Visit else branch - With ctx(op->condition, &dom_map_, false); + With ctx(op->condition, &dom_map_, &hint_map_, false); StmtExprVisitor::VisitStmt(op->else_case); } } @@ -175,12 +192,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { VisitExpr(op->args[0]); { // Visit then branch - With ctx(op->args[0], &dom_map_, true); + With ctx(op->args[0], &dom_map_, &hint_map_, true); StmtExprVisitor::VisitExpr(op->args[1]); } { // Visit else branch - With ctx(op->args[0], &dom_map_, false); + With ctx(op->args[0], &dom_map_, &hint_map_, false); StmtExprVisitor::VisitExpr(op->args[2]); } return; @@ -196,7 +213,7 @@ void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) { void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) { std::vector relaxed_region; for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(arith::EvalSet(index, dom_map_)); + relaxed_region.push_back(RelaxAccessIndex(index)); } Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region); StmtVisitor::VisitStmt_(op); diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 07f977860d933..20ddd7f84a35d 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -123,12 +123,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor { StmtExprVisitor::VisitExpr(op->condition); { // Visit then branch - With ctx(op->condition, &dom_map_, true); + With ctx(op->condition, &dom_map_, &hint_map_, true); StmtExprVisitor::VisitStmt(op->then_case); } if (op->else_case.defined()) { // Visit else branch - With ctx(op->condition, &dom_map_, false); + With ctx(op->condition, &dom_map_, &hint_map_, false); StmtExprVisitor::VisitStmt(op->else_case); } } @@ -139,12 +139,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor { StmtExprVisitor::VisitExpr(op->args[0]); { // Visit then branch - With ctx(op->args[0], &dom_map_, true); + With ctx(op->args[0], &dom_map_, &hint_map_, true); StmtExprVisitor::VisitExpr(op->args[1]); } { // Visit else branch - With ctx(op->args[0], &dom_map_, false); + With ctx(op->args[0], &dom_map_, &hint_map_, false); StmtExprVisitor::VisitExpr(op->args[2]); } return; @@ -282,6 +282,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor { /*! \brief The map from loop vars to their iter range. */ std::unordered_map dom_map_; + /*! \brief Extra map from free vars to their iter range hints. */ + std::unordered_map hint_map_; /*! \brief The analyzer aware of loop domains. */ arith::Analyzer dom_analyzer_; /*! \brief The map from Buffer to it's relaxed access set. */ diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 2423b09d4fb7c..bc2f7ad6f357f 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -300,11 +300,18 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { Array vars = Array(var_set.begin(), var_set.end()); Map ranges; for (const Var& v : vars) { - auto it = dom_map_->find(v.get()); - if (it != dom_map_->end()) { - const auto& int_set = it->second; - ranges.Set(v, Range::FromMinExtent(int_set.min(), - analyzer.Simplify(int_set.max() - int_set.min() + 1))); + arith::IntSet dom; + auto relax_it = relax_map_->find(v.get()); + if (relax_it != relax_map_->end()) { + dom = relax_it->second; + } else { + auto hint_it = hint_map_->find(v.get()); + if (hint_it != hint_map_->end()) { + dom = hint_it->second; + } + } + if (dom.defined()) { + ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer.Simplify(dom.max() - dom.min() + 1))); } } // solve constraints @@ -314,24 +321,53 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { } ConditionalBoundsContext::ConditionalBoundsContext( - const PrimExpr& condition, std::unordered_map* dom_map, - bool is_true_branch) - : condition_(condition), dom_map_(dom_map), is_true_branch_(is_true_branch) {} + const PrimExpr& condition, std::unordered_map* relax_map, + std::unordered_map* hint_map, bool is_true_branch) + : condition_(condition), + relax_map_(relax_map), + hint_map_(hint_map), + is_true_branch_(is_true_branch) {} void ConditionalBoundsContext::EnterWithScope() { for (const auto& p : GetVarBoundsFromCondition()) { const auto* var = p.first.get(); - auto it = dom_map_->find(var); - if (it != dom_map_->end()) { - origin_map_.emplace(var, it->second); - it->second = arith::Intersect({it->second, arith::IntSet::FromRange(p.second)}); + arith::IntSet new_dom = arith::IntSet::FromRange(p.second); + auto relax_it = relax_map_->find(var); + if (relax_it != relax_map_->end()) { + // this is a bound for relaxed var + origin_map_.emplace(var, relax_it->second); + relax_it->second = arith::Intersect({relax_it->second, new_dom}); + } else { + // this is a bound for free var + auto hint_it = hint_map_->find(var); + if (hint_it != hint_map_->end()) { + origin_map_.emplace(var, hint_it->second); + hint_it->second = arith::Intersect({hint_it->second, new_dom}); + } else { + origin_map_.emplace(var, arith::IntSet::Nothing()); + hint_map_->insert(hint_it, {var, new_dom}); + } } } } void ConditionalBoundsContext::ExitWithScope() { for (const auto& p : origin_map_) { - (*dom_map_)[p.first] = p.second; + const auto* var = p.first; + auto relax_it = relax_map_->find(var); + if (relax_it != relax_map_->end()) { + // recover bound for relaxed var + relax_it->second = p.second; + } else { + // recover bound for free var + auto hint_it = hint_map_->find(var); + ICHECK(hint_it != hint_map_->end()); + if (p.second.IsNothing()) { + hint_map_->erase(hint_it); + } else { + hint_it->second = p.second; + } + } } } diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 7b1d34c8162de..da52a82a2f087 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -231,9 +231,9 @@ Bool IsFromLegacyTESchedule(PrimFunc f); *\brief Context helper to update domain map within conditional scope. * * Assume the condition is `0 <= i && i < 9` and global domain of i is [0, 20], thus `bounds[i]` is - *[0, 8]. Then `With ctx(&dom_map, bounds, true)` step into scope where - *dom_map[i] is [0, 8] and `With ctx(&dom_map, bounds, false)` step into - *scope where dom_map[i] is [9, 20] + * [0, 8]. Then `With ctx(condition, &relax_map, &hint_map, true)` step + *into scope where dom_map[i] is [0, 8] and `With ctx(condition, + *&relax_map, &hint_map, false)` step into scope where dom_map[i] is [9, 20] */ class ConditionalBoundsContext { private: @@ -241,11 +241,13 @@ class ConditionalBoundsContext { /*! * \brief Construct a condition bounds context. * \param condition The condition holds on true branch. - * \param dom_map The global domain map to be updated. + * \param relax_map The domain map for relaxed vars to update. + * \param hint_map The domain map for free vars to update. * \param is_true_branch Whether step into the branch where condition bounds holds. */ ConditionalBoundsContext(const PrimExpr& condition, - std::unordered_map* dom_map, + std::unordered_map* relax_map, + std::unordered_map* hint_map, bool is_true_branch); void EnterWithScope(); void ExitWithScope(); @@ -255,8 +257,10 @@ class ConditionalBoundsContext { /*! \brief the condition holds on true branch. */ const PrimExpr& condition_; - /*! \brief global domain map to updated */ - std::unordered_map* dom_map_; + /*! \brief domain map for relaxed vars to update */ + std::unordered_map* relax_map_; + /*! \brief domain map for free vars to update */ + std::unordered_map* hint_map_; /*! \brief whether is on true branch */ bool is_true_branch_; /*! \brief used to record and restore original var bounds */ diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index e508fbb0f7477..54037541016d3 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -130,6 +130,41 @@ def access_in_branch_func() -> None: B[i] = A[i - 1] +@T.prim_func +def access_of_padding_pattern() -> None: + X = T.alloc_buffer([28, 28]) + X_pad = T.alloc_buffer([32, 32]) + Y = T.alloc_buffer([28, 28]) + for i, j in T.grid(32, 32): + with T.block("padding"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads( + [ + X[ + T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1, + T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1, + ] + ] + ) + T.writes([X_pad[vi, vj]]) + X_pad[vi, vj] = T.if_then_else( + 2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32" + ) + with T.block("padding_reverse"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads([X_pad[T.max(vi, 2) : T.min(vi, 29) + 1, T.max(vj, 2) : T.min(vj, 29) + 1]]) + T.writes( + [ + Y[ + T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1, + T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1, + ] + ] + ) + if 2 <= vi and vi < 30 and 2 <= vj and vj < 30: + Y[vi - 2, vj - 2] = X_pad[vi, vj] + + def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers @@ -220,6 +255,36 @@ def test_access_in_branch_func(): tvm.ir.assert_structural_equal(ret0[1], ret1[1]) +def test_access_of_padding_pattern(): + s = tvm.tir.schedule.Schedule(access_of_padding_pattern) + alloc_buffers = s.get_sref(s.get_block("root")).stmt.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + + def do_compare_buffer_region(region, expect): + assert region.buffer == expect.buffer + analyzer = tvm.arith.Analyzer() + for k, rng in enumerate(region.region): + tvm.ir.assert_structural_equal( + analyzer.simplify(rng.min), analyzer.simplify(expect.region[k].min) + ) + tvm.ir.assert_structural_equal( + analyzer.simplify(rng.extent), analyzer.simplify(expect.region[k].extent) + ) + + def do_check_block(block_name): + block = s.get_sref(s.get_block(block_name)).stmt + expect_reads = block.reads + expect_writes = block.writes + ret = tir.analysis.get_block_access_region(block, buffer_var_map) + for i, read in enumerate(ret[0]): + do_compare_buffer_region(read, expect_reads[i]) + for i, write in enumerate(ret[1]): + do_compare_buffer_region(write, expect_writes[i]) + + do_check_block("padding") + do_check_block("padding_reverse") + + if __name__ == "__main__": test_block_access_region_detector() test_opaque_block() @@ -227,3 +292,4 @@ def test_access_in_branch_func(): test_match_buffer() test_access_in_if_then_else_func() test_access_in_branch_func() + test_access_of_padding_pattern() diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 57c87e5dedf4a..9b844853f2438 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -24,6 +24,7 @@ def _check(original, transformed): mod = tvm.IRModule.from_expr(func) mod = tvm.tir.transform.CompactBufferAllocation()(mod) mod = tvm.tir.transform.Simplify()(mod) + transformed = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(transformed))["main"] tvm.ir.assert_structural_equal(mod["main"], transformed)