Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Add syntax sugar for T.handle and T.match_buffer #9492

Merged
merged 21 commits into from
Dec 8, 2021
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ pip3 install \
pytest-xdist \
requests \
scipy \
synr==0.5.0 \
synr==0.6.0 \
six \
tornado
2 changes: 1 addition & 1 deletion python/gen_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@
("sphinx_autodoc_annotation", None),
("sphinx_gallery", None),
("sphinx_rtd_theme", None),
("synr", "==0.5.0"),
("synr", "==0.6.0"),
("tensorflow", None),
("tensorflow-estimator", None),
("tflite", None),
Expand Down
61 changes: 52 additions & 9 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tvm._ffi.base import TVMError
from tvm.ir import GlobalVar
from tvm.ir.function import BaseFunc
from tvm.tir import buffer
from tvm.tir.function import PrimFunc
from . import _ffi_api
from . import tir
Expand Down Expand Up @@ -154,10 +155,10 @@ class TVMScriptParser(Transformer):
ast.BuiltinOp.Not: tvm.tir.Not,
}

def __init__(self, base_lienno, tir_namespace):
def __init__(self, base_lineno, tir_namespace):
self.context = None

self.base_lineno = base_lienno
self.base_lineno = base_lineno
self.current_lineno = 0
self.current_col_offset = 0
self.tir_namespace = tir_namespace
Expand Down Expand Up @@ -249,20 +250,23 @@ def parse_arg_list(self, func, node_call):
func : Function
The function that provides the signature

node_call: ast.Call
node_call: Union[ast.Call, ast.TypeApply, ast.TypeCall]
The AST call node that calls into the function.

Returns
-------
arg_list : list
The parsed positional argument.
"""
assert isinstance(node_call, ast.Call)
assert isinstance(node_call, (ast.Call, ast.TypeApply, ast.TypeCall))
# collect arguments
args = [self.transform(arg) for arg in node_call.params]
kw_args = {
self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
}
if isinstance(node_call, ast.TypeApply):
kw_args = {} # TypeApply (e.g. foo[bar]) doesn't have kwargs defined in synr
else:
kw_args = {
self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
}
# get the name and parameter list of func
if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)):
func_name, param_list = func.signature()
Expand All @@ -276,6 +280,7 @@ def parse_arg_list(self, func, node_call):
reader = CallArgumentReader(func_name, args, kw_args, self, node_call)
pos_only, kwargs, varargs = param_list
internal_args = list()

for i, arg_name in enumerate(pos_only):
internal_args.append(reader.get_pos_only_arg(i + 1, arg_name))
for i, arg_info in enumerate(kwargs):
Expand Down Expand Up @@ -439,8 +444,22 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:

# add parameters of function
for arg in node.params:
arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
self.context.update_symbol(arg.name, arg_var, node)
# Note that this case is for T.match_buffer syntax sugar
if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)):
result = self.handle_match_buffer_type(arg.ty, arg.name)
if not isinstance(result, buffer.Buffer):
self.report_error(
"The result type of evaluating TypeCall and TypeApply stmt"
f" is wrong: {type(result)}. It should be a Buffer",
node.span,
)
arg_name_with_handle = arg.name + "_handle"
arg_var = tvm.te.var(arg_name_with_handle, tvm.ir.PrimType("handle"))
self.context.func_buffer_map[arg_var] = result
self.context.update_symbol(arg.name, result, node)
else:
arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
self.context.update_symbol(arg.name, arg_var, node)
self.context.func_params.append(arg_var)

if not check_decorator(node.decorators):
Expand Down Expand Up @@ -1110,6 +1129,30 @@ def transform_TypeConstant(self, node):
"""
return node.value

def transform_TypeTuple(self, node):
"""Tuple value visitor for types.

Mostly used in `transform_TypeCall` and `transform_TypeApply`.
"""
return [self.transform(value) for value in node.values]

def handle_match_buffer_type(self, node, buffer_name):
"""special function to handle syntax sugar for match buffer.

This method is for buffer declarations in the function parameters.
"""
func = self.transform(node.func_name)
assert isinstance(func, SpecialStmt)

# parse args and kwargs for TypeCall and TypeApply
arg_list = self.parse_arg_list(func, node)
# Note that the third element in arg_list would always be the 'name'
# TODO: This index is hardcoded as a workaround. Better to make it programmatic
if arg_list[2] is None:
arg_list[2] = buffer_name
buf = func.handle(node, self.context, arg_list, node.func_name.span)
return buf

def transform_Return(self, node):
self.report_error(
"TVM script does not support return statements. Instead the last statement in any "
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@

# Type system
from .ty import int8, int16, int32, int64, float16, float32, float64
from .ty import boolean, handle, Ptr, Tuple
from .ty import boolean, handle, Ptr, Tuple, Buffer

from .prim_func import prim_func
73 changes: 73 additions & 0 deletions python/tvm/script/tir/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""
# pylint: disable=invalid-name
import tvm
from .special_stmt import SpecialStmt, convert_to_int


class TypeGeneric: # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -67,6 +68,75 @@ def __getitem__(self, vtypes):
return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))


class GenericBufferType(SpecialStmt): # pylint: disable=too-few-public-methods, abstract-method
"""TVM script typing class for uniform Type objects"""

def __init__(self, vtype):
def match_buffer_syntax_sugar(
shape,
dtype: str = "float32",
name: str = None,
data=None,
strides=None,
elem_offset=None,
scope="global",
align=-1,
offset_factor=0,
buffer_type="default",
span=None,
):
if strides is None:
strides = []
align = convert_to_int(align, "align", self.context.report_error, self.node.span)
offset_factor = convert_to_int(
offset_factor, "offset_factor", self.context.report_error, self.node.span
)
buffer = tvm.tir.decl_buffer(
shape,
dtype,
name,
data,
strides,
elem_offset,
scope,
align,
offset_factor,
buffer_type,
span=span,
)
return buffer

self.type = vtype
super().__init__(match_buffer_syntax_sugar, def_symbol=True)

def __call__(
self,
shape,
dtype="float32",
*,
name: str = None,
data=None,
strides=None,
elem_offset=None,
scope="global",
align=-1,
offset_factor=0,
buffer_type="default",
span=None,
):
"""
This function is for Buffer(...) syntax sugar.
"""
pass # pylint: disable=unnecessary-pass

def __getitem__(self, args):
"""
This function is for Buffer[...] syntax sugar
Note that args is the list of all arguments
"""
pass # pylint: disable=unnecessary-pass


int8 = ConcreteType("int8")
int16 = ConcreteType("int16")
int32 = ConcreteType("int32")
Expand All @@ -78,3 +148,6 @@ def __getitem__(self, vtypes):
handle = ConcreteType("handle")
Ptr = GenericPtrType()
Tuple = GenericTupleType()
# we don't have 'buffer' type on the cpp side
# thus 'handle' is used here for convenience's sake
Buffer = GenericBufferType("handle")
45 changes: 45 additions & 0 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,50 @@ def test_syntax_sugar_fail():
check_error(loop_syntax_sugar_fail, 3)


# match buffer - use kwargs
@T.prim_func
shingjan marked this conversation as resolved.
Show resolved Hide resolved
def elementwise_handle(
a: T.handle,
b: T.handle,
) -> None:
A = T.match_buffer(a, (128, 128, 128, 128))
B = T.match_buffer(b, (128, 128, 128, 128))
for i, j, k, l in T.grid(128, 128, 128, 128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0


# match buffer - use buffer with kwargs
@T.prim_func
def elementwise_buffer_kwargs(
a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
) -> None:
for i, j, k, l in T.grid(128, 128, 128, 128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0


# match buffer - use buffer without kwargs
@T.prim_func
def elementwise_buffer_no_kwargs(
a: T.Buffer[(128, 128, 128, 128), "float32"],
b: T.Buffer[(128, 128, 128, 128), "float32"],
) -> None:
for i, j, k, l in T.grid(128, 128, 128, 128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0


def test_match_buffer_syntax_sugar():
# with kwargs
assert_structural_equal(elementwise_handle, elementwise_buffer_kwargs)
# without kwargs
assert_structural_equal(elementwise_handle, elementwise_buffer_no_kwargs)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
2 changes: 1 addition & 1 deletion tests/scripts/task_ci_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ set -o pipefail
#
echo "Addtiional setup in" ${CI_IMAGE_NAME}

python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.5.0
python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.6.0

# Rebuild standalone_crt in build/ tree. This file is not currently archived by pack_lib() in
# Jenkinsfile. We expect config.cmake to be present from pack_lib().
Expand Down