Skip to content

Commit

Permalink
addr cmts
Browse files Browse the repository at this point in the history
  • Loading branch information
shingjan committed Nov 18, 2021
1 parent ed9ba73 commit d705401
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 10 deletions.
16 changes: 14 additions & 2 deletions python/tvm/script/tir/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,20 @@ def __init__(self, vtype):
def evaluate(self):
return tvm.ir.PrimType(self.type)

def __call__(self, shape, dtype, elem_offset):
pass
def __call__(
self,
shape,
dtype="float32",
data=None,
strides=None,
elem_offset=None,
scope="global",
align=-1,
offset_factor=0,
buffer_type="default",
span=None,
):
self.name = "match_buffer"


class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method
Expand Down
34 changes: 26 additions & 8 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,53 @@
import sys

import pytest
import tvm
from tvm.script import tir as T

# match buffer - use kwargs
# match buffer - no syntax sugar
@T.prim_func
def elementwise(
def elementwise_handle(
a: T.handle,
b: T.handle,
) -> None:
A = T.match_buffer(a, (128, 128, 128, 128))
B = T.match_buffer(b, (128, 128, 128, 128))
for i, j, k, l in T.grid(128, 128, 128, 128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0


# match buffer - use buffer with kwargs
@T.prim_func
def elementwise_buffer_kwargs(
a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=1),
b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=2),
) -> None:
# A = T.match_buffer(a, (128, 128, 128, 128))
# B = T.match_buffer(b, (128, 128, 128, 128))
for i, j, k, l in T.grid(128, 128, 128, 128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0


# match buffer - no kwargs
# match buffer - use buffer without kwargs
@T.prim_func
def elementwise(
def elementwise_buffer_no_kwargs(
a: T.Buffer[(128, 128, 128, 128), "float32"],
b: T.Buffer[(128, 128, 128, 128), "float32"],
) -> None:
# A = T.match_buffer(a, (128, 128, 128, 128))
# B = T.match_buffer(b, (128, 128, 128, 128))
for i, j, k, l in T.grid(128, 128, 128, 128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0


def test_match_buffer_syntax_sugar():
# with kwargs
tvm.ir.assert_structural_equal(elementwise_handle, elementwise_buffer_kwargs)
# without kwargs
tvm.ir.assert_structural_equal(elementwise_handle, elementwise_buffer_no_kwargs)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit d705401

Please sign in to comment.