Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Implement Rewriter class for pattern-rewrite #17149

Merged
merged 17 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,47 @@ class BlockBuilderNode : public Object {
* \brief Begin a new scope, with optional parameters that
* are visible within the scope.
*
* Symbolic variables from the parent scope are not available.
*
* \param params Parameters that are visible within the scope.
*
* \note This function should be called when new scope is introduced
* (function, seq) to properly track the variable availability
* and help the best effort deduction.
* (e.g. function bodies) to properly track the variable
* availability and help the best effort deduction.
*
* \sa EndScope
*/
virtual void BeginScope(Optional<Array<Var>> params) = 0;

/*!
* \brief Begin a new scope, which inherits visible parameters from
* its parent scope.
*
* Symbolic variables from the parent scope are available.
*
* \note This function should be called when an inner scope is
* introduced (e.g. conditional branches) to properly track
* the variable availability and help the best effort
* deduction.
*
* \sa EndScope
*/
virtual void BeginInnerScope() = 0;

/*!
* \brief Append a definition to the current scope.
*
* \param var A variable within the current scope.
*
* \note This function should be called when a new variable is
* defined that may impact struct inference (e.g. MatchCast)
* to properly track the variable availability and help the
* best effort deduction.
*
* \sa EndScope
*/
virtual void AddDefinitionToScope(Var var) = 0;

/*! \brief End the previously defined scope. */
virtual void EndScope() = 0;

Expand Down
21 changes: 20 additions & 1 deletion include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,10 @@ class ExprMutator : public ExprMutatorBase {
void ReEmitBinding(const VarBindingNode* binding, Expr new_value);

/*!
* \brief Rewrite the expr with a new scope, used in a Function's body and the branches of If.
* \brief Rewrite the expr with a new scope, used in a Function's body.
*
* Visit an expression that may neither access variables from the
* current scope, nor may export definitions into the current scope.
*
* \param body_expr The body to be visited.
* \param params Optional parameters that are visible within the scope.
Expand All @@ -504,6 +507,22 @@ class ExprMutator : public ExprMutatorBase {
*/
Expr VisitWithNewScope(const Expr& body_expr, Optional<Array<Var>> params = NullOpt);

/*!
* \brief Rewrite the expr with a new scope, used in the branches of If.
*
* Visit an expression that may access variables from the current
* scope, but may not export definitions into the current scope.
*
* \param body_expr The body to be visited.
*
* \return The expr after visiting.
*
* \sa VisitWithNewScope
*
* \note The body_expr must be an SeqExpr in the normal form.
*/
Expr VisitWithInnerScope(const Expr& body_expr);

/*!
* \brief Look up the value bound to a variable.
* \param var The var to be looked up.
Expand Down
1 change: 1 addition & 0 deletions include/tvm/script/ir_builder/relax/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class FunctionFrameNode : public SeqExprFrameNode {
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionFrameNode, SeqExprFrameNode);

public:
void EnterWithScope() final;
void ExitWithScope() final;
};

Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relax/dpl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,10 @@

from .pattern import *
from .context import *
from .rewrite import rewrite_call, rewrite_bindings
from .rewrite import (
rewrite_call,
rewrite_bindings,
PatternMatchingRewriter,
ExprPatternRewriter,
OrRewriter,
)
186 changes: 183 additions & 3 deletions python/tvm/relax/dpl/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,196 @@
# specific language governing permissions and limitations
# under the License.
"""APIs for pattern-based rewriting."""
from typing import Dict, Callable

from typing import Dict, Callable, Union

from tvm.ir import IRModule
from tvm.runtime import Object
from tvm._ffi import register_object

from .pattern import DFPattern
from .context import PatternContext

from ..expr import Expr, Function, Var
from . import _ffi as ffi


@register_object("relax.dpl.PatternMatchingRewriter")
class PatternMatchingRewriter(Object):
"""A pattern-matching rewriter for Relax"""

@staticmethod
def from_pattern(
pattern: DFPattern,
func: Callable[[Expr, Dict[DFPattern, Expr]], Expr],
) -> "PatternMatchingRewriter":
"""Construct from a pattern and rewriter-function

The replacements performed by the rewriter will be equivalent
to using the `pattern` and `func` as arguments to
`rewrite_call`.

Parameters
----------
pattern: DFPattern

The pattern to be matched against.

func: Callable[[Expr, Dict[DFPattern, Expr]], Expr]

A function that returns the rewritten expression. See
`rewrite_call` for details and examples.


Returns
-------
rewriter_obj: PatternMatchingRewriter

The rewriter object

"""
return ffi.PatternMatchingRewriterFromPattern(
pattern,
func,
) # type: ignore

@staticmethod
def from_module(mod: IRModule) -> "PatternMatchingRewriter":
"""Construct a rewriter from an IRModule

The IRModule must have two publicly-exposed functions,
`pattern` and `replacement`, where `pattern` and `replacement`
have the same function signature, as shown in the example
below.

.. code-block:: python

@I.ir_module
class RewriteAddIntoMultiply:
@R.function
def pattern(A: R.Tensor):
B = A + A
return B

@R.function
def replacement(A: R.Tensor):
B = A * 2
return B

rewriter = PatternMatchingRewriter.from_module(RewriteAddIntoMultiply)
rewritten_ir_module = rewriter(ir_module)

To support the common case of defining an IRModule with
TVMScript, then immediately turning it into a rewriter, the
`@R.rewriter` annotation can be used.

.. code-block:: python

@R.rewriter
class RewriteAddIntoMultiply:
@R.function
def pattern(A: R.Tensor):
B = A + A
return B

@R.function
def replacement(A: R.Tensor):
B = A * 2
return B

rewritten_ir_module = RewriteAddIntoMultiply(ir_module)

Parameters
----------
mod: IRModule

A module with `pattern` and `replacement` functions,
sunggg marked this conversation as resolved.
Show resolved Hide resolved
defining a rewrite rule.


Returns
-------
rewriter_obj: PatternMatchingRewriter

The rewriter object

"""
return ffi.PatternMatchingRewriterFromModule(mod) # type: ignore

def __call__(self, obj: Union[Expr, IRModule]) -> Union[Expr, IRModule]:
"""Apply the rewriter

Parameters
----------
obj: Union[Expr, IRModule])

The object to be rewritten. May be applied to either a
relax expression, or an IRModule.

Returns
-------
updated: Union[Expr, IRModule]

The rewritten object

"""
return ffi.PatternMatchingRewriterApply(self, obj)

def __or__(self, other: "PatternMatchingRewriter") -> "PatternMatchingRewriter":
"""Compose two rewriters

Composing two rewrite rules together allows them to be applied
in a single Relax-level transformation.

Parameters
----------
other: PatternMatchingRewriter

Another rewrite rule

Returns
-------
PatternMatchingRewriter

A rewriter that will apply either rewrite pattern

"""
return OrRewriter(self, other)


@register_object("relax.dpl.ExprPatternRewriter")
class ExprPatternRewriter(PatternMatchingRewriter):
def __init__(self, pattern, func):
self.__init_handle_by_constructor__(
ffi.PatternRewriter,
pattern,
func,
) # type: ignore


@register_object("relax.dpl.OrRewriter")
class OrRewriter(PatternMatchingRewriter):
def __init__(self, lhs, rhs):
self.__init_handle_by_constructor__(
ffi.OrRewriter,
lhs,
rhs,
) # type: ignore


@register_object("relax.dpl.TupleRewriter")
class TupleRewriter(PatternMatchingRewriter):
def __init__(self, patterns, func):
self.__init_handle_by_constructor__(
ffi.TupleRewriter,
patterns,
func,
) # type: ignore


def rewrite_call(
pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function
pattern: DFPattern,
rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr],
func: Function,
) -> Function:
"""
Rewrite a function with the given pattern and the rewriter function.
Expand Down
48 changes: 46 additions & 2 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import builtins
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type

import tvm
from tvm import DataType, relax
from tvm.ir import PrimExpr, VDevice
from tvm.ir import PrimExpr, VDevice, IRModule
from tvm.relax import (
Call,
Expr,
Expand All @@ -35,6 +35,7 @@
VarBinding,
const,
)
from tvm.relax.dpl import PatternMatchingRewriter

############################### Operators ###############################
from tvm.relax.op import (
Expand Down Expand Up @@ -306,6 +307,48 @@ def func_ret_value(value: Expr) -> None:
return _ffi_api.FuncRetValue(value) # type: ignore[attr-defined] # pylint: disable=no-member


def rewriter(rewriter_mod: Union[IRModule, Type]) -> PatternMatchingRewriter:
"""Define a pattern-rewrite rule

The IRModule must have two publicly-exposed functions, `pattern`
and `replacement`, where `pattern` and `replacement` have the same
function signature.

.. code-block:: python

@R.rewriter
class RewriteAddIntoMultiply:
@R.function
def pattern(A: R.Tensor):
B = A + A
return B

@R.function
def replacement(A: R.Tensor):
B = A * 2
return B

Parameters
----------
rewriter_mod: Union[IRModule, Type]

Either an IRModule that defines a rewrite pattern, or a
TVMScript class that can be parsed into an IRModule.

Returns
-------
rewriter: PatternMatchingRewriter

A rewriter object, which can be applied either to a Relax
function or to an entire IRModule.

"""
if not isinstance(rewriter_mod, IRModule):
rewriter_mod = tvm.script.ir_module(rewriter_mod)

return PatternMatchingRewriter.from_module(rewriter_mod)


############################# BindingBlock ##############################


Expand Down Expand Up @@ -765,6 +808,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"dequantize",
"repeat",
"reshape",
"rewriter",
"tensor_to_shape",
"shape_to_tensor",
"rocm",
Expand Down
Loading
Loading