Skip to content

Commit

Permalink
[Relax][Refactor] Implement Rewriter class for pattern-rewrite
Browse files Browse the repository at this point in the history
Prior to this commit, the pattern to be matched and the rewrite to be
performed were provided as separate arguments.  This commit introduces
a new class `ExprRewriter`, which contains both parts.

This abstraction will make it easier to combine multiple different
rewrite rules, applying them in a single pass.
  • Loading branch information
Lunderberg committed Jul 10, 2024
1 parent 486bfca commit 7488ea1
Show file tree
Hide file tree
Showing 11 changed files with 2,712 additions and 217 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relax/dpl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@

from .pattern import *
from .context import *
from .rewrite import rewrite_call, rewrite_bindings
from .rewrite import rewrite_call, rewrite_bindings, ExprRewriter, PatternRewriter, OrRewriter
63 changes: 61 additions & 2 deletions python/tvm/relax/dpl/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,75 @@
# specific language governing permissions and limitations
# under the License.
"""APIs for pattern-based rewriting."""
from typing import Dict, Callable

from typing import Dict, Callable, Union
from .pattern import DFPattern
from .context import PatternContext

from tvm.ir import IRModule
from tvm.runtime import Object
from tvm._ffi import register_object
from ..expr import Expr, Function, Var
from . import _ffi as ffi


@register_object("relax.dpl.ExprRewriter")
class ExprRewriter(Object):
@staticmethod
def from_pattern(
pattern: DFPattern,
func: Callable[[Expr, Dict[DFPattern, Expr]], Expr],
) -> "ExprRewriter":
return ffi.ExprRewriterFromPattern(
pattern,
func,
) # type: ignore

@staticmethod
def from_module(mod: IRModule) -> "ExprRewriter":
return ffi.ExprRewriterFromModule(mod) # type: ignore

def __call__(self, obj: Union[Expr, IRModule]) -> Union[Expr, IRModule]:
return ffi.ExprRewriterApply(self, obj)

def __or__(self, other: "ExprRewriter") -> "ExprRewriter":
return OrRewriter(self, other)


@register_object("relax.dpl.PatternRewriter")
class PatternRewriter(ExprRewriter):
def __init__(self, pattern, func):
self.__init_handle_by_constructor__(
ffi.PatternRewriter,
pattern,
func,
) # type: ignore


@register_object("relax.dpl.OrRewriter")
class OrRewriter(ExprRewriter):
def __init__(self, lhs, rhs):
self.__init_handle_by_constructor__(
ffi.OrRewriter,
lhs,
rhs,
) # type: ignore


@register_object("relax.dpl.TupleRewriter")
class TupleRewriter(ExprRewriter):
def __init__(self, patterns, func):
self.__init_handle_by_constructor__(
ffi.TupleRewriter,
patterns,
func,
) # type: ignore


def rewrite_call(
pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function
pattern: DFPattern,
rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr],
func: Function,
) -> Function:
"""
Rewrite a function with the given pattern and the rewriter function.
Expand Down
48 changes: 46 additions & 2 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import builtins
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type

import tvm
from tvm import DataType, relax
from tvm.ir import PrimExpr, VDevice
from tvm.ir import PrimExpr, VDevice, IRModule
from tvm.relax import (
Call,
Expr,
Expand All @@ -35,6 +35,7 @@
VarBinding,
const,
)
from tvm.relax.dpl import ExprRewriter

############################### Operators ###############################
from tvm.relax.op import (
Expand Down Expand Up @@ -306,6 +307,48 @@ def func_ret_value(value: Expr) -> None:
return _ffi_api.FuncRetValue(value) # type: ignore[attr-defined] # pylint: disable=no-member


def rewriter(rewriter_mod: Union[IRModule, Type]) -> ExprRewriter:
"""Define a pattern-rewrite rule
The IRModule must have two publicly-exposed functions, `pattern`
and `replacement`, where `pattern` and `replacement` have the same
function signature.
.. code-block:: python
@R.rewriter
class RewriteAddIntoMultiply:
@R.function
def pattern(A: R.Tensor):
B = A + A
return B
@R.function
def replacement(A: R.Tensor):
B = A * 2
return B
Parameters
----------
rewriter_mod: Union[IRModule, Type]
Either an IRModule that defines a rewrite pattern, or a
TVMScript class that can be parsed into an IRModule.
Returns
-------
rewriter: ExprRewriter
A rewriter object, which can be applied either to a Relax
function or to an entire IRModule.
"""
if not isinstance(rewriter_mod, IRModule):
rewriter_mod = tvm.script.ir_module(rewriter_mod)

return ExprRewriter.from_module(rewriter_mod)


############################# BindingBlock ##############################


Expand Down Expand Up @@ -765,6 +808,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"dequantize",
"repeat",
"reshape",
"rewriter",
"tensor_to_shape",
"shape_to_tensor",
"rocm",
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/script/parser/core/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def parse(
The parsed TVMScript program.
"""
if extra_vars is None:
extra_vars = _default_globals()
extra_vars = {}
extra_vars = {**extra_vars, **_default_globals()}

ann = {}
if inspect.isfunction(program):
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/script/parser/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,29 @@ def is_defined_in_class(frames: List[FrameType], obj: Any) -> bool:
res : bool
The result if the object is defined in a class scope.
"""

def _is_tvmscript_class_annotator(line: str) -> bool:
"""Checks if the line contains a TVMScript annotator for a class
These match either `@I.ir_module` or `@R.rewriter`, or their
imported names `@ir_module` or `@rewriter`.
"""

return line.startswith("@") and ("ir_module" in line or "rewriter" in line)

if len(frames) > 2:
frame_info = frames[2]
code_context = frame_info.code_context
if code_context is None:
return False
line = code_context[0].strip()
if line.startswith("@") and "ir_module" in line:
if _is_tvmscript_class_annotator(line):
return True
if line.startswith("class"):
lineno = frame_info.lineno
if lineno >= 2:
source, _ = findsource(obj)
line = source[lineno - 2].strip()
if line.startswith("@") and "ir_module" in line:
if _is_tvmscript_class_annotator(line):
return True
return False
Loading

0 comments on commit 7488ea1

Please sign in to comment.