Skip to content

Commit

Permalink
[ARITH] Canonicalize comparison to move constant to one side (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and wweic committed Jul 11, 2019
1 parent 4f7336b commit a4dbaac
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,12 @@ Mutate_(const LT* op, const Expr& self) {
TVM_TRY_RECURSIVE_REWRITE(z < min(x, y), z < x && z < y);
TVM_TRY_RECURSIVE_REWRITE(z < max(x, y), z < x || z < y);

TVM_TRY_RECURSIVE_REWRITE(x < c1 - y, x + y < c1);
TVM_TRY_RECURSIVE_REWRITE(x < c1 + y, x - y < c1);
TVM_TRY_RECURSIVE_REWRITE(c1 - y < x, c1 < x + y);
TVM_TRY_RECURSIVE_REWRITE(c1 + y < x, c1 < x - y);


TVM_TRY_REWRITE(x - c1 < 0, x < c1);
TVM_TRY_REWRITE(x + c1 < c2, x < c2 - c1);
}
Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,18 @@ def test_simplify_if_then_else():
tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528),
(((((x*4) + y) - 466036) % 24528) -24512) % 16,
x), y)

res2 = tvm.if_then_else((x * 4) >= 466036 - y,
tvm.if_then_else(24512 <= ((((x*4) + y) - 466036) % 24528),
(((((x*4) + y) - 466036) % 24528) -24512) % 16,
x), y)
expected = tvm.if_then_else(
tvm.expr.LE(466036, (x * 4 + y)),
tvm.if_then_else(tvm.expr.LE(24512, ((((x*4) + y) - 4) % 24528)),
(((x*4) + y) - 4) % 16,
x), y)
ck.verify(res, expected)
ck.verify(res2, expected)
# can only simplify if condition
res = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 100) % 3, (x + 100) % 3)
expected = tvm.expr.Select(tvm.all(x >= -1, y >= 0), (x + y + 1) % 3, (x + 100) % 3)
Expand Down

0 comments on commit a4dbaac

Please sign in to comment.