-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ARITH] DeduceBound #40
Changes from 5 commits
2b9cb33
461d778
bccdfce
31478b2
2b4d091
2539265
0ea07ab
cf9f3ba
8f72e50
9106944
1f1ff8f
b409040
e4bee27
e3a5f9e
b1617a8
7abe378
f829694
96ded33
5a8fa91
71349f6
f3e3fa9
d5aedde
35683a8
d9794bb
696976a
434835a
2527b2e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file bound_deducer.cc | ||
*/ | ||
#include <tvm/expr.h> | ||
#include <tvm/ir_pass.h> | ||
#include <tvm/ir_visitor.h> | ||
#include <tvm/api_registry.h> | ||
#include <unordered_set> | ||
#include "./int_set.h" | ||
|
||
namespace tvm { | ||
namespace arith { | ||
|
||
using namespace ir; | ||
using Halide::Internal::Interval; | ||
|
||
// a visitor to find the path to the target variable | ||
// from a expression. | ||
class VariableFinder: public IRVisitor { | ||
public: | ||
explicit VariableFinder(Var target) : target_(target) {} | ||
|
||
void Visit(const NodeRef& node) final { | ||
if (finded_) return; | ||
if (visited_.count(node.get()) != 0) return; | ||
visited_.insert(node.get()); | ||
|
||
path_.push_back(node.get()); | ||
if (node.same_as(target_)) finded_ = true; | ||
IRVisitor::Visit(node); | ||
if (!finded_) path_.pop_back(); | ||
} | ||
|
||
std::vector<const Node*> path_; | ||
|
||
private: | ||
bool finded_{false}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. found |
||
Var target_; | ||
std::unordered_set<const Node*> visited_; | ||
}; | ||
|
||
|
||
// get the path to the variable | ||
std::vector<const Node*> GetPath(Var target, Expr expr) { | ||
VariableFinder v(target); | ||
v.Visit(expr); | ||
return v.path_; | ||
} | ||
|
||
|
||
// a visitor to deduce the bound of a variable from a expression | ||
class BoundDeducer: public IRVisitor { | ||
public: | ||
Expr Deduce(Var target, Expr expr) { | ||
path_ = GetPath(target, expr); | ||
target_ = target; | ||
iter_ = 0; | ||
result = make_zero(expr.type()); | ||
|
||
Visit(expr); | ||
return result; | ||
} | ||
|
||
void Visit(const NodeRef& e) final { | ||
if (e.get() == path_[iter_++]) { | ||
IRVisitor::Visit(e); | ||
} else { | ||
LOG(FATAL) << "the current node is not match with the deduced path"; | ||
} | ||
} | ||
|
||
void Visit_(const Add* op) final { | ||
bool left = op->a.get() == path_[iter_]; | ||
result -= left ? op->b : op->a; | ||
Visit(left ? op->a : op->b); | ||
} | ||
|
||
void Visit_(const Sub* op) final { | ||
bool left = op->a.get() == path_[iter_]; | ||
if (left) { | ||
result += op->b; | ||
} else { | ||
result -= op->a; | ||
result = -1 * result; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change to -result, or 0- result, negation should be overloaded already? |
||
is_greater = !is_greater; | ||
} | ||
Visit(left ? op->a : op->b); | ||
} | ||
|
||
void Visit_(const Mul* op) final { | ||
bool left = op->a.get() == path_[iter_]; | ||
Expr operand = left ? op->b : op->a; | ||
if (is_negative_const(operand)) is_greater = !is_greater; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There could be cases where we cannot prove either it is negative_const or is positive const, which results in a deduction failure |
||
result /= operand; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There will be problem of rounding in here. if it is a lower bound and rounds toward 0, then it becomes problematic. Maybe we should consider find out the direction first, before doing deduction |
||
Visit(left ? op->a : op->b); | ||
} | ||
|
||
void Visit_(const Div* op) final { | ||
bool left = op->a.get() == path_[iter_]; | ||
Expr operand = left ? op->b : op->a; | ||
if (is_negative_const(operand)) is_greater = !is_greater; | ||
result = left ? result * operand : operand / result; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment about rounding |
||
Visit(left ? op->a : op->b); | ||
} | ||
|
||
Expr result; | ||
bool is_greater{true}; | ||
|
||
private: | ||
Var target_; | ||
std::vector<const Node*> path_; | ||
size_t iter_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. directly initialize it to 0 here |
||
}; | ||
|
||
// Assuming e >= 0, deduce the bound of variable from it. | ||
IntSet DeduceBound(Var v, Expr e) { | ||
BoundDeducer deducer; | ||
deducer.Deduce(v, e); | ||
Type t = deducer.result.type(); | ||
return deducer.is_greater ? | ||
IntSet::range(Range(deducer.result, Cast::make(t, Interval::pos_inf))) : | ||
IntSet::range(Range(Cast::make(t, Interval::neg_inf), deducer.result)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do not use range, we can use IntervalSet::make, maybe we can consider to expose IntervalSet under arith There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider to put IntervalSet and other data structure in int_set_internal.h |
||
} | ||
|
||
TVM_REGISTER_API(_pass_DeduceBound) | ||
.set_body([](TVMArgs args, TVMRetValue *ret) { | ||
*ret = DeduceBound(args[0].operator Var(), args[1].operator Expr()); | ||
}); | ||
|
||
|
||
} // namespace arith | ||
} // namespace tvm |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import tvm | ||
|
||
x = tvm.Var('x') | ||
y = tvm.Var('y') | ||
z = tvm.Var('z') | ||
a = tvm.Var('a') | ||
b = tvm.Var('b') | ||
|
||
e0 = (-a+x*y+z-b) | ||
print(type(e0)) | ||
e1 = tvm.ir_pass.DeduceBound(x, e0) | ||
print(e1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VariablePathFinder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to look out for errors when a variable appears in multiple locations in the expression