Skip to content

Commit

Permalink
[UnitTest][TIR] Support IRModule comparisons in CompareBeforeAfter
Browse files Browse the repository at this point in the history
A follow-up commit from apache#12264.
This allows the before/expected fixtures generated by
`tvm.testing.CompareBeforeAfter` to be `IRModule` instances as well as
`PrimFunc`.  This is intended to allow testing that requires comparing
more than one function (e.g. hoisting/fusing a PrimFunc).
  • Loading branch information
Lunderberg committed Sep 27, 2022
1 parent 7dbc68d commit 4768d8b
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 37 deletions.
103 changes: 67 additions & 36 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1856,13 +1856,11 @@ def __init_subclass__(cls):
if hasattr(cls, "expected"):
cls.expected = cls._normalize_expected(cls.expected)
if hasattr(cls, "transform"):
cls._transform_orig = cls.transform
cls.transform = cls._normalize_transform(cls.transform)

@classmethod
def _normalize_before(cls, func):
if hasattr(func, "_pytestfixturefunction"):
return func

def _normalize_ir_module(cls, func):
if isinstance(func, tvm.tir.PrimFunc):

def inner(self):
Expand All @@ -1875,6 +1873,22 @@ def inner(self):
# pylint: disable=unused-argument
return func(self)

elif inspect.isclass(func):

def inner(self):
# pylint: disable=unused-argument
func_dict = {}
for name, method in func.__dict__.items():
if name.startswith("_"):
pass
elif isinstance(method, tvm.ir.function.BaseFunc):
func_dict[name] = method
else:
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(method))
prim_func = tvm.script.from_source(source_code)
func_dict[name] = prim_func
return tvm.IRModule(func_dict)

else:

def inner(self):
Expand All @@ -1884,50 +1898,61 @@ def inner(self):

return pytest.fixture(inner)

@classmethod
def _normalize_before(cls, func):
if hasattr(func, "_pytestfixturefunction"):
return func
else:
return cls._normalize_ir_module(func)

@classmethod
def _normalize_expected(cls, func):
if hasattr(func, "_pytestfixturefunction"):
return func

if isinstance(func, tvm.tir.PrimFunc) or (
inspect.isclass(func) and issubclass(func, Exception)
):
elif inspect.isclass(func) and issubclass(func, Exception):

def inner(self):
# pylint: disable=unused-argument
return func

elif cls._is_method(func):

def inner(self):
# pylint: disable=unused-argument
return func(self)
return pytext.fixture(inner)

else:

def inner(self):
# pylint: disable=unused-argument
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func))
return tvm.script.from_source(source_code)

return pytest.fixture(inner)
return cls._normalize_ir_module(func)

@classmethod
def _normalize_transform(cls, transform):
def apply(module_transform):
def inner(obj):
if isinstance(obj, tvm.IRModule):
return module_transform(obj)
elif isinstance(obj, tvm.tir.PrimFunc):
mod = tvm.IRModule({"main": obj})
mod = module_transform(mod)
return mod["main"]
else:
raise TypeError(f"Expected IRModule or PrimFunc, but received {type(obj)}")

return inner

if hasattr(transform, "_pytestfixturefunction"):
return transform

if isinstance(transform, tvm.ir.transform.Pass):
def inner(self, _transform_orig):
# pylint: disable=unused-argument
return apply(_transform_orig)

elif isinstance(transform, tvm.ir.transform.Pass):

def inner(self):
# pylint: disable=unused-argument
return transform
return apply(transform)

elif cls._is_method(transform):

def inner(self):
# pylint: disable=unused-argument
return transform(self)
return apply(transform(self))

else:

Expand All @@ -1945,42 +1970,48 @@ def _is_method(func):
def test_compare(self, before, expected, transform):
"""Unit test to compare the expected TIR PrimFunc to actual"""

before_mod = tvm.IRModule.from_expr(before)
def pprint(name, obj):
script = obj.script()
if isinstance(obj, tvm.IRModule):
return script.replace("class Module", f"class {name}")
else:
return script.replace("def func", f"def {name}")

if inspect.isclass(expected) and issubclass(expected, Exception):
with pytest.raises(expected):
after_mod = transform(before_mod)
after = transform(before)

# This portion through pytest.fail isn't strictly
# necessary, but gives a better error message that
# includes the before/after.
after = after_mod["main"]
script = tvm.IRModule({"after": after, "before": before}).script()
before_str = pprint("before", before)
after_str = pprint("after", after)

pytest.fail(
msg=(
f"Expected {expected.__name__} to be raised from transformation, "
f"instead received TIR\n:{script}"
f"instead received TIR\n:{before_str}\n{after_str}"
)
)

elif isinstance(expected, tvm.tir.PrimFunc):
after_mod = transform(before_mod)
after = after_mod["main"]
elif isinstance(expected, (tvm.tir.PrimFunc, tvm.ir.IRModule)):
after = transform(before)

try:
tvm.ir.assert_structural_equal(after, expected)
except ValueError as err:
script = tvm.IRModule(
{"expected": expected, "after": after, "before": before}
).script()
before_str = pprint("before", before)
after_str = pprint("after", after)
expected_str = pprint("expected", expected)
raise ValueError(
f"TIR after transformation did not match expected:\n{script}"
f"TIR after transformation did not match expected:\n"
f"{before_str}\n{after_str}\n{expected_str}"
) from err

else:
raise TypeError(
f"tvm.testing.CompareBeforeAfter requires the `expected` fixture "
f"to return either `Exception`, an `Exception` subclass, "
f"or an instance of `tvm.tir.PrimFunc`. "
f"Instead, received {type(exception)}."
f"Instead, received {type(expected)}."
)
49 changes: 48 additions & 1 deletion tests/python/unittest/test_tvm_testing_before_after.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import tvm
import tvm.testing
from tvm.script import tir as T
from tvm.script import tir as T, ir_module


class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
Expand Down Expand Up @@ -79,5 +79,52 @@ def func(A: T.Buffer[n, "float32"]):
expected = before


class TestBeforeAfterIRModule(BaseBeforeAfter):
"""The preferred form for writing TIR unit tests
All evaluation is done at test-time, with the minimal amount of
additional lines. The `@tvm.testing.fixture`, `@ir_module`, and
`@T.prim_func` annotations are handled by
`tvm.testing.CompareBeforeAfter`.
"""

class before:
def func_A(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
A[i] = 0.0

def func_B(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
A[i] = 42

expected = before


class TestBeforeAfterIRModuleExplicitFixture(BaseBeforeAfter):
"""Like TestBeforeAfterIRModule, but with an explicit fixture
If the IRModule depends on additional fixtures, this form can be
used.
"""

@tvm.testing.fixture
def before(self):
@ir_module
class mod:
@T.prim_func
def func_A(A: T.Buffer[16, "float32"]):
for i in T.serial(16):
A[i] = 0.0

@T.prim_func
def func_B(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
A[i] = 42

return mod

expected = before


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 4768d8b

Please sign in to comment.