diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 294d22b812a1..79172c374316 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -140,23 +140,6 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2); */ TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2); -/*! - * \brief Add abstraction over a function - * - * For example: `square` is transformed to - * `fun x -> square x`. - * - * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion - * for more details. - * - * \param e The original function. - * \param mod The module used for referencing global functions, can be - * None. - * - * \return the new function with abstraction - */ -TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod); - /*! * \brief Check that each Var is only bound once. * @@ -288,24 +271,6 @@ TVM_DLL tvm::Array AllTypeVars(const Expr& expr, const Module& mod); */ TVM_DLL tvm::Array AllTypeVars(const Type& t, const Module& mod); -/*! \brief Remove expressions which does not effect the program result. - * - * It will remove let bindings which are not referenced, - * and inline let bindings that are only used once. - * - * For example, this pass should turn `let a = 1 in 2` into `2`, - * as the value of the expression does not depend on a. - * - * As another example, `let a = 1 in a` will be optimized into 1, - * if the flag is turned on. - * - * \param e the expression to optimize. - * \param inline_once whether or not to inline binding used one. - * - * \return the optimized expression. - */ -TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false); - /*! * \brief Fold constant expressions. * @@ -387,38 +352,6 @@ TVM_DLL Map CollectDeviceInfo(const Expr& expr); */ TVM_DLL Map CollectDeviceAnnotationOps(const Expr& expr); -/*! - * \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). - * - * It will turn an expression that is in a graph form (with sharing implicit), - * to an expression with explicit sharing (A-Normal Form). - * - * The scope of the root expression is the global scope. - * - * The scope of any non root expression is the least common ancestor of all it's scope. - * - * Values are ordered by post-DFS order in each scope. - * - * \param e the expression to observably share. - * \param mod The module used for referencing global functions, can be - * None. - * - * \return expression in A-Normal Form. - */ -TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod); - -/*! - * \brief Remove let binding and directly share via pointer instead. - * - * It will remove all let binding, - * and turn all of the variable bound by let into direct pointer reference. - * - * \param e the expression. - * - * \return the expression in graph normal form. - */ -TVM_DLL Expr ToGraphNormalForm(const Expr& e); - /*! * \brief Finds cases that the given match expression does not catch, if any. * @@ -432,28 +365,17 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e); TVM_DLL Array UnmatchedCases(const Match& match, const Module& mod); /*! - * \brief Aggressive constant propagation/constant folding/inlining. - * It will do as much computation in compile time as possible. - * It has two benefit: remove runtime overhead, and allow more optimization (typically fusion). - * As a side effect, code size will explode. - * - * \param e the expression - * \param mod the module - * - * \return the optimized expression. - */ -TVM_DLL Expr PartialEval(const Expr& e, const Module& mod); - -/* - * \brief Bind function parameters or free variables. + * \brief Bind the free variables to a Relay expression. * * Parameter binding can only happen if expr is a Function. * binds cannot change internal arguments of internal functions. * * \param expr The function to be binded. * \param binds The map of arguments to + * + * \return The expression with all free vars bound. */ -TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& bind_map); +TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); /*! \brief A hashing structure in the style of std::hash. */ struct StructuralHash { diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 04b4e64dc9c3..9ae71d824f94 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -541,6 +541,33 @@ TVM_DLL Pass AlterOpLayout(); */ TVM_DLL Pass CanonicalizeCast(); +/*! + * \brief Add abstraction over a function + * + * For example: `square` is transformed to + * `fun x -> square x`. + * + * See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion + * for more details. + * + * \return The pass. + */ +TVM_DLL Pass EtaExpand(); + +/*! + * \brief This is a helper function that runs a some optimization passes on + * a certain expression and returns the optimized version. With the help of this + * function, users don't need to manually construct a module, then perform + * passes, and finally and extract the target function/expression from the + * returned module frequently. + * + * \param expr The expression to be optimized. + * \param passes The passses that will be applied on the given expression. + * + * \return The optimized expression. + */ +TVM_DLL Expr OptimizeOnExpr(const Expr& expr, const Array& passes); + } // namespace transform } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 1748571cb316..52dc34d7aac9 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -84,23 +84,6 @@ def backward_fold_scale_axis(expr): """ return _ir_pass.backward_fold_scale_axis(expr) -def eta_expand(expr, mod): - """Add abstraction over a function. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression, we expect that expr's types - should be fully inferred by infer_type. - mod : tvm.relay.Module - The global module. - - Returns - ------- - expanded_expr : tvm.relay.Expr - The expression after eta expansion. - """ - return _ir_pass.eta_expand(expr, mod) def forward_fold_scale_axis(expr): """Fold the scaling of axis into weights of conv2d/dense. @@ -318,25 +301,6 @@ def canonicalize_ops(expr): return _ir_pass.canonicalize_ops(expr) -def dead_code_elimination(expr, inline_once=False): - """ Remove expressions which does not effect the program result (dead code). - - Parameters - ---------- - expr : tvm.relay.Expr - The input Expression - - inline_once : Optional[Bool] - Whether to inline binding that occur only once. - Returns - ------- - result : tvm.relay.Expr - An expression which is semantically equal to the input expression, - but with dead code removed. - """ - return _ir_pass.dead_code_elimination(expr, inline_once) - - def alpha_equal(lhs, rhs): """Compare two Relay expr for structural equivalence (alpha equivalence). @@ -534,46 +498,6 @@ def collect_device_annotation_ops(expr): return _ir_pass.CollectDeviceAnnotationOps(expr) -def to_a_normal_form(expr, mod=None): - """ - Turn Graph Normal Form expression into A Normal Form Expression. - - The scope of the root expression is the global scope. - - The scope of any non root expression is the least common ancestor of all it's scope. - - Values are ordered by post-DFS order in each scope. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - mod : Optional[tvm.relay.Module] - The global module. - - Returns - ------- - result : tvm.relay.Expr - The output expression. - """ - return _ir_pass.to_a_normal_form(expr, mod) - - -def to_graph_normal_form(expr): - """Turn A Normal Form expression into Graph Normal Form expression - Parameters - ---------- - expr : tvm.relay.Expr - The input expression - Returns - ------- - result : tvm.relay.Expr - The output expression - """ - return _ir_pass.to_graph_normal_form(expr) - - def gradient(expr, mod=None, mode='higher_order'): """ Transform the input function, @@ -642,26 +566,6 @@ def eliminate_common_subexpr(expr, fskip=None): return _ir_pass.eliminate_common_subexpr(expr, fskip) -def partial_evaluate(expr, mod=None): - """ - Evaluate the static fragment of the code. - - Parameters - ---------- - expr : tvm.relay.Expr - The input expression. - - mod : Optional[tvm.relay.Module] - The global module - - Returns - ------- - result : tvm.relay.Expr - The output expression. - """ - return _ir_pass.partial_evaluate(expr, mod) - - def unmatched_cases(match, mod=None): """ Finds cases that the match expression does not catch, if any. diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 5f47e5b446aa..ba4857dc4d36 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -302,15 +302,20 @@ def CanonicalizeOps(): return _transform.CanonicalizeOps() -def DeadCodeElimination(): - """ Remove expressions which does not effect the program result (dead code). +def DeadCodeElimination(inline_once=False): + """Remove expressions which does not effect the program result (dead code). + + Parameters + ---------- + inline_once: Optional[Bool] + Whether to inline binding that occurs only once. Returns ------- ret: tvm.relay.Pass The registered pass that eliminates the dead code in a Relay program. """ - return _transform.DeadCodeElimination() + return _transform.DeadCodeElimination(inline_once) def FoldConstant(): @@ -406,6 +411,7 @@ def ToANormalForm(): """ return _transform.ToANormalForm() + def EtaExpand(): """Add abstraction over a function @@ -416,6 +422,7 @@ def EtaExpand(): """ return _transform.EtaExpand() + def ToGraphNormalForm(): """Turn A Normal Form expression into Graph Normal Form expression @@ -449,7 +456,7 @@ def PartialEvaluate(): Returns ------- - ret : tvm.relay.Pass + ret: tvm.relay.Pass The registered pass that performs partial evaluation on an expression. """ return _transform.PartialEvaluate() @@ -465,6 +472,31 @@ def CanonicalizeCast(): """ return _transform.CanonicalizeCast() + +def OptimizeOnExpr(expr, passes): + """Perform optimization passes on an expressioin. + + Parameters + ---------- + expr: tvm.relay.Expr + The expression for optimization. + + passes: Union[Pass, List[Pass]] + The list of optimizations to be applied. + + Returns + ------- + ret: tvm.relay.Expr + The optimized expression. + """ + if isinstance(passes, Pass): + passes = [passes] + if not isinstance(passes, (list, tuple)): + raise TypeError("passes must be a pass or a list of pass objects.") + + return _transform.OptimizeOnExpr(expr, passes) + + def _wrap_class_module_pass(pass_cls, pass_info): """Wrap a python class as function pass""" class PyModulePass(ModulePass): diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 7e186f80df92..8799bf403375 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -156,9 +156,6 @@ Expr DeadCodeElimination(const Expr& e, bool inline_once) { return CalcDep::Eliminate(e, inline_once); } -TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") -.set_body_typed(DeadCodeElimination); - namespace transform { Pass DeadCodeElimination(bool inline_once) { diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index b95c5844f8a4..e7edbb3153d8 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -1086,27 +1086,30 @@ Expr PostProcess(const Expr& e) { } // namespace partial_eval -Expr PartialEval(const Expr& e, const Module& m) { - return TransformF([&](const Expr& e) { +Module PartialEval(const Module& m) { + CHECK(m->entry_func.defined()); + auto func = m->Lookup(m->entry_func); + Expr ret = + TransformF([&](const Expr& e) { return LetList::With([&](LetList* ll) { - relay::partial_eval::PartialEvaluator pe(FreeVars(e), m); - pe.InitializeFuncId(e); - return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic); - }); - }, e); + relay::partial_eval::PartialEvaluator pe(FreeVars(e), m); + pe.InitializeFuncId(e); + return relay::partial_eval::PostProcess(pe.VisitExpr(e, ll)->dynamic); + }); + }, func); + CHECK(ret->is_type()); + m->Update(m->entry_func, Downcast(ret)); + return m; } -TVM_REGISTER_API("relay._ir_pass.partial_evaluate") -.set_body_typed(PartialEval); - namespace transform { Pass PartialEval() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast(PartialEval(f, m)); + runtime::TypedPackedFunc pass_func = + [=](Module m, PassContext pc) { + return PartialEval(m); }; - return CreateFunctionPass(pass_func, 1, "PartialEvaluate", {}); + return CreateModulePass(pass_func, 1, "PartialEvaluate", {}); } TVM_REGISTER_API("relay._transform.PartialEvaluate") diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index d63d9121fe27..a620316035c7 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -573,6 +573,18 @@ class PassContext::Internal { } }; +Expr OptimizeOnExpr(const Expr& expr, const Array& passes) { + auto mod = ModuleNode::FromExpr(expr); + Sequential seq(passes); + auto pass_ctx = PassContext::Create(); + pass_ctx->opt_level = 3; + tvm::With ctx_scope(pass_ctx); + mod = seq(mod); + CHECK(mod.defined()); + auto entry_func = mod->Lookup(mod->entry_func); + return expr.as() == nullptr ? entry_func->body : entry_func; +} + TVM_REGISTER_API("relay._transform.GetCurrentPassContext") .set_body_typed(PassContext::Current); @@ -582,6 +594,9 @@ TVM_REGISTER_API("relay._transform.EnterPassContext") TVM_REGISTER_API("relay._transform.ExitPassContext") .set_body_typed(PassContext::Internal::ExitScope); +TVM_REGISTER_API("relay._transform.OptimizeOnExpr") +.set_body_typed(OptimizeOnExpr); + } // namespace transform } // namespace relay } // namespace tvm diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 324eddd21c5c..b5a3f8552d8d 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -26,6 +26,8 @@ */ #include #include +#include +#include #include #include "let_list.h" #include "../../common/arena.h" @@ -35,10 +37,6 @@ namespace tvm { namespace relay { -Expr ToANormalForm(const Expr& e, - const Module& m, - std::unordered_set* gv); - struct ScopeNode; using Scope = std::shared_ptr; @@ -104,29 +102,21 @@ bool IsPrimitiveFunction(const Expr& e) { class Fill : ExprFunctor { public: static Expr ToANormalForm(const Expr& e, - const Module& m, const DependencyGraph& dg, - std::unordered_map* node_scope, - std::unordered_set* gv) { - Fill fi(m, dg, node_scope, gv); + std::unordered_map* node_scope) { + Fill fi(dg, node_scope); return fi.GetScope(e)->ll->Get(fi.VisitExpr(e)); } private: - Module mod_; const DependencyGraph& dg_; std::unordered_map* node_scope_; - std::unordered_set* visited_; std::unordered_map memo; - Fill(Module mod, - const DependencyGraph& dg, - std::unordered_map* node_scope, - std::unordered_set* visited) : - mod_(mod), + Fill(const DependencyGraph& dg, + std::unordered_map* node_scope) : dg_(dg), - node_scope_(node_scope), - visited_(visited) { } + node_scope_(node_scope) { } Scope GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); @@ -246,10 +236,6 @@ class Fill : ExprFunctor { Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { GlobalVar gv = GetRef(gvn); - if (visited_->count(gv) == 0) { - visited_->insert(gv); - mod_->Update(gv, Downcast(relay::ToANormalForm(mod_->Lookup(gv), mod_, visited_))); - } return Atomic(gv, gv, v); } @@ -276,9 +262,7 @@ class Fill : ExprFunctor { } }; -Expr ToANormalFormAux(const Expr& e, - const Module& m, - std::unordered_set* gv) { +Expr ToANormalFormAux(const Expr& e) { /* When you lift a lambda, what is inside is also being lift. * * So we must determine the scope of the lambda before determining the scope of it's body. @@ -301,46 +285,40 @@ Expr ToANormalFormAux(const Expr& e, * We do an additional pass to fill all the LetList and we are done. */ std::unordered_map node_scope = CalcScope(dg); - return Fill::ToANormalForm(e, m, dg, &node_scope, gv); + return Fill::ToANormalForm(e, dg, &node_scope); } -Expr ToANormalForm(const Expr& e, - const Module& m, - std::unordered_set* gv) { - DLOG(INFO) - << "ToANF:" << std::endl - << AsText(e, false); - - Expr ret = - TransformF([&](const Expr& e) { - return ToANormalFormAux(e, m, gv); - }, e); - - CHECK_EQ(FreeVars(ret).size(), 0); +Module ToANormalForm(const Module& m) { + DLOG(INFO) << "ToANF:" << std::endl << m; + + tvm::Map updates; + auto funcs = m->functions; + for (const auto& it : funcs) { + Expr ret = + TransformF([&](const Expr& e) { + return ToANormalFormAux(e); + }, it.second); + CHECK_EQ(FreeVars(ret).size(), 0); + updates.Set(it.first, Downcast(ret)); + } - DLOG(INFO) - << "ToANF: transformed" << std::endl - << AsText(ret, false); + for (auto pair : updates) { + m->Add(pair.first, pair.second, true); + } - return ret; -} + DLOG(INFO) << "ToANF: transformed" << std::endl << m; -Expr ToANormalForm(const Expr& e, const Module& m) { - std::unordered_set gv; - return ToANormalForm(e, m, &gv); + return m; } -TVM_REGISTER_API("relay._ir_pass.to_a_normal_form") -.set_body_typed(static_cast(ToANormalForm)); - namespace transform { Pass ToANormalForm() { - runtime::TypedPackedFunc pass_func = - [=](Function f, Module m, PassContext pc) { - return Downcast(ToANormalForm(f, m)); + runtime::TypedPackedFunc pass_func = + [=](Module m, PassContext pc) { + return ToANormalForm(m); }; - return CreateFunctionPass(pass_func, 1, "ToANormalForm", {}); + return CreateModulePass(pass_func, 1, "ToANormalForm", {}); } TVM_REGISTER_API("relay._transform.ToANormalForm") diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc index 9c166f98c1a5..c1ae19e92748 100644 --- a/src/relay/pass/to_graph_normal_form.cc +++ b/src/relay/pass/to_graph_normal_form.cc @@ -24,8 +24,8 @@ * * \brief Turn A normal form into graph normal form. */ -#include #include +#include #include "let_list.h" namespace tvm { @@ -76,9 +76,6 @@ Expr ToGraphNormalForm(const Expr& e) { return GNF()(e); } -TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form") -.set_body_typed(ToGraphNormalForm); - namespace transform { Pass ToGraphNormalForm() { diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index 9158f0729d61..c3b12fea4486 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -18,20 +18,13 @@ import tvm from tvm import relay -from tvm.relay.ir_pass import dead_code_elimination, alpha_equal +from tvm.relay import Function, transform +from tvm.relay.ir_pass import alpha_equal, graph_equal, free_vars from tvm.relay.op import log, add, equal, subtract class env: def __init__(self): - self.a = relay.Var("a") - self.b = relay.Var("b") - self.c = relay.Var("c") - self.d = relay.Var("d") - self.e = relay.Var("e") - self.x = relay.Var("x") - self.y = relay.Var("y") - self.z = relay.Var("z") self.shape = tvm.convert([1, 2, 3]) self.tt = relay.TensorType(self.shape, "float32") self.int32 = relay.TensorType([], "int32") @@ -39,6 +32,14 @@ def __init__(self): self.one = relay.const(1.0) self.two = relay.const(2.0) self.three = relay.const(3.0) + self.a = relay.Var("a", self.float32) + self.b = relay.Var("b", self.float32) + self.c = relay.Var("c", self.float32) + self.d = relay.Var("d", self.float32) + self.e = relay.Var("e", self.float32) + self.x = relay.Var("x", self.int32) + self.y = relay.Var("y", self.int32) + self.z = relay.Var("z", self.int32) e = env() @@ -46,22 +47,27 @@ def __init__(self): def test_let(): orig = relay.Let(e.x, e.y, e.z) - assert alpha_equal(dead_code_elimination(orig), e.z) + orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + assert alpha_equal(Function(free_vars(orig), orig), Function([e.z], e.z)) def test_used_let(): orig = relay.Let(e.c, e.one, e.c + e.c) - assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.one, e.c + e.c)) + orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + expected = relay.Let(e.c, e.one, e.c + e.c) + assert alpha_equal(Function([e.c], orig), Function([e.c], expected)) @nottest def test_inline(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c)) - assert alpha_equal(dead_code_elimination(orig), e.d) + orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + assert alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d)) def test_chain_unused_let(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e)) - assert alpha_equal(dead_code_elimination(orig), e.e) + orig = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e)) # make sure we dont infinite loop @@ -78,27 +84,39 @@ def test_recursion(): f(2, 10000); """ f = relay.Var("f") + f1 = relay.Var("f1") n = relay.Var("n", e.int32) data = relay.Var("data", e.float32) funcbody = relay.If(equal(n, relay.const(0)), data, - relay.Call(f, [subtract(n, relay.const(1.0)), + relay.Call(f1, [subtract(n, relay.const(1)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) - orig = relay.Let(f, value, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) - assert alpha_equal(dead_code_elimination(orig), orig) - assert alpha_equal(dead_code_elimination(relay.Let(f, value, e.three)), e.three) + orig = relay.Let(f, value, relay.Call(f, [relay.const(2), relay.const(10000.0)])) + dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + orig = transform.OptimizeOnExpr(orig, transform.InferType()) + assert graph_equal(dced, orig) + dced = transform.OptimizeOnExpr(relay.Let(f, value, e.three), + transform.DeadCodeElimination()) + assert alpha_equal(dced, e.three) def test_op_let(): - assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three), e.two)), add(e.three, e.two)) + dced = transform.OptimizeOnExpr(add(relay.Let(e.a, e.one, e.three), e.two), + transform.DeadCodeElimination()) + assert alpha_equal(dced, add(e.three, e.two)) def test_tuple_get_item(): - t = relay.Var('t') + tt = relay.TupleType([e.float32, e.float32]) + t = relay.Var('t', tt) + a = relay.Var('a') g = relay.TupleGetItem(t, 0) - assert alpha_equal(dead_code_elimination(g), g) - assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t), 0)), g) + dced = transform.OptimizeOnExpr(g, transform.DeadCodeElimination()) + assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) + orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0) + dced = transform.OptimizeOnExpr(orig, transform.DeadCodeElimination()) + assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g), g)) if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index b3c0c28d26cb..f2aedd1905d4 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -18,17 +18,13 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.ir_pass import partial_evaluate, alpha_equal, infer_type, dead_code_elimination -from tvm.relay.ir_pass import gradient -from tvm.relay import op, create_executor -from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue +from tvm.relay.ir_pass import alpha_equal, gradient from tvm.relay.prelude import Prelude -from tvm.relay import create_executor -from nose.tools import nottest +from tvm.relay import op, create_executor, transform from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match -from tvm.relay import GlobalVar, Call, Type -from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr +from tvm.relay import GlobalVar, Call +from tvm.relay.testing import add_nat_definitions, make_nat_expr def check_eval(expr, expected_result, mod=None, rtol=1e-07): ctx = tvm.context("llvm", 0) @@ -38,8 +34,25 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07): np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) -def dcpe(expr, mod=None): - return dead_code_elimination(partial_evaluate(expr, mod=mod), inline_once=True) +def tipe(expr): + return transform.OptimizeOnExpr(expr, + [transform.InferType(), + transform.PartialEvaluate(), + transform.InferType()]) + + +def dcpe(expr, mod=None, grad=False): + passes = [transform.PartialEvaluate(), + transform.DeadCodeElimination(inline_once=True)] + if grad: + expr = gradient(expr) + if mod: + assert isinstance(expr, Function) + mod[mod.entry_func] = expr + seq = transform.Sequential(passes) + mod = seq(mod) + return mod[mod.entry_func] + return transform.OptimizeOnExpr(expr, passes) def test_tuple(): @@ -47,24 +60,31 @@ def test_tuple(): x = Var("x", t) body = TupleGetItem(relay.Tuple([relay.const(4.0), x]), 1) f = Function([x], body, None, [t]) - assert alpha_equal(dcpe(f), relay.Function([x], x, None, [t])) + expected = relay.Function([x], x, None, [t]) + expected = transform.OptimizeOnExpr(expected, transform.InferType()) + assert alpha_equal(dcpe(f), expected) + def test_const_inline(): - d = Var("d") + t = relay.TensorType([], "float32") + d = Var("d", t) double = Function([d], d + d) orig = double(const(4.0)) assert alpha_equal(dcpe(orig), const(8.0)) def test_ref(): - d = relay.Var("d") - r = relay.Var("r") + t = relay.TensorType([], "float32") + d = relay.Var("d", t) + r = relay.Var("r", relay.RefType(t)) x = relay.Var("x") body = relay.RefRead(r) body = Let(x, RefWrite(r, RefRead(r) * RefRead(r)), body) body = Let(r, RefCreate(d), body) square = Function([d], body) - assert alpha_equal(dcpe(square), Function([d], d * d)) + expected = transform.OptimizeOnExpr(Function([d], d * d), + transform.InferType()) + assert alpha_equal(dcpe(square), expected) def test_empty_ad(): @@ -73,17 +93,19 @@ def test_empty_ad(): t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d) - g = dcpe(gradient(f)) + g = dcpe(f, grad=True) expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])])) + expected = transform.OptimizeOnExpr(expected, transform.InferType()) assert alpha_equal(g, expected) + def test_ad(): shape = (10, 10) dtype = "float32" t = TensorType(shape, dtype) d = Var("d", t) f = Function([d], d * d) - g = dcpe(gradient(f)) + g = dcpe(f, grad=True) m = d * d x = relay.Var("x") o = op.ones_like(x) @@ -92,6 +114,7 @@ def test_ad(): body = Tuple([x, Tuple([grad])]) body = relay.Let(x1, o, body) expected = Function([d], relay.Let(x, m, body)) + expected = transform.OptimizeOnExpr(expected, transform.InferType()) assert alpha_equal(g, expected) @@ -107,8 +130,7 @@ def test_if_ref(): eff = Var("eff") body = Let(eff, body, RefRead(r)) f = Function([d], Let(r, RefCreate(const(1)), Let(u, update, body))) - f = infer_type(f) - pe_f = infer_type(partial_evaluate(f)) + pe_f = tipe(f) ex = create_executor() f_res = ex.evaluate(f)(const(True)) pe_f_res = ex.evaluate(pe_f)(const(True)) @@ -132,8 +154,7 @@ def test_function_invalidate(): body = Let(fet, fetch, body) body = Let(r, RefCreate(const(0)), body) f = Function([d], body) - f = infer_type(f) - pe_f = infer_type(partial_evaluate(f)) + pe_f = tipe(f) ex = create_executor() f_res = ex.evaluate(f)(const(True)) pe_f_res = ex.evaluate(pe_f)(const(True)) @@ -144,35 +165,30 @@ def test_function_invalidate(): def test_head_cons(): mod = Module() p = Prelude(mod) - def hd_impl(): - a = TypeVar("a") - x = Var("x", p.l(a)) - y = Var("y") - z = Var("z") - cons_case = Clause(PatternConstructor(p.cons, - [PatternVar(y), - PatternVar(z)]), - y) - y = Var("y") - z = Var("z") - return Function([x], Match(x, [cons_case]), a, [a]) + hd = p.hd t = TypeVar("t") x = Var("x", t) - hd = Var("hd") - body = Let(hd, hd_impl(), hd(p.cons(x, p.nil()))) + body = hd(p.cons(x, p.nil())) f = Function([x], body, None, [t]) - f = infer_type(f, mod=mod) - res = dcpe(f) + res = dcpe(f, mod) assert alpha_equal(res, Function([x], x, t, [t])) def test_map(): mod = Module() p = Prelude(mod) - f = Var("f") + f = GlobalVar("f") + t = TypeVar("t") + a = Var("a", t) + mod[f] = Function([a], a, t, [t]) orig = p.map(f, p.cons(const(1), p.cons(const(2), p.cons(const(3), p.nil())))) - expected = p.cons(f(const(1)), p.cons(f(const(2)), p.cons(f(const(3)), p.nil()))) - assert alpha_equal(dcpe(orig, mod=mod), expected) + expected = p.cons((const(1)), p.cons((const(2)), p.cons((const(3)), p.nil()))) + expected = Function([], expected) + mod[mod.entry_func] = expected + expected = mod[mod.entry_func] + orig = Function([], orig) + res = dcpe(orig, mod=mod) + assert alpha_equal(res.body, expected.body) def test_loop(): @@ -181,9 +197,12 @@ def test_loop(): x = Var("x", t) loop = GlobalVar("loop") mod[loop] = Function([x], loop(x), t, [t]) - res = dcpe(loop(const(1)), mod=mod) - expected = Call(loop, [const(1)], None, [None]) - assert alpha_equal(res, expected) + expected = Call(loop, [const(1)]) + mod[mod.entry_func] = Function([], expected) + expected = mod[mod.entry_func].body + call = Function([], loop(const(1))) + res = dcpe(call, mod=mod) + assert alpha_equal(res.body, expected) def test_swap_loop(): @@ -196,8 +215,9 @@ def test_swap_loop(): loop = GlobalVar("loop") mod[loop] = Function([x, y], loop(y, x), nat) prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2)) - res = dcpe(prog, mod=mod) - assert alpha_equal(prog, res) + res = Function([], prog) + res = dcpe(res, mod=mod) + assert alpha_equal(prog, res.body) def test_abs_diff(): @@ -217,8 +237,9 @@ def test_abs_diff(): x_s_case = Clause(PatternConstructor(p.s, [PatternVar(xp)]), Match(y, [y_z_case, y_s_case])) mod[diff] = Function([x, y], Match(x, [x_z_case, x_s_case])) orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3)) + orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res, make_nat_expr(p, 4)) + assert alpha_equal(res.body, make_nat_expr(p, 4)) def test_match_nat_id(): @@ -233,8 +254,9 @@ def test_match_nat_id(): s_case = Clause(PatternConstructor(p.s, [PatternVar(y)]), p.s(y)) mod[nat_id] = Function([x], Match(x, [z_case, s_case])) orig = nat_id(make_nat_expr(p, 3)) + orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res, make_nat_expr(p, 3)) + assert alpha_equal(res.body, make_nat_expr(p, 3)) def test_nat_id(): @@ -247,8 +269,9 @@ def test_nat_id(): nat_id = GlobalVar("nat_id") mod[nat_id] = Function([x], x) orig = nat_id(make_nat_expr(p, 3)) + orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res, make_nat_expr(p, 3)) + assert alpha_equal(res.body, make_nat_expr(p, 3)) def test_global_match_nat_id(): @@ -260,8 +283,9 @@ def test_global_match_nat_id(): z_case = Clause(PatternConstructor(p.z, []), p.z()) s_case = Clause(PatternConstructor(p.s, [PatternVar(x)]), p.s(x)) orig = Match(make_nat_expr(p, 3), [z_case, s_case]) + orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res, make_nat_expr(p, 3)) + assert alpha_equal(res.body, make_nat_expr(p, 3)) def test_double(): @@ -269,8 +293,9 @@ def test_double(): p = Prelude(mod) add_nat_definitions(p) orig = p.double(make_nat_expr(p, 3)) + orig = Function([], orig) res = dcpe(orig, mod=mod) - assert alpha_equal(res, make_nat_expr(p, 6)) + assert alpha_equal(res.body, make_nat_expr(p, 6)) if __name__ == '__main__': diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index 9a2570eabb11..e74168141e63 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -17,9 +17,8 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type, detect_feature -from tvm.relay import op, create_executor -from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue +from tvm.relay.ir_pass import alpha_equal, detect_feature +from tvm.relay import op, create_executor, transform from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count from tvm.relay.feature import Feature @@ -39,7 +38,7 @@ def test_explicit_bound(): z = op.add(y, y) f = relay.Function([], op.add(z, z)) assert not Feature.fLet in detect_feature(f) - anf = to_a_normal_form(f) + anf = transform.OptimizeOnExpr(f, transform.ToANormalForm()) assert Feature.fLet in detect_feature(anf) check_eval(f(), 8.0) check_eval(anf(), 8.0) @@ -53,7 +52,8 @@ def test_order(): x = relay.const(1) val = x + y * z check_eval(val, 7.0) - anf = infer_type(to_a_normal_form(val)) + anf = transform.OptimizeOnExpr(val, [transform.ToANormalForm(), + transform.InferType()]) a = relay.Var('a', relay.IncompleteType()) b = relay.Var('b', relay.IncompleteType()) c = relay.Var('c', relay.IncompleteType()) @@ -65,14 +65,16 @@ def test_order(): expected_output = relay.Let(c, z, expected_output) expected_output = relay.Let(b, y, expected_output) expected_output = relay.Let(a, x, expected_output) - expected_output = infer_type(expected_output) + expected_output = transform.OptimizeOnExpr(expected_output, + transform.InferType()) assert alpha_equal(anf, expected_output) def test_if(): cond = relay.const(True) x = relay.If(cond, relay.const(2), relay.const(3)) - anf = infer_type(to_a_normal_form(x)) + anf = transform.OptimizeOnExpr(x, [transform.ToANormalForm(), + transform.InferType()]) a = relay.Var('a', relay.IncompleteType()) b = relay.Var('b', relay.IncompleteType()) c = relay.Var('c', relay.IncompleteType()) @@ -82,7 +84,8 @@ def test_if(): expected_output = relay.If(c, true_branch, false_branch) expected_output = relay.Let(d, expected_output, d) expected_output = relay.Let(c, cond, expected_output) - expected_output = infer_type(expected_output) + expected_output = transform.OptimizeOnExpr(expected_output, + transform.InferType()) assert alpha_equal(anf, expected_output) @@ -114,7 +117,8 @@ def test_recursion(): mod[f] = value check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) old_f = mod[f] - f = to_a_normal_form(f, mod=mod) + mod = transform.ToANormalForm()(mod) + f = mod[f] check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) @@ -129,7 +133,8 @@ def test_ref(): body = relay.Let(iv, relay.RefRead(i), body) body = relay.Let(i, relay.RefCreate(relay.const(1)), body) check_eval(body, 3) - check_eval(to_a_normal_form(body), 3) + opt_body = transform.OptimizeOnExpr(body, transform.ToANormalForm()) + check_eval(opt_body, 3) def test_nat_add(): @@ -144,7 +149,12 @@ def test_nat_add(): intrp = create_executor(mod=mod, ctx=ctx, target="llvm") assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 - assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 + expr = add(s(z()), s(z())) + f = relay.GlobalVar("f") + mod[f] = relay.Function([], expr) + mod = transform.ToANormalForm()(mod) + expr = mod["f"] + assert count(p, intrp.evaluate(expr.body)) == 2 assert Feature.fLet in detect_feature(mod[add]) @@ -155,14 +165,16 @@ def test_let(): body = relay.Let(y, x, x + y) body = relay.Let(x, d, body) check_eval(body, 8) - check_eval(to_a_normal_form(body), 8) + opt_body = transform.OptimizeOnExpr(body, transform.ToANormalForm()) + check_eval(opt_body, 8) def test_function(): - x = relay.Var("x") + t = relay.TensorType((), 'float32') + x = relay.Var("x", t) f = relay.Function([x], x + x) d = relay.const(4.0, 'float32') - anf_f = to_a_normal_form(f) + anf_f = transform.OptimizeOnExpr(f, transform.ToANormalForm()) assert isinstance(anf_f, relay.Function) check_eval(f(d), 8) check_eval(anf_f(d), 8) diff --git a/tests/python/relay/test_pass_to_graph_normal_form.py b/tests/python/relay/test_pass_to_graph_normal_form.py index 6d9bd6ac254e..09db48f633d9 100644 --- a/tests/python/relay/test_pass_to_graph_normal_form.py +++ b/tests/python/relay/test_pass_to_graph_normal_form.py @@ -17,10 +17,9 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal, detect_feature -from tvm.relay import op, create_executor +from tvm.relay import op, create_executor, transform +from tvm.relay.ir_pass import detect_feature from tvm.relay.feature import Feature -from tvm.relay.backend.interpreter import Value, TupleValue def check_eval(expr, args, expected_result, mod=None, rtol=1e-07): @@ -41,9 +40,9 @@ def test_implicit_share(): body = relay.Let(z, op.add(y, y), op.add(z, z)) body = relay.Let(y, op.add(x, x), body) f = relay.Function([], relay.Let(x, relay.const(1), body)) - g = to_graph_normal_form(f) - assert "let" in f.astext() - assert not "let" in g.astext() + g = transform.OptimizeOnExpr(f, transform.ToGraphNormalForm()) + assert Feature.fLet in detect_feature(f) + assert not Feature.fLet in detect_feature(g) check_eval(f, [], 8.0) check_eval(g, [], 8.0) @@ -55,8 +54,8 @@ def test_round_trip(): body = relay.Let(z, op.add(y, y), op.add(z, z)) body = relay.Let(y, op.add(x, x), body) f = relay.Function([], relay.Let(x, relay.const(1), body)) - g = to_graph_normal_form(f) - h = to_a_normal_form(g) + g = transform.OptimizeOnExpr(f, transform.ToGraphNormalForm()) + h = transform.OptimizeOnExpr(g, transform.ToANormalForm()) assert Feature.fLet in detect_feature(f) assert not Feature.fLet in detect_feature(g) check_eval(f, [], 8.0)