Skip to content

Commit

Permalink
Change to recursive depth limit. address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Mar 9, 2019
1 parent 82a9385 commit d88051b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
22 changes: 13 additions & 9 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class RewriteSimplifier::Impl : public IRMutator {
// Run simplification in post order
Expr PostOrderSimplify(Expr expr, int max_iter = 2) {
for (int i = 0; i < max_iter; ++i) {
recur_counter_ = 0;
Expr new_expr = this->Mutate(expr);
if (new_expr.same_as(expr)) return expr;
expr = new_expr;
Expand All @@ -80,12 +79,12 @@ class RewriteSimplifier::Impl : public IRMutator {
private:
// reference to the main analyzer
Analyzer* parent_;
// counter to record recursive rewrite times.
int recur_counter_{0};
// counter to record recursive rewrite depth.
int recur_depth_{0};
// internal variable map
std::unordered_map<Var, Expr, ExprHash, ExprEqual> var_map_;
// maximum number of recursion allowed during a single pass.
static const constexpr int kMaxRecurCount = 10;
static const constexpr int kMaxRecurDepth = 5;
// Whether x >= val
bool CanProveGreaterEqual(const Expr& x, int64_t val) {
return parent_->CanProveGreaterEqual(x, val);
Expand All @@ -100,13 +99,16 @@ class RewriteSimplifier::Impl : public IRMutator {
return false;
}
// Recursive rewrite x
// we limit maximum number of recursive rewrite allowed to
// we limit maximum depth of recursive rewrite allowed to
// avoid infinite loop
Expr RecursiveRewrite(Expr x) {
if (recur_counter_ >= kMaxRecurCount) return x;
++recur_counter_;
return Mutate(x);
Expr RecursiveRewrite(const Expr& x) {
if (recur_depth_ >= kMaxRecurDepth) return x;
++recur_depth_;
Expr res = Mutate(x);
--recur_depth_;
return res;
}

template<typename TA>
PConstWithTypeLike<TA> ZeroWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 0);
Expand Down Expand Up @@ -152,6 +154,8 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z));
TVM_TRY_REWRITE(max(x, y) + min(x, y), x + y);
TVM_TRY_REWRITE(min(x, y) + max(x, y), x + y);
TVM_TRY_REWRITE(max(x, y) + min(y, x), x + y);
TVM_TRY_REWRITE(min(x, y) + max(y, x), x + y);

TVM_TRY_REWRITE_IF(min(x, y + c1) + c2, min(x + c2, y),
c1.Eval()->value == -c2.Eval()->value);
Expand Down
5 changes: 4 additions & 1 deletion tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def test_select_simplify():
tvm.expr.Select(x > 0, y + 1, z))
ck.verify(tvm.expr.Select(x > 0, y, 1) - tvm.expr.Select(x > 0, 1, z),
tvm.expr.Select(x > 0, y + (-1), 1 - z))

ck.verify(tvm.expr.Select(x > 0, y, z) - y,
tvm.expr.Select(x > 0, 0, z - y))
ck.verify(tvm.expr.Select(x > 0, y, z) - z,
tvm.expr.Select(x > 0, y - z, 0))


def test_add_index_simplify():
Expand Down

0 comments on commit d88051b

Please sign in to comment.