Skip to content

Commit

Permalink
[TVMScript] Fixing T.buffer with typed positional arguments other tha…
Browse files Browse the repository at this point in the history
…n int32 (apache#10892)

* workaround for T.buffer with typed positional arguments

* address comments

* fix linting
  • Loading branch information
Yuanjing Shi authored and Lucien0 committed Apr 19, 2022
1 parent 879c138 commit e2823ca
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/tvm/script/tir/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __init__(self, vtype):
else:
self.type = tvm.ir.PrimType(vtype)

def __call__(self, *args): # pylint: disable=arguments-differ
pass

def evaluate(self):
return self.type

Expand Down
37 changes: 37 additions & 0 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,5 +181,42 @@ def test_dynamic_shape_gemm():
assert_structural_equal(gemm_dyn_shape, gemm_dyn_shape_roundtrip)


@T.prim_func
def match_buffer_int64(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (T.int64(128), T.int64(128)), dtype="float32")
B = T.alloc_buffer((T.int64(128), T.int64(128)), dtype="float32")
C = T.match_buffer(c, (T.int64(128), T.int64(128)), dtype="float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(T.int64(128), T.int64(128)):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0


@T.prim_func
def match_buffer_int64_after_roundtrip(
A: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
C: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
) -> None:
B = T.alloc_buffer((T.int64(128), T.int64(128)), dtype="float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(T.int64(128), T.int64(128)):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0


def test_match_buffer_int64():
original = match_buffer_int64
after_roundtrip = match_buffer_int64_after_roundtrip
assert_structural_equal(original, after_roundtrip, True)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit e2823ca

Please sign in to comment.