diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 6f8155f9de59f..9f61bf58a5e25 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -613,12 +613,15 @@ IntSet Intersect(const Array& sets); * give the domain of each variables. Return undefined IntSet to * represent failure. * + * \note The returned set may be smaller than set that + * contains all possible values of v that satisfies the bound. + * * \param v The target variable to be deduced. * \param cond The conditional expression. * \param hint_map The domain of variable, used to help deduce. * \param relax_map The domain of each variable, used to relax the domain, - * The deduce bound mush implies e for all value in relax_map - * \return An integer set that can cover all the possible values. + * The deduce bound must implies e for all value in relax_map + * \return An integer set that always satisfies the condition. */ IntSet DeduceBound(Expr v, Expr cond, const Map& hint_map, @@ -631,7 +634,7 @@ IntSet DeduceBound(Expr v, Expr cond, * \param hint_map The domain of variable, used to help deduce. * \param relax_map The domain of each variable, used to relax the domain, * The deduce bound mush implies e for all value in relax_map - * \return An integer set that can cover all the possible values. + * \return An integer set that always satisfies the condition. */ IntSet DeduceBound(Expr v, Expr cond, const std::unordered_map& hint_map, diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index e85c71057e6c2..955554d8f408b 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -84,11 +84,11 @@ class BoundDeducer: public IRVisitor { void Deduce(); void Visit(const NodeRef& e) final { - if (!success) return; + if (!success_) return; if (e.get() == path_[iter_++]) { IRVisitor::Visit(e); } else { - success = false; + success_ = false; return; } } @@ -111,18 +111,18 @@ class BoundDeducer: public IRVisitor { void Visit_(const Add* op) final { bool left = op->a.get() == path_[iter_]; - result -= left ? op->b : op->a; + result_ -= left ? op->b : op->a; Visit(left ? op->a : op->b); } void Visit_(const Sub* op) final { bool left = op->a.get() == path_[iter_]; if (left) { - result += op->b; + result_ += op->b; } else { - result -= op->a; - result = - result; - is_greater = !is_greater; + result_ -= op->a; + result_ = - result_; + is_greater_ = !is_greater_; } Visit(left ? op->a : op->b); } @@ -130,43 +130,66 @@ class BoundDeducer: public IRVisitor { void Visit_(const Mul* op) final { bool left = op->a.get() == path_[iter_]; Expr operand = left ? op->b : op->a; + Expr target_var = left ? op->a : op->b; - SignType sign; + SignType sign_operand; if (operand.type().is_uint()) { - sign = kPositive; + sign_operand = kPositive; } else { - sign = expr_map_[operand].sign_type(); + sign_operand = expr_map_[operand].sign_type(); } - if (sign == SignType::kNegative) { - is_greater = !is_greater; - } else if (sign == SignType::kUnknown) { + if (sign_operand == SignType::kNegative) { + is_greater_ = !is_greater_; + } else if (sign_operand == SignType::kUnknown) { // unable to get the sign of operand - success = false; + success_ = false; return; } - // always use relax bound - bool divided = analyzer_.CanProve(result % operand == 0); - result = result / operand; - // since system will round down when not divided - // eg. 2/4 -> 0; -2/4 -> -1 - // no need fix for !is_greater: - // eg. a <= 2/4 -> a <= 0 - // eg. a <= 0/4 -> a <= 0 - // so just fix for not divided and is_greater - // eg. a >= 2/4 -> a >= 0 + 1 - // eg. a >= 0/4 -> a >= 0 - if (is_greater && !divided) { - result += 1; + bool divided = analyzer_.CanProve(result_ % operand == 0); + + result_ = result_ / operand; + + if (!divided) { + // Handle non-divisible case + // NOTE: this accounts for truc div behavior. + bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative(); + + if (is_greater_) { + result_ += 1; + } else { + // NOTE: this is a bit sutble hack. + // + // condition: + // - x * operand <= result + // - operand > 0 + // - x >= 0 + // + // Then it is fine to deduce that x <= result / operand. + // - if result > 0, this division round down + // - if result < 0, (result / operand) rounds up and may violate the constraint + // however, given that x is always non-negative, + // it is fine to have this relaxed bound, given that the user of deduce bound + // will respect the bound of x + // + // TODO(tvm-team): think about a better API to incorporate constraint of x. + // e.g. specify an interval of x and return a bound + // that is in the interval and satisfies the condition. + if (target_is_non_neg && sign_operand == kPositive) { + // do nothing + } else { + result_ -= 1; + } + + } } - Visit(left ? op->a : op->b); } - Expr result; - bool is_greater{true}; - bool success{true}; + Expr result_; + bool is_greater_{true}; + bool success_{true}; private: void Init(); @@ -204,7 +227,7 @@ class BoundDeduceInputChecker: public IRVisitor { void BoundDeducer::Init() { BoundDeduceInputChecker checker; - if (!checker.Check(this)) success = false; + if (!checker.Check(this)) success_ = false; Transform(); } @@ -213,66 +236,65 @@ void BoundDeducer::Transform() { if (const LT* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a < b -> b >= a + 1 - is_greater = true; + is_greater_ = true; expr_ = op->b; - result = op->a + 1; + result_ = op->a + 1; } else { // a < b -> a <= b - 1 - is_greater = false; + is_greater_ = false; expr_ = op->a; - result = op->b - 1; + result_ = op->b - 1; } } else if (const LE* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a <= b -> b >= a - is_greater = true; + is_greater_ = true; expr_ = op->b; - result = op->a; + result_ = op->a; } else { - is_greater = false; + is_greater_ = false; expr_ = op->a; - result = op->b; + result_ = op->b; } } else if (const GT* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a > b -> b <= a - 1 - is_greater = false; + is_greater_ = false; expr_ = op->b; - result = op->a - 1; + result_ = op->a - 1; } else { // a > b -> a >= b + 1 - is_greater = true; + is_greater_ = true; expr_ = op->a; - result = op->b + 1; + result_ = op->b + 1; } } else if (const GE* op = expr_.as()) { if (GetPath(target_, op->a).empty()) { // a >= b -> b <= a - is_greater = false; + is_greater_ = false; expr_ = op->b; - result = op->a; + result_ = op->a; } else { - is_greater = true; + is_greater_ = true; expr_ = op->a; - result = op->b; + result_ = op->b; } } else { - success = false; + success_ = false; } } void BoundDeducer::Deduce() { Init(); - if (!success) return; + if (!success_) return; Relax(); - if (!success) return; + if (!success_) return; // get the path path_ = GetPath(target_, expr_); if (!path_.size()) { - success = false; + success_ = false; return; } - expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); Visit(expr_); @@ -280,13 +302,13 @@ void BoundDeducer::Deduce() { void BoundDeducer::Relax() { IntSet a = EvalSet(expr_, relax_map_); - IntSet b = EvalSet(result, relax_map_); + IntSet b = EvalSet(result_, relax_map_); if (a.is_everything() || b.is_everything()) { - success = false; + success_ = false; return; } - expr_ = is_greater ? a.min() : a.max(); - result = is_greater ? b.max() : b.min(); + expr_ = is_greater_ ? a.min() : a.max(); + result_ = is_greater_ ? b.max() : b.min(); } IntSet DeduceBound(Expr v, Expr e, @@ -294,12 +316,12 @@ IntSet DeduceBound(Expr v, Expr e, const std::unordered_map& relax_map) { BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); - if (!d.success) return IntSet::nothing(); + if (!d.success_) return IntSet::nothing(); Expr min = neg_inf(), max = pos_inf(); - if (d.is_greater) { - min = d.result; + if (d.is_greater_) { + min = d.result_; } else { - max = d.result; + max = d.result_; } return IntSet::interval(min, max); } diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index ec50aef5c51ed..dc6b80a31c7bd 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -155,9 +155,10 @@ template<> inline Expr TryConstFold(Expr a, Expr b) { TVM_ARITH_CONST_PROPAGATION({ const Type& rtype = a.type(); - // due to division and mod can have different modes - // only constant fold positive number where rule is fixed. - if (pa && pb && pa->value >= 0 && pb->value > 0) { + if (pa && pb) { + // due to division and mod can have different modes + // NOTE: this will assumes truc div. + CHECK_NE(pb->value, 0) << "Divide by zero"; return IntImm::make(rtype, pa->value / pb->value); } if (pa) { diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index dae5301e0338d..8b23b2eb2e02e 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -147,6 +147,7 @@ Mutate_(const Add* op, const Expr& self) { TVM_TRY_REWRITE(min(x - z, y) + z, min(x, y + z)); TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y)); TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z)); + TVM_TRY_REWRITE(max(x, y) + min(x, y), x + y); TVM_TRY_REWRITE(min(x, y) + max(x, y), x + y); TVM_TRY_REWRITE(max(x, y) + min(y, x), x + y); @@ -188,6 +189,9 @@ Mutate_(const Add* op, const Expr& self) { TVM_TRY_RECURSIVE_REWRITE(x + c1 + y, (x + y) + c1); TVM_TRY_RECURSIVE_REWRITE(x + (c1 + y), (x + y) + c1); TVM_TRY_RECURSIVE_REWRITE((y % c1) + x * c1, x * c1 + (y % c1)); + + TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x); + TVM_TRY_RECURSIVE_REWRITE(x + min(y, z), min(y, z) + x); } // condition rules. @@ -455,6 +459,10 @@ Mutate_(const Div* op, const Expr& self) { } } + TVM_TRY_REWRITE(x / x, OneWithTypeLike(x)); + TVM_TRY_REWRITE(x * c1 / x, c1); + TVM_TRY_REWRITE(c1 * x / x, c1); + // Rules involving 2-operands. TVM_TRY_REWRITE_IF((x * c1 + y) / c2, x * (c1 / c2) + y / c2, c1.Eval()->value >= 0 && diff --git a/src/arithmetic/rewrite_simplify.h b/src/arithmetic/rewrite_simplify.h index 82c65f97f782c..476e42220c237 100644 --- a/src/arithmetic/rewrite_simplify.h +++ b/src/arithmetic/rewrite_simplify.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -121,6 +121,11 @@ class RewriteSimplifier::Impl : public IRMutator { PConstWithTypeLike ZeroWithTypeLike(const Pattern& pattern) { return PConstWithTypeLike(pattern.derived(), 0); } + + template + PConstWithTypeLike OneWithTypeLike(const Pattern& pattern) { + return PConstWithTypeLike(pattern.derived(), 1); + } }; diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc index 89298ed6d101c..d4c6da6081559 100644 --- a/src/arithmetic/stmt_simplify.cc +++ b/src/arithmetic/stmt_simplify.cc @@ -149,7 +149,9 @@ Expr Simplify(Expr expr, Map vrange) { for (auto kv : vrange) { analyzer.Bind(kv.first, kv.second); } - return analyzer.canonical_simplify(expr); + expr = analyzer.canonical_simplify(expr); + expr = analyzer.rewrite_simplify(expr); + return expr; } Stmt Simplify(Stmt stmt, Map vrange) { diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index b79016ecbc160..e303cea5636cc 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -314,6 +314,7 @@ class LoopPartitioner : public IRMutator { } Stmt Mutate_(const For* op, const Stmt& stmt) { + if (selector.candidates.count(op)) { Stmt s = TryPartition(op, stmt, op->loop_var, op->min, op->min + op->extent - 1, op->body, false); @@ -466,8 +467,13 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Stmt body, bool partition_thread_scope) { using namespace arith; + // include hint of var. + hint_map_.insert({var.get(), IntSet::interval(min, max)}); + PartitionFinder finder(var, hint_map_, relax_map_); finder.Visit(body); + + hint_map_.erase(var.get()); if (finder.partitions.empty()) return Stmt(); arith::IntervalSet for_interval(min, max);