Skip to content

Commit

Permalink
fix lint issues and compare bug
Browse files Browse the repository at this point in the history
  • Loading branch information
baoxinqi committed Dec 10, 2021
1 parent d1900bd commit 1ef2e7d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 13 deletions.
1 change: 1 addition & 0 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ class IterSumExpr : public IterMapExpr {
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param analyzer Analyzer used to get context information.
* \param diag_ctx Diagnostic context.
*
* \return The detected pattern if a match exists,
* otherwise return an empty array.
Expand Down
43 changes: 36 additions & 7 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) {
n->dtype = source->source->dtype;
n->source = std::move(source);
n->extent = n->source->extent;
if (is_zero(n->source->min)) {
if (!is_zero(n->source->min)) {
n->extent = n->extent + n->source->min;
}
n->lower_factor = one;
Expand Down Expand Up @@ -233,13 +233,20 @@ class IterMapRewriter : public ExprMutator {
collector.Collect(bindings);
for (const IterMark& mark : collector.visited_) {
if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) {
diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
<< "Fail to normalize iter mark splits: " << mark);
return false;
}
}
if (require_bijective) {
// all input marks must be visited
for (const IterMark& mark : input_marks_) {
if (collector.visited_.count(mark) == 0) return false;
if (collector.visited_.count(mark) == 0) {
diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
<< "The mapping is not bijective because input iter mark " << mark
<< " is not covered, ");
return false;
}
}
}
return true;
Expand Down Expand Up @@ -425,26 +432,48 @@ class IterMapRewriter : public ExprMutator {
}
if (j == splits.size()) {
// we do not allow incomplete split if the bindings should be bijective
if (require_bijective) return Array<IterSplitExpr>();
if (require_bijective) {
diag_ctx_.Emit(
Diagnostic::Error(mark->source->span)
<< "Do not allow incomplete split in bijective checking, expected_lower_factor="
<< expected_lower_factor);
return Array<IterSplitExpr>();
}
// look for the next split skipping this lower factor
// For example, y \in [0, 24) has 3 splits [y / 6, (y / 2) % 6, y % 2]
// It is valid to only have [y / 6, y % 2] if bijective is not required
// We can skip (y / 2) % 6
j = SearchSkipLowerFactor(splits, used, expected_lower_factor);
// split not found
if (j == splits.size()) return Array<IterSplitExpr>();
if (j == splits.size()) {
diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
<< "Fail to find split skipping the lower factor in bijective-free "
"checking, expected_lower_factor="
<< expected_lower_factor);
return Array<IterSplitExpr>();
}
}
used[j] = true;
iters.push_back(splits[j]);
expected_lower_factor = splits[j]->lower_factor * splits[j]->extent;
}

// Case 1. bijective is required.
// We check the extent we calculate is consistent with the extent of the mark
// Case 2. bijective is not required.
// We check the extent we calculate is a factor of the extent of the mark
// We check either
// (1) the extent we calculate is a factor of the extent of the mark
// For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not.
if ((require_bijective && !analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) ||
(!require_bijective && !CanProveDivisible(mark->extent, expected_lower_factor))) {
// (2) the extent we calculate is larger than the max of the mark
// For example, y \in [1, 8] [y / 18, y % 18] is valid.
if ((require_bijective &&
!(analyzer_->CanProveEqual(expected_lower_factor, mark->extent) && is_zero(mark->min))) ||
(!require_bijective &&
!(CanProveDivisible(mark->extent, expected_lower_factor) ||
analyzer_->CanProve(mark->min + mark->extent <= expected_lower_factor)))) {
diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
<< "Mark extent of " << mark
<< " is not compatible with expected_lower_factor=" << expected_lower_factor);
return Array<IterSplitExpr>();
}
return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
Expand Down
14 changes: 8 additions & 6 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def var_dom(iters):
return {var: tvm.ir.Range(0, ext) for var, ext in iters}


def assert_iter_sum_pattern(sum_expr, extent, base, scale=1, min=0):
def assert_iter_sum_pattern(sum_expr, extent, base, scale=1, mark_min=0, mark_extent=None):
"""Check the sum expr have the right pattern."""
assert isinstance(sum_expr, tvm.arith.IterSumExpr)
if extent == 1:
Expand All @@ -53,7 +53,9 @@ def assert_iter_sum_pattern(sum_expr, extent, base, scale=1, min=0):
assert len(sum_expr.args) == 1
tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent)
tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale)
tvm.testing.assert_prim_expr_equal(sum_expr.args[0].source.min, min)
tvm.testing.assert_prim_expr_equal(sum_expr.args[0].source.min, mark_min)
if mark_extent:
tvm.testing.assert_prim_expr_equal(sum_expr.args[0].source.extent, mark_extent)
tvm.testing.assert_prim_expr_equal(sum_expr.base, base)


Expand Down Expand Up @@ -212,10 +214,10 @@ def test_predicate():
# lower bound only
res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] > 5)
assert len(res) == 1
assert_iter_sum_pattern(res[0], 124, 0, min=6)
assert_iter_sum_pattern(res[0], 130, 0, mark_min=6, mark_extent=124)
res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] >= 6)
assert len(res) == 1
assert_iter_sum_pattern(res[0], 124, 0, min=6)
assert_iter_sum_pattern(res[0], 130, 0, mark_min=6, mark_extent=124)

# lower bound + upper bound
res = tvm.arith.detect_iter_map(
Expand All @@ -224,14 +226,14 @@ def test_predicate():
tvm.tir.And(x[0] * 10 + y[0] > 5, x[0] * 10 + y[0] < 128),
)
assert len(res) == 1
assert_iter_sum_pattern(res[0], 122, 0, min=6)
assert_iter_sum_pattern(res[0], 128, 0, mark_min=6, mark_extent=122)
res = tvm.arith.detect_iter_map(
[x[0] * 10 + y[0]],
var_dom([x, y]),
tvm.tir.And(x[0] * 10 + y[0] >= 6, x[0] * 10 + y[0] <= 127),
)
assert len(res) == 1
assert_iter_sum_pattern(res[0], 122, 0, min=6)
assert_iter_sum_pattern(res[0], 128, 0, mark_min=6, mark_extent=122)

# non-standard form of predicate
res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 < 128 - y[0])
Expand Down

0 comments on commit 1ef2e7d

Please sign in to comment.