Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Testing] Allow Capitalized name in CompareBeforeAfter #15568

Merged
merged 3 commits into from
Aug 17, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,7 +1890,7 @@ class CompareBeforeAfter:
input, apply a transformation, then either compare against an
expected output or assert that the transformation raised an error.
A test should subclass CompareBeforeAfter, defining class members
`before`, `transform`, and `expected`. CompareBeforeAfter will
`before` / `Before`, `transform`, and `expected` / `Expected`. CompareBeforeAfter will
then use these members to define a test method and test fixture.

`transform` may be one of the following.
Expand All @@ -1901,7 +1901,7 @@ class CompareBeforeAfter:

- A pytest fixture that returns a `tvm.ir.transform.Pass`

`before` may be any one of the following.
`before` / `Before` may be any one of the following.

- An instance of `tvm.tir.PrimFunc`. This is allowed, but is not
the preferred method, as any errors in constructing the
Expand All @@ -1916,13 +1916,13 @@ class CompareBeforeAfter:

- A pytest fixture that returns a `tvm.tir.PrimFunc`

`expected` may be any one of the following. The type of
`expected` defines the test being performed. If `expected`
`expected` / `Expected` may be any one of the following. The type of
`expected` / `Expected` defines the test being performed. If `expected`
provides a `tvm.tir.PrimFunc`, the result of the transformation
must match `expected`. If `expected` is an exception, then the
transformation must raise that exception type.

- Any option supported for `before`.
- Any option supported for `before` / `Before`.

- The `Exception` class object, or a class object that inherits
from `Exception`.
Expand Down Expand Up @@ -1953,10 +1953,13 @@ def expected(A: T.Buffer(1, "int32")):
"""

def __init_subclass__(cls):
if hasattr(cls, "before"):
cls.before = cls._normalize_before(cls.before)
if hasattr(cls, "expected"):
cls.expected = cls._normalize_expected(cls.expected)
for name in ["before", "Before"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an assert that at most one of the two options are present?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if should be that strong of an assert, as that would prevent defining an intermediate test that defines transform and expected, but doesn't define before. This could be useful for defining several tests that should all normalize to produce the same result. This would also break existing cases where the transform is defined, then used across several tests in a file.

What if instead, we were to assert that no duplicates names exist?

assert len([getattr(cls,name) for name in ['before','Before'] if hasattr(cls,name)]) <= 1
assert len([getattr(cls,name) for name in ['expected','Expected'] if hasattr(cls,name)]) <= 1

if hasattr(cls, name):
cls.before = cls._normalize_before(getattr(cls, name))
break
for name in ["expected", "Expected"]:
if hasattr(cls, name):
cls.expected = cls._normalize_expected(getattr(cls, name))
if hasattr(cls, "transform"):
cls.transform = cls._normalize_transform(cls.transform)

Expand Down