Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] DeduceBound #40

Merged
merged 27 commits into from
Feb 17, 2017
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/ir_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ class IRVisitor {
virtual void Visit_(const Let* op);
virtual void Visit_(const Free* op);
virtual void Visit_(const Call* op);
virtual void Visit_(const Add* op);
virtual void Visit_(const Sub* op);
virtual void Visit_(const Mul* op);
virtual void Visit_(const Div* op);
virtual void Visit_(const Mod* op);
virtual void Visit_(const LT* op);
};

/*!
Expand Down
133 changes: 133 additions & 0 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*!
* Copyright (c) 2017 by Contributors
* \file bound_deducer.cc
*/
#include <tvm/expr.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h>
#include <unordered_set>
#include "./int_set.h"

namespace tvm {
namespace arith {

using namespace ir;
using Halide::Internal::Interval;

// a visitor to find the path to the target variable
// from a expression.
class VariableFinder: public IRVisitor {
public:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VariablePathFinder

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to look out for errors when a variable appears in multiple locations in the expression

explicit VariableFinder(Var target) : target_(target) {}

void Visit(const NodeRef& node) final {
if (finded_) return;
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());

path_.push_back(node.get());
if (node.same_as(target_)) finded_ = true;
IRVisitor::Visit(node);
if (!finded_) path_.pop_back();
}

std::vector<const Node*> path_;

private:
bool finded_{false};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

found

Var target_;
std::unordered_set<const Node*> visited_;
};


// get the path to the variable
std::vector<const Node*> GetPath(Var target, Expr expr) {
VariableFinder v(target);
v.Visit(expr);
return v.path_;
}


// a visitor to deduce the bound of a variable from a expression
class BoundDeducer: public IRVisitor {
public:
Expr Deduce(Var target, Expr expr) {
path_ = GetPath(target, expr);
target_ = target;
iter_ = 0;
result = make_zero(expr.type());

Visit(expr);
return result;
}

void Visit(const NodeRef& e) final {
if (e.get() == path_[iter_++]) {
IRVisitor::Visit(e);
} else {
LOG(FATAL) << "the current node is not match with the deduced path";
}
}

void Visit_(const Add* op) final {
bool left = op->a.get() == path_[iter_];
result -= left ? op->b : op->a;
Visit(left ? op->a : op->b);
}

void Visit_(const Sub* op) final {
bool left = op->a.get() == path_[iter_];
if (left) {
result += op->b;
} else {
result -= op->a;
result = -1 * result;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to -result, or 0- result, negation should be overloaded already?

is_greater = !is_greater;
}
Visit(left ? op->a : op->b);
}

void Visit_(const Mul* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;
if (is_negative_const(operand)) is_greater = !is_greater;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be cases where we cannot prove either it is negative_const or is positive const, which results in a deduction failure

result /= operand;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be problem of rounding in here. if it is a lower bound and rounds toward 0, then it becomes problematic.

Maybe we should consider find out the direction first, before doing deduction

Visit(left ? op->a : op->b);
}

void Visit_(const Div* op) final {
bool left = op->a.get() == path_[iter_];
Expr operand = left ? op->b : op->a;
if (is_negative_const(operand)) is_greater = !is_greater;
result = left ? result * operand : operand / result;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about rounding

Visit(left ? op->a : op->b);
}

Expr result;
bool is_greater{true};

private:
Var target_;
std::vector<const Node*> path_;
size_t iter_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

directly initialize it to 0 here

};

// Assuming e >= 0, deduce the bound of variable from it.
IntSet DeduceBound(Var v, Expr e) {
BoundDeducer deducer;
deducer.Deduce(v, e);
Type t = deducer.result.type();
return deducer.is_greater ?
IntSet::range(Range(deducer.result, Cast::make(t, Interval::pos_inf))) :
IntSet::range(Range(Cast::make(t, Interval::neg_inf), deducer.result));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not use range, we can use IntervalSet::make, maybe we can consider to expose IntervalSet under arith

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider to put IntervalSet and other data structure in int_set_internal.h

}

TVM_REGISTER_API(_pass_DeduceBound)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0].operator Var(), args[1].operator Expr());
});


} // namespace arith
} // namespace tvm
50 changes: 43 additions & 7 deletions src/pass/ir_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,13 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.DISPATCH_TO_VISIT(Store)
.DISPATCH_TO_VISIT(Let)
.DISPATCH_TO_VISIT(Call)
.DISPATCH_TO_VISIT(Free);
.DISPATCH_TO_VISIT(Free)
.DISPATCH_TO_VISIT(Add)
.DISPATCH_TO_VISIT(Sub)
.DISPATCH_TO_VISIT(Mul)
.DISPATCH_TO_VISIT(Div)
.DISPATCH_TO_VISIT(Mod)
.DISPATCH_TO_VISIT(LT);

void IRVisitor::Visit_(const Variable* op) {}

Expand Down Expand Up @@ -128,6 +134,36 @@ void IRVisitor::Visit_(const Call *op) {
VisitArray(op->args, this);
}

void IRVisitor::Visit_(const Add* op) {
this->Visit(op->a);
this->Visit(op->b);
}

void IRVisitor::Visit_(const Sub* op) {
this->Visit(op->a);
this->Visit(op->b);
}

void IRVisitor::Visit_(const Mul* op) {
this->Visit(op->a);
this->Visit(op->b);
}

void IRVisitor::Visit_(const Div* op) {
this->Visit(op->a);
this->Visit(op->b);
}

void IRVisitor::Visit_(const Mod* op) {
this->Visit(op->a);
this->Visit(op->b);
}

void IRVisitor::Visit_(const LT* op) {
this->Visit(op->a);
this->Visit(op->b);
}

TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
VisitRDom(op->axis, v);
Expand All @@ -151,16 +187,16 @@ inline void Binary(const T* op, IRVisitor* v) {
}

TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Add>(Binary<Add>)
.set_dispatch<Sub>(Binary<Sub>)
.set_dispatch<Mul>(Binary<Mul>)
.set_dispatch<Div>(Binary<Div>)
.set_dispatch<Mod>(Binary<Mod>)
// .set_dispatch<Add>(Binary<Add>)
// .set_dispatch<Sub>(Binary<Sub>)
// .set_dispatch<Mul>(Binary<Mul>)
// .set_dispatch<Div>(Binary<Div>)
// .set_dispatch<Mod>(Binary<Mod>)
.set_dispatch<Min>(Binary<Min>)
.set_dispatch<Max>(Binary<Max>)
.set_dispatch<EQ>(Binary<EQ>)
.set_dispatch<NE>(Binary<NE>)
.set_dispatch<LT>(Binary<LT>)
// .set_dispatch<LT>(Binary<LT>)
.set_dispatch<LE>(Binary<LE>)
.set_dispatch<GT>(Binary<GT>)
.set_dispatch<GE>(Binary<GE>)
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_pass_deduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import tvm

x = tvm.Var('x')
y = tvm.Var('y')
z = tvm.Var('z')
a = tvm.Var('a')
b = tvm.Var('b')

e0 = (-a+x*y+z-b)
print(type(e0))
e1 = tvm.ir_pass.DeduceBound(x, e0)
print(e1)