Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Affine utility support iter lowerbound and diagnostics #9699

Merged
merged 7 commits into from
Dec 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#define TVM_ARITH_ITER_AFFINE_MAP_H_

#include <tvm/arith/analyzer.h>
#include <tvm/ir/diagnostic.h>
#include <tvm/ir/expr.h>
#include <tvm/tir/var.h>

Expand Down Expand Up @@ -275,13 +276,14 @@ class IterSumExpr : public IterMapExpr {
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param analyzer Analyzer used to get context information.
* \param diag_ctx Diagnostic context.
*
* \return The detected pattern if a match exists,
* otherwise return an empty array.
*/
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer);
arith::Analyzer* analyzer, DiagnosticContext diag_ctx);
/*!
* \brief Use IterVarMap detector to rewrite and simplify the indices
*
Expand Down Expand Up @@ -333,6 +335,7 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param analyzer Analyzer used to get context information.
* \param diag_ctx Diagnostic context.
*
* \return The result list has length len(bindings) + 1
[0, len(bindings)): The iter map matching result. The inner list is of length 2.
Expand All @@ -344,7 +347,8 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
const Map<Var, Range>& input_iters,
const Array<Var>& sub_iters, const PrimExpr& predicate,
bool require_bijective, arith::Analyzer* analyzer);
bool require_bijective, arith::Analyzer* analyzer,
DiagnosticContext diag_ctx);

} // namespace arith
} // namespace tvm
Expand Down
4 changes: 3 additions & 1 deletion src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -835,9 +835,10 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
for (const Range& range : region) {
affine_indices.push_back(range->min);
}
DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
iter_sum_exprs = DetectIterMap(
/*indices=*/affine_indices, /*input_iters=*/var_dom,
/*predicate=*/predicate, /*require_bijective=*/false, analyzer);
/*predicate=*/predicate, /*require_bijective=*/false, analyzer, diag_ctx);
}
if (iter_sum_exprs.empty()) {
return NullOpt;
Expand All @@ -857,6 +858,7 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
if (!analyzer->CanProve(range->extent >= split->scale)) {
return NullOpt;
}

const PrimExpr& base = sum_expr->base;
// IterSplitExpr: (source // lower_factor) % extent * scale
// where `(source // lower_factor) % extent` is within [0, extent - 1]
Expand Down
471 changes: 361 additions & 110 deletions src/arith/iter_affine_map.cc

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -415,12 +415,14 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
if (loop_var_ranges.empty()) {
return true;
}
DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
Array<arith::IterSumExpr> results = arith::DetectIterMap(
/*indices=*/realize->iter_values,
/*input_iters=*/loop_var_ranges,
/*predicate=*/realize->predicate,
/*require_bijective=*/false,
/*analyzer=*/analyzer);
/*analyzer=*/analyzer,
/*diag_ctx*/ diag_ctx);
if (results.empty()) {
return false;
}
Expand Down
87 changes: 87 additions & 0 deletions tests/python/unittest/test_arith_intset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
import tvm
from tvm import te
from tvm import tir
from tvm.ir.base import structural_equal


class IntSetChecker:
Expand Down Expand Up @@ -233,6 +235,90 @@ def test_region_lower_bound_negative_scale():
assert int_set_1.max_value.value == 35


def test_region_lower_bound_for_non_perfect_tile():
h1 = tvm.tir.Var("h1", "int32")
h2 = tvm.tir.Var("h2", "int32")
h3 = tvm.tir.Var("h3", "int32")
analyzer = tvm.arith.Analyzer()

def do_test_point_access(point, predicates, var_dom, expect):
regions = tvm.arith.estimate_region_lower_bound(
region=[
tvm.ir.Range.from_min_extent(min_value=point, extent=1),
],
var_dom=var_dom,
predicate=tvm.tir.all(*predicates),
)
if expect is None: # expect a failure
assert regions is None
else:
assert len(regions) == 1
for binding, expect_min, expect_max in expect:
min_diff = expect_min - regions[0].min_value
assert analyzer.simplify(tir.stmt_functor.substitute(min_diff, binding), 3) == 0
max_diff = expect_max - regions[0].max_value
assert analyzer.simplify(tir.stmt_functor.substitute(max_diff, binding), 3) == 0

# non-uniform tiling, single inner variable
# h3 == 0: region is [1, 9]
# 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 9]
# h3 > 26: region is [h3 * 8, 223]
do_test_point_access(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great if we can add 3 more test cases for each case of h3 :)

point=h3 * 8 + h2,
predicates=[1 <= h3 * 8 + h2, h3 * 8 + h2 < 224],
var_dom={
h2: tvm.ir.Range(begin=0, end=10),
},
expect=[
(
{},
tvm.tir.max(h3 * 8, 1),
tvm.tir.max(h3 * 8, 1)
- tvm.tir.max(h3 * 8, 214)
- tvm.tir.max(1 - h3 * 8, 0)
+ 223,
),
({h3: 0}, 1, 9),
({h3: 10}, h3 * 8, h3 * 8 + 9),
({h3: 27}, h3 * 8, 223),
],
)

# non-uniform tiling, two inner variables
do_test_point_access(
point=h3 * 8 + h2 * 5 + h1,
predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h2 * 5 + h1 < 224],
var_dom={
h2: tvm.ir.Range(begin=0, end=2),
h1: tvm.ir.Range(begin=0, end=5),
},
expect=[
(
{},
tvm.tir.max(h3 * 8, 1),
tvm.tir.max(h3 * 8, 1)
- tvm.tir.max(h3 * 8, 214)
- tvm.tir.max(1 - h3 * 8, 0)
+ 223,
),
({h3: 0}, 1, 9),
({h3: 10}, h3 * 8, h3 * 8 + 9),
({h3: 27}, h3 * 8, 223),
],
)

# should fail on incompatible predicates
do_test_point_access(
point=h3 * 8 + h2 * 5 + h1,
predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224],
var_dom={
h2: tvm.ir.Range(begin=0, end=2),
h1: tvm.ir.Range(begin=0, end=5),
},
expect=None,
)


def test_union_lower_bound():
neg_inf = tvm.arith.int_set.neg_inf()
pos_inf = tvm.arith.int_set.pos_inf()
Expand All @@ -257,4 +343,5 @@ def test_union_lower_bound():
test_region_lower_bound_split_predicate()
test_region_lower_bound_multiple_variables()
test_region_lower_bound_negative_scale()
test_region_lower_bound_for_non_perfect_tile()
test_union_lower_bound()
Loading