Skip to content

Commit

Permalink
[Relay][Pass] Only allow Module -> Module for opts managed by pass in…
Browse files Browse the repository at this point in the history
…fra (#3430)

* [Relay][Pass] Only allow Module -> Module for opts managed by pass infra

* revert gradient pass
  • Loading branch information
zhiics authored and jroesch committed Jul 1, 2019
1 parent 6c81d78 commit f2a6851
Show file tree
Hide file tree
Showing 13 changed files with 278 additions and 349 deletions.
86 changes: 4 additions & 82 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,23 +140,6 @@ TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
*/
TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2);

/*!
* \brief Add abstraction over a function
*
* For example: `square` is transformed to
* `fun x -> square x`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
* for more details.
*
* \param e The original function.
* \param mod The module used for referencing global functions, can be
* None.
*
* \return the new function with abstraction
*/
TVM_DLL Expr EtaExpand(const Expr& e, const Module& mod);

/*!
* \brief Check that each Var is only bound once.
*
Expand Down Expand Up @@ -288,24 +271,6 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);

/*! \brief Remove expressions which does not effect the program result.
*
* It will remove let bindings which are not referenced,
* and inline let bindings that are only used once.
*
* For example, this pass should turn `let a = 1 in 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1,
* if the flag is turned on.
*
* \param e the expression to optimize.
* \param inline_once whether or not to inline binding used one.
*
* \return the optimized expression.
*/
TVM_DLL Expr DeadCodeElimination(const Expr& e, bool inline_once = false);

/*!
* \brief Fold constant expressions.
*
Expand Down Expand Up @@ -387,38 +352,6 @@ TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);
*/
TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);

/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
* It will turn an expression that is in a graph form (with sharing implicit),
* to an expression with explicit sharing (A-Normal Form).
*
* The scope of the root expression is the global scope.
*
* The scope of any non root expression is the least common ancestor of all it's scope.
*
* Values are ordered by post-DFS order in each scope.
*
* \param e the expression to observably share.
* \param mod The module used for referencing global functions, can be
* None.
*
* \return expression in A-Normal Form.
*/
TVM_DLL Expr ToANormalForm(const Expr& e, const Module& mod);

/*!
* \brief Remove let binding and directly share via pointer instead.
*
* It will remove all let binding,
* and turn all of the variable bound by let into direct pointer reference.
*
* \param e the expression.
*
* \return the expression in graph normal form.
*/
TVM_DLL Expr ToGraphNormalForm(const Expr& e);

/*!
* \brief Finds cases that the given match expression does not catch, if any.
*
Expand All @@ -432,28 +365,17 @@ TVM_DLL Expr ToGraphNormalForm(const Expr& e);
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const Module& mod);

/*!
* \brief Aggressive constant propagation/constant folding/inlining.
* It will do as much computation in compile time as possible.
* It has two benefit: remove runtime overhead, and allow more optimization (typically fusion).
* As a side effect, code size will explode.
*
* \param e the expression
* \param mod the module
*
* \return the optimized expression.
*/
TVM_DLL Expr PartialEval(const Expr& e, const Module& mod);

/*
* \brief Bind function parameters or free variables.
* \brief Bind the free variables to a Relay expression.
*
* Parameter binding can only happen if expr is a Function.
* binds cannot change internal arguments of internal functions.
*
* \param expr The function to be binded.
* \param binds The map of arguments to
*
* \return The expression with all free vars bound.
*/
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& bind_map);
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
Expand Down
27 changes: 27 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,33 @@ TVM_DLL Pass AlterOpLayout();
*/
TVM_DLL Pass CanonicalizeCast();

/*!
* \brief Add abstraction over a function
*
* For example: `square` is transformed to
* `fun x -> square x`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B7-conversion
* for more details.
*
* \return The pass.
*/
TVM_DLL Pass EtaExpand();

/*!
* \brief This is a helper function that runs a some optimization passes on
* a certain expression and returns the optimized version. With the help of this
* function, users don't need to manually construct a module, then perform
* passes, and finally and extract the target function/expression from the
* returned module frequently.
*
* \param expr The expression to be optimized.
* \param passes The passses that will be applied on the given expression.
*
* \return The optimized expression.
*/
TVM_DLL Expr OptimizeOnExpr(const Expr& expr, const Array<Pass>& passes);

} // namespace transform
} // namespace relay
} // namespace tvm
Expand Down
96 changes: 0 additions & 96 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,6 @@ def backward_fold_scale_axis(expr):
"""
return _ir_pass.backward_fold_scale_axis(expr)

def eta_expand(expr, mod):
"""Add abstraction over a function.
Parameters
----------
expr : tvm.relay.Expr
The input expression, we expect that expr's types
should be fully inferred by infer_type.
mod : tvm.relay.Module
The global module.
Returns
-------
expanded_expr : tvm.relay.Expr
The expression after eta expansion.
"""
return _ir_pass.eta_expand(expr, mod)

def forward_fold_scale_axis(expr):
"""Fold the scaling of axis into weights of conv2d/dense.
Expand Down Expand Up @@ -318,25 +301,6 @@ def canonicalize_ops(expr):
return _ir_pass.canonicalize_ops(expr)


def dead_code_elimination(expr, inline_once=False):
""" Remove expressions which does not effect the program result (dead code).
Parameters
----------
expr : tvm.relay.Expr
The input Expression
inline_once : Optional[Bool]
Whether to inline binding that occur only once.
Returns
-------
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
"""
return _ir_pass.dead_code_elimination(expr, inline_once)


def alpha_equal(lhs, rhs):
"""Compare two Relay expr for structural equivalence (alpha equivalence).
Expand Down Expand Up @@ -534,46 +498,6 @@ def collect_device_annotation_ops(expr):
return _ir_pass.CollectDeviceAnnotationOps(expr)


def to_a_normal_form(expr, mod=None):
"""
Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
The scope of any non root expression is the least common ancestor of all it's scope.
Values are ordered by post-DFS order in each scope.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_a_normal_form(expr, mod)


def to_graph_normal_form(expr):
"""Turn A Normal Form expression into Graph Normal Form expression
Parameters
----------
expr : tvm.relay.Expr
The input expression
Returns
-------
result : tvm.relay.Expr
The output expression
"""
return _ir_pass.to_graph_normal_form(expr)


def gradient(expr, mod=None, mode='higher_order'):
"""
Transform the input function,
Expand Down Expand Up @@ -642,26 +566,6 @@ def eliminate_common_subexpr(expr, fskip=None):
return _ir_pass.eliminate_common_subexpr(expr, fskip)


def partial_evaluate(expr, mod=None):
"""
Evaluate the static fragment of the code.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module
Returns
-------
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.partial_evaluate(expr, mod)


def unmatched_cases(match, mod=None):
"""
Finds cases that the match expression does not catch, if any.
Expand Down
40 changes: 36 additions & 4 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,20 @@ def CanonicalizeOps():
return _transform.CanonicalizeOps()


def DeadCodeElimination():
""" Remove expressions which does not effect the program result (dead code).
def DeadCodeElimination(inline_once=False):
"""Remove expressions which does not effect the program result (dead code).
Parameters
----------
inline_once: Optional[Bool]
Whether to inline binding that occurs only once.
Returns
-------
ret: tvm.relay.Pass
The registered pass that eliminates the dead code in a Relay program.
"""
return _transform.DeadCodeElimination()
return _transform.DeadCodeElimination(inline_once)


def FoldConstant():
Expand Down Expand Up @@ -406,6 +411,7 @@ def ToANormalForm():
"""
return _transform.ToANormalForm()


def EtaExpand():
"""Add abstraction over a function
Expand All @@ -416,6 +422,7 @@ def EtaExpand():
"""
return _transform.EtaExpand()


def ToGraphNormalForm():
"""Turn A Normal Form expression into Graph Normal Form expression
Expand Down Expand Up @@ -449,7 +456,7 @@ def PartialEvaluate():
Returns
-------
ret : tvm.relay.Pass
ret: tvm.relay.Pass
The registered pass that performs partial evaluation on an expression.
"""
return _transform.PartialEvaluate()
Expand All @@ -465,6 +472,31 @@ def CanonicalizeCast():
"""
return _transform.CanonicalizeCast()


def OptimizeOnExpr(expr, passes):
"""Perform optimization passes on an expressioin.
Parameters
----------
expr: tvm.relay.Expr
The expression for optimization.
passes: Union[Pass, List[Pass]]
The list of optimizations to be applied.
Returns
-------
ret: tvm.relay.Expr
The optimized expression.
"""
if isinstance(passes, Pass):
passes = [passes]
if not isinstance(passes, (list, tuple)):
raise TypeError("passes must be a pass or a list of pass objects.")

return _transform.OptimizeOnExpr(expr, passes)


def _wrap_class_module_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
class PyModulePass(ModulePass):
Expand Down
3 changes: 0 additions & 3 deletions src/relay/pass/dead_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,6 @@ Expr DeadCodeElimination(const Expr& e, bool inline_once) {
return CalcDep::Eliminate(e, inline_once);
}

TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
.set_body_typed(DeadCodeElimination);

namespace transform {

Pass DeadCodeElimination(bool inline_once) {
Expand Down
Loading

0 comments on commit f2a6851

Please sign in to comment.