Skip to content

Commit

Permalink
[TIR] Affine utility support iter lowerbound and diagnostics (apache#…
Browse files Browse the repository at this point in the history
…9699)

* Enable freevars, iter lowerbound and diagnostics in affine utility

* fix lint issues and compare bug

* update to use iter shift instead of itermark min for lowerbound

* add testcase of fused iters sum with multiple lowerbounds

* add more affine check testcases, fix bug for single iter and duplicate constraints on iter

* add a newline to comment

* forbidden predicate unmatch

Co-authored-by: baoxinqi <[email protected]>
  • Loading branch information
2 people authored and ylc committed Jan 7, 2022
1 parent c2e4707 commit 375abf6
Show file tree
Hide file tree
Showing 9 changed files with 720 additions and 122 deletions.
8 changes: 6 additions & 2 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#define TVM_ARITH_ITER_AFFINE_MAP_H_

#include <tvm/arith/analyzer.h>
#include <tvm/ir/diagnostic.h>
#include <tvm/ir/expr.h>
#include <tvm/tir/var.h>

Expand Down Expand Up @@ -275,13 +276,14 @@ class IterSumExpr : public IterMapExpr {
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param analyzer Analyzer used to get context information.
* \param diag_ctx Diagnostic context.
*
* \return The detected pattern if a match exists,
* otherwise return an empty array.
*/
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer);
arith::Analyzer* analyzer, DiagnosticContext diag_ctx);
/*!
* \brief Use IterVarMap detector to rewrite and simplify the indices
*
Expand Down Expand Up @@ -333,6 +335,7 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param analyzer Analyzer used to get context information.
* \param diag_ctx Diagnostic context.
*
* \return The result list has length len(bindings) + 1
[0, len(bindings)): The iter map matching result. The inner list is of length 2.
Expand All @@ -344,7 +347,8 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
const Map<Var, Range>& input_iters,
const Array<Var>& sub_iters, const PrimExpr& predicate,
bool require_bijective, arith::Analyzer* analyzer);
bool require_bijective, arith::Analyzer* analyzer,
DiagnosticContext diag_ctx);

} // namespace arith
} // namespace tvm
Expand Down
4 changes: 3 additions & 1 deletion src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -835,9 +835,10 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
for (const Range& range : region) {
affine_indices.push_back(range->min);
}
DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
iter_sum_exprs = DetectIterMap(
/*indices=*/affine_indices, /*input_iters=*/var_dom,
/*predicate=*/predicate, /*require_bijective=*/false, analyzer);
/*predicate=*/predicate, /*require_bijective=*/false, analyzer, diag_ctx);
}
if (iter_sum_exprs.empty()) {
return NullOpt;
Expand All @@ -857,6 +858,7 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
if (!analyzer->CanProve(range->extent >= split->scale)) {
return NullOpt;
}

const PrimExpr& base = sum_expr->base;
// IterSplitExpr: (source // lower_factor) % extent * scale
// where `(source // lower_factor) % extent` is within [0, extent - 1]
Expand Down
471 changes: 361 additions & 110 deletions src/arith/iter_affine_map.cc

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,12 +415,14 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
if (loop_var_ranges.empty()) {
return true;
}
DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
Array<arith::IterSumExpr> results = arith::DetectIterMap(
/*indices=*/realize->iter_values,
/*input_iters=*/loop_var_ranges,
/*predicate=*/realize->predicate,
/*require_bijective=*/false,
/*analyzer=*/analyzer);
/*analyzer=*/analyzer,
/*diag_ctx*/ diag_ctx);
if (results.empty()) {
return false;
}
Expand Down
87 changes: 87 additions & 0 deletions tests/python/unittest/test_arith_intset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
import tvm
from tvm import te
from tvm import tir
from tvm.ir.base import structural_equal


class IntSetChecker:
Expand Down Expand Up @@ -233,6 +235,90 @@ def test_region_lower_bound_negative_scale():
assert int_set_1.max_value.value == 35


def test_region_lower_bound_for_non_perfect_tile():
h1 = tvm.tir.Var("h1", "int32")
h2 = tvm.tir.Var("h2", "int32")
h3 = tvm.tir.Var("h3", "int32")
analyzer = tvm.arith.Analyzer()

def do_test_point_access(point, predicates, var_dom, expect):
regions = tvm.arith.estimate_region_lower_bound(
region=[
tvm.ir.Range.from_min_extent(min_value=point, extent=1),
],
var_dom=var_dom,
predicate=tvm.tir.all(*predicates),
)
if expect is None: # expect a failure
assert regions is None
else:
assert len(regions) == 1
for binding, expect_min, expect_max in expect:
min_diff = expect_min - regions[0].min_value
assert analyzer.simplify(tir.stmt_functor.substitute(min_diff, binding), 3) == 0
max_diff = expect_max - regions[0].max_value
assert analyzer.simplify(tir.stmt_functor.substitute(max_diff, binding), 3) == 0

# non-uniform tiling, single inner variable
# h3 == 0: region is [1, 9]
# 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 9]
# h3 > 26: region is [h3 * 8, 223]
do_test_point_access(
point=h3 * 8 + h2,
predicates=[1 <= h3 * 8 + h2, h3 * 8 + h2 < 224],
var_dom={
h2: tvm.ir.Range(begin=0, end=10),
},
expect=[
(
{},
tvm.tir.max(h3 * 8, 1),
tvm.tir.max(h3 * 8, 1)
- tvm.tir.max(h3 * 8, 214)
- tvm.tir.max(1 - h3 * 8, 0)
+ 223,
),
({h3: 0}, 1, 9),
({h3: 10}, h3 * 8, h3 * 8 + 9),
({h3: 27}, h3 * 8, 223),
],
)

# non-uniform tiling, two inner variables
do_test_point_access(
point=h3 * 8 + h2 * 5 + h1,
predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h2 * 5 + h1 < 224],
var_dom={
h2: tvm.ir.Range(begin=0, end=2),
h1: tvm.ir.Range(begin=0, end=5),
},
expect=[
(
{},
tvm.tir.max(h3 * 8, 1),
tvm.tir.max(h3 * 8, 1)
- tvm.tir.max(h3 * 8, 214)
- tvm.tir.max(1 - h3 * 8, 0)
+ 223,
),
({h3: 0}, 1, 9),
({h3: 10}, h3 * 8, h3 * 8 + 9),
({h3: 27}, h3 * 8, 223),
],
)

# should fail on incompatible predicates
do_test_point_access(
point=h3 * 8 + h2 * 5 + h1,
predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224],
var_dom={
h2: tvm.ir.Range(begin=0, end=2),
h1: tvm.ir.Range(begin=0, end=5),
},
expect=None,
)


def test_union_lower_bound():
neg_inf = tvm.arith.int_set.neg_inf()
pos_inf = tvm.arith.int_set.pos_inf()
Expand All @@ -257,4 +343,5 @@ def test_union_lower_bound():
test_region_lower_bound_split_predicate()
test_region_lower_bound_multiple_variables()
test_region_lower_bound_negative_scale()
test_region_lower_bound_for_non_perfect_tile()
test_union_lower_bound()
Loading

0 comments on commit 375abf6

Please sign in to comment.