From f0f3d31bc48536d694cee6d2b7931321b4c185c9 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik Date: Tue, 19 Mar 2019 18:45:23 +0300 Subject: [PATCH] [ARITH] RewriteSimplifier: improved cmp simplification --- src/arithmetic/rewrite_simplify.cc | 20 ++++++++++++------- .../unittest/test_arith_rewrite_simplify.py | 15 ++++++++++++++ 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index 17f8e010f393..f031e094d84a 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -96,6 +96,8 @@ class RewriteSimplifier::Impl : public IRMutator { kEQ, kGT, kLT, + kGE, + kLE, kNE }; // reference to the main analyzer @@ -140,6 +142,12 @@ class RewriteSimplifier::Impl : public IRMutator { if (dbound->max_value < val) { return kLT; } + if (dbound->min_value >= val) { + return kGE; + } + if (dbound->max_value <= val) { + return kLE; + } return kUnknown; } @@ -994,12 +1002,10 @@ Mutate_(const EQ* op, const Expr& self) { if (IsIndexType(op->a.type())) { CompareResult result = TryCompare(op->a - op->b, 0); - if (result != kUnknown) { - if (result == kEQ) { - return make_const(op->type, true); - } else { - return make_const(op->type, false); - } + if (result == kEQ) { + return make_const(op->type, true); + } else if (result == kNE || result == kGT || result == kLT) { + return make_const(op->type, false); } TVM_TRY_REWRITE(x - c1 == 0, x == c1); TVM_TRY_REWRITE(c1 - x == 0, x == c1); @@ -1055,7 +1061,7 @@ Mutate_(const LT* op, const Expr& self) { if (result == kLT) { return make_const(op->type, true); } - if (result == kEQ || result == kGT) { + if (result == kEQ || result == kGT || result == kGE) { return make_const(op->type, false); } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 274676449cab..62e6ea9c6c8e 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -450,6 +450,21 @@ def test_cmp_simplify(): ck.verify(tvm.max(8, x) > 10, tvm.expr.LT(10, x)) ck.verify(x + 1 < tvm.max(8, x), x < 7) + ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10), override=True) + ck.analyzer.update(y, tvm.arith.ConstIntBound(-10, 0), override=True) + ck.analyzer.update(z, tvm.arith.ConstIntBound(-5, 5), override=True) + + ck.verify(x < 11, tvm.const(1, "bool")) + ck.verify(x <= 10, tvm.const(1, "bool")) + ck.verify(z <= 5, tvm.const(1, "bool")) + ck.verify(x + y <= 10, tvm.const(1, "bool")) + ck.verify(x + y >= -10, tvm.const(1, "bool")) + ck.verify(z - 5 <= y + 10, tvm.const(1, "bool")) + ck.verify(tvm.all(x > -1, z <= x + 5), tvm.const(1, "bool")) + ck.verify(x*y <= 0, tvm.const(1, "bool")) + ck.verify((x + 1)*(y - 1) < 0, tvm.const(1, "bool")) + ck.verify(y*y >= 0, tvm.const(1, "bool")) + def test_logical_simplify(): ck = RewriteChecker()