Skip to content

Commit

Permalink
PR-13024 [TIR][Arith] Use TryCompare to narrow inequalities if possible
Browse files Browse the repository at this point in the history
Prior to this commit, the result of TryCompare would only be used if
it could definitively prove a conditional to be either true or false.
For example, if it is known that `0 <= i`, a conditional of `i <= 0`
would be left as-is.

This commit introduces rewrite rules to preferentially simplify
into more restrictive conditions.  Using the same example, if it is
known that `0 <= i`, a conditional of `i <= 0` would be simplified
into `i == 0`.  Similarly, if it is known that `0 <= i`, a
conditional of `i != 0` would be simplified into `0 < i`.

Because this change does not introduce significant overhead, as the
results of `RewriteSimplifier::Impl::TryCompare` are already
available, this change is enabled for all use cases and does not
require a call to `RewriteSimplifier::SetEnabledExtensions`.
  • Loading branch information
Lunderberg committed Oct 31, 2022
1 parent c9b10a8 commit 2bf9ef3
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 32 deletions.
144 changes: 119 additions & 25 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/tir/op.h>

#include <algorithm>
#include <utility>

#include "../target/datatype/registry.h"
#include "conjunctive_normal_form.h"
Expand Down Expand Up @@ -1384,80 +1385,164 @@ Optional<PrimExpr> RewriteSimplifier::Impl::TryMatchLiteralConstraint(const Prim
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<EQNode>();
EQ ret = Downcast<EQ>(IRMutatorWithAnalyzer::VisitExpr_(op));
op = ret.get();

if (auto const_res = TryConstFold<EQ>(op->a, op->b)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

return ApplyRewriteRules(ret);
}

PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) {
// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1;
PVar<int> lanes;

// vector rule
if (op->dtype.lanes() != 1) {
if (ret->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x == y, lanes));
}

if (IsIndexType(op->a.dtype())) {
CompareResult result = TryCompare(op->a, op->b);
if (IsIndexType(ret->a.dtype())) {
CompareResult result = TryCompare(ret->a, ret->b);
if (result == CompareResult::kEQ) {
return make_const(op->dtype, true);
return make_const(ret->dtype, true);
} else if (result == CompareResult::kNE || result == CompareResult::kGT ||
result == CompareResult::kLT) {
return make_const(op->dtype, false);
return make_const(ret->dtype, false);
}
TVM_TRY_REWRITE(c1 == x, x == c1);

TVM_TRY_REWRITE(x - c1 == 0, x == c1);
TVM_TRY_REWRITE(c1 - x == 0, x == c1);
TVM_TRY_REWRITE(x + c1 == 0, x == 0 - c1);
TVM_TRY_REWRITE(x * y == 0, x == 0 || y == 0);
}
return ret;
return std::move(ret);
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) {
return this->VisitExpr(Not(op->a == op->b));
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<NENode>();

if (auto const_res = TryConstFold<NE>(op->a, op->b)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

if (IsIndexType(op->a.dtype())) {
CompareResult result = TryCompare(op->a, op->b);
if (result == CompareResult::kNE || result == CompareResult::kGT ||
result == CompareResult::kLT) {
return make_const(op->dtype, true);
} else if (result == CompareResult::kEQ) {
return make_const(op->dtype, false);
} else if (result == CompareResult::kGE) {
// Known: a >= b
//
// a != b
// (a < b) or (b < a)
// False or (b < a)
// b < a
return ApplyRewriteRules(LT(op->b, op->a));
} else if (result == CompareResult::kLE) {
// Known: a <= b
//
// a != b
// (a < b) or (b < a)
// (a < b) or False
// a < b
return ApplyRewriteRules(LT(op->a, op->b));
}
}

return ApplyRewriteRules(Not(ApplyRewriteRules(EQ(op->a, op->b))));
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) {
return this->VisitExpr(Not(op->b < op->a));
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<LENode>();
ICHECK(op);

if (auto const_res = TryConstFold<LE>(op->a, op->b)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

// Check for applicable rewrites before attempting to prove/disprove
// the inequality. This preserves earlier behavior, where (A<=B*x)
// simplifies to (ceildiv(A,B)<=x) when (A%B!=0). Performing the
// TryCompare first would simplify to the equivalent
// (floordiv(A,B)<x) in these cases instead.
ret = ApplyRewriteRules(Not(ApplyRewriteRules(LT(op->b, op->a))));

if (auto op = ret.as<LENode>(); op && IsIndexType(op->a.dtype())) {
CompareResult result = TryCompare(op->a, op->b);
if (result == CompareResult::kLE || result == CompareResult::kLT ||
result == CompareResult::kEQ) {
return make_const(op->dtype, true);
} else if (result == CompareResult::kGT) {
return make_const(op->dtype, false);
} else if (result == CompareResult::kNE) {
// Known: a != b
//
// a <= b
// (a < b) or (a == b)
// (a < b) or False
// a < b
return ApplyRewriteRules(LT(op->a, op->b));
} else if (result == CompareResult::kGE) {
// Known: a >= b
//
// a <= b
// (a < b) or (a == b)
// False or (a == b)
// a == b
return ApplyRewriteRules(EQ(op->a, op->b));
}
}

return ret;
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GTNode* op) {
return this->VisitExpr(op->b < op->a);
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GENode* op) {
return this->VisitExpr(Not(op->a < op->b));
return this->VisitExpr(op->b <= op->a);
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<LTNode>();
LT node = Downcast<LT>(IRMutatorWithAnalyzer::VisitExpr_(op));
op = node.get();

if (auto const_res = TryConstFold<LT>(op->a, op->b)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
if (auto match = TryMatchLiteralConstraint(node)) return match.value();

return ApplyRewriteRules(node);
}

PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) {
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<int> lanes;

// vector rule
if (op->dtype.lanes() != 1) {
if (ret->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), broadcast(x < y, lanes));
TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), broadcast(x < y, lanes));
}

if (IsIndexType(op->a.dtype())) {
CompareResult result = TryCompare(op->a, op->b);
if (IsIndexType(ret->a.dtype())) {
CompareResult result = TryCompare(ret->a, ret->b);
if (result == CompareResult::kLT) {
return make_const(op->dtype, true);
return make_const(ret->dtype, true);
}
if (result == CompareResult::kEQ || result == CompareResult::kGT ||
result == CompareResult::kGE) {
return make_const(op->dtype, false);
return make_const(ret->dtype, false);
}

// clang-format off
Expand Down Expand Up @@ -1561,19 +1646,22 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
TVM_TRY_REWRITE(x - c1 < 0, x < c1);
// clang-format on
}
return ret;
return std::move(ret);
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<NotNode>();
if (auto const_res = TryConstFold<Not>(op->a)) return const_res.value();
Not ret = Downcast<Not>(IRMutatorWithAnalyzer::VisitExpr_(op));
if (auto const_res = TryConstFold<Not>(ret->a)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

return ApplyRewriteRules(ret);
}

PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(Not ret) {
// Pattern var to match any expression
PVar<PrimExpr> x, y;
PVar<int> lanes;
if (op->dtype.lanes() != 1) {
if (ret->dtype.lanes() != 1) {
TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes));
}

Expand All @@ -1586,7 +1674,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) {
TVM_TRY_REWRITE(!(x != y), x == y);
TVM_TRY_RECURSIVE_REWRITE(!(x || y), (!x) && (!y));
TVM_TRY_RECURSIVE_REWRITE(!(x && y), (!x) || (!y));
return ret;
return std::move(ret);
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
Expand Down Expand Up @@ -1762,6 +1850,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {

TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2);
TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2);

TVM_TRY_RECURSIVE_REWRITE(x < y || x == y, x <= y);
TVM_TRY_RECURSIVE_REWRITE(x < y || y == x, x <= y);
TVM_TRY_RECURSIVE_REWRITE(x == y || x < y, x <= y);
TVM_TRY_RECURSIVE_REWRITE(y == x || x < y, x <= y);

return ret;
}

Expand Down
21 changes: 21 additions & 0 deletions src/arith/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,27 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
*/
Optional<PrimExpr> TryMatchLiteralConstraint(const PrimExpr& expr) const;

/*! \brief Rewrite rules for Less Than comparisons
*
* These are separate from the VisitExpr_(const LTNode*) method, as
* they may required from rewrites of LT or LE.
*/
PrimExpr ApplyRewriteRules(LT node);

/*! \brief Rewrite rules for Equal comparisons
*
* These are separate from the VisitExpr_(const EQNode*) method, as
* they may required from rewrites of LE or NE.
*/
PrimExpr ApplyRewriteRules(EQ node);

/*! \brief Rewrite rules for Equal comparisons
*
* These are separate from the VisitExpr_(const EQNode*) method, as
* they may required from rewrites of LT, LE, or NE.
*/
PrimExpr ApplyRewriteRules(Not node);

private:
CompareResult TryCompareUsingKnownInequalities(const PrimExpr& x, const PrimExpr& y);
CompareResult TryCompareUsingConstIntBounds(const PrimExpr& x, const PrimExpr y);
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ def test_cmp_simplify():
ck.verify(fld(x, 2) <= -1, tvm.tir.LE(x, -1))

ck.verify(fld(x, 4) * 4 < x, tvm.tir.LT(0, flm(x, 4)))
ck.verify(fld(x, 4) * 4 >= x, tvm.tir.LE(flm(x, 4), 0))
ck.verify(fld(x, 4) * 4 >= x, tvm.tir.EQ(flm(x, 4), 0))

ck.verify(fld(x, 4) * 4 < x + y, tvm.tir.LT(0, flm(x, 4) + y))
ck.verify(fld(x, 4) * 4 < x - y, tvm.tir.LT(y, flm(x, 4)))
Expand Down
8 changes: 4 additions & 4 deletions tests/python/unittest/test_index_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_nonbijective_inverse_gives_error():
inverse=lambda i, j: [4 * i + j],
pre_shape=[15],
post_shape=[4, 4],
padding=lambda i, j: tvm.tir.And(i == 3, j >= 3),
padding=lambda i, j: tvm.tir.And(i == 3, tvm.runtime.convert(3) == j),
),
"left_padding": dict(
forward=lambda i: [(i + 1) // 4, (i + 1) % 4],
Expand All @@ -107,7 +107,7 @@ def test_nonbijective_inverse_gives_error():
post_shape=[4, 4],
padding=lambda i, j: tvm.tir.Or(
tvm.tir.And(i == 0, j < 1),
tvm.tir.And(i == 3, j >= 3),
tvm.tir.And(i == 3, tvm.runtime.convert(3) == j),
),
),
"dynamic_size": dict(
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_nonbijective_inverse_gives_error():
padding=lambda i_outer, j_outer, i_inner, j_inner: tvm.tir.Or(
tvm.tir.Or(
tvm.tir.And(i_outer == 0, i_inner < 1),
tvm.tir.And(i_outer == 3, i_inner >= 3),
tvm.tir.And(i_outer == 3, tvm.runtime.convert(3) == i_inner),
),
tvm.tir.Or(
tvm.tir.And(j_outer == 0, j_inner < 5),
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_nonbijective_inverse_gives_error():
inverse=lambda i, j: [i * 4 + j],
pre_shape=[3],
post_shape=[1, 4],
padding=lambda i, j: 3 <= j,
padding=lambda i, j: tvm.runtime.convert(3) == j,
),
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def transformed_three_stage_compute(
T.writes(B[0:2, tx, 0])
B[i, tx, 0] = A[tx, i] * T.float32(2)
with T.block():
T.where(1 <= i)
T.where(i == 1)
T.reads(B[0:2, tx, 0])
T.writes(C[0:2, tx, 0])
C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
Expand Down Expand Up @@ -1349,7 +1349,7 @@ def ref(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]) -> N
with T.attr(0, "async_scope", 1):
B[i % 2, tx, 0] = A[tx, i] * T.float32(2)
with T.block():
T.where(1 <= i and i - 1 < 16)
T.where(i == 1 and i - 1 < 16)
T.reads(B[(i + 1) % 2, tx, 0])
T.writes(C[(i + 1) % 2, tx, 0])
with T.attr(0, "async_commit_queue_scope", 1):
Expand Down
46 changes: 46 additions & 0 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,5 +1003,51 @@ def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32):
A[0] = True


class TestMostRestrictiveConditional(BaseBeforeAfter):
"""Preferentially prove part of a compound conditional.
Even if we cannot prove a conditional as true or false on its own,
proving that a conditional must satisfy a stronger condition may
allow for later rewrites. For example, if it is known that `a <= b`,
then `a >= b` cannot be proven, but can be reduced to `a == b`.
"""

i, j, k = [tvm.tir.Var(name, "int32") for name in "ijk"]
tir_int = tvm.tir.IntImm("int32", 0)

test_case = tvm.testing.parameter(
(i <= tir_int, tir_int <= i, i == tir_int),
(i <= tir_int, i != tir_int, i < tir_int),
(i != tir_int, i <= tir_int, i < tir_int),
(i != tir_int, tir_int <= i, tir_int < i),
(i <= j, j <= i, j == i),
(i <= j, i != j, i < j),
(i != j, i <= j, i < j),
(i != j, j <= i, j < i),
)

@tvm.testing.fixture
def before(self, test_case):
priors, expr_before, _ = test_case

@T.prim_func
def func(A: T.Buffer[1, "bool"]):
if priors:
A[0] = expr_before

return func

@tvm.testing.fixture
def expected(self, test_case):
priors, _, expr_after = test_case

@T.prim_func
def func(A: T.Buffer[1, "bool"]):
if priors:
A[0] = expr_after

return func


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 2bf9ef3

Please sign in to comment.