From a5b8a3248cc340b52fe34e12e70d9f89edd9ff98 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 4 Feb 2017 16:53:51 -0800 Subject: [PATCH] [PASS] UnrollLoop, isolate arithmetic module. --- include/tvm/ir_pass.h | 10 ++- include/tvm/runtime/packed_func.h | 2 +- python/tvm/build.py | 1 - src/README.md | 1 + src/api/api_pass.cc | 10 +++ src/{schedule => arithmetic}/compute_expr.h | 10 +-- src/{schedule => arithmetic}/int_set.cc | 96 +++------------------ src/{schedule => arithmetic}/int_set.h | 75 ++++------------ src/pass/inline.cc | 22 ++++- src/pass/simple_passes.cc | 4 +- src/pass/unroll_loop.cc | 78 +++++++++++++++++ src/schedule/bound.cc | 78 ++++++++++++++++- src/schedule/graph.cc | 1 - src/schedule/schedule_ops.cc | 14 ++- tests/python/unittest/test_pass_unroll.py | 20 +++++ 15 files changed, 261 insertions(+), 161 deletions(-) rename src/{schedule => arithmetic}/compute_expr.h (94%) rename src/{schedule => arithmetic}/int_set.cc (82%) rename src/{schedule => arithmetic}/int_set.h (56%) create mode 100644 src/pass/unroll_loop.cc create mode 100644 tests/python/unittest/test_pass_unroll.py diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 8eaec0f52315..f8412dc3666b 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -68,7 +68,7 @@ Stmt ConvertSSA(Stmt stmt); * \param value_map The map of new values. * \return The converted form. */ -Stmt Substitute(Stmt stmt, const Map& value_map); +Stmt Substitute(Stmt stmt, const Map& value_map); /*! * \brief inline all calls of f in stmt. @@ -97,6 +97,13 @@ Stmt Inline(Stmt stmt, Stmt StorageFlatten(Stmt stmt, Map extern_buffer); +/*! + * \brief unroll the constant loops + * \param stmt The statment to be unrolled. + * \param max_auto_step The maximum step to stop performing automatic unrolling. + */ +Stmt UnrollLoop(Stmt stmt, int max_auto_step); + /*! * \brief Make an user callable API LoweredFunc. * @@ -153,6 +160,7 @@ Array SplitHostDevice(LoweredFunc func); */ LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope); + } // namespace ir } // namespace tvm diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index eafc367fe3c5..3b1921ee8868 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -562,7 +562,7 @@ inline TVMArgValue TVMArgs::operator[](int i) const { CHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed" - << "but request arg" << i; + << " but request arg[" << i << "]."; return TVMArgValue(values[i], type_codes[i]); } diff --git a/python/tvm/build.py b/python/tvm/build.py index 29321eabe711..fbed0a33f849 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -70,7 +70,6 @@ def build(sch, fsplits = [x for x in fsplits] for i in range(1, len(fsplits)): fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared") - fsplits[i] = ir_pass.StorageSync(fsplits[i], "global") if record_codes is not None: output_ssa = False diff --git a/src/README.md b/src/README.md index 16dfc19d8f54..91cb47ece9ea 100644 --- a/src/README.md +++ b/src/README.md @@ -3,5 +3,6 @@ - api API functionr registration - lang The definition of DSL related data structure - schedule The operations on the schedule graph before converting to IR. +- arithmetic Arithmetic expression and set simplification - pass The optimization pass on the IR structure - runtime Minimum runtime related codes. diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 6e7bbd849171..df79996e4a6f 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include namespace tvm { @@ -29,6 +30,14 @@ TVM_REGISTER_API(_pass_Equal) } }); +TVM_REGISTER_API(_pass_PostOrderVisit) +.set_body([](TVMArgs args, TVMRetValue *ret) { + PackedFunc f = args[1]; + ir::PostOrderVisit(args[0], [f](const NodeRef& n) { + f(n); + }); + }); + // make from two arguments #define REGISTER_PASS1(PassName) \ TVM_REGISTER_API(_pass_## PassName) \ @@ -52,6 +61,7 @@ REGISTER_PASS1(ConvertSSA); REGISTER_PASS1(VerifySSA); REGISTER_PASS4(Inline); REGISTER_PASS2(StorageFlatten); +REGISTER_PASS2(UnrollLoop); REGISTER_PASS2(StorageSync); REGISTER_PASS4(MakeAPI); REGISTER_PASS1(SplitHostDevice); diff --git a/src/schedule/compute_expr.h b/src/arithmetic/compute_expr.h similarity index 94% rename from src/schedule/compute_expr.h rename to src/arithmetic/compute_expr.h index ee1947b61039..9550c1c96d2c 100644 --- a/src/schedule/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -4,14 +4,14 @@ * \brief Utility integer expression with quick eager simplification. * This is weaker than Simplify but can be done Eagerly. */ -#ifndef TVM_SCHEDULE_COMPUTE_EXPR_H_ -#define TVM_SCHEDULE_COMPUTE_EXPR_H_ +#ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_ +#define TVM_ARITHMETIC_COMPUTE_EXPR_H_ #include #include namespace tvm { -namespace schedule { +namespace arith { using Halide::Internal::add_would_overflow; using Halide::Internal::sub_would_overflow; @@ -104,6 +104,6 @@ inline Expr ComputeExpr(Expr a, Expr b) { return Halide::Internal::Interval::make_min(a, b); } -} // namespace schedule +} // namespace arith } // namespace tvm -#endif // TVM_SCHEDULE_COMPUTE_EXPR_H_ +#endif // TVM_ARITHMETIC_COMPUTE_EXPR_H_ diff --git a/src/schedule/int_set.cc b/src/arithmetic/int_set.cc similarity index 82% rename from src/schedule/int_set.cc rename to src/arithmetic/int_set.cc index 0da1a39e7c60..04b40191de11 100644 --- a/src/schedule/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -1,6 +1,6 @@ /*! - * Copyright (c) 2016 by Contributors - * \file int_set_impl.cc + * Copyright (c) 2017 by Contributors + * \file int_set.cc * \brief The integer set functions */ #include @@ -10,7 +10,7 @@ #include "./compute_expr.h" namespace tvm { -namespace schedule { +namespace arith { using Halide::Internal::Interval; @@ -94,6 +94,12 @@ bool IntSet::is_single_point() const { return (s_int && s_int->i.is_single_point()); } +Expr IntSet::point_value() const { + const IntervalSet* s_int = (*this).as(); + CHECK(s_int && s_int->i.is_single_point()); + return s_int->i.min; +} + IntSet IntSet::everything() { return IntervalSet::make(Interval::everything()); } @@ -115,8 +121,8 @@ IntSet IntSet::range(Range r) { } // Check if a is created from b. -inline bool MatchRange(const IntSet& a, - const Range& b) { +bool IntSet::match_range(const Range& b) const { + const IntSet& a = *this; const IntervalSet* a_int = a.as(); if (!a_int) return false; const Interval& i = a_int->i; @@ -349,84 +355,6 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) { return CombineSets(a, b); } -// Implementation of Evaluations and passing. -void PassUp(const SplitNode* s, - const std::unordered_map& dom_map, - const IntSet& outer, - const IntSet& inner, - IntSet* parent) { - if (dom_map.count(s->outer) && - dom_map.count(s->inner) && - dom_map.count(s->parent) && - MatchRange(outer, dom_map.at(s->outer)) && - MatchRange(inner, dom_map.at(s->inner))) { - *parent = IntSet::range(dom_map.at(s->parent)); - return; - } - Expr factor = dom_map.at(s->inner)->extent; - Expr parent_min = dom_map.at(s->parent)->min; - CHECK(outer.defined()); - CHECK(inner.defined()); - CHECK(factor.defined()); - - *parent = Combine( - Combine( - Combine(outer, IntSet::single_point(factor)), inner), - IntSet::single_point(parent_min)); -} - -void PassUp(const FuseNode* s, - const std::unordered_map& dom_map, - const IntSet& fused, - IntSet* outer, - IntSet* inner) { - CHECK(dom_map.count(s->outer)); - CHECK(dom_map.count(s->inner)); - CHECK(dom_map.count(s->fused)); - - if (MatchRange(fused, dom_map.at(s->fused))) { - *outer = IntSet::range(dom_map.at(s->outer)); - *inner = IntSet::range(dom_map.at(s->inner)); - return; - } - - Expr outer_min = dom_map.at(s->outer)->min; - Expr inner_min = dom_map.at(s->inner)->min; - - const IntervalSet* fused_int = fused.as(); - - if (fused_int && fused_int->i.is_single_point()) { - Expr value = fused_int->i.min; - Expr factor = dom_map.at(s->inner)->extent; - Expr v_outer = value / factor; - Expr v_inner = value % factor; - if (!is_zero(outer_min)) v_outer = v_outer + outer_min; - if (!is_zero(inner_min)) v_inner = v_inner + inner_min; - *outer = IntSet::single_point(v_outer); - *inner = IntSet::single_point(v_inner); - } else { - LOG(WARNING) << "use fallback inference rule in fuse"; - // simply use the entire set, this rule can be enhanced. - *outer = IntSet::range(dom_map.at(s->outer)); - *inner = IntSet::range(dom_map.at(s->inner)); - return; - } -} - - -void PassUp(const RebaseNode* s, - const std::unordered_map& dom_map, - const IntSet& rebased, - IntSet* parent) { - CHECK(dom_map.count(s->parent)); - if (MatchRange(rebased, dom_map.at(s->rebased))) { - *parent = IntSet::range(dom_map.at(s->parent)); - return; - } - Expr parent_min = dom_map.at(s->parent)->min; - *parent = Combine(rebased, IntSet::single_point(parent_min)); -} - // Evaluator to evalute the epxression. class IntSetEvaluator { public: @@ -527,5 +455,5 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); -} // namespace schedule +} // namespace arith } // namespace tvm diff --git a/src/schedule/int_set.h b/src/arithmetic/int_set.h similarity index 56% rename from src/schedule/int_set.h rename to src/arithmetic/int_set.h index 5866c123d6c7..80c2fae79146 100644 --- a/src/schedule/int_set.h +++ b/src/arithmetic/int_set.h @@ -3,14 +3,14 @@ * \file int_set.h * \brief Abstraction for all integer set operations. */ -#ifndef TVM_SCHEDULE_INT_SET_H_ -#define TVM_SCHEDULE_INT_SET_H_ +#ifndef TVM_ARITHMETIC_INT_SET_H_ +#define TVM_ARITHMETIC_INT_SET_H_ #include #include namespace tvm { -namespace schedule { +namespace arith { // internal node container of int set. class IntSetNode; @@ -44,6 +44,18 @@ class IntSet : public NodeRef { bool is_everything() const; /*! \return Whether the set is a single point */ bool is_single_point() const; + /*! + * \brief The single point value, call only if is_single_point is true + * \return The point value. + */ + Expr point_value() const; + /*! + * \brief Try to match IntSet with range r. + * + * \note It is guanrateed that IntSet::range(r).match_range(r) == true + * \return true if we can prove they are the same. + */ + bool match_range(const Range& r) const; /*! \return Whether the set contains everything */ static IntSet everything(); /*! @@ -88,59 +100,6 @@ IntSet EvalSet(Expr e, IntSet EvalSet(Range r, const Map& dom_map); -/*! - * \brief Conditional upward message passing. - * - * Get domain of parent, condition on domain of children. - * Domain is represented as IntSet. - * - * \param s The Split relation node. - * \param dom_map The old domain result from downward message passing. - * Contains the domain set if all the children are full set. - * \param outer domain of outer iteration. - * \param inner domain of inner iteration. - * \param parent The result domain of parent. - */ -void PassUp(const SplitNode* s, - const std::unordered_map& dom_map, - const IntSet& outer, - const IntSet& inner, - IntSet* parent); -/*! - * \brief Conditional upward message passing. - * - * Get domain of parent, condition on domain of children. - * Domain is represented as IntSet. - * - * \param s The Fuse relation node. - * \param dom_map The old domain result from downward message passing. - * Contains the domain set if all the children are full set. - * \param fused domain of fused iteration. - * \param outer The result domain of outer iteration. - * \param inner The result domain of inner iteration. - */ -void PassUp(const FuseNode* s, - const std::unordered_map& dom_map, - const IntSet& fused, - IntSet* outer, - IntSet* inner); - -/*! - * \brief Conditional upward message passing. - * - * Get domain of parent, condition on domain of children. - * Domain is represented as IntSet. - * - * \param s The Fuse relation node. - * \param dom_map The old domain result from downward message passing. - * Contains the domain set if all the children are full set. - * \param rebased domain of rebased iteration. - * \param parent The result domain of parent iteration. - */ -void PassUp(const RebaseNode* s, - const std::unordered_map& dom_map, - const IntSet& fused, - IntSet* parent); /*! * \brief Create an union set of all sets * \param sets The sets to be unioned @@ -153,7 +112,7 @@ inline const IntSetNode* IntSet::operator->() const { return static_cast(node_.get()); } -} // namespace schedule +} // namespace arith } // namespace tvm -#endif // TVM_SCHEDULE_INT_SET_H_ +#endif // TVM_ARITHMETIC_INT_SET_H_ diff --git a/src/pass/inline.cc b/src/pass/inline.cc index de452c364cd8..1dee4776e6ab 100644 --- a/src/pass/inline.cc +++ b/src/pass/inline.cc @@ -24,10 +24,24 @@ class IRInline : public IRMutator { if (op->func == f_) { CHECK_EQ(op->value_index, 0); Expr expr = body_; - CHECK_EQ(args_.size(), op->args.size()) - << op->args.size() << " vs " << args_.size(); - for (size_t i = 0; i < args_.size(); ++i) { - expr = Let::make(args_[i], op->args[i], expr); + CHECK_EQ(args_.size(), op->args.size()); + + bool has_side_effect = false; + for (size_t i = 0; i < op->args.size(); ++i) { + if (HasSideEffect(op->args[i])) has_side_effect = true; + } + + if (has_side_effect) { + for (size_t i = 0; i < args_.size(); ++i) { + expr = Let::make(args_[i], op->args[i], expr); + } + } else { + Map vmap; + for (size_t i = 0; i < args_.size(); ++i) { + vmap.Set(args_[i], op->args[i]); + } + expr = Substitute( + Evaluate::make(expr), vmap).as()->value; } return expr; } else { diff --git a/src/pass/simple_passes.cc b/src/pass/simple_passes.cc index 0fe6b94ebd24..5fc928cdd32b 100644 --- a/src/pass/simple_passes.cc +++ b/src/pass/simple_passes.cc @@ -47,10 +47,10 @@ class IRSubstitue : public IRMutator { std::unordered_map smap; }; -Stmt Substitute(Stmt stmt, const Map& value_map) { +Stmt Substitute(Stmt stmt, const Map& value_map) { IRSubstitue m; for (auto kv : value_map) { - m.smap[kv.first->var.get()] = kv.second; + m.smap[kv.first.get()] = kv.second; } return m.Mutate(stmt); } diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc new file mode 100644 index 000000000000..555e5b970a05 --- /dev/null +++ b/src/pass/unroll_loop.cc @@ -0,0 +1,78 @@ +/*! + * Copyright (c) 2016 by Contributors + * SSA related checks and pass. + * \file ssa.cc + */ +#include +#include +#include +#include +#include +#include +#include "../arithmetic//compute_expr.h" + +namespace tvm { +namespace ir { + +class LoopUnroller : public IRMutator { + public: + explicit LoopUnroller(int max_auto_step) + : max_auto_step_(max_auto_step) { + } + + Stmt Mutate_(const For* op, const Stmt& s) { + Stmt stmt = s; + // constant folding. + Expr extent = ir::Simplify(op->extent); + const IntImm* v1 = extent.as(); + const UIntImm* v2 = extent.as(); + int value = -1; + if (v1 != nullptr) { + value = static_cast(v1->value); + } + if (v2 != nullptr) { + value = static_cast(v2->value); + } + bool allow_unroll = value >= 0 && value <= max_auto_step_; + if (op->for_type == ForType::Unrolled) { + CHECK_GE(value, 0) + << "Cannot unroll non-constant loop"; + allow_unroll = true; + } + + if (allow_unroll) { + using arith::ComputeExpr; + if (value == 0) return Evaluate::make(0); + Stmt body = op->body; + Map vmap; + Stmt unrolled; + for (int i = 0; i < value; ++i) { + Var lv(op->loop_var.node_); + vmap.Set(lv, + ComputeExpr( + op->min, make_const(op->loop_var.type(), i))); + Stmt step = Substitute(body, vmap); + if (unrolled.defined()) { + unrolled = Block::make(unrolled, step); + } else { + unrolled = step; + } + } + return this->Mutate(unrolled); + } else { + return IRMutator::Mutate_(op, stmt); + } + } + + private: + int max_auto_step_; +}; + + +Stmt UnrollLoop(Stmt stmt, int max_auto_step) { + Stmt ret = LoopUnroller(max_auto_step).Mutate(stmt); + return ConvertSSA(ret); +} + +} // namespace ir +} // namespace tvm diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 706550843326..9fe530b6767a 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -7,13 +7,15 @@ #include #include #include -#include "./int_set.h" #include "./graph.h" +#include "../arithmetic/int_set.h" #include "../runtime/thread_storage_scope.h" namespace tvm { namespace schedule { +using namespace arith; + // result = ceil((a / b)), both a and b are positive integer inline Expr DivCeil(Expr a, Expr b) { return ir::Simplify((a + b - 1) / b); @@ -70,6 +72,80 @@ void PassDown(const Stage& s, // pass the integer set on each leave loop up to the root // dom_map is the result of PassDown, it records the domain of each IterVar. // dom_map can be used to get cached result in reverse construction. +// Implementation of Evaluations and passing. +void PassUp(const SplitNode* s, + const std::unordered_map& dom_map, + const IntSet& outer, + const IntSet& inner, + IntSet* parent) { + if (dom_map.count(s->outer) && + dom_map.count(s->inner) && + dom_map.count(s->parent) && + outer.match_range(dom_map.at(s->outer)) && + inner.match_range(dom_map.at(s->inner))) { + *parent = IntSet::range(dom_map.at(s->parent)); + return; + } + Expr factor = dom_map.at(s->inner)->extent; + Expr parent_min = dom_map.at(s->parent)->min; + CHECK(outer.defined()); + CHECK(inner.defined()); + CHECK(factor.defined()); + *parent = EvalSet( + s->outer->var * factor + s->inner->var + parent_min, + {{s->outer, outer}, {s->inner, inner}}); +} + +void PassUp(const FuseNode* s, + const std::unordered_map& dom_map, + const IntSet& fused, + IntSet* outer, + IntSet* inner) { + CHECK(dom_map.count(s->outer)); + CHECK(dom_map.count(s->inner)); + CHECK(dom_map.count(s->fused)); + + if (fused.match_range(dom_map.at(s->fused))) { + *outer = IntSet::range(dom_map.at(s->outer)); + *inner = IntSet::range(dom_map.at(s->inner)); + return; + } + Expr outer_min = dom_map.at(s->outer)->min; + Expr inner_min = dom_map.at(s->inner)->min; + + if (fused.is_single_point()) { + Expr value = fused.point_value(); + Expr factor = dom_map.at(s->inner)->extent; + Expr v_outer = value / factor; + Expr v_inner = value % factor; + if (!is_zero(outer_min)) v_outer = v_outer + outer_min; + if (!is_zero(inner_min)) v_inner = v_inner + inner_min; + *outer = IntSet::single_point(v_outer); + *inner = IntSet::single_point(v_inner); + } else { + LOG(WARNING) << "use fallback inference rule in fuse"; + // simply use the entire set, this rule can be enhanced. + *outer = IntSet::range(dom_map.at(s->outer)); + *inner = IntSet::range(dom_map.at(s->inner)); + return; + } +} + + +void PassUp(const RebaseNode* s, + const std::unordered_map& dom_map, + const IntSet& rebased, + IntSet* parent) { + CHECK(dom_map.count(s->parent)); + if (rebased.match_range(dom_map.at(s->rebased))) { + *parent = IntSet::range(dom_map.at(s->parent)); + return; + } + Expr parent_min = dom_map.at(s->parent)->min; + *parent = EvalSet(s->rebased->var + parent_min, + {{s->rebased, rebased}}); +} + void PassUp(const Stage& s, const std::unordered_map& dom_map, std::unordered_map* p_state) { diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index 530eecaac971..33272fceb222 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -6,7 +6,6 @@ #include #include #include -#include "./int_set.h" #include "./graph.h" namespace tvm { diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index e1390b5891f8..58d5f6bdb3a1 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -9,13 +9,13 @@ #include #include "../pass/ir_util.h" -#include "./int_set.h" +#include "../arithmetic/compute_expr.h" #include "./graph.h" -#include "./compute_expr.h" namespace tvm { namespace schedule { +using namespace arith; using namespace ir; /*! @@ -230,6 +230,15 @@ MakeLoopNest(const Stage& sch, return nest; } +Stmt Substitute(Stmt s, + const std::unordered_map& value_map) { + Map temp; + for (const auto& kv : value_map) { + temp.Set(kv.first->var, kv.second); + } + return ir::Substitute(s, temp); +} + Stmt MakeLoop(const Stage& s, const Map& dom_map, Stmt provide, @@ -244,7 +253,6 @@ Stmt MakeLoop(const Stage& s, auto nest = MakeLoopNest(s, dom_map, 0, false, bound_state, {}, &value_map); - provide = Substitute(provide, value_map); if (init.defined()) { // try to find the location to insert the initialization. diff --git a/tests/python/unittest/test_pass_unroll.py b/tests/python/unittest/test_pass_unroll.py new file mode 100644 index 000000000000..191377baaab6 --- /dev/null +++ b/tests/python/unittest/test_pass_unroll.py @@ -0,0 +1,20 @@ +import tvm + +def test_unroll_loop(): + dtype = 'int64' + n = tvm.Var('n') + Ab = tvm.Buffer((n, ), dtype) + i = tvm.Var('i') + j = tvm.Var('j') + # for i in 0 to n-1: + stmt = tvm.make.For( + i, n, 2, 0, 0, + tvm.make.For(j, 0, n, 0, 0, + tvm.make.Store(Ab.data, + tvm.make.Load(dtype, Ab.data, i) + 1, + j + 1))) + stmt = tvm.ir_pass.UnrollLoop(stmt, 8) + print(stmt) + +if __name__ == "__main__": + test_unroll_loop()