Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Enable T.macro decorateing class method #17435

Merged
merged 2 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def _find_parser_def(self):
def get_macro_def(self):
ast_module = self.source.as_ast()
for decl in ast_module.body:
if isinstance(decl, doc.FunctionDef) and decl.name == self.__name__:
if isinstance(decl, doc.FunctionDef) and decl.name == self.func.__name__:
return decl
raise RuntimeError(f"cannot find macro definition for {self.__name__}")
raise RuntimeError(f"cannot find macro definition for {self.func.__name__}")

def __call__(self, *args, **kwargs):
param_binding = inspect.signature(self.func).bind(*args, **kwargs)
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,11 @@ def macro(*args, hygienic: bool = True) -> _Callable:
def _decorator(func: _Callable) -> ScriptMacro:
source, closure_vars = scan_macro(func, utils.inspect_function_capture(func))
obj = RelaxMacro(source, closure_vars, func, hygienic)
obj.__name__ = func.__name__
return obj

def wrapper(*args, **kwargs):
return obj(*args, **kwargs)

return wrapper

if len(args) == 0:
return _decorator
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,11 @@ def use2(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32")) -> None:
def _decorator(func: Callable) -> TIRMacro:
source, closure_vars = scan_macro(func, utils.inspect_function_capture(func))
obj = TIRMacro(source, closure_vars, func, hygienic)
obj.__name__ = func.__name__
return obj

def wrapper(*args, **kwargs):
return obj(*args, **kwargs)

return wrapper

if len(args) == 0:
return _decorator
Expand Down
42 changes: 38 additions & 4 deletions tests/python/tvmscript/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ def evaluate0():
def func1():
T.evaluate(0)

assert func1.hygienic

@T.prim_func(private=True)
def use1():
func1()
Expand All @@ -129,8 +127,6 @@ def use1():
def func2():
T.evaluate(0)

assert func2.hygienic

@T.prim_func(private=True)
def use2():
func2()
Expand Down Expand Up @@ -212,6 +208,44 @@ def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32"
tvm.ir.assert_structural_equal(use_non_hygienic, expected_non_hygienic)


def test_tir_macro_in_class():
class Object:
def __init__(self, x: T.Buffer):
self.local_x = T.alloc_buffer(x.shape, x.dtype)

@T.macro
def load(self, x: T.Buffer):
N, M = T.meta_var(self.local_x.shape)
for i, j in T.grid(N, M):
with T.block("update"):
vi, vj = T.axis.remap("SS", [i, j])
self.local_x[vi, vj] = x[vi, vj]

@T.prim_func(private=True)
def func_w_macro(a: T.handle):
A = T.match_buffer(a, [128, 128])
o1 = T.meta_var(Object(A))
o1.load(A)
o2 = T.meta_var(Object(A))
o2.load(o1.local_x)

@T.prim_func(private=True)
def func_no_macro(a: T.handle):
A = T.match_buffer(a, [128, 128])
local_a = T.alloc_buffer([128, 128])
for i, j in T.grid(128, 128):
with T.block("update"):
vi, vj = T.axis.remap("SS", [i, j])
local_a[vi, vj] = A[vi, vj]
local_b = T.alloc_buffer([128, 128])
for i, j in T.grid(128, 128):
with T.block("update"):
vi, vj = T.axis.remap("SS", [i, j])
local_b[vi, vj] = local_a[vi, vj]

tvm.ir.assert_structural_equal(func_no_macro, func_w_macro)


def test_tir_starred_expression():
dims = (128, 128)

Expand Down
Loading