Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Add syntax sugar for T.handle and T.match_buffer #9492

Merged
merged 21 commits into from
Dec 8, 2021
Prev Previous commit
Next Next commit
fix test
  • Loading branch information
shingjan committed Dec 8, 2021
commit 7c9b3edd468d0c45014de71e7ce0cdbcc9952962
4 changes: 3 additions & 1 deletion python/tvm/script/tir/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __call__(
span=span,
)

def __class_getitem__(shape, dtype: str):
def __getitem__(shape, dtype: str):
return tvm.tir.decl_buffer(shape=shape, dtype=dtype)

def evaluate(self):
Expand All @@ -159,4 +159,6 @@ def evaluate(self):
handle = ConcreteType("handle")
Ptr = GenericPtrType()
Tuple = GenericTupleType()
# we don't have 'buffer' type on the cpp side
# thus 'handle' is used here for convenience's sake
Buffer = GenericBufferType("handle")
6 changes: 3 additions & 3 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ def elementwise_buffer_kwargs(


# match buffer - use buffer without kwargs
# This function is commented out as it is supported yet
shingjan marked this conversation as resolved.
Show resolved Hide resolved
# @T.prim_func
# def elementwise_buffer_no_kwargs(
# a: T.Buffer[(128, 128, 128, 128), "float32"],
# b: T.Buffer[(128, 128, 128, 128), "float32"],
# a: T.Buffer((128, 128, 128, 128), "float32", "a"),
# b: T.Buffer((128, 128, 128, 128), "float32", "b"),
# ) -> None:
# for i, j, k, l in T.grid(128, 128, 128, 128):
# with T.block("B"):
Expand All @@ -144,7 +145,6 @@ def test_match_buffer_syntax_sugar():
assert_structural_equal(elementwise_handle, elementwise_buffer_kwargs)
# without kwargs
assert_structural_equal(elementwise_handle, elementwise_buffer_kwargs)
# without kwargs


if __name__ == "__main__":
Expand Down