Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][PASS] Enable decorating python class as Pass #3364

Merged
merged 1 commit into from
Jun 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/dmlc-core
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
bind = expr.bind
module_pass = transform.module_pass
function_pass = transform.function_pass
alpha_equal = ir_pass.alpha_equal

# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
Expand Down
185 changes: 150 additions & 35 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
Relay pass transformation infrastructure.
"""
import types
import inspect
import functools

from tvm._ffi.runtime_ctypes import TVMContext
from . import _transform
Expand Down Expand Up @@ -444,16 +446,47 @@ def PartialEvaluate():
return _transform.PartialEvaluate()


def _wrap_class_module_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
class PyModulePass(ModulePass):
"""Internal wrapper class to create a class instance."""
def __init__(self, *args, **kwargs):
# initialize handle in cass pass_cls creation failed.fg
self.handle = None
inst = pass_cls(*args, **kwargs)
# it is important not to capture self to
# avoid a cyclic dependency
def _pass_func(mod, ctx):
return inst.transform_module(mod, ctx)
self.__init_handle_by_constructor__(
_transform.MakeModulePass, _pass_func, pass_info)
self._inst = inst

def __getattr__(self, name):
# fall back to instance attribute if there is not any
return self._inst.__getattribute__(name)

functools.update_wrapper(PyModulePass.__init__, pass_cls.__init__)
PyModulePass.__name__ = pass_cls.__name__
PyModulePass.__doc__ = pass_cls.__doc__
PyModulePass.__module__ = pass_cls.__module__
return PyModulePass


def module_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a module pass. This function returns a callback when pass_func
is provided. Otherwise, it returns the created module level pass using the
given optimization function.
"""Decorate a module pass.

This function returns a callback when pass_func is provided.
Otherwise, it serves a decorator function.

pass_func can also be a class type with a method transform_module.
This function will create a decorated ModulePass using transform_module
as the pass function.

Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
pass_func : Optional[Callable[(Module, PassContext) ->Module]]
The transformation function or class.

opt_level : int
The optimization level of this module pass.
Expand All @@ -468,14 +501,39 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
Returns
-------
create_module_pass : Union[Callable, ModulePass]
The callable that will create a module pass is returned when
pass_func is not passed in. Otherwise, a ModulePass object will be
directly created.
A decorator will be returned if pass_func is not provided,
otherwise return the decorated result.
The returned decorator has two behaviors depending on the input:
A new ModulePass will be returned when we decorate a pass function.
A new ModulePass class will be returned when we decorate a class type.

Examples
--------
The following code creates a module level pass and adds an abs function to
the module.
The following code block decorates a module pass class.

.. code-block:: python

@relay.transform.module_pass
class CustomPipeline:
def __init__(self, enable_fold):
self.enable_fold = enable_fold
self.cse = relay.transform.EliminateCommonSubexpr()
self.const_fold = relay.transform.FoldConstant()

def transform_module(self, mod, ctx):
mod = self.cse(mod, ctx)
if self.enable_fold:
mod = self.const_fold(mod, ctx)
return mod

# create an instance of customized pipeline
pipeline = CustomPipeline(enable_fold=False)
assert isinstance(pipeline, transform.ModulePass)
# run the pipeline.
output_module = pipeline(input_module)

The following code creates a module pass by decorating
a user defined transform function.

.. code-block:: python

Expand All @@ -497,7 +555,6 @@ def transform(mod, ctx):
updated_mod = module_pass(m)
# Now a function abs should be added to the module m.
"""

if opt_level is None:
raise ValueError("Please provide opt_level for the module pass.")

Expand All @@ -506,30 +563,59 @@ def transform(mod, ctx):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")

def create_module_pass(pass_func):
def create_module_pass(pass_arg):
"""Internal function that creates a module pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")

fname = name if name else pass_func.__name__
fname = name if name else pass_arg.__name__
info = PassInfo(opt_level, fname, required)
return _transform.MakeModulePass(pass_func, info)
if inspect.isclass(pass_arg):
return _wrap_class_module_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _transform.MakeModulePass(pass_arg, info)

if pass_func:
return create_module_pass(pass_func)
return create_module_pass


def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
class PyFunctionPass(FunctionPass):
"""Internal wrapper class to create a class instance."""
def __init__(self, *args, **kwargs):
# initialize handle in cass pass_cls creation failed.fg
self.handle = None
inst = pass_cls(*args, **kwargs)
# it is important not to capture self to
# avoid a cyclic dependency
def _pass_func(func, mod, ctx):
return inst.transform_function(func, mod, ctx)
self.__init_handle_by_constructor__(
_transform.MakeFunctionPass, _pass_func, pass_info)
self._inst = inst

def __getattr__(self, name):
# fall back to instance attribute if there is not any
return self._inst.__getattribute__(name)

functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__)
PyFunctionPass.__name__ = pass_cls.__name__
PyFunctionPass.__doc__ = pass_cls.__doc__
PyFunctionPass.__module__ = pass_cls.__module__
return PyFunctionPass


def function_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Create a function pass. This function returns a callback when pass_func
"""Decorate a function pass.

This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.

Parameters
----------
pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
The implemented optimization pass.
pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]]
The transformation function or class.

opt_level : int
The optimization level of this module pass.
Expand All @@ -544,20 +630,48 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
Returns
-------
create_function_pass : Union[Callable, FunctionPass]
The callable that will create a function pass is returned when
pass_func is not passed in. Otherwise, a FunctionPass object will be
created.

A decorator will be returned if pass_func is not provided,
otherwise return the decorated result.
The returned decorator has two behaviors depending on the input:
A new FunctionPass will be returned when we decorate a pass function.
A new FunctionPass class will be returned when we decorate a class type.

Examples
--------
The following code creates a function level pass that performs constant
folding.
The following code block decorates a function pass class.

.. code-block:: python

@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
def __init__(self, new_func):
self.new_func = new_func

def transform_function(self, func, mod, ctx):
# just for demo purposes
# transform func to new_func
return self.new_func

x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
# fpass is now a special pass that replaces every
# function to f1
fpass = TestReplaceFunc(f1)
# now every function in input_mod is replaced by f1
res_mod = fpass(input_mod)


The following code creates a function pass by decorating
a user defined transform function.

.. code-block:: python

@relay.transform.function_pass(opt_level=2)
def transform(func, ctx):
return ir_pass.fold_constant(func)
def transform(func, mod, ctx):
# my transformations here.
return func

function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
Expand All @@ -577,14 +691,15 @@ def transform(func, ctx):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")

def create_function_pass(pass_func):
def create_function_pass(pass_arg):
"""Internal function that creates a function pass"""
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")

fname = name if name else pass_func.__name__
fname = name if name else pass_arg.__name__
info = PassInfo(opt_level, fname, required)
return _transform.MakeFunctionPass(pass_func, info)
if inspect.isclass(pass_arg):
return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _transform.MakeFunctionPass(pass_arg, info)

if pass_func:
return create_function_pass(pass_func)
Expand Down
49 changes: 49 additions & 0 deletions tests/python/relay/test_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,29 @@ def test_pass_run():
test_pass_run()


def test_function_class_pass():
@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
"""Simple test function to replace one argument to another."""
def __init__(self, new_func):
self.new_func = new_func

def transform_function(self, func, mod, ctx):
return self.new_func

x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
fpass = TestReplaceFunc(f1)
assert fpass.info.opt_level == 1
assert fpass.info.name == "TestReplaceFunc"
mod = relay.Module.from_expr(f2)
mod = fpass(mod)
# wrap in expr
mod2 = relay.Module.from_expr(f1)
assert relay.alpha_equal(mod["main"], mod2["main"])


def test_function_pass():
shape = (10, )
dtype = 'float32'
Expand Down Expand Up @@ -259,6 +282,30 @@ def test_pass_run():
test_pass_run()


def test_module_class_pass():
@relay.transform.module_pass(opt_level=1)
class TestPipeline:
"""Simple test function to replace one argument to another."""
def __init__(self, new_mod, replace):
self.new_mod = new_mod
self.replace = replace

def transform_module(self, mod, ctx):
if self.replace:
return self.new_mod
return mod

x = relay.var("x", shape=(10, 20))
m1 = relay.Module.from_expr(relay.Function([x], x))
m2 = relay.Module.from_expr(relay.Function([x], relay.log(x)))
fpass = TestPipeline(m2, replace=True)
assert fpass.info.name == "TestPipeline"
mod3 = fpass(m1)
assert mod3.same_as(m2)
mod4 = TestPipeline(m2, replace=False)(m1)
assert mod4.same_as(m1)


def test_pass_info():
info = relay.transform.PassInfo(opt_level=1, name="xyz")
assert info.opt_level == 1
Expand Down Expand Up @@ -451,6 +498,8 @@ def expected():


if __name__ == "__main__":
test_function_class_pass()
test_module_class_pass()
test_module_pass()
test_function_pass()
test_sequential_pass()
Expand Down