From 2b9cb33f3d3d26f83d2ea98cfaa0345c7a37c44d Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 10 Feb 2017 06:32:32 +0000 Subject: [PATCH 01/23] [PYTHON/API] Add compare and logic build-in op for Expr --- python/tvm/expr.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index c3b0845aecb6..3c6c9f023906 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -37,6 +37,30 @@ def __rtruediv__(self, other): def __neg__(self): return self.__mul__(-1) + def __lt__(self, other): + return _make.LT(self, other) + + def __le__(self, other): + return _make.LE(self, other) + + def __eq__(self, other): + return _make.EQ(self, other) + + def __ne__(self, other): + return _make.NE(self, other) + + def __gt__(self, other): + return _make.GT(self, other) + + def __ge__(self, other): + return _make.GE(self, other) + + def __and__(self, other): + return _make.And(self, other) + + def __or__(self, other): + return _make.And(self, other) + class Expr(NodeBase, ExprOp): pass From 461d7788aa2a820bc3a47318139d383c406cc34a Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 10 Feb 2017 16:45:21 +0000 Subject: [PATCH 02/23] remove 'and', 'or' --- python/tvm/expr.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 3c6c9f023906..8ae1d45d8b5e 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -55,12 +55,6 @@ def __gt__(self, other): def __ge__(self, other): return _make.GE(self, other) - def __and__(self, other): - return _make.And(self, other) - - def __or__(self, other): - return _make.And(self, other) - class Expr(NodeBase, ExprOp): pass From bccdfce51825f4b943318dfc28b5070aa826e5bd Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 10 Feb 2017 22:17:54 +0000 Subject: [PATCH 03/23] add deducer --- include/tvm/ir_pass.h | 2 + include/tvm/ir_visitor.h | 3 + src/api/api_pass.cc | 5 + src/pass/ir_visitor.cc | 26 +++- src/pass/partition_loops.cc | 167 ++++++++++++++++++++++ tests/python/unittest/test_pass_deduce.py | 12 ++ 6 files changed, 211 insertions(+), 4 deletions(-) create mode 100644 src/pass/partition_loops.cc create mode 100644 tests/python/unittest/test_pass_deduce.py diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 4ce90e3b7739..f0ad94fca903 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -55,6 +55,8 @@ bool VerifySSA(const Stmt& ir); */ bool HasSideEffect(const Expr& e); +Expr Deduce(Var v, Expr e); + /*! * \brief Convert a IR node to be SSA form. * \param stmt The source statement to be converted. diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index e5711f65ff86..0df935660742 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -46,6 +46,9 @@ class IRVisitor { virtual void Visit_(const Let* op); virtual void Visit_(const Free* op); virtual void Visit_(const Call* op); + virtual void Visit_(const Mul* op); + virtual void Visit_(const Add* op); + virtual void Visit_(const LT* op); }; /*! diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index b8f3cbc3bd9e..91092b135908 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -12,6 +12,11 @@ namespace tvm { namespace ir { +TVM_REGISTER_API(_pass_Deduce) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = Deduce(args[0].operator Var(), args[1].operator Expr()); + }); + TVM_REGISTER_API(_pass_Simplify) .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsNodeType()) { diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index 5baaa851970e..3f77e1f55d27 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -67,7 +67,10 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .DISPATCH_TO_VISIT(Store) .DISPATCH_TO_VISIT(Let) .DISPATCH_TO_VISIT(Call) -.DISPATCH_TO_VISIT(Free); +.DISPATCH_TO_VISIT(Free) +.DISPATCH_TO_VISIT(Add) +.DISPATCH_TO_VISIT(Mul) +.DISPATCH_TO_VISIT(LT); void IRVisitor::Visit_(const Variable* op) {} @@ -128,6 +131,21 @@ void IRVisitor::Visit_(const Call *op) { VisitArray(op->args, this); } +void IRVisitor::Visit_(const Add* op) { + this->Visit(op->a); + this->Visit(op->b); +} + +void IRVisitor::Visit_(const Mul* op) { + this->Visit(op->a); + this->Visit(op->b); +} + +void IRVisitor::Visit_(const LT* op) { + this->Visit(op->a); + this->Visit(op->b); +} + TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .set_dispatch([](const Reduce* op, IRVisitor* v) { VisitRDom(op->axis, v); @@ -151,16 +169,16 @@ inline void Binary(const T* op, IRVisitor* v) { } TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) -.set_dispatch(Binary) +// .set_dispatch(Binary) .set_dispatch(Binary) -.set_dispatch(Binary) +// .set_dispatch(Binary) .set_dispatch
(Binary
) .set_dispatch(Binary) .set_dispatch(Binary) .set_dispatch(Binary) .set_dispatch(Binary) .set_dispatch(Binary) -.set_dispatch(Binary) +// .set_dispatch(Binary) .set_dispatch(Binary) .set_dispatch(Binary) .set_dispatch(Binary) diff --git a/src/pass/partition_loops.cc b/src/pass/partition_loops.cc new file mode 100644 index 000000000000..6ebf3a96093e --- /dev/null +++ b/src/pass/partition_loops.cc @@ -0,0 +1,167 @@ +#include +#include +#include + + +namespace tvm { +namespace ir { + +class VariableFinder: public IRVisitor { + public: + VariableFinder(Var target) : target_(target) {} + + void Visit(const NodeRef& node) final { + if (finded) return; + if (visited_.count(node.get()) != 0) return; + visited_.insert(node.get()); + + path_.push_back(node.get()); + if (node.same_as(target_)) finded = true; + IRVisitor::Visit(node); + if (!finded) path_.pop_back(); + } + + std::vector path_; + + private: + bool finded{false}; + Var target_; + std::unordered_set visited_; +}; + + +// Get the path to the variable +std::vector GetPath(Var target, Expr expr) { + VariableFinder v(target); + v.Visit(expr); + return v.path_; +} + + + +// IRVisitor version +class Deducer: public IRVisitor { + public: + Expr Deduce(Var target, Expr expr) { + path_ = GetPath(target, expr); + target_ = target; + iter = 0; + + LOG(INFO) << "Path"; + for (const Node* n : path_) { + LOG(INFO) << n->type_key(); + } + Visit(expr); + return result; + } + + void Visit(const NodeRef& e) final { + if (e.get() == path_[iter++]) { + LOG(INFO) << "Deduce " << e->type_key(); + IRVisitor::Visit(e); + } else { + LOG(INFO) << "ERROR " << e->type_key(); + } + } + + void Visit_(const LT* op) final { + result = op->b; + Visit(op->a); + } + + void Visit_(const Add* op) final { + bool left = op->a.get() == path_[iter]; + result -= left ? op->b : op->a; + Visit(left ? op->a : op->b); + } + + void Visit_(const Mul* op) final { + bool left = op->a.get() == path_[iter]; + result /= left ? op->b : op->a; + Visit(left ? op->a : op->b); + } + + Expr result; + private: + Var target_; + std::vector path_; + size_t iter; +}; + + +// IRMutator version +class DeduceMutator { + public: + Expr Deduce(Var target, Expr expr) { + this->path_ = GetPath(target, expr); + this->target = target; + this->iter = 0; + + LOG(INFO) << "Path"; + for (const Node* n : path_) { + LOG(INFO) << n->type_key(); + } + return Mutate(expr, expr); + } + + Expr Mutate(const NodeRef& node, Expr result) { + if (node.get() == path_[iter++]) { + LOG(INFO) << "Deduce " << node->type_key(); + static const FMutateExpr& f = vtable_expr(); + return f(node, result, this); + } else { + LOG(INFO) << "Error " << node->type_key(); + return result; + } + } + + const Node* GetCurrentNode() { + return path_[iter]; + } + + using FMutateExpr = IRFunctor; + static FMutateExpr& vtable_expr(); + + Var target; + private: + std::vector path_; + size_t iter; +}; + +DeduceMutator::FMutateExpr& DeduceMutator::vtable_expr() { // NOLINT(*) + static FMutateExpr inst; return inst; +} + +TVM_STATIC_IR_FUNCTOR(DeduceMutator, vtable_expr) +.set_dispatch([](const LT* op, Expr& res, DeduceMutator* m) { + return m->Mutate(op->a, op->b); +}) +.set_dispatch([](const Mul* op, Expr& res, DeduceMutator* m) { + bool left = op->a.get() == m->GetCurrentNode(); + res /= left ? op->b : op->a; + return m->Mutate(left ? op->a : op->b, res); +}) +.set_dispatch([](const Add* op, Expr& res, DeduceMutator* m) { + bool left = op->a.get() == m->GetCurrentNode(); + res -= left ? op->b : op->a; + return m->Mutate(left ? op->a : op->b, res); +}) +.set_dispatch([](const Variable* op, Expr& res, DeduceMutator* m) { + return res; +}); + + + + +Expr Deduce(Var v, Expr e) { + // x*y+z < a + LOG(INFO) << "Deduce"; + // Deducer deducer; + // deducer.Deduce(v, e); + // return deducer.result; + DeduceMutator deducer; + return deducer.Deduce(v, e); +} + +} +} diff --git a/tests/python/unittest/test_pass_deduce.py b/tests/python/unittest/test_pass_deduce.py new file mode 100644 index 000000000000..c4d8e68a559f --- /dev/null +++ b/tests/python/unittest/test_pass_deduce.py @@ -0,0 +1,12 @@ +import tvm + +x = tvm.Var('x') +y = tvm.Var('y') +z = tvm.Var('z') +a = tvm.Var('a') +b = tvm.Var('b') + +e0 = (x*y+z Date: Sun, 12 Feb 2017 01:12:01 +0000 Subject: [PATCH 04/23] [WIP] bound_deducer.cc --- include/tvm/ir_pass.h | 2 - include/tvm/ir_visitor.h | 5 +- src/api/api_pass.cc | 5 - src/arithmetic/bound_deducer.cc | 133 +++++++++++++++++ src/pass/ir_visitor.cc | 24 +++- src/pass/partition_loops.cc | 167 ---------------------- tests/python/unittest/test_pass_deduce.py | 4 +- 7 files changed, 160 insertions(+), 180 deletions(-) create mode 100644 src/arithmetic/bound_deducer.cc delete mode 100644 src/pass/partition_loops.cc diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index f0ad94fca903..4ce90e3b7739 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -55,8 +55,6 @@ bool VerifySSA(const Stmt& ir); */ bool HasSideEffect(const Expr& e); -Expr Deduce(Var v, Expr e); - /*! * \brief Convert a IR node to be SSA form. * \param stmt The source statement to be converted. diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index 0df935660742..1cfce55ef33e 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -46,8 +46,11 @@ class IRVisitor { virtual void Visit_(const Let* op); virtual void Visit_(const Free* op); virtual void Visit_(const Call* op); - virtual void Visit_(const Mul* op); virtual void Visit_(const Add* op); + virtual void Visit_(const Sub* op); + virtual void Visit_(const Mul* op); + virtual void Visit_(const Div* op); + virtual void Visit_(const Mod* op); virtual void Visit_(const LT* op); }; diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 91092b135908..b8f3cbc3bd9e 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -12,11 +12,6 @@ namespace tvm { namespace ir { -TVM_REGISTER_API(_pass_Deduce) -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = Deduce(args[0].operator Var(), args[1].operator Expr()); - }); - TVM_REGISTER_API(_pass_Simplify) .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsNodeType()) { diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc new file mode 100644 index 000000000000..54ca077f35ed --- /dev/null +++ b/src/arithmetic/bound_deducer.cc @@ -0,0 +1,133 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file bound_deducer.cc + */ +#include +#include +#include +#include +#include +#include "./int_set.h" + +namespace tvm { +namespace arith { + +using namespace ir; +using Halide::Internal::Interval; + +// a visitor to find the path to the target variable +// from a expression. +class VariableFinder: public IRVisitor { + public: + explicit VariableFinder(Var target) : target_(target) {} + + void Visit(const NodeRef& node) final { + if (finded_) return; + if (visited_.count(node.get()) != 0) return; + visited_.insert(node.get()); + + path_.push_back(node.get()); + if (node.same_as(target_)) finded_ = true; + IRVisitor::Visit(node); + if (!finded_) path_.pop_back(); + } + + std::vector path_; + + private: + bool finded_{false}; + Var target_; + std::unordered_set visited_; +}; + + +// get the path to the variable +std::vector GetPath(Var target, Expr expr) { + VariableFinder v(target); + v.Visit(expr); + return v.path_; +} + + +// a visitor to deduce the bound of a variable from a expression +class BoundDeducer: public IRVisitor { + public: + Expr Deduce(Var target, Expr expr) { + path_ = GetPath(target, expr); + target_ = target; + iter_ = 0; + result = make_zero(expr.type()); + + Visit(expr); + return result; + } + + void Visit(const NodeRef& e) final { + if (e.get() == path_[iter_++]) { + IRVisitor::Visit(e); + } else { + LOG(FATAL) << "the current node is not match with the deduced path"; + } + } + + void Visit_(const Add* op) final { + bool left = op->a.get() == path_[iter_]; + result -= left ? op->b : op->a; + Visit(left ? op->a : op->b); + } + + void Visit_(const Sub* op) final { + bool left = op->a.get() == path_[iter_]; + if (left) { + result += op->b; + } else { + result -= op->a; + result = -1 * result; + is_greater = !is_greater; + } + Visit(left ? op->a : op->b); + } + + void Visit_(const Mul* op) final { + bool left = op->a.get() == path_[iter_]; + Expr operand = left ? op->b : op->a; + if (is_negative_const(operand)) is_greater = !is_greater; + result /= operand; + Visit(left ? op->a : op->b); + } + + void Visit_(const Div* op) final { + bool left = op->a.get() == path_[iter_]; + Expr operand = left ? op->b : op->a; + if (is_negative_const(operand)) is_greater = !is_greater; + result = left ? result * operand : operand / result; + Visit(left ? op->a : op->b); + } + + Expr result; + bool is_greater{true}; + + private: + Var target_; + std::vector path_; + size_t iter_; +}; + +// Assuming e >= 0, deduce the bound of variable from it. +IntSet DeduceBound(Var v, Expr e) { + BoundDeducer deducer; + deducer.Deduce(v, e); + Type t = deducer.result.type(); + return deducer.is_greater ? + IntSet::range(Range(deducer.result, Cast::make(t, Interval::pos_inf))) : + IntSet::range(Range(Cast::make(t, Interval::neg_inf), deducer.result)); +} + +TVM_REGISTER_API(_pass_DeduceBound) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = DeduceBound(args[0].operator Var(), args[1].operator Expr()); + }); + + +} // namespace arith +} // namespace tvm diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index 3f77e1f55d27..0ebaa2f9c6a5 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -69,7 +69,10 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .DISPATCH_TO_VISIT(Call) .DISPATCH_TO_VISIT(Free) .DISPATCH_TO_VISIT(Add) +.DISPATCH_TO_VISIT(Sub) .DISPATCH_TO_VISIT(Mul) +.DISPATCH_TO_VISIT(Div) +.DISPATCH_TO_VISIT(Mod) .DISPATCH_TO_VISIT(LT); void IRVisitor::Visit_(const Variable* op) {} @@ -136,11 +139,26 @@ void IRVisitor::Visit_(const Add* op) { this->Visit(op->b); } +void IRVisitor::Visit_(const Sub* op) { + this->Visit(op->a); + this->Visit(op->b); +} + void IRVisitor::Visit_(const Mul* op) { this->Visit(op->a); this->Visit(op->b); } +void IRVisitor::Visit_(const Div* op) { + this->Visit(op->a); + this->Visit(op->b); +} + +void IRVisitor::Visit_(const Mod* op) { + this->Visit(op->a); + this->Visit(op->b); +} + void IRVisitor::Visit_(const LT* op) { this->Visit(op->a); this->Visit(op->b); @@ -170,10 +188,10 @@ inline void Binary(const T* op, IRVisitor* v) { TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) // .set_dispatch(Binary) -.set_dispatch(Binary) +// .set_dispatch(Binary) // .set_dispatch(Binary) -.set_dispatch
(Binary
) -.set_dispatch(Binary) +// .set_dispatch
(Binary
) +// .set_dispatch(Binary) .set_dispatch(Binary) .set_dispatch(Binary) .set_dispatch(Binary) diff --git a/src/pass/partition_loops.cc b/src/pass/partition_loops.cc deleted file mode 100644 index 6ebf3a96093e..000000000000 --- a/src/pass/partition_loops.cc +++ /dev/null @@ -1,167 +0,0 @@ -#include -#include -#include - - -namespace tvm { -namespace ir { - -class VariableFinder: public IRVisitor { - public: - VariableFinder(Var target) : target_(target) {} - - void Visit(const NodeRef& node) final { - if (finded) return; - if (visited_.count(node.get()) != 0) return; - visited_.insert(node.get()); - - path_.push_back(node.get()); - if (node.same_as(target_)) finded = true; - IRVisitor::Visit(node); - if (!finded) path_.pop_back(); - } - - std::vector path_; - - private: - bool finded{false}; - Var target_; - std::unordered_set visited_; -}; - - -// Get the path to the variable -std::vector GetPath(Var target, Expr expr) { - VariableFinder v(target); - v.Visit(expr); - return v.path_; -} - - - -// IRVisitor version -class Deducer: public IRVisitor { - public: - Expr Deduce(Var target, Expr expr) { - path_ = GetPath(target, expr); - target_ = target; - iter = 0; - - LOG(INFO) << "Path"; - for (const Node* n : path_) { - LOG(INFO) << n->type_key(); - } - Visit(expr); - return result; - } - - void Visit(const NodeRef& e) final { - if (e.get() == path_[iter++]) { - LOG(INFO) << "Deduce " << e->type_key(); - IRVisitor::Visit(e); - } else { - LOG(INFO) << "ERROR " << e->type_key(); - } - } - - void Visit_(const LT* op) final { - result = op->b; - Visit(op->a); - } - - void Visit_(const Add* op) final { - bool left = op->a.get() == path_[iter]; - result -= left ? op->b : op->a; - Visit(left ? op->a : op->b); - } - - void Visit_(const Mul* op) final { - bool left = op->a.get() == path_[iter]; - result /= left ? op->b : op->a; - Visit(left ? op->a : op->b); - } - - Expr result; - private: - Var target_; - std::vector path_; - size_t iter; -}; - - -// IRMutator version -class DeduceMutator { - public: - Expr Deduce(Var target, Expr expr) { - this->path_ = GetPath(target, expr); - this->target = target; - this->iter = 0; - - LOG(INFO) << "Path"; - for (const Node* n : path_) { - LOG(INFO) << n->type_key(); - } - return Mutate(expr, expr); - } - - Expr Mutate(const NodeRef& node, Expr result) { - if (node.get() == path_[iter++]) { - LOG(INFO) << "Deduce " << node->type_key(); - static const FMutateExpr& f = vtable_expr(); - return f(node, result, this); - } else { - LOG(INFO) << "Error " << node->type_key(); - return result; - } - } - - const Node* GetCurrentNode() { - return path_[iter]; - } - - using FMutateExpr = IRFunctor; - static FMutateExpr& vtable_expr(); - - Var target; - private: - std::vector path_; - size_t iter; -}; - -DeduceMutator::FMutateExpr& DeduceMutator::vtable_expr() { // NOLINT(*) - static FMutateExpr inst; return inst; -} - -TVM_STATIC_IR_FUNCTOR(DeduceMutator, vtable_expr) -.set_dispatch([](const LT* op, Expr& res, DeduceMutator* m) { - return m->Mutate(op->a, op->b); -}) -.set_dispatch([](const Mul* op, Expr& res, DeduceMutator* m) { - bool left = op->a.get() == m->GetCurrentNode(); - res /= left ? op->b : op->a; - return m->Mutate(left ? op->a : op->b, res); -}) -.set_dispatch([](const Add* op, Expr& res, DeduceMutator* m) { - bool left = op->a.get() == m->GetCurrentNode(); - res -= left ? op->b : op->a; - return m->Mutate(left ? op->a : op->b, res); -}) -.set_dispatch([](const Variable* op, Expr& res, DeduceMutator* m) { - return res; -}); - - - - -Expr Deduce(Var v, Expr e) { - // x*y+z < a - LOG(INFO) << "Deduce"; - // Deducer deducer; - // deducer.Deduce(v, e); - // return deducer.result; - DeduceMutator deducer; - return deducer.Deduce(v, e); -} - -} -} diff --git a/tests/python/unittest/test_pass_deduce.py b/tests/python/unittest/test_pass_deduce.py index c4d8e68a559f..b1aae31d5d5b 100644 --- a/tests/python/unittest/test_pass_deduce.py +++ b/tests/python/unittest/test_pass_deduce.py @@ -6,7 +6,7 @@ a = tvm.Var('a') b = tvm.Var('b') -e0 = (x*y+z Date: Mon, 13 Feb 2017 20:22:00 +0000 Subject: [PATCH 05/23] move IntervalSet and StrideSet into int_set_internal.h --- src/arithmetic/bound_deducer.cc | 27 ++++++++------- src/arithmetic/int_set.cc | 41 +---------------------- src/arithmetic/int_set_internal.h | 55 +++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 51 deletions(-) create mode 100644 src/arithmetic/int_set_internal.h diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 54ca077f35ed..2329833cd9f5 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -8,6 +8,7 @@ #include #include #include "./int_set.h" +#include "./int_set_internal.h" namespace tvm { namespace arith { @@ -17,38 +18,38 @@ using Halide::Internal::Interval; // a visitor to find the path to the target variable // from a expression. -class VariableFinder: public IRVisitor { +// (TODO) look out for errors when a variable appears in +// multiple locations in the expression +class VariablePathFinder: public IRVisitor { public: - explicit VariableFinder(Var target) : target_(target) {} + explicit VariablePathFinder(Var target) : target_(target) {} void Visit(const NodeRef& node) final { - if (finded_) return; + if (found_) return; if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); path_.push_back(node.get()); - if (node.same_as(target_)) finded_ = true; + if (node.same_as(target_)) found_ = true; IRVisitor::Visit(node); - if (!finded_) path_.pop_back(); + if (!found_) path_.pop_back(); } std::vector path_; private: - bool finded_{false}; + bool found_{false}; Var target_; std::unordered_set visited_; }; - // get the path to the variable std::vector GetPath(Var target, Expr expr) { - VariableFinder v(target); + VariablePathFinder v(target); v.Visit(expr); return v.path_; } - // a visitor to deduce the bound of a variable from a expression class BoundDeducer: public IRVisitor { public: @@ -93,6 +94,10 @@ class BoundDeducer: public IRVisitor { Expr operand = left ? op->b : op->a; if (is_negative_const(operand)) is_greater = !is_greater; result /= operand; + // (TODO) There will be problem of rounding in here. + // if it is a lower bound and rounds toward 0, then + // it becomes problematic. Maybe we should consider + // find out the direction first, before doing deduction Visit(left ? op->a : op->b); } @@ -119,8 +124,8 @@ IntSet DeduceBound(Var v, Expr e) { deducer.Deduce(v, e); Type t = deducer.result.type(); return deducer.is_greater ? - IntSet::range(Range(deducer.result, Cast::make(t, Interval::pos_inf))) : - IntSet::range(Range(Cast::make(t, Interval::neg_inf), deducer.result)); + IntervalSet::make(deducer.result, Interval::pos_inf) : + IntervalSet::make(Interval::neg_inf, deducer.result); } TVM_REGISTER_API(_pass_DeduceBound) diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index d60504f2c51e..6e2948017efd 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -8,53 +8,14 @@ #include #include "./int_set.h" #include "./compute_expr.h" +#include "./int_set_internal.h" namespace tvm { namespace arith { using Halide::Internal::Interval; - using namespace ir; -/*! \brief Set of continuous interval */ -struct IntervalSet : public IntSetNode { - /*! \brief the internal interval*/ - Interval i; - - static IntSet make(Interval i) { - std::shared_ptr n = - std::make_shared(); - n->i = i; - return IntSet(n); - } - static IntSet make(Expr min, Expr max) { - std::shared_ptr n = - std::make_shared(); - n->i.min = min; - n->i.max = max; - return IntSet(n); - } - - static constexpr const char* _type_key = "IntervalSet"; - TVM_DECLARE_NODE_TYPE_INFO(IntervalSet); -}; - -/*! - * \brief set represented by strided integers - * Reserved for cases where strided access is supported. - */ -struct StrideSet : public IntSetNode { - /*! \brief the base inetrval */ - Interval base; - /*! \brief additional extents in positive number */ - Array extents; - /*! \brief additional strides in positive number */ - Array strides; - - static constexpr const char* _type_key = "StrideSet"; - TVM_DECLARE_NODE_TYPE_INFO(StrideSet); -}; - inline IntSet IntSet::cover_interval() const { if ((*this).as()) return *this; const StrideSet* s = (*this).as(); diff --git a/src/arithmetic/int_set_internal.h b/src/arithmetic/int_set_internal.h new file mode 100644 index 000000000000..e00dae3b8a01 --- /dev/null +++ b/src/arithmetic/int_set_internal.h @@ -0,0 +1,55 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file int_set_internal.h + * \brief (TODO) + */ +#include +#include +#include "./int_set.h" + +namespace tvm { +namespace arith { + +using Halide::Internal::Interval; + +/*! \brief Set of continuous interval */ +struct IntervalSet : public IntSetNode { + /*! \brief the internal interval*/ + Interval i; + + static IntSet make(Interval i) { + std::shared_ptr n = + std::make_shared(); + n->i = i; + return IntSet(n); + } + static IntSet make(Expr min, Expr max) { + std::shared_ptr n = + std::make_shared(); + n->i.min = min; + n->i.max = max; + return IntSet(n); + } + + static constexpr const char* _type_key = "IntervalSet"; + TVM_DECLARE_NODE_TYPE_INFO(IntervalSet); +}; + +/*! + * \brief set represented by strided integers + * Reserved for cases where strided access is supported. + */ +struct StrideSet : public IntSetNode { + /*! \brief the base inetrval */ + Interval base; + /*! \brief additional extents in positive number */ + Array extents; + /*! \brief additional strides in positive number */ + Array strides; + + static constexpr const char* _type_key = "StrideSet"; + TVM_DECLARE_NODE_TYPE_INFO(StrideSet); +}; + +} +} From cf9f3baf1b6d57ebf396b6337567f40c4be2f248 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Mon, 13 Feb 2017 22:50:44 +0000 Subject: [PATCH 06/23] add multiple failure for VariablePathFinder, add EvalSign --- src/arithmetic/bound_deducer.cc | 28 ++++++++++++++++------------ src/arithmetic/int_set.cc | 17 +++++++++++++++++ src/arithmetic/int_set.h | 12 ++++++++++++ src/arithmetic/int_set_internal.h | 4 ++-- 4 files changed, 47 insertions(+), 14 deletions(-) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 2329833cd9f5..b3325889b871 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -18,24 +18,32 @@ using Halide::Internal::Interval; // a visitor to find the path to the target variable // from a expression. -// (TODO) look out for errors when a variable appears in -// multiple locations in the expression class VariablePathFinder: public IRVisitor { public: explicit VariablePathFinder(Var target) : target_(target) {} void Visit(const NodeRef& node) final { + if (!success) return; if (found_) return; if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); - path_.push_back(node.get()); - if (node.same_as(target_)) found_ = true; + if (!found_) path_.push_back(node.get()); + if (node.same_as(target_)) { + if (!found_) { + found_ = true; + } else { + // target variable appears at multiple location + success = false; + return; + } + } IRVisitor::Visit(node); if (!found_) path_.pop_back(); } std::vector path_; + bool success{true}; private: bool found_{false}; @@ -93,11 +101,8 @@ class BoundDeducer: public IRVisitor { bool left = op->a.get() == path_[iter_]; Expr operand = left ? op->b : op->a; if (is_negative_const(operand)) is_greater = !is_greater; + // (TODO) round result /= operand; - // (TODO) There will be problem of rounding in here. - // if it is a lower bound and rounds toward 0, then - // it becomes problematic. Maybe we should consider - // find out the direction first, before doing deduction Visit(left ? op->a : op->b); } @@ -121,11 +126,10 @@ class BoundDeducer: public IRVisitor { // Assuming e >= 0, deduce the bound of variable from it. IntSet DeduceBound(Var v, Expr e) { BoundDeducer deducer; - deducer.Deduce(v, e); - Type t = deducer.result.type(); + Expr res = deducer.Deduce(v, e); return deducer.is_greater ? - IntervalSet::make(deducer.result, Interval::pos_inf) : - IntervalSet::make(Interval::neg_inf, deducer.result); + IntervalSet::make(res, Interval::pos_inf) : + IntervalSet::make(Interval::neg_inf, res); } TVM_REGISTER_API(_pass_DeduceBound) diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 6e2948017efd..b91d0813b6f7 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -60,6 +60,11 @@ bool IntSet::can_prove_positive() const { return (s_int && is_positive_const(ir::Simplify(s_int->i.min))); } +bool IntSet::can_prove_negative() const { + const IntervalSet* s_int = (*this).as(); + return (s_int && is_negative_const(ir::Simplify(s_int->i.max))); +} + Expr IntSet::point_value() const { const IntervalSet* s_int = (*this).as(); CHECK(s_int && s_int->i.is_single_point()); @@ -424,6 +429,18 @@ IntSet EvalSet(Range r, return Combine(min_set, ext_set); } +SignType EvalSign(Expr r, + const Map& dom_map) { + IntSet set = EvalSet(r, dom_map); + if (set.can_prove_positive()) { + return kPositive; + } else if (set.can_prove_negative()) { + return kNegative; + } else { + return kUnknown; + } +} + TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const IntervalSet *op, IRPrinter *p) { p->stream << "interval-set[" diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 979d138af9e2..f3ce9534517e 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -46,6 +46,8 @@ class IntSet : public NodeRef { bool is_single_point() const; /*! \return Whether the set is proved to be bigger than 0 */ bool can_prove_positive() const; + /*! \return Whether the set is proved to be smaller than 0 */ + bool can_prove_negative() const; /*! * \brief The single point value, call only if is_single_point is true * \return The point value. @@ -104,6 +106,16 @@ IntSet EvalSet(Expr e, IntSet EvalSet(Range r, const Map& dom_map); + +enum SignType { + kPositive, + kNegative, + kUnknown +}; + +SignType EvalSign(Expr r, + const Map& dom_map); + /*! * \brief Create an union set of all sets * \param sets The sets to be unioned diff --git a/src/arithmetic/int_set_internal.h b/src/arithmetic/int_set_internal.h index e00dae3b8a01..d5135a38cf74 100644 --- a/src/arithmetic/int_set_internal.h +++ b/src/arithmetic/int_set_internal.h @@ -51,5 +51,5 @@ struct StrideSet : public IntSetNode { TVM_DECLARE_NODE_TYPE_INFO(StrideSet); }; -} -} +} // namespace arith +} // namespace tvm From 8f72e502c821086aaa0d14170eb40dcd3866cb29 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 14 Feb 2017 00:24:15 +0000 Subject: [PATCH 07/23] consider round in deduce, add success flag --- src/arithmetic/bound_deducer.cc | 54 ++++++++++++++--------- src/arithmetic/int_set.cc | 2 + src/arithmetic/int_set.h | 9 ++++ tests/python/unittest/test_pass_deduce.py | 12 ----- 4 files changed, 45 insertions(+), 32 deletions(-) delete mode 100644 tests/python/unittest/test_pass_deduce.py diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index b3325889b871..b5b3a1998a9a 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -5,7 +5,6 @@ #include #include #include -#include #include #include "./int_set.h" #include "./int_set_internal.h" @@ -24,7 +23,6 @@ class VariablePathFinder: public IRVisitor { void Visit(const NodeRef& node) final { if (!success) return; - if (found_) return; if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); @@ -55,27 +53,34 @@ class VariablePathFinder: public IRVisitor { std::vector GetPath(Var target, Expr expr) { VariablePathFinder v(target); v.Visit(expr); - return v.path_; + return v.success ? v.path_ : std::vector(); } // a visitor to deduce the bound of a variable from a expression class BoundDeducer: public IRVisitor { public: - Expr Deduce(Var target, Expr expr) { - path_ = GetPath(target, expr); + void Deduce(Var target, Expr expr, + const Map& dom_map) { target_ = target; + dom_map_ = dom_map; + path_ = GetPath(target, expr); + if (path_.empty()) { + success = false; + return; + } iter_ = 0; result = make_zero(expr.type()); Visit(expr); - return result; } void Visit(const NodeRef& e) final { + if (!success) return; if (e.get() == path_[iter_++]) { IRVisitor::Visit(e); } else { - LOG(FATAL) << "the current node is not match with the deduced path"; + success = false; + return; } } @@ -100,9 +105,20 @@ class BoundDeducer: public IRVisitor { void Visit_(const Mul* op) final { bool left = op->a.get() == path_[iter_]; Expr operand = left ? op->b : op->a; - if (is_negative_const(operand)) is_greater = !is_greater; - // (TODO) round - result /= operand; + SignType sign = EvalSign(operand, dom_map_); + if (sign == SignType::kNegative) { + is_greater = !is_greater; + } else if (sign == SignType::kUnknown) { + // unable to get the sign of operand + success = false; + return; + } + // always use relax bound + if (is_greater) { + result = result / operand + 1; + } else { + result = result / operand - 1; + } Visit(left ? op->a : op->b); } @@ -116,27 +132,25 @@ class BoundDeducer: public IRVisitor { Expr result; bool is_greater{true}; + bool success{true}; private: Var target_; + Map dom_map_; std::vector path_; size_t iter_; }; // Assuming e >= 0, deduce the bound of variable from it. -IntSet DeduceBound(Var v, Expr e) { +IntSet DeduceBound(Var v, Expr e, + const Map& dom_map) { BoundDeducer deducer; - Expr res = deducer.Deduce(v, e); + deducer.Deduce(v, e, dom_map); + if (!deducer.success) return IntSet(); return deducer.is_greater ? - IntervalSet::make(res, Interval::pos_inf) : - IntervalSet::make(Interval::neg_inf, res); + IntervalSet::make(deducer.result, Interval::pos_inf) : + IntervalSet::make(Interval::neg_inf, deducer.result); } -TVM_REGISTER_API(_pass_DeduceBound) -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = DeduceBound(args[0].operator Var(), args[1].operator Expr()); - }); - - } // namespace arith } // namespace tvm diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index b91d0813b6f7..d0eafb8460a4 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -436,6 +436,8 @@ SignType EvalSign(Expr r, return kPositive; } else if (set.can_prove_negative()) { return kNegative; + } else if (set.is_single_point() && is_zero(set.point_value())) { + return kZero; } else { return kUnknown; } diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index f3ce9534517e..0b43cbee2a7d 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -110,9 +110,18 @@ IntSet EvalSet(Range r, enum SignType { kPositive, kNegative, + kZero, kUnknown }; +/*! + * \brief Find the sign of the expr, given the domain of each + * iteration variables. + * + * \param e The expression to be evaluated. + * \param dom_map The domain of each variable. + * \return the sign type of the expression. + */ SignType EvalSign(Expr r, const Map& dom_map); diff --git a/tests/python/unittest/test_pass_deduce.py b/tests/python/unittest/test_pass_deduce.py deleted file mode 100644 index b1aae31d5d5b..000000000000 --- a/tests/python/unittest/test_pass_deduce.py +++ /dev/null @@ -1,12 +0,0 @@ -import tvm - -x = tvm.Var('x') -y = tvm.Var('y') -z = tvm.Var('z') -a = tvm.Var('a') -b = tvm.Var('b') - -e0 = (-a+x*y+z-b) -print(type(e0)) -e1 = tvm.ir_pass.DeduceBound(x, e0) -print(e1) From 91069446472178aa1a706ec200cbd54cf9d2aaa0 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 14 Feb 2017 00:26:05 +0000 Subject: [PATCH 08/23] remove Visit_(Div) --- src/arithmetic/bound_deducer.cc | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index b5b3a1998a9a..cf578c163d69 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -122,14 +122,6 @@ class BoundDeducer: public IRVisitor { Visit(left ? op->a : op->b); } - void Visit_(const Div* op) final { - bool left = op->a.get() == path_[iter_]; - Expr operand = left ? op->b : op->a; - if (is_negative_const(operand)) is_greater = !is_greater; - result = left ? result * operand : operand / result; - Visit(left ? op->a : op->b); - } - Expr result; bool is_greater{true}; bool success{true}; From 1f1ff8fa420274cc62befbfb1b9e867f756b5469 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Tue, 14 Feb 2017 00:30:48 +0000 Subject: [PATCH 09/23] add comment, update HalideIR --- HalideIR | 2 +- src/arithmetic/bound_deducer.cc | 6 ++++-- src/arithmetic/int_set_internal.h | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/HalideIR b/HalideIR index 642ae50ac749..e68ae61cd541 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit 642ae50ac749c91c04483db04500163304d4334e +Subproject commit e68ae61cd541ac29efc9fafe2ad061479bcaa9c9 diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index cf578c163d69..110d4081c362 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -49,7 +49,8 @@ class VariablePathFinder: public IRVisitor { std::unordered_set visited_; }; -// get the path to the variable +// get the path to the variable, +// return empty vector to represent failure std::vector GetPath(Var target, Expr expr) { VariablePathFinder v(target); v.Visit(expr); @@ -133,7 +134,8 @@ class BoundDeducer: public IRVisitor { size_t iter_; }; -// Assuming e >= 0, deduce the bound of variable from it. +// assuming e >= 0, deduce the bound of variable from it. +// return empty set to represent deduce failure. IntSet DeduceBound(Var v, Expr e, const Map& dom_map) { BoundDeducer deducer; diff --git a/src/arithmetic/int_set_internal.h b/src/arithmetic/int_set_internal.h index d5135a38cf74..fe958c07f581 100644 --- a/src/arithmetic/int_set_internal.h +++ b/src/arithmetic/int_set_internal.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2017 by Contributors * \file int_set_internal.h - * \brief (TODO) + * \brief Implementations of integer set */ #include #include From b40904096b3070f2a8e3f3f7d33ef8d6ee1abab6 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Wed, 15 Feb 2017 16:43:20 +0000 Subject: [PATCH 10/23] expose intset to python --- python/tvm/__init__.py | 1 + python/tvm/_ctypes/_function.py | 1 + python/tvm/arith.py | 17 ++++++++++++++ src/api/api_arith.cc | 20 +++++++++++++++++ src/arithmetic/bound_deducer.cc | 36 +++++++++++++++++++++--------- src/arithmetic/int_set.cc | 37 +++++++++++++++++++++---------- src/arithmetic/int_set.h | 6 +++-- src/arithmetic/int_set_internal.h | 5 +++++ 8 files changed, 98 insertions(+), 25 deletions(-) create mode 100644 python/tvm/arith.py create mode 100644 src/api/api_arith.cc diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index c676e5cfeb67..33c148f92d97 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -4,6 +4,7 @@ from ._ctypes._node import register_node from . import tensor +from . import arith from . import expr from . import stmt from . import make diff --git a/python/tvm/_ctypes/_function.py b/python/tvm/_ctypes/_function.py index 6393d1fed58d..7e0927cac6a3 100644 --- a/python/tvm/_ctypes/_function.py +++ b/python/tvm/_ctypes/_function.py @@ -244,6 +244,7 @@ def _init_api_functions(root_namespace): module_internal = sys.modules["%s._api_internal" % root_namespace] namespace_match = { "_make_": sys.modules["%s.make" % root_namespace], + "_arith_": sys.modules["%s.arith" % root_namespace], "_pass_": sys.modules["%s.ir_pass" % root_namespace], "_codegen_": sys.modules["%s.codegen" % root_namespace], "_schedule_": sys.modules["%s.schedule" % root_namespace] diff --git a/python/tvm/arith.py b/python/tvm/arith.py new file mode 100644 index 000000000000..cff13724143b --- /dev/null +++ b/python/tvm/arith.py @@ -0,0 +1,17 @@ +# pylint: disable=protected-access, no-member +"""Arithmetic"""" +from __future__ import absolute_import as _abs +from ._ctypes._node import NodeBase, register_node + +@register_node +class IntSet(NodeBase): + pass + +@register_node +class IntervalSet(IntSet): + pass + +@register_node +class StrideSet(IntSet): + pass + diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc new file mode 100644 index 000000000000..d7b83e8bd16f --- /dev/null +++ b/src/api/api_arith.cc @@ -0,0 +1,20 @@ +/*! + * Copyright (c) 2016 by Contributors + * (TODO) + * \file api_arith.cc + */ +#include +#include +#include +#include "../arithmetic/int_set.h" + +namespace tvm { +namespace arith { + +TVM_REGISTER_API(_arith_single_point) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = IntSet::single_point(args[0]); + }); + +} // namespace arith +} // namespace tvm diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 110d4081c362..d84b97b27110 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -5,10 +5,13 @@ #include #include #include +#include #include +#include #include "./int_set.h" #include "./int_set_internal.h" + namespace tvm { namespace arith { @@ -61,9 +64,10 @@ std::vector GetPath(Var target, Expr expr) { class BoundDeducer: public IRVisitor { public: void Deduce(Var target, Expr expr, - const Map& dom_map) { + const Map& dom_map) { target_ = target; dom_map_ = dom_map; + // get the path path_ = GetPath(target, expr); if (path_.empty()) { success = false; @@ -71,6 +75,9 @@ class BoundDeducer: public IRVisitor { } iter_ = 0; result = make_zero(expr.type()); + // get the sign of every subexpr + sign_map_ = EvalSign(expr, dom_map); + LOG(INFO) << "get the sign map"; Visit(expr); } @@ -106,14 +113,14 @@ class BoundDeducer: public IRVisitor { void Visit_(const Mul* op) final { bool left = op->a.get() == path_[iter_]; Expr operand = left ? op->b : op->a; - SignType sign = EvalSign(operand, dom_map_); - if (sign == SignType::kNegative) { - is_greater = !is_greater; - } else if (sign == SignType::kUnknown) { - // unable to get the sign of operand - success = false; - return; - } + // SignType sign = EvalSign(operand, dom_map_); + // if (sign == SignType::kNegative) { + // is_greater = !is_greater; + // } else if (sign == SignType::kUnknown) { + // // unable to get the sign of operand + // success = false; + // return; + // } // always use relax bound if (is_greater) { result = result / operand + 1; @@ -129,7 +136,8 @@ class BoundDeducer: public IRVisitor { private: Var target_; - Map dom_map_; + Map dom_map_; + std::unordered_map sign_map_; std::vector path_; size_t iter_; }; @@ -137,7 +145,7 @@ class BoundDeducer: public IRVisitor { // assuming e >= 0, deduce the bound of variable from it. // return empty set to represent deduce failure. IntSet DeduceBound(Var v, Expr e, - const Map& dom_map) { + const Map& dom_map) { BoundDeducer deducer; deducer.Deduce(v, e, dom_map); if (!deducer.success) return IntSet(); @@ -146,5 +154,11 @@ IntSet DeduceBound(Var v, Expr e, IntervalSet::make(Interval::neg_inf, deducer.result); } + +TVM_REGISTER_API(_pass_DeduceBound) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = DeduceBound(args[0], args[1], args[2]); + }); + } // namespace arith } // namespace tvm diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index d0eafb8460a4..7d7aedd224fd 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include "./int_set.h" #include "./compute_expr.h" #include "./int_set_internal.h" @@ -351,6 +352,7 @@ class IntSetEvaluator { }; inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator*) { + LOG(INFO) << e->type_key() << " " << e; return IntSet::single_point(e); } @@ -361,6 +363,7 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) .set_dispatch([](const Variable* op, const Expr& e, IntSetEvaluator* m) { + LOG(INFO) << e->type_key() << " " << e; auto it = m->dom_map.find(op); if (it != m->dom_map.end()) { return it->second; @@ -372,12 +375,15 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) // binary operator template inline IntSet Binary(const T* op, const Expr& e, IntSetEvaluator* m) { + LOG(INFO) << e->type_key() << " " << e; IntSet a = m->Eval(op->a); IntSet b = m->Eval(op->b); if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { return IntSet::single_point(e); } + LOG(INFO) << "Before Combine"; IntSet r = Combine(a, b); + LOG(INFO) << "After Combine"; return r; } @@ -429,23 +435,30 @@ IntSet EvalSet(Range r, return Combine(min_set, ext_set); } -SignType EvalSign(Expr r, - const Map& dom_map) { - IntSet set = EvalSet(r, dom_map); - if (set.can_prove_positive()) { - return kPositive; - } else if (set.can_prove_negative()) { - return kNegative; - } else if (set.is_single_point() && is_zero(set.point_value())) { - return kZero; - } else { - return kUnknown; +std::unordered_map EvalSetForSubExpr(Expr e, + std::unordered_map& dom_map) { + IntSetEvaluator m(dom_map); + m.Eval(e); + LOG(INFO) << "Eval Finished"; + return std::unordered_map(); +} + +std::unordered_map EvalSign(Expr e, + const Map& dom_map) { + LOG(INFO) << e; + // LOG(INFO) << dom_map; + std::unordered_map dmap; + for (auto kv : dom_map) { + LOG(INFO) << kv.second->type_key(); + dmap[kv.first.get()] = kv.second; } + auto m = EvalSetForSubExpr(e, dmap); + return std::unordered_map(); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const IntervalSet *op, IRPrinter *p) { - p->stream << "interval-set[" + p->stream << "interval-set" << "[" << op->i.min << ", " << op->i.max << ']'; }); diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 0b43cbee2a7d..2a44c0c4156b 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -80,6 +80,8 @@ class IntSet : public NodeRef { * \brief Base class of all IntSet containers. */ struct IntSetNode : public Node { + // static constexpr const char* _type_key = "IntSet"; + // TVM_DECLARE_NODE_TYPE_INFO(IntSetNode); }; /*! @@ -122,8 +124,8 @@ enum SignType { * \param dom_map The domain of each variable. * \return the sign type of the expression. */ -SignType EvalSign(Expr r, - const Map& dom_map); +std::unordered_map EvalSign(Expr r, + const Map& dom_map); /*! * \brief Create an union set of all sets diff --git a/src/arithmetic/int_set_internal.h b/src/arithmetic/int_set_internal.h index fe958c07f581..b00e6bdc0ee4 100644 --- a/src/arithmetic/int_set_internal.h +++ b/src/arithmetic/int_set_internal.h @@ -3,6 +3,9 @@ * \file int_set_internal.h * \brief Implementations of integer set */ +#ifndef TVM_ARITHMETIC_INT_SET_INTERNAL_H_ +#define TVM_ARITHMETIC_INT_SET_INTERNAL_H_ + #include #include #include "./int_set.h" @@ -53,3 +56,5 @@ struct StrideSet : public IntSetNode { } // namespace arith } // namespace tvm + +#endif // TVM_ARITHMETIC_INT_SET_INTERNAL_H_ From e4bee2777f6d60aa1b9d8008f39edf73d5b390e4 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 16 Feb 2017 06:14:45 +0800 Subject: [PATCH 11/23] check the sign of every expr --- python/tvm/arith.py | 2 +- src/api/api_arith.cc | 7 +++- src/arithmetic/bound_deducer.cc | 18 ++++----- src/arithmetic/int_set.cc | 47 ++++++++++++++++------ src/arithmetic/int_set.h | 18 +++++---- tests/python/unittest/test_arith_intset.py | 13 ++++++ 6 files changed, 74 insertions(+), 31 deletions(-) create mode 100644 tests/python/unittest/test_arith_intset.py diff --git a/python/tvm/arith.py b/python/tvm/arith.py index cff13724143b..525a91981a19 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -1,5 +1,5 @@ # pylint: disable=protected-access, no-member -"""Arithmetic"""" +"""Arithmetic""" from __future__ import absolute_import as _abs from ._ctypes._node import NodeBase, register_node diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index d7b83e8bd16f..16450d24ccca 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -11,10 +11,15 @@ namespace tvm { namespace arith { -TVM_REGISTER_API(_arith_single_point) +TVM_REGISTER_API(_arith_intset_single_point) .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = IntSet::single_point(args[0]); }); +TVM_REGISTER_API(_arith_intset_range) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = IntSet::range(args[0], args[1]); + }); + } // namespace arith } // namespace tvm diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index d84b97b27110..d7d1188e398f 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -113,14 +113,14 @@ class BoundDeducer: public IRVisitor { void Visit_(const Mul* op) final { bool left = op->a.get() == path_[iter_]; Expr operand = left ? op->b : op->a; - // SignType sign = EvalSign(operand, dom_map_); - // if (sign == SignType::kNegative) { - // is_greater = !is_greater; - // } else if (sign == SignType::kUnknown) { - // // unable to get the sign of operand - // success = false; - // return; - // } + SignType sign = sign_map_[operand.get()]; + if (sign == SignType::kNegative) { + is_greater = !is_greater; + } else if (sign == SignType::kUnknown) { + // unable to get the sign of operand + success = false; + return; + } // always use relax bound if (is_greater) { result = result / operand + 1; @@ -137,7 +137,7 @@ class BoundDeducer: public IRVisitor { private: Var target_; Map dom_map_; - std::unordered_map sign_map_; + std::unordered_map sign_map_; std::vector path_; size_t iter_; }; diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 7d7aedd224fd..0542e29ce05a 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -66,6 +66,17 @@ bool IntSet::can_prove_negative() const { return (s_int && is_negative_const(ir::Simplify(s_int->i.max))); } +SignType IntSet::sign_type() const { + if (can_prove_positive()) { + return kPositive; + } else if (can_prove_negative()) { + return kNegative; + } else if (is_single_point() && is_zero(point_value())) { + return kZero; + } else { + return kUnknown; + } +} Expr IntSet::point_value() const { const IntervalSet* s_int = (*this).as(); CHECK(s_int && s_int->i.is_single_point()); @@ -92,6 +103,10 @@ IntSet IntSet::range(Range r) { return IntervalSet::make(r->min, (r->extent + r->min) - 1); } +IntSet IntSet::range(Expr min, Expr max) { + return IntervalSet::make(min, max); +} + // Check if a is created from b. bool IntSet::match_range(const Range& b) const { const IntSet& a = *this; @@ -349,11 +364,14 @@ class IntSetEvaluator { } const std::unordered_map& dom_map; + std::unordered_map expr_map; }; -inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator*) { +inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator* m) { LOG(INFO) << e->type_key() << " " << e; - return IntSet::single_point(e); + IntSet res = IntSet::single_point(e); + m->expr_map[e.get()] = res; + return res; } TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) @@ -364,12 +382,15 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) .set_dispatch([](const Variable* op, const Expr& e, IntSetEvaluator* m) { LOG(INFO) << e->type_key() << " " << e; + IntSet res; auto it = m->dom_map.find(op); if (it != m->dom_map.end()) { - return it->second; + res = it->second; } else { - return IntSet::single_point(e); + res = IntSet::single_point(e); } + m->expr_map[e.get()] = res; + return res; }); // binary operator @@ -381,9 +402,8 @@ inline IntSet Binary(const T* op, const Expr& e, IntSetEvaluator* m) { if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { return IntSet::single_point(e); } - LOG(INFO) << "Before Combine"; IntSet r = Combine(a, b); - LOG(INFO) << "After Combine"; + m->expr_map[e.get()] = r; return r; } @@ -435,15 +455,14 @@ IntSet EvalSet(Range r, return Combine(min_set, ext_set); } -std::unordered_map EvalSetForSubExpr(Expr e, +std::unordered_map EvalSetForSubExpr(Expr e, std::unordered_map& dom_map) { IntSetEvaluator m(dom_map); m.Eval(e); - LOG(INFO) << "Eval Finished"; - return std::unordered_map(); + return m.expr_map; } -std::unordered_map EvalSign(Expr e, +std::unordered_map EvalSign(Expr e, const Map& dom_map) { LOG(INFO) << e; // LOG(INFO) << dom_map; @@ -452,8 +471,12 @@ std::unordered_map EvalSign(Expr e, LOG(INFO) << kv.second->type_key(); dmap[kv.first.get()] = kv.second; } - auto m = EvalSetForSubExpr(e, dmap); - return std::unordered_map(); + auto expr_map = EvalSetForSubExpr(e, dmap); + std::unordered_map res; + for (auto kv : expr_map) { + res[kv.first] = kv.second.sign_type(); + } + return res; } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 2a44c0c4156b..4d45a2716b06 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -12,6 +12,13 @@ namespace tvm { namespace arith { +enum SignType { + kPositive, + kNegative, + kZero, + kUnknown +}; + // internal node container of int set. class IntSetNode; @@ -48,6 +55,7 @@ class IntSet : public NodeRef { bool can_prove_positive() const; /*! \return Whether the set is proved to be smaller than 0 */ bool can_prove_negative() const; + SignType sign_type() const; /*! * \brief The single point value, call only if is_single_point is true * \return The point value. @@ -74,6 +82,7 @@ class IntSet : public NodeRef { * \return constructed set. */ static IntSet range(Range r); + static IntSet range(Expr min, Expr max); }; /*! @@ -109,13 +118,6 @@ IntSet EvalSet(Range r, const Map& dom_map); -enum SignType { - kPositive, - kNegative, - kZero, - kUnknown -}; - /*! * \brief Find the sign of the expr, given the domain of each * iteration variables. @@ -124,7 +126,7 @@ enum SignType { * \param dom_map The domain of each variable. * \return the sign type of the expression. */ -std::unordered_map EvalSign(Expr r, +std::unordered_map EvalSign(Expr r, const Map& dom_map); /*! diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py new file mode 100644 index 000000000000..642dd38c6673 --- /dev/null +++ b/tests/python/unittest/test_arith_intset.py @@ -0,0 +1,13 @@ +import tvm + +x = tvm.Var('x') +y = tvm.Var('y') +z = tvm.Var('z') + +ys = tvm.arith.intset_range(2, 3) +zs = tvm.arith.intset_range(2, 3) + + +e0 = (-z)*x+y +e1 = tvm.ir_pass.DeduceBound(x, e0, {y: ys, z: zs}) +print(e1) From e3a5f9edbd527ca75741ec79dda469c86bed902a Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 16 Feb 2017 10:57:07 +0800 Subject: [PATCH 12/23] set return type as ExprSignType --- src/arithmetic/bound_deducer.cc | 5 ++-- src/arithmetic/int_set.cc | 41 +++++++++++++++------------------ src/arithmetic/int_set.h | 7 ++++-- 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index d7d1188e398f..64c3a6d6ba70 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -77,7 +77,6 @@ class BoundDeducer: public IRVisitor { result = make_zero(expr.type()); // get the sign of every subexpr sign_map_ = EvalSign(expr, dom_map); - LOG(INFO) << "get the sign map"; Visit(expr); } @@ -113,7 +112,7 @@ class BoundDeducer: public IRVisitor { void Visit_(const Mul* op) final { bool left = op->a.get() == path_[iter_]; Expr operand = left ? op->b : op->a; - SignType sign = sign_map_[operand.get()]; + SignType sign = sign_map_[operand]; if (sign == SignType::kNegative) { is_greater = !is_greater; } else if (sign == SignType::kUnknown) { @@ -137,7 +136,7 @@ class BoundDeducer: public IRVisitor { private: Var target_; Map dom_map_; - std::unordered_map sign_map_; + ExprSignMap sign_map_; std::vector path_; size_t iter_; }; diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 0542e29ce05a..d7e0691807f8 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -17,6 +17,9 @@ namespace arith { using Halide::Internal::Interval; using namespace ir; +using ExprIntSetMap = std::unordered_map; + inline IntSet IntSet::cover_interval() const { if ((*this).as()) return *this; const StrideSet* s = (*this).as(); @@ -364,13 +367,12 @@ class IntSetEvaluator { } const std::unordered_map& dom_map; - std::unordered_map expr_map; + ExprIntSetMap expr_map; }; inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator* m) { - LOG(INFO) << e->type_key() << " " << e; IntSet res = IntSet::single_point(e); - m->expr_map[e.get()] = res; + m->expr_map[e] = res; return res; } @@ -381,7 +383,6 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) .set_dispatch([](const Variable* op, const Expr& e, IntSetEvaluator* m) { - LOG(INFO) << e->type_key() << " " << e; IntSet res; auto it = m->dom_map.find(op); if (it != m->dom_map.end()) { @@ -389,21 +390,20 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) } else { res = IntSet::single_point(e); } - m->expr_map[e.get()] = res; + m->expr_map[e] = res; return res; }); // binary operator template inline IntSet Binary(const T* op, const Expr& e, IntSetEvaluator* m) { - LOG(INFO) << e->type_key() << " " << e; IntSet a = m->Eval(op->a); IntSet b = m->Eval(op->b); if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { return IntSet::single_point(e); } IntSet r = Combine(a, b); - m->expr_map[e.get()] = r; + m->expr_map[e] = r; return r; } @@ -455,26 +455,23 @@ IntSet EvalSet(Range r, return Combine(min_set, ext_set); } -std::unordered_map EvalSetForSubExpr(Expr e, - std::unordered_map& dom_map) { - IntSetEvaluator m(dom_map); - m.Eval(e); - return m.expr_map; -} - -std::unordered_map EvalSign(Expr e, +ExprSignMap EvalSign(Expr e, const Map& dom_map) { - LOG(INFO) << e; - // LOG(INFO) << dom_map; std::unordered_map dmap; for (auto kv : dom_map) { - LOG(INFO) << kv.second->type_key(); dmap[kv.first.get()] = kv.second; } - auto expr_map = EvalSetForSubExpr(e, dmap); - std::unordered_map res; - for (auto kv : expr_map) { - res[kv.first] = kv.second.sign_type(); + + IntSetEvaluator m(dmap); + m.Eval(e); + + ExprSignMap res; + for (auto kv : m.expr_map) { + if (kv.first.type().is_uint()) { + res[kv.first] = kPositive; + } else { + res[kv.first] = kv.second.sign_type(); + } } return res; } diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 4d45a2716b06..6976db474f17 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -19,6 +19,9 @@ enum SignType { kUnknown }; +using ExprSignMap = std::unordered_map; + // internal node container of int set. class IntSetNode; @@ -126,8 +129,8 @@ IntSet EvalSet(Range r, * \param dom_map The domain of each variable. * \return the sign type of the expression. */ -std::unordered_map EvalSign(Expr r, - const Map& dom_map); +ExprSignMap EvalSign(Expr r, + const Map& dom_map); /*! * \brief Create an union set of all sets From b1617a8ae4dd5b4a3108d67ec70bfc69e07b28b2 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 16 Feb 2017 12:38:46 +0800 Subject: [PATCH 13/23] fine tune --- python/tvm/arith.py | 2 +- src/api/api_arith.cc | 8 +++- src/arithmetic/bound_deducer.cc | 52 +++++++++++++--------- src/arithmetic/int_set.cc | 28 +++--------- src/arithmetic/int_set.h | 38 +++++++++++----- tests/python/unittest/test_arith_intset.py | 2 +- 6 files changed, 74 insertions(+), 56 deletions(-) diff --git a/python/tvm/arith.py b/python/tvm/arith.py index 525a91981a19..6ff05ff6a0a5 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -1,5 +1,5 @@ # pylint: disable=protected-access, no-member -"""Arithmetic""" +"""Arithmetic data structure and utility""" from __future__ import absolute_import as _abs from ._ctypes._node import NodeBase, register_node diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 16450d24ccca..23cb641c6ff4 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2016 by Contributors - * (TODO) + * Implementation of API functions related to arith * \file api_arith.cc */ #include @@ -21,5 +21,11 @@ TVM_REGISTER_API(_arith_intset_range) *ret = IntSet::range(args[0], args[1]); }); +TVM_REGISTER_API(_arith_DeduceBound) +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = DeduceBound(args[0], args[1], args[2]); + }); + + } // namespace arith } // namespace tvm diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 64c3a6d6ba70..b0576fb34f99 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -1,6 +1,7 @@ /*! * Copyright (c) 2017 by Contributors * \file bound_deducer.cc + * \brief Utility to deduce bound of expression */ #include #include @@ -11,7 +12,6 @@ #include "./int_set.h" #include "./int_set_internal.h" - namespace tvm { namespace arith { @@ -26,7 +26,10 @@ class VariablePathFinder: public IRVisitor { void Visit(const NodeRef& node) final { if (!success) return; - if (visited_.count(node.get()) != 0) return; + if (visited_.count(node.get()) != 0 && + !node.same_as(target_)) { + return; + } visited_.insert(node.get()); if (!found_) path_.push_back(node.get()); @@ -63,10 +66,9 @@ std::vector GetPath(Var target, Expr expr) { // a visitor to deduce the bound of a variable from a expression class BoundDeducer: public IRVisitor { public: - void Deduce(Var target, Expr expr, - const Map& dom_map) { - target_ = target; - dom_map_ = dom_map; + BoundDeducer(Var target, Expr expr, + const std::unordered_map& dom_map) + : target_(target), expr_(expr), dom_map_(dom_map) { // get the path path_ = GetPath(target, expr); if (path_.empty()) { @@ -76,7 +78,7 @@ class BoundDeducer: public IRVisitor { iter_ = 0; result = make_zero(expr.type()); // get the sign of every subexpr - sign_map_ = EvalSign(expr, dom_map); + expr_map_ = EvalSetForEachSubExpr(expr, dom_map); Visit(expr); } @@ -112,7 +114,14 @@ class BoundDeducer: public IRVisitor { void Visit_(const Mul* op) final { bool left = op->a.get() == path_[iter_]; Expr operand = left ? op->b : op->a; - SignType sign = sign_map_[operand]; + + SignType sign; + if (operand.type().is_uint()) { + sign = kPositive; + } else { + sign = expr_map_[operand].sign_type(); + } + if (sign == SignType::kNegative) { is_greater = !is_greater; } else if (sign == SignType::kUnknown) { @@ -120,6 +129,7 @@ class BoundDeducer: public IRVisitor { success = false; return; } + // always use relax bound if (is_greater) { result = result / operand + 1; @@ -135,29 +145,27 @@ class BoundDeducer: public IRVisitor { private: Var target_; - Map dom_map_; - ExprSignMap sign_map_; + Expr expr_; + const std::unordered_map& dom_map_; std::vector path_; size_t iter_; + ExprIntSetMap expr_map_; }; // assuming e >= 0, deduce the bound of variable from it. // return empty set to represent deduce failure. IntSet DeduceBound(Var v, Expr e, const Map& dom_map) { - BoundDeducer deducer; - deducer.Deduce(v, e, dom_map); - if (!deducer.success) return IntSet(); - return deducer.is_greater ? - IntervalSet::make(deducer.result, Interval::pos_inf) : - IntervalSet::make(Interval::neg_inf, deducer.result); + std::unordered_map dmap; + for (auto kv : dom_map) { + dmap[kv.first.get()] = kv.second; + } + BoundDeducer deducer(v, e, dmap); + if (!deducer.success) return IntSet(); + return deducer.is_greater ? + IntervalSet::make(deducer.result, Interval::pos_inf) : + IntervalSet::make(Interval::neg_inf, deducer.result); } - -TVM_REGISTER_API(_pass_DeduceBound) -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = DeduceBound(args[0], args[1], args[2]); - }); - } // namespace arith } // namespace tvm diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index d7e0691807f8..8b2b35112015 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -17,9 +17,6 @@ namespace arith { using Halide::Internal::Interval; using namespace ir; -using ExprIntSetMap = std::unordered_map; - inline IntSet IntSet::cover_interval() const { if ((*this).as()) return *this; const StrideSet* s = (*this).as(); @@ -107,6 +104,9 @@ IntSet IntSet::range(Range r) { } IntSet IntSet::range(Expr min, Expr max) { + if (min.same_as(max)) { + return IntSet::single_point(min); + } return IntervalSet::make(min, max); } @@ -455,25 +455,11 @@ IntSet EvalSet(Range r, return Combine(min_set, ext_set); } -ExprSignMap EvalSign(Expr e, - const Map& dom_map) { - std::unordered_map dmap; - for (auto kv : dom_map) { - dmap[kv.first.get()] = kv.second; - } - - IntSetEvaluator m(dmap); +ExprIntSetMap EvalSetForEachSubExpr(Expr e, + const std::unordered_map& dom_map) { + IntSetEvaluator m(dom_map); m.Eval(e); - - ExprSignMap res; - for (auto kv : m.expr_map) { - if (kv.first.type().is_uint()) { - res[kv.first] = kPositive; - } else { - res[kv.first] = kv.second.sign_type(); - } - } - return res; + return m.expr_map; } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 6976db474f17..426f39236820 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -19,9 +19,6 @@ enum SignType { kUnknown }; -using ExprSignMap = std::unordered_map; - // internal node container of int set. class IntSetNode; @@ -58,6 +55,7 @@ class IntSet : public NodeRef { bool can_prove_positive() const; /*! \return Whether the set is proved to be smaller than 0 */ bool can_prove_negative() const; + /*! \return The sign of the elements in the integer set */ SignType sign_type() const; /*! * \brief The single point value, call only if is_single_point is true @@ -85,6 +83,12 @@ class IntSet : public NodeRef { * \return constructed set. */ static IntSet range(Range r); + /*! + * \brief Construct a set representing a range. + * \param min The minimum value of the range. + * \param max The maximum value of the range. + * \return constructed set. + */ static IntSet range(Expr min, Expr max); }; @@ -92,10 +96,11 @@ class IntSet : public NodeRef { * \brief Base class of all IntSet containers. */ struct IntSetNode : public Node { - // static constexpr const char* _type_key = "IntSet"; - // TVM_DECLARE_NODE_TYPE_INFO(IntSetNode); }; +using ExprIntSetMap = std::unordered_map; + /*! * \brief Find an symbolic integer set that contains all possible values of * e given the domain of each iteration variables. @@ -122,15 +127,15 @@ IntSet EvalSet(Range r, /*! - * \brief Find the sign of the expr, given the domain of each - * iteration variables. + * \brief Find the integer set of every sub-expression, given the + * domain of each iteration variables. * * \param e The expression to be evaluated. * \param dom_map The domain of each variable. - * \return the sign type of the expression. + * \return the map from the expression to its possible value. */ -ExprSignMap EvalSign(Expr r, - const Map& dom_map); +ExprIntSetMap EvalSetForEachSubExpr(Expr r, + const std::unordered_map& dom_map); /*! * \brief Create an union set of all sets @@ -144,6 +149,19 @@ inline const IntSetNode* IntSet::operator->() const { return static_cast(node_.get()); } +/*! + * \brief Deduce the bound of the target variable in a expression, + * give the domain of each variables. Return undefined IntSet to + * represent failure. + * + * \param v The target variable to be deduced. + * \param e The conditional expression. + * \param dom_map The domain of each variable. + * \return An integer set that can cover all the possible values. + */ +IntSet DeduceBound(Var v, Expr e, + const Map& dom_map); + } // namespace arith } // namespace tvm diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 642dd38c6673..26537c4ae1e7 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -9,5 +9,5 @@ e0 = (-z)*x+y -e1 = tvm.ir_pass.DeduceBound(x, e0, {y: ys, z: zs}) +e1 = tvm.arith.DeduceBound(x, e0, {y: ys, z: zs}) print(e1) From 7abe378e7fc2717092860950c7d864845d8e8ae9 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 16 Feb 2017 13:01:32 +0800 Subject: [PATCH 14/23] add min & max python api for interval set --- python/tvm/arith.py | 7 ++++++- src/api/api_arith.cc | 12 ++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/python/tvm/arith.py b/python/tvm/arith.py index 6ff05ff6a0a5..54fdf98f315a 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -2,6 +2,7 @@ """Arithmetic data structure and utility""" from __future__ import absolute_import as _abs from ._ctypes._node import NodeBase, register_node +from . import _api_internal @register_node class IntSet(NodeBase): @@ -9,7 +10,11 @@ class IntSet(NodeBase): @register_node class IntervalSet(IntSet): - pass + def min(self): + return _api_internal._IntervalSetGetMin(self) + + def max(self): + return _api_internal._IntervalSetGetMax(self) @register_node class StrideSet(IntSet): diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 23cb641c6ff4..35689e7dfbc9 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -7,6 +7,7 @@ #include #include #include "../arithmetic/int_set.h" +#include "../arithmetic/int_set_internal.h" namespace tvm { namespace arith { @@ -26,6 +27,17 @@ TVM_REGISTER_API(_arith_DeduceBound) *ret = DeduceBound(args[0], args[1], args[2]); }); +TVM_REGISTER_API(_IntervalSetGetMin) +.set_body([](TVMArgs args, TVMRetValue *ret) { + IntSet s = args[0].operator IntSet(); + *ret = s.as()->i.min; + }); + +TVM_REGISTER_API(_IntervalSetGetMax) +.set_body([](TVMArgs args, TVMRetValue *ret) { + IntSet s = args[0].operator IntSet(); + *ret = s.as()->i.max; + }); } // namespace arith } // namespace tvm From f829694abd55388bc3ad395e27293a789a71d4f6 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 16 Feb 2017 13:38:38 +0800 Subject: [PATCH 15/23] support for conditional expr --- src/arithmetic/bound_deducer.cc | 44 ++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index b0576fb34f99..a708b8eb4d3e 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -10,7 +10,6 @@ #include #include #include "./int_set.h" -#include "./int_set_internal.h" namespace tvm { namespace arith { @@ -93,6 +92,34 @@ class BoundDeducer: public IRVisitor { } } + void Visit_(const LT* op) final { + is_greater = false; + is_equal = false; + result = op->b; + Visit(op->a); + } + + void Visit_(const LE* op) final { + is_greater = false; + is_equal = true; + result = op->b; + Visit(op->a); + } + + void Visit_(const GT* op) final { + is_greater = true; + is_equal = false; + result = op->b; + Visit(op->a); + } + + void Visit_(const GE* op) final { + is_greater = true; + is_equal = true; + result = op->b; + Visit(op->a); + } + void Visit_(const Add* op) final { bool left = op->a.get() == path_[iter_]; result -= left ? op->b : op->a; @@ -141,6 +168,7 @@ class BoundDeducer: public IRVisitor { Expr result; bool is_greater{true}; + bool is_equal{true}; bool success{true}; private: @@ -160,11 +188,15 @@ IntSet DeduceBound(Var v, Expr e, for (auto kv : dom_map) { dmap[kv.first.get()] = kv.second; } - BoundDeducer deducer(v, e, dmap); - if (!deducer.success) return IntSet(); - return deducer.is_greater ? - IntervalSet::make(deducer.result, Interval::pos_inf) : - IntervalSet::make(Interval::neg_inf, deducer.result); + BoundDeducer d(v, e, dmap); + if (!d.success) return IntSet(); + Expr min = Interval::neg_inf, max = Interval::pos_inf; + if (d.is_greater) { + min = d.is_equal ? d.result : d.result+1; + } else { + max = d.is_equal ? d.result : d.result-1; + } + return IntSet::range(min, max); } } // namespace arith From 96ded33b17c8a3cef6ee63f155f5129db0aee854 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Thu, 16 Feb 2017 13:38:56 +0800 Subject: [PATCH 16/23] refactor test --- tests/python/unittest/test_arith_intset.py | 30 ++++++++++++++++------ 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 26537c4ae1e7..63084b916d15 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -1,13 +1,27 @@ import tvm -x = tvm.Var('x') -y = tvm.Var('y') -z = tvm.Var('z') +def test_basic(): + s = tvm.arith.intset_range(2, 3) + assert s.min().value == 2 + assert s.max().value == 3 -ys = tvm.arith.intset_range(2, 3) -zs = tvm.arith.intset_range(2, 3) +def test_deduce(): + a = tvm.Var('a') + b = tvm.Var('b') + c = tvm.Var('c') + d = tvm.Var('d') + b_s = tvm.arith.intset_range(2, 3) + c_s = tvm.arith.intset_range(5, 7) + d_s = tvm.arith.intset_range(-3, -1) -e0 = (-z)*x+y -e1 = tvm.arith.DeduceBound(x, e0, {y: ys, z: zs}) -print(e1) + e0 = (-b)*a+c-d*b + res = tvm.arith.DeduceBound(a, e0, {b: b_s, c: c_s, d: d_s}) + ans = ((0+d*b)-c)/(-b)-1 + print(res) + print(ans) + # assert print(res.max() == ans) # will print False + +if __name__ == "__main__": + test_basic() + test_deduce() From 71349f61d8202e57a6f8dd50d321ea39bc55b227 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 17 Feb 2017 07:54:58 +0800 Subject: [PATCH 17/23] add checker for BoundDeducer --- python/tvm/arith.py | 5 ++ src/arithmetic/bound_deducer.cc | 97 ++++++++++++++++++++++++--------- 2 files changed, 77 insertions(+), 25 deletions(-) diff --git a/python/tvm/arith.py b/python/tvm/arith.py index 54fdf98f315a..b75939b60728 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -6,17 +6,22 @@ @register_node class IntSet(NodeBase): + """Represent a set of integer in one dimension.""" pass @register_node class IntervalSet(IntSet): + """Represent set of continuous interval""" def min(self): + """get the minimum value""" return _api_internal._IntervalSetGetMin(self) def max(self): + """get the maximum value""" return _api_internal._IntervalSetGetMax(self) @register_node class StrideSet(IntSet): + """Represent set of strided integers""" pass diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index a708b8eb4d3e..9d9707d5724f 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -62,25 +62,17 @@ std::vector GetPath(Var target, Expr expr) { return v.success ? v.path_ : std::vector(); } +class Checker; + // a visitor to deduce the bound of a variable from a expression class BoundDeducer: public IRVisitor { public: + friend class Checker; BoundDeducer(Var target, Expr expr, const std::unordered_map& dom_map) - : target_(target), expr_(expr), dom_map_(dom_map) { - // get the path - path_ = GetPath(target, expr); - if (path_.empty()) { - success = false; - return; - } - iter_ = 0; - result = make_zero(expr.type()); - // get the sign of every subexpr - expr_map_ = EvalSetForEachSubExpr(expr, dom_map); + : target_(target), expr_(expr), dom_map_(dom_map) {} - Visit(expr); - } + void Deduce(); void Visit(const NodeRef& e) final { if (!success) return; @@ -132,7 +124,7 @@ class BoundDeducer: public IRVisitor { result += op->b; } else { result -= op->a; - result = -1 * result; + result = - result; is_greater = !is_greater; } Visit(left ? op->a : op->b); @@ -158,11 +150,7 @@ class BoundDeducer: public IRVisitor { } // always use relax bound - if (is_greater) { - result = result / operand + 1; - } else { - result = result / operand - 1; - } + result = result / operand + (is_greater ? 1 : -1); Visit(left ? op->a : op->b); } @@ -175,11 +163,69 @@ class BoundDeducer: public IRVisitor { Var target_; Expr expr_; const std::unordered_map& dom_map_; - std::vector path_; - size_t iter_; ExprIntSetMap expr_map_; + std::vector path_; + size_t iter_{0}; +}; + +class Checker: public IRVisitor { + public: + bool Check(BoundDeducer* deducer) { + deducer_ = deducer; + Visit(deducer_->expr_); + return target_count == 1 && cmp_count == 1; + } + + void Visit(const NodeRef& e) final { + if (e.same_as(deducer_->target_)) ++target_count; + IRVisitor::Visit(e); + } + + void Visit_(const LT* op) final { + ++cmp_count; + Visit(op->a); + Visit(op->b); + } + + void Visit_(const LE* op) final { + ++cmp_count; + Visit(op->a); + Visit(op->b); + } + + void Visit_(const GT* op) final { + ++cmp_count; + Visit(op->a); + Visit(op->b); + } + + void Visit_(const GE* op) final { + ++cmp_count; + Visit(op->a); + Visit(op->b); + } + + private: + BoundDeducer* deducer_; + size_t target_count{0}; + size_t cmp_count{0}; }; +void BoundDeducer::Deduce() { + result = make_zero(expr_.type()); + // get the path + path_ = GetPath(target_, expr_); + Checker checker; + if (!checker.Check(this) || path_.empty()) { + success = false; + return; + } + // get the sign of every subexpr + expr_map_ = EvalSetForEachSubExpr(expr_, dom_map_); + + Visit(expr_); +} + // assuming e >= 0, deduce the bound of variable from it. // return empty set to represent deduce failure. IntSet DeduceBound(Var v, Expr e, @@ -189,15 +235,16 @@ IntSet DeduceBound(Var v, Expr e, dmap[kv.first.get()] = kv.second; } BoundDeducer d(v, e, dmap); + d.Deduce(); if (!d.success) return IntSet(); Expr min = Interval::neg_inf, max = Interval::pos_inf; if (d.is_greater) { - min = d.is_equal ? d.result : d.result+1; + min = d.is_equal ? d.result : d.result + 1; } else { - max = d.is_equal ? d.result : d.result-1; + max = d.is_equal ? d.result : d.result - 1; } return IntSet::range(min, max); } -} // namespace arith -} // namespace tvm +} // namespace arith +} // namespace tvm From f3e3fa9474ebc4265bff255555483d10d641e6b6 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 17 Feb 2017 12:09:20 +0800 Subject: [PATCH 18/23] add python check test --- python/tvm/arith.py | 8 +++++- src/api/api_arith.cc | 12 ++++++++ src/arithmetic/bound_deducer.cc | 2 +- src/arithmetic/int_set.cc | 9 ++++++ src/arithmetic/int_set.h | 6 +++- tests/python/unittest/test_arith_intset.py | 33 ++++++++++++++++++---- 6 files changed, 61 insertions(+), 9 deletions(-) diff --git a/python/tvm/arith.py b/python/tvm/arith.py index b75939b60728..a18e05d00ba1 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -7,7 +7,13 @@ @register_node class IntSet(NodeBase): """Represent a set of integer in one dimension.""" - pass + def is_nothing(self): + """Whether the set represent nothing""" + return _api_internal._IntSetIsNothing(self) + + def is_everything(self): + """Whether the set represent everything""" + return _api_internal._IntSetIsEverything(self) @register_node class IntervalSet(IntSet): diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 35689e7dfbc9..8571dd62384f 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -39,5 +39,17 @@ TVM_REGISTER_API(_IntervalSetGetMax) *ret = s.as()->i.max; }); +TVM_REGISTER_API(_IntSetIsNothing) +.set_body([](TVMArgs args, TVMRetValue *ret) { + IntSet s = args[0].operator IntSet(); + *ret = s.is_nothing(); + }); + +TVM_REGISTER_API(_IntSetIsEverything) +.set_body([](TVMArgs args, TVMRetValue *ret) { + IntSet s = args[0].operator IntSet(); + *ret = s.is_everything(); + }); + } // namespace arith } // namespace tvm diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 9d9707d5724f..8bdc4cbf9950 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -236,7 +236,7 @@ IntSet DeduceBound(Var v, Expr e, } BoundDeducer d(v, e, dmap); d.Deduce(); - if (!d.success) return IntSet(); + if (!d.success) return IntSet::nothing(); Expr min = Interval::neg_inf, max = Interval::pos_inf; if (d.is_greater) { min = d.is_equal ? d.result : d.result + 1; diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index d0c05ca19c54..35fa611fcfc0 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -46,6 +46,11 @@ Range IntSet::cover_range(Range max_range) const { return max_range; } +bool IntSet::is_nothing() const { + const IntervalSet* s_int = (*this).as(); + return (s_int && s_int->i.is_empty()); +} + bool IntSet::is_everything() const { const IntervalSet* s_int = (*this).as(); return (s_int && s_int->i.is_everything()); @@ -83,6 +88,10 @@ Expr IntSet::point_value() const { return s_int->i.min; } +IntSet IntSet::nothing() { + return IntervalSet::make(Interval::nothing()); +} + IntSet IntSet::everything() { return IntervalSet::make(Interval::everything()); } diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index eb28440bf7d1..dcc74c691803 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -47,6 +47,8 @@ class IntSet : public NodeRef { * \return The covering interval set. */ IntSet cover_interval() const; + /*! \return Whether the set represent nothing */ + bool is_nothing() const; /*! \return Whether the set represent everything */ bool is_everything() const; /*! \return Whether the set is a single point */ @@ -69,7 +71,9 @@ class IntSet : public NodeRef { * \return true if we can prove they are the same. */ bool match_range(const Range& r) const; - /*! \return Whether the set contains everything */ + /*! \return The set contains nothing */ + static IntSet nothing(); + /*! \return The set contains everything */ static IntSet everything(); /*! * \brief construct a point set. diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 63084b916d15..4e00469a3277 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -15,13 +15,34 @@ def test_deduce(): c_s = tvm.arith.intset_range(5, 7) d_s = tvm.arith.intset_range(-3, -1) - e0 = (-b)*a+c-d*b - res = tvm.arith.DeduceBound(a, e0, {b: b_s, c: c_s, d: d_s}) - ans = ((0+d*b)-c)/(-b)-1 - print(res) - print(ans) - # assert print(res.max() == ans) # will print False + e0 = (-b)*a+c-d + res = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}) + ans = (d-c)/(-b)+(-1) + assert str(tvm.ir_pass.Simplify(res.max())) == str(ans) + +def test_check(): + a = tvm.Var('a') + b = tvm.Var('b') + c = tvm.Var('c') + d = tvm.Var('d') + + b_s = tvm.arith.intset_range(2, 3) + c_s = tvm.arith.intset_range(5, 7) + d_s = tvm.arith.intset_range(-3, -1) + + # no compare operator + res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}) + assert res1.is_nothing() + + # multiple compare operators + res2 = tvm.arith.DeduceBound(a, a+b>3>c , {b: b_s, c: c_s}) + assert res1.is_nothing() + + # multiple target variable + res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}) + assert res1.is_nothing() if __name__ == "__main__": test_basic() test_deduce() + test_check() From 35683a88350be1fd4a11ff50e8f05312b308b24b Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 17 Feb 2017 05:55:22 +0000 Subject: [PATCH 19/23] fix --- src/api/api_arith.cc | 12 +-- src/arithmetic/bound_deducer.cc | 89 +++++++++------- src/arithmetic/int_set.cc | 116 +++++++++++++++++---- src/arithmetic/int_set.h | 4 + tests/python/unittest/test_arith_intset.py | 13 ++- 5 files changed, 166 insertions(+), 68 deletions(-) diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 8571dd62384f..df48b3124d9f 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -29,26 +29,22 @@ TVM_REGISTER_API(_arith_DeduceBound) TVM_REGISTER_API(_IntervalSetGetMin) .set_body([](TVMArgs args, TVMRetValue *ret) { - IntSet s = args[0].operator IntSet(); - *ret = s.as()->i.min; + *ret = args[0].operator IntSet().min(); }); TVM_REGISTER_API(_IntervalSetGetMax) .set_body([](TVMArgs args, TVMRetValue *ret) { - IntSet s = args[0].operator IntSet(); - *ret = s.as()->i.max; + *ret = args[0].operator IntSet().max(); }); TVM_REGISTER_API(_IntSetIsNothing) .set_body([](TVMArgs args, TVMRetValue *ret) { - IntSet s = args[0].operator IntSet(); - *ret = s.is_nothing(); + *ret = args[0].operator IntSet().is_nothing(); }); TVM_REGISTER_API(_IntSetIsEverything) .set_body([](TVMArgs args, TVMRetValue *ret) { - IntSet s = args[0].operator IntSet(); - *ret = s.is_everything(); + *ret = args[0].operator IntSet().is_everything(); }); } // namespace arith diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 8bdc4cbf9950..94c636ae4641 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -63,11 +63,13 @@ std::vector GetPath(Var target, Expr expr) { } class Checker; +class Converter; // a visitor to deduce the bound of a variable from a expression class BoundDeducer: public IRVisitor { public: friend class Checker; + friend class Converter; BoundDeducer(Var target, Expr expr, const std::unordered_map& dom_map) : target_(target), expr_(expr), dom_map_(dom_map) {} @@ -85,31 +87,19 @@ class BoundDeducer: public IRVisitor { } void Visit_(const LT* op) final { - is_greater = false; - is_equal = false; - result = op->b; - Visit(op->a); + LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } void Visit_(const LE* op) final { - is_greater = false; - is_equal = true; - result = op->b; - Visit(op->a); + LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } void Visit_(const GT* op) final { - is_greater = true; - is_equal = false; - result = op->b; - Visit(op->a); + LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } void Visit_(const GE* op) final { - is_greater = true; - is_equal = true; - result = op->b; - Visit(op->a); + LOG(FATAL) << "unable to deduce due to multiple comparison operator"; } void Visit_(const Add* op) final { @@ -173,7 +163,7 @@ class Checker: public IRVisitor { bool Check(BoundDeducer* deducer) { deducer_ = deducer; Visit(deducer_->expr_); - return target_count == 1 && cmp_count == 1; + return target_count == 1; } void Visit(const NodeRef& e) final { @@ -181,45 +171,72 @@ class Checker: public IRVisitor { IRVisitor::Visit(e); } + private: + BoundDeducer* deducer_; + size_t target_count{0}; +}; + +class Converter: public IRVisitor { + public: + void Convert(BoundDeducer* deducer) { + deducer_ = deducer; + Visit(deducer_->expr_); + } + + void Visit(const NodeRef& e) final { + IRVisitor::Visit(e); + } + void Visit_(const LT* op) final { - ++cmp_count; - Visit(op->a); - Visit(op->b); + has_cmp = true; + deducer_->is_greater = false; + deducer_->is_equal = false; + deducer_->expr_ = op->a; + deducer_->result = op->b; } void Visit_(const LE* op) final { - ++cmp_count; - Visit(op->a); - Visit(op->b); + has_cmp = true; + deducer_->is_greater = false; + deducer_->is_equal = true; + deducer_->expr_ = op->a; + deducer_->result = op->b; } void Visit_(const GT* op) final { - ++cmp_count; - Visit(op->a); - Visit(op->b); + has_cmp = true; + deducer_->is_greater = true; + deducer_->is_equal = false; + deducer_->expr_ = op->a; + deducer_->result = op->b; } void Visit_(const GE* op) final { - ++cmp_count; - Visit(op->a); - Visit(op->b); + has_cmp = true; + deducer_->is_greater = true; + deducer_->is_equal = true; + deducer_->expr_ = op->a; + deducer_->result = op->b; } + bool has_cmp{false}; private: BoundDeducer* deducer_; - size_t target_count{0}; - size_t cmp_count{0}; }; + void BoundDeducer::Deduce() { result = make_zero(expr_.type()); + Checker checker; + if (!checker.Check(this)) success = false; + Converter converter; + converter.Convert(this); + // no compare op + if (!converter.has_cmp) success = false; + if (!success) return; + // get the path path_ = GetPath(target_, expr_); - Checker checker; - if (!checker.Check(this) || path_.empty()) { - success = false; - return; - } // get the sign of every subexpr expr_map_ = EvalSetForEachSubExpr(expr_, dom_map_); diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 35fa611fcfc0..4e64896a124c 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -46,6 +46,18 @@ Range IntSet::cover_range(Range max_range) const { return max_range; } +Expr IntSet::min() const { + const IntervalSet* s_int = (*this).as(); + CHECK(s_int); + return s_int->i.min; +} + +Expr IntSet::max() const { + const IntervalSet* s_int = (*this).as(); + CHECK(s_int); + return s_int->i.max; +} + bool IntSet::is_nothing() const { const IntervalSet* s_int = (*this).as(); return (s_int && s_int->i.is_empty()); @@ -366,7 +378,7 @@ class IntSetEvaluator { return f(expr, expr, this); } else { LOG(WARNING) << "cannot evaluate set type " << expr->type_key(); - return IntSet::everything(); + return IntSet::nothing(); } } @@ -376,13 +388,10 @@ class IntSetEvaluator { } const std::unordered_map& dom_map; - ExprIntSetMap expr_map; }; inline IntSet ConstOp(const NodeRef&, const Expr& e, IntSetEvaluator* m) { - IntSet res = IntSet::single_point(e); - m->expr_map[e] = res; - return res; + return IntSet::single_point(e); } TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) @@ -392,15 +401,12 @@ TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) .set_dispatch([](const Variable* op, const Expr& e, IntSetEvaluator* m) { - IntSet res; auto it = m->dom_map.find(op); if (it != m->dom_map.end()) { - res = it->second; + return it->second; } else { - res = IntSet::single_point(e); + return IntSet::single_point(e); } - m->expr_map[e] = res; - return res; }); // binary operator @@ -411,9 +417,7 @@ inline IntSet Binary(const T* op, const Expr& e, IntSetEvaluator* m) { if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { return IntSet::single_point(e); } - IntSet r = Combine(a, b); - m->expr_map[e] = r; - return r; + return Combine(a, b); } TVM_STATIC_IR_FUNCTOR(IntSetEvaluator, vtable) @@ -447,13 +451,6 @@ IntSet EvalSet(Expr e, return EvalSet(e, dmap); } -ExprIntSetMap EvalSetForEachSubExpr(Expr e, - const std::unordered_map& dom_map) { - IntSetEvaluator m(dom_map); - m.Eval(e); - return m.expr_map; -} - IntSet EvalSet(Range r, const std::unordered_map& dom_map) { IntSetEvaluator m(dom_map); @@ -465,6 +462,85 @@ IntSet EvalSet(Range r, return Combine(min_set, ext_set); } +class SubExprIntSetEvaluator : public IntSetEvaluator { + public: + explicit SubExprIntSetEvaluator(const std::unordered_map& dom_map) + : IntSetEvaluator(dom_map) {} + + inline IntSet Eval(Expr expr) { + static const FType& f = vtable(); + if (f.can_dispatch(expr)) { + IntSet res = f(expr, expr, this); + expr_map[expr] = res; + return res; + } else { + LOG(WARNING) << "cannot evaluate set type " << expr->type_key(); + return IntSet::nothing(); + } + } + + using FType = tvm::IRFunctor; + static FType& vtable() { // NOLINT(*) + static FType inst; return inst; + } + + ExprIntSetMap expr_map; +}; + +inline IntSet SubExprConstOp(const NodeRef&, const Expr& e, SubExprIntSetEvaluator* m) { + return IntSet::single_point(e); +} + +TVM_STATIC_IR_FUNCTOR(SubExprIntSetEvaluator, vtable) +.set_dispatch(SubExprConstOp) +.set_dispatch(SubExprConstOp) +.set_dispatch(SubExprConstOp); + +TVM_STATIC_IR_FUNCTOR(SubExprIntSetEvaluator, vtable) +.set_dispatch([](const Variable* op, const Expr& e, SubExprIntSetEvaluator* m) { + auto it = m->dom_map.find(op); + if (it != m->dom_map.end()) { + return it->second; + } else { + return IntSet::single_point(e); + } + }); + +// binary operator +template +inline IntSet SubExprBinary(const T* op, const Expr& e, SubExprIntSetEvaluator* m) { + IntSet a = m->Eval(op->a); + IntSet b = m->Eval(op->b); + if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { + return IntSet::single_point(e); + } + return Combine(a, b); +} + +TVM_STATIC_IR_FUNCTOR(SubExprIntSetEvaluator, vtable) +.set_dispatch(SubExprBinary) +.set_dispatch(SubExprBinary) +.set_dispatch(SubExprBinary) +.set_dispatch
(SubExprBinary
) +.set_dispatch(SubExprBinary) +.set_dispatch(SubExprBinary) +.set_dispatch(SubExprBinary) +.set_dispatch(SubExprBinary) +.set_dispatch(SubExprBinary) +.set_dispatch(SubExprBinary) +.set_dispatch(SubExprBinary) +.set_dispatch(SubExprBinary) +.set_dispatch(SubExprBinary) +.set_dispatch(SubExprBinary) +.set_dispatch(SubExprBinary); + +ExprIntSetMap EvalSetForEachSubExpr(Expr e, + const std::unordered_map& dom_map) { + SubExprIntSetEvaluator m(dom_map); + m.Eval(e); + return m.expr_map; +} + IntSet EvalSet(Range r, const Map& dom_map) { std::unordered_map dmap; diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index dcc74c691803..173bf248ce25 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -47,6 +47,10 @@ class IntSet : public NodeRef { * \return The covering interval set. */ IntSet cover_interval() const; + /*! \return Lower bound of the set */ + Expr min() const; + /*! \return upper bound of the set */ + Expr max() const; /*! \return Whether the set represent nothing */ bool is_nothing() const; /*! \return Whether the set represent everything */ diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 4e00469a3277..e116a93db7b6 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -12,13 +12,18 @@ def test_deduce(): d = tvm.Var('d') b_s = tvm.arith.intset_range(2, 3) - c_s = tvm.arith.intset_range(5, 7) + c_s = tvm.arith.intset_range(10, 15) d_s = tvm.arith.intset_range(-3, -1) e0 = (-b)*a+c-d - res = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}) - ans = (d-c)/(-b)+(-1) - assert str(tvm.ir_pass.Simplify(res.max())) == str(ans) + res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}) + ans0 = (d-c)/(-b)+(-1) + assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) + + e1 = (a*4+b < c) + res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}) + ans1 = (c-b)/4+(-2) + assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) def test_check(): a = tvm.Var('a') From d9794bbc36ab40819cc226a20fed42541acbecaa Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 17 Feb 2017 06:09:48 +0000 Subject: [PATCH 20/23] fix --- src/arithmetic/bound_deducer.cc | 19 ++------- src/arithmetic/int_set.cc | 68 +++------------------------------ 2 files changed, 8 insertions(+), 79 deletions(-) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 94c636ae4641..2476972f99e1 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -24,29 +24,16 @@ class VariablePathFinder: public IRVisitor { explicit VariablePathFinder(Var target) : target_(target) {} void Visit(const NodeRef& node) final { - if (!success) return; - if (visited_.count(node.get()) != 0 && - !node.same_as(target_)) { - return; - } + if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); if (!found_) path_.push_back(node.get()); - if (node.same_as(target_)) { - if (!found_) { - found_ = true; - } else { - // target variable appears at multiple location - success = false; - return; - } - } + if (node.same_as(target_)) found_ = true; IRVisitor::Visit(node); if (!found_) path_.pop_back(); } std::vector path_; - bool success{true}; private: bool found_{false}; @@ -59,7 +46,7 @@ class VariablePathFinder: public IRVisitor { std::vector GetPath(Var target, Expr expr) { VariablePathFinder v(target); v.Visit(expr); - return v.success ? v.path_ : std::vector(); + return v.path_; } class Checker; diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 4e64896a124c..a75e11d486f9 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -372,7 +372,7 @@ class IntSetEvaluator { explicit IntSetEvaluator(const std::unordered_map& dom_map) : dom_map(dom_map) {} - inline IntSet Eval(Expr expr) { + inline virtual IntSet Eval(Expr expr) { static const FType& f = vtable(); if (f.can_dispatch(expr)) { return f(expr, expr, this); @@ -467,73 +467,15 @@ class SubExprIntSetEvaluator : public IntSetEvaluator { explicit SubExprIntSetEvaluator(const std::unordered_map& dom_map) : IntSetEvaluator(dom_map) {} - inline IntSet Eval(Expr expr) { - static const FType& f = vtable(); - if (f.can_dispatch(expr)) { - IntSet res = f(expr, expr, this); - expr_map[expr] = res; - return res; - } else { - LOG(WARNING) << "cannot evaluate set type " << expr->type_key(); - return IntSet::nothing(); - } - } - - using FType = tvm::IRFunctor; - static FType& vtable() { // NOLINT(*) - static FType inst; return inst; + inline IntSet Eval(Expr expr) override { + IntSet ret = IntSetEvaluator::Eval(expr); + expr_map[expr] = ret; + return ret; } ExprIntSetMap expr_map; }; -inline IntSet SubExprConstOp(const NodeRef&, const Expr& e, SubExprIntSetEvaluator* m) { - return IntSet::single_point(e); -} - -TVM_STATIC_IR_FUNCTOR(SubExprIntSetEvaluator, vtable) -.set_dispatch(SubExprConstOp) -.set_dispatch(SubExprConstOp) -.set_dispatch(SubExprConstOp); - -TVM_STATIC_IR_FUNCTOR(SubExprIntSetEvaluator, vtable) -.set_dispatch([](const Variable* op, const Expr& e, SubExprIntSetEvaluator* m) { - auto it = m->dom_map.find(op); - if (it != m->dom_map.end()) { - return it->second; - } else { - return IntSet::single_point(e); - } - }); - -// binary operator -template -inline IntSet SubExprBinary(const T* op, const Expr& e, SubExprIntSetEvaluator* m) { - IntSet a = m->Eval(op->a); - IntSet b = m->Eval(op->b); - if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { - return IntSet::single_point(e); - } - return Combine(a, b); -} - -TVM_STATIC_IR_FUNCTOR(SubExprIntSetEvaluator, vtable) -.set_dispatch(SubExprBinary) -.set_dispatch(SubExprBinary) -.set_dispatch(SubExprBinary) -.set_dispatch
(SubExprBinary
) -.set_dispatch(SubExprBinary) -.set_dispatch(SubExprBinary) -.set_dispatch(SubExprBinary) -.set_dispatch(SubExprBinary) -.set_dispatch(SubExprBinary) -.set_dispatch(SubExprBinary) -.set_dispatch(SubExprBinary) -.set_dispatch(SubExprBinary) -.set_dispatch(SubExprBinary) -.set_dispatch(SubExprBinary) -.set_dispatch(SubExprBinary); - ExprIntSetMap EvalSetForEachSubExpr(Expr e, const std::unordered_map& dom_map) { SubExprIntSetEvaluator m(dom_map); From 696976ab1daa6f4e313d4d6e0b62263ef1d6b253 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 17 Feb 2017 07:46:57 +0000 Subject: [PATCH 21/23] change range to interval; remove converter --- src/api/api_arith.cc | 4 +- src/arithmetic/bound_deducer.cc | 104 ++++++++------------- src/arithmetic/int_set.cc | 2 +- src/arithmetic/int_set.h | 12 +-- tests/python/unittest/test_arith_intset.py | 14 +-- 5 files changed, 56 insertions(+), 80 deletions(-) diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index df48b3124d9f..db64a3c8e586 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -17,9 +17,9 @@ TVM_REGISTER_API(_arith_intset_single_point) *ret = IntSet::single_point(args[0]); }); -TVM_REGISTER_API(_arith_intset_range) +TVM_REGISTER_API(_arith_intset_interval) .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = IntSet::range(args[0], args[1]); + *ret = IntSet::interval(args[0], args[1]); }); TVM_REGISTER_API(_arith_DeduceBound) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 2476972f99e1..cc3d5f8550a1 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -49,18 +49,19 @@ std::vector GetPath(Var target, Expr expr) { return v.path_; } -class Checker; +class BoundDeduceIntputChecker; class Converter; // a visitor to deduce the bound of a variable from a expression class BoundDeducer: public IRVisitor { public: - friend class Checker; + friend class BoundDeduceInputChecker; friend class Converter; BoundDeducer(Var target, Expr expr, const std::unordered_map& dom_map) : target_(target), expr_(expr), dom_map_(dom_map) {} + bool Init(); void Deduce(); void Visit(const NodeRef& e) final { @@ -145,7 +146,7 @@ class BoundDeducer: public IRVisitor { size_t iter_{0}; }; -class Checker: public IRVisitor { +class BoundDeduceInputChecker: public IRVisitor { public: bool Check(BoundDeducer* deducer) { deducer_ = deducer; @@ -163,71 +164,46 @@ class Checker: public IRVisitor { size_t target_count{0}; }; -class Converter: public IRVisitor { - public: - void Convert(BoundDeducer* deducer) { - deducer_ = deducer; - Visit(deducer_->expr_); - } - - void Visit(const NodeRef& e) final { - IRVisitor::Visit(e); - } - - void Visit_(const LT* op) final { - has_cmp = true; - deducer_->is_greater = false; - deducer_->is_equal = false; - deducer_->expr_ = op->a; - deducer_->result = op->b; - } - - void Visit_(const LE* op) final { - has_cmp = true; - deducer_->is_greater = false; - deducer_->is_equal = true; - deducer_->expr_ = op->a; - deducer_->result = op->b; - } - - void Visit_(const GT* op) final { - has_cmp = true; - deducer_->is_greater = true; - deducer_->is_equal = false; - deducer_->expr_ = op->a; - deducer_->result = op->b; - } - - void Visit_(const GE* op) final { - has_cmp = true; - deducer_->is_greater = true; - deducer_->is_equal = true; - deducer_->expr_ = op->a; - deducer_->result = op->b; +bool BoundDeducer::Init() { + BoundDeduceInputChecker checker; + if (!checker.Check(this)) success = false; + + if (const LT* op = expr_.as()) { + is_greater = false; + is_equal = false; + expr_ = op->a; + result = op->b; + } else if (const LE* op = expr_.as()) { + is_greater = false; + is_equal = true; + expr_ = op->a; + result = op->b; + } else if (const GT* op = expr_.as()) { + is_greater = true; + is_equal = false; + expr_ = op->a; + result = op->b; + } else if (const GE* op = expr_.as()) { + is_greater = true; + is_equal = true; + expr_ = op->a; + result = op->b; + } else { + success = false; } - - bool has_cmp{false}; - private: - BoundDeducer* deducer_; -}; - + return success; +} void BoundDeducer::Deduce() { - result = make_zero(expr_.type()); - Checker checker; - if (!checker.Check(this)) success = false; - Converter converter; - converter.Convert(this); - // no compare op - if (!converter.has_cmp) success = false; - if (!success) return; + Init(); + if (!success) return; - // get the path - path_ = GetPath(target_, expr_); - // get the sign of every subexpr - expr_map_ = EvalSetForEachSubExpr(expr_, dom_map_); + // get the path + path_ = GetPath(target_, expr_); + // get the sign of every subexpr + expr_map_ = EvalSetForEachSubExpr(expr_, dom_map_); - Visit(expr_); + Visit(expr_); } // assuming e >= 0, deduce the bound of variable from it. @@ -247,7 +223,7 @@ IntSet DeduceBound(Var v, Expr e, } else { max = d.is_equal ? d.result : d.result - 1; } - return IntSet::range(min, max); + return IntSet::interval(min, max); } } // namespace arith diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index a75e11d486f9..8fdba6650f25 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -124,7 +124,7 @@ IntSet IntSet::range(Range r) { return IntervalSet::make(r->min, (r->extent + r->min) - 1); } -IntSet IntSet::range(Expr min, Expr max) { +IntSet IntSet::interval(Expr min, Expr max) { if (min.same_as(max)) { return IntSet::single_point(min); } diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h index 173bf248ce25..0c34fd71fae5 100644 --- a/src/arithmetic/int_set.h +++ b/src/arithmetic/int_set.h @@ -92,12 +92,12 @@ class IntSet : public NodeRef { */ static IntSet range(Range r); /*! - * \brief Construct a set representing a range. - * \param min The minimum value of the range. - * \param max The maximum value of the range. + * \brief Construct a set representing a interval. + * \param min The minimum value of the interval. + * \param max The maximum value of the interval. * \return constructed set. */ - static IntSet range(Expr min, Expr max); + static IntSet interval(Expr min, Expr max); }; /*! @@ -166,11 +166,11 @@ inline const IntSetNode* IntSet::operator->() const { * represent failure. * * \param v The target variable to be deduced. - * \param e The conditional expression. + * \param cond The conditional expression. * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values. */ -IntSet DeduceBound(Var v, Expr e, +IntSet DeduceBound(Var v, Expr cond, const Map& dom_map); } // namespace arith diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index e116a93db7b6..b60ed0d510b4 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -1,7 +1,7 @@ import tvm def test_basic(): - s = tvm.arith.intset_range(2, 3) + s = tvm.arith.intset_interval(2, 3) assert s.min().value == 2 assert s.max().value == 3 @@ -11,9 +11,9 @@ def test_deduce(): c = tvm.Var('c') d = tvm.Var('d') - b_s = tvm.arith.intset_range(2, 3) - c_s = tvm.arith.intset_range(10, 15) - d_s = tvm.arith.intset_range(-3, -1) + b_s = tvm.arith.intset_interval(2, 3) + c_s = tvm.arith.intset_interval(10, 15) + d_s = tvm.arith.intset_interval(-3, -1) e0 = (-b)*a+c-d res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}) @@ -31,9 +31,9 @@ def test_check(): c = tvm.Var('c') d = tvm.Var('d') - b_s = tvm.arith.intset_range(2, 3) - c_s = tvm.arith.intset_range(5, 7) - d_s = tvm.arith.intset_range(-3, -1) + b_s = tvm.arith.intset_interval(2, 3) + c_s = tvm.arith.intset_interval(5, 7) + d_s = tvm.arith.intset_interval(-3, -1) # no compare operator res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}) From 434835abaf703ff4d5d58651f7640b0022d54b03 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 17 Feb 2017 17:56:22 +0000 Subject: [PATCH 22/23] remove converter declaration --- src/arithmetic/bound_deducer.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index cc3d5f8550a1..b83215c4a36a 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -50,7 +50,6 @@ std::vector GetPath(Var target, Expr expr) { } class BoundDeduceIntputChecker; -class Converter; // a visitor to deduce the bound of a variable from a expression class BoundDeducer: public IRVisitor { From 2527b2e8d9d1ad81ee0394a612936e56cc3f1357 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Fri, 17 Feb 2017 17:58:36 +0000 Subject: [PATCH 23/23] remove int_set_internal.h --- src/api/api_arith.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index db64a3c8e586..7edbe3eec2a8 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -7,7 +7,6 @@ #include #include #include "../arithmetic/int_set.h" -#include "../arithmetic/int_set_internal.h" namespace tvm { namespace arith {