Skip to content

Commit

Permalink
[RELAY][PASS] Enable decorating python class as Pass
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jun 13, 2019
1 parent a698ad7 commit ccfe250
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 36 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/dmlc-core
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
bind = expr.bind
module_pass = transform.module_pass
function_pass = transform.function_pass
alpha_equal = ir_pass.alpha_equal

# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
Expand Down
185 changes: 150 additions & 35 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
Relay pass transformation infrastructure.
"""
import types
import inspect
import functools

from tvm._ffi.runtime_ctypes import TVMContext
from . import _transform
Expand Down Expand Up @@ -444,16 +446,47 @@ def PartialEvaluate():
return _transform.PartialEvaluate()


def _wrap_class_module_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
class PyModulePass(ModulePass):
"""Internal wrapper class to create a class instance."""
def __init__(self, *args, **kwargs):
# initialize handle in cass pass_cls creation failed.fg
self.handle = None
inst = pass_cls(*args, **kwargs)
# it is important not to capture self to
# avoid a cyclic dependency
def _pass_func(mod, ctx):
return inst.transform_module(mod, ctx)
self.__init_handle_by_constructor__(
_transform.MakeModulePass, _pass_func, pass_info)
self._inst = inst

def __getattr__(self, name):
# fall back to instance attribute if there is not any
return self._inst.__getattribute__(name)

functools.update_wrapper(PyModulePass.__init__, pass_cls.__init__)
PyModulePass.__name__ = pass_cls.__name__
PyModulePass.__doc__ = pass_cls.__doc__
PyModulePass.__module__ = pass_cls.__module__
return PyModulePass


def module_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a module pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created module level pass using the
given optimization function.
"""Decorate a module pass.
This function returns a callback when pass_func is provided.
Otherwise, it serves a decorator function.
pass_func can also be a class type with a method transform_module.
This function will create a decorated ModulePass using transform_module
as the pass function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
pass_func : Optional[Callable[(Module, PassContext) ->Module]]
The transformation function or class.
opt_level : int
The optimization level of this module pass.
Expand All @@ -468,14 +501,39 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
Returns
-------
create_module_pass : Union[Callable, ModulePass]
The callable that will create a module pass is returned when
pass_func is not passed in. Otherwise, a ModulePass object will be
directly created.
A decorator will be returned if pass_func is not provided,
otherwise return the decorated result.
The returned decorator has two behaviors depending on the input:
A new ModulePass will be returned when we decorate a pass function.
A new ModulePass class will be returned when we decorate a class type.
Examples
--------
The following code creates a module level pass and adds an abs function to
the module.
The following code block decorates a module pass class.
.. code-block:: python
@relay.transform.module_pass
class CustomPipeline:
def __init__(self, enable_fold):
self.enable_fold = enable_fold
self.cse = relay.transform.EliminateCommonSubexpr()
self.const_fold = relay.transform.FoldConstant()
def transform_module(self, mod, ctx):
mod = self.cse(mod, ctx)
if self.enable_fold:
mod = self.const_fold(mod, ctx)
return mod
# create an instance of customized pipeline
pipeline = CustomPipeline(enable_fold=False)
assert isinstance(pipeline, transform.ModulePass)
# run the pipeline.
output_module = pipeline(input_module)
The following code creates a module pass by decorating
a user defined transform function.
.. code-block:: python
Expand All @@ -497,7 +555,6 @@ def transform(mod, ctx):
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
"""

if opt_level is None:
raise ValueError("Please provide opt_level for the module pass.")

Expand All @@ -506,30 +563,59 @@ def transform(mod, ctx):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")

def create_module_pass(pass_func):
def create_module_pass(pass_arg):
"""Internal function that creates a module pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")

fname = name if name else pass_func.__name__
fname = name if name else pass_arg.__name__
info = PassInfo(opt_level, fname, required)
return _transform.MakeModulePass(pass_func, info)
if inspect.isclass(pass_arg):
return _wrap_class_module_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _transform.MakeModulePass(pass_arg, info)

if pass_func:
return create_module_pass(pass_func)
return create_module_pass


def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
class PyFunctionPass(FunctionPass):
"""Internal wrapper class to create a class instance."""
def __init__(self, *args, **kwargs):
# initialize handle in cass pass_cls creation failed.fg
self.handle = None
inst = pass_cls(*args, **kwargs)
# it is important not to capture self to
# avoid a cyclic dependency
def _pass_func(func, mod, ctx):
return inst.transform_function(func, mod, ctx)
self.__init_handle_by_constructor__(
_transform.MakeFunctionPass, _pass_func, pass_info)
self._inst = inst

def __getattr__(self, name):
# fall back to instance attribute if there is not any
return self._inst.__getattribute__(name)

functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__)
PyFunctionPass.__name__ = pass_cls.__name__
PyFunctionPass.__doc__ = pass_cls.__doc__
PyFunctionPass.__module__ = pass_cls.__module__
return PyFunctionPass


def function_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a function pass. This function returns a callback when pass_func
"""Decorate a function pass.
This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]]
The transformation function or class.
opt_level : int
The optimization level of this module pass.
Expand All @@ -544,20 +630,48 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
Returns
-------
create_function_pass : Union[Callable, FunctionPass]
The callable that will create a function pass is returned when
pass_func is not passed in. Otherwise, a FunctionPass object will be
created.
A decorator will be returned if pass_func is not provided,
otherwise return the decorated result.
The returned decorator has two behaviors depending on the input:
A new FunctionPass will be returned when we decorate a pass function.
A new FunctionPass class will be returned when we decorate a class type.
Examples
--------
The following code creates a function level pass that performs constant
folding.
The following code block decorates a function pass class.
.. code-block:: python
@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
def __init__(self, new_func):
self.new_func = new_func
def transform_function(self, func, mod, ctx):
# just for demo purposes
# transform func to new_func
return self.new_func
x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
# fpass is now a special pass that replaces every
# function to f1
fpass = TestReplaceFunc(f1)
# now every function in input_mod is replaced by f1
res_mod = fpass(input_mod)
The following code creates a function pass by decorating
a user defined transform function.
.. code-block:: python
@relay.transform.function_pass(opt_level=2)
def transform(func, ctx):
return ir_pass.fold_constant(func)
def transform(func, mod, ctx):
# my transformations here.
return func
function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
Expand All @@ -577,14 +691,15 @@ def transform(func, ctx):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")

def create_function_pass(pass_func):
def create_function_pass(pass_arg):
"""Internal function that creates a function pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")

fname = name if name else pass_func.__name__
fname = name if name else pass_arg.__name__
info = PassInfo(opt_level, fname, required)
return _transform.MakeFunctionPass(pass_func, info)
if inspect.isclass(pass_arg):
return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _transform.MakeFunctionPass(pass_arg, info)

if pass_func:
return create_function_pass(pass_func)
Expand Down
49 changes: 49 additions & 0 deletions tests/python/relay/test_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,29 @@ def test_pass_run():
test_pass_run()


def test_function_class_pass():
@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
"""Simple test function to replace one argument to another."""
def __init__(self, new_func):
self.new_func = new_func

def transform_function(self, func, mod, ctx):
return self.new_func

x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
fpass = TestReplaceFunc(f1)
assert fpass.info.opt_level == 1
assert fpass.info.name == "TestReplaceFunc"
mod = relay.Module.from_expr(f2)
mod = fpass(mod)
# wrap in expr
mod2 = relay.Module.from_expr(f1)
assert relay.alpha_equal(mod["main"], mod2["main"])


def test_function_pass():
shape = (10, )
dtype = 'float32'
Expand Down Expand Up @@ -259,6 +282,30 @@ def test_pass_run():
test_pass_run()


def test_module_class_pass():
@relay.transform.module_pass(opt_level=1)
class TestPipeline:
"""Simple test function to replace one argument to another."""
def __init__(self, new_mod, replace):
self.new_mod = new_mod
self.replace = replace

def transform_module(self, mod, ctx):
if self.replace:
return self.new_mod
return mod

x = relay.var("x", shape=(10, 20))
m1 = relay.Module.from_expr(relay.Function([x], x))
m2 = relay.Module.from_expr(relay.Function([x], relay.log(x)))
fpass = TestPipeline(m2, replace=True)
assert fpass.info.name == "TestPipeline"
mod3 = fpass(m1)
assert mod3.same_as(m2)
mod4 = TestPipeline(m2, replace=False)(m1)
assert mod4.same_as(m1)


def test_pass_info():
info = relay.transform.PassInfo(opt_level=1, name="xyz")
assert info.opt_level == 1
Expand Down Expand Up @@ -451,6 +498,8 @@ def expected():


if __name__ == "__main__":
test_function_class_pass()
test_module_class_pass()
test_module_pass()
test_function_pass()
test_sequential_pass()
Expand Down

0 comments on commit ccfe250

Please sign in to comment.