Skip to content

Commit

Permalink
Fix bound deducer to account for trunc div
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jun 20, 2019
1 parent 5274a51 commit 47a1a33
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 70 deletions.
9 changes: 6 additions & 3 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -613,12 +613,15 @@ IntSet Intersect(const Array<IntSet>& sets);
* give the domain of each variables. Return undefined IntSet to
* represent failure.
*
* \note The returned set may be smaller than set that
* contains all possible values of v that satisfies the bound.
*
* \param v The target variable to be deduced.
* \param cond The conditional expression.
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that can cover all the possible values.
* The deduce bound must implies e for all value in relax_map
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
const Map<Var, IntSet>& hint_map,
Expand All @@ -631,7 +634,7 @@ IntSet DeduceBound(Expr v, Expr cond,
* \param hint_map The domain of variable, used to help deduce.
* \param relax_map The domain of each variable, used to relax the domain,
* The deduce bound mush implies e for all value in relax_map
* \return An integer set that can cover all the possible values.
* \return An integer set that always satisfies the condition.
*/
IntSet DeduceBound(Expr v, Expr cond,
const std::unordered_map<const Variable*, IntSet>& hint_map,
Expand Down
144 changes: 83 additions & 61 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ class BoundDeducer: public IRVisitor {
void Deduce();

void Visit(const NodeRef& e) final {
if (!success) return;
if (!success_) return;
if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e);
} else {
success = false;
success_ = false;
return;
}
}
Expand All @@ -111,62 +111,85 @@ class BoundDeducer: public IRVisitor {

void Visit_(const Add* op) final {
bool left = op->a.get() == path_[iter_];
result -= left ? op->b : op->a;
result_ -= left ? op->b : op->a;
Visit(left ? op->a : op->b);
}

void Visit_(const Sub* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
result += op->b;
result_ += op->b;
} else {
result -= op->a;
result = - result;
is_greater = !is_greater;
result_ -= op->a;
result_ = - result_;
is_greater_ = !is_greater_;
}
Visit(left ? op->a : op->b);
}

void Visit_(const Mul* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;
Expr target_var = left ? op->a : op->b;

SignType sign;
SignType sign_operand;
if (operand.type().is_uint()) {
sign = kPositive;
sign_operand = kPositive;
} else {
sign = expr_map_[operand].sign_type();
sign_operand = expr_map_[operand].sign_type();
}

if (sign == SignType::kNegative) {
is_greater = !is_greater;
} else if (sign == SignType::kUnknown) {
if (sign_operand == SignType::kNegative) {
is_greater_ = !is_greater_;
} else if (sign_operand == SignType::kUnknown) {
// unable to get the sign of operand
success = false;
success_ = false;
return;
}

// always use relax bound
bool divided = analyzer_.CanProve(result % operand == 0);
result = result / operand;
// since system will round down when not divided
// eg. 2/4 -> 0; -2/4 -> -1
// no need fix for !is_greater:
// eg. a <= 2/4 -> a <= 0
// eg. a <= 0/4 -> a <= 0
// so just fix for not divided and is_greater
// eg. a >= 2/4 -> a >= 0 + 1
// eg. a >= 0/4 -> a >= 0
if (is_greater && !divided) {
result += 1;
bool divided = analyzer_.CanProve(result_ % operand == 0);

result_ = result_ / operand;

if (!divided) {
// Handle non-divisible case
// NOTE: this accounts for truc div behavior.
bool target_is_non_neg = expr_map_[target_var].can_prove_non_negative();

if (is_greater_) {
result_ += 1;
} else {
// NOTE: this is a bit sutble hack.
//
// condition:
// - x * operand <= result
// - operand > 0
// - x >= 0
//
// Then it is fine to deduce that x <= result / operand.
// - if result > 0, this division round down
// - if result < 0, (result / operand) rounds up and may violate the constraint
// however, given that x is always non-negative,
// it is fine to have this relaxed bound, given that the user of deduce bound
// will respect the bound of x
//
// TODO(tvm-team): think about a better API to incorporate constraint of x.
// e.g. specify an interval of x and return a bound
// that is in the interval and satisfies the condition.
if (target_is_non_neg && sign_operand == kPositive) {
// do nothing
} else {
result_ -= 1;
}

}
}

Visit(left ? op->a : op->b);
}

Expr result;
bool is_greater{true};
bool success{true};
Expr result_;
bool is_greater_{true};
bool success_{true};

private:
void Init();
Expand Down Expand Up @@ -204,7 +227,7 @@ class BoundDeduceInputChecker: public IRVisitor {

void BoundDeducer::Init() {
BoundDeduceInputChecker checker;
if (!checker.Check(this)) success = false;
if (!checker.Check(this)) success_ = false;
Transform();
}

Expand All @@ -213,93 +236,92 @@ void BoundDeducer::Transform() {
if (const LT* op = expr_.as<LT>()) {
if (GetPath(target_, op->a).empty()) {
// a < b -> b >= a + 1
is_greater = true;
is_greater_ = true;
expr_ = op->b;
result = op->a + 1;
result_ = op->a + 1;
} else {
// a < b -> a <= b - 1
is_greater = false;
is_greater_ = false;
expr_ = op->a;
result = op->b - 1;
result_ = op->b - 1;
}
} else if (const LE* op = expr_.as<LE>()) {
if (GetPath(target_, op->a).empty()) {
// a <= b -> b >= a
is_greater = true;
is_greater_ = true;
expr_ = op->b;
result = op->a;
result_ = op->a;
} else {
is_greater = false;
is_greater_ = false;
expr_ = op->a;
result = op->b;
result_ = op->b;
}
} else if (const GT* op = expr_.as<GT>()) {
if (GetPath(target_, op->a).empty()) {
// a > b -> b <= a - 1
is_greater = false;
is_greater_ = false;
expr_ = op->b;
result = op->a - 1;
result_ = op->a - 1;
} else {
// a > b -> a >= b + 1
is_greater = true;
is_greater_ = true;
expr_ = op->a;
result = op->b + 1;
result_ = op->b + 1;
}
} else if (const GE* op = expr_.as<GE>()) {
if (GetPath(target_, op->a).empty()) {
// a >= b -> b <= a
is_greater = false;
is_greater_ = false;
expr_ = op->b;
result = op->a;
result_ = op->a;
} else {
is_greater = true;
is_greater_ = true;
expr_ = op->a;
result = op->b;
result_ = op->b;
}
} else {
success = false;
success_ = false;
}
}

void BoundDeducer::Deduce() {
Init();
if (!success) return;
if (!success_) return;
Relax();
if (!success) return;
if (!success_) return;
// get the path
path_ = GetPath(target_, expr_);
if (!path_.size()) {
success = false;
success_ = false;
return;
}

expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);

Visit(expr_);
}

void BoundDeducer::Relax() {
IntSet a = EvalSet(expr_, relax_map_);
IntSet b = EvalSet(result, relax_map_);
IntSet b = EvalSet(result_, relax_map_);
if (a.is_everything() || b.is_everything()) {
success = false;
success_ = false;
return;
}
expr_ = is_greater ? a.min() : a.max();
result = is_greater ? b.max() : b.min();
expr_ = is_greater_ ? a.min() : a.max();
result_ = is_greater_ ? b.max() : b.min();
}

IntSet DeduceBound(Expr v, Expr e,
const std::unordered_map<const Variable*, IntSet>& hint_map,
const std::unordered_map<const Variable*, IntSet>& relax_map) {
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success) return IntSet::nothing();
if (!d.success_) return IntSet::nothing();
Expr min = neg_inf(), max = pos_inf();
if (d.is_greater) {
min = d.result;
if (d.is_greater_) {
min = d.result_;
} else {
max = d.result;
max = d.result_;
}
return IntSet::interval(min, max);
}
Expand Down
7 changes: 4 additions & 3 deletions src/arithmetic/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ template<>
inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
TVM_ARITH_CONST_PROPAGATION({
const Type& rtype = a.type();
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if (pa && pb && pa->value >= 0 && pb->value > 0) {
if (pa && pb) {
// due to division and mod can have different modes
// NOTE: this will assumes truc div.
CHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm::make(rtype, pa->value / pb->value);
}
if (pa) {
Expand Down
8 changes: 8 additions & 0 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_REWRITE(min(x - z, y) + z, min(x, y + z));
TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y));
TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z));

TVM_TRY_REWRITE(max(x, y) + min(x, y), x + y);
TVM_TRY_REWRITE(min(x, y) + max(x, y), x + y);
TVM_TRY_REWRITE(max(x, y) + min(y, x), x + y);
Expand Down Expand Up @@ -188,6 +189,9 @@ Mutate_(const Add* op, const Expr& self) {
TVM_TRY_RECURSIVE_REWRITE(x + c1 + y, (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x + (c1 + y), (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE((y % c1) + x * c1, x * c1 + (y % c1));

TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x);
TVM_TRY_RECURSIVE_REWRITE(x + min(y, z), min(y, z) + x);
}

// condition rules.
Expand Down Expand Up @@ -455,6 +459,10 @@ Mutate_(const Div* op, const Expr& self) {
}
}

TVM_TRY_REWRITE(x / x, OneWithTypeLike(x));
TVM_TRY_REWRITE(x * c1 / x, c1);
TVM_TRY_REWRITE(c1 * x / x, c1);

// Rules involving 2-operands.
TVM_TRY_REWRITE_IF((x * c1 + y) / c2, x * (c1 / c2) + y / c2,
c1.Eval()->value >= 0 &&
Expand Down
9 changes: 7 additions & 2 deletions src/arithmetic/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand Down Expand Up @@ -121,6 +121,11 @@ class RewriteSimplifier::Impl : public IRMutator {
PConstWithTypeLike<TA> ZeroWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 0);
}

template<typename TA>
PConstWithTypeLike<TA> OneWithTypeLike(const Pattern<TA>& pattern) {
return PConstWithTypeLike<TA>(pattern.derived(), 1);
}
};


Expand Down
4 changes: 3 additions & 1 deletion src/arithmetic/stmt_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ Expr Simplify(Expr expr, Map<Var, Range> vrange) {
for (auto kv : vrange) {
analyzer.Bind(kv.first, kv.second);
}
return analyzer.canonical_simplify(expr);
expr = analyzer.canonical_simplify(expr);
expr = analyzer.rewrite_simplify(expr);
return expr;
}

Stmt Simplify(Stmt stmt, Map<Var, Range> vrange) {
Expand Down
6 changes: 6 additions & 0 deletions src/pass/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ class LoopPartitioner : public IRMutator {
}

Stmt Mutate_(const For* op, const Stmt& stmt) {

if (selector.candidates.count(op)) {
Stmt s = TryPartition(op, stmt, op->loop_var,
op->min, op->min + op->extent - 1, op->body, false);
Expand Down Expand Up @@ -466,8 +467,13 @@ Stmt LoopPartitioner::TryPartition(const Node* node,
Stmt body,
bool partition_thread_scope) {
using namespace arith;
// include hint of var.
hint_map_.insert({var.get(), IntSet::interval(min, max)});

PartitionFinder finder(var, hint_map_, relax_map_);
finder.Visit(body);

hint_map_.erase(var.get());
if (finder.partitions.empty()) return Stmt();

arith::IntervalSet for_interval(min, max);
Expand Down

0 comments on commit 47a1a33

Please sign in to comment.