Skip to content

Commit

Permalink
[Arith] Updated incorrect simplification rule (apache#13922)
Browse files Browse the repository at this point in the history
The rules that rewrite `min(floordiv(x + (A-1), A) * A, max(x, A))`
and `min(truncdiv(x + (A-1), A) * A, max(x, A))` into `max(x, A)` did
not have sufficiently tight bounds.  The `truncdiv` rule required that
`x >= 0`, while the `floordiv` rule had no requirement on `x`.  In
both cases, the simplification was incorrect when `x==0`, as it would
result in a rewrite from `min(0, max(0, A))` into `max(0, A)`.

This commit updates the rules to require that `x >= 0` for each of
these rules.
  • Loading branch information
Lunderberg authored Feb 7, 2023
1 parent 5456fae commit 282f175
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
10 changes: 6 additions & 4 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1096,24 +1096,26 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) {
c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, max(x, c2)), max(x, c2),
c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value &&
CanProveGreaterEqual(x.Eval(), 0));
CanProveGreaterEqual(x.Eval(), 1));

TVM_TRY_REWRITE_IF(min(x, truncdiv(x + c1, c2) * c2), x,
c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(max(x, c2), truncdiv(x + c1, c2) * c2), max(x, c2),
c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value &&
CanProveGreaterEqual(x.Eval(), 0));
CanProveGreaterEqual(x.Eval(), 1));

// Divide up rounding: floor div
TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, x), x,
c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, max(x, c2)), max(x, c2),
c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value &&
CanProveGreaterEqual(x.Eval(), 1));

TVM_TRY_REWRITE_IF(min(x, floordiv(x + c1, c2) * c2), x,
c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(max(x, c2), floordiv(x + c1, c2) * c2), max(x, c2),
c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value &&
CanProveGreaterEqual(x.Eval(), 1));

TVM_TRY_REWRITE_IF(min(x, floordiv(x, c2) * c2), floordiv(x, c2) * c2, c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(min(floordiv(x, c2) * c2, x), floordiv(x, c2) * c2, c2.Eval()->value > 0);
Expand Down
9 changes: 6 additions & 3 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,20 +649,23 @@ def test_min_index_simplify():
# truc div
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000))
ck.verify(tvm.te.min(tdiv(x + 3, 4) * 4, x), x)
ck.verify(tvm.te.min(tdiv(x + 3, 4) * 4, tvm.te.max(x, 4)), tvm.te.max(x, 4))
ck.verify(tvm.te.min(x, tdiv(x + 3, 4) * 4), x)
ck.verify(tvm.te.min(tvm.te.max(x, 4), tdiv(x + 3, 4) * 4), tvm.te.max(x, 4))
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True)
ck.verify(tvm.te.min(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.te.min(x, y), 10))
ck.verify(tvm.te.min(tdiv(x, (-10)), tdiv(y, (-10))), tdiv(tvm.te.max(x, y), (-10)))
ck.analyzer.update(x, tvm.arith.ConstIntBound(1, 1000), True)
ck.verify(tvm.te.min(tdiv(x + 3, 4) * 4, tvm.te.max(x, 4)), tvm.te.max(x, 4))
ck.verify(tvm.te.min(tvm.te.max(x, 4), tdiv(x + 3, 4) * 4), tvm.te.max(x, 4))

# floor div
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True)
ck.verify(tvm.te.min(fld(x + 3, 4) * 4, x), x)
ck.verify(tvm.te.min(fld(x + 3, 4) * 4, tvm.te.max(x, 4)), tvm.te.max(x, 4))
ck.verify(tvm.te.min(x, fld(x + 3, 4) * 4), x)
ck.verify(tvm.te.min(x, fld(x, 4) * 4), fld(x, 4) * 4)
ck.analyzer.update(x, tvm.arith.ConstIntBound(1, 1000), True)
ck.verify(tvm.te.min(fld(x + 3, 4) * 4, tvm.te.max(x, 4)), tvm.te.max(x, 4))
ck.verify(tvm.te.min(tvm.te.max(x, 4), fld(x + 3, 4) * 4), tvm.te.max(x, 4))
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), True)
ck.verify(tvm.te.min(fld(x, 10), fld(y, 10)), fld(tvm.te.min(x, y), 10))
ck.verify(tvm.te.min(fld(x, (-10)), fld(y, (-10))), fld(tvm.te.max(x, y), (-10)))

Expand Down

0 comments on commit 282f175

Please sign in to comment.