From b3b32871b083540154f56fea48ca82c55cbba820 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sun, 26 Sep 2021 04:29:42 +0000 Subject: [PATCH] namespace Co-authored-by: Junru Shao Co-authored-by: Zihao Ye Co-authored-by: Tristan Konolige --- .../install/ubuntu_install_python_package.sh | 2 +- docker/install/ubuntu_install_sphinx.sh | 2 +- include/tvm/tir/function.h | 20 +- include/tvm/tir/stmt.h | 20 +- include/tvm/tir/transform.h | 8 +- python/gen_requirements.py | 2 +- python/tvm/ir/module.py | 15 + python/tvm/script/__init__.py | 4 +- python/tvm/script/context_maintainer.py | 48 +- python/tvm/script/parser.py | 243 +- python/tvm/script/registry.py | 2 +- python/tvm/script/tir/__init__.py | 23 + python/tvm/script/{ => tir}/intrin.py | 4 +- python/tvm/script/{ => tir}/node.py | 0 python/tvm/script/tir/prim_func.py | 45 + python/tvm/script/{ => tir}/scope_handler.py | 68 +- python/tvm/script/{ => tir}/special_stmt.py | 32 +- python/tvm/script/{ => tir}/ty.py | 2 +- python/tvm/script/tir/utils.py | 54 + python/tvm/script/utils.py | 37 +- python/tvm/te/operation.py | 16 +- python/tvm/tir/function.py | 35 +- python/tvm/tir/schedule/schedule.py | 496 ++-- python/tvm/tir/transform/transform.py | 8 +- src/printer/text_printer.h | 2 +- src/printer/tir_text_printer.cc | 4 +- src/printer/tvmscript_printer.cc | 161 +- src/tir/ir/specialize.cc | 4 +- src/tir/schedule/primitive/compute_inline.cc | 6 +- .../test_ethosu/test_encode_constants.py | 153 +- .../test_ethosu/test_replace_conv2d.py | 239 +- .../contrib/test_ethosu/test_replace_copy.py | 30 +- .../contrib/test_ethosu/test_vela_api.py | 83 +- tests/python/integration/test_lower.py | 182 +- .../unittest/test_aot_legalize_packed_call.py | 53 +- tests/python/unittest/test_lower_build.py | 76 +- .../unittest/test_meta_schedule_arg_info.py | 20 +- .../unittest/test_meta_schedule_builder.py | 83 +- .../unittest/test_meta_schedule_database.py | 55 +- .../unittest/test_meta_schedule_runner.py | 87 +- .../test_meta_schedule_search_strategy.py | 25 +- .../test_meta_schedule_space_generator.py | 26 +- .../test_meta_schedule_tune_context.py | 22 +- .../unittest/test_te_create_primfunc.py | 120 +- .../test_tir_analysis_calculate_workspace.py | 116 +- ...t_tir_analysis_detect_buffer_access_lca.py | 98 +- ...st_tir_analysis_get_block_access_region.py | 111 +- tests/python/unittest/test_tir_intrin.py | 37 +- .../unittest/test_tir_lower_match_buffer.py | 393 ++-- .../unittest/test_tir_schedule_block_scope.py | 50 +- .../test_tir_schedule_cache_read_write.py | 516 ++--- .../unittest/test_tir_schedule_compute_at.py | 960 ++++---- .../test_tir_schedule_compute_inline.py | 272 +-- .../unittest/test_tir_schedule_error.py | 20 +- .../unittest/test_tir_schedule_for_kind.py | 302 +-- .../unittest/test_tir_schedule_reduction.py | 580 ++--- .../unittest/test_tir_schedule_reorder.py | 266 +-- .../unittest/test_tir_schedule_sampling.py | 12 +- .../unittest/test_tir_schedule_split_fuse.py | 472 ++-- .../unittest/test_tir_schedule_state.py | 72 +- .../test_tir_schedule_state_cached_flags.py | 314 +-- .../test_tir_schedule_storage_align.py | 152 +- .../unittest/test_tir_schedule_trace.py | 26 +- .../unittest/test_tir_schedule_utilities.py | 20 +- tests/python/unittest/test_tir_specialize.py | 183 +- ...est_tir_transform_compact_buffer_region.py | 488 ++-- ..._tir_transform_convert_blocks_to_opaque.py | 58 +- .../test_tir_transform_flatten_buffer.py | 250 +- .../test_tir_transform_loop_partition.py | 23 +- .../test_tir_transform_lower_init_block.py | 80 +- ..._plan_update_buffer_allocation_location.py | 206 +- .../test_tir_transform_storage_flatten.py | 10 +- ...test_tir_transform_unify_thread_binding.py | 222 +- .../unittest/test_tvmscript_complete.py | 258 +-- .../unittest/test_tvmscript_error_report.py | 261 ++- tests/python/unittest/test_tvmscript_ops.py | 54 +- .../unittest/test_tvmscript_roundtrip.py | 2034 ++++++++--------- tests/python/unittest/test_tvmscript_spans.py | 31 +- tests/scripts/task_ci_setup.sh | 2 +- 79 files changed, 5836 insertions(+), 5730 deletions(-) create mode 100644 python/tvm/script/tir/__init__.py rename python/tvm/script/{ => tir}/intrin.py (98%) rename python/tvm/script/{ => tir}/node.py (100%) create mode 100644 python/tvm/script/tir/prim_func.py rename python/tvm/script/{ => tir}/scope_handler.py (91%) rename python/tvm/script/{ => tir}/special_stmt.py (95%) rename python/tvm/script/{ => tir}/ty.py (98%) create mode 100644 python/tvm/script/tir/utils.py diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index eff86a950b90d..d1fa340ac37d2 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -36,6 +36,6 @@ pip3 install \ pytest-xdist \ requests \ scipy \ - synr==0.4.0 \ + synr==0.4.1 \ six \ tornado diff --git a/docker/install/ubuntu_install_sphinx.sh b/docker/install/ubuntu_install_sphinx.sh index 12208bbe66436..66720d4118323 100755 --- a/docker/install/ubuntu_install_sphinx.sh +++ b/docker/install/ubuntu_install_sphinx.sh @@ -29,5 +29,5 @@ pip3 install \ matplotlib \ sphinx \ sphinx_autodoc_annotation \ - sphinx-gallery==0.4.0 \ + sphinx-gallery==0.4.1 \ sphinx_rtd_theme diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 55f4fc62649c9..23057f7140e4b 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -195,12 +195,12 @@ class LinkedParam : public ObjectRef { * \note We can define a Meta TIR function with symbolic shape: * * \code - * @tvm.script.tir - * def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None: - * A = tir.match_buffer(a, (m, n), "float32") - * B = tir.match_buffer(b, (m, n), "float32") + * @T.prim_func + * def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None: + * A = T.match_buffer(a, (m, n), "float32") + * B = T.match_buffer(b, (m, n), "float32") * - * with tir.block([m, n], "") as [vi, vj]: + * with T.block([m, n], "") as [vi, vj]: * B[vi, vj] = A[vi, vj] * \endcode * @@ -214,12 +214,12 @@ class LinkedParam : public ObjectRef { * \endcode * * \code {.language-id} - * @tvm.script.tir - * def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None: - * A = tir.match_buffer(a, (16, 16), "float32") - * B = tir.match_buffer(b, (16, 16), "float32") + * @T.prim_func + * def mem_copy_16_16(a: T.handle, b: T.handle) -> None: + * A = T.match_buffer(a, (16, 16), "float32") + * B = T.match_buffer(b, (16, 16), "float32") * - * with tir.block([16, 16], "") as [vi, vj]: + * with T.block([16, 16], "") as [vi, vj]: * B[vi, vj] = A[vi, vj] * \endcode */ diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 2ae2877b2f92d..94ba853c493ac 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1068,17 +1068,17 @@ class MatchBufferRegion : public ObjectRef { * \note Block's body is parameterized by iter vars. * \code * - * with tir.block([extent0, extent1, ...], name) as [v0, v1, ...]: - * tir.bind(v0, value0) - * tir.bind(v1, value1) + * with T.block([extent0, extent1, ...], name) as [v0, v1, ...]: + * T.bind(v0, value0) + * T.bind(v1, value1) * ... - * tir.reads([buffer0[start:end, ...], ...]) - * tir.writes([buffer1[start:end, ...], ...]) - * tir.where(predicate) - * buffer2 = tir.alloc_buffer(shape, dtype) - * buffer3 = tir.match_buffer(source_buffer[start:end, ...]) - * tir.attr({attr_key: attr_value, ...}) - * with tir.init(): + * T.reads([buffer0[start:end, ...], ...]) + * T.writes([buffer1[start:end, ...], ...]) + * T.where(predicate) + * buffer2 = T.alloc_buffer(shape, dtype) + * buffer3 = T.match_buffer(source_buffer[start:end, ...]) + * T.attr({attr_key: attr_value, ...}) + * with T.init(): * // init body * // body * diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index b5998874f7e35..e94b966bc0fc0 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -388,8 +388,8 @@ TVM_DLL Pass ConvertBlocksToOpaque(); * \code * * for i in range(0, 16): - * with tir.block([]): - * B = tir.alloc_buffer(16, 16) + * with T.block([]): + * B = T.alloc_buffer(16, 16) * for j in range(0, 16): * B[i, j] = A[i, j] + 1 * for j in range(0, 16): @@ -404,8 +404,8 @@ TVM_DLL Pass ConvertBlocksToOpaque(); * \code * * for i in range(0, 16): - * with tir.block([]): - * B = tir.alloc_buffer(1, 16) + * with T.block([]): + * B = T.alloc_buffer(1, 16) * for j in range(0, 16): * B[0, j] = A[i, j] + 1 * for j in range(0, 16): diff --git a/python/gen_requirements.py b/python/gen_requirements.py index d6dd094f6a5b7..fee07388f234a 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -249,7 +249,7 @@ ("sphinx_autodoc_annotation", None), ("sphinx_gallery", None), ("sphinx_rtd_theme", None), - ("synr", "==0.4.0"), + ("synr", "==0.4.1"), ("tensorflow", None), ("tensorflow-estimator", None), ("tflite", None), diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index f8b6ff2953391..55dc847f03925 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -255,3 +255,18 @@ def __str__(self): def __repr__(self): return self.astext() + + def script(self, show_meta=False) -> str: + """Print IRModule into TVMScript + + Parameters + ---------- + show_meta : bool + Whether show meta + + Returns + ------- + script : str + The TVM Script of the IRModule + """ + return tvm._ffi.get_global_func("script.AsTVMScript")(self, show_meta) # type: ignore diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py index 4cf7828290a72..555659d0c55ed 100644 --- a/python/tvm/script/__init__.py +++ b/python/tvm/script/__init__.py @@ -16,4 +16,6 @@ # under the License. """TVM Script APIs of TVM Python Package, aimed to support TIR""" -from .parser import from_source, create_module, asscript, tir, module +from . import tir + +from .parser import ir_module, from_source diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 44c92b792f122..75566cf6e2c5e 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -24,7 +24,7 @@ from tvm.ir import Span from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion from tvm.runtime import Object -from .node import BufferSlice +from .tir.node import BufferSlice class BlockInfo: @@ -34,34 +34,34 @@ class BlockInfo: ---------- .. code-block:: python - @tvm.script.tir - def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - B = tir.match_buffer(b, (16, 16), "float32") - C = tir.match_buffer(a, (16, 16), "float32") + @T.prim_func + def example_func(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + B = T.match_buffer(b, (16, 16), "float32") + C = T.match_buffer(a, (16, 16), "float32") - for i, j, k in tir.grid(16, 16, 16): - with tir.block([16, 16, tir.reduce_axis(16)], "matmul") as [vi, vj, vk]: - tir.bind(vi, i) - tir.bind(vj, j) - tir.bind(vk, k) # iter_bindings = {vj: i, vj: j, vk: k} + for i, j, k in T.grid(16, 16, 16): + with T.block([16, 16, T.reduce_axis(16)], "matmul") as [vi, vj, vk]: + T.bind(vi, i) + T.bind(vj, j) + T.bind(vk, k) # iter_bindings = {vj: i, vj: j, vk: k} - tir.where(True) # predicate of the block_realize + T.where(True) # predicate of the block_realize - tir.reads(A[0:16, 0:16], B[0: 16, 0: 16]) # reads region of the block - tir.writes(C[0: 16, 0: 16]) # writes region of the block - tir.block_attr({"attr_key": "attr_value"}) # block annotations + T.reads(A[0:16, 0:16], B[0: 16, 0: 16]) # reads region of the block + T.writes(C[0: 16, 0: 16]) # writes region of the block + T.block_attr({"attr_key": "attr_value"}) # block annotations # alloc_buffers inside the block - CC = tir.alloc_buffer((1, 1), dtype="float32") + CC = T.alloc_buffer((1, 1), dtype="float32") # match_buffers of the block, # which bind a sub-region of source buffer into a new buffer - D = tir.match_buffer(C[vi, vj], ()) + D = T.match_buffer(C[vi, vj], ()) # init part of the block, executed when all reduce axes are the beginning value - with tir.init(): - C[vi, vj] = tir.float32(0) + with T.init(): + C[vi, vj] = T.float32(0) # block body CC[0, 0] = A[vi, vk] * B[vj, vk] @@ -69,20 +69,20 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None: """ alloc_buffers: List[Buffer] = [] - """List[Buffer]: list of tir.alloc_buffer statements in the block signature""" + """List[Buffer]: list of T.alloc_buffer statements in the block signature""" match_buffers: List[MatchBufferRegion] = [] - """List[MatchBufferRegion]: list of tir.match_buffer statements in the block signature""" + """List[MatchBufferRegion]: list of T.match_buffer statements in the block signature""" iter_bindings: Mapping[Var, PrimExpr] = {} """Mapping[Var, PrimExpr]: map of block iter var to its values""" reads: Optional[List[BufferSlice]] = None """Optional[List[BufferSlice]]: - list of tir.reads statements in the block signature, None for not-visited""" + list of T.reads statements in the block signature, None for not-visited""" writes: Optional[List[BufferSlice]] = None """Optional[List[BufferSlice]]: - list of tir.writes statements in the block signature, None for not-visited""" + list of T.writes statements in the block signature, None for not-visited""" annotations: Optional[Mapping[str, Object]] = None """Optional[Mapping[str, Object]]: - list of tir.block_attr statements in the block signature, None for not-visited""" + list of T.block_attr statements in the block signature, None for not-visited""" predicate: Optional[PrimExpr] = None """Optional[PrimExpr]: block realize predicate, None for not-visited""" init: Optional[Stmt] = None diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 51ee0aed982c7..53aa11f2548ba 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -24,25 +24,28 @@ import json import operator import inspect -from typing import Union +from typing import Any, Callable, Dict, List, Optional, Union from synr import ast, Transformer, to_ast import tvm from tvm import IRModule from tvm._ffi.base import TVMError from tvm.ir import GlobalVar +from tvm.ir.function import BaseFunc +from tvm.tir.function import PrimFunc +from . import _ffi_api +from . import tir -from . import context_maintainer, ty -from .context_maintainer import BlockInfo +from .context_maintainer import BlockInfo, ContextMaintainer from .meta_unparser import MetaUnparser from .registry import Registry -from .intrin import Intrin -from .special_stmt import SpecialStmt -from .scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler -from . import _ffi_api from .diagnostics import TVMDiagnosticCtx from .utils import tvm_span_from_synr, synr_span_from_tvm, call_with_error_reporting -from .node import Slice, BufferSlice + +from .tir.intrin import Intrin +from .tir.node import Slice, BufferSlice +from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler +from .tir.special_stmt import SpecialStmt class CallArgumentReader(object): @@ -151,17 +154,18 @@ class TVMScriptParser(Transformer): ast.BuiltinOp.Not: tvm.tir.Not, } - def __init__(self, base_lienno): + def __init__(self, base_lienno, tir_namespace): self.context = None self.base_lineno = base_lienno self.current_lineno = 0 self.current_col_offset = 0 + self.tir_namespace = tir_namespace self.meta = None def init_function_parsing_env(self): """Initialize function parsing environment""" - self.context = context_maintainer.ContextMaintainer(self.report_error) # scope emitter + self.context = ContextMaintainer(self.report_error) # scope emitter def init_meta(self, meta_dict): if meta_dict is not None: @@ -185,6 +189,10 @@ def transform(self, node): return transform_res + def match_tir_namespace(self, identifier: str) -> bool: + """Check if the namespace is equal to tvm.script.tir""" + return identifier in self.tir_namespace + def report_error(self, message: str, span: Union[ast.Span, tvm.ir.Span]): """Report an error occuring at a location. @@ -324,13 +332,14 @@ def A(...): import tvm - @tvm.script.tir + @tvm.script.ir_module class MyMod(): - def A(...): - ... - - def B(...): - ... + @T.prim_func + def A(...): + ... + @T.prim_func + def B(...): + ... __tvm_meta__ = ... @@ -360,11 +369,11 @@ def transform_Class(self, node): ------- .. code-block:: python - @tvm.script.tir + @tvm.script.ir_module class MyClass: __tvm_meta__ = {} def A(): - tir.evaluate(0) + T.evaluate(0) """ if len(node.assignments) == 1: if not ( @@ -384,7 +393,7 @@ def A(): ast.Span.union([x.span for x in node.assignments[1:]]), ) - return create_module( + return IRModule( {GlobalVar(name): self.transform(func) for name, func in node.funcs.items()} ) @@ -405,13 +414,25 @@ def transform_Function(self, node): ------- .. code-block:: python - @tvm.script.tir - def my_function(x: ty.handle): # 1. Argument types - tir.func_attr({"global_symbol": "mmult"}) # 2. Function attributes + @T.prim_func + def my_function(x: T.handle): # 1. Argument types + T.func_attr({"global_symbol": "mmult"}) # 2. Function attributes X_1 = tir.buffer_bind(x, [1024, 1024]) # 3. Buffer binding - tir.evaluate(0) # 4. This function returns 0 + T.evaluate(0) # 4. This function returns 0 """ + def check_decorator(decorators: List[ast.Expr]) -> bool: + """Check the decorator is `T.prim_func""" + if len(decorators) != 1: + return False + d: ast.Expr = decorators[0] + return ( + isinstance(d, ast.Attr) + and isinstance(d.object, ast.Var) + and self.match_tir_namespace(d.object.id.name) + and d.field.name == "prim_func" + ) + self.init_function_parsing_env() self.context.enter_scope(nodes=node.body.stmts) @@ -421,6 +442,12 @@ def my_function(x: ty.handle): # 1. Argument types self.context.update_symbol(arg.name, arg_var, node) self.context.func_params.append(arg_var) + if not check_decorator(node.decorators): + self.report_error( + "All functions should be decorated by `T.prim_func`", + node.span, + ) + # New Scope : Implicit root block # Each function contains an implicit root block in TensorIR, # so here we need a block scope for it. Please note that `enter_block_scope` @@ -437,10 +464,11 @@ def my_function(x: ty.handle): # 1. Argument types # return a tir.PrimFunc dict_attr = self.context.func_dict_attr + ret_type = self.parse_type(node.ret_type, node) if node.ret_type is not None else None func = tvm.tir.PrimFunc( self.context.func_params, body, - ret_type=self.parse_type(node.ret_type, node), + ret_type, buffer_map=self.context.func_buffer_map, attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None, span=tvm_span_from_synr(node.span), @@ -468,13 +496,13 @@ def transform_Assign(self, node): By now 3 patterns of Assign is supported: 1. special stmts with return value - 1.1 Buffer = tir.match_buffer()/tir.buffer_decl() - 1.2 Var = tir.var() - 1.3 Var = tir.env_thread() + 1.1 Buffer = T.match_buffer()/T.buffer_decl() + 1.2 Var = T.var() + 1.3 Var = T.env_thread() 2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr 3. (Store) Var[PrimExpr] = PrimExpr 4. with scope handlers with concise scoping and var def - 4.1 var = tir.allocate() + 4.1 var = T.allocate() """ if isinstance(node.rhs, ast.Call): @@ -550,7 +578,7 @@ def transform_SubscriptAssign(self, node): def transform_Assert(self, node): """Assert visitor - Pattern corresponds to concise mode of :code:`with tir.Assert()`. + Pattern corresponds to concise mode of :code:`with T.Assert()`. """ condition = self.transform(node.condition) @@ -568,8 +596,8 @@ def transform_For(self, node): For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) By now 1 pattern of For is supported: 1. for scope handler - for name in tir.serial()/tir.parallel()/tir.vectorized()/tir.unroll()/tir.range()/ - tir.grid()/tir.thread_binding() + for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/tir.range()/ + T.grid()/T.thread_binding() """ if not isinstance(node.rhs, ast.Call): @@ -614,9 +642,9 @@ def transform_With(self, node): withitem = (expr context_expr, expr? optional_vars) By now 2 patterns of With is supported: 1. with scope handler with symbol def - with tir.block(*axes)/tir.allocate() as targets: + with T.block(*axes)/T.allocate() as targets: 2. with scope handler without symbol def - with tir.let()/tir.Assert()/tir.attr()/tir.realize() + with T.let()/T.Assert()/T.attr()/T.realize() """ if not isinstance(node.rhs, ast.Call): @@ -736,10 +764,10 @@ def transform_UnassignedCall(self, node): -------- .. code-block:: python - @tvm.script.tir + @T.prim_func def f(): - A = tir.buffer_decl([10, 10]) - tir.realize(A[1:2, 1:2], "") # This is an UnassignedCall + A = T.buffer_decl([10, 10]) + T.realize(A[1:2, 1:2], "") # This is an UnassignedCall A[1, 1] = 2 # This is also an UnassignedCall """ # Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign. @@ -755,9 +783,9 @@ def f(): func = self.transform(node.call.func_name) arg_list = self.parse_arg_list(func, node.call) - if isinstance(func, tvm.script.scope_handler.AssertHandler): + if isinstance(func, tir.scope_handler.AssertHandler): self.report_error( - "A standalone `tir.Assert` is not allowed. Use `assert condition, message` " + "A standalone `T.Assert` is not allowed. Use `assert condition, message` " "instead.", node.call.func_name.span, ) @@ -865,7 +893,7 @@ def transform_Attr(self, node): """ if isinstance(node.object, ast.Var): - if node.object.id.name == "tir": + if self.match_tir_namespace(node.object.id.name): func_name = "tir." + node.field.name res = Registry.lookup(func_name) if res is not None: @@ -895,16 +923,18 @@ def transform_TypeAttr(self, node): """Visitor for field access of the form `x.y` for types. We have two cases here: - 1. If the type is of the form `ty.something`, we look up the type in - the `ty` namespace in this module. + 1. If the type is of the form `T.something`, we look up the type in + the `tir` namespace in this module. 2. If the type is of the form `tvm.x.something` then we look up `tvm.x.something` in this modules namespace. """ if isinstance(node.object, ast.TypeVar): - if node.object.id.name == "ty": - if not hasattr(ty, node.field.name): - self.report_error(f"Invalid type annotation `ty.{node.field.name}`.", node.span) - return getattr(ty, node.field.name) + if self.match_tir_namespace(node.object.id.name): + if not hasattr(tir, node.field.name): + self.report_error( + f"Invalid type annotation `tir.{node.field.name}`.", node.span + ) + return getattr(tir, node.field.name) symbol = self.transform(node.object) if symbol is None: @@ -994,108 +1024,63 @@ def transform_Return(self, node): ) -def from_source(src): - """Parse function or string into TIR. - - If possible, pass the TVM script in as a function so that line numbers and - filename will be accurate. - - Parameters - ---------- - src : [str, function, class] - Pruned source of original script - - Returns - ------- - functions : PrimFunc or IRModule - The PrimFunc or IRModule in IR. - """ - if isinstance(src, str): - start_line = 0 - else: - _, start_line = inspect.getsourcelines(src) - parser = TVMScriptParser(start_line) - return to_ast(src, TVMDiagnosticCtx(), parser) - - -def create_module(functions=None): - """Construct a module from list of functions. +def get_tir_namespace(script: Union[Callable, type]) -> List[str]: + assert inspect.isfunction(script) or inspect.isclass(script) + env: Dict[str, Any] = script.__globals__ + return [key for key in env.keys() if env[key] == tir] - Parameters - ----------- - functions: Optional[dict]. - Map of GlobalVar or str to PrimFunc - Returns - ------- - mod : IRModule - An IRModule containing the passed definitions - """ +def from_source( + input_func: Union[str, Callable], tir_prefix: Optional[List[str]] = None +) -> Union[PrimFunc, IRModule]: + """Parse function or string into PrimFunc or IRModule. - return IRModule(functions=functions) - - -def asscript(input_ir, show_meta=False): - """Transform a PrimFunc or IRModule to python syntax script + If possible, pass the TVM script in as a function so that line numbers and + filename will be accurate. Parameters ---------- - input_ir : Union[PrimFunc, IRModule] - The PrimFunc or IRModule to be dumped - - show_meta : bool - Whether show meta - - Returns - ------- - script : str - The Python script - """ + input_module : Union[str, Callable] + The python function to be parsed. - return _ffi_api.AsTVMScript(input_ir, show_meta) - - -def tir(script_in): - """Decorate a python function or class as tvm script. - - The tvm function or parsing support parsing to the internal TIR. + tir_prefix : Optional[List[str]] + The tir prefix list. Only works for str input, default by "tir" and "T". Returns ------- output : Union[Function, Module] The Function or Module in IR. """ - - if inspect.isfunction(script_in): - result = from_source(script_in) - elif inspect.isclass(script_in): - result = TVMScriptClass(script_in) + if isinstance(input_func, str): + tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix + return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix)) + elif inspect.isfunction(input_func): + _, start_line = inspect.getsourcelines(input_func) + env: Dict[str, Any] = input_func.__globals__ + namespace = [key for key in env.keys() if env[key] == tir] + parser = TVMScriptParser(start_line, namespace) + result = to_ast(input_func, TVMDiagnosticCtx(), parser) + return result else: - raise TypeError("Only function and class definitions are supported.") - result.__name__ = script_in.__name__ - result.__qualname__ = script_in.__qualname__ - return result + raise TypeError("Only function definitions are supported.") -def module(script_in): - """Decorate a python function or class as tvm script. +def ir_module(input_module: type) -> IRModule: + """Decorate a python class as tvm IRModule. - Alias for tvm.script.tir for now. + Parameters + ---------- + input_module : type + The python class to be parsed. Returns ------- - output : Union[Function, Module] - The Function or Module in IR. + output : IRModule + The result IRModule. """ - return tir(script_in) - - -class TVMScriptClass: - """Helper class for decorating a class""" - - def __init__(self, script_in): - self.script = script_in - - def __call__(self, *args, **kwargs): - # call the parser to transform tvm script into TIR - return from_source(self.script) + if inspect.isclass(input_module): + func_dict = { + name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc) + } + return IRModule(func_dict) + raise TypeError("Only class definitions are supported.") diff --git a/python/tvm/script/registry.py b/python/tvm/script/registry.py index 245cc01051d52..e7d90dd515171 100644 --- a/python/tvm/script/registry.py +++ b/python/tvm/script/registry.py @@ -41,7 +41,7 @@ def register(inputs: Union[Callable, type]) -> type: registration: type if isinstance(inputs, types.FunctionType): # is function - from .intrin import Intrin + from .tir.intrin import Intrin def create_new_intrin(func) -> type: class NewIntrin(Intrin): diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/tir/__init__.py new file mode 100644 index 0000000000000..6aa7eb33ec8bf --- /dev/null +++ b/python/tvm/script/tir/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVMScript for TIR""" + +# Type system +from .ty import int8, int16, int32, int64, float16, float32, float64 +from .ty import boolean, handle, Ptr, Tuple + +from .prim_func import prim_func diff --git a/python/tvm/script/intrin.py b/python/tvm/script/tir/intrin.py similarity index 98% rename from python/tvm/script/intrin.py rename to python/tvm/script/tir/intrin.py index e2d44440e2ac7..4d7fe80b28b17 100644 --- a/python/tvm/script/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -19,8 +19,8 @@ from typing import List, Any import tvm.tir -from .registry import register -from .utils import get_param_list, tvm_span_from_synr +from ..registry import register +from ..utils import get_param_list, tvm_span_from_synr class Intrin: diff --git a/python/tvm/script/node.py b/python/tvm/script/tir/node.py similarity index 100% rename from python/tvm/script/node.py rename to python/tvm/script/tir/node.py diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/tir/prim_func.py new file mode 100644 index 0000000000000..923eb97d2758a --- /dev/null +++ b/python/tvm/script/tir/prim_func.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TVM Script Interface for PrimFunc""" + +import inspect +from typing import Callable + +from tvm.tir.function import PrimFunc +from ..parser import from_source + + +def prim_func(input_func: Callable) -> PrimFunc: + """Decorate a python function as tvm script. + + Parameters + ---------- + func : input_func + The function to be parsed. + + Returns + ------- + output : PrimFunc + The result functions. + """ + if inspect.isfunction(input_func): + result = from_source(input_func) + result.__name__ = input_func.__name__ + result.__qualname__ = input_func.__qualname__ + return result + + raise TypeError("Only function definitions are supported.") diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/tir/scope_handler.py similarity index 91% rename from python/tvm/script/scope_handler.py rename to python/tvm/script/tir/scope_handler.py index cba067990befd..1072809abf4bf 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -19,21 +19,21 @@ from typing import Tuple, Any, Callable, Optional, List, Union, Mapping import synr -from synr import ast import tvm.tir from tvm.runtime import Object from tvm.ir import Span, Range from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind -from .context_maintainer import ContextMaintainer -from .utils import ( +from .node import BufferSlice +from .utils import buffer_slice_to_region + +from ..context_maintainer import ContextMaintainer +from ..registry import register +from ..utils import ( get_param_list, tvm_span_from_synr, - buffer_slice_to_region, call_with_error_reporting, ) -from .registry import register -from .node import BufferSlice class ScopeHandler: @@ -81,14 +81,14 @@ def __init__(self, func, concise_scope, def_symbol): @staticmethod def get_optional_vars(node, context): - """Get a list ast.With's optional_vars""" + """Get a list synr.ast.With's optional_vars""" assert isinstance( - node, ast.With - ), f"WithScopeHandler expected ast.With but got {type(node)}" + node, synr.ast.With + ), f"WithScopeHandler expected synr.ast.With but got {type(node)}" if isinstance(node.lhs, list): for var in node.lhs: - if not isinstance(var, ast.Var): + if not isinstance(var, synr.ast.Var): context.report_error( f"Invalid optional var definition, expected Var but got {type(var)}", node.span, @@ -104,7 +104,7 @@ def get_optional_vars(node, context): @register class Allocate(WithScopeHandler): - """With scope handler tir.allocate(extents, dtype, scope, condition)""" + """With scope handler T.allocate(extents, dtype, scope, condition)""" def __init__(self): def allocate(extents, dtype, scope, condition=True, span=None): @@ -125,13 +125,13 @@ def enter_scope( span: synr.ast.Span, ): # define buffer vars in symbol table - if isinstance(node, ast.With): + if isinstance(node, synr.ast.With): vars = WithScopeHandler.get_optional_vars(node, context) if len(vars) != 1: context.report_error("Unexpected number of vars", node.span) name = vars[0].id.name var_span = vars[0].id.span - elif isinstance(node, ast.Assign): + elif isinstance(node, synr.ast.Assign): name = node.lhs.id.name var_span = node.lhs.id.span else: @@ -148,7 +148,7 @@ def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None): @register class LaunchThread(WithScopeHandler): - """With scope handler tir.launch_thread(env_var, extent)""" + """With scope handler T.launch_thread(env_var, extent)""" def __init__(self): def launch_thread(env_var, extent, span): @@ -174,7 +174,7 @@ def launch_thread(env_var, extent, span): @register class Realize(WithScopeHandler): - """With scope handler tir.realize(buffer_bounds, scope, condition)""" + """With scope handler T.realize(buffer_bounds, scope, condition)""" def __init__(self): def realize( @@ -204,7 +204,7 @@ def realize( @register class Attr(WithScopeHandler): - """With scope handler tir.attr(attr_node, attr_key, value)""" + """With scope handler T.attr(attr_node, attr_key, value)""" def __init__(self): def attr(attr_node, attr_key, value, span): @@ -217,7 +217,7 @@ def attr(attr_node, attr_key, value, span): @register class AssertHandler(WithScopeHandler): - """With scope handler tir.Assert(condition, message)""" + """With scope handler T.Assert(condition, message)""" def __init__(self): def Assert(condition, message, span): @@ -228,7 +228,7 @@ def Assert(condition, message, span): @register class Let(WithScopeHandler): - """With scope handler tir.let(var, value)""" + """With scope handler T.let(var, value)""" def __init__(self): def let(var, value, span): @@ -239,7 +239,7 @@ def let(var, value, span): @register class Block(WithScopeHandler): - """With scope handler tir.block(extents, name) as iter_vars""" + """With scope handler T.block(extents, name) as iter_vars""" def __init__(self): def block(axes=None, name_hint: str = "", span: Optional[Span] = None): @@ -268,7 +268,7 @@ def block(axes=None, name_hint: str = "", span: Optional[Span] = None): block_iters.append(IterVar(axis.dom, self.block_vars[i], axis.iter_type)) else: self.context.report_error( - "Invalid argument of tir.block(), " + "Invalid argument of T.block(), " + f"expected PrimExpr, Range or IterVar, but got {type(axis)}", self.node.span, ) @@ -347,8 +347,8 @@ def enter_scope( ): # define block vars assert isinstance( - node, ast.With - ), f"BlockScopeHandler expected to work on ast.With but got {type(node)}" + node, synr.ast.With + ), f"BlockScopeHandler expected to work on synr.ast.With but got {type(node)}" vars = WithScopeHandler.get_optional_vars(node, context) self.block_vars = [tvm.te.var(var.id.name) for var in vars] @@ -358,7 +358,7 @@ def enter_scope( @register class InitBlock(WithScopeHandler): - """With scope handler tir.init()""" + """With scope handler T.init()""" def __init__(self): def init(span: Span = None): @@ -384,16 +384,18 @@ def enter_scope( arg_list: List[Any], span: synr.ast.Span, ): - assert isinstance(node, ast.For), f"ForScopeHandler expected ast.For but got {type(node)}" + assert isinstance( + node, synr.ast.For + ), f"ForScopeHandler expected synr.ast.For but got {type(node)}" loop_var_names = list() spans = list() - if isinstance(node.lhs, ast.Var): + if isinstance(node.lhs, synr.ast.Var): loop_var_names.append(node.lhs.id.name) spans.append(tvm_span_from_synr(node.lhs.id.span)) elif isinstance(node.lhs, list): for elt in node.lhs: - if not isinstance(elt, ast.Var): + if not isinstance(elt, synr.ast.Var): context.report_error( f"Invalid loop var. Expected a var, but got {type(elt)}", elt.span ) @@ -489,7 +491,7 @@ def create_loop( @register class Serial(ForScopeHandler): - """For scope handler tir.serial(begin, end, annotations)""" + """For scope handler T.serial(begin, end, annotations)""" def __init__(self): def serial( @@ -505,7 +507,7 @@ def serial( @register class Parallel(ForScopeHandler): - """For scope handler tir.parallel(begin, end, annotations)""" + """For scope handler T.parallel(begin, end, annotations)""" def __init__(self): def parallel( @@ -523,7 +525,7 @@ def parallel( @register class Vectorized(ForScopeHandler): - """For scope handler tir.vectorized(begin, end, annotations)""" + """For scope handler T.vectorized(begin, end, annotations)""" def __init__(self): def vectorized( @@ -541,7 +543,7 @@ def vectorized( @register class Unroll(ForScopeHandler): - """For scope handler tir.unroll(begin, end, annotations)""" + """For scope handler T.unroll(begin, end, annotations)""" def __init__(self): def unroll( @@ -559,7 +561,7 @@ def unroll( @register class ThreadBinding(ForScopeHandler): - """For scope handler tir.thread_binding(begin, end, thread, annotations)""" + """For scope handler T.thread_binding(begin, end, thread, annotations)""" def __init__(self): def thread_binding( @@ -585,7 +587,7 @@ def thread_binding( @register class RangeHandler(ForScopeHandler): """For scope handler range(begin, end, annotations) - Note that tir.range is totally the same as tir.serial + Note that tir.range is totally the same as T.serial """ def __init__(self): @@ -608,7 +610,7 @@ def signature(self): @register class Grid(ForScopeHandler): - """For scope handler tir.grid(extents)""" + """For scope handler T.grid(extents)""" def __init__(self): def grid(*extents: List[PrimExpr], span: Span): diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/tir/special_stmt.py similarity index 95% rename from python/tvm/script/special_stmt.py rename to python/tvm/script/tir/special_stmt.py index 25af7635742b2..69cf15f493de4 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -27,15 +27,17 @@ from tvm import te from tvm.ir import Span from tvm.tir import IntImm -from .utils import ( + +from .node import BufferSlice +from .utils import buffer_slice_to_region + +from ..context_maintainer import ContextMaintainer +from ..registry import register +from ..utils import ( get_param_list, tvm_span_from_synr, - buffer_slice_to_region, call_with_error_reporting, ) -from .registry import register -from .context_maintainer import ContextMaintainer -from .node import BufferSlice def convert_to_int( @@ -109,11 +111,11 @@ class MatchBuffer(SpecialStmt): ------- Match buffer from function parameter .. code-block:: python - A = tir.match_buffer(a, (128, 128), dtype="float32") + A = T.match_buffer(a, (128, 128), dtype="float32") Match buffer from Buffer subregion .. code-block:: python - A = tir.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32") + A = T.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32") """ def __init__(self): @@ -183,7 +185,7 @@ class BufferDeclare(SpecialStmt): Example ------- .. code-block:: python - A = tir.buffer_decl((128, 128), dtype="float32") + A = T.buffer_decl((128, 128), dtype="float32") """ def __init__(self): @@ -239,7 +241,7 @@ class AllocBuffer(SpecialStmt): ------- .. code-block:: python - A = tir.alloc_buffer((128, 128), dtype="float32") + A = T.alloc_buffer((128, 128), dtype="float32") """ def __init__(self): @@ -294,7 +296,7 @@ class BlockVarBind(SpecialStmt): ------- .. code-block:: python - tir.bind(vx, i) + T.bind(vx, i) """ def __init__(self): @@ -315,7 +317,7 @@ class BlockReads(SpecialStmt): ------- .. code-block:: python - tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]]) + T.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]]) """ def __init__(self): @@ -350,7 +352,7 @@ class BlockWrites(SpecialStmt): ------- .. code-block:: python - tir.writes([C[vi: vi + 4, vj]) + T.writes([C[vi: vi + 4, vj]) """ def __init__(self): @@ -387,7 +389,7 @@ class BlockAttr(SpecialStmt): ------- .. code-block:: python - tir.block_attr({"double_buffer_scope": 1}) + T.block_attr({"double_buffer_scope": 1}) """ def __init__(self): @@ -418,7 +420,7 @@ class BlockPredicate(SpecialStmt): ------- .. code-block:: python - tir.where(i < 4) + T.where(i < 4) """ def __init__(self): @@ -491,7 +493,7 @@ class FuncAttr(SpecialStmt): Example ------- .. code-block:: python - tir.func_attr({"tir.noalias": True, "global_symbol"}) + T.func_attr({"tir.noalias": True, "global_symbol"}) """ def __init__(self): diff --git a/python/tvm/script/ty.py b/python/tvm/script/tir/ty.py similarity index 98% rename from python/tvm/script/ty.py rename to python/tvm/script/tir/ty.py index 960e090a163c5..6a4f7bc00cb6d 100644 --- a/python/tvm/script/ty.py +++ b/python/tvm/script/tir/ty.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""TVM Script Parser Typing Class +"""TVM Script Parser Typing Class for TIR This module provides typing class for TVM script type annotation usage, it can be viewed as a wrapper for uniform Type system in IR diff --git a/python/tvm/script/tir/utils.py b/python/tvm/script/tir/utils.py new file mode 100644 index 0000000000000..da201229eb001 --- /dev/null +++ b/python/tvm/script/tir/utils.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Helper functions in TVM Script Parser""" + +from typing import List, Optional, Union + +from tvm.arith import Analyzer +from tvm.ir import Range +from tvm.tir import PrimExpr, BufferRegion +from .node import BufferSlice + + +def buffer_slice_to_region( + buffer_slice: BufferSlice, analyzer: Optional[Analyzer] = None +) -> BufferRegion: + """Construct BufferRegion from BufferSlice + + Parameters + ---------- + buffer_slice : BufferSlice + The input BufferSlice + + analyzer : Optional[tvm.arith.Analyzer] + The analyzer for simplifying. If not provided, the method will construct a new one + + Returns + ------- + buffer_region : BufferRegion + The constructed BufferRegion. + """ + region: List[Range] = [] + for s in buffer_slice.slices: + start: Union[PrimExpr, int] = s.start + extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop - s.start + if not analyzer: + analyzer = Analyzer() + if isinstance(extent, PrimExpr): + extent = analyzer.simplify(extent) + region.append(Range.from_min_extent(start, extent, span=s.span)) + return BufferRegion(buffer_slice.buffer, region) diff --git a/python/tvm/script/utils.py b/python/tvm/script/utils.py index f8a0f610d477f..c655a6223740f 100644 --- a/python/tvm/script/utils.py +++ b/python/tvm/script/utils.py @@ -16,16 +16,13 @@ # under the License. """Helper functions in TVM Script Parser""" -from typing import Callable, List, Any, Optional, Tuple, Union +from typing import Callable, List, Any, Optional, Tuple import inspect import synr -from tvm.arith import Analyzer -from tvm.ir import Range, Span, SourceName -from tvm.tir import PrimExpr, BufferRegion +from tvm.ir import Span, SourceName from tvm.error import DiagnosticError -from .node import BufferSlice def get_param_list( @@ -68,36 +65,6 @@ def get_param_list( return pos_only, kwargs, full_arg_spec.varargs -def buffer_slice_to_region( - buffer_slice: BufferSlice, analyzer: Optional[Analyzer] = None -) -> BufferRegion: - """Construct BufferRegion from BufferSlice - - Parameters - ---------- - buffer_slice : BufferSlice - The input BufferSlice - - analyzer : Optional[tvm.arith.Analyzer] - The analyzer for simplifying. If not provided, the method will construct a new one - - Returns - ------- - buffer_region : BufferRegion - The constructed BufferRegion. - """ - region: List[Range] = [] - for s in buffer_slice.slices: - start: Union[PrimExpr, int] = s.start - extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop - s.start - if not analyzer: - analyzer = Analyzer() - if isinstance(extent, PrimExpr): - extent = analyzer.simplify(extent) - region.append(Range.from_min_extent(start, extent, span=s.span)) - return BufferRegion(buffer_slice.buffer, region) - - def tvm_span_from_synr(span: synr.ast.Span) -> Span: """Convert a synr span to a TVM span""" return Span( diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 4c361bca6c57e..681e322b20823 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -453,7 +453,7 @@ def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc: k = te.reduce_axis((0, 128), "k") C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C") func = create_prim_func([A, B, C]) - print(tvm.script.asscript(func)) + print(func.script()) If we want to use TensorIR schedule to do transformations on such kernel, we need to use `create_prim_func([A, B, C])` to create a schedulable PrimFunc. @@ -461,14 +461,14 @@ def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc: .. code-block:: python - @tvm.script.tir - def tir_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - C = tir.match_buffer(c, (128, 128)) + @T.prim_func + def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) - with tir.block([128, 128, tir.reduce_axis(0, 128)]) as [i, j, k]: - with tir.init(): + with T.block([128, 128, T.reduce_axis(0, 128)]) as [i, j, k]: + with T.init(): C[i, j] = 0.0 C[i, j] += A[i, k] * B[j, k] diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 68d967aa497df..ddbfa14f13c64 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -103,12 +103,12 @@ def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]): .. code-block:: python - @tvm.script.tir - def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None: - A = tir.match_buffer(a, (m, n), "float32") - B = tir.match_buffer(b, (m, n), "float32") + @T.prim_func + def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None: + A = T.match_buffer(a, (m, n), "float32") + B = T.match_buffer(b, (m, n), "float32") - with tir.block([m, n], "") as [vi, vj]: + with T.block([m, n], "") as [vi, vj]: B[vi, vj] = A[vi, vj] Then we can make it specialized with given shapes or buffers. @@ -124,12 +124,12 @@ def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None: .. code-block:: python - @tvm.script.tir - def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - B = tir.match_buffer(b, (16, 16), "float32") + @T.prim_func + def mem_copy_16_16(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + B = T.match_buffer(b, (16, 16), "float32") - with tir.block([16, 16], "") as [vi, vj]: + with T.block([16, 16], "") as [vi, vj]: B[vi, vj] = A[vi, vj] Returns @@ -138,3 +138,18 @@ def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None: The new function with parameter specialized """ return _ffi_api.Specialize(self, param_map) # type: ignore + + def script(self, show_meta=False) -> str: + """Print PrimFunc into TVMScript + + Parameters + ---------- + show_meta : bool + Whether to show meta information + + Returns + ------- + script : str + The TVM Script of the PrimFunc + """ + return tvm._ffi.get_global_func("script.AsTVMScript")(self, show_meta) # type: ignore diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index d26ffc0b1efaa..6e27015648f06 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -392,12 +392,12 @@ def fuse(self, *loops: List[LoopRV]) -> LoopRV: .. code-block:: python - @tvm.script.tir - def before_fuse(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "B") as [vi, vj]: + @T.prim_func + def before_fuse(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do fuse: @@ -407,21 +407,21 @@ def before_fuse(a: ty.handle, b: ty.handle) -> None: sch = tir.Schedule(before_fuse) i, j = sch.get_loops(sch.get_block("B")) sch.fuse(i, j) - print(tvm.script.asscript(sch.mod["main"])) + print(sch.mod["main"].script()) After applying fuse, the IR becomes: .. code-block:: python - @tvm.script.tir - def after_fuse(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) + @T.prim_func + def after_fuse(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) # the 2 loops are fused into 1 - for i_j_fused in tir.serial(0, 16384): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, tir.floordiv(i_j_fused, 128)) - tir.bind(vj, tir.floormod(i_j_fused, 128)) + for i_j_fused in T.serial(0, 16384): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, tir.floordiv(i_j_fused, 128)) + T.bind(vj, T.floormod(i_j_fused, 128)) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -463,12 +463,12 @@ def split( .. code-block:: python - @tvm.script.tir - def before_split(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "B") as [vi, vj]: + @T.prim_func + def before_split(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do split: @@ -478,21 +478,21 @@ def before_split(a: ty.handle, b: ty.handle) -> None: sch = tir.Schedule(before_split) i, j = sch.get_loops(sch.get_block("B")) sch.split(i, factors=[2, 64]) - print(tvm.script.asscript(sch.mod["main"])) + print(sch.mod["main"].script()) After applying split, the IR becomes: .. code-block:: python - @tvm.script.tir - def after_split(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) + @T.prim_func + def after_split(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) # the original loop is split into 2 loops - for i0, i1, j in tir.grid(2, 64, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, ((i0*64) + i1)) - tir.bind(vj, j) + for i0, i1, j in T.grid(2, 64, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, ((i0*64) + i1)) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -524,12 +524,12 @@ def reorder(self, *ordered_loops: List[LoopRV]) -> None: .. code-block:: python - @tvm.script.tir - def before_reorder(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "B") as [vi, vj]: + @T.prim_func + def before_reorder(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do reorder: @@ -539,21 +539,21 @@ def before_reorder(a: ty.handle, b: ty.handle) -> None: sch = tir.Schedule(before_reorder) i, j = sch.get_loops(sch.get_block("B")) sch.reorder(j, i) - print(tvm.script.asscript(sch.mod["main"])) + print(sch.mod["main"].script()) After applying reorder, the IR becomes: .. code-block:: python - @tvm.script.tir - def after_reorder(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) + @T.prim_func + def after_reorder(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) # Here j and i are reordered - for j, i in tir.grid(128, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + for j, i in T.grid(128, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -581,14 +581,14 @@ def parallel(self, loop: LoopRV) -> None: .. code-block:: python - @tvm.script.tir - def before_parallel(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + @T.prim_func + def before_parallel(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do parallel: @@ -603,15 +603,15 @@ def before_parallel(a: ty.handle, b: ty.handle) -> None: .. code-block:: python - @tvm.script.tir - def after_parallel(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i in tir.parallel(0, 128): - for j in tir.serial(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + @T.prim_func + def after_parallel(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i in T.parallel(0, 128): + for j in T.serial(0, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -637,14 +637,14 @@ def vectorize(self, loop: LoopRV) -> None: .. code-block:: python - @tvm.script.tir - def before_vectorize(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + @T.prim_func + def before_vectorize(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do vectorize: @@ -659,15 +659,15 @@ def before_vectorize(a: ty.handle, b: ty.handle) -> None: .. code-block:: python - @tvm.script.tir - def after_vectorize(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i in tir.serial(0, 128): - for j in tir.vectorized(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + @T.prim_func + def after_vectorize(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i in T.serial(0, 128): + for j in T.vectorized(0, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -701,14 +701,14 @@ def bind(self, loop: LoopRV, thread_axis: str) -> None: .. code-block:: python - @tvm.script.tir - def before_bind(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + @T.prim_func + def before_bind(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do bind: @@ -724,15 +724,15 @@ def before_bind(a: ty.handle, b: ty.handle) -> None: .. code-block:: python - @tvm.script.tir - def after_bind(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i in tir.thread_binding(0, 128, thread = "blockIdx.x"): - for j in tir.thread_binding(0, 128, thread = "threadIdx.x"): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + @T.prim_func + def after_bind(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i in T.thread_binding(0, 128, thread = "blockIdx.x"): + for j in T.thread_binding(0, 128, thread = "threadIdx.x"): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -753,14 +753,14 @@ def unroll(self, loop: LoopRV) -> None: .. code-block:: python - @tvm.script.tir - def before_unroll(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + @T.prim_func + def before_unroll(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do unroll: @@ -775,15 +775,15 @@ def before_unroll(a: ty.handle, b: ty.handle) -> None: .. code-block:: python - @tvm.script.tir - def after_unroll(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i in tir.unroll(0, 128): - for j in tir.serial(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + @T.prim_func + def after_unroll(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i in T.unroll(0, 128): + for j in T.serial(0, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -820,12 +820,12 @@ def cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str) .. code-block:: python - @tvm.script.tir - def before_cache_read(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "B") as [vi, vj]: + @T.prim_func + def before_cache_read(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and cache_read: @@ -835,22 +835,22 @@ def before_cache_read(a: ty.handle, b: ty.handle) -> None: sch = tir.Schedule(before_cache_read) block_b = sch.get_block("B") sch.cache_read(block_b, 0, "local") - print(tvm.script.asscript(sch.mod["main"])) + print(sch.mod["main"].script()) After applying cache_read, the IR becomes: .. code-block:: python - @tvm.script.tir - def after_cache_read(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - A_local = tir.alloc_buffer((128, 128), scope="local") - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "A_local") as [vi, vj]: + @T.prim_func + def after_cache_read(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + A_local = T.alloc_buffer((128, 128), scope="local") + for i, j in T.grid(128, 128): + with T.block([128, 128], "A_local") as [vi, vj]: A_local[vi, vj] = A[vi, vj] - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "B") as [vi, vj]: + for i, j in T.grid(128, 128): + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A_local[vi, vj] * 2.0 """ @@ -888,12 +888,12 @@ def cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: st .. code-block:: python - @tvm.script.tir - def before_cache_write(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "B") as [vi, vj]: + @T.prim_func + def before_cache_write(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and cache_write: @@ -903,22 +903,22 @@ def before_cache_write(a: ty.handle, b: ty.handle) -> None: sch = tir.Schedule(before_cache_write) block_b = sch.get_block("B") sch.cache_write(block_b, 0, "local") - print(tvm.script.asscript(sch.mod["main"])) + print(sch.mod["main"].script()) After applying cache_write, the IR becomes: .. code-block:: python - @tvm.script.tir - def after_cache_write(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - B_local = tir.alloc_buffer((128, 128), scope="local") - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "A_local") as [vi, vj]: + @T.prim_func + def after_cache_write(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + B_local = T.alloc_buffer((128, 128), scope="local") + for i, j in T.grid(128, 128): + with T.block([128, 128], "A_local") as [vi, vj]: B_local[vi, vj] = A[vi, vj] * 2.0 - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "B") as [vi, vj]: + for i, j in T.grid(128, 128): + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = B_local[vi, vj] """ @@ -969,14 +969,14 @@ def compute_at( .. code-block:: python - @tvm.script.tir - def before_compute_at(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") - with tir.block([128, 128], "B") as [vi, vj]: + @T.prim_func + def before_compute_at(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do compute-at: @@ -987,27 +987,27 @@ def before_compute_at(a: ty.handle, c: ty.handle) -> None: block = sch.get_block("B") loop, _ = sch.get_loops(sch.get_block("C")) sch.compute_at(block, loop, preserve_unit_loops=False) - print(tvm.script.asscript(sch.mod["main"])) + print(sch.mod["main"].script()) After applying compute-at, the IR becomes: .. code-block:: python - @tvm.script.tir - def after_compute_at(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") - for i in tir.serial(0, 128): - for j in tir.serial(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + @T.prim_func + def after_compute_at(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + for i in T.serial(0, 128): + for j in T.serial(0, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 - for j in tir.serial(0, 128): - with tir.block([128, 128], "C") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + for j in T.serial(0, 128): + with T.block([128, 128], "C") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) C[vi, vj] = B[vi, vj] + 1.0 """ @@ -1056,14 +1056,14 @@ def reverse_compute_at( .. code-block:: python - @tvm.script.tir - def before_reverse_compute_at(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") - with tir.block([128, 128], "B") as [vi, vj]: + @T.prim_func + def before_reverse_compute_at(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do reverse-compute-at: @@ -1074,27 +1074,27 @@ def before_reverse_compute_at(a: ty.handle, c: ty.handle) -> None: block = sch.get_block("C") loop, _ = sch.get_loops(sch.get_block("B")) sch.reverse_compute_at(block, loop, preserve_unit_loops=False) - print(tvm.script.asscript(sch.mod["main"])) + print(sch.mod["main"].script()) After applying reverse-compute-at, the IR becomes: .. code-block:: python - @tvm.script.tir - def after_reverse_compute_at(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") - for i in tir.serial(0, 128): - for j in tir.serial(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + @T.prim_func + def after_reverse_compute_at(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + for i in T.serial(0, 128): + for j in T.serial(0, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 - for j in tir.serial(0, 128): - with tir.block([128, 128], "C") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + for j in T.serial(0, 128): + with T.block([128, 128], "C") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) C[vi, vj] = B[vi, vj] + 1.0 """ @@ -1130,14 +1130,14 @@ def compute_inline(self, block: BlockRV) -> None: .. code-block:: python - @tvm.script.tir - def before_inline(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: + @T.prim_func + def before_inline(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do compute-inline: @@ -1146,17 +1146,17 @@ def before_inline(a: ty.handle, c: ty.handle) -> None: sch = tir.Schedule(before_inline) sch.compute_inline(sch.get_block("B")) - print(tvm.script.asscript(sch.mod["main"])) + print(sch.mod["main"].script()) After applying compute-inline, the IR becomes: .. code-block:: python - @tvm.script.tir - def after_inline(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "C") as [vi, vj]: + @T.prim_func + def after_inline(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ @@ -1190,14 +1190,14 @@ def reverse_compute_inline(self, block: BlockRV) -> None: .. code-block:: python - @tvm.script.tir - def before_inline(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: + @T.prim_func + def before_inline(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do reverse-compute-inline: @@ -1206,17 +1206,17 @@ def before_inline(a: ty.handle, c: ty.handle) -> None: sch = tir.Schedule(before_inline) sch.reverse_compute_inline(sch.get_block("C")) - print(tvm.script.asscript(sch.mod["main"])) + print(sch.mod["main"].script()) After applying reverse-compute-inline, the IR becomes: .. code-block:: python - @tvm.script.tir - def after_inline(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "C") as [vi, vj]: + @T.prim_func + def after_inline(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = A[vi, vj] * 2.0 + 1.0 """ @@ -1304,13 +1304,13 @@ def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV: .. code-block:: python - @tvm.script.tir - def before_rfactor(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128)) - B = tir.match_buffer(b, (128,)) - with tir.block([128, tir.reduce_axis(0, 128), - tir.reduce_axis(0, 128)], "B") as [vii, vi, vj]: - with tir.init(): + @T.prim_func + def before_rfactor(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128,)) + with T.block([128, T.reduce_axis(0, 128), + T.reduce_axis(0, 128)], "B") as [vii, vi, vj]: + with T.init(): B[vii] = 0.0 B[vii] = B[vii] + A[vii, vi, vj] @@ -1321,23 +1321,23 @@ def before_rfactor(a: ty.handle, b: ty.handle) -> None: sch = tir.Schedule(before_rfactor) _, _, k = sch.get_loops(sch.get_block("B")) sch.rfactor(k, 0) - print(tvm.script.asscript(sch.mod["main"])) + print(sch.mod["main"].script()) After applying rfactor, the IR becomes: .. code-block:: python - @tvm.script.tir - def after_rfactor(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128, 128]) - B = tir.match_buffer(b, [128]) - B_rf = tir.alloc_buffer([128, 128]) - with tir.block([128, 128, tir.reduce_axis(0, 128)], "B_rf") as [vi2, vii, vi]: - with tir.init(): + @T.prim_func + def after_rfactor(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128]) + B = T.match_buffer(b, [128]) + B_rf = T.alloc_buffer([128, 128]) + with T.block([128, 128, T.reduce_axis(0, 128)], "B_rf") as [vi2, vii, vi]: + with T.init(): B_rf[vi2, vii] = 0.0 B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2]) - with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vii_1, vi2_1]: - with tir.init(): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vii_1, vi2_1]: + with T.init(): B[vii_1] = 0.0 B[vii_1] = (B[vii_1] + B_rf[vi2_1, vii_1]) @@ -1402,14 +1402,14 @@ def storage_align( # pylint: disable=too-many-arguments .. code-block:: python - @tvm.script.tir - def before_storage_align(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: + @T.prim_func + def before_storage_align(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 Create the schedule and do storage_align: @@ -1418,21 +1418,21 @@ def before_storage_align(a: ty.handle, c: ty.handle) -> None: sch = tir.Schedule(before_storage_align) sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0, factor=128, offset=1) - print(tvm.script.asscript(sch.mod["main"])) + print(sch.mod["main"].script()) After applying rfactor, the IR becomes: .. code-block:: python - @tvm.script.tir - def after_storage_align(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: - tir.block_attr({"buffer_dim_align": [[[0, 128, 1]]]}) + @T.prim_func + def after_storage_align(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: + T.block_attr({"buffer_dim_align": [[[0, 128, 1]]]}) B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 After lowering passes, buffer B will have strides as [129, 1]. diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 2183319a006f5..f072f6b38a433 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -628,8 +628,8 @@ def CompactBufferAllocation(): .. code-block:: python for i in range(0, 16): - with tir.block([]): - B = tir.alloc_buffer(16, 16) + with T.block([]): + B = T.alloc_buffer(16, 16) for j in range(0, 16): B[i, j] = A[i, j] + 1 for j in range(0, 16): @@ -643,8 +643,8 @@ def CompactBufferAllocation(): .. code-block:: python for i in range(0, 16): - with tir.block([]): - B = tir.alloc_buffer(1, 16) + with T.block([]): + B = T.alloc_buffer(1, 16) for j in range(0, 16): B[0, j] = A[i, j] + 1 for j in range(0, 16): diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 7e4a56529ddc7..3514f3228e279 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -370,7 +370,7 @@ class TIRTextPrinter : public StmtFunctor, Doc PrintBody(const Stmt& body, bool indent = true); }; -String AsTVMScript(const ObjectRef& mod, bool show_meta = false); +String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "tir", bool show_meta = false); } // namespace tir } // namespace tvm diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index f232994480f84..e664477270a15 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -565,8 +565,8 @@ Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) { for (size_t i = 0; i < block_op->iter_vars.size(); ++i) block_attr_doc << Doc::NewLine() << "bind(" << Print(block_op->iter_vars[i]->var) << ", " << Print(op->iter_values[i]) << ")"; - block_attr_doc << Doc::NewLine() << "tir.reads(" << Print(block_op->reads) << ")"; - block_attr_doc << Doc::NewLine() << "tir.writes(" << Print(block_op->writes) << ")"; + block_attr_doc << Doc::NewLine() << "T.reads(" << Print(block_op->reads) << ")"; + block_attr_doc << Doc::NewLine() << "T.writes(" << Print(block_op->writes) << ")"; if (!block_op->annotations.empty()) { std::vector attr_docs; for (const auto& it : block_op->annotations) { diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 906dc258560a5..fdafdbfee0db5 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -76,9 +76,12 @@ class TVMScriptPrinter : public StmtFunctor, public ExprFunctor, public TypeFunctor { public: - explicit TVMScriptPrinter(bool show_meta, + explicit TVMScriptPrinter(const String& tir_prefix, bool show_meta, runtime::TypedPackedFunc annotate = nullptr) - : show_meta_(show_meta), annotate_(std::move(annotate)), meta_collector_(&meta_) {} + : tir_prefix_(tir_prefix), + show_meta_(show_meta), + annotate_(std::move(annotate)), + meta_collector_(&meta_) {} /*! * \brief Print the node. @@ -89,6 +92,8 @@ class TVMScriptPrinter : public StmtFunctor, TVM_DLL Doc Print(const ObjectRef& node); private: + /*! \brief The tir prefix */ + String tir_prefix_; /*! \brief whether show meta data */ bool show_meta_; /*! \brief additional comment function */ @@ -207,7 +212,7 @@ class TVMScriptPrinter : public StmtFunctor, * \param loop The for loop to be printed */ Doc PrintLoop(const For& loop); - /*! \brief Print all simple loops in stack into one line using tir.grid(). */ + /*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */ Doc PrintLoopStack(); /*! @@ -269,7 +274,7 @@ class TVMScriptPrinter : public StmtFunctor, * \param data The pointer to hold the data. */ template - static Doc PrintConstScalar(DataType dtype, const T* data) { + Doc PrintConstScalar(DataType dtype, const T* data) const { Doc doc; std::ostringstream os; if (dtype.is_float() || dtype.is_float16() || dtype.is_bfloat16()) { @@ -281,7 +286,8 @@ class TVMScriptPrinter : public StmtFunctor, } else if (dtype == DataType::Bool()) { doc << Doc::Text(data[0] ? "True" : "False"); } else { - doc << "tir." << runtime::DLDataType2String(dtype) << "(" << Doc::Text(os.str()) << ")"; + doc << tir_prefix_ << "." << runtime::DLDataType2String(dtype) << "(" << Doc::Text(os.str()) + << ")"; } return doc; } @@ -404,8 +410,8 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) { const Buffer& buf = op->buffer; buf_not_in_headers_.insert(buf.get()); - Doc doc = Print(op->buffer) << " = tir.match_buffer(" << Print(op->source) << ", " - << memo_buf_decl_[op->buffer] << ")"; + Doc doc = Print(op->buffer) << " = " << tir_prefix_ << ".match_buffer(" << Print(op->source) + << ", " << memo_buf_decl_[op->buffer] << ")"; return doc; } @@ -470,7 +476,7 @@ Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op, ExprPrecedence* out_pr Doc TVMScriptPrinter::VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << "tir.cast(" << Print(op->value) << ", " << PrintDType(op->dtype) << ")"; + doc << tir_prefix_ << ".cast(" << Print(op->value) << ", " << PrintDType(op->dtype) << ")"; return doc; } @@ -506,7 +512,7 @@ Doc TVMScriptPrinter::VisitExpr_(const VarNode* op, ExprPrecedence* out_preceden return doc; \ } -TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, "*", ExprPrecedence::kMultiplicationDivision) +TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", ExprPrecedence::kMultiplicationDivision) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", ExprPrecedence::kMultiplicationDivision) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(ModNode, " % ", ExprPrecedence::kMultiplicationDivision) TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", ExprPrecedence::kAdditionSubtraction) @@ -523,28 +529,28 @@ TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", ExprPrecedence::kOr) Doc TVMScriptPrinter::VisitExpr_(const FloorDivNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << "tir.floordiv(" << Print(op->a) << ", " << Print(op->b) << ")"; + doc << tir_prefix_ << ".floordiv(" << Print(op->a) << ", " << Print(op->b) << ")"; return doc; } Doc TVMScriptPrinter::VisitExpr_(const FloorModNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << "tir.floormod(" << Print(op->a) << ", " << Print(op->b) << ")"; + doc << tir_prefix_ << ".floormod(" << Print(op->a) << ", " << Print(op->b) << ")"; return doc; } Doc TVMScriptPrinter::VisitExpr_(const MinNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << "tir.min(" << Print(op->a) << ", " << Print(op->b) << ")"; + doc << tir_prefix_ << ".min(" << Print(op->a) << ", " << Print(op->b) << ")"; return doc; } Doc TVMScriptPrinter::VisitExpr_(const MaxNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << "tir.max(" << Print(op->a) << ", " << Print(op->b) << ")"; + doc << tir_prefix_ << ".max(" << Print(op->a) << ", " << Print(op->b) << ")"; return doc; } @@ -558,7 +564,7 @@ Doc TVMScriptPrinter::VisitExpr_(const NotNode* op, ExprPrecedence* out_preceden Doc TVMScriptPrinter::VisitExpr_(const SelectNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << "tir.Select(" << Print(op->condition) << ", " << Print(op->true_value) << ", " + doc << tir_prefix_ << ".Select(" << Print(op->condition) << ", " << Print(op->true_value) << ", " << Print(op->false_value) << ")"; return doc; } @@ -587,7 +593,7 @@ Doc TVMScriptPrinter::VisitExpr_(const LoadNode* op, ExprPrecedence* out_precede op->buffer_var->dtype == DataType::Float(32)) { doc << Print(op->buffer_var) << "[" << Print(op->index) << "]"; } else { - doc << "tir.load(" << PrintDType(op->dtype) << ", " << Print(op->buffer_var) << ", " + doc << tir_prefix_ << ".load(" << PrintDType(op->dtype) << ", " << Print(op->buffer_var) << ", " << Print(op->index); if (!is_one(op->predicate) || op->dtype.lanes() != 1) { doc << ", " << Print(op->predicate); @@ -600,21 +606,23 @@ Doc TVMScriptPrinter::VisitExpr_(const LoadNode* op, ExprPrecedence* out_precede Doc TVMScriptPrinter::VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << "tir.ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")"; + doc << tir_prefix_ << ".ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " + << op->lanes << ")"; return doc; } Doc TVMScriptPrinter::VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << "tir.broadcast(" << Print(op->value) << ", " << op->lanes << ")"; + doc << tir_prefix_ << ".broadcast(" << Print(op->value) << ", " << op->lanes << ")"; return doc; } Doc TVMScriptPrinter::VisitExpr_(const LetNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << "tir.let(" << Print(op->var) << ", " << Print(op->value) << ", " << Print(op->body) << ")"; + doc << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << ", " + << Print(op->body) << ")"; return doc; } @@ -622,7 +630,11 @@ Doc TVMScriptPrinter::VisitExpr_(const CallNode* op, ExprPrecedence* out_precede *out_precedence = ExprPrecedence::kIdentity; Doc doc; if (auto* ptr_op = op->op.as()) { - doc << Doc::Text(ptr_op->name) << "("; + std::string name = ptr_op->name; + if (name.find("tir.") == 0) { + name = tir_prefix_ + "." + name.substr(4); + } + doc << name << "("; } else { auto* op_gvar = op->op.as(); ICHECK(op_gvar != nullptr); @@ -640,14 +652,14 @@ Doc TVMScriptPrinter::VisitExpr_(const CallNode* op, ExprPrecedence* out_precede Doc TVMScriptPrinter::VisitExpr_(const ShuffleNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << "tir.shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")"; + doc << tir_prefix_ << ".shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")"; return doc; } Doc TVMScriptPrinter::VisitExpr_(const ReduceNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; - doc << "tir.reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " + doc << tir_prefix_ << ".reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " << Print(op->axis) << ", " << op->value_index << ")"; return doc; } @@ -655,7 +667,7 @@ Doc TVMScriptPrinter::VisitExpr_(const ReduceNode* op, ExprPrecedence* out_prece Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) { Doc doc; if (current_num_ != num_child_ - 1) { - doc << "with tir.let(" << Print(op->var) << ", " << Print(op->value) << "):"; + doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):"; doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); } else { if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get()); @@ -673,15 +685,15 @@ Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) { const auto* realize = Downcast(op->body).get(); if (realize->buffer.same_as(op->node)) { if (current_num_ != num_child_ - 1) { - doc << "with tir.realize(" << Print(realize->buffer) << Print(realize->bounds) << ", " - << Print(op->value); + doc << "with " << tir_prefix_ << ".realize(" << Print(realize->buffer) + << Print(realize->bounds) << ", " << Print(op->value); if (!is_one(realize->condition)) { doc << ", " << Print(realize->condition); } doc << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(realize->body)); } else { - doc << "tir.realize(" << Print(realize->buffer) << Print(realize->bounds) << ", " - << Print(op->value); + doc << tir_prefix_ << ".realize(" << Print(realize->buffer) << Print(realize->bounds) + << ", " << Print(op->value); if (!is_one(realize->condition)) { doc << ", " << Print(realize->condition); } @@ -697,10 +709,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) { var_not_in_headers_.insert(iter_var->var.get()); var_env_map_[iter_var->var] = iter_var->thread_tag; if (current_num_ != num_child_ - 1) { - doc << "with tir.launch_thread(" << Print(iter_var->var) << ", " << Print(op->value) << "):"; + doc << "with " << tir_prefix_ << ".launch_thread(" << Print(iter_var->var) << ", " + << Print(op->value) << "):"; doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); } else { - doc << "tir.launch_thread(" << Print(iter_var->var) << ", " << Print(op->value) << ")"; + doc << tir_prefix_ << ".launch_thread(" << Print(iter_var->var) << ", " << Print(op->value) + << ")"; doc << Doc::NewLine() << PrintBody(op->body); } TryDeallocVar(iter_var->var); @@ -708,12 +722,12 @@ Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) { } // default if (current_num_ != num_child_ - 1) { - doc << "with tir.attr(" << Print(op->node) << ", " << Doc::StrLiteral(op->attr_key) << ", " - << Print(op->value) << "):"; + doc << "with " << tir_prefix_ << ".attr(" << Print(op->node) << ", " + << Doc::StrLiteral(op->attr_key) << ", " << Print(op->value) << "):"; doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); } else { - doc << "tir.attr(" << Print(op->node) << ", " << Doc::StrLiteral(op->attr_key) << ", " - << Print(op->value) << ")"; + doc << tir_prefix_ << ".attr(" << Print(op->node) << ", " << Doc::StrLiteral(op->attr_key) + << ", " << Print(op->value) << ")"; doc << Doc::NewLine() << PrintBody(op->body); } return doc; @@ -722,7 +736,8 @@ Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) { Doc TVMScriptPrinter::VisitStmt_(const AssertStmtNode* op) { Doc doc; if (current_num_ != num_child_ - 1) { - doc << "with tir.Assert(" << Print(op->condition) << ", " << Print(op->message) << "):"; + doc << "with " << tir_prefix_ << ".Assert(" << Print(op->condition) << ", " + << Print(op->message) << "):"; doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); } else { doc << "assert " << Print(op->condition) << ", " << Print(op->message); @@ -733,7 +748,7 @@ Doc TVMScriptPrinter::VisitStmt_(const AssertStmtNode* op) { Doc TVMScriptPrinter::VisitStmt_(const StoreNode* op) { Doc doc; - doc << "tir.store(" << Print(op->buffer_var) << ", " << Print(op->index) << ", " + doc << tir_prefix_ << ".store(" << Print(op->buffer_var) << ", " << Print(op->index) << ", " << Print(op->value) << ", " << Print(op->predicate) << ")"; return doc; } @@ -749,16 +764,16 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; auto storage_scope = GetPtrStorageScope(op->buffer_var); if (current_num_ != num_child_ - 1) { - doc << "with tir.allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype) << ", " - << Print(storage_scope); + doc << "with " << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " + << PrintDType(op->dtype) << ", " << Print(storage_scope); if (!is_one(op->condition)) { doc << ", " << Print(op->condition); } doc << ") as " << Print(op->buffer_var) << ":"; doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); } else { - doc << Print(op->buffer_var) << " = tir.allocate(" << Print(op->extents) << ", " - << PrintDType(op->dtype) << ", " << Print(storage_scope); + doc << Print(op->buffer_var) << " = " << tir_prefix_ << ".allocate(" << Print(op->extents) + << ", " << PrintDType(op->dtype) << ", " << Print(storage_scope); if (!is_one(op->condition)) { doc << ", " << Print(op->condition); } @@ -789,7 +804,7 @@ Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) { Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) { Doc doc; - doc << "tir.evaluate(" << Print(op->value) << ")"; + doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")"; return doc; } @@ -827,7 +842,7 @@ Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { Doc TVMScriptPrinter::VisitStmt_(const PrefetchNode* op) { Doc doc; - doc << "tir.prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")"; + doc << tir_prefix_ << ".prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")"; return doc; } @@ -840,13 +855,13 @@ Doc TVMScriptPrinter::VisitStmt_(const WhileNode* op) { Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) { Doc doc; - doc << "ty." << runtime::DLDataType2String(node->dtype); + doc << tir_prefix_ << "." << runtime::DLDataType2String(node->dtype); return doc; } Doc TVMScriptPrinter::VisitType_(const PointerTypeNode* node) { Doc doc; - doc << "ty.Ptr["; + doc << tir_prefix_ << ".Ptr["; if (!node->storage_scope.empty()) { doc << node->storage_scope << " "; } @@ -862,7 +877,7 @@ Doc TVMScriptPrinter::VisitType_(const TupleTypeNode* node) { for (Type field : node->fields) { fields.push_back(Print(field)); } - return Doc::Text("ty.Tuple[") << Doc::Concat(fields) << "]"; + return Doc::Text(tir_prefix_ + ".Tuple[") << Doc::Concat(fields) << "]"; } } @@ -878,14 +893,14 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { Doc TVMScriptPrinter::PrintBlockVar(const BlockNode* op) { Doc doc; - doc << "with tir.block(["; + doc << "with " << tir_prefix_ << ".block(["; std::vector block_var_docs; for (const auto& iter_var : op->iter_vars) { Doc block_var_doc; if (is_zero(iter_var->dom->min) && iter_var->iter_type == kDataPar) { block_var_doc << Print(iter_var->dom->extent); } else { - block_var_doc << "tir."; + block_var_doc << tir_prefix_ << "."; switch (iter_var->iter_type) { case kDataPar: block_var_doc << "range"; @@ -930,15 +945,16 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) { Doc block_attr_doc; // print predicate, binding, read/write tensor region, annotations if (!is_one(op->predicate)) { - block_attr_doc << Doc::NewLine() << "tir.where(" << Print(op->predicate) << ")"; + block_attr_doc << Doc::NewLine() << tir_prefix_ << ".where(" << Print(op->predicate) << ")"; } for (size_t i = 0; i < block_op->iter_vars.size(); ++i) - block_attr_doc << Doc::NewLine() << "tir.bind(" << Print(block_op->iter_vars[i]->var) << ", " - << Print(op->iter_values[i]) << ")"; - block_attr_doc << Doc::NewLine() << "tir.reads(" << Print(block_op->reads) << ")"; - block_attr_doc << Doc::NewLine() << "tir.writes(" << Print(block_op->writes) << ")"; + block_attr_doc << Doc::NewLine() << tir_prefix_ << ".bind(" + << Print(block_op->iter_vars[i]->var) << ", " << Print(op->iter_values[i]) + << ")"; + block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads(" << Print(block_op->reads) << ")"; + block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes(" << Print(block_op->writes) << ")"; if (!block_op->annotations.empty()) { - block_attr_doc << Doc::NewLine() << "tir.block_attr({"; + block_attr_doc << Doc::NewLine() << tir_prefix_ << ".block_attr({"; block_attr_doc << PrintAnnotations(block_op->annotations); block_attr_doc << "})"; } @@ -949,15 +965,15 @@ Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) { Doc body; for (const auto& alloc_buf : op->alloc_buffers) { buf_not_in_headers_.insert(alloc_buf.get()); - body << Print(alloc_buf) << " = tir.alloc_buffer(" << memo_buf_decl_[alloc_buf] << ")" - << Doc::NewLine(); + body << Print(alloc_buf) << " = " << tir_prefix_ << ".alloc_buffer(" + << memo_buf_decl_[alloc_buf] << ")" << Doc::NewLine(); } for (const auto& match_buf : op->match_buffers) { body << Print(match_buf) << Doc::NewLine(); } if (op->init.defined()) { Doc init_block; - init_block << "with tir.init():"; + init_block << "with " << tir_prefix_ << ".init():"; init_block << Doc::Indent(4, Doc::NewLine() << PrintBody(op->init.value())); body << init_block << Doc::NewLine(); } @@ -1010,6 +1026,7 @@ Doc TVMScriptPrinter::PrintBody(const Stmt& body) { Doc TVMScriptPrinter::PrintIRModule(const IRModule& module) { auto* op = module.operator->(); Doc doc; + doc << "@tvm.script.ir_module" << Doc::NewLine(); doc << "class Module:"; for (const auto& x : op->functions) { func2var_[x.second.operator->()] = x.first; @@ -1038,6 +1055,7 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { buf_not_in_headers_.clear(); // print signature Doc doc; + doc << "@" << tir_prefix_ << ".prim_func" << Doc::NewLine(); doc << "def " << (func2var_.find(op) == func2var_.end() ? "func" : func2var_[op]->name_hint) << "("; std::vector params; @@ -1053,13 +1071,13 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { auto it = op->buffer_map.find(param); if (it == op->buffer_map.end()) continue; buf_not_in_headers_.insert((*it).second.get()); - body << Print((*it).second) << " = tir.match_buffer("; + body << Print((*it).second) << " = " << tir_prefix_ << ".match_buffer("; body << Print((*it).first) << ", " << memo_buf_decl_[(*it).second]; body << ")" << Doc::NewLine(); } // print comm_reducer for (const auto& it : memo_reducer_) { - body << it.second << " = tir.comm_reducer("; + body << it.second << " = .comm_reducer("; var_not_in_headers_.insert(it.first->lhs[0].get()); var_not_in_headers_.insert(it.first->rhs[0].get()); body << "lambda " << Print(it.first->lhs[0]) << ", " << Print(it.first->rhs[0]) << ": " @@ -1071,7 +1089,7 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { if (op->body->IsInstance() && op->body.as()->iter_values.empty()) { // Skip print root block - body << "# tir.with block(\"root\")" << Doc::NewLine(); + body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine(); const BlockNode* block = op->body.as()->block.get(); body << PrintBlockBody(block); } else { @@ -1080,7 +1098,8 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { // print func attrs Doc header_attr; if (primFunc->attrs.defined()) { - header_attr << Doc::NewLine() << "# function attr dict" << Doc::NewLine() << "tir.func_attr({"; + header_attr << Doc::NewLine() << "# function attr dict" << Doc::NewLine() << tir_prefix_ + << ".func_attr({"; std::vector attrs; for (const auto& it : op->attrs->dict) { attrs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); @@ -1101,7 +1120,8 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { return memo_buf_[GetRef(a)].str() < memo_buf_[GetRef(b)].str(); }); for (const auto& buf : bufs) { - header_buf << Doc::NewLine() << Print(GetRef(buf)) << " = tir.buffer_decl("; + header_buf << Doc::NewLine() << Print(GetRef(buf)) << " = " << tir_prefix_ + << ".buffer_decl("; header_buf << memo_buf_decl_[GetRef(buf)] << ")"; } } @@ -1116,7 +1136,7 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { if (!var_env_map_.empty()) { header_var << Doc::NewLine() << "# var definition"; for (const auto& it : var_env_map_) { - header_var << Doc::NewLine() << Print(it.first) << " = tir.env_thread(" + header_var << Doc::NewLine() << Print(it.first) << " = " << tir_prefix_ << ".env_thread(" << Doc::StrLiteral(it.second) << ")"; } } @@ -1129,11 +1149,12 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { if (auto* ptr_type = type.as()) { auto* prim_type = ptr_type->element_type.as(); ICHECK(prim_type); - header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.buffer_var("; + header_var << Doc::NewLine() << Print(GetRef(var)) << " = " << tir_prefix_ + << ".buffer_var("; header_var << PrintDType(prim_type->dtype) << ", " << Doc::StrLiteral(ptr_type->storage_scope) << ")"; } else { - header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.var("; + header_var << Doc::NewLine() << Print(GetRef(var)) << " = " << tir_prefix_ << ".var("; header_var << PrintDType(var->dtype) << ")"; } } @@ -1157,7 +1178,7 @@ Doc TVMScriptPrinter::PrintArray(const ArrayNode* op) { Doc TVMScriptPrinter::PrintIterVar(const IterVarNode* op) { Doc doc; - doc << "tir.iter_var(" << Print(op->var); + doc << tir_prefix_ << ".iter_var(" << Print(op->var); if (op->dom.defined()) { doc << ", [" << Print(op->dom) << "], "; } else { @@ -1216,15 +1237,15 @@ Doc TVMScriptPrinter::PrintAnnotations(const Map& annotations Doc TVMScriptPrinter::PrintLoop(const For& loop) { Doc res; - res << "for " << Print(loop->loop_var) - << " in tir." + std::string(ForKind2String(loop->kind)) + "(" << Print(loop->min) << ", " + res << "for " << Print(loop->loop_var) << " in " << tir_prefix_ + << "." + std::string(ForKind2String(loop->kind)) + "(" << Print(loop->min) << ", " << Print(loop->min + loop->extent); if (loop->thread_binding.defined()) { - res << ", thread = "; + res << ", thread="; res << Print(loop->thread_binding.value()->thread_tag); } if (!loop->annotations.empty()) { - res << ", annotations = {"; + res << ", annotations={"; res << PrintAnnotations(loop->annotations); res << "}"; } @@ -1242,15 +1263,15 @@ Doc TVMScriptPrinter::PrintLoopStack() { vars.push_back(Print(loop->loop_var)); extents.push_back(Print(loop->extent)); } - res << "for " << PrintSep(vars, Doc::Text(", ")) << " in tir.grid(" + res << "for " << PrintSep(vars, Doc::Text(", ")) << " in " << tir_prefix_ << ".grid(" << PrintSep(extents, Doc::Text(", ")) << "):"; } return res; } -String AsTVMScript(const ObjectRef& mod, bool show_meta) { +String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) { ICHECK(mod->IsInstance() || mod->IsInstance()); - return "@tvm.script.tir\n" + TVMScriptPrinter(show_meta).Print(mod).str() + "\n"; + return TVMScriptPrinter(tir_prefix, show_meta).Print(mod).str() + "\n"; } TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript); diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 768787735a1fe..1e5b2f28b2d96 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -267,14 +267,14 @@ class PrimFuncSpecializer : public StmtExprMutator { * \param var_map The var mapping to be updated. * \note This function will match target buffer's shape, strides and element_offset * For example, we define a buffer in PrimFunc: - * A = tir.match_buffer(a, [m, n]) + * A = T.match_buffer(a, [m, n]) * * Then we match it with a buffer B = tir.decl_buffer((8, 16)) * * It means we have two var mappings here: m = 8 and n = 16 * * If the buffer signature is not a Var, the mapping will fail. - * e.g. A = tir.match_buffer(a, [m * 2, n + 1]) + * e.g. A = T.match_buffer(a, [m * 2, n + 1]) */ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer& specific_buf, VarMap* var_map) { diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index c2de78863d791..539a82f9ae5c4 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -290,9 +290,9 @@ class BaseInliner : public StmtExprMutator { /*! * \brief Update the following block signature: - * 1) tir.alloc_buffer, if the block is scope root - * 2) tir.reads, if the block is not scope root - * 3) tir.writes, if the block is not scope root + * 1) T.alloc_buffer, if the block is scope root + * 2) T.reads, if the block is not scope root + * 3) T.writes, if the block is not scope root * \param block The block to be updated * \param is_scope_root A flag indicating if a block is the scope root of the block to be inlined * \return The updated block diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index eb3a4d8cb4dad..32dde21230aaf 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -18,10 +18,8 @@ pytest.importorskip("ethosu.vela") import tvm -from tvm import tir -from tvm import script from tvm import relay -from tvm.script import ty +from tvm.script import tir as T from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute @@ -30,36 +28,37 @@ # fmt: off -@tvm.script.tir +@tvm.script.ir_module class WeightStreamOnly: - def main(placeholder: ty.handle, ethosu_write: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, placeholder_5: ty.handle, placeholder_6: ty.handle, placeholder_7: ty.handle, placeholder_8: ty.handle) -> None: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, placeholder_7: T.handle, placeholder_8: T.handle) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = tir.match_buffer(placeholder_7, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_4, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = tir.match_buffer(placeholder_2, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = tir.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_4 = tir.match_buffer(placeholder_5, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_9 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_5 = tir.match_buffer(placeholder_3, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_6 = tir.match_buffer(placeholder_1, [128], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_7 = tir.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = T.match_buffer(placeholder_7, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = T.match_buffer(placeholder_4, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = T.match_buffer(placeholder_2, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = T.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_4 = T.match_buffer(placeholder_5, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_9 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_5 = T.match_buffer(placeholder_3, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_6 = T.match_buffer(placeholder_1, [128], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_7 = T.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - placeholder_global = tir.allocate([128], "uint8", "global") - placeholder_d_global = tir.allocate([32], "uint8", "global") - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_6.data, 0), 128, tir.load("uint8", placeholder_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_2.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 128, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_5.data, 0), 112, tir.load("uint8", placeholder_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_1.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 2), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 112, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_4.data, 0), 112, tir.load("uint8", placeholder_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_7.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 4), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 112, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer.data, 0), 112, tir.load("uint8", placeholder_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_3.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 6), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 112, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + placeholder_global = T.allocate([128], "uint8", "global") + placeholder_d_global = T.allocate([32], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6.data, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5.data, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4.data, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_9.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -93,7 +92,7 @@ def _get_func(): func = _get_func() mod, consts = lower_to_tir(func, cascader=_planner) - script = tvm.script.asscript(mod, True) + script = mod.script(True) test_mod = tvm.script.from_source(script) reference_mod = WeightStreamOnly() tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) @@ -107,21 +106,22 @@ def _get_func(): # fmt: off -@tvm.script.tir +@tvm.script.ir_module class DirectReadOnly: - def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, ethosu_write: T.handle) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = tir.match_buffer(placeholder_3, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_1, [592], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = T.match_buffer(placeholder_3, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = T.match_buffer(placeholder_1, [592], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = T.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = tir.allocate([4096], "int8", "global") - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 592, 12, tir.load("uint8", buffer_2.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 160, 12, tir.load("uint8", buffer_3.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + ethosu_write_2 = T.allocate([4096], "int8", "global") + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 592, 12, T.load("uint8", buffer_2.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 160, 12, T.load("uint8", buffer_3.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -154,7 +154,7 @@ def _get_func(): func = _get_func() mod, consts = lower_to_tir(func) - script = tvm.script.asscript(mod, True) + script = mod.script(True) test_mod = tvm.script.from_source(script) reference_mod = DirectReadOnly() tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) @@ -168,40 +168,41 @@ def _get_func(): # fmt: off -@tvm.script.tir +@tvm.script.ir_module class MixedRead: - def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, placeholder_5: ty.handle, placeholder_6: ty.handle, placeholder_7: ty.handle, placeholder_8: ty.handle, placeholder_9: ty.handle, placeholder_10: ty.handle) -> None: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, placeholder_7: T.handle, placeholder_8: T.handle, placeholder_9: T.handle, placeholder_10: T.handle) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = tir.match_buffer(placeholder_7, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_5, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = tir.match_buffer(placeholder_3, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = tir.match_buffer(placeholder_4, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_4 = tir.match_buffer(placeholder_9, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_5 = tir.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_11 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_6 = tir.match_buffer(placeholder_1, [592], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_7 = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_8 = tir.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_9 = tir.match_buffer(placeholder_10, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = T.match_buffer(placeholder_7, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = T.match_buffer(placeholder_5, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = T.match_buffer(placeholder_3, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = T.match_buffer(placeholder_4, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_4 = T.match_buffer(placeholder_9, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_5 = T.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_11 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_6 = T.match_buffer(placeholder_1, [592], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_7 = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_8 = T.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_9 = T.match_buffer(placeholder_10, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = tir.allocate([4096], "int8", "global") - placeholder_global = tir.allocate([80], "uint8", "global") - placeholder_d_global = tir.allocate([32], "uint8", "global") - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_11.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer_6.data, 0), 592, 12, tir.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_2.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_3.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_1.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_5.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 2), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_8.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 4), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_4.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_9.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 6), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + ethosu_write_2 = T.allocate([4096], "int8", "global") + placeholder_global = T.allocate([80], "uint8", "global") + placeholder_d_global = T.allocate([32], "uint8", "global") + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_11.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_6.data, 0), 592, 12, T.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_8.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4.data, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_9.data, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write_1.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -245,7 +246,7 @@ def _get_func(): func = _get_func() mod, consts = lower_to_tir(func, cascader=_planner) - script = tvm.script.asscript(mod, True) + script = mod.script(True) test_mod = tvm.script.from_source(script) reference_mod = MixedRead() tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index f66b21b92a03a..b357b5e38e556 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -18,8 +18,7 @@ pytest.importorskip("ethosu.vela") import tvm -import tvm.script -from tvm.script import tir, ty +from tvm.script import tir as T from tvm import relay from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir @@ -190,85 +189,89 @@ def _visit(stmt): # fmt: off -@tvm.script.tir +@tvm.script.ir_module class Conv2dDoubleCascade1: - def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, ethosu_write: T.handle) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = tir.match_buffer(placeholder_3, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = tir.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = tir.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = tir.match_buffer(placeholder_1, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = T.match_buffer(placeholder_3, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = T.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = T.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = T.match_buffer(placeholder_1, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = tir.allocate([1024], "int8", "global") - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 160, 12, tir.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 304, 12, tir.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, tir.load("int8", placeholder_5.data, 12), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 160, 12, tir.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, tir.load("int8", ethosu_write_1.data, 32), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 304, 12, tir.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + ethosu_write_2 = T.allocate([1024], "int8", "global") + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 160, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 304, 12, T.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 12), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 160, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 32), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 304, 12, T.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None -@tvm.script.tir +@tvm.script.ir_module class Conv2dDoubleCascade2: - def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, ethosu_write: T.handle) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = tir.match_buffer(placeholder_1, [1312], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = tir.match_buffer(placeholder_3, [2608], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = tir.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = T.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = T.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = T.match_buffer(placeholder_1, [1312], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = T.match_buffer(placeholder_3, [2608], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = tir.allocate([1536], "int8", "global") - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, tir.load("int8", ethosu_write_2, 256), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_2.data, 0), 1312, 12, tir.load("uint8", buffer_1.data, 0), 320, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, tir.load("int8", ethosu_write_2, 256), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 2608, 12, tir.load("uint8", buffer.data, 0), 80, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, tir.load("int8", placeholder_5.data, 48), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_2.data, 0), 1312, 12, tir.load("uint8", buffer_1.data, 0), 320, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, tir.load("int8", ethosu_write_1.data, 256), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 2608, 12, tir.load("uint8", buffer.data, 0), 80, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + ethosu_write_2 = T.allocate([1536], "int8", "global") + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2.data, 0), 1312, 12, T.load("uint8", buffer_1.data, 0), 320, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 2608, 12, T.load("uint8", buffer.data, 0), 80, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 48), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2.data, 0), 1312, 12, T.load("uint8", buffer_1.data, 0), 320, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 2608, 12, T.load("uint8", buffer.data, 0), 80, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None -@tvm.script.tir +@tvm.script.ir_module class Conv2dDoubleCascade3: - def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, ethosu_write: T.handle) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 20, 4, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer = tir.match_buffer(placeholder_3, [1744], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = tir.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = tir.match_buffer(placeholder_1, [880], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = tir.match_buffer(placeholder, [1, 16, 16, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 20, 4, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = T.match_buffer(placeholder_3, [1744], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = T.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = T.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = T.match_buffer(placeholder_1, [880], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder, [1, 16, 16, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = tir.allocate([2560], "int8", "global") - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, tir.load("int8", ethosu_write_2, 512), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer_3.data, 0), 880, 12, tir.load("uint8", buffer_2.data, 0), 320, 2, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, tir.load("int8", ethosu_write_2, 512), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer.data, 0), 1744, 12, tir.load("uint8", buffer_1.data, 0), 80, 2, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, tir.load("int8", placeholder_5.data, 192), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer_3.data, 0), 880, 12, tir.load("uint8", buffer_2.data, 0), 320, 0, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, tir.load("int8", ethosu_write_1.data, 256), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer.data, 0), 1744, 12, tir.load("uint8", buffer_1.data, 0), 80, 0, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, tir.load("int8", placeholder_5.data, 576), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer_3.data, 0), 880, 12, tir.load("uint8", buffer_2.data, 0), 320, 0, 1, 2, 0, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, tir.load("int8", ethosu_write_1.data, 512), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer.data, 0), 1744, 12, tir.load("uint8", buffer_1.data, 0), 80, 0, 1, 2, 0, "NONE", 0, 0, "NONE", dtype="handle")) + ethosu_write_2 = T.allocate([2560], "int8", "global") + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 2, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer.data, 0), 1744, 12, T.load("uint8", buffer_1.data, 0), 80, 2, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, T.load("int8", placeholder_5.data, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer.data, 0), 1744, 12, T.load("uint8", buffer_1.data, 0), 80, 0, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, T.load("int8", placeholder_5.data, 576), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3.data, 0), 880, 12, T.load("uint8", buffer_2.data, 0), 320, 0, 1, 2, 0, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, T.load("int8", ethosu_write_1.data, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer.data, 0), 1744, 12, T.load("uint8", buffer_1.data, 0), 80, 0, 1, 2, 0, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None -@tvm.script.tir +@tvm.script.ir_module class Conv2dDoubleCascade4: - def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, ethosu_write: T.handle) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = tir.match_buffer(placeholder_1, [1456], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_2, [352], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_5 = tir.match_buffer(placeholder, [1, 8, 1, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 2, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_2 = tir.match_buffer(placeholder_4, [272], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_3 = tir.match_buffer(placeholder_3, [11040], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = T.match_buffer(placeholder_1, [1456], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = T.match_buffer(placeholder_2, [352], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder, [1, 8, 1, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 2, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = T.match_buffer(placeholder_4, [272], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = T.match_buffer(placeholder_3, [11040], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - ethosu_write_2 = tir.allocate([2304], "int8", "global") - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, tir.load("int8", ethosu_write_2, 384), 0, 0, 0, tir.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 1456, 12, tir.load("uint8", buffer_1.data, 0), 352, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, tir.load("int8", ethosu_write_2, 384), 0, 0, 0, tir.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 11040, 12, tir.load("uint8", buffer_2.data, 0), 272, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, tir.load("int8", placeholder_5.data, 256), 0, 0, 0, tir.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 1456, 12, tir.load("uint8", buffer_1.data, 0), 352, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, tir.load("int8", ethosu_write_1.data, 1024), 0, 0, 0, tir.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 11040, 12, tir.load("uint8", buffer_2.data, 0), 272, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + ethosu_write_2 = T.allocate([2304], "int8", "global") + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 1456, 12, T.load("uint8", buffer_1.data, 0), 352, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 11040, 12, T.load("uint8", buffer_2.data, 0), 272, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 1456, 12, T.load("uint8", buffer_1.data, 0), 352, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 1024), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3.data, 0), 11040, 12, T.load("uint8", buffer_2.data, 0), 272, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -375,37 +378,39 @@ def _get_func( params = trial[1:] func = _get_func(*params[:-1]) mod, _ = lower_to_tir(func, cascader=total_cascader(params[-1])) - script = tvm.script.asscript(mod, True) + script = mod.script(True) mod = tvm.script.from_source(script) tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) # fmt: off -@tvm.script.tir +@tvm.script.ir_module class Conv2dInlineCopy1: - def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = tir.match_buffer(placeholder, [1, 10, 12, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = T.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = T.match_buffer(placeholder, [1, 10, 12, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, tir.load("int8", placeholder_3.data, 120), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 848, 12, tir.load("uint8", buffer_1.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, T.load("int8", placeholder_3.data, 120), 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer.data, 0), 848, 12, T.load("uint8", buffer_1.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None -@tvm.script.tir +@tvm.script.ir_module class Conv2dInlineCopy2: - def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 3, 5, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = tir.match_buffer(placeholder, [1, 7, 9, 5], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_1, [656], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 3, 5, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = T.match_buffer(placeholder, [1, 7, 9, 5], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = T.match_buffer(placeholder_1, [656], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, tir.load("int8", placeholder_3.data, 146), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 656, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, T.load("int8", placeholder_3.data, 146), 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 656, 12, T.load("uint8", buffer.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -432,69 +437,73 @@ def _get_func(ifm_shape, lower, upper, ofm_channels=16): params = trial[1:] func = _get_func(*params) mod, _ = lower_to_tir(func) - script = tvm.script.asscript(mod, True) + script = mod.script(True) mod = tvm.script.from_source(script) tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) # fmt: off -@tvm.script.tir +@tvm.script.ir_module class Conv2dInlineReshape1: - def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = tir.match_buffer(placeholder, [4, 6, 8, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = T.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = T.match_buffer(placeholder, [4, 6, 8, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 72), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 384), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None -@tvm.script.tir +@tvm.script.ir_module class Conv2dInlineReshape2: - def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = tir.match_buffer(placeholder, [1, 24, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = T.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = T.match_buffer(placeholder, [1, 24, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 72), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 384), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None -@tvm.script.tir +@tvm.script.ir_module class Conv2dInlineReshape3: - def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = tir.match_buffer(placeholder, [192, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = T.match_buffer(placeholder, [192, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = T.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 72), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 384), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None -@tvm.script.tir +@tvm.script.ir_module class Conv2dInlineReshape4: - def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = tir.match_buffer(placeholder, [192], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = T.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = T.match_buffer(placeholder, [192], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = T.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) # body - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 72), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 384), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1.data, 0), 848, 12, T.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -523,7 +532,7 @@ def _get_func(ifm_shape, reshaped, ifm_layout): params = trial[1:] func = _get_func(*params) mod, _ = lower_to_tir(func, cascader=total_cascader((1, 4, 6, 16))) - script = tvm.script.asscript(mod, True) + script = mod.script(True) mod = tvm.script.from_source(script) tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 2d76cd654690d..cc69815f1abf3 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -18,8 +18,7 @@ pytest.importorskip("ethosu.vela") import tvm -import tvm.script -from tvm.script import tir, ty +from tvm.script import tir as T from tvm import relay from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir @@ -29,21 +28,22 @@ # fmt: off -@tvm.script.tir +@tvm.script.ir_module class ReferenceModule: - def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: # function attr dict - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - buffer = tir.match_buffer(placeholder_2, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - placeholder_3 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) - buffer_1 = tir.match_buffer(placeholder_1, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) - ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = T.match_buffer(placeholder_2, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = T.match_buffer(placeholder_1, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) # body - placeholder_global = tir.allocate([304], "uint8", "global") - placeholder_d_global = tir.allocate([80], "uint8", "global") - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_1.data, 0), 304, tir.load("uint8", placeholder_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer.data, 0), 80, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) - tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 304, 12, tir.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + placeholder_global = T.allocate([304], "uint8", "global") + placeholder_d_global = T.allocate([80], "uint8", "global") + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -67,7 +67,7 @@ def _get_func(): func = _get_func() mod, _ = lower_to_tir(func, cascader=copy_constants()) - script = tvm.script.asscript(mod, True) + script = mod.script(True) test_mod = tvm.script.from_source(script) reference_mod = ReferenceModule() tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index a86dd919d5caf..5a71dc4dc0ca2 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -22,8 +22,7 @@ from unittest.mock import patch import tvm -from tvm import tir -from tvm.script import ty +from tvm.script import tir as T from tvm.tir import stmt_functor from tvm.relay.backend.contrib.ethosu import vela_api import tvm.relay.backend.contrib.ethosu.tir_to_cs_translator as tirtocs @@ -39,31 +38,32 @@ """Test case 1""" -@tvm.script.tir +@tvm.script.ir_module class Module1: + @T.prim_func def main( - placeholder: ty.handle, - placeholder_1: ty.handle, - placeholder_2: ty.handle, - ethosu_conv2d: ty.handle, + placeholder: T.handle, + placeholder_1: T.handle, + placeholder_2: T.handle, + ethosu_conv2d: T.handle, ) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_3 = tir.match_buffer( + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_3 = T.match_buffer( placeholder, [1, 8, 8, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) - placeholder_4 = tir.match_buffer( + placeholder_4 = T.match_buffer( placeholder_1, [48], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) - placeholder_5 = tir.match_buffer( + placeholder_5 = T.match_buffer( placeholder_2, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1 ) - ethosu_conv2d_1 = tir.match_buffer( + ethosu_conv2d_1 = T.match_buffer( ethosu_conv2d, [1, 8, 8, 16], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) # body - tir.evaluate( - tir.call_extern( + T.evaluate( + T.call_extern( "ethosu_conv2d", "uint8", 8, @@ -72,11 +72,11 @@ def main( 8, 0, 8, - tir.load("uint8", placeholder_3.data, 0), + T.load("uint8", placeholder_3.data, 0), 0, 0, 0, - tir.float32(0.5), + T.float32(0.5), 10, "NHWC", 24, @@ -89,11 +89,11 @@ def main( 8, 0, 8, - tir.load("uint8", ethosu_conv2d_1.data, 0), + T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, - tir.float32(0.25), + T.float32(0.25), 14, "NHWC", 128, @@ -105,10 +105,10 @@ def main( 1, 1, 1, - tir.load("uint8", placeholder_4.data, 0), + T.load("uint8", placeholder_4.data, 0), 0, 12, - tir.load("uint8", placeholder_5.data, 0), + T.load("uint8", placeholder_5.data, 0), 0, 0, 0, @@ -128,36 +128,37 @@ def main( """Test case 2 with per-channel quantization""" -@tvm.script.tir +@tvm.script.ir_module class Module2: + @T.prim_func def main( - placeholder: ty.handle, - placeholder_1: ty.handle, - placeholder_2: ty.handle, - placeholder_6: ty.handle, - ethosu_conv2d: ty.handle, + placeholder: T.handle, + placeholder_1: T.handle, + placeholder_2: T.handle, + placeholder_6: T.handle, + ethosu_conv2d: T.handle, ) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) - placeholder_3 = tir.match_buffer( + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_3 = T.match_buffer( placeholder, [1, 8, 8, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) - placeholder_4 = tir.match_buffer( + placeholder_4 = T.match_buffer( placeholder_1, [16, 1, 1, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) - placeholder_5 = tir.match_buffer( + placeholder_5 = T.match_buffer( placeholder_2, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1 ) # Per-channel weight scales - placeholder_7 = tir.match_buffer( + placeholder_7 = T.match_buffer( placeholder_6, [16], dtype="float32", elem_offset=0, align=128, offset_factor=1 ) - ethosu_conv2d_1 = tir.match_buffer( + ethosu_conv2d_1 = T.match_buffer( ethosu_conv2d, [1, 8, 8, 16], dtype="uint8", elem_offset=0, align=128, offset_factor=1 ) # body - tir.evaluate( - tir.call_extern( + T.evaluate( + T.call_extern( "ethosu_conv2d", "uint8", 8, @@ -166,11 +167,11 @@ def main( 8, 0, 8, - tir.load("uint8", placeholder_3.data, 0), + T.load("uint8", placeholder_3.data, 0), 0, 0, 0, - tir.float32(0.5), + T.float32(0.5), 10, "NHWC", 24, @@ -183,11 +184,11 @@ def main( 8, 0, 8, - tir.load("uint8", ethosu_conv2d_1.data, 0), + T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, - tir.float32(0.25), + T.float32(0.25), 14, "NHWC", 128, @@ -199,10 +200,10 @@ def main( 1, 1, 1, - tir.load("uint8", placeholder_4.data, 0), + T.load("uint8", placeholder_4.data, 0), 0, 12, - tir.load("uint8", placeholder_5.data, 0), + T.load("uint8", placeholder_5.data, 0), 0, 0, 0, @@ -478,7 +479,7 @@ def extract_ethosu_conv2d_extern_calls(mod): def populate_ethosu_conv2d_calls(stmt): if ( isinstance(stmt, tvm.tir.Call) - and stmt.op.name == "tir.call_extern" + and stmt.op.name == "T.call_extern" and stmt.args[0] == "ethosu_conv2d" ): ethosu_conv2d_calls.append(stmt) diff --git a/tests/python/integration/test_lower.py b/tests/python/integration/test_lower.py index 3fa4795870d59..690258c2fa3b4 100644 --- a/tests/python/integration/test_lower.py +++ b/tests/python/integration/test_lower.py @@ -18,38 +18,38 @@ """Test workload for lowering and build""" import tvm from tvm import tir -from tvm.script import ty +from tvm.script import tir as T import tvm.testing import numpy as np -@tvm.script.tir -def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: +@T.prim_func +def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: # match buffer - A = tir.match_buffer(a, [1024, 1024], "float16") - B = tir.match_buffer(b, [1024, 1024], "float16") - C = tir.match_buffer(c, [1024, 1024], "float32") + A = T.match_buffer(a, [1024, 1024], "float16") + B = T.match_buffer(b, [1024, 1024], "float16") + C = T.match_buffer(c, [1024, 1024], "float32") # body - for blockIdx_x in tir.thread_binding(0, 16, "blockIdx.x"): - for blockIdx_y in tir.thread_binding(0, 8, "blockIdx.y"): - with tir.block([16, 8]) as [bx, by]: - tir.bind(bx, blockIdx_x) - tir.bind(by, blockIdx_y) - shared_A = tir.alloc_buffer([1024, 1024], "float16", scope="shared") - shared_B = tir.alloc_buffer([1024, 1024], "float16", scope="shared") - wmma_A = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") - wmma_B = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") - wmma_C = tir.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") - for ty in tir.thread_binding(0, 2, "threadIdx.y"): - for tz in tir.thread_binding(0, 2, "threadIdx.z"): - for i, j in tir.grid(2, 4): - with tir.block([64, 64]) as [vi, vj]: - tir.bind(vi, bx * 4 + ty * 2 + i) - tir.bind(vj, by * 8 + tz * 4 + j) - tir.reads([]) - tir.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - C0 = tir.match_buffer( + for blockIdx_x in T.thread_binding(0, 16, "blockIdx.x"): + for blockIdx_y in T.thread_binding(0, 8, "blockIdx.y"): + with T.block([16, 8]) as [bx, by]: + T.bind(bx, blockIdx_x) + T.bind(by, blockIdx_y) + shared_A = T.alloc_buffer([1024, 1024], "float16", scope="shared") + shared_B = T.alloc_buffer([1024, 1024], "float16", scope="shared") + wmma_A = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") + wmma_B = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") + wmma_C = T.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") + for ty in T.thread_binding(0, 2, "threadIdx.y"): + for tz in T.thread_binding(0, 2, "threadIdx.z"): + for i, j in T.grid(2, 4): + with T.block([64, 64]) as [vi, vj]: + T.bind(vi, bx * 4 + ty * 2 + i) + T.bind(vj, by * 8 + tz * 4 + j) + T.reads([]) + T.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + C0 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", @@ -57,52 +57,52 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: scope="wmma.accumulator", offset_factor=1, ) - tir.evaluate( - tir.tvm_fill_fragment( + T.evaluate( + T.tvm_fill_fragment( C0.data, 16, 16, 16, i * 4 + j, - tir.float32(0), + T.float32(0), dtype="handle", ) ) for ko in range(0, 32): # copy data from global to shared - for tx in tir.thread_binding(0, 32, "threadIdx.x"): - for i0, j0 in tir.grid(1, 4): - for j1 in tir.vectorized(0, 4): - with tir.block([1024, 1024]) as [vi, vj]: - tir.bind(vi, bx * 64 + ty * 32 + tx + i0) - tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + for tx in T.thread_binding(0, 32, "threadIdx.x"): + for i0, j0 in T.grid(1, 4): + for j1 in T.vectorized(0, 4): + with T.block([1024, 1024]) as [vi, vj]: + T.bind(vi, bx * 64 + ty * 32 + tx + i0) + T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) shared_A[vi, vj + 8] = A[vi, vj] - for i0, j0 in tir.grid(2, 4): - for j1 in tir.vectorized(0, 4): - with tir.block([1024, 1024]) as [vi, vj]: - tir.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) - tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) + for i0, j0 in T.grid(2, 4): + for j1 in T.vectorized(0, 4): + with T.block([1024, 1024]) as [vi, vj]: + T.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) + T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) shared_B[vi, vj + 8] = B[vi, vj] for ki in range(0, 2): for i in range(0, 2): - with tir.block([64, 64]) as [vi, vk]: - tir.bind(vi, bx * 4 + ty * 2 + i) - tir.bind(vk, ko * 2 + ki) - tir.reads( + with T.block([64, 64]) as [vi, vk]: + T.bind(vi, bx * 4 + ty * 2 + i) + T.bind(vk, ko * 2 + ki) + T.reads( shared_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ] ) - tir.writes( + T.writes( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16] ) - s0 = tir.var("int32") - s1 = tir.var("int32") - A0 = tir.match_buffer( + s0 = T.var("int32") + s1 = T.var("int32") + A0 = T.match_buffer( shared_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + 8, @@ -113,7 +113,7 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: scope="shared", offset_factor=1, ) - wmma_A0 = tir.match_buffer( + wmma_A0 = T.match_buffer( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", @@ -121,15 +121,15 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: scope="wmma.matrix_a", offset_factor=1, ) - tir.evaluate( - tir.tvm_load_matrix_sync( + T.evaluate( + T.tvm_load_matrix_sync( wmma_A0.data, 16, 16, 16, i, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), A0.data, A0.elem_offset + 8, A0.strides[0], @@ -142,21 +142,21 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: ) ) for j in range(0, 4): - with tir.block([64, 64]) as [vj, vk]: - tir.bind(vj, by * 8 + tz * 4 + j) - tir.bind(vk, ko * 2 + ki) - tir.reads( + with T.block([64, 64]) as [vj, vk]: + T.bind(vj, by * 8 + tz * 4 + j) + T.bind(vk, ko * 2 + ki) + T.reads( shared_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ] ) - tir.writes( + T.writes( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16] ) - s0 = tir.var("int32") - s1 = tir.var("int32") - B0 = tir.match_buffer( + s0 = T.var("int32") + s1 = T.var("int32") + B0 = T.match_buffer( shared_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + 8, @@ -167,7 +167,7 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: scope="shared", offset_factor=1, ) - wmma_B0 = tir.match_buffer( + wmma_B0 = T.match_buffer( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", @@ -175,15 +175,15 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: scope="wmma.matrix_b", offset_factor=1, ) - tir.evaluate( - tir.tvm_load_matrix_sync( + T.evaluate( + T.tvm_load_matrix_sync( wmma_B0.data, 16, 16, 16, j, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), B0.data, B0.elem_offset + 8, B0.strides[0], @@ -195,16 +195,16 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: dtype="handle", ) ) - for i, j in tir.grid(2, 4): - with tir.block([64, 64, tir.reduce_axis(0, 64)]) as [ + for i, j in T.grid(2, 4): + with T.block([64, 64, T.reduce_axis(0, 64)]) as [ vi, vj, vk, ]: - tir.bind(vi, bx * 4 + ty * 2 + i) - tir.bind(vj, by * 8 + tz * 4 + j) - tir.bind(vk, ko * 2 + ki) - tir.reads( + T.bind(vi, bx * 4 + ty * 2 + i) + T.bind(vj, by * 8 + tz * 4 + j) + T.bind(vk, ko * 2 + ki) + T.reads( [ wmma_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 @@ -217,10 +217,10 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: ], ] ) - tir.writes( + T.writes( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16] ) - wmma_A1 = tir.match_buffer( + wmma_A1 = T.match_buffer( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", @@ -228,7 +228,7 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: scope="wmma.matrix_a", offset_factor=1, ) - wmma_B1 = tir.match_buffer( + wmma_B1 = T.match_buffer( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", @@ -236,7 +236,7 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: scope="wmma.matrix_b", offset_factor=1, ) - wmma_C1 = tir.match_buffer( + wmma_C1 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", @@ -244,8 +244,8 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: scope="wmma.accumulator", offset_factor=1, ) - tir.evaluate( - tir.tvm_mma_sync( + T.evaluate( + T.tvm_mma_sync( wmma_C1.data, i * 4 + j, wmma_A1.data, @@ -257,15 +257,15 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: dtype="handle", ) ) - for i, j in tir.grid(2, 4): - with tir.block([64, 64]) as [vi, vj]: - tir.bind(vi, bx * 4 + ty * 2 + i) - tir.bind(vj, by * 8 + tz * 4 + j) - tir.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - s0 = tir.var("int32") - s1 = tir.var("int32") - wmma_C2 = tir.match_buffer( + for i, j in T.grid(2, 4): + with T.block([64, 64]) as [vi, vj]: + T.bind(vi, bx * 4 + ty * 2 + i) + T.bind(vj, by * 8 + tz * 4 + j) + T.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + s0 = T.var("int32") + s1 = T.var("int32") + wmma_C2 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", @@ -273,22 +273,22 @@ def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None: scope="wmma.accumulator", offset_factor=1, ) - C1 = tir.match_buffer( + C1 = T.match_buffer( C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[s0, s1], offset_factor=1, ) - tir.evaluate( - tir.tvm_store_matrix_sync( + T.evaluate( + T.tvm_store_matrix_sync( wmma_C2.data, 16, 16, 16, i * 4 + j, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float32"), + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), C1.data, C1.elem_offset, C1.strides[0], diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py index 626af0c96633c..222d647f4ea71 100644 --- a/tests/python/unittest/test_aot_legalize_packed_call.py +++ b/tests/python/unittest/test_aot_legalize_packed_call.py @@ -16,22 +16,22 @@ # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring import tvm -from tvm.script import ty -from tvm import te, tir -import numpy as np +from tvm.script import tir as T +from tvm import tir import tvm.testing import pytest -@tvm.script.tir +@tvm.script.ir_module class Module: + @T.prim_func def tir_packed_call() -> None: - A = tir.var("handle") - B = tir.var("handle") - C = tir.var("handle") + A = T.var("handle") + B = T.var("handle") + C = T.var("handle") # body - tir.evaluate( - tir.tvm_call_cpacked( + T.evaluate( + T.tvm_call_cpacked( "tvm_test_cpacked", A, B, @@ -41,25 +41,26 @@ def tir_packed_call() -> None: ) -@tvm.script.tir +@tvm.script.ir_module class Expected: + @T.prim_func def tir_packed_call() -> None: - A = tir.var("handle") - B = tir.var("handle") - C = tir.var("handle") + A = T.var("handle") + B = T.var("handle") + C = T.var("handle") # body - tvm_value_2 = tir.var("handle") - tvm_value_1 = tir.var("handle") - tvm_value_0 = tir.var("handle") - with tir.let(tvm_value_2, tir.tvm_stack_alloca("array", 1, dtype="handle")): - with tir.let(tvm_value_1, tir.tvm_stack_alloca("array", 1, dtype="handle")): - with tir.let(tvm_value_0, tir.tvm_stack_alloca("array", 1, dtype="handle")): - tir.evaluate(tir.tvm_struct_set(tvm_value_0, 0, 1, A, dtype="handle")) - tir.evaluate(tir.tvm_struct_set(tvm_value_1, 0, 1, B, dtype="handle")) - tir.evaluate(tir.tvm_struct_set(tvm_value_2, 0, 1, C, dtype="handle")) - tir.evaluate( - tir.tvm_call_cpacked( + tvm_value_2 = T.var("handle") + tvm_value_1 = T.var("handle") + tvm_value_0 = T.var("handle") + with T.let(tvm_value_2, T.tvm_stack_alloca("array", 1, dtype="handle")): + with T.let(tvm_value_1, T.tvm_stack_alloca("array", 1, dtype="handle")): + with T.let(tvm_value_0, T.tvm_stack_alloca("array", 1, dtype="handle")): + T.evaluate(T.tvm_struct_set(tvm_value_0, 0, 1, A, dtype="handle")) + T.evaluate(T.tvm_struct_set(tvm_value_1, 0, 1, B, dtype="handle")) + T.evaluate(T.tvm_struct_set(tvm_value_2, 0, 1, C, dtype="handle")) + T.evaluate( + T.tvm_call_cpacked( "tvm_test_cpacked", tvm_value_0, tvm_value_1, @@ -70,8 +71,8 @@ def tir_packed_call() -> None: def test_aot_packed_call(): - mod = Module() - expected = Expected() + mod = Module + expected = Expected out = tir.transform.LegalizePackedCalls()(mod) tvm.ir.assert_structural_equal(expected, out, map_free_vars=True) diff --git a/tests/python/unittest/test_lower_build.py b/tests/python/unittest/test_lower_build.py index e5528a8c4756c..6502f0c67de62 100644 --- a/tests/python/unittest/test_lower_build.py +++ b/tests/python/unittest/test_lower_build.py @@ -18,9 +18,9 @@ import numpy as np import tvm -from tvm import te, tir +from tvm import te from tvm.ir.module import IRModule -from tvm.script import ty +from tvm.script import tir as T import tvm.testing @@ -35,53 +35,53 @@ def _check_module_with_numpy(mod, shape=(128, 128, 128)): # pylint: disable=no-self-argument, missing-class-docstring, missing-function-docstring -@tvm.script.tir -def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "init") as [vi, vj]: - C[vi, vj] = tir.float32(0) +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j in T.grid(128, 128): + with T.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = T.float32(0) for k in range(0, 128): - with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@tvm.script.tir +@tvm.script.ir_module class LoweredModule: - def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: # function attr dict - tir.func_attr( - {"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True} - ) - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) + T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True}) + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) # body - for x, y in tir.grid(128, 128): + for x, y in T.grid(128, 128): C.data[x * 128 + y] = 0.0 - for k in tir.serial(0, 128): - C.data[x * 128 + y] = tir.load("float32", C.data, x * 128 + y) + tir.load( + for k in T.serial(0, 128): + C.data[x * 128 + y] = T.load("float32", C.data, x * 128 + y) + T.load( "float32", A.data, x * 128 + k - ) * tir.load("float32", B.data, y * 128 + k) + ) * T.load("float32", B.data, y * 128 + k) -@tvm.script.tir +@tvm.script.ir_module class LoweredTIRModule: - def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) # body - for x, y in tir.grid(128, 128): + for x, y in T.grid(128, 128): C.data[x * 128 + y] = 0.0 - for k in tir.serial(0, 128): - C.data[x * 128 + y] = tir.load("float32", C.data, x * 128 + y) + tir.load( + for k in T.serial(0, 128): + C.data[x * 128 + y] = T.load("float32", C.data, x * 128 + y) + T.load( "float32", A.data, x * 128 + k - ) * tir.load("float32", B.data, y * 128 + k) + ) * T.load("float32", B.data, y * 128 + k) def test_lower_build_te_schedule(): @@ -93,7 +93,7 @@ def test_lower_build_te_schedule(): s = te.create_schedule(C.op) # check lowering ir_mod = tvm.lower(s, [A, B, C]) - tvm.ir.assert_structural_equal(ir_mod, LoweredModule()) + tvm.ir.assert_structural_equal(ir_mod, LoweredModule) # check building mod = tvm.build(s, [A, B, C], target="llvm") _check_module_with_numpy(mod) @@ -102,7 +102,7 @@ def test_lower_build_te_schedule(): def test_lower_build_tir_func(): # check lowering ir_mod = tvm.lower(matmul) - tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule()) + tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule) # check building mod = tvm.build(matmul, target="llvm") _check_module_with_numpy(mod) @@ -114,7 +114,7 @@ def test_lower_build_tir_module(): ir_mod = IRModule({"main": func}) # check lowering lowered_mod = tvm.lower(ir_mod) - tvm.ir.assert_structural_equal(lowered_mod, LoweredTIRModule()) + tvm.ir.assert_structural_equal(lowered_mod, LoweredTIRModule) # check building mod = tvm.build(ir_mod, target="llvm") _check_module_with_numpy(mod) @@ -122,8 +122,8 @@ def test_lower_build_tir_module(): def test_lower_build_lowered_module(): # check lowering - ir_mod = tvm.lower(LoweredTIRModule()) - tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule()) + ir_mod = tvm.lower(LoweredTIRModule) + tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule) # check building mod = tvm.build(ir_mod, target="llvm") _check_module_with_numpy(mod) diff --git a/tests/python/unittest/test_meta_schedule_arg_info.py b/tests/python/unittest/test_meta_schedule_arg_info.py index 51ec9ea87ed3f..7bedea9082d14 100644 --- a/tests/python/unittest/test_meta_schedule_arg_info.py +++ b/tests/python/unittest/test_meta_schedule_arg_info.py @@ -16,22 +16,20 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring -import tvm -from tvm import tir from tvm.meta_schedule.arg_info import ArgInfo, TensorInfo -from tvm.script import ty +from tvm.script import tir as T # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off -@tvm.script.tir -def Matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - tir.func_attr({"global_symbol": "main"}) - A = tir.match_buffer(a, (128, 256), "float32") - B = tir.match_buffer(b, (256, 512), "float32") - C = tir.match_buffer(c, (128, 512), "float32") - with tir.block([128, 256, tir.reduce_axis(0, 512)], "matmul") as [vi, vj, vk]: - with tir.init(): +@T.prim_func +def Matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (128, 256), "float32") + B = T.match_buffer(b, (256, 512), "float32") + C = T.match_buffer(c, (128, 512), "float32") + with T.block([128, 256, T.reduce_axis(0, 512)], "matmul") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] diff --git a/tests/python/unittest/test_meta_schedule_builder.py b/tests/python/unittest/test_meta_schedule_builder.py index f97ede881330e..fa09a092c8c46 100644 --- a/tests/python/unittest/test_meta_schedule_builder.py +++ b/tests/python/unittest/test_meta_schedule_builder.py @@ -23,7 +23,7 @@ import pytest -from tvm import tir, script +from tvm import script from tvm._ffi import register_func from tvm.meta_schedule.builder import ( BuilderInput, @@ -32,57 +32,58 @@ PyBuilder, ) from tvm.runtime import Module -from tvm.script import ty +from tvm.script import tir as T from tvm.target import Target # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring -@script.tir +@script.ir_module class MatmulModule: - def matmul( # pylint: disable=no-self-argument - a: ty.handle, b: ty.handle, c: ty.handle - ) -> None: - tir.func_attr({"global_symbol": "matmul", "tir.noalias": True}) - A = tir.match_buffer(a, (1024, 1024), "float32") - B = tir.match_buffer(b, (1024, 1024), "float32") - C = tir.match_buffer(c, (1024, 1024), "float32") - with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with tir.init(): + @T.prim_func + def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "matmul", "tir.noalias": True}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] -@script.tir +@script.ir_module class MatmulReluModule: + @T.prim_func def matmul_relu( # pylint: disable=no-self-argument - a: ty.handle, b: ty.handle, d: ty.handle + a: T.handle, b: T.handle, d: T.handle ) -> None: - tir.func_attr({"global_symbol": "matmul_relu", "tir.noalias": True}) - A = tir.match_buffer(a, (1024, 1024), "float32") - B = tir.match_buffer(b, (1024, 1024), "float32") - D = tir.match_buffer(d, (1024, 1024), "float32") - C = tir.alloc_buffer((1024, 1024), "float32") - with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with tir.init(): + T.func_attr({"global_symbol": "matmul_relu", "tir.noalias": True}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + D = T.match_buffer(d, (1024, 1024), "float32") + C = T.alloc_buffer((1024, 1024), "float32") + with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with tir.block([1024, 1024], "relu") as [vi, vj]: - D[vi, vj] = tir.max(C[vi, vj], 0.0) + with T.block([1024, 1024], "relu") as [vi, vj]: + D[vi, vj] = T.max(C[vi, vj], 0.0) -@script.tir +@script.ir_module class BatchMatmulModule: + @T.prim_func def batch_matmul( # pylint: disable=no-self-argument - a: ty.handle, b: ty.handle, c: ty.handle + a: T.handle, b: T.handle, c: T.handle ) -> None: - tir.func_attr({"global_symbol": "batch_matmul", "tir.noalias": True}) - A = tir.match_buffer(a, [16, 128, 128]) - B = tir.match_buffer(b, [16, 128, 128]) - C = tir.match_buffer(c, [16, 128, 128]) - with tir.block([16, 128, 128, tir.reduce_axis(0, 128)], "update") as [vn, vi, vj, vk]: - with tir.init(): + T.func_attr({"global_symbol": "batch_matmul", "tir.noalias": True}) + A = T.match_buffer(a, [16, 128, 128]) + B = T.match_buffer(b, [16, 128, 128]) + C = T.match_buffer(c, [16, 128, 128]) + with T.block([16, 128, 128, T.reduce_axis(0, 128)], "update") as [vn, vi, vj, vk]: + with T.init(): C[vn, vi, vj] = 0.0 C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] @@ -103,7 +104,7 @@ def _check_build_results(builder_results: List[BuilderResult]): def test_meta_schedule_single_build(): """Test meta schedule builder for a single build""" - mod = MatmulModule() + mod = MatmulModule builder = LocalBuilder() builder_inputs = [BuilderInput(mod, Target("llvm"))] builder_results = builder.build(builder_inputs) @@ -115,9 +116,9 @@ def test_meta_schedule_multiple_build(): """Test meta schedule builder for multiple builds""" builder = LocalBuilder() builder_inputs = [ - BuilderInput(MatmulModule(), Target("llvm")), - BuilderInput(MatmulReluModule(), Target("llvm")), - BuilderInput(BatchMatmulModule(), Target("llvm")), + BuilderInput(MatmulModule, Target("llvm")), + BuilderInput(MatmulReluModule, Target("llvm")), + BuilderInput(BatchMatmulModule, Target("llvm")), ] builder_results = builder.build(builder_inputs) assert len(builder_results) == len(builder_inputs) @@ -136,9 +137,9 @@ def build( # pylint: disable=no-self-use builder = TestBuilder() builder_inputs = [ - BuilderInput(MatmulModule(), Target("llvm")), - BuilderInput(MatmulReluModule(), Target("llvm")), - BuilderInput(BatchMatmulModule(), Target("llvm")), + BuilderInput(MatmulModule, Target("llvm")), + BuilderInput(MatmulReluModule, Target("llvm")), + BuilderInput(BatchMatmulModule, Target("llvm")), ] builder_results = builder.build(builder_inputs) assert len(builder_results) == len(builder_inputs) @@ -158,7 +159,7 @@ def test_build(mod: Module, target: Target) -> None: # pylint: disable=unused-v raise ValueError("Builder intended Test Error (build func).") builder = LocalBuilder(f_build="meta_schedule.builder.test_build", initializer=initializer) - builder_inputs = [BuilderInput(MatmulModule(), Target("llvm"))] + builder_inputs = [BuilderInput(MatmulModule, Target("llvm"))] builder_results = builder.build(builder_inputs) assert len(builder_results) == len(builder_inputs) for result in builder_results: @@ -177,7 +178,7 @@ def test_build(mod: Module) -> str: # pylint: disable=unused-variable raise ValueError("Builder intended Test Error (export func).") builder = LocalBuilder(f_export="meta_schedule.builder.test_export", initializer=initializer) - builder_inputs = [BuilderInput(MatmulModule(), Target("llvm"))] + builder_inputs = [BuilderInput(MatmulModule, Target("llvm"))] builder_results = builder.build(builder_inputs) assert len(builder_results) == len(builder_inputs) for result in builder_results: @@ -200,7 +201,7 @@ def timeout_build(mod, target): # pylint: disable=unused-argument, unused-varia f_build="meta_schedule.builder.test_time_out", initializer=initializer, ) - builder_inputs = [BuilderInput(MatmulModule(), Target("llvm"))] + builder_inputs = [BuilderInput(MatmulModule, Target("llvm"))] builder_results = builder.build(builder_inputs) assert len(builder_results) == len(builder_inputs) for result in builder_results: diff --git a/tests/python/unittest/test_meta_schedule_database.py b/tests/python/unittest/test_meta_schedule_database.py index feef023675b04..cb39c91eaca46 100644 --- a/tests/python/unittest/test_meta_schedule_database.py +++ b/tests/python/unittest/test_meta_schedule_database.py @@ -28,39 +28,40 @@ from tvm.ir.module import IRModule from tvm.meta_schedule.arg_info import ArgInfo from tvm.meta_schedule.database import JSONDatabase, TuningRecord -from tvm.script import ty +from tvm.script import tir as T from tvm.tir import Schedule # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off - -@tvm.script.tir +@tvm.script.ir_module class Matmul: - def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - tir.func_attr({"global_symbol": "main"}) - A = tir.match_buffer(a, (1024, 1024), "float32") - B = tir.match_buffer(b, (1024, 1024), "float32") - C = tir.match_buffer(c, (1024, 1024), "float32") - with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with tir.init(): + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] -@tvm.script.tir +@tvm.script.ir_module class MatmulRelu: - def main(a: ty.handle, b: ty.handle, d: ty.handle) -> None: # pylint: disable=no-self-argument - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = tir.match_buffer(a, (16, 16), "float32") - B = tir.match_buffer(b, (16, 16), "float32") - D = tir.match_buffer(d, (16, 16), "float32") - C = tir.alloc_buffer((16, 16), "float32") - with tir.block([16, 16, tir.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: - with tir.init(): + @T.prim_func + def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, (16, 16), "float32") + B = T.match_buffer(b, (16, 16), "float32") + D = T.match_buffer(d, (16, 16), "float32") + C = T.alloc_buffer((16, 16), "float32") + with T.block([16, 16, T.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with tir.block([16, 16], "relu") as [vi, vj]: - D[vi, vj] = tir.max(C[vi, vj], 0.0) + with T.block([16, 16], "relu") as [vi, vj]: + D[vi, vj] = T.max(C[vi, vj], 0.0) # fmt: on @@ -102,7 +103,7 @@ def _equal_record(a: TuningRecord, b: TuningRecord): def test_meta_schedule_tuning_record_round_trip(): - mod: IRModule = Matmul() + mod: IRModule = Matmul with tempfile.TemporaryDirectory() as tmpdir: database = _create_tmp_database(tmpdir) workload = database.commit_workload(mod) @@ -126,7 +127,7 @@ def test_meta_schedule_database_create(): def test_meta_schedule_database_add_entry(): - mod: IRModule = Matmul() + mod: IRModule = Matmul with tempfile.TemporaryDirectory() as tmpdir: database = _create_tmp_database(tmpdir) workload = database.commit_workload(mod) @@ -144,8 +145,8 @@ def test_meta_schedule_database_add_entry(): def test_meta_schedule_database_missing(): - mod: IRModule = Matmul() - mod_2: IRModule = MatmulRelu() + mod: IRModule = Matmul + mod_2: IRModule = MatmulRelu with tempfile.TemporaryDirectory() as tmpdir: database = _create_tmp_database(tmpdir) workload = database.commit_workload(mod) @@ -163,7 +164,7 @@ def test_meta_schedule_database_missing(): def test_meta_schedule_database_sorting(): - mod: IRModule = Matmul() + mod: IRModule = Matmul with tempfile.TemporaryDirectory() as tmpdir: database = _create_tmp_database(tmpdir) token = database.commit_workload(mod) @@ -225,7 +226,7 @@ def test_meta_schedule_database_sorting(): def test_meta_schedule_database_reload(): - mod: IRModule = Matmul() + mod: IRModule = Matmul with tempfile.TemporaryDirectory() as tmpdir: database = _create_tmp_database(tmpdir) token = database.commit_workload(mod) diff --git a/tests/python/unittest/test_meta_schedule_runner.py b/tests/python/unittest/test_meta_schedule_runner.py index 3c8aee0c6d58f..0b12b66ee460c 100644 --- a/tests/python/unittest/test_meta_schedule_runner.py +++ b/tests/python/unittest/test_meta_schedule_runner.py @@ -25,7 +25,6 @@ import pytest import tvm -from tvm import tir from tvm._ffi import register_func from tvm.meta_schedule.arg_info import TensorInfo from tvm.meta_schedule.builder import BuilderInput, LocalBuilder @@ -44,7 +43,7 @@ from tvm.meta_schedule.utils import get_global_func_with_default_on_worker from tvm.rpc import RPCSession from tvm.runtime import Device, Module -from tvm.script import ty +from tvm.script import tir as T from tvm.target import Target import tvm.testing from tvm.tir import FloatImm @@ -55,56 +54,60 @@ # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring,unbalanced-tuple-unpacking -@tvm.script.tir +@tvm.script.ir_module class MatmulModule: - def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = tir.match_buffer(a, (16, 16), "float32") - B = tir.match_buffer(b, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") - with tir.block([16, 16, tir.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: - with tir.init(): + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, (16, 16), "float32") + B = T.match_buffer(b, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") + with T.block([16, 16, T.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] -@tvm.script.tir +@tvm.script.ir_module class MatmulReluModule: - def main(a: ty.handle, b: ty.handle, d: ty.handle) -> None: # pylint: disable=no-self-argument - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = tir.match_buffer(a, (16, 16), "float32") - B = tir.match_buffer(b, (16, 16), "float32") - D = tir.match_buffer(d, (16, 16), "float32") - C = tir.alloc_buffer((16, 16), "float32") - with tir.block([16, 16, tir.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: - with tir.init(): + @T.prim_func + def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, (16, 16), "float32") + B = T.match_buffer(b, (16, 16), "float32") + D = T.match_buffer(d, (16, 16), "float32") + C = T.alloc_buffer((16, 16), "float32") + with T.block([16, 16, T.reduce_axis(0, 16)], "matmul") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with tir.block([16, 16], "relu") as [vi, vj]: - D[vi, vj] = tir.max(C[vi, vj], 0.0) + with T.block([16, 16], "relu") as [vi, vj]: + D[vi, vj] = T.max(C[vi, vj], 0.0) -@tvm.script.tir +@tvm.script.ir_module class BatchMatmulModule: - def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = tir.match_buffer(a, [16, 32, 32]) - B = tir.match_buffer(b, [16, 32, 32]) - C = tir.match_buffer(c, [16, 32, 32]) - with tir.block([16, 32, 32, tir.reduce_axis(0, 32)], "update") as [vn, vi, vj, vk]: - with tir.init(): + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [16, 32, 32]) + B = T.match_buffer(b, [16, 32, 32]) + C = T.match_buffer(c, [16, 32, 32]) + with T.block([16, 32, 32, T.reduce_axis(0, 32)], "update") as [vn, vi, vj, vk]: + with T.init(): C[vn, vi, vj] = 0.0 C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] -@tvm.script.tir +@tvm.script.ir_module class AddModule: - def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = tir.match_buffer(a, [32], "float32") - B = tir.match_buffer(b, [32], "float32") - C = tir.match_buffer(c, [32], "float32") - with tir.block([32], "add") as [vi]: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [32], "float32") + B = T.match_buffer(b, [32], "float32") + C = T.match_buffer(c, [32], "float32") + with T.block([32], "add") as [vi]: C[vi] = A[vi] + B[vi] @@ -122,7 +125,7 @@ def _clean_build(artifact_path: str) -> None: def test_meta_schedule_rpc_single_run(): """Test meta schedule rpc runner for a single run""" # Build the module - mod = MatmulModule() + mod = MatmulModule builder = LocalBuilder() (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) assert builder_result.artifact_path is not None @@ -169,9 +172,9 @@ def test_meta_schedule_rpc_multiple_runs(): """Test meta schedule rpc runner for multiple runs""" # Build the module mods = [ - MatmulModule(), - MatmulReluModule(), - BatchMatmulModule(), + MatmulModule, + MatmulReluModule, + BatchMatmulModule, ] builder = LocalBuilder() builder_inputs = [BuilderInput(mod, Target("llvm")) for mod in mods] @@ -407,7 +410,7 @@ def test_run_evaluator( return costs # Build the module - mod = MatmulModule() + mod = MatmulModule builder = LocalBuilder() (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) assert builder_result.artifact_path is not None @@ -519,7 +522,7 @@ def test_run_evaluator( return costs # Build the module - mod = AddModule() + mod = AddModule builder = LocalBuilder() (builder_result,) = builder.build([BuilderInput(mod, Target("llvm"))]) assert builder_result.artifact_path is not None diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 6e90bddb84b41..e12871391558c 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -26,9 +26,9 @@ from tvm.meta_schedule import TuneContext from tvm.meta_schedule.runner import RunnerResult from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.search_strategy import SearchStrategy, ReplayTrace +from tvm.meta_schedule.search_strategy import ReplayTrace -from tvm.script import ty +from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace @@ -37,15 +37,16 @@ # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking # fmt: off -@tvm.script.tir +@tvm.script.ir_module class Matmul: - def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - tir.func_attr({"global_symbol": "main"}) - A = tir.match_buffer(a, (32, 32), "float32") - B = tir.match_buffer(b, (32, 32), "float32") - C = tir.match_buffer(c, (32, 32), "float32") - with tir.block([32, 32, tir.reduce_axis(0, 32)], "matmul") as [vi, vj, vk]: - with tir.init(): + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (32, 32), "float32") + B = T.match_buffer(b, (32, 32), "float32") + C = T.match_buffer(c, (32, 32), "float32") + with T.block([32, 32, T.reduce_axis(0, 32)], "matmul") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @@ -73,9 +74,9 @@ def test_meta_schedule_replay_trace(): num_trials_per_iter = 7 num_trials_total = 20 - (example_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul()) + (example_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) - tune_context = TuneContext(mod=Matmul()) + tune_context = TuneContext(mod=Matmul) replay.initialize_with_tune_context(tune_context) num_trials_each_round: List[int] = [] diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py index 3ab60aced197e..39bb1acf065f8 100644 --- a/tests/python/unittest/test_meta_schedule_space_generator.py +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -23,25 +23,25 @@ import pytest import tvm -from tvm import tir -from tvm.script import ty +from tvm.script import tir as T -from tvm.tir.schedule import Schedule, Trace +from tvm.tir.schedule import Schedule from tvm.meta_schedule.space_generator import ScheduleFn, SpaceGeneratorUnion # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off -@tvm.script.tir +@tvm.script.ir_module class Matmul: - def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - tir.func_attr({"global_symbol": "main"}) - A = tir.match_buffer(a, (1024, 1024), "float32") - B = tir.match_buffer(b, (1024, 1024), "float32") - C = tir.match_buffer(c, (1024, 1024), "float32") - with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with tir.init(): + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @@ -66,7 +66,7 @@ def _check_correct(schedule: Schedule): def test_meta_schedule_space_generator_schedule_fn(): - mod = Matmul() + mod = Matmul space_generator = ScheduleFn(sch_fn=schedule_matmul) design_spaces = space_generator.generate_design_space(mod) assert len(design_spaces) == 1 @@ -75,7 +75,7 @@ def test_meta_schedule_space_generator_schedule_fn(): def test_meta_schedule_design_space_generator_union(): - mod = Matmul() + mod = Matmul space_generator = ScheduleFn(sch_fn=schedule_matmul) space_generator_union = SpaceGeneratorUnion([space_generator, space_generator]) design_spaces = space_generator_union.generate_design_space(mod) diff --git a/tests/python/unittest/test_meta_schedule_tune_context.py b/tests/python/unittest/test_meta_schedule_tune_context.py index 2da4c85ab4218..44bb949b925b9 100644 --- a/tests/python/unittest/test_meta_schedule_tune_context.py +++ b/tests/python/unittest/test_meta_schedule_tune_context.py @@ -20,23 +20,23 @@ import pytest import tvm -from tvm import tir -from tvm.script import ty +from tvm.script import tir as T from tvm.target import Target from tvm.meta_schedule import TuneContext # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring -@tvm.script.tir +@tvm.script.ir_module class Matmul: - def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=no-self-argument - tir.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = tir.match_buffer(a, (1024, 1024), "float32") - B = tir.match_buffer(b, (1024, 1024), "float32") - C = tir.match_buffer(c, (1024, 1024), "float32") - with tir.block([1024, 1024, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: - with tir.init(): + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + with T.block([1024, 1024, T.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] @@ -45,7 +45,7 @@ def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=n def test_tune_context_create(): - mod = Matmul() + mod = Matmul context = TuneContext(mod=mod, target=Target("llvm"), task_name="Test Task") assert context.num_threads > 0 assert context.rand_state != -1 diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 2fdafe08e60f7..987898001a1b7 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring import tvm -from tvm.script import ty +from tvm.script import tir as T from tvm import te, tir import numpy as np import tvm.testing @@ -48,14 +48,14 @@ def te_matmul(): return [A, B, C] -@tvm.script.tir -def tir_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - C = tir.match_buffer(c, (128, 128)) +@T.prim_func +def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) - with tir.block([128, 128, tir.reduce_axis(0, 128)]) as [i, j, k]: - with tir.init(): + with T.block([128, 128, T.reduce_axis(0, 128)]) as [i, j, k]: + with T.init(): C[i, j] = 0.0 C[i, j] += A[i, k] * B[j, k] @@ -71,15 +71,15 @@ def te_element_wise(): return [A, C] -@tvm.script.tir -def tir_element_wise(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - B = tir.alloc_buffer((128, 128)) +@T.prim_func +def tir_element_wise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + B = T.alloc_buffer((128, 128)) - with tir.block([128, 128]) as [i, j]: + with T.block([128, 128]) as [i, j]: B[i, j] = A[i, j] * 2.0 - with tir.block([128, 128]) as [i, j]: + with T.block([128, 128]) as [i, j]: C[i, j] = B[i, j] + 1.0 @@ -118,24 +118,24 @@ def te_conv2d(): return [A, W, B] -@tvm.script.tir -def tir_conv2d(a: ty.handle, w: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [16, 16, 14, 14]) - W = tir.match_buffer(w, [16, 3, 3, 32]) - B = tir.match_buffer(b, [16, 32, 14, 14]) - Apad = tir.alloc_buffer([16, 16, 16, 16]) +@T.prim_func +def tir_conv2d(a: T.handle, w: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 14, 14]) + W = T.match_buffer(w, [16, 3, 3, 32]) + B = T.match_buffer(b, [16, 32, 14, 14]) + Apad = T.alloc_buffer([16, 16, 16, 16]) - with tir.block([16, 16, 16, 16], "Apad") as [nn, cc, yy, xx]: - Apad[nn, cc, yy, xx] = tir.if_then_else( + with T.block([16, 16, 16, 16], "Apad") as [nn, cc, yy, xx]: + Apad[nn, cc, yy, xx] = T.if_then_else( yy >= 1 and yy - 1 < 14 and xx >= 1 and xx - 1 < 14, A[nn, cc, yy - 1, xx - 1], 0.0, dtype="float32", ) - with tir.block( - [16, 32, 14, 14, tir.reduce_axis(0, 16), tir.reduce_axis(0, 3), tir.reduce_axis(0, 3)], "B" + with T.block( + [16, 32, 14, 14, T.reduce_axis(0, 16), T.reduce_axis(0, 3), T.reduce_axis(0, 3)], "B" ) as [nn, ff, yy, xx, rc, ry, rx]: - with tir.init(): + with T.init(): B[nn, ff, yy, xx] = 0.0 B[nn, ff, yy, xx] += Apad[nn, rc, yy + ry, xx + rx] * W[rc, ry, rx, ff] @@ -153,19 +153,19 @@ def te_multi_output(): return [A0, A1, B0, B1] -@tvm.script.tir -def tir_multi_output(a0: ty.handle, a1: ty.handle, b0: ty.handle, b1: ty.handle) -> None: - m = tir.var("int32") - n = tir.var("int32") - A0 = tir.match_buffer(a0, (m, n)) - A1 = tir.match_buffer(a1, (m, n)) - B0 = tir.match_buffer(b0, (m, n)) - B1 = tir.match_buffer(b1, (m, n)) +@T.prim_func +def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> None: + m = T.var("int32") + n = T.var("int32") + A0 = T.match_buffer(a0, (m, n)) + A1 = T.match_buffer(a1, (m, n)) + B0 = T.match_buffer(b0, (m, n)) + B1 = T.match_buffer(b1, (m, n)) - for i0, i1 in tir.grid(m, n): - with tir.block([m, n], "B.v0") as [i, j]: + for i0, i1 in T.grid(m, n): + with T.block([m, n], "B.v0") as [i, j]: B0[i, j] = A0[i, j] + 2.0 - with tir.block([m, n], "B.v1") as [i, j]: + with T.block([m, n], "B.v1") as [i, j]: B1[i, j] = A1[i, j] * 3.0 @@ -187,39 +187,39 @@ def te_extern(): return [A, B, C] -@tvm.script.tir -def tir_extern(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - C = tir.match_buffer(c, (128, 128)) +@T.prim_func +def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) # body - with tir.block([], "C"): - tir.reads([A[0:128, 0:128], B[0:128, 0:128]]) - tir.writes([C[0:128, 0:128]]) - tir.evaluate( - tir.tvm_call_packed( + with T.block([], "C"): + T.reads([A[0:128, 0:128], B[0:128, 0:128]]) + T.writes([C[0:128, 0:128]]) + T.evaluate( + T.tvm_call_packed( "tvm.contrib.cblas.matmul", - tir.tvm_stack_make_array( + T.tvm_stack_make_array( A.data, - tir.tvm_stack_make_shape(128, 128, dtype="handle"), + T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, 0, dtype="handle", ), - tir.tvm_stack_make_array( + T.tvm_stack_make_array( B.data, - tir.tvm_stack_make_shape(128, 128, dtype="handle"), + T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, 0, dtype="handle", ), - tir.tvm_stack_make_array( + T.tvm_stack_make_array( C.data, - tir.tvm_stack_make_shape(128, 128, dtype="handle"), + T.tvm_stack_make_shape(128, 128, dtype="handle"), 0, 2, 0.0, @@ -245,14 +245,14 @@ def te_reordered_matmul(): return [C, A, B] -@tvm.script.tir -def tir_reordered_matmul(c: ty.handle, a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - C = tir.match_buffer(c, (128, 128)) +@T.prim_func +def tir_reordered_matmul(c: T.handle, a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) - with tir.block([128, 128, tir.reduce_axis(0, 128)]) as [i, j, k]: - with tir.init(): + with T.block([128, 128, T.reduce_axis(0, 128)]) as [i, j, k]: + with T.init(): C[i, j] = 0.0 C[i, j] += A[i, k] * B[j, k] diff --git a/tests/python/unittest/test_tir_analysis_calculate_workspace.py b/tests/python/unittest/test_tir_analysis_calculate_workspace.py index 190d1820c1f4e..4b61625014e2c 100644 --- a/tests/python/unittest/test_tir_analysis_calculate_workspace.py +++ b/tests/python/unittest/test_tir_analysis_calculate_workspace.py @@ -14,81 +14,79 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import numpy as np import pytest import tvm -from tvm import tir, script -from tvm.ir import Range -from tvm.script import ty +from tvm import tir +from tvm.script import tir as T # fmt: off -@tvm.script.tir -def primfunc_global_allocates(placeholder_144: ty.handle, placeholder_145: ty.handle, placeholder_146: ty.handle, T_cast_48: ty.handle) -> None: +@T.prim_func +def primfunc_global_allocates(placeholder_144: T.handle, placeholder_145: T.handle, placeholder_146: T.handle, T_cast_48: T.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_13", "tir.noalias": True}) - placeholder_147 = tir.match_buffer(placeholder_144, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_148 = tir.match_buffer(placeholder_145, [3, 3, 512, 1], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_149 = tir.match_buffer(placeholder_146, [1, 1, 1, 512], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_49 = tir.match_buffer(T_cast_48, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"global_symbol": "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_13", "tir.noalias": True}) + placeholder_147 = T.match_buffer(placeholder_144, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_148 = T.match_buffer(placeholder_145, [3, 3, 512, 1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_149 = T.match_buffer(placeholder_146, [1, 1, 1, 512], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_49 = T.match_buffer(T_cast_48, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - PaddedInput_22 = tir.allocate([131072], "int16", "global") - DepthwiseConv2d_9 = tir.allocate([100352], "int32", "global") - for i1_29, i2_39, i3_40 in tir.grid(16, 16, 512): - PaddedInput_22[(((i1_29*8192) + (i2_39*512)) + i3_40)] = tir.if_then_else(((((1 <= i1_29) and (i1_29 < 15)) and (1 <= i2_39)) and (i2_39 < 15)), tir.load("int16", placeholder_147.data, ((((i1_29*7168) + (i2_39*512)) + i3_40) - 7680)), tir.int16(0), dtype="int16") - for i_9, j_9, c_9 in tir.grid(14, 14, 512): + PaddedInput_22 = T.allocate([131072], "int16", "global") + DepthwiseConv2d_9 = T.allocate([100352], "int32", "global") + for i1_29, i2_39, i3_40 in T.grid(16, 16, 512): + PaddedInput_22[(((i1_29*8192) + (i2_39*512)) + i3_40)] = T.if_then_else(((((1 <= i1_29) and (i1_29 < 15)) and (1 <= i2_39)) and (i2_39 < 15)), T.load("int16", placeholder_147.data, ((((i1_29*7168) + (i2_39*512)) + i3_40) - 7680)), T.int16(0), dtype="int16") + for i_9, j_9, c_9 in T.grid(14, 14, 512): DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] = 0 - for di_9, dj_9 in tir.grid(3, 3): - DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] = (tir.load("int32", DepthwiseConv2d_9, (((i_9*7168) + (j_9*512)) + c_9)) + (tir.load("int16", PaddedInput_22, (((((i_9*8192) + (di_9*8192)) + (j_9*512)) + (dj_9*512)) + c_9)).astype("int32")*tir.load("int16", placeholder_148.data, (((di_9*1536) + (dj_9*512)) + c_9)).astype("int32"))) - for ax1_27, ax2_28, ax3_30 in tir.grid(14, 14, 512): - DepthwiseConv2d_9[(((ax1_27*7168) + (ax2_28*512)) + ax3_30)] = (tir.load("int32", DepthwiseConv2d_9, (((ax1_27*7168) + (ax2_28*512)) + ax3_30)) + tir.load("int32", placeholder_149.data, ax3_30)) - for i1_30, i2_40, i3_41 in tir.grid(14, 14, 512): - DepthwiseConv2d_9[(((i1_30*7168) + (i2_40*512)) + i3_41)] = tir.q_multiply_shift(tir.load("int32", DepthwiseConv2d_9, (((i1_30*7168) + (i2_40*512)) + i3_41)), 1269068532, 31, -4, dtype="int32") - for i1_31, i2_41, i3_42 in tir.grid(14, 14, 512): - DepthwiseConv2d_9[(((i1_31*7168) + (i2_41*512)) + i3_42)] = tir.max(tir.max(tir.load("int32", DepthwiseConv2d_9, (((i1_31*7168) + (i2_41*512)) + i3_42)), 255), 0) - for ax1_28, ax2_29, ax3_31 in tir.grid(14, 14, 512): - PaddedInput_22[(((ax1_28*7168) + (ax2_29*512)) + ax3_31)] = tir.load("int32", DepthwiseConv2d_9, (((ax1_28*7168) + (ax2_29*512)) + ax3_31)).astype("uint8") - for ax1_29, ax2_30, ax3_32 in tir.grid(14, 14, 512): - T_cast_49.data[(((ax1_29*7168) + (ax2_30*512)) + ax3_32)] = tir.load("uint8", PaddedInput_22, (((ax1_29*7168) + (ax2_30*512)) + ax3_32)).astype("int16") + for di_9, dj_9 in T.grid(3, 3): + DepthwiseConv2d_9[(((i_9*7168) + (j_9*512)) + c_9)] = (T.load("int32", DepthwiseConv2d_9, (((i_9*7168) + (j_9*512)) + c_9)) + (T.load("int16", PaddedInput_22, (((((i_9*8192) + (di_9*8192)) + (j_9*512)) + (dj_9*512)) + c_9)).astype("int32")*T.load("int16", placeholder_148.data, (((di_9*1536) + (dj_9*512)) + c_9)).astype("int32"))) + for ax1_27, ax2_28, ax3_30 in T.grid(14, 14, 512): + DepthwiseConv2d_9[(((ax1_27*7168) + (ax2_28*512)) + ax3_30)] = (T.load("int32", DepthwiseConv2d_9, (((ax1_27*7168) + (ax2_28*512)) + ax3_30)) + T.load("int32", placeholder_149.data, ax3_30)) + for i1_30, i2_40, i3_41 in T.grid(14, 14, 512): + DepthwiseConv2d_9[(((i1_30*7168) + (i2_40*512)) + i3_41)] = T.q_multiply_shift(T.load("int32", DepthwiseConv2d_9, (((i1_30*7168) + (i2_40*512)) + i3_41)), 1269068532, 31, -4, dtype="int32") + for i1_31, i2_41, i3_42 in T.grid(14, 14, 512): + DepthwiseConv2d_9[(((i1_31*7168) + (i2_41*512)) + i3_42)] = T.max(T.max(T.load("int32", DepthwiseConv2d_9, (((i1_31*7168) + (i2_41*512)) + i3_42)), 255), 0) + for ax1_28, ax2_29, ax3_31 in T.grid(14, 14, 512): + PaddedInput_22[(((ax1_28*7168) + (ax2_29*512)) + ax3_31)] = T.load("int32", DepthwiseConv2d_9, (((ax1_28*7168) + (ax2_29*512)) + ax3_31)).astype("uint8") + for ax1_29, ax2_30, ax3_32 in T.grid(14, 14, 512): + T_cast_49.data[(((ax1_29*7168) + (ax2_30*512)) + ax3_32)] = T.load("uint8", PaddedInput_22, (((ax1_29*7168) + (ax2_30*512)) + ax3_32)).astype("int16") # fmt: on # fmt: off -@tvm.script.tir -def primfunc_local_allocates(placeholder_162: ty.handle, placeholder_163: ty.handle, placeholder_164: ty.handle, T_cast_76: ty.handle) -> None: +@T.prim_func +def primfunc_local_allocates(placeholder_162: T.handle, placeholder_163: T.handle, placeholder_164: T.handle, T_cast_76: T.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_9", "tir.noalias": True}) - placeholder_165 = tir.match_buffer(placeholder_162, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_166 = tir.match_buffer(placeholder_163, [3, 3, 512, 1], dtype="int16", elem_offset=0, align=128, offset_factor=1) - placeholder_167 = tir.match_buffer(placeholder_164, [1, 1, 1, 512], dtype="int32", elem_offset=0, align=128, offset_factor=1) - T_cast_77 = tir.match_buffer(T_cast_76, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T.func_attr({"global_symbol": "fused_nn_conv2d_add_cast_fixed_point_multiply_clip_cast_cast_9", "tir.noalias": True}) + placeholder_165 = T.match_buffer(placeholder_162, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_166 = T.match_buffer(placeholder_163, [3, 3, 512, 1], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_167 = T.match_buffer(placeholder_164, [1, 1, 1, 512], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_77 = T.match_buffer(T_cast_76, [1, 14, 14, 512], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body - PaddedInput_25 = tir.allocate([1, 16, 16, 512], "int16", "global") - for i1_35, i2_46, i3_47 in tir.grid(16, 16, 512): - PaddedInput_25[(((i1_35*8192) + (i2_46*512)) + i3_47)] = tir.if_then_else(((((1 <= i1_35) and (i1_35 < 15)) and (1 <= i2_46)) and (i2_46 < 15)), tir.load("int16", placeholder_165.data, ((((i1_35*7168) + (i2_46*512)) + i3_47) - 7680)), tir.int16(0), dtype="int16") - T_add_11 = tir.allocate([1, 14, 14, 512], "int32", "global") - with tir.allocate([1, 14, 14, 512], "int32", "global") as DepthwiseConv2d_11: - for i_11, j_11, c_11 in tir.grid(14, 14, 512): + PaddedInput_25 = T.allocate([1, 16, 16, 512], "int16", "global") + for i1_35, i2_46, i3_47 in T.grid(16, 16, 512): + PaddedInput_25[(((i1_35*8192) + (i2_46*512)) + i3_47)] = T.if_then_else(((((1 <= i1_35) and (i1_35 < 15)) and (1 <= i2_46)) and (i2_46 < 15)), T.load("int16", placeholder_165.data, ((((i1_35*7168) + (i2_46*512)) + i3_47) - 7680)), T.int16(0), dtype="int16") + T_add_11 = T.allocate([1, 14, 14, 512], "int32", "global") + with T.allocate([1, 14, 14, 512], "int32", "global") as DepthwiseConv2d_11: + for i_11, j_11, c_11 in T.grid(14, 14, 512): DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = 0 - for di_11, dj_11 in tir.grid(3, 3): - DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = (tir.load("int32", DepthwiseConv2d_11, (((i_11*7168) + (j_11*512)) + c_11)) + (tir.load("int16", PaddedInput_25, (((((i_11*8192) + (di_11*8192)) + (j_11*512)) + (dj_11*512)) + c_11)).astype("int32")*tir.load("int16", placeholder_166.data, (((di_11*1536) + (dj_11*512)) + c_11)).astype("int32"))) - for ax1_44, ax2_45, ax3_47 in tir.grid(14, 14, 512): - T_add_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] = (tir.load("int32", DepthwiseConv2d_11, (((ax1_44*7168) + (ax2_45*512)) + ax3_47)) + tir.load("int32", placeholder_167.data, ax3_47)) - compute_22 = tir.allocate([1, 14, 14, 512], "int32", "global") - with tir.allocate([1, 14, 14, 512], "int32", "global") as T_cast_78: - for ax1_45, ax2_46, ax3_48 in tir.grid(14, 14, 512): - T_cast_78[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] = tir.load("int32", T_add_11, (((ax1_45*7168) + (ax2_46*512)) + ax3_48)) - for i1_36, i2_47, i3_48 in tir.grid(14, 14, 512): - compute_22[(((i1_36*7168) + (i2_47*512)) + i3_48)] = tir.q_multiply_shift(tir.load("int32", T_cast_78, (((i1_36*7168) + (i2_47*512)) + i3_48)), 1948805937, 31, -5, dtype="int32") - T_cast_79 = tir.allocate([1, 14, 14, 512], "uint8", "global") - with tir.allocate([1, 14, 14, 512], "int32", "global") as compute_23: - for i1_37, i2_48, i3_49 in tir.grid(14, 14, 512): - compute_23[(((i1_37*7168) + (i2_48*512)) + i3_49)] = tir.max(tir.max(tir.load("int32", compute_22, (((i1_37*7168) + (i2_48*512)) + i3_49)), 255), 0) - for ax1_46, ax2_47, ax3_49 in tir.grid(14, 14, 512): - T_cast_79[(((ax1_46*7168) + (ax2_47*512)) + ax3_49)] = tir.load("int32", compute_23, (((ax1_46*7168) + (ax2_47*512)) + ax3_49)).astype("uint8") - for ax1_47, ax2_48, ax3_50 in tir.grid(14, 14, 512): - T_cast_77.data[(((ax1_47*7168) + (ax2_48*512)) + ax3_50)] = tir.load("uint8", T_cast_79, (((ax1_47*7168) + (ax2_48*512)) + ax3_50)).astype("int16") + for di_11, dj_11 in T.grid(3, 3): + DepthwiseConv2d_11[(((i_11*7168) + (j_11*512)) + c_11)] = (T.load("int32", DepthwiseConv2d_11, (((i_11*7168) + (j_11*512)) + c_11)) + (T.load("int16", PaddedInput_25, (((((i_11*8192) + (di_11*8192)) + (j_11*512)) + (dj_11*512)) + c_11)).astype("int32")*T.load("int16", placeholder_166.data, (((di_11*1536) + (dj_11*512)) + c_11)).astype("int32"))) + for ax1_44, ax2_45, ax3_47 in T.grid(14, 14, 512): + T_add_11[(((ax1_44*7168) + (ax2_45*512)) + ax3_47)] = (T.load("int32", DepthwiseConv2d_11, (((ax1_44*7168) + (ax2_45*512)) + ax3_47)) + T.load("int32", placeholder_167.data, ax3_47)) + compute_22 = T.allocate([1, 14, 14, 512], "int32", "global") + with T.allocate([1, 14, 14, 512], "int32", "global") as T_cast_78: + for ax1_45, ax2_46, ax3_48 in T.grid(14, 14, 512): + T_cast_78[(((ax1_45*7168) + (ax2_46*512)) + ax3_48)] = T.load("int32", T_add_11, (((ax1_45*7168) + (ax2_46*512)) + ax3_48)) + for i1_36, i2_47, i3_48 in T.grid(14, 14, 512): + compute_22[(((i1_36*7168) + (i2_47*512)) + i3_48)] = T.q_multiply_shift(T.load("int32", T_cast_78, (((i1_36*7168) + (i2_47*512)) + i3_48)), 1948805937, 31, -5, dtype="int32") + T_cast_79 = T.allocate([1, 14, 14, 512], "uint8", "global") + with T.allocate([1, 14, 14, 512], "int32", "global") as compute_23: + for i1_37, i2_48, i3_49 in T.grid(14, 14, 512): + compute_23[(((i1_37*7168) + (i2_48*512)) + i3_49)] = T.max(T.max(T.load("int32", compute_22, (((i1_37*7168) + (i2_48*512)) + i3_49)), 255), 0) + for ax1_46, ax2_47, ax3_49 in T.grid(14, 14, 512): + T_cast_79[(((ax1_46*7168) + (ax2_47*512)) + ax3_49)] = T.load("int32", compute_23, (((ax1_46*7168) + (ax2_47*512)) + ax3_49)).astype("uint8") + for ax1_47, ax2_48, ax3_50 in T.grid(14, 14, 512): + T_cast_77.data[(((ax1_47*7168) + (ax2_48*512)) + ax3_50)] = T.load("uint8", T_cast_79, (((ax1_47*7168) + (ax2_48*512)) + ax3_50)).astype("int16") # fmt: on diff --git a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py index 8c2b2710f1ba5..1aae8cdd03e1a 100644 --- a/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py +++ b/tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py @@ -16,74 +16,72 @@ # under the License. import tvm from tvm import tir -from tvm.script import ty - - -@tvm.script.tir -def buffer_load_store_func(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - B = tir.match_buffer(b, (128, 128), "float32") - C = tir.alloc_buffer((128, 128), "float32") - D = tir.alloc_buffer((128, 128), "float32") - with tir.block([128, 128]) as [i, j]: - A[i, j] = tir.float32(0) - with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]: - with tir.init(): - for ii, jj in tir.grid(4, 4): +from tvm.script import tir as T + + +@T.prim_func +def buffer_load_store_func(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.match_buffer(b, (128, 128), "float32") + C = T.alloc_buffer((128, 128), "float32") + D = T.alloc_buffer((128, 128), "float32") + with T.block([128, 128]) as [i, j]: + A[i, j] = T.float32(0) + with T.block([32, 32, T.reduce_axis(0, 32)]) as [i, j, k]: + with T.init(): + for ii, jj in T.grid(4, 4): B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] - for ii, jj in tir.grid(4, 4): + for ii, jj in T.grid(4, 4): for kk in range(0, 4): B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] for kk in range(0, 4): B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] -@tvm.script.tir -def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None: - B = tir.match_buffer(b, [16, 16], "float32") - C = tir.match_buffer(c, [16, 16], "float32") +@T.prim_func +def buffer_opaque_access(b: T.handle, c: T.handle) -> None: + B = T.match_buffer(b, [16, 16], "float32") + C = T.match_buffer(c, [16, 16], "float32") - with tir.block([]): - tir.reads([]) - tir.writes(B[0:16, 0:16]) - A = tir.allocate([256], "float32", "global") - for i, j in tir.grid(16, 16): - tir.store(A, i * 16 + j, 1) + with T.block([]): + T.reads([]) + T.writes(B[0:16, 0:16]) + A = T.allocate([256], "float32", "global") + for i, j in T.grid(16, 16): + T.store(A, i * 16 + j, 1) for i in range(0, 16): for j in range(0, 16): - tir.evaluate(tir.load("float32", A, i * 16 + j)) + T.evaluate(T.load("float32", A, i * 16 + j)) for j in range(0, 16): - tir.evaluate( - tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, tir.float32(0), dtype="handle") - ) - - for i, j in tir.grid(16, 16): - with tir.block([16, 16]) as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, T.float32(0), dtype="handle")) + + for i, j in T.grid(16, 16): + with T.block([16, 16]) as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) C[vi, vj] = B[vi, vj] -@tvm.script.tir -def lca_is_func_root(a: ty.handle) -> None: - A = tir.match_buffer(a, [0, 0], "float32") +@T.prim_func +def lca_is_func_root(a: T.handle) -> None: + A = T.match_buffer(a, [0, 0], "float32") A.data[0] = 1.0 -@tvm.script.tir -def match_buffer_func(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - B = tir.match_buffer(b, (128, 128), "float32") - with tir.block([8, 8], "block") as [vi, vj]: - tir.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) - tir.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - B0 = tir.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) - B1 = tir.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) - with tir.block([16, 16], "AAA") as [i, j]: - AA = tir.match_buffer(A[i, j], ()) +@T.prim_func +def match_buffer_func(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.match_buffer(b, (128, 128), "float32") + with T.block([8, 8], "block") as [vi, vj]: + T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = T.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) + with T.block([16, 16], "AAA") as [i, j]: + AA = T.match_buffer(A[i, j], ()) AA[()] = 1.0 - tir.evaluate(B0.data) - tir.evaluate(B1.data) + T.evaluate(B0.data) + T.evaluate(B1.data) def test_buffer_load_store(): diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index bc421aa4d19b3..e3a63c3254344 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -16,83 +16,84 @@ # under the License. import pytest import tvm -from tvm import tir, script +from tvm import tir +from tvm.script import tir as T from tvm.ir import Range -@tvm.script.tir +@T.prim_func def func() -> None: - A = tir.alloc_buffer((128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - C = tir.alloc_buffer((128, 128), "float32") - D = tir.alloc_buffer((128, 128), "float32") - with tir.block([]): + A = T.alloc_buffer((128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.alloc_buffer((128, 128), "float32") + D = T.alloc_buffer((128, 128), "float32") + with T.block([]): # Need add read/write region manually to avoid triggering block access region detector - tir.reads([B[0, 0], C[0:16, 0:16], A[4:12, 4:12]]) - tir.writes([A[0:12, 0:12]]) - for i, j in tir.grid(8, 8): + T.reads([B[0, 0], C[0:16, 0:16], A[4:12, 4:12]]) + T.writes([A[0:12, 0:12]]) + for i, j in T.grid(8, 8): A[i, j] = B[0, 0] + C[0, 0] - with tir.block([2, 2]) as [vi, vj]: - tir.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8], C[12:16, 12:16]]) - tir.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]]) - for i, j in tir.grid(4, 4): + with T.block([2, 2]) as [vi, vj]: + T.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8], C[12:16, 12:16]]) + T.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]]) + for i, j in T.grid(4, 4): A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12] - tir.evaluate(D.data) + T.evaluate(D.data) -@tvm.script.tir +@T.prim_func def match_buffer_func() -> None: - with tir.block([], "root"): - A = tir.alloc_buffer((128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - tir.reads([]) - tir.writes([]) + with T.block([], "root"): + A = T.alloc_buffer((128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + T.reads([]) + T.writes([]) # Need add read/write region manually to avoid triggering block access region detector - with tir.block([8, 8], "block") as [vi, vj]: - tir.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) - tir.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - AA = tir.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) - B0 = tir.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) - B1 = tir.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) - with tir.block([16, 16], "AAA") as [i, j]: - tir.reads([]) - tir.writes(AA[i, j]) - AAA = tir.match_buffer(AA[i, j], ()) + with T.block([8, 8], "block") as [vi, vj]: + T.reads(B[vi * 16 + 2 : vi * 16 + 12, vj * 16 + 2 : vj * 16 + 16]) + T.writes(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + AA = T.match_buffer(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16)) + B0 = T.match_buffer(B[vi * 16 + 2 : vi * 16 + 6, vj * 16 + 2 : vj * 16 + 6], (4, 4)) + B1 = T.match_buffer(B[vi * 16 + 8 : vi * 16 + 12, vj * 16 + 8 : vj * 16 + 16], (4, 8)) + with T.block([16, 16], "AAA") as [i, j]: + T.reads([]) + T.writes(AA[i, j]) + AAA = T.match_buffer(AA[i, j], ()) AAA[()] = 1.0 - tir.evaluate(B0.data) - tir.evaluate(B1.data) + T.evaluate(B0.data) + T.evaluate(B1.data) -@tvm.script.tir +@T.prim_func def opaque_block_func() -> None: - with tir.block([], "root"): - A = tir.alloc_buffer((16, 16), "float32") - B = tir.alloc_buffer((16, 16), "float32") - tir.reads([]) - tir.writes([]) + with T.block([], "root"): + A = T.alloc_buffer((16, 16), "float32") + B = T.alloc_buffer((16, 16), "float32") + T.reads([]) + T.writes([]) # Need add read/write region manually to avoid triggering block access region detector for i in range(0, 16): - with tir.block([]): - tir.reads(A[i, 0:16]) - tir.writes([B[i, 0:16]]) + with T.block([]): + T.reads(A[i, 0:16]) + T.writes([B[i, 0:16]]) for j in range(0, 16): - with tir.block([]): - tir.reads(A[i, j]) - tir.writes(B[i, j]) + with T.block([]): + T.reads(A[i, j]) + T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 -@tvm.script.tir +@T.prim_func def opaque_access_func() -> None: - A = tir.alloc_buffer([1024]) - B = tir.alloc_buffer([1024]) - for i in tir.serial(0, 8): - with tir.block([8]) as [v]: - tir.bind(v, i) - tir.reads([A[v * 128 : v * 128 + 128]]) - tir.writes([B[v * 128 : v * 128 + 128]]) - tir.evaluate( - tir.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32") + A = T.alloc_buffer([1024]) + B = T.alloc_buffer([1024]) + for i in T.serial(0, 8): + with T.block([8]) as [v]: + T.bind(v, i) + T.reads([A[v * 128 : v * 128 + 128]]) + T.writes([B[v * 128 : v * 128 + 128]]) + T.evaluate( + T.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32") ) diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index ecc30199c1a7d..3e9e7fd33fd90 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -19,7 +19,7 @@ from tvm import te, tir from tvm import topi from tvm.contrib import utils, clang -from tvm.script import ty +from tvm.script import tir as T import numpy as np import ctypes import math @@ -187,17 +187,18 @@ def clz_np(x, dtype): np.testing.assert_equal(b.numpy(), ref) -@tvm.script.tir +@tvm.script.ir_module class Module: - def test_tir_fma(A: ty.handle, B: ty.handle, C: ty.handle, d: ty.handle) -> None: + @T.prim_func + def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "test_fma", "tir.noalias": True}) - n = tir.var("int32") - stride = tir.var("int32") - stride_1 = tir.var("int32") - stride_2 = tir.var("int32") - stride_3 = tir.var("int32") - A_1 = tir.match_buffer( + T.func_attr({"global_symbol": "test_fma", "tir.noalias": True}) + n = T.var("int32") + stride = T.var("int32") + stride_1 = T.var("int32") + stride_2 = T.var("int32") + stride_3 = T.var("int32") + A_1 = T.match_buffer( A, [n], strides=[stride], @@ -206,7 +207,7 @@ def test_tir_fma(A: ty.handle, B: ty.handle, C: ty.handle, d: ty.handle) -> None offset_factor=1, type="auto", ) - B_1 = tir.match_buffer( + B_1 = T.match_buffer( B, [n], strides=[stride_1], @@ -215,7 +216,7 @@ def test_tir_fma(A: ty.handle, B: ty.handle, C: ty.handle, d: ty.handle) -> None offset_factor=1, type="auto", ) - C_1 = tir.match_buffer( + C_1 = T.match_buffer( C, [n], strides=[stride_2], @@ -224,7 +225,7 @@ def test_tir_fma(A: ty.handle, B: ty.handle, C: ty.handle, d: ty.handle) -> None offset_factor=1, type="auto", ) - d_1 = tir.match_buffer( + d_1 = T.match_buffer( d, [n], strides=[stride_3], @@ -234,11 +235,11 @@ def test_tir_fma(A: ty.handle, B: ty.handle, C: ty.handle, d: ty.handle) -> None type="auto", ) # body - for i in tir.serial(0, n): + for i in T.serial(0, n): d_1.data[(i * stride_3)] = ( - tir.load("float32", A_1.data, (i * stride)) - * tir.load("float32", B_1.data, (i * stride_1)) - ) + tir.load("float32", C_1.data, (i * stride_2)) + T.load("float32", A_1.data, (i * stride)) + * T.load("float32", B_1.data, (i * stride_1)) + ) + T.load("float32", C_1.data, (i * stride_2)) def test_fma(): @@ -248,7 +249,7 @@ def test_fma(): tvm.tir.transform.LowerIntrin(), ] ) - mod = opt(Module()) + mod = opt(Module) assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin" diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index efb2073e08625..92da807680666 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -18,8 +18,7 @@ import pytest import tvm -from tvm import tir -from tvm.script import ty +from tvm.script import tir as T def _check(original, transformed): @@ -35,33 +34,31 @@ def _check_fail(original): mod = tvm.tir.transform.LowerMatchBuffer()(mod) -@tvm.script.tir -def buffer_load_store(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16, 16)) - C = tir.match_buffer(c, (16, 16)) - for i, j, k in tir.grid(4, 16, 8): - with tir.block([]): - tir.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) - tir.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) - sub_A = tir.match_buffer( +@T.prim_func +def buffer_load_store(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16, 16)) + C = T.match_buffer(c, (16, 16)) + for i, j, k in T.grid(4, 16, 8): + with T.block([]): + T.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) + T.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) + sub_A = T.match_buffer( A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2], (4, 1, 2), offset_factor=1 ) - sub_C = tir.match_buffer( - C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2], (4, 2), offset_factor=1 - ) - for ii, kk in tir.grid(4, 2): + sub_C = T.match_buffer(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2], (4, 2), offset_factor=1) + for ii, kk in T.grid(4, 2): sub_A[ii, 0, kk] += sub_C[ii, kk] -@tvm.script.tir -def transformed_buffer_load_store(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16, 16)) - C = tir.match_buffer(c, (16, 16)) - for i, j, k in tir.grid(4, 16, 8): - with tir.block([]): - tir.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) - tir.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) - for ii, kk in tir.grid(4, 2): +@T.prim_func +def transformed_buffer_load_store(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16, 16)) + C = T.match_buffer(c, (16, 16)) + for i, j, k in T.grid(4, 16, 8): + with T.block([]): + T.reads(C[i * 4 : i * 4 + 4, k * 2 : k * 2 + 2]) + T.writes(A[i * 4 : i * 4 + 4, j, k * 2 : k * 2 + 2]) + for ii, kk in T.grid(4, 2): A[i * 4 + ii, j, k * 2 + kk] += C[i * 4 + ii, k * 2 + kk] @@ -70,22 +67,22 @@ def intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1): return 0 -@tvm.script.tir -def opaque_access(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (32, 64, 128)) - B = tir.match_buffer(b, (64, 64, 64)) - for i, j, k in tir.grid(2, 64, 8): - with tir.block([]): - tir.reads([]) - tir.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) - sub_A = tir.match_buffer( +@T.prim_func +def opaque_access(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (32, 64, 128)) + B = T.match_buffer(b, (64, 64, 64)) + for i, j, k in T.grid(2, 64, 8): + with T.block([]): + T.reads([]) + T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) + sub_A = T.match_buffer( A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16], (16, 1, 16), strides=[8192, 128, 1], offset_factor=1, ) - tir.evaluate( - tir.intrin_test( + T.evaluate( + T.intrin_test( sub_A.data, sub_A.elem_offset, sub_A.strides[0], @@ -95,20 +92,20 @@ def opaque_access(a: ty.handle, b: ty.handle) -> None: dtype="handle", ) ) - for i, j, k in tir.grid(64, 2, 8): - with tir.block([]): - Bs_0 = tir.var("int32") - Bs_1 = tir.var("int32") - tir.reads([]) - tir.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) - sub_B = tir.match_buffer( + for i, j, k in T.grid(64, 2, 8): + with T.block([]): + Bs_0 = T.var("int32") + Bs_1 = T.var("int32") + T.reads([]) + T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) + sub_B = T.match_buffer( B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8], (32, 8), strides=[Bs_0, Bs_1], offset_factor=1, ) - tir.evaluate( - tir.intrin_test( + T.evaluate( + T.intrin_test( sub_B.data, sub_B.elem_offset, sub_B.strides[0], @@ -120,16 +117,16 @@ def opaque_access(a: ty.handle, b: ty.handle) -> None: ) -@tvm.script.tir -def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (32, 64, 128)) - B = tir.match_buffer(b, (64, 64, 64)) - for i, j, k in tir.grid(2, 64, 8): - with tir.block([]): - tir.reads([]) - tir.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) - tir.evaluate( - tir.intrin_test( +@T.prim_func +def transformed_opaque_access(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (32, 64, 128)) + B = T.match_buffer(b, (64, 64, 64)) + for i, j, k in T.grid(2, 64, 8): + with T.block([]): + T.reads([]) + T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16]) + T.evaluate( + T.intrin_test( A.data, i * 131072 + j * 128 + k * 16, 8192, @@ -139,12 +136,12 @@ def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: dtype="handle", ) ) - for i, j, k in tir.grid(64, 2, 8): - with tir.block([]): - tir.reads([]) - tir.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) - tir.evaluate( - tir.intrin_test( + for i, j, k in T.grid(64, 2, 8): + with T.block([]): + T.reads([]) + T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8]) + T.evaluate( + T.intrin_test( B.data, i * 4096 + j * 2048 + k * 8, 64, @@ -156,23 +153,23 @@ def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: ) -@tvm.script.tir -def high_dim_opaque_access(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 32, 64)) - for i, j, k in tir.grid(16, 2, 4): - with tir.block([]): - As_0 = tir.var("int32") - As_1 = tir.var("int32") - tir.reads([]) - tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) - sub_A = tir.match_buffer( +@T.prim_func +def high_dim_opaque_access(a: T.handle) -> None: + A = T.match_buffer(a, (16, 32, 64)) + for i, j, k in T.grid(16, 2, 4): + with T.block([]): + As_0 = T.var("int32") + As_1 = T.var("int32") + T.reads([]) + T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) + sub_A = T.match_buffer( A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], (16, 16), strides=[As_0, As_1], offset_factor=1, ) - tir.evaluate( - tir.intrin_test( + T.evaluate( + T.intrin_test( sub_A.data, sub_A.elem_offset, sub_A.strides[0], @@ -184,15 +181,15 @@ def high_dim_opaque_access(a: ty.handle) -> None: ) -@tvm.script.tir -def transformed_high_dim_opaque_access(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 32, 64)) - for i, j, k in tir.grid(16, 2, 4): - with tir.block([]): - tir.reads([]) - tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) - tir.evaluate( - tir.intrin_test( +@T.prim_func +def transformed_high_dim_opaque_access(a: T.handle) -> None: + A = T.match_buffer(a, (16, 32, 64)) + for i, j, k in T.grid(16, 2, 4): + with T.block([]): + T.reads([]) + T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) + T.evaluate( + T.intrin_test( A.data, i * 2048 + j * 1024 + k * 16, 64, @@ -204,56 +201,56 @@ def transformed_high_dim_opaque_access(a: ty.handle) -> None: ) -@tvm.script.tir -def recursive_match(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (64, 64, 64)) - B = tir.match_buffer(b, (64, 64, 64)) - for i, j, k in tir.grid(64, 4, 4): - with tir.block([]): - tir.reads([]) - tir.writes( +@T.prim_func +def recursive_match(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (64, 64, 64)) + B = T.match_buffer(b, (64, 64, 64)) + for i, j, k in T.grid(64, 4, 4): + with T.block([]): + T.reads([]) + T.writes( [ A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], ] ) - As_0 = tir.var("int32") - As_1 = tir.var("int32") - sub_A = tir.match_buffer( + As_0 = T.var("int32") + As_1 = T.var("int32") + sub_A = T.match_buffer( A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], (16, 16), strides=[As_0, As_1], offset_factor=1, ) - sub_B = tir.match_buffer( + sub_B = T.match_buffer( B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], (16, 16), offset_factor=1, ) - for jj, kk in tir.grid(4, 4): - with tir.block([]): - tir.reads([]) - tir.writes( + for jj, kk in T.grid(4, 4): + with T.block([]): + T.reads([]) + T.writes( [ sub_A[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], sub_B[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], ] ) - Ass_0 = tir.var("int32") - Ass_1 = tir.var("int32") - sub_sub_A = tir.match_buffer( + Ass_0 = T.var("int32") + Ass_1 = T.var("int32") + sub_sub_A = T.match_buffer( sub_A[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], (4, 4), strides=[Ass_0, Ass_1], offset_factor=1, ) - sub_sub_B = tir.match_buffer( + sub_sub_B = T.match_buffer( sub_B[jj * 4 : jj * 4 + 4, kk * 4 : kk * 4 + 4], (4, 4), offset_factor=1, ) - tir.evaluate( - tir.intrin_test( + T.evaluate( + T.intrin_test( sub_sub_A.data, sub_sub_A.elem_offset, sub_sub_A.strides[0], @@ -263,27 +260,27 @@ def recursive_match(a: ty.handle, b: ty.handle) -> None: dtype="handle", ) ) - for jjj, kkk in tir.grid(4, 4): + for jjj, kkk in T.grid(4, 4): sub_sub_B[jjj, kkk] = 1 -@tvm.script.tir -def transformed_recursive_match(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (64, 64, 64)) - B = tir.match_buffer(b, (64, 64, 64)) - for i, j, k in tir.grid(64, 4, 4): - with tir.block([]): - tir.reads([]) - tir.writes( +@T.prim_func +def transformed_recursive_match(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (64, 64, 64)) + B = T.match_buffer(b, (64, 64, 64)) + for i, j, k in T.grid(64, 4, 4): + with T.block([]): + T.reads([]) + T.writes( [ A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], B[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16], ] ) - for jj, kk in tir.grid(4, 4): - with tir.block([]): - tir.reads([]) - tir.writes( + for jj, kk in T.grid(4, 4): + with T.block([]): + T.reads([]) + T.writes( [ A[ i, @@ -297,8 +294,8 @@ def transformed_recursive_match(a: ty.handle, b: ty.handle) -> None: ], ] ) - tir.evaluate( - tir.intrin_test( + T.evaluate( + T.intrin_test( A.data, i * 4096 + j * 1024 + jj * 256 + k * 16 + kk * 4, 64, @@ -308,29 +305,29 @@ def transformed_recursive_match(a: ty.handle, b: ty.handle) -> None: dtype="handle", ) ) - for jjj, kkk in tir.grid(4, 4): + for jjj, kkk in T.grid(4, 4): B[i, j * 16 + jj * 4 + jjj, k * 16 + kk * 4 + kkk] = 1 -@tvm.script.tir -def symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.int32) -> None: - A = tir.match_buffer(a, (n * m, m)) - B = tir.match_buffer(b, (n * 2, m * 4)) +@T.prim_func +def symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) -> None: + A = T.match_buffer(a, (n * m, m)) + B = T.match_buffer(b, (n * 2, m * 4)) for i in range(0, n): - with tir.block([]): - tir.reads([]) - tir.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) - Bs_0 = tir.var("int32") - Bs_1 = tir.var("int32") - sub_A = tir.match_buffer(A[i * m : i * m + m, 0:m], (m, m), offset_factor=1) - sub_B = tir.match_buffer( + with T.block([]): + T.reads([]) + T.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) + Bs_0 = T.var("int32") + Bs_1 = T.var("int32") + sub_A = T.match_buffer(A[i * m : i * m + m, 0:m], (m, m), offset_factor=1) + sub_B = T.match_buffer( B[i * n : i * n + 2, 0 : m * 4], (2, m * 4), strides=[Bs_0, Bs_1], offset_factor=1 ) - for ii, jj in tir.grid(m, m): + for ii, jj in T.grid(m, m): sub_A[ii, jj] = 1 for j in range(0, 4): - tir.evaluate( - tir.intrin_test( + T.evaluate( + T.intrin_test( sub_B.data, sub_B.elem_offset, sub_B.strides[0], @@ -342,19 +339,19 @@ def symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.int32) -> None ) -@tvm.script.tir -def transformed_symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.int32) -> None: - A = tir.match_buffer(a, (n * m, m)) - B = tir.match_buffer(b, (n * 2, m * 4)) +@T.prim_func +def transformed_symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) -> None: + A = T.match_buffer(a, (n * m, m)) + B = T.match_buffer(b, (n * 2, m * 4)) for i in range(0, n): - with tir.block([]): - tir.reads([]) - tir.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) - for ii, jj in tir.grid(m, m): + with T.block([]): + T.reads([]) + T.writes([A[i * m : i * m + n, 0:m], B[i * n : i * n + 2, 0 : m * 4]]) + for ii, jj in T.grid(m, m): A[i * m + ii, jj] = 1 for j in range(0, 4): - tir.evaluate( - tir.intrin_test( + T.evaluate( + T.intrin_test( B.data, i * n * (m * 4), m * 4, @@ -366,19 +363,19 @@ def transformed_symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.in ) -@tvm.script.tir -def rank0_buffer(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (8, 8)) - B = tir.match_buffer(b, (8, 8)) - for i, j in tir.grid(8, 8): - with tir.block([]): - tir.reads([]) - tir.writes([A[i, j], B[i, j]]) - sub_A = tir.match_buffer(A[i, j], (), offset_factor=1) - sub_B = tir.match_buffer(B[i, j], (), offset_factor=1) +@T.prim_func +def rank0_buffer(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (8, 8)) + B = T.match_buffer(b, (8, 8)) + for i, j in T.grid(8, 8): + with T.block([]): + T.reads([]) + T.writes([A[i, j], B[i, j]]) + sub_A = T.match_buffer(A[i, j], (), offset_factor=1) + sub_B = T.match_buffer(B[i, j], (), offset_factor=1) sub_A[()] = 1 - tir.evaluate( - tir.intrin_test( + T.evaluate( + T.intrin_test( sub_B.data, sub_B.elem_offset, 0, @@ -390,17 +387,17 @@ def rank0_buffer(a: ty.handle, b: ty.handle) -> None: ) -@tvm.script.tir -def transformed_rank0_buffer(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (8, 8)) - B = tir.match_buffer(b, (8, 8)) - for i, j in tir.grid(8, 8): - with tir.block([]): - tir.reads([]) - tir.writes([A[i, j], B[i, j]]) +@T.prim_func +def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (8, 8)) + B = T.match_buffer(b, (8, 8)) + for i, j in T.grid(8, 8): + with T.block([]): + T.reads([]) + T.writes([A[i, j], B[i, j]]) A[i, j] = 1 - tir.evaluate( - tir.intrin_test( + T.evaluate( + T.intrin_test( B.data, i * 8 + j, 0, @@ -412,49 +409,47 @@ def transformed_rank0_buffer(a: ty.handle, b: ty.handle) -> None: ) -@tvm.script.tir -def fail_match_load(a: ty.handle) -> None: - A = tir.match_buffer(a, (8, 8)) - for i, j in tir.grid(8, 8): - with tir.block([]): - tir.reads(A[i, j]) - tir.writes([]) - sub_A = tir.match_buffer(A[i, j], ()) - tir.evaluate(tir.load("float32", sub_A.data, 0)) - - -@tvm.script.tir -def fail_match_store(a: ty.handle) -> None: - A = tir.match_buffer(a, (8, 8)) - for i, j in tir.grid(8, 8): - with tir.block([]): - tir.reads([]) - tir.writes(A[i, j]) - sub_A = tir.match_buffer(A[i, j], ()) +@T.prim_func +def fail_match_load(a: T.handle) -> None: + A = T.match_buffer(a, (8, 8)) + for i, j in T.grid(8, 8): + with T.block([]): + T.reads(A[i, j]) + T.writes([]) + sub_A = T.match_buffer(A[i, j], ()) + T.evaluate(T.load("float32", sub_A.data, 0)) + + +@T.prim_func +def fail_match_store(a: T.handle) -> None: + A = T.match_buffer(a, (8, 8)) + for i, j in T.grid(8, 8): + with T.block([]): + T.reads([]) + T.writes(A[i, j]) + sub_A = T.match_buffer(A[i, j], ()) sub_A.data[0] = 1 -@tvm.script.tir -def fail_buffer_bind(a: ty.handle) -> None: - A = tir.match_buffer(a, (8, 8)) - for i, j in tir.grid(8, 2): - with tir.block([]): - stride = tir.var("int32") - sub_A = tir.match_buffer( +@T.prim_func +def fail_buffer_bind(a: T.handle) -> None: + A = T.match_buffer(a, (8, 8)) + for i, j in T.grid(8, 2): + with T.block([]): + stride = T.var("int32") + sub_A = T.match_buffer( A[i, j * 4 : j * 4 + 4], (1, 4), strides=[stride, stride], offset_factor=1 ) for jj in range(0, 4): sub_A[i, j * 4 + jj] = 1 -@tvm.script.tir -def fail_match_func_param(a: ty.handle, m: ty.handle, n: ty.handle) -> None: - A = tir.match_buffer(a, (8, 8)) - for i, j in tir.grid(8, 2): - with tir.block([]): - sub_A = tir.match_buffer( - A[i, j * 4 : j * 4 + 4], (1, 4), strides=[m, n], offset_factor=1 - ) +@T.prim_func +def fail_match_func_param(a: T.handle, m: T.handle, n: T.handle) -> None: + A = T.match_buffer(a, (8, 8)) + for i, j in T.grid(8, 2): + with T.block([]): + sub_A = T.match_buffer(A[i, j * 4 : j * 4 + 4], (1, 4), strides=[m, n], offset_factor=1) for jj in range(0, 4): sub_A[i, j * 4 + jj] = 1 diff --git a/tests/python/unittest/test_tir_schedule_block_scope.py b/tests/python/unittest/test_tir_schedule_block_scope.py index f66dca30d9980..2182c7b9f449e 100644 --- a/tests/python/unittest/test_tir_schedule_block_scope.py +++ b/tests/python/unittest/test_tir_schedule_block_scope.py @@ -20,47 +20,47 @@ import pytest import tvm from tvm import tir -from tvm.script import ty +from tvm.script import tir as T from tvm.tir.schedule import DepKind from tvm.tir.stmt_functor import post_order_visit # pylint: disable=no-member,invalid-name,unused-variable -@tvm.script.tir -def elementwise(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "init") as [vi, vj]: - C[vi, vj] = tir.float32(0) +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j in T.grid(128, 128): + with T.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = T.float32(0) for k in range(0, 128): - with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@tvm.script.tir -def war_dependency(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - C = tir.match_buffer(c, (128, 128)) +@T.prim_func +def war_dependency(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "C") as [vi, vj]: + for i, j in T.grid(128, 128): + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 - with tir.block([128, 128], "B") as [vi, vj]: + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index d7eb8d864135a..ff5b61a135ebf 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -20,7 +20,7 @@ import pytest import tvm from tvm import tir -from tvm.script import ty +from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip # pylint: disable=no-member,invalid-name,unused-variable @@ -28,62 +28,62 @@ ########## Function before schedule ########## -@tvm.script.tir -def elementwise(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def access_under_scope(b: ty.handle, c: ty.handle) -> None: - A = tir.alloc_buffer((128, 128)) - B = tir.match_buffer(b, (128, 128)) - C = tir.match_buffer(c, (128, 128)) +@T.prim_func +def access_under_scope(b: T.handle, c: T.handle) -> None: + A = T.alloc_buffer((128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) - with tir.block([8, 8], "scope") as [i, j]: - for x, y in tir.grid(16, 16): - with tir.block([128, 128], "A") as [vi, vj]: - tir.bind(vi, i * 16 + x) - tir.bind(vj, j * 16 + y) + with T.block([8, 8], "scope") as [i, j]: + for x, y in T.grid(16, 16): + with T.block([128, 128], "A") as [vi, vj]: + T.bind(vi, i * 16 + x) + T.bind(vj, j * 16 + y) A[vi, vj] = 1.0 - for x, y in tir.grid(16, 16): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i * 16 + x) - tir.bind(vj, j * 16 + y) + for x, y in T.grid(16, 16): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i * 16 + x) + T.bind(vj, j * 16 + y) B[vi, vj] = A[vi, vj] + 1.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = A[vi, vj] * 2.0 -@tvm.script.tir -def opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), dtype="float16") - B = tir.match_buffer(b, (128, 128), dtype="float16") - C = tir.match_buffer(c, (128, 128), dtype="float16") - D = tir.match_buffer(d, (128, 128), dtype="float16") - - with tir.block([128, 128], "load_store") as [vi, vj]: - tir.reads(A[vi, vj]) - tir.writes(D[vi, vj]) - D.data[vi * 128 + vj] = tir.load("float16", A.data, vi * 128 + vj) - with tir.block([8, 8], "opaque") as [vi, vj]: - tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - tir.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - tir.evaluate( - tir.tvm_load_matrix_sync( +@T.prim_func +def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (128, 128), dtype="float16") + B = T.match_buffer(b, (128, 128), dtype="float16") + C = T.match_buffer(c, (128, 128), dtype="float16") + D = T.match_buffer(d, (128, 128), dtype="float16") + + with T.block([128, 128], "load_store") as [vi, vj]: + T.reads(A[vi, vj]) + T.writes(D[vi, vj]) + D.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) + with T.block([8, 8], "opaque") as [vi, vj]: + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.evaluate( + T.tvm_load_matrix_sync( B.data, 16, 16, 16, vi * 8 + vj, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), A.data, vi * 2048 + vj * 16, 128, @@ -95,10 +95,10 @@ def opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> Non dtype="handle", ) ) - with tir.block([8, 8], "match_buffer") as [vi, vj]: - tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A0 = tir.match_buffer( + with T.block([8, 8], "match_buffer") as [vi, vj]: + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = T.match_buffer( A[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16, @@ -108,7 +108,7 @@ def opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> Non strides=[128, 1], offset_factor=1, ) - C0 = tir.match_buffer( + C0 = T.match_buffer( C[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16, @@ -118,15 +118,15 @@ def opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> Non strides=[128, 1], offset_factor=1, ) - tir.evaluate( - tir.tvm_load_matrix_sync( + T.evaluate( + T.tvm_load_matrix_sync( C0.data, 16, 16, 16, vi * 8 + vj, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), A0.data, A0.elem_offset, A0.strides[0], @@ -140,113 +140,113 @@ def opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> Non ) -@tvm.script.tir +@T.prim_func def func_multi_consumer() -> None: - A = tir.alloc_buffer((128)) - B = tir.alloc_buffer((128)) - C = tir.alloc_buffer((128)) - for i in tir.grid(8): - for j in tir.grid(16): - with tir.block([128], "A") as [vi]: - tir.bind(vi, i * 16 + j) + A = T.alloc_buffer((128)) + B = T.alloc_buffer((128)) + C = T.alloc_buffer((128)) + for i in T.grid(8): + for j in T.grid(16): + with T.block([128], "A") as [vi]: + T.bind(vi, i * 16 + j) A[vi] = 1.0 - for j in tir.grid(16): - with tir.block([128], "B") as [vi]: - tir.bind(vi, i * 16 + j) + for j in T.grid(16): + with T.block([128], "B") as [vi]: + T.bind(vi, i * 16 + j) B[vi] = A[vi] + 1.0 - for i in tir.grid(128): - with tir.block([128], "C") as [vi]: + for i in T.grid(128): + with T.block([128], "C") as [vi]: C[vi] = A[vi] -@tvm.script.tir +@T.prim_func def func_multi_producer() -> None: - A = tir.alloc_buffer((128)) - B = tir.alloc_buffer((128)) - with tir.block([128], "A0") as [vi]: + A = T.alloc_buffer((128)) + B = T.alloc_buffer((128)) + with T.block([128], "A0") as [vi]: A[vi] = 1.0 - with tir.block([128], "A1") as [vi]: + with T.block([128], "A1") as [vi]: A[vi] = 2.0 - with tir.block([128], "B") as [vi]: + with T.block([128], "B") as [vi]: B[vi] = A[vi] ########## Expected function after cache_read ########## -@tvm.script.tir -def cache_read_elementwise(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - B = tir.alloc_buffer((128, 128)) - A_global = tir.alloc_buffer((128, 128)) - B_local = tir.alloc_buffer((128, 128), scope="local") - with tir.block([128, 128], "A_global") as [vi, vj]: +@T.prim_func +def cache_read_elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + B = T.alloc_buffer((128, 128)) + A_global = T.alloc_buffer((128, 128)) + B_local = T.alloc_buffer((128, 128), scope="local") + with T.block([128, 128], "A_global") as [vi, vj]: A_global[vi, vj] = A[vi, vj] - with tir.block([128, 128], "B") as [vi, vj]: + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A_global[vi, vj] * 2.0 - with tir.block([128, 128], "B_local") as [vi, vj]: + with T.block([128, 128], "B_local") as [vi, vj]: B_local[vi, vj] = B[vi, vj] - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B_local[vi, vj] + 1.0 -@tvm.script.tir -def cache_read_under_scope(b: ty.handle, c: ty.handle) -> None: - A = tir.alloc_buffer((128, 128)) - B = tir.match_buffer(b, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - A_global = tir.alloc_buffer((128, 128)) +@T.prim_func +def cache_read_under_scope(b: T.handle, c: T.handle) -> None: + A = T.alloc_buffer((128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + A_global = T.alloc_buffer((128, 128)) - with tir.block([8, 8], "scope") as [i, j]: - A_local = tir.alloc_buffer((128, 128), scope="local") - for x, y in tir.grid(16, 16): - with tir.block([128, 128], "A") as [vi, vj]: - tir.bind(vi, i * 16 + x) - tir.bind(vj, j * 16 + y) + with T.block([8, 8], "scope") as [i, j]: + A_local = T.alloc_buffer((128, 128), scope="local") + for x, y in T.grid(16, 16): + with T.block([128, 128], "A") as [vi, vj]: + T.bind(vi, i * 16 + x) + T.bind(vj, j * 16 + y) A[vi, vj] = 1.0 - for x, y in tir.grid(16, 16): - with tir.block([128, 128], "A_local") as [vi, vj]: - tir.bind(vi, i * 16 + x) - tir.bind(vj, j * 16 + y) + for x, y in T.grid(16, 16): + with T.block([128, 128], "A_local") as [vi, vj]: + T.bind(vi, i * 16 + x) + T.bind(vj, j * 16 + y) A_local[vi, vj] = A[vi, vj] - for x, y in tir.grid(16, 16): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i * 16 + x) - tir.bind(vj, j * 16 + y) + for x, y in T.grid(16, 16): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i * 16 + x) + T.bind(vj, j * 16 + y) B[vi, vj] = A_local[vi, vj] + 1.0 - with tir.block([128, 128], "A_global") as [vi, vj]: + with T.block([128, 128], "A_global") as [vi, vj]: A_global[vi, vj] = A[vi, vj] - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = A_global[vi, vj] * 2.0 -@tvm.script.tir -def cache_read_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), dtype="float16") - B = tir.match_buffer(b, (128, 128), dtype="float16") - C = tir.match_buffer(c, (128, 128), dtype="float16") - D = tir.match_buffer(d, (128, 128), dtype="float16") - A_global = tir.alloc_buffer((128, 128), dtype="float16") +@T.prim_func +def cache_read_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (128, 128), dtype="float16") + B = T.match_buffer(b, (128, 128), dtype="float16") + C = T.match_buffer(c, (128, 128), dtype="float16") + D = T.match_buffer(d, (128, 128), dtype="float16") + A_global = T.alloc_buffer((128, 128), dtype="float16") - with tir.block([128, 128], "A_global") as [vi, vj]: + with T.block([128, 128], "A_global") as [vi, vj]: A_global[vi, vj] = A[vi, vj] - with tir.block([128, 128], "load_store") as [vi, vj]: - tir.reads(A_global[vi, vj]) - tir.writes(D[vi, vj]) - D.data[vi * 128 + vj] = tir.load("float16", A_global.data, vi * 128 + vj) - with tir.block([8, 8], "opaque") as [vi, vj]: - tir.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - tir.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - tir.evaluate( - tir.tvm_load_matrix_sync( + with T.block([128, 128], "load_store") as [vi, vj]: + T.reads(A_global[vi, vj]) + T.writes(D[vi, vj]) + D.data[vi * 128 + vj] = T.load("float16", A_global.data, vi * 128 + vj) + with T.block([8, 8], "opaque") as [vi, vj]: + T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.evaluate( + T.tvm_load_matrix_sync( B.data, 16, 16, 16, vi * 8 + vj, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), A_global.data, vi * 2048 + vj * 16, 128, @@ -258,10 +258,10 @@ def cache_read_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.han dtype="handle", ) ) - with tir.block([8, 8], "match_buffer") as [vi, vj]: - tir.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A0 = tir.match_buffer( + with T.block([8, 8], "match_buffer") as [vi, vj]: + T.reads(A_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = T.match_buffer( A_global[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16, @@ -271,7 +271,7 @@ def cache_read_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.han strides=[128, 1], offset_factor=1, ) - C0 = tir.match_buffer( + C0 = T.match_buffer( C[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16, @@ -281,15 +281,15 @@ def cache_read_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.han strides=[128, 1], offset_factor=1, ) - tir.evaluate( - tir.tvm_load_matrix_sync( + T.evaluate( + T.tvm_load_matrix_sync( C0.data, 16, 16, 16, vi * 8 + vj, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), A0.data, A0.elem_offset, A0.strides[0], @@ -303,130 +303,130 @@ def cache_read_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.han ) -@tvm.script.tir +@T.prim_func def cache_read_multi_consumer() -> None: - A = tir.alloc_buffer((128)) - B = tir.alloc_buffer((128)) - C = tir.alloc_buffer((128)) - A_global = tir.alloc_buffer((128)) - for i in tir.grid(8): - for j in tir.grid(16): - with tir.block([128], "A") as [vi]: - tir.bind(vi, i * 16 + j) + A = T.alloc_buffer((128)) + B = T.alloc_buffer((128)) + C = T.alloc_buffer((128)) + A_global = T.alloc_buffer((128)) + for i in T.grid(8): + for j in T.grid(16): + with T.block([128], "A") as [vi]: + T.bind(vi, i * 16 + j) A[vi] = 1.0 - for j in tir.grid(16): - with tir.block([128], "A") as [vi]: - tir.bind(vi, i * 16 + j) + for j in T.grid(16): + with T.block([128], "A") as [vi]: + T.bind(vi, i * 16 + j) A_global[vi] = A[vi] - for j in tir.grid(16): - with tir.block([128], "B") as [vi]: - tir.bind(vi, i * 16 + j) + for j in T.grid(16): + with T.block([128], "B") as [vi]: + T.bind(vi, i * 16 + j) B[vi] = A_global[vi] + 1.0 - for i in tir.grid(128): - with tir.block([128], "C") as [vi]: + for i in T.grid(128): + with T.block([128], "C") as [vi]: C[vi] = A_global[vi] -@tvm.script.tir -def continuous_cache_read(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - B = tir.alloc_buffer((128, 128)) - B_shared = tir.alloc_buffer((128, 128), scope="shared") - B_local = tir.alloc_buffer((128, 128), scope="local") - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def continuous_cache_read(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + B = T.alloc_buffer((128, 128)) + B_shared = T.alloc_buffer((128, 128), scope="shared") + B_local = T.alloc_buffer((128, 128), scope="local") + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "B_shared") as [vi, vj]: + with T.block([128, 128], "B_shared") as [vi, vj]: B_shared[vi, vj] = B[vi, vj] - with tir.block([128, 128], "B_local") as [vi, vj]: + with T.block([128, 128], "B_local") as [vi, vj]: B_local[vi, vj] = B_shared[vi, vj] - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B_local[vi, vj] + 1.0 ########## Expected function after cache_write ########## -@tvm.script.tir -def cache_write_elementwise(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - B = tir.alloc_buffer((128, 128)) - B_global = tir.alloc_buffer((128, 128), scope="local") - C_local = tir.alloc_buffer((128, 128)) - with tir.block([128, 128], "B_global") as [vi, vj]: +@T.prim_func +def cache_write_elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + B = T.alloc_buffer((128, 128)) + B_global = T.alloc_buffer((128, 128), scope="local") + C_local = T.alloc_buffer((128, 128)) + with T.block([128, 128], "B_global") as [vi, vj]: B_global[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "B") as [vi, vj]: + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = B_global[vi, vj] - with tir.block([128, 128], "C_local") as [vi, vj]: + with T.block([128, 128], "C_local") as [vi, vj]: C_local[vi, vj] = B[vi, vj] + 1.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = C_local[vi, vj] -@tvm.script.tir -def cache_write_under_scope(b: ty.handle, c: ty.handle) -> None: - A = tir.alloc_buffer((128, 128)) - B = tir.match_buffer(b, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - A_global = tir.alloc_buffer((128, 128)) - - with tir.block([8, 8], "scope") as [i, j]: - A_local = tir.alloc_buffer((128, 128), scope="local") - B_global = tir.alloc_buffer((128, 128)) - for x, y in tir.grid(16, 16): - with tir.block([128, 128], "A_local") as [vi, vj]: - tir.bind(vi, i * 16 + x) - tir.bind(vj, j * 16 + y) +@T.prim_func +def cache_write_under_scope(b: T.handle, c: T.handle) -> None: + A = T.alloc_buffer((128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + A_global = T.alloc_buffer((128, 128)) + + with T.block([8, 8], "scope") as [i, j]: + A_local = T.alloc_buffer((128, 128), scope="local") + B_global = T.alloc_buffer((128, 128)) + for x, y in T.grid(16, 16): + with T.block([128, 128], "A_local") as [vi, vj]: + T.bind(vi, i * 16 + x) + T.bind(vj, j * 16 + y) A_local[vi, vj] = 1.0 - for x, y in tir.grid(16, 16): - with tir.block([128, 128], "A") as [vi, vj]: - tir.bind(vi, i * 16 + x) - tir.bind(vj, j * 16 + y) + for x, y in T.grid(16, 16): + with T.block([128, 128], "A") as [vi, vj]: + T.bind(vi, i * 16 + x) + T.bind(vj, j * 16 + y) A_global[vi, vj] = A_local[vi, vj] - for x, y in tir.grid(16, 16): - with tir.block([128, 128], "B_global") as [vi, vj]: - tir.bind(vi, i * 16 + x) - tir.bind(vj, j * 16 + y) + for x, y in T.grid(16, 16): + with T.block([128, 128], "B_global") as [vi, vj]: + T.bind(vi, i * 16 + x) + T.bind(vj, j * 16 + y) B_global[vi, vj] = A_global[vi, vj] + 1.0 - for x, y in tir.grid(16, 16): - with tir.block([128, 128], "B_global") as [vi, vj]: - tir.bind(vi, i * 16 + x) - tir.bind(vj, j * 16 + y) + for x, y in T.grid(16, 16): + with T.block([128, 128], "B_global") as [vi, vj]: + T.bind(vi, i * 16 + x) + T.bind(vj, j * 16 + y) B[vi, vj] = B_global[vi, vj] - with tir.block([128, 128], "A_global") as [vi, vj]: + with T.block([128, 128], "A_global") as [vi, vj]: A[vi, vj] = A_global[vi, vj] - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = A[vi, vj] * 2.0 -@tvm.script.tir -def cache_write_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), dtype="float16") - B = tir.match_buffer(b, (128, 128), dtype="float16") - C = tir.match_buffer(c, (128, 128), dtype="float16") - D = tir.match_buffer(d, (128, 128), dtype="float16") - D_global = tir.alloc_buffer((128, 128), dtype="float16") - B_global = tir.alloc_buffer((128, 128), dtype="float16") - C_global = tir.alloc_buffer((128, 128), dtype="float16") - - with tir.block([128, 128], "load_store") as [vi, vj]: - tir.reads(A[vi, vj]) - tir.writes(D_global[vi, vj]) - D_global.data[vi * 128 + vj] = tir.load("float16", A.data, vi * 128 + vj) - with tir.block([8, 8], "opaque") as [vi, vj]: - tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - tir.writes(B_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - tir.evaluate( - tir.tvm_load_matrix_sync( +@T.prim_func +def cache_write_opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (128, 128), dtype="float16") + B = T.match_buffer(b, (128, 128), dtype="float16") + C = T.match_buffer(c, (128, 128), dtype="float16") + D = T.match_buffer(d, (128, 128), dtype="float16") + D_global = T.alloc_buffer((128, 128), dtype="float16") + B_global = T.alloc_buffer((128, 128), dtype="float16") + C_global = T.alloc_buffer((128, 128), dtype="float16") + + with T.block([128, 128], "load_store") as [vi, vj]: + T.reads(A[vi, vj]) + T.writes(D_global[vi, vj]) + D_global.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj) + with T.block([8, 8], "opaque") as [vi, vj]: + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(B_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.evaluate( + T.tvm_load_matrix_sync( B_global.data, 16, 16, 16, vi * 8 + vj, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), A.data, vi * 2048 + vj * 16, 128, @@ -438,10 +438,10 @@ def cache_write_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.ha dtype="handle", ) ) - with tir.block([8, 8], "match_buffer") as [vi, vj]: - tir.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - tir.writes(C_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) - A0 = tir.match_buffer( + with T.block([8, 8], "match_buffer") as [vi, vj]: + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(C_global[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + A0 = T.match_buffer( A[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16, @@ -451,7 +451,7 @@ def cache_write_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.ha strides=[128, 1], offset_factor=1, ) - C0 = tir.match_buffer( + C0 = T.match_buffer( C_global[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16, @@ -461,15 +461,15 @@ def cache_write_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.ha strides=[128, 1], offset_factor=1, ) - tir.evaluate( - tir.tvm_load_matrix_sync( + T.evaluate( + T.tvm_load_matrix_sync( C0.data, 16, 16, 16, vi * 8 + vj, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), A0.data, A0.elem_offset, A0.strides[0], @@ -482,53 +482,53 @@ def cache_write_opaque_access(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.ha ) ) - with tir.block([128, 128], "D") as [vi, vj]: + with T.block([128, 128], "D") as [vi, vj]: D[vi, vj] = D_global[vi, vj] - with tir.block([128, 128], "B") as [vi, vj]: + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = B_global[vi, vj] - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = C_global[vi, vj] -@tvm.script.tir +@T.prim_func def cache_write_multi_consumer() -> None: - A = tir.alloc_buffer((128)) - B = tir.alloc_buffer((128)) - C = tir.alloc_buffer((128)) - A_global = tir.alloc_buffer((128)) - for i in tir.grid(8): - for j in tir.grid(16): - with tir.block([128], "A_global") as [vi]: - tir.bind(vi, i * 16 + j) + A = T.alloc_buffer((128)) + B = T.alloc_buffer((128)) + C = T.alloc_buffer((128)) + A_global = T.alloc_buffer((128)) + for i in T.grid(8): + for j in T.grid(16): + with T.block([128], "A_global") as [vi]: + T.bind(vi, i * 16 + j) A_global[vi] = 1.0 - for j in tir.grid(16): - with tir.block([128], "A") as [vi]: - tir.bind(vi, i * 16 + j) + for j in T.grid(16): + with T.block([128], "A") as [vi]: + T.bind(vi, i * 16 + j) A[vi] = A_global[vi] - for j in tir.grid(16): - with tir.block([128], "B") as [vi]: - tir.bind(vi, i * 16 + j) + for j in T.grid(16): + with T.block([128], "B") as [vi]: + T.bind(vi, i * 16 + j) B[vi] = A[vi] + 1.0 - for i in tir.grid(128): - with tir.block([128], "C") as [vi]: + for i in T.grid(128): + with T.block([128], "C") as [vi]: C[vi] = A[vi] -@tvm.script.tir -def continuous_cache_write(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - B_shared = tir.alloc_buffer((128, 128), scope="shared") - B_local = tir.alloc_buffer((128, 128), scope="local") - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def continuous_cache_write(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + B_shared = T.alloc_buffer((128, 128), scope="shared") + B_local = T.alloc_buffer((128, 128), scope="local") + with T.block([128, 128], "B") as [vi, vj]: B_local[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "B") as [vi, vj]: + with T.block([128, 128], "B") as [vi, vj]: B_shared[vi, vj] = B_local[vi, vj] - with tir.block([128, 128], "B") as [vi, vj]: + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = B_shared[vi, vj] - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index a4f8b2e77078b..5235664595add 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -21,651 +21,651 @@ import tvm from tvm import tir -from tvm.script import ty +from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks -@tvm.script.tir -def two_elementwise(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def two_elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def two_elementwise_after_compute_at(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") +@T.prim_func +def two_elementwise_after_compute_at(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") for i in range(0, 128): - for ax0, ax1 in tir.grid(1, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i + ax0) - tir.bind(vj, ax1) + for ax0, ax1 in T.grid(1, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i + ax0) + T.bind(vj, ax1) B[vi, vj] = A[vi, vj] * 2.0 for j in range(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: + with T.block([128, 128], "B") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def blockized_1(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128], "float32") - B = tir.alloc_buffer([128, 128], "float32") - C = tir.match_buffer(c, [128, 128], "float32") - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def blockized_1(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], "float32") + B = T.alloc_buffer([128, 128], "float32") + C = T.match_buffer(c, [128, 128], "float32") + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([8, 8], "C_outer") as [vi_o, vj_o]: - tir.reads([B[ + with T.block([8, 8], "C_outer") as [vi_o, vj_o]: + T.reads([B[ vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16, ]]) - tir.writes([C[ + T.writes([C[ vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16 ]]) - for i_i, j_i in tir.grid(16, 16): - with tir.block([128, 128], "C_inner") as [vi, vj]: - tir.bind(vi, vi_o * 16 + i_i) - tir.bind(vj, vj_o * 16 + j_i) + for i_i, j_i in T.grid(16, 16): + with T.block([128, 128], "C_inner") as [vi, vj]: + T.bind(vi, vi_o * 16 + i_i) + T.bind(vj, vj_o * 16 + j_i) C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def blockized_after_compute_at(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128], "float32") - B = tir.alloc_buffer([128, 128], "float32") - C = tir.match_buffer(c, [128, 128], "float32") - for i0_0, i1_0 in tir.grid(8, 8): - for ax0, ax1 in tir.grid(16, 16): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i0_0 * 16 + ax0) - tir.bind(vj, i1_0 * 16 + ax1) +@T.prim_func +def blockized_after_compute_at(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], "float32") + B = T.alloc_buffer([128, 128], "float32") + C = T.match_buffer(c, [128, 128], "float32") + for i0_0, i1_0 in T.grid(8, 8): + for ax0, ax1 in T.grid(16, 16): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i0_0 * 16 + ax0) + T.bind(vj, i1_0 * 16 + ax1) B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([8, 8], "C_outer") as [vi_o, vj_o]: - tir.bind(vi_o, i0_0) - tir.bind(vj_o, i1_0) - tir.reads([B[ + with T.block([8, 8], "C_outer") as [vi_o, vj_o]: + T.bind(vi_o, i0_0) + T.bind(vj_o, i1_0) + T.reads([B[ vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16, ]]) - tir.writes([C[ + T.writes([C[ vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16 ]]) - for i0_1, i1_1 in tir.grid(16, 16): - with tir.block([128, 128], "C_inner") as [vi, vj]: - tir.bind(vi, vi_o * 16 + i0_1) - tir.bind(vj, vj_o * 16 + i1_1) + for i0_1, i1_1 in T.grid(16, 16): + with T.block([128, 128], "C_inner") as [vi, vj]: + T.bind(vi, vi_o * 16 + i0_1) + T.bind(vj, vj_o * 16 + i1_1) C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def blockized_2(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128], "float32") - B = tir.alloc_buffer([128, 128], "float32") - C = tir.match_buffer(c, [128, 128], "float32") - for i_o, j_o in tir.grid(8, 8): - with tir.block([8, 8], "B_outer") as [vio, vjo]: - tir.bind(vio, i_o) - tir.bind(vjo, j_o) - tir.reads([A[ +@T.prim_func +def blockized_2(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], "float32") + B = T.alloc_buffer([128, 128], "float32") + C = T.match_buffer(c, [128, 128], "float32") + for i_o, j_o in T.grid(8, 8): + with T.block([8, 8], "B_outer") as [vio, vjo]: + T.bind(vio, i_o) + T.bind(vjo, j_o) + T.reads([A[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16, ]]) - tir.writes([B[ + T.writes([B[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16 ]]) - for i_i, j_i in tir.grid(16, 16): - with tir.block([128, 128], "B_inner") as [vi, vj]: - tir.bind(vi, vio * 16 + i_i) - tir.bind(vj, vjo * 16 + j_i) + for i_i, j_i in T.grid(16, 16): + with T.block([128, 128], "B_inner") as [vi, vj]: + T.bind(vi, vio * 16 + i_i) + T.bind(vj, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 - for i_o, j_o, i_i, j_i in tir.grid(4, 4, 32, 32): - with tir.block([128, 128], "C") as [vi, vj]: - tir.bind(vi, i_o * 32 + i_i) - tir.bind(vj, j_o * 32 + j_i) + for i_o, j_o, i_i, j_i in T.grid(4, 4, 32, 32): + with T.block([128, 128], "C") as [vi, vj]: + T.bind(vi, i_o * 32 + i_i) + T.bind(vj, j_o * 32 + j_i) C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def blockized_2_after_reverse_compute_at(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128], "float32") - B = tir.alloc_buffer([128, 128], "float32") - C = tir.match_buffer(c, [128, 128], "float32") - for i_o, j_o in tir.grid(8, 8): - with tir.block([8, 8], "B_outer") as [vio, vjo]: - tir.bind(vio, i_o) - tir.bind(vjo, j_o) - tir.reads([A[ +@T.prim_func +def blockized_2_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], "float32") + B = T.alloc_buffer([128, 128], "float32") + C = T.match_buffer(c, [128, 128], "float32") + for i_o, j_o in T.grid(8, 8): + with T.block([8, 8], "B_outer") as [vio, vjo]: + T.bind(vio, i_o) + T.bind(vjo, j_o) + T.reads([A[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16, ]]) - tir.writes([B[ + T.writes([B[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16 ]]) - for i_i, j_i in tir.grid(16, 16): - with tir.block([128, 128], "B_inner") as [vi, vj]: - tir.bind(vi, vio * 16 + i_i) - tir.bind(vj, vjo * 16 + j_i) + for i_i, j_i in T.grid(16, 16): + with T.block([128, 128], "B_inner") as [vi, vj]: + T.bind(vi, vio * 16 + i_i) + T.bind(vj, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 - for ax0, ax1 in tir.grid(16, 16): - with tir.block([128, 128], "C") as [vi, vj]: - tir.bind(vi, i_o * 16 + ax0) - tir.bind(vj, j_o * 16 + ax1) - tir.reads([B[vi, vj]]) - tir.writes([C[vi, vj]]) + for ax0, ax1 in T.grid(16, 16): + with T.block([128, 128], "C") as [vi, vj]: + T.bind(vi, i_o * 16 + ax0) + T.bind(vj, j_o * 16 + ax1) + T.reads([B[vi, vj]]) + T.writes([C[vi, vj]]) C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def blockized_2_after_compute_at(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128], "float32") - B = tir.alloc_buffer([128, 128], "float32") - C = tir.match_buffer(c, [128, 128], "float32") - for i_o, j_o in tir.grid(4, 4): - for ax0, ax1 in tir.grid(2, 2): - with tir.block([8, 8], "blockized_B") as [vio, vjo]: - tir.bind(vio, i_o * 2 + ax0) - tir.bind(vjo, j_o * 2 + ax1) - tir.reads([A[ +@T.prim_func +def blockized_2_after_compute_at(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], "float32") + B = T.alloc_buffer([128, 128], "float32") + C = T.match_buffer(c, [128, 128], "float32") + for i_o, j_o in T.grid(4, 4): + for ax0, ax1 in T.grid(2, 2): + with T.block([8, 8], "blockized_B") as [vio, vjo]: + T.bind(vio, i_o * 2 + ax0) + T.bind(vjo, j_o * 2 + ax1) + T.reads([A[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16, ]]) - tir.writes([B[ + T.writes([B[ vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16, ]]) - for i_i, j_i in tir.grid(16, 16): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, vio * 16 + i_i) - tir.bind(vj, vjo * 16 + j_i) + for i_i, j_i in T.grid(16, 16): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, vio * 16 + i_i) + T.bind(vj, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 - for i_i, j_i in tir.grid(32, 32): - with tir.block([128, 128], "C") as [vi, vj]: - tir.bind(vi, i_o * 32 + i_i) - tir.bind(vj, j_o * 32 + j_i) + for i_i, j_i in T.grid(32, 32): + with T.block([128, 128], "C") as [vi, vj]: + T.bind(vi, i_o * 32 + i_i) + T.bind(vj, j_o * 32 + j_i) C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def cuda_matmul_0(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable - A = tir.match_buffer(a, [2048, 2048], "float32") - B = tir.match_buffer(b, [2048, 2048], "float32") - C = tir.match_buffer(c, [2048, 2048], "float32") - A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") - B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") - A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - with tir.block([2048, 2048], "A_shared") as [v0, v1]: +@T.prim_func +def cuda_matmul_0(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable + A = T.match_buffer(a, [2048, 2048], "float32") + B = T.match_buffer(b, [2048, 2048], "float32") + C = T.match_buffer(c, [2048, 2048], "float32") + A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") + A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + with T.block([2048, 2048], "A_shared") as [v0, v1]: A_shared[v0, v1] = A[v0, v1] - with tir.block([2048, 2048], "B_shared") as [v0, v1]: + with T.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] - with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: + with T.block([2048, 2048], "A_shared_local") as [v0, v1]: A_shared_local[v0, v1] = A_shared[v0, v1] - with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: + with T.block([2048, 2048], "B_shared_local") as [v0, v1]: B_shared_local[v0, v1] = B_shared[v0, v1] - with tir.block([2048, 2048, tir.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - with tir.init(): + with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: + with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] - for by in tir.thread_binding(0, 32, thread = "blockIdx.y"): - for bx in tir.thread_binding(0, 32, thread = "blockIdx.x"): - for vy in tir.thread_binding(0, 2, thread = "vthread.y"): - for vx in tir.thread_binding(0, 2, thread = "vthread.x"): - for ty in tir.thread_binding(0, 8, thread = "threadIdx.y"): - for tx in tir.thread_binding(0, 8, thread = "threadIdx.x"): - for i, j in tir.grid(4, 4): - with tir.block([2048, 2048], "C_local") as [v0_4, v1_4]: - tir.bind(v0_4, by * 64 + vy * 32 + ty * 4 + i) - tir.bind(v1_4, bx * 64 + vx * 32 + tx * 4 + j) + for by in T.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in T.thread_binding(0, 2, thread = "vthread.y"): + for vx in T.thread_binding(0, 2, thread = "vthread.x"): + for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): + for i, j in T.grid(4, 4): + with T.block([2048, 2048], "C_local") as [v0_4, v1_4]: + T.bind(v0_4, by * 64 + vy * 32 + ty * 4 + i) + T.bind(v1_4, bx * 64 + vx * 32 + tx * 4 + j) C[v0_4, v1_4] = C_local[v0_4, v1_4] -@tvm.script.tir -def cuda_matmul_0_after_compute_at(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable - A = tir.match_buffer(a, [2048, 2048], "float32") - B = tir.match_buffer(b, [2048, 2048], "float32") - C = tir.match_buffer(c, [2048, 2048], "float32") - A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") - B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") - A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - with tir.block([2048, 2048], "A_shared") as [v0, v1]: +@T.prim_func +def cuda_matmul_0_after_compute_at(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable + A = T.match_buffer(a, [2048, 2048], "float32") + B = T.match_buffer(b, [2048, 2048], "float32") + C = T.match_buffer(c, [2048, 2048], "float32") + A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") + A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + with T.block([2048, 2048], "A_shared") as [v0, v1]: A_shared[v0, v1] = A[v0, v1] - with tir.block([2048, 2048], "B_shared") as [v0, v1]: + with T.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] - with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: + with T.block([2048, 2048], "A_shared_local") as [v0, v1]: A_shared_local[v0, v1] = A_shared[v0, v1] - with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: + with T.block([2048, 2048], "B_shared_local") as [v0, v1]: B_shared_local[v0, v1] = B_shared[v0, v1] - for by in tir.thread_binding(0, 32, thread = "blockIdx.y"): - for bx in tir.thread_binding(0, 32, thread = "blockIdx.x"): - for vy in tir.thread_binding(0, 2, thread = "vthread.y"): - for vx in tir.thread_binding(0, 2, thread = "vthread.x"): - for ty in tir.thread_binding(0, 8, thread = "threadIdx.y"): - for tx in tir.thread_binding(0, 8, thread = "threadIdx.x"): - for i, j, k in tir.grid(4, 4, 2048): - with tir.block([2048, 2048, tir.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - tir.bind(vk, k) - with tir.init(): + for by in T.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in T.thread_binding(0, 2, thread = "vthread.y"): + for vx in T.thread_binding(0, 2, thread = "vthread.x"): + for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): + for i, j, k in T.grid(4, 4, 2048): + with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: + T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + T.bind(vk, k) + with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] - for i, j in tir.grid(4, 4): - with tir.block([2048, 2048], "C_local") as [vi, vj]: - tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + for i, j in T.grid(4, 4): + with T.block([2048, 2048], "C_local") as [vi, vj]: + T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj] -@tvm.script.tir -def cuda_matmul_1(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable - A = tir.match_buffer(a, [2048, 2048], "float32") - B = tir.match_buffer(b, [2048, 2048], "float32") - C = tir.match_buffer(c, [2048, 2048], "float32") - A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") - B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") - A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - with tir.block([2048, 2048], "A_shared") as [v0, v1]: +@T.prim_func +def cuda_matmul_1(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable + A = T.match_buffer(a, [2048, 2048], "float32") + B = T.match_buffer(b, [2048, 2048], "float32") + C = T.match_buffer(c, [2048, 2048], "float32") + A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") + A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + with T.block([2048, 2048], "A_shared") as [v0, v1]: A_shared[v0, v1] = A[v0, v1] - with tir.block([2048, 2048], "B_shared") as [v0, v1]: + with T.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] - with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: + with T.block([2048, 2048], "A_shared_local") as [v0, v1]: A_shared_local[v0, v1] = A_shared[v0, v1] - with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: + with T.block([2048, 2048], "B_shared_local") as [v0, v1]: B_shared_local[v0, v1] = B_shared[v0, v1] - for by in tir.thread_binding(0, 32, thread = "blockIdx.y"): - for bx in tir.thread_binding(0, 32, thread = "blockIdx.x"): - for vy in tir.thread_binding(0, 2, thread = "vthread.y"): - for vx in tir.thread_binding(0, 2, thread = "vthread.x"): - for ty in tir.thread_binding(0, 8, thread = "threadIdx.y"): - for tx in tir.thread_binding(0, 8, thread = "threadIdx.x"): - for k_0 in tir.serial(0, 256): - for k_1 in tir.unroll(0, 8): - for _, i, j in tir.grid(1, 4, 4): - with tir.block([2048, 2048, tir.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - tir.bind(vk, k_0 * 8 + k_1) - with tir.init(): + for by in T.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in T.thread_binding(0, 2, thread = "vthread.y"): + for vx in T.thread_binding(0, 2, thread = "vthread.x"): + for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): + for k_0 in T.serial(0, 256): + for k_1 in T.unroll(0, 8): + for _, i, j in T.grid(1, 4, 4): + with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: + T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + T.bind(vk, k_0 * 8 + k_1) + with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] - for i, j in tir.grid(4, 4): - with tir.block([2048, 2048], "C_local") as [vi, vj]: - tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + for i, j in T.grid(4, 4): + with T.block([2048, 2048], "C_local") as [vi, vj]: + T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj] -@tvm.script.tir -def cuda_matmul_2(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable - A = tir.match_buffer(a, [2048, 2048], "float32") - B = tir.match_buffer(b, [2048, 2048], "float32") - C = tir.match_buffer(c, [2048, 2048], "float32") - A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") - B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") - A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - with tir.block([2048, 2048], "A_shared") as [v0, v1]: +@T.prim_func +def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable + A = T.match_buffer(a, [2048, 2048], "float32") + B = T.match_buffer(b, [2048, 2048], "float32") + C = T.match_buffer(c, [2048, 2048], "float32") + A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") + A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + with T.block([2048, 2048], "A_shared") as [v0, v1]: A_shared[v0, v1] = A[v0, v1] - with tir.block([2048, 2048], "B_shared") as [v0, v1]: + with T.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] - with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: + with T.block([2048, 2048], "B_shared_local") as [v0, v1]: B_shared_local[v0, v1] = B_shared[v0, v1] - for by in tir.thread_binding(0, 32, thread = "blockIdx.y"): - for bx in tir.thread_binding(0, 32, thread = "blockIdx.x"): - for vy in tir.thread_binding(0, 2, thread = "vthread.y"): - for vx in tir.thread_binding(0, 2, thread = "vthread.x"): - for ty in tir.thread_binding(0, 8, thread = "threadIdx.y"): - for tx in tir.thread_binding(0, 8, thread = "threadIdx.x"): - for k_0 in tir.serial(0, 256): - for k_1 in tir.unroll(0, 8): - for i, j in tir.grid(1, 4): - with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: - tir.bind(v0, k_0 * 8 + k_1 + i) - tir.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + for by in T.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in T.thread_binding(0, 2, thread = "vthread.y"): + for vx in T.thread_binding(0, 2, thread = "vthread.x"): + for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): + for k_0 in T.serial(0, 256): + for k_1 in T.unroll(0, 8): + for i, j in T.grid(1, 4): + with T.block([2048, 2048], "A_shared_local") as [v0, v1]: + T.bind(v0, k_0 * 8 + k_1 + i) + T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] - for _, i, j in tir.grid(1, 4, 4): - with tir.block([2048, 2048, tir.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - tir.bind(vk, k_0 * 8 + k_1) - with tir.init(): - C_local[vi, vj] = tir.float32(0) + for _, i, j in T.grid(1, 4, 4): + with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: + T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + T.bind(vk, k_0 * 8 + k_1) + with T.init(): + C_local[vi, vj] = T.float32(0) C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] - for i, j in tir.grid(4, 4): - with tir.block([2048, 2048], "C_local") as [v0, v1]: - tir.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - tir.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + for i, j in T.grid(4, 4): + with T.block([2048, 2048], "C_local") as [v0, v1]: + T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) + T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] -@tvm.script.tir -def cuda_matmul_3(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable - A = tir.match_buffer(a, [2048, 2048], "float32") - B = tir.match_buffer(b, [2048, 2048], "float32") - C = tir.match_buffer(c, [2048, 2048], "float32") - A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") - B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") - A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - with tir.block([2048, 2048], "A_shared") as [v0, v1]: +@T.prim_func +def cuda_matmul_3(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable + A = T.match_buffer(a, [2048, 2048], "float32") + B = T.match_buffer(b, [2048, 2048], "float32") + C = T.match_buffer(c, [2048, 2048], "float32") + A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") + A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + with T.block([2048, 2048], "A_shared") as [v0, v1]: A_shared[v0, v1] = A[v0, v1] - with tir.block([2048, 2048], "B_shared") as [v0, v1]: + with T.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] - for by in tir.thread_binding(0, 32, thread = "blockIdx.y"): - for bx in tir.thread_binding(0, 32, thread = "blockIdx.x"): - for vy in tir.thread_binding(0, 2, thread = "vthread.y"): - for vx in tir.thread_binding(0, 2, thread = "vthread.x"): - for ty in tir.thread_binding(0, 8, thread = "threadIdx.y"): - for tx in tir.thread_binding(0, 8, thread = "threadIdx.x"): - for k0 in tir.serial(0, 256): - for k1 in tir.unroll(0, 8): - for i, j in tir.grid(1, 4): - with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: - tir.bind(v0, k0 * 8 + k1 + i) - tir.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + for by in T.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in T.thread_binding(0, 2, thread = "vthread.y"): + for vx in T.thread_binding(0, 2, thread = "vthread.x"): + for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): + for k0 in T.serial(0, 256): + for k1 in T.unroll(0, 8): + for i, j in T.grid(1, 4): + with T.block([2048, 2048], "A_shared_local") as [v0, v1]: + T.bind(v0, k0 * 8 + k1 + i) + T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] - for i, j in tir.grid(1, 4): - with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: - tir.bind(v0, k0 * 8 + k1 + i) - tir.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + for i, j in T.grid(1, 4): + with T.block([2048, 2048], "B_shared_local") as [v0, v1]: + T.bind(v0, k0 * 8 + k1 + i) + T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] - for _, i, j in tir.grid(1, 4, 4): - with tir.block([2048, 2048, tir.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - tir.bind(vk, k0 * 8 + k1) - with tir.init(): - C_local[vi, vj] = tir.float32(0) + for _, i, j in T.grid(1, 4, 4): + with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: + T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + T.bind(vk, k0 * 8 + k1) + with T.init(): + C_local[vi, vj] = T.float32(0) C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] - for i, j in tir.grid(4, 4): - with tir.block([2048, 2048], "C_local") as [v0, v1]: - tir.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - tir.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + for i, j in T.grid(4, 4): + with T.block([2048, 2048], "C_local") as [v0, v1]: + T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) + T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] -@tvm.script.tir -def cuda_matmul_4(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable - A = tir.match_buffer(a, [2048, 2048], "float32") - B = tir.match_buffer(b, [2048, 2048], "float32") - C = tir.match_buffer(c, [2048, 2048], "float32") - A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") - B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") - A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - with tir.block([2048, 2048], "B_shared") as [v0, v1]: +@T.prim_func +def cuda_matmul_4(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable + A = T.match_buffer(a, [2048, 2048], "float32") + B = T.match_buffer(b, [2048, 2048], "float32") + C = T.match_buffer(c, [2048, 2048], "float32") + A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") + A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + with T.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] - for by in tir.thread_binding(0, 32, thread = "blockIdx.y"): - for bx in tir.thread_binding(0, 32, thread = "blockIdx.x"): - for vy in tir.thread_binding(0, 2, thread = "vthread.y"): - for vx in tir.thread_binding(0, 2, thread = "vthread.x"): - for ty in tir.thread_binding(0, 8, thread = "threadIdx.y"): - for tx in tir.thread_binding(0, 8, thread = "threadIdx.x"): - for k0 in tir.serial(0, 256): - for i, j in tir.grid(8, 64): - with tir.block([2048, 2048], "A_shared") as [v0, v1]: - tir.bind(v0, k0 * 8 + i) - tir.bind(v1, by * 64 + j) + for by in T.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in T.thread_binding(0, 2, thread = "vthread.y"): + for vx in T.thread_binding(0, 2, thread = "vthread.x"): + for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): + for k0 in T.serial(0, 256): + for i, j in T.grid(8, 64): + with T.block([2048, 2048], "A_shared") as [v0, v1]: + T.bind(v0, k0 * 8 + i) + T.bind(v1, by * 64 + j) A_shared[v0, v1] = A[v0, v1] - for k1 in tir.unroll(0, 8): - for i, j in tir.grid(1, 4): - with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: - tir.bind(v0, k0 * 8 + k1 + i) - tir.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + for k1 in T.unroll(0, 8): + for i, j in T.grid(1, 4): + with T.block([2048, 2048], "A_shared_local") as [v0, v1]: + T.bind(v0, k0 * 8 + k1 + i) + T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] - for i, j in tir.grid(1, 4): - with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: - tir.bind(v0, k0 * 8 + k1 + i) - tir.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + for i, j in T.grid(1, 4): + with T.block([2048, 2048], "B_shared_local") as [v0, v1]: + T.bind(v0, k0 * 8 + k1 + i) + T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] - for _, i, j in tir.grid(1, 4, 4): - with tir.block([2048, 2048, tir.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - tir.bind(vk, k0 * 8 + k1) - with tir.init(): + for _, i, j in T.grid(1, 4, 4): + with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: + T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + T.bind(vk, k0 * 8 + k1) + with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] - for i, j in tir.grid(4, 4): - with tir.block([2048, 2048], "C_local") as [v0, v1]: - tir.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - tir.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + for i, j in T.grid(4, 4): + with T.block([2048, 2048], "C_local") as [v0, v1]: + T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) + T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] -@tvm.script.tir -def cuda_matmul_5(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable - A = tir.match_buffer(a, [2048, 2048], "float32") - B = tir.match_buffer(b, [2048, 2048], "float32") - C = tir.match_buffer(c, [2048, 2048], "float32") - A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") - B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") - A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") - for by in tir.thread_binding(0, 32, thread = "blockIdx.y"): - for bx in tir.thread_binding(0, 32, thread = "blockIdx.x"): - for vy in tir.thread_binding(0, 2, thread = "vthread.y"): - for vx in tir.thread_binding(0, 2, thread = "vthread.x"): - for ty in tir.thread_binding(0, 8, thread = "threadIdx.y"): - for tx in tir.thread_binding(0, 8, thread = "threadIdx.x"): - for k0 in tir.serial(0, 256): - for i, j in tir.grid(8, 64): - with tir.block([2048, 2048], "A_shared") as [v0, v1]: - tir.bind(v0, k0 * 8 + i) - tir.bind(v1, by * 64 + j) +@T.prim_func +def cuda_matmul_5(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable + A = T.match_buffer(a, [2048, 2048], "float32") + B = T.match_buffer(b, [2048, 2048], "float32") + C = T.match_buffer(c, [2048, 2048], "float32") + A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") + A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") + for by in T.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in T.thread_binding(0, 2, thread = "vthread.y"): + for vx in T.thread_binding(0, 2, thread = "vthread.x"): + for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): + for k0 in T.serial(0, 256): + for i, j in T.grid(8, 64): + with T.block([2048, 2048], "A_shared") as [v0, v1]: + T.bind(v0, k0 * 8 + i) + T.bind(v1, by * 64 + j) A_shared[v0, v1] = A[v0, v1] - for i, j in tir.grid(8, 64): - with tir.block([2048, 2048], "B_shared") as [v0, v1]: - tir.bind(v0, k0 * 8 + i) - tir.bind(v1, bx * 64 + j) + for i, j in T.grid(8, 64): + with T.block([2048, 2048], "B_shared") as [v0, v1]: + T.bind(v0, k0 * 8 + i) + T.bind(v1, bx * 64 + j) B_shared[v0, v1] = B[v0, v1] - for k1 in tir.unroll(0, 8): - for i, j in tir.grid(1, 4): - with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: - tir.bind(v0, k0 * 8 + k1 + i) - tir.bind(v1, by * 64 + vy * 32 + ty * 4 + j) + for k1 in T.unroll(0, 8): + for i, j in T.grid(1, 4): + with T.block([2048, 2048], "A_shared_local") as [v0, v1]: + T.bind(v0, k0 * 8 + k1 + i) + T.bind(v1, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] - for i, j in tir.grid(1, 4): - with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: - tir.bind(v0, k0 * 8 + k1 + i) - tir.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + for i, j in T.grid(1, 4): + with T.block([2048, 2048], "B_shared_local") as [v0, v1]: + T.bind(v0, k0 * 8 + k1 + i) + T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) B_shared_local[v0, v1] = B_shared[v0, v1] - for _, i, j in tir.grid(1, 4, 4): - with tir.block([2048, 2048, tir.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: - tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) - tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) - tir.bind(vk, k0 * 8 + k1) - with tir.init(): + for _, i, j in T.grid(1, 4, 4): + with T.block([2048, 2048, T.reduce_axis(0, 2048)], "C") as [vi, vj, vk]: + T.bind(vi, by * 64 + vy * 32 + ty * 4 + i) + T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) + T.bind(vk, k0 * 8 + k1) + with T.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj] - for i, j in tir.grid(4, 4): - with tir.block([2048, 2048], "C_local") as [v0, v1]: - tir.bind(v0, by * 64 + vy * 32 + ty * 4 + i) - tir.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) + for i, j in T.grid(4, 4): + with T.block([2048, 2048], "C_local") as [v0, v1]: + T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) + T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1] -@tvm.script.tir -def tiled(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128], "float32") - B = tir.alloc_buffer([128, 128], "float32") - C = tir.match_buffer(c, [128, 128], "float32") - for i_0, j_0, i_1, j_1 in tir.grid(8, 8, 16, 16): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i_0 * 16 + i_1) - tir.bind(vj, j_0 * 16 + j_1) +@T.prim_func +def tiled(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], "float32") + B = T.alloc_buffer([128, 128], "float32") + C = T.match_buffer(c, [128, 128], "float32") + for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i_0 * 16 + i_1) + T.bind(vj, j_0 * 16 + j_1) B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def tiled_after_reverse_compute_at(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128], "float32") - B = tir.alloc_buffer([128, 128], "float32") - C = tir.match_buffer(c, [128, 128], "float32") - for i_0, j_0, i_1 in tir.grid(8, 8, 16): - for j_1 in tir.serial(0, 16): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i_0 * 16 + i_1) - tir.bind(vj, j_0 * 16 + j_1) +@T.prim_func +def tiled_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], "float32") + B = T.alloc_buffer([128, 128], "float32") + C = T.match_buffer(c, [128, 128], "float32") + for i_0, j_0, i_1 in T.grid(8, 8, 16): + for j_1 in T.serial(0, 16): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i_0 * 16 + i_1) + T.bind(vj, j_0 * 16 + j_1) B[vi, vj] = A[vi, vj] * 2.0 - for j_1 in tir.serial(0, 16): - with tir.block([128, 128], "C") as [vi, vj]: - tir.bind(vi, i_0 * 16 + i_1) - tir.bind(vj, j_0 * 16 + j_1) + for j_1 in T.serial(0, 16): + with T.block([128, 128], "C") as [vi, vj]: + T.bind(vi, i_0 * 16 + i_1) + T.bind(vj, j_0 * 16 + j_1) C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def factorized(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [16, 16, 16], "float32") - B = tir.match_buffer(b, [16], "float32") - B_rf_local = tir.alloc_buffer([16, 16], "float32", scope="local") - for j in tir.thread_binding(0, 16, thread = "blockIdx.x"): - for i_o in tir.thread_binding(0, 4, thread = "threadIdx.x"): - for i_i, k in tir.grid(4, 16): - with tir.block([16, 16, tir.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: - tir.bind(vi, i_o * 4 + i_i) - tir.bind(vj, j) - tir.bind(vk, k) - with tir.init(): +@T.prim_func +def factorized(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16], "float32") + B = T.match_buffer(b, [16], "float32") + B_rf_local = T.alloc_buffer([16, 16], "float32", scope="local") + for j in T.thread_binding(0, 16, thread = "blockIdx.x"): + for i_o in T.thread_binding(0, 4, thread = "threadIdx.x"): + for i_i, k in T.grid(4, 16): + with T.block([16, 16, T.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: + T.bind(vi, i_o * 4 + i_i) + T.bind(vj, j) + T.bind(vk, k) + with T.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] - for i, k in tir.grid(16, 16): - with tir.block([16, tir.reduce_axis(0, 16)], "B") as [vi, vk]: - tir.bind(vi, i) - tir.bind(vk, k) - with tir.init(): + for i, k in T.grid(16, 16): + with T.block([16, T.reduce_axis(0, 16)], "B") as [vi, vk]: + T.bind(vi, i) + T.bind(vk, k) + with T.init(): B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi] -@tvm.script.tir -def factorized_after_reverse_compute_at(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [16, 16, 16], "float32") - B = tir.match_buffer(b, [16], "float32") - B_rf_local = tir.alloc_buffer([16, 16], "float32", scope="local") - for j in tir.thread_binding(0, 16, thread = "blockIdx.x"): - for i_o in tir.thread_binding(0, 4, thread = "threadIdx.x"): - for i_i, k in tir.grid(4, 16): - with tir.block([16, 16, tir.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: - tir.bind(vi, i_o * 4 + i_i) - tir.bind(vj, j) - tir.bind(vk, k) - with tir.init(): +@T.prim_func +def factorized_after_reverse_compute_at(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16], "float32") + B = T.match_buffer(b, [16], "float32") + B_rf_local = T.alloc_buffer([16, 16], "float32", scope="local") + for j in T.thread_binding(0, 16, thread = "blockIdx.x"): + for i_o in T.thread_binding(0, 4, thread = "threadIdx.x"): + for i_i, k in T.grid(4, 16): + with T.block([16, 16, T.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: + T.bind(vi, i_o * 4 + i_i) + T.bind(vj, j) + T.bind(vk, k) + with T.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] - for k in tir.serial(0, 4): - with tir.block([16, tir.reduce_axis(0, 16)], "B") as [vi, vk]: - tir.bind(vi, j) - tir.bind(vk, i_o * 4 + k) - with tir.init(): + for k in T.serial(0, 4): + with T.block([16, T.reduce_axis(0, 16)], "B") as [vi, vk]: + T.bind(vi, j) + T.bind(vk, i_o * 4 + k) + with T.init(): B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi] -@tvm.script.tir -def fail_subtree_compact_dataflow(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") +@T.prim_func +def fail_subtree_compact_dataflow(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") for i in range(0, 128): for j in range(0, 64): - with tir.block([128, 128], "B_0") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + with T.block([128, 128], "B_0") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 for j in range(0, 64): - with tir.block([128, 128], "B_1") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j + 64) + with T.block([128, 128], "B_1") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j + 64) B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def fail_all_consumers_under_loop(a: ty.handle, c: ty.handle, d: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") - D = tir.match_buffer(d, (128, 128), "float32") - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def fail_all_consumers_under_loop(a: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + D = T.match_buffer(d, (128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "C") as [vi, vj]: + for i, j in T.grid(128, 128): + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "D") as [vi, vj]: + for i, j in T.grid(128, 128): + with T.block([128, 128], "D") as [vi, vj]: D[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def fail_all_producers_under_loop(a: ty.handle, d: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - C = tir.alloc_buffer((128, 128), "float32") - D = tir.match_buffer(d, (128, 128), "float32") - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def fail_all_producers_under_loop(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.alloc_buffer((128, 128), "float32") + D = T.match_buffer(d, (128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "C") as [vi, vj]: + for i, j in T.grid(128, 128): + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = A[vi, vj] + 1.0 - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "D") as [vi, vj]: + for i, j in T.grid(128, 128): + with T.block([128, 128], "D") as [vi, vj]: D[vi, vj] = B[vi, vj] + C[vi, vj] -@tvm.script.tir -def read_out_of_bound(a: ty.handle, c:ty.handle) -> None: - A = tir.match_buffer(a, [16], "float32") - B = tir.alloc_buffer([16], "float32") - C = tir.match_buffer(c, [16], "float32") - for i in tir.serial(0, 16): - with tir.block([16], "B") as [v]: +@T.prim_func +def read_out_of_bound(a: T.handle, c:T.handle) -> None: + A = T.match_buffer(a, [16], "float32") + B = T.alloc_buffer([16], "float32") + C = T.match_buffer(c, [16], "float32") + for i in T.serial(0, 16): + with T.block([16], "B") as [v]: B[v] = A[v] - for j in tir.serial(0, 16): - with tir.block([16], "C") as [v]: - tir.reads(B[v : v + 2]) - C[v] = tir.if_then_else(v < 15, tir.max(B[v], B[v + 1]), B[v], dtype="float32") - - -@tvm.script.tir -def read_out_of_bound_after_compute_at(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [16], "float32") - B = tir.alloc_buffer([16], "float32") - C = tir.match_buffer(c, [16], "float32") - for j in tir.serial(0, 16): - for i in tir.serial(0, tir.min(1, 15 - j) + 1): - with tir.block([16], "B") as [v]: - tir.bind(v, j + i) + for j in T.serial(0, 16): + with T.block([16], "C") as [v]: + T.reads(B[v : v + 2]) + C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") + + +@T.prim_func +def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16], "float32") + B = T.alloc_buffer([16], "float32") + C = T.match_buffer(c, [16], "float32") + for j in T.serial(0, 16): + for i in T.serial(0, T.min(1, 15 - j) + 1): + with T.block([16], "B") as [v]: + T.bind(v, j + i) B[v] = A[v] - with tir.block([16], "C") as [v]: - tir.bind(v, j) - tir.reads([B[v : v + 2]]) - C[v] = tir.if_then_else(v < 15, tir.max(B[v], B[v + 1]), B[v], dtype="float32") + with T.block([16], "C") as [v]: + T.bind(v, j) + T.reads([B[v : v + 2]]) + C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index ea322920b8466..f9049f6da732b 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -20,203 +20,203 @@ import pytest import tvm from tvm import tir -from tvm.script import ty +from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip # pylint: disable=no-member,invalid-name,unused-variable -@tvm.script.tir -def elementwise(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def elementwise_multi_producer_consumer(a: ty.handle, c: ty.handle, d: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - D = tir.match_buffer(d, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def elementwise_multi_producer_consumer(a: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + D = T.match_buffer(d, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 # B has two consumers - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 - with tir.block([128, 128], "D") as [vi, vj]: + with T.block([128, 128], "D") as [vi, vj]: D[vi, vj] = B[vi, vj] + 2.0 + C[vi, vj] # D has two producers -@tvm.script.tir -def elementwise_multi_consumer_inlined(a: ty.handle, c: ty.handle, d: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - D = tir.match_buffer(d, (128, 128)) - with tir.block([128, 128], "C") as [vi, vj]: +@T.prim_func +def elementwise_multi_consumer_inlined(a: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + D = T.match_buffer(d, (128, 128)) + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = A[vi, vj] * 2.0 + 1.0 - with tir.block([128, 128], "D") as [vi, vj]: + with T.block([128, 128], "D") as [vi, vj]: D[vi, vj] = A[vi, vj] * 2.0 + 2.0 + C[vi, vj] -@tvm.script.tir -def elementwise_standalone(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def elementwise_standalone(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = A[vi, vj] + 1.0 -@tvm.script.tir -def elementwise_standalone_dce(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "C") as [vi, vj]: +@T.prim_func +def elementwise_standalone_dce(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = A[vi, vj] + 1.0 -@tvm.script.tir -def elementwise_under_loop(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - B = tir.alloc_buffer((128, 128)) - for i in tir.serial(0, 128): - for j in tir.serial(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) +@T.prim_func +def elementwise_under_loop(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + B = T.alloc_buffer((128, 128)) + for i in T.serial(0, 128): + for j in T.serial(0, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 - for j in tir.serial(0, 128): - with tir.block([128, 128], "C") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + for j in T.serial(0, 128): + with T.block([128, 128], "C") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def elementwise_inlined(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "C") as [vi, vj]: +@T.prim_func +def elementwise_inlined(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = A[vi, vj] * 2.0 + 1.0 -@tvm.script.tir -def fail_multi_reader_writer(a: ty.handle, d: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.alloc_buffer((128, 128)) - D = tir.match_buffer(d, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def fail_multi_reader_writer(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.alloc_buffer((128, 128)) + D = T.match_buffer(d, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 C[vi, vj] = A[vi, vj] + 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: D[vi, vj] = B[vi, vj] + C[vi, vj] -@tvm.script.tir -def elementwise_multi_reverse_loads(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def elementwise_multi_reverse_loads(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = (B[vi, vj] + 1.0) * (B[vi, vj] * 2.0) + 3.0 -@tvm.script.tir -def elementwise_multi_reverse_loads_inlined(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def elementwise_multi_reverse_loads_inlined(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: C[vi, vj] = (A[vi, vj] * 2.0 + 1.0) * (A[vi, vj] * 2.0 * 2.0) + 3.0 -@tvm.script.tir -def opaque_access_load(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def opaque_access_load(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: - tir.reads(B[0:128, 0:128]) - tir.writes(C[0:128, 0:128]) - C[vi, vj] = tir.load("float32", B.data, vi * 128 + vj) + 1.0 - - -@tvm.script.tir -def opaque_access_store(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: + T.reads(B[0:128, 0:128]) + T.writes(C[0:128, 0:128]) + C[vi, vj] = T.load("float32", B.data, vi * 128 + vj) + 1.0 + + +@T.prim_func +def opaque_access_store(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: - tir.reads(B[0:128, 0:128]) - tir.writes(C[0:128, 0:128]) - tir.store(C.data, vi * 128 + vj, B[vi, vj] + 1.0) - C[vi, vj] = tir.load("float32", B.data, vi * 16 + vj) + 1.0 - - -@tvm.script.tir -def buffer_matched(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: + T.reads(B[0:128, 0:128]) + T.writes(C[0:128, 0:128]) + T.store(C.data, vi * 128 + vj, B[vi, vj] + 1.0) + C[vi, vj] = T.load("float32", B.data, vi * 16 + vj) + 1.0 + + +@T.prim_func +def buffer_matched(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: - Bb = tir.match_buffer(B[vi : vi + 1, vj], (1, 1)) + with T.block([128, 128], "C") as [vi, vj]: + Bb = T.match_buffer(B[vi : vi + 1, vj], (1, 1)) C[vi, vj] = Bb[0, 0] + 1.0 -@tvm.script.tir -def elementwise_predicate(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def elementwise_predicate(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "C") as [vi, vj]: - tir.where(B[i, j] < 10.0) + for i, j in T.grid(128, 128): + with T.block([128, 128], "C") as [vi, vj]: + T.where(B[i, j] < 10.0) C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def elementwise_predicate_inlined(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "C") as [vi, vj]: - tir.where(A[i, j] * 2.0 < 10.0) +@T.prim_func +def elementwise_predicate_inlined(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + for i, j in T.grid(128, 128): + with T.block([128, 128], "C") as [vi, vj]: + T.where(A[i, j] * 2.0 < 10.0) C[vi, vj] = A[vi, vj] * 2.0 + 1.0 -@tvm.script.tir -def elementwise_multi_loads(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def elementwise_multi_loads(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 126], "C") as [vi, vj]: + with T.block([128, 126], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + B[vi, vj + 1] + B[vi, vj + 2] -@tvm.script.tir -def elementwise_multi_loads_inlined(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 126], "C") as [vi, vj]: +@T.prim_func +def elementwise_multi_loads_inlined(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 126], "C") as [vi, vj]: C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2] * 2.0 diff --git a/tests/python/unittest/test_tir_schedule_error.py b/tests/python/unittest/test_tir_schedule_error.py index 6fcd0dc2aedca..7a9c8e01d3554 100644 --- a/tests/python/unittest/test_tir_schedule_error.py +++ b/tests/python/unittest/test_tir_schedule_error.py @@ -20,21 +20,21 @@ import pytest import tvm from tvm import tir -from tvm.script import ty +from tvm.script import tir as T # pylint: disable=no-member,invalid-name,unused-variable -@tvm.script.tir -def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "init") as [vi, vj]: - C[vi, vj] = tir.float32(0) +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j in T.grid(128, 128): + with T.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = T.float32(0) for k in range(0, 128): - with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py index 5649a06bd3b8d..60269ac01c14d 100644 --- a/tests/python/unittest/test_tir_schedule_for_kind.py +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -21,221 +21,221 @@ import tvm import tvm.testing from tvm import tir -from tvm.script import ty +from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip # pylint: disable=no-member,invalid-name,unused-variable -@tvm.script.tir -def element_wise(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) +@T.prim_func +def element_wise(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 -@tvm.script.tir -def element_wise_parallelized(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i0 in tir.parallel(0, 128): - for i1 in tir.serial(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i0) - tir.bind(vj, i1) +@T.prim_func +def element_wise_parallelized(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i0 in T.parallel(0, 128): + for i1 in T.serial(0, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i0) + T.bind(vj, i1) B[vi, vj] = A[vi, vj] * 2.0 -@tvm.script.tir -def element_wise_i_bound(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - for i0 in tir.thread_binding(0, 128, thread="threadIdx.x"): - for i1 in tir.serial(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i0) - tir.bind(vj, i1) +@T.prim_func +def element_wise_i_bound(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i0 in T.thread_binding(0, 128, thread="threadIdx.x"): + for i1 in T.serial(0, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i0) + T.bind(vj, i1) B[vi, vj] = A[vi, vj] * 2.0 -@tvm.script.tir -def element_wise_compute_at_split(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - B = tir.alloc_buffer((128, 128)) - for i in tir.serial(0, 128): - for j0 in tir.serial(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j0) +@T.prim_func +def element_wise_compute_at_split(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + B = T.alloc_buffer((128, 128)) + for i in T.serial(0, 128): + for j0 in T.serial(0, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j0) B[vi, vj] = A[vi, vj] * 2.0 - for j1o, j1i in tir.grid(32, 4): - with tir.block([128, 128], "C") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j1o * 4 + j1i) + for j1o, j1i in T.grid(32, 4): + with T.block([128, 128], "C") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def element_wise_compute_at_split_vectorized(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - B = tir.alloc_buffer((128, 128)) - for i in tir.serial(0, 128): - for j0 in tir.serial(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j0) +@T.prim_func +def element_wise_compute_at_split_vectorized(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + B = T.alloc_buffer((128, 128)) + for i in T.serial(0, 128): + for j0 in T.serial(0, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j0) B[vi, vj] = A[vi, vj] * 2.0 - for j1o in tir.serial(0, 32): - for j1i in tir.vectorized(0, 4): - with tir.block([128, 128], "C") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j1o * 4 + j1i) + for j1o in T.serial(0, 32): + for j1i in T.vectorized(0, 4): + with T.block([128, 128], "C") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def element_wise_split_predicate(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - for i, j_0, j_1 in tir.grid(128, 13, 10): - with tir.block([128, 128], "B") as [vi, vj]: - tir.where(j_0 * 10 + j_1 < 128) - tir.bind(vi, i) - tir.bind(vj, j_0 * 10 + j_1) +@T.prim_func +def element_wise_split_predicate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + for i, j_0, j_1 in T.grid(128, 13, 10): + with T.block([128, 128], "B") as [vi, vj]: + T.where(j_0 * 10 + j_1 < 128) + T.bind(vi, i) + T.bind(vj, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0 -@tvm.script.tir -def element_wise_split_predicate_parallelized(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - for i in tir.serial(0, 128): - for j_0 in tir.parallel(0, 13): - for j_1 in tir.serial(0, 10): - with tir.block([128, 128], "B") as [vi, vj]: - tir.where(j_0 * 10 + j_1 < 128) - tir.bind(vi, i) - tir.bind(vj, j_0 * 10 + j_1) +@T.prim_func +def element_wise_split_predicate_parallelized(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + for i in T.serial(0, 128): + for j_0 in T.parallel(0, 13): + for j_1 in T.serial(0, 10): + with T.block([128, 128], "B") as [vi, vj]: + T.where(j_0 * 10 + j_1 < 128) + T.bind(vi, i) + T.bind(vj, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0 -@tvm.script.tir -def element_wise_split_predicate_vectorized(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - for i in tir.vectorized(0, 128): - for j_0, j_1 in tir.grid(13, 10): - with tir.block([128, 128], "B") as [vi, vj]: - tir.where(j_0 * 10 + j_1 < 128) - tir.bind(vi, i) - tir.bind(vj, j_0 * 10 + j_1) +@T.prim_func +def element_wise_split_predicate_vectorized(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + for i in T.vectorized(0, 128): + for j_0, j_1 in T.grid(13, 10): + with T.block([128, 128], "B") as [vi, vj]: + T.where(j_0 * 10 + j_1 < 128) + T.bind(vi, i) + T.bind(vj, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0 -@tvm.script.tir -def element_wise_compute_at_split_j0_j1o_bound(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - B = tir.alloc_buffer((128, 128)) - for i in tir.serial(0, 128): - for j0 in tir.thread_binding(0, 128, thread="threadIdx.x"): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j0) +@T.prim_func +def element_wise_compute_at_split_j0_j1o_bound(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + B = T.alloc_buffer((128, 128)) + for i in T.serial(0, 128): + for j0 in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j0) B[vi, vj] = A[vi, vj] * 2.0 - for j1o in tir.thread_binding(0, 32, thread="threadIdx.x"): - for j1i in tir.serial(0, 4): - with tir.block([128, 128], "C") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j1o * 4 + j1i) + for j1o in T.thread_binding(0, 32, thread="threadIdx.x"): + for j1i in T.serial(0, 4): + with T.block([128, 128], "C") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - C = tir.match_buffer(c, (128, 128)) +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) - with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: - with tir.init(): + with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@tvm.script.tir -def rowsum(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128,)) +@T.prim_func +def rowsum(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) - with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: - with tir.init(): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] -@tvm.script.tir -def rowsum_unrolled(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128,)) - for i0 in tir.unroll(0, 128): - for i1 in tir.serial(0, 128): - with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: - tir.bind(vi, i0) - tir.bind(vk, i1) - with tir.init(): +@T.prim_func +def rowsum_unrolled(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) + for i0 in T.unroll(0, 128): + for i1 in T.serial(0, 128): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + T.bind(vi, i0) + T.bind(vk, i1) + with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] -@tvm.script.tir -def rowsum_not_quasi_affine(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128,)) +@T.prim_func +def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) - for i, k in tir.grid(128, 16): - with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: - tir.bind(vi, i) - tir.bind(vk, tir.floordiv(k * k, 2)) - with tir.init(): + for i, k in T.grid(128, 16): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + T.bind(vi, i) + T.bind(vk, T.floordiv(k * k, 2)) + with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] -@tvm.script.tir -def rowsum_not_compact_data_flow(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128,)) +@T.prim_func +def rowsum_not_compact_data_flow(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) - with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: - with tir.init(): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + with T.init(): B[vk] = 0.0 B[vk] = B[vk] + A[vi, vk] -@tvm.script.tir -def rowsum_cross_thread_reduction(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128,)) - for i0 in tir.serial(0, 128): - for i1 in tir.thread_binding(0, 128, thread="threadIdx.x"): - with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: - tir.bind(vi, i0) - tir.bind(vk, i1) - with tir.init(): +@T.prim_func +def rowsum_cross_thread_reduction(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) + for i0 in T.serial(0, 128): + for i1 in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + T.bind(vi, i0) + T.bind(vk, i1) + with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] -@tvm.script.tir -def opaque_block(a: ty.handle) -> None: - A = tir.match_buffer(a, (16,)) - for i in tir.serial(0, 15): - with tir.block([], "opaque"): +@T.prim_func +def opaque_block(a: T.handle) -> None: + A = T.match_buffer(a, (16,)) + for i in T.serial(0, 15): + with T.block([], "opaque"): A[i + 1] = A[i + 1] + A[i] diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index bc054938d282f..d79338ace7266 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -21,434 +21,438 @@ import tvm import tvm.testing from tvm import tir -from tvm.script import ty +from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg -@tvm.script.tir -def transformed_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - - for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(128, 128, 4, 8, 4): - with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - tir.bind(vi, i0) - tir.bind(vj, i1) - tir.bind(vk, (((i2_outer * 32) + (i2_inner_outer * 4)) + i2_inner_inner)) - tir.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) - tir.writes([C[vi, vj]]) - with tir.init(): +@T.prim_func +def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + T.bind(vi, i0) + T.bind(vj, i1) + T.bind(vk, (((i2_outer * 32) + (i2_inner_outer * 4)) + i2_inner_inner)) + T.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) + T.writes([C[vi, vj]]) + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) -@tvm.script.tir -def matmul_rfactor(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - C_rf = tir.alloc_buffer([4, 128, 128]) - - for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(128, 128, 4, 8, 4): - with tir.block( - [4, 128, 128, tir.reduce_axis(0, 4), tir.reduce_axis(0, 8)], "update_rf" - ) as [vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer]: - tir.bind(vi2_inner_inner, i2_inner_inner) - tir.bind(vi, i0) - tir.bind(vj, i1) - tir.bind(vi2_outer, i2_outer) - tir.bind(vi2_inner_outer, i2_inner_outer) - with tir.init(): +@T.prim_func +def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + C_rf = T.alloc_buffer([4, 128, 128]) + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): + with T.block([4, 128, 128, T.reduce_axis(0, 4), T.reduce_axis(0, 8)], "update_rf") as [ + vi2_inner_inner, + vi, + vj, + vi2_outer, + vi2_inner_outer, + ]: + T.bind(vi2_inner_inner, i2_inner_inner) + T.bind(vi, i0) + T.bind(vj, i1) + T.bind(vi2_outer, i2_outer) + T.bind(vi2_inner_outer, i2_inner_outer) + with T.init(): C_rf[vi2_inner_inner, vi, vj] = 0.0 C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + ( A[vi, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] * B[vj, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] ) - for i0_1, i1_1, i2_inner_inner_1 in tir.grid(128, 128, 4): - with tir.block([tir.reduce_axis(0, 4), 128, 128], "update") as [ + for i0_1, i1_1, i2_inner_inner_1 in T.grid(128, 128, 4): + with T.block([T.reduce_axis(0, 4), 128, 128], "update") as [ vi2_inner_inner_1, vi_1, vj_1, ]: - tir.bind(vi2_inner_inner_1, i2_inner_inner_1) - tir.bind(vi_1, i0_1) - tir.bind(vj_1, i1_1) - with tir.init(): + T.bind(vi2_inner_inner_1, i2_inner_inner_1) + T.bind(vi_1, i0_1) + T.bind(vj_1, i1_1) + with T.init(): C[vi_1, vj_1] = 0.0 C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1] -@tvm.script.tir -def matmul_not_stage_pipeline(a: ty.handle, b: ty.handle, d: ty.handle) -> None: - A = tir.match_buffer(a, [256, 256]) - B = tir.match_buffer(b, [256, 256]) - D = tir.match_buffer(d, [256, 256]) - C = tir.alloc_buffer([256, 256]) +@T.prim_func +def matmul_not_stage_pipeline(a: T.handle, b: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [256, 256]) + B = T.match_buffer(b, [256, 256]) + D = T.match_buffer(d, [256, 256]) + C = T.alloc_buffer([256, 256]) - with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: - with tir.init(): + with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - with tir.block([256, 256], "D") as [vi, vj]: + with T.block([256, 256], "D") as [vi, vj]: D[vi, vj] = C[vi, vj] -@tvm.script.tir -def matmul_not_same_buffer_access(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - C = tir.match_buffer(c, (128, 128)) +@T.prim_func +def matmul_not_same_buffer_access(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) - with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: - with tir.init(): + with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj] -@tvm.script.tir -def matmul_loop_multiple_children(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - D = tir.match_buffer(d, [128, 128]) +@T.prim_func +def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + D = T.match_buffer(d, [128, 128]) - for k, i, j in tir.grid(128, 128, 128): - with tir.block([tir.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]: - tir.bind(ck, k) - tir.bind(ci, i) - tir.bind(cj, j) - with tir.init(): + for k, i, j in T.grid(128, 128, 128): + with T.block([T.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]: + T.bind(ck, k) + T.bind(ci, i) + T.bind(cj, j) + with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj] - with tir.block([tir.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]: - tir.bind(dk, k) - tir.bind(di, i) - tir.bind(dj, j) - with tir.init(): + with T.block([T.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]: + T.bind(dk, k) + T.bind(di, i) + T.bind(dj, j) + with T.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj] -@tvm.script.tir -def square_sum(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [16, 256, 256]) - C = tir.match_buffer(c, [16]) +@T.prim_func +def square_sum(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 256, 256]) + C = T.match_buffer(c, [16]) - with tir.block([16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]: - with tir.init(): + with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: + with T.init(): C[b] = 0.0 C[b] = C[b] + A[b, i, j] * A[b, i, j] -@tvm.script.tir -def square_sum_rfactor(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [16, 256, 256]) - C = tir.match_buffer(c, [16]) - C_rf = tir.alloc_buffer([16, 256]) +@T.prim_func +def square_sum_rfactor(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 256, 256]) + C = T.match_buffer(c, [16]) + C_rf = T.alloc_buffer([16, 256]) - for i0, i1, i2 in tir.grid(16, 256, 256): - with tir.block([256, 16, tir.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]: - tir.bind(vi2, i2) - tir.bind(b, i0) - tir.bind(i, i1) - with tir.init(): + for i0, i1, i2 in T.grid(16, 256, 256): + with T.block([256, 16, T.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]: + T.bind(vi2, i2) + T.bind(b, i0) + T.bind(i, i1) + with T.init(): C_rf[b, vi2] = 0.0 C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) - for i0_1, i2_1 in tir.grid(16, 256): - with tir.block([tir.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]: - tir.bind(vi2_1, i2_1) - tir.bind(b_1, i0_1) - with tir.init(): + for i0_1, i2_1 in T.grid(16, 256): + with T.block([T.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]: + T.bind(vi2_1, i2_1) + T.bind(b_1, i0_1) + with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[b_1, vi2_1] -@tvm.script.tir -def transformed_square_sum_square_root(a: ty.handle, d: ty.handle) -> None: - A = tir.match_buffer(a, [16, 256, 256]) - D = tir.match_buffer(d, [16]) - C = tir.alloc_buffer([16]) - - for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1): - with tir.block([16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]: - tir.bind(b, i0) - tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256)) - tir.bind(j, tir.floormod(i1_i2_fused_outer, 256)) - tir.reads([C[b], A[b, i, j]]) - tir.writes([C[b]]) - with tir.init(): +@T.prim_func +def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [16, 256, 256]) + D = T.match_buffer(d, [16]) + C = T.alloc_buffer([16]) + + for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): + with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: + T.bind(b, i0) + T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) + T.bind(j, T.floormod(i1_i2_fused_outer, 256)) + T.reads([C[b], A[b, i, j]]) + T.writes([C[b]]) + with T.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) - for i0_1 in tir.serial(0, 16): - with tir.block([16], "D") as [b_1]: - tir.bind(b_1, i0_1) - tir.reads([C[b_1]]) - tir.writes([D[b_1]]) - D[b_1] = tir.sqrt(C[b_1], dtype="float32") - - -@tvm.script.tir -def square_sum_square_root_rfactor(a: ty.handle, d: ty.handle) -> None: - A = tir.match_buffer(a, [16, 256, 256]) - D = tir.match_buffer(d, [16]) - C = tir.alloc_buffer([16]) - C_rf = tir.alloc_buffer([1, 16]) - - for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1): - with tir.block([1, 16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C_rf") as [ + for i0_1 in T.serial(0, 16): + with T.block([16], "D") as [b_1]: + T.bind(b_1, i0_1) + T.reads([C[b_1]]) + T.writes([D[b_1]]) + D[b_1] = T.sqrt(C[b_1], dtype="float32") + + +@T.prim_func +def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [16, 256, 256]) + D = T.match_buffer(d, [16]) + C = T.alloc_buffer([16]) + C_rf = T.alloc_buffer([1, 16]) + + for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): + with T.block([1, 16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C_rf") as [ vi1_i2_fused_inner, b, i, j, ]: - tir.bind(vi1_i2_fused_inner, i1_i2_fused_inner) - tir.bind(b, i0) - tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256)) - tir.bind(j, tir.floormod(i1_i2_fused_outer, 256)) - with tir.init(): + T.bind(vi1_i2_fused_inner, i1_i2_fused_inner) + T.bind(b, i0) + T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) + T.bind(j, T.floormod(i1_i2_fused_outer, 256)) + with T.init(): C_rf[vi1_i2_fused_inner, b] = 0.0 C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) - for i0_1, i1_i2_fused_inner_1 in tir.grid(16, 1): - with tir.block([tir.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]: - tir.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) - tir.bind(b_1, i0_1) - with tir.init(): + for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1): + with T.block([T.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]: + T.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) + T.bind(b_1, i0_1) + with T.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] - for i0_2 in tir.serial(0, 16): - with tir.block([16], "D") as [b_2]: - tir.bind(b_2, i0_2) - D[b_2] = tir.sqrt(C[b_2], dtype="float32") + for i0_2 in T.serial(0, 16): + with T.block([16], "D") as [b_2]: + T.bind(b_2, i0_2) + D[b_2] = T.sqrt(C[b_2], dtype="float32") -@tvm.script.tir -def element_wise(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) +@T.prim_func +def element_wise(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 -@tvm.script.tir -def rowsum(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128,)) +@T.prim_func +def rowsum(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) - with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: - with tir.init(): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] -@tvm.script.tir -def rowsum_not_quasi_affine(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128,)) +@T.prim_func +def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) - for i, k in tir.grid(128, 16): - with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: - tir.bind(vi, i) - tir.bind(vk, tir.floordiv(k * k, 2)) - with tir.init(): + for i, k in T.grid(128, 16): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + T.bind(vi, i) + T.bind(vk, T.floordiv(k * k, 2)) + with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] -@tvm.script.tir -def rowsum_not_dominant(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) +@T.prim_func +def rowsum_not_dominant(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) - with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: - with tir.init(): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + with T.init(): B[vi, vk] = 0.0 B[vi, vk] = B[vi, vk] + A[vi, vk] -@tvm.script.tir -def rowsum_not_serial(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128,)) +@T.prim_func +def rowsum_not_serial(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) - for i in tir.serial(0, 128): - for k in tir.parallel(0, 128): - with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: - tir.bind(vi, i) - tir.bind(vk, k) - with tir.init(): + for i in T.serial(0, 128): + for k in T.parallel(0, 128): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + T.bind(vi, i) + T.bind(vk, k) + with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] -@tvm.script.tir -def rowsum_wrong_reduce_pattern1(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128,)) +@T.prim_func +def rowsum_wrong_reduce_pattern1(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) - with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: - with tir.init(): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + with T.init(): B[vi] = 1.0 B[vi] = B[vi] + A[vi, vk] -@tvm.script.tir -def rowsum_wrong_reduce_pattern2(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128,)) +@T.prim_func +def rowsum_wrong_reduce_pattern2(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) - with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: - with tir.init(): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + with T.init(): B[vi] = 0.0 B[vi] = B[vi] - A[vi, vk] -@tvm.script.tir -def rowsum_transformed(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128,)) +@T.prim_func +def rowsum_transformed(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) - for io, ii_ko_fused, ki in tir.grid(32, 128, 4): - with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: - tir.bind(vi, io * 4 + tir.floordiv(ii_ko_fused, 32)) - tir.bind(vk, tir.floormod(ii_ko_fused, 32) * 4 + ki) - with tir.init(): + for io, ii_ko_fused, ki in T.grid(32, 128, 4): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + T.bind(vi, io * 4 + T.floordiv(ii_ko_fused, 32)) + T.bind(vk, T.floormod(ii_ko_fused, 32) * 4 + ki) + with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk] -@tvm.script.tir -def rowsum_zero_dim(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [128]) - B = tir.match_buffer(b, []) +@T.prim_func +def rowsum_zero_dim(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128]) + B = T.match_buffer(b, []) - with tir.block([tir.reduce_axis(0, 128)], "B") as [k]: - with tir.init(): + with T.block([T.reduce_axis(0, 128)], "B") as [k]: + with T.init(): B[()] = 0.0 B[()] = B[()] + A[k] -@tvm.script.tir -def rowsum_zero_dim_rfactor(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [128]) - B = tir.match_buffer(b, []) - B_rf = tir.alloc_buffer([128]) +@T.prim_func +def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128]) + B = T.match_buffer(b, []) + B_rf = T.alloc_buffer([128]) - with tir.block([128], "B_rf") as [vi0]: - with tir.init(): + with T.block([128], "B_rf") as [vi0]: + with T.init(): B_rf[vi0] = 0.0 B_rf[vi0] = B_rf[vi0] + A[vi0] - with tir.block([tir.reduce_axis(0, 128)], "B") as [vi0_1]: - with tir.init(): + with T.block([T.reduce_axis(0, 128)], "B") as [vi0_1]: + with T.init(): B[()] = 0.0 B[()] = B[()] + B_rf[vi0_1] -@tvm.script.tir -def multiple_reduction_blocks(a: ty.handle, f: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16, 16)) - C = tir.alloc_buffer((16, 16)) - D = tir.alloc_buffer((16, 16)) - E = tir.alloc_buffer((16, 16)) - F = tir.match_buffer(f, (16, 16)) - - for i in tir.serial(0, 16): - for j1 in tir.serial(0, 16): - for k1o, k1i in tir.grid(4, 4): - with tir.block([16, 16, tir.reduce_axis(0, 16)], "C") as [ci, cj, ck]: - tir.bind(ci, i) - tir.bind(cj, j1) - tir.bind(ck, k1o * 4 + k1i) - with tir.init(): +@T.prim_func +def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None: + A = T.match_buffer(a, (16, 16, 16)) + C = T.alloc_buffer((16, 16)) + D = T.alloc_buffer((16, 16)) + E = T.alloc_buffer((16, 16)) + F = T.match_buffer(f, (16, 16)) + + for i in T.serial(0, 16): + for j1 in T.serial(0, 16): + for k1o, k1i in T.grid(4, 4): + with T.block([16, 16, T.reduce_axis(0, 16)], "C") as [ci, cj, ck]: + T.bind(ci, i) + T.bind(cj, j1) + T.bind(ck, k1o * 4 + k1i) + with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, cj, ck] - for k2o, k2i in tir.grid(4, 4): - with tir.block([16, 16, tir.reduce_axis(0, 16)], "D") as [di, dj, dk]: - tir.bind(di, i) - tir.bind(dj, j1) - tir.bind(dk, k2o * 4 + k2i) - with tir.init(): + for k2o, k2i in T.grid(4, 4): + with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]: + T.bind(di, i) + T.bind(dj, j1) + T.bind(dk, k2o * 4 + k2i) + with T.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] - for j2 in tir.serial(0, 16): - for k3o, k3i in tir.grid(4, 4): - with tir.block([16, 16, tir.reduce_axis(0, 16)], "E") as [ei, ej, ek]: - tir.bind(ei, i) - tir.bind(ej, j2) - tir.bind(ek, k3o * 4 + k3i) - with tir.init(): + for j2 in T.serial(0, 16): + for k3o, k3i in T.grid(4, 4): + with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]: + T.bind(ei, i) + T.bind(ej, j2) + T.bind(ek, k3o * 4 + k3i) + with T.init(): E[ei, ej] = 0.0 E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] - for k4o, k4i in tir.grid(4, 4): - with tir.block([16, 16, tir.reduce_axis(0, 16)], "F") as [fi, fj, fk]: - tir.bind(fi, i) - tir.bind(fj, j2) - tir.bind(fk, k4o * 4 + k4i) - with tir.init(): + for k4o, k4i in T.grid(4, 4): + with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]: + T.bind(fi, i) + T.bind(fj, j2) + T.bind(fk, k4o * 4 + k4i) + with T.init(): F[fi, fj] = 0.0 F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj] -@tvm.script.tir -def multiple_reduction_blocks_rfactor(a: ty.handle, f: ty.handle) -> None: - A = tir.match_buffer(a, [16, 16, 16]) - C = tir.alloc_buffer([16, 16]) - D = tir.alloc_buffer([16, 16]) - E = tir.alloc_buffer([16, 16]) - F = tir.match_buffer(f, [16, 16]) - C_rf = tir.alloc_buffer([16, 16, 4]) - - for i, j1, k1o, k1i in tir.grid(16, 16, 4, 4): - with tir.block([4, 16, 16, tir.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]: - tir.bind(vk1o, k1o) - tir.bind(ci, i) - tir.bind(cj, j1) - tir.bind(vk1i, k1i) - with tir.init(): +@T.prim_func +def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16]) + C = T.alloc_buffer([16, 16]) + D = T.alloc_buffer([16, 16]) + E = T.alloc_buffer([16, 16]) + F = T.match_buffer(f, [16, 16]) + C_rf = T.alloc_buffer([16, 16, 4]) + + for i, j1, k1o, k1i in T.grid(16, 16, 4, 4): + with T.block([4, 16, 16, T.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]: + T.bind(vk1o, k1o) + T.bind(ci, i) + T.bind(cj, j1) + T.bind(vk1i, k1i) + with T.init(): C_rf[ci, cj, vk1o] = 0.0 C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] - for i_1 in tir.serial(0, 16): - for j1_1 in tir.serial(0, 16): - for k1o_1 in tir.serial(0, 4): - with tir.block([tir.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]: - tir.bind(vk1o_1, k1o_1) - tir.bind(ci_1, i_1) - tir.bind(cj_1, j1_1) - with tir.init(): + for i_1 in T.serial(0, 16): + for j1_1 in T.serial(0, 16): + for k1o_1 in T.serial(0, 4): + with T.block([T.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]: + T.bind(vk1o_1, k1o_1) + T.bind(ci_1, i_1) + T.bind(cj_1, j1_1) + with T.init(): C[ci_1, cj_1] = 0.0 C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] - for k2o, k2i in tir.grid(4, 4): - with tir.block([16, 16, tir.reduce_axis(0, 16)], "D") as [di, dj, dk]: - tir.bind(di, i_1) - tir.bind(dj, j1_1) - tir.bind(dk, (k2o * 4) + k2i) - with tir.init(): + for k2o, k2i in T.grid(4, 4): + with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]: + T.bind(di, i_1) + T.bind(dj, j1_1) + T.bind(dk, (k2o * 4) + k2i) + with T.init(): D[di, dj] = 0.0 D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] - for j2 in tir.serial(0, 16): - for k3o, k3i in tir.grid(4, 4): - with tir.block([16, 16, tir.reduce_axis(0, 16)], "E") as [ei, ej, ek]: - tir.bind(ei, i_1) - tir.bind(ej, j2) - tir.bind(ek, (k3o * 4) + k3i) - with tir.init(): + for j2 in T.serial(0, 16): + for k3o, k3i in T.grid(4, 4): + with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]: + T.bind(ei, i_1) + T.bind(ej, j2) + T.bind(ek, (k3o * 4) + k3i) + with T.init(): E[ei, ej] = 0.0 E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] - for k4o, k4i in tir.grid(4, 4): - with tir.block([16, 16, tir.reduce_axis(0, 16)], "F") as [fi, fj, fk]: - tir.bind(fi, i_1) - tir.bind(fj, j2) - tir.bind(fk, (k4o * 4) + k4i) - with tir.init(): + for k4o, k4i in T.grid(4, 4): + with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]: + T.bind(fi, i_1) + T.bind(fj, j2) + T.bind(fk, (k4o * 4) + k4i) + with T.init(): F[fi, fj] = 0.0 F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj] diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py index 091a77df20302..a60ab8dca9725 100644 --- a/tests/python/unittest/test_tir_schedule_reorder.py +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -20,175 +20,175 @@ import pytest import tvm from tvm import tir -from tvm.script import ty +from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip # pylint: disable=no-member,invalid-name,unused-variable -@tvm.script.tir -def elementwise(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128, 128)) - with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: +@T.prim_func +def elementwise(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128, 128)) + with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 -@tvm.script.tir -def elementwise_not_affine(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128, 128)) - for i, j, k, l in tir.grid(128, 128, 128, 8): - with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - tir.bind(vi, i) - tir.bind(vj, j) - tir.bind(vk, k) - tir.bind(vl, l * 16) +@T.prim_func +def elementwise_not_affine(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in T.grid(128, 128, 128, 8): + with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + T.bind(vi, i) + T.bind(vj, j) + T.bind(vk, k) + T.bind(vl, l * 16) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 -@tvm.script.tir -def elementwise_dependent_loop(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128, 128)) - for i in tir.serial(0, 128): - for j, k, l in tir.grid(128, i, 128): - with tir.block([128, 128, i, 128], "B") as [vi, vj, vk, vl]: +@T.prim_func +def elementwise_dependent_loop(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128, 128)) + for i in T.serial(0, 128): + for j, k, l in T.grid(128, i, 128): + with T.block([128, 128, i, 128], "B") as [vi, vj, vk, vl]: B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 -@tvm.script.tir -def elementwise_predicate(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128, 128)) - for i, j, k, l in tir.grid(128, 128, 128, 128): - with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - tir.where(i * 2097152 + j * 16384 + k * 128 + l < 100) +@T.prim_func +def elementwise_predicate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in T.grid(128, 128, 128, 128): + with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + T.where(i * 2097152 + j * 16384 + k * 128 + l < 100) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 -@tvm.script.tir -def elementwise_non_single_branch(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128)) - C = tir.alloc_buffer((128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128)) - for i, j in tir.grid(128, 128): - for k in tir.serial(0, 128): - with tir.block([128, 128, 128], "C") as [vi, vj, vk]: - tir.bind(vi, i) - tir.bind(vj, j) - tir.bind(vk, k) +@T.prim_func +def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + C = T.alloc_buffer((128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.serial(0, 128): + with T.block([128, 128, 128], "C") as [vi, vj, vk]: + T.bind(vi, i) + T.bind(vj, j) + T.bind(vk, k) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 - for k in tir.serial(0, 128): - with tir.block([128, 128, 128], "B") as [vi, vj, vk]: - tir.bind(vi, i) - tir.bind(vj, j) - tir.bind(vk, k) + for k in T.serial(0, 128): + with T.block([128, 128, 128], "B") as [vi, vj, vk]: + T.bind(vi, i) + T.bind(vj, j) + T.bind(vk, k) B[vi, vj, vk] = C[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_with_loops_not_same_scope(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128)) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "A") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) - for k in tir.serial(0, 128): - with tir.block([128], "B") as [vk]: - tir.bind(vk, k) - tir.reads([A[vi, vj, vk]]) - tir.writes([B[vi, vj, vk]]) +@T.prim_func +def elementwise_with_loops_not_same_scope(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j in T.grid(128, 128): + with T.block([128, 128], "A") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) + for k in T.serial(0, 128): + with T.block([128], "B") as [vk]: + T.bind(vk, k) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_with_wrong_block_var_type(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128)) - for i, j, k in tir.grid(128, 128, 128): - with tir.block([128, 128, tir.scan_axis(0, 128)], "B") as [vi, vj, vk]: - tir.bind(vi, i) - tir.bind(vj, j) - tir.bind(vk, k) - tir.reads([A[vi, vj, vk]]) - tir.writes([B[vi, vj, vk]]) +@T.prim_func +def elementwise_with_wrong_block_var_type(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j, k in T.grid(128, 128, 128): + with T.block([128, 128, T.scan_axis(0, 128)], "B") as [vi, vj, vk]: + T.bind(vi, i) + T.bind(vj, j) + T.bind(vk, k) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_reordered(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128, 128)) - for l, j, k, i in tir.grid(128, 128, 128, 128): - with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - tir.bind(vi, i) - tir.bind(vj, j) - tir.bind(vk, k) - tir.bind(vl, l) +@T.prim_func +def elementwise_reordered(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128, 128)) + for l, j, k, i in T.grid(128, 128, 128, 128): + with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + T.bind(vi, i) + T.bind(vj, j) + T.bind(vk, k) + T.bind(vl, l) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 -@tvm.script.tir -def elementwise_reordered2(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128, 128)) - for k, j, i, l in tir.grid(128, 128, 128, 128): - with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - tir.bind(vi, i) - tir.bind(vj, j) - tir.bind(vk, k) - tir.bind(vl, l) +@T.prim_func +def elementwise_reordered2(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128, 128)) + for k, j, i, l in T.grid(128, 128, 128, 128): + with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + T.bind(vi, i) + T.bind(vj, j) + T.bind(vk, k) + T.bind(vl, l) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 -@tvm.script.tir -def elementwise_reordered_with_predicate(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128, 128)) - for l, j, k, i in tir.grid(128, 128, 128, 128): - with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: - tir.where(i * 2097152 + j * 16384 + k * 128 + l < 100) - tir.bind(vi, i) - tir.bind(vj, j) - tir.bind(vk, k) - tir.bind(vl, l) +@T.prim_func +def elementwise_reordered_with_predicate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128, 128)) + for l, j, k, i in T.grid(128, 128, 128, 128): + with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + T.where(i * 2097152 + j * 16384 + k * 128 + l < 100) + T.bind(vi, i) + T.bind(vj, j) + T.bind(vk, k) + T.bind(vl, l) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 -@tvm.script.tir -def opaque_access(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [16, 16], "float32") - B = tir.match_buffer(b, [16, 16], "float32") - with tir.block([16, 16], "A") as [vi, vj]: - tir.reads([]) - tir.writes([A[0:16, 0:16]]) - tir.store(A.data, vi * 16 + vj, 1) - with tir.block([16, 16], "B") as [vi, vj]: - tir.reads([]) - tir.writes([B[0:16, 0:16]]) - tir.evaluate(tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) - - -@tvm.script.tir -def opaque_access_reorder(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [16, 16], "float32") - B = tir.match_buffer(b, [16, 16], "float32") - for j, i in tir.grid(16, 16): - with tir.block([16, 16], "A") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) - tir.reads([]) - tir.writes([A[0:16, 0:16]]) - tir.store(A.data, vi * 16 + vj, 1) - for j, i in tir.grid(16, 16): - with tir.block([16, 16], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) - tir.reads([]) - tir.writes([B[0:16, 0:16]]) - tir.evaluate(tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) +@T.prim_func +def opaque_access(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [16, 16], "float32") + B = T.match_buffer(b, [16, 16], "float32") + with T.block([16, 16], "A") as [vi, vj]: + T.reads([]) + T.writes([A[0:16, 0:16]]) + T.store(A.data, vi * 16 + vj, 1) + with T.block([16, 16], "B") as [vi, vj]: + T.reads([]) + T.writes([B[0:16, 0:16]]) + T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + + +@T.prim_func +def opaque_access_reorder(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [16, 16], "float32") + B = T.match_buffer(b, [16, 16], "float32") + for j, i in T.grid(16, 16): + with T.block([16, 16], "A") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) + T.reads([]) + T.writes([A[0:16, 0:16]]) + T.store(A.data, vi * 16 + vj, 1) + for j, i in T.grid(16, 16): + with T.block([16, 16], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) + T.reads([]) + T.writes([B[0:16, 0:16]]) + T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) # pylint: enable=no-member,invalid-name,unused-variable diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index 2bfd68663c99e..c93c7ca63aa88 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -20,7 +20,7 @@ import pytest import tvm from tvm import tir -from tvm.script import ty +from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip from tvm.tir.schedule import Trace @@ -28,11 +28,11 @@ # pylint: disable=no-member,invalid-name,unused-variable -@tvm.script.tir -def elementwise(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128)) - with tir.block([128, 128, 128], "B") as [vi, vj, vk]: +@T.prim_func +def elementwise(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + with T.block([128, 128, 128], "B") as [vi, vj, vk]: B[vi, vj, vk] = A[vi, vj, vk] * 2.0 diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index d11e7f877ccca..79cdd5b6549e1 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -20,313 +20,309 @@ import pytest import tvm from tvm import tir -from tvm.script import ty +from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip # pylint: disable=no-member,invalid-name,unused-variable -@tvm.script.tir -def elementwise(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128)) - with tir.block([128, 128, 128], "B") as [vi, vj, vk]: +@T.prim_func +def elementwise(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + with T.block([128, 128, 128], "B") as [vi, vj, vk]: B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_dependent_loops(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128)) - for i in tir.serial(0, 128): - for j, k in tir.grid(i, 128): - with tir.block([128, i, 128], "B") as [vi, vj, vk]: +@T.prim_func +def elementwise_dependent_loops(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i in T.serial(0, 128): + for j, k in T.grid(i, 128): + with T.block([128, i, 128], "B") as [vi, vj, vk]: B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_symbolic(a: ty.handle, b: ty.handle, n: ty.int32) -> None: - A = tir.match_buffer(a, (128, 128, n)) - B = tir.match_buffer(b, (128, 128, n)) - for i, j, k in tir.grid(128, 128, n): - with tir.block([128, 128, n], "B") as [vi, vj, vk]: +@T.prim_func +def elementwise_symbolic(a: T.handle, b: T.handle, n: T.int32) -> None: + A = T.match_buffer(a, (128, 128, n)) + B = T.match_buffer(b, (128, 128, n)) + for i, j, k in T.grid(128, 128, n): + with T.block([128, 128, n], "B") as [vi, vj, vk]: B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_symbolic_fused(a: ty.handle, b: ty.handle, n: ty.int32) -> None: - A = tir.match_buffer(a, (128, 128, n)) - B = tir.match_buffer(b, (128, 128, n)) - for i_j_k_fused in tir.serial(0, (n * 16384)): - with tir.block([128, 128, n], "B") as [vi, vj, vk]: - tir.bind(vi, tir.floordiv(i_j_k_fused, (n * 128))) - tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, n), 128)) - tir.bind(vk, tir.floormod(i_j_k_fused, n)) - tir.reads([A[vi, vj, vk]]) - tir.writes([B[vi, vj, vk]]) +@T.prim_func +def elementwise_symbolic_fused(a: T.handle, b: T.handle, n: T.int32) -> None: + A = T.match_buffer(a, (128, 128, n)) + B = T.match_buffer(b, (128, 128, n)) + for i_j_k_fused in T.serial(0, (n * 16384)): + with T.block([128, 128, n], "B") as [vi, vj, vk]: + T.bind(vi, T.floordiv(i_j_k_fused, (n * 128))) + T.bind(vj, T.floormod(T.floordiv(i_j_k_fused, n), 128)) + T.bind(vk, T.floormod(i_j_k_fused, n)) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_symbolic_split(a: ty.handle, b: ty.handle, n: ty.int32) -> None: - A = tir.match_buffer(a, (128, 128, n)) - B = tir.match_buffer(b, (128, 128, n)) - for i, j, k0, k1 in tir.grid(128, 128, 10, tir.floordiv((n + 9), 10)): - with tir.block([128, 128, n], "B") as [vi, vj, vk]: - tir.where((((k0 * tir.floordiv((n + 9), 10)) + k1) < n)) - tir.bind(vi, i) - tir.bind(vj, j) - tir.bind(vk, ((k0 * tir.floordiv((n + 9), 10)) + k1)) - tir.reads([A[vi, vj, vk]]) - tir.writes([B[vi, vj, vk]]) +@T.prim_func +def elementwise_symbolic_split(a: T.handle, b: T.handle, n: T.int32) -> None: + A = T.match_buffer(a, (128, 128, n)) + B = T.match_buffer(b, (128, 128, n)) + for i, j, k0, k1 in T.grid(128, 128, 10, T.floordiv((n + 9), 10)): + with T.block([128, 128, n], "B") as [vi, vj, vk]: + T.where((((k0 * T.floordiv((n + 9), 10)) + k1) < n)) + T.bind(vi, i) + T.bind(vj, j) + T.bind(vk, ((k0 * T.floordiv((n + 9), 10)) + k1)) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_with_seq(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128)) - C = tir.alloc_buffer((128, 128, 128)) - for i, j in tir.grid(128, 128): - for k in tir.serial(0, 128): - with tir.block([128, 128, 128], "C") as [vi, vj, vk]: +@T.prim_func +def elementwise_with_seq(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + C = T.alloc_buffer((128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.serial(0, 128): + with T.block([128, 128, 128], "C") as [vi, vj, vk]: C[vi, vj, vk] = A[vi, vj, vk] * 2.0 - for k in tir.serial(0, 128): - with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + for k in T.serial(0, 128): + with T.block([128, 128, 128], "B") as [vi, vj, vk]: B[vi, vj, vk] = C[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_with_anno(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128)) - for i, j in tir.grid(128, 128): - for k in tir.serial(0, 128, annotations={"useless_annotation": True}): - with tir.block([128, 128, 128], "B") as [vi, vj, vk]: - tir.bind(vi, i) - tir.bind(vj, j) - tir.bind(vk, k) - tir.reads([A[vi, vj, vk]]) - tir.writes([B[vi, vj, vk]]) +@T.prim_func +def elementwise_with_anno(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.serial(0, 128, annotations={"useless_annotation": True}): + with T.block([128, 128, 128], "B") as [vi, vj, vk]: + T.bind(vi, i) + T.bind(vj, j) + T.bind(vk, k) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_with_thread_binding(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128)) - for i, j in tir.grid(128, 128): - for k in tir.thread_binding(0, 128, thread="threadIdx.x"): - with tir.block([128, 128, 128], "B") as [vi, vj, vk]: - tir.bind(vi, i) - tir.bind(vj, j) - tir.bind(vk, k) - tir.reads([A[vi, vj, vk]]) - tir.writes([B[vi, vj, vk]]) +@T.prim_func +def elementwise_with_thread_binding(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.thread_binding(0, 128, thread="threadIdx.x"): + with T.block([128, 128, 128], "B") as [vi, vj, vk]: + T.bind(vi, i) + T.bind(vj, j) + T.bind(vk, k) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_with_starting_point(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128)) - for i, j in tir.grid(128, 128): - for k in tir.serial(10, 128): - with tir.block([128, 128, 128], "B") as [vi, vj, vk]: - tir.bind(vi, i) - tir.bind(vj, j) - tir.bind(vk, k) - tir.reads([A[vi, vj, vk]]) - tir.writes([B[vi, vj, vk]]) +@T.prim_func +def elementwise_with_starting_point(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j in T.grid(128, 128): + for k in T.serial(10, 128): + with T.block([128, 128, 128], "B") as [vi, vj, vk]: + T.bind(vi, i) + T.bind(vj, j) + T.bind(vk, k) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_with_opaque_block(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128)) - for i, j, k in tir.grid(128, 128, 128): - with tir.block([], "opaque"): - tir.reads([A[i, j, k]]) - tir.writes([B[i, j, k]]) - with tir.block([128, 128, 128], "B") as [vi, vj, vk]: - tir.bind(vi, i) - tir.bind(vj, j) - tir.bind(vk, k) - tir.reads([A[vi, vj, vk]]) - tir.writes([B[vi, vj, vk]]) +@T.prim_func +def elementwise_with_opaque_block(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for i, j, k in T.grid(128, 128, 128): + with T.block([], "opaque"): + T.reads([A[i, j, k]]) + T.writes([B[i, j, k]]) + with T.block([128, 128, 128], "B") as [vi, vj, vk]: + T.bind(vi, i) + T.bind(vj, j) + T.bind(vk, k) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_fused(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128, 128)) - B = tir.match_buffer(b, (128, 128, 128)) - for fused in tir.serial(0, 2097152): - with tir.block([128, 128, 128], "B") as [vi, vj, vk]: - tir.bind(vi, tir.floordiv(fused, 16384)) - tir.bind(vj, tir.floormod(tir.floordiv(fused, 128), 128)) - tir.bind(vk, tir.floormod(fused, 128)) - tir.reads([A[vi, vj, vk]]) - tir.writes([B[vi, vj, vk]]) +@T.prim_func +def elementwise_fused(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128)) + for fused in T.serial(0, 2097152): + with T.block([128, 128, 128], "B") as [vi, vj, vk]: + T.bind(vi, T.floordiv(fused, 16384)) + T.bind(vj, T.floormod(T.floordiv(fused, 128), 128)) + T.bind(vk, T.floormod(fused, 128)) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_split_case0(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128, 128]) - B = tir.match_buffer(b, [128, 128, 128]) - for i1, i2, i3, j1, j2, k1, k2 in tir.grid(2, 1, 64, 4, 32, 16, 8): - with tir.block([128, 128, 128], "B") as [vi, vj, vk]: - tir.bind(vi, ((i1 * 64) + i3)) - tir.bind(vj, ((j1 * 32) + j2)) - tir.bind(vk, ((k1 * 8) + k2)) - tir.reads([A[vi, vj, vk]]) - tir.writes([B[vi, vj, vk]]) +@T.prim_func +def elementwise_split_case0(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128]) + B = T.match_buffer(b, [128, 128, 128]) + for i1, i2, i3, j1, j2, k1, k2 in T.grid(2, 1, 64, 4, 32, 16, 8): + with T.block([128, 128, 128], "B") as [vi, vj, vk]: + T.bind(vi, ((i1 * 64) + i3)) + T.bind(vj, ((j1 * 32) + j2)) + T.bind(vk, ((k1 * 8) + k2)) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_split_case1(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128, 128]) - B = tir.match_buffer(b, [128, 128, 128]) - for i1, i2, i3, j1, j2, j3, k1, k2, k3 in tir.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): - with tir.block([128, 128, 128], "B") as [vi, vj, vk]: - tir.bind(vi, i1 * 64 + i3) - tir.bind(vj, j1 * 64 + j3) - tir.bind(vk, k1 * 64 + k3) - tir.reads([A[vi, vj, vk]]) - tir.writes([B[vi, vj, vk]]) +@T.prim_func +def elementwise_split_case1(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128, 128]) + B = T.match_buffer(b, [128, 128, 128]) + for i1, i2, i3, j1, j2, j3, k1, k2, k3 in T.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): + with T.block([128, 128, 128], "B") as [vi, vj, vk]: + T.bind(vi, i1 * 64 + i3) + T.bind(vj, j1 * 64 + j3) + T.bind(vk, k1 * 64 + k3) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_split_with_predicate(a: ty.handle, b: ty.handle) -> None: - B = tir.match_buffer(b, [128, 128, 128]) - A = tir.match_buffer(a, [128, 128, 128]) - for i0, i1, i2, j0, j1, k0, k1 in tir.grid(1000, 2, 3, 1, 129, 3, 43): - with tir.block([128, 128, 128], "B") as [vi, vj, vk]: - tir.where( +@T.prim_func +def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None: + B = T.match_buffer(b, [128, 128, 128]) + A = T.match_buffer(a, [128, 128, 128]) + for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43): + with T.block([128, 128, 128], "B") as [vi, vj, vk]: + T.where( ( ((((((i0 * 2) + i1) * 3) + i2) < 128) and (((j0 * 129) + j1) < 128)) and (((k0 * 43) + k1) < 128) ) ) - tir.bind(vi, (((i0 * 6) + (i1 * 3)) + i2)) - tir.bind(vj, j1) - tir.bind(vk, ((k0 * 43) + k1)) - tir.reads([A[vi, vj, vk]]) - tir.writes([B[vi, vj, vk]]) + T.bind(vi, (((i0 * 6) + (i1 * 3)) + i2)) + T.bind(vj, j1) + T.bind(vk, ((k0 * 43) + k1)) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_fuse_with_opaque_block(a: ty.handle, b: ty.handle) -> None: - B = tir.match_buffer(b, [128, 128, 128]) - A = tir.match_buffer(a, [128, 128, 128]) - for i_j_k_fused in tir.serial(0, 2097152): - with tir.block([], "opaque"): - tir.reads( +@T.prim_func +def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: + B = T.match_buffer(b, [128, 128, 128]) + A = T.match_buffer(a, [128, 128, 128]) + for i_j_k_fused in T.serial(0, 2097152): + with T.block([], "opaque"): + T.reads( [ A[ - tir.floormod(tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128), - tir.floormod(tir.floordiv(i_j_k_fused, 128), 128), - tir.floormod(i_j_k_fused, 128), + T.floormod(T.floordiv(T.floordiv(i_j_k_fused, 128), 128), 128), + T.floormod(T.floordiv(i_j_k_fused, 128), 128), + T.floormod(i_j_k_fused, 128), ] ] ) - tir.writes( + T.writes( [ B[ - tir.floormod(tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128), - tir.floormod(tir.floordiv(i_j_k_fused, 128), 128), - tir.floormod(i_j_k_fused, 128), + T.floormod(T.floordiv(T.floordiv(i_j_k_fused, 128), 128), 128), + T.floormod(T.floordiv(i_j_k_fused, 128), 128), + T.floormod(i_j_k_fused, 128), ] ] ) - with tir.block([128, 128, 128], "B") as [vi, vj, vk]: - tir.bind(vi, tir.floordiv(i_j_k_fused, 16384)) - tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, 128), 128)) - tir.bind(vk, tir.floormod(i_j_k_fused, 128)) - tir.reads([A[vi, vj, vk]]) - tir.writes([B[vi, vj, vk]]) + with T.block([128, 128, 128], "B") as [vi, vj, vk]: + T.bind(vi, T.floordiv(i_j_k_fused, 16384)) + T.bind(vj, T.floormod(T.floordiv(i_j_k_fused, 128), 128)) + T.bind(vk, T.floormod(i_j_k_fused, 128)) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def elementwise_split_with_opaque_block(a: ty.handle, b: ty.handle) -> None: - B = tir.match_buffer(b, [128, 128, 128]) - A = tir.match_buffer(a, [128, 128, 128]) - - for i0, i1, j, k in tir.grid(8, 16, 128, 128): - with tir.block([], "opaque"): - tir.reads([A[i0 * 16 + i1, j, k]]) - tir.writes([B[i0 * 16 + i1, j, k]]) - with tir.block([128, 128, 128], "B") as [vi, vj, vk]: - tir.bind(vi, i0 * 16 + i1) - tir.bind(vj, j) - tir.bind(vk, k) - tir.reads([A[vi, vj, vk]]) - tir.writes([B[vi, vj, vk]]) +@T.prim_func +def elementwise_split_with_opaque_block(a: T.handle, b: T.handle) -> None: + B = T.match_buffer(b, [128, 128, 128]) + A = T.match_buffer(a, [128, 128, 128]) + + for i0, i1, j, k in T.grid(8, 16, 128, 128): + with T.block([], "opaque"): + T.reads([A[i0 * 16 + i1, j, k]]) + T.writes([B[i0 * 16 + i1, j, k]]) + with T.block([128, 128, 128], "B") as [vi, vj, vk]: + T.bind(vi, i0 * 16 + i1) + T.bind(vj, j) + T.bind(vk, k) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -@tvm.script.tir -def opaque_access(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [16, 16], "float32") - B = tir.match_buffer(b, [16, 16], "float32") - with tir.block([16, 16], "A") as [vi, vj]: - tir.reads([]) - tir.writes([A[0:16, 0:16]]) - tir.store(A.data, vi * 16 + vj, 1) - with tir.block([16, 16], "B") as [vi, vj]: - tir.reads([]) - tir.writes([B[0:16, 0:16]]) - tir.evaluate(tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) - - -@tvm.script.tir -def opaque_access_fused(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [16, 16]) - B = tir.match_buffer(b, [16, 16]) - for i_j_fused in tir.serial(0, 256): - with tir.block([16, 16], "A") as [vi, vj]: - tir.bind(vi, tir.floordiv(i_j_fused, 16)) - tir.bind(vj, tir.floormod(i_j_fused, 16)) - tir.reads([]) - tir.writes([A[0:16, 0:16]]) - tir.store(A.data, ((vi * 16) + vj), 1, 1) - for i_j_fused in tir.serial(0, 256): - with tir.block([16, 16], "B") as [vi, vj]: - tir.bind(vi, tir.floordiv(i_j_fused, 16)) - tir.bind(vj, tir.floormod(i_j_fused, 16)) - tir.reads([]) - tir.writes([B[0:16, 0:16]]) - tir.evaluate( - tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle") - ) - - -@tvm.script.tir -def opaque_access_split(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16)) - B = tir.match_buffer(b, (16, 16)) - for i, j0, j1 in tir.grid(16, 4, 4): - with tir.block([16, 16], "A") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, ((j0 * 4) + j1)) - tir.reads([]) - tir.writes([A[0:16, 0:16]]) - tir.store(A.data, ((vi * 16) + vj), 1, 1) - for i, j0, j1 in tir.grid(16, 4, 4): - with tir.block([16, 16], "B") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, ((j0 * 4) + j1)) - tir.reads([]) - tir.writes([B[0:16, 0:16]]) - tir.evaluate( - tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle") - ) +@T.prim_func +def opaque_access(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [16, 16], "float32") + B = T.match_buffer(b, [16, 16], "float32") + with T.block([16, 16], "A") as [vi, vj]: + T.reads([]) + T.writes([A[0:16, 0:16]]) + T.store(A.data, vi * 16 + vj, 1) + with T.block([16, 16], "B") as [vi, vj]: + T.reads([]) + T.writes([B[0:16, 0:16]]) + T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + + +@T.prim_func +def opaque_access_fused(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [16, 16]) + B = T.match_buffer(b, [16, 16]) + for i_j_fused in T.serial(0, 256): + with T.block([16, 16], "A") as [vi, vj]: + T.bind(vi, T.floordiv(i_j_fused, 16)) + T.bind(vj, T.floormod(i_j_fused, 16)) + T.reads([]) + T.writes([A[0:16, 0:16]]) + T.store(A.data, ((vi * 16) + vj), 1, 1) + for i_j_fused in T.serial(0, 256): + with T.block([16, 16], "B") as [vi, vj]: + T.bind(vi, T.floordiv(i_j_fused, 16)) + T.bind(vj, T.floormod(i_j_fused, 16)) + T.reads([]) + T.writes([B[0:16, 0:16]]) + T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle")) + + +@T.prim_func +def opaque_access_split(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16)) + B = T.match_buffer(b, (16, 16)) + for i, j0, j1 in T.grid(16, 4, 4): + with T.block([16, 16], "A") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, ((j0 * 4) + j1)) + T.reads([]) + T.writes([A[0:16, 0:16]]) + T.store(A.data, ((vi * 16) + vj), 1, 1) + for i, j0, j1 in T.grid(16, 4, 4): + with T.block([16, 16], "B") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, ((j0 * 4) + j1)) + T.reads([]) + T.writes([B[0:16, 0:16]]) + T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle")) # pylint: enable=no-member,invalid-name,unused-variable diff --git a/tests/python/unittest/test_tir_schedule_state.py b/tests/python/unittest/test_tir_schedule_state.py index 856d6a5c17ebf..94e1b4a6b3959 100644 --- a/tests/python/unittest/test_tir_schedule_state.py +++ b/tests/python/unittest/test_tir_schedule_state.py @@ -22,54 +22,54 @@ import tvm from tvm import tir from tvm.ir import IRModule -from tvm.script import ty +from tvm.script import tir as T # pylint: disable=no-member,invalid-name,unused-variable -@tvm.script.tir -def elementwise(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "init") as [vi, vj]: - C[vi, vj] = tir.float32(0) +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j in T.grid(128, 128): + with T.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = T.float32(0) for k in range(0, 128): - with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@tvm.script.tir -def block_in_opaque_block(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - B = tir.match_buffer(b, (128, 128), "float32") - with tir.block([128], "B") as vi: - tir.reads([A[0:128, 0:128]]) - tir.writes([B[0:128, 0:128]]) +@T.prim_func +def block_in_opaque_block(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.match_buffer(b, (128, 128), "float32") + with T.block([128], "B") as vi: + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) B[vi, 0] = A[vi, 0] if A[vi, 0] == 0.0: - with tir.block([], "C"): - tir.reads([A[0:128, 0:128]]) - tir.writes([B[0:128, 0:128]]) - with tir.block([128], "D") as vj: + with T.block([], "C"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + with T.block([128], "D") as vj: B[vi, vj] = A[vi, vj] * 3.0 else: - with tir.block([], "E"): - tir.reads([A[0:128, 0:128]]) - tir.writes([B[0:128, 0:128]]) - with tir.block([128], "F") as vj: + with T.block([], "E"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + with T.block([128], "F") as vj: B[vi, vj] = A[vi, vj] * 2.0 @@ -77,7 +77,7 @@ def block_in_opaque_block(a: ty.handle, b: ty.handle) -> None: def replace_ir_builder(deep_copy=False, realize=False): - new_func = tvm.script.from_source(tvm.script.asscript(elementwise)) + new_func = tvm.script.from_source(elementwise.script()) s = tir.ScheduleState(new_func, debug_mask="all") target = tvm.tir.Block( iter_vars=[], @@ -103,8 +103,8 @@ def replace_ir_builder(deep_copy=False, realize=False): def replace_ir_builder_module(deep_copy=False, realize=False): - new_func = tvm.script.from_source(tvm.script.asscript(elementwise)) - other_func = tvm.script.from_source(tvm.script.asscript(elementwise)) + new_func = tvm.script.from_source(elementwise.script()) + other_func = tvm.script.from_source(elementwise.script()) mod = IRModule(functions={"main": new_func, "other": other_func}) s = tir.ScheduleState(mod, debug_mask="all") target = tvm.tir.Block( @@ -131,7 +131,7 @@ def replace_ir_builder_module(deep_copy=False, realize=False): def replace_ir_builder_with_opaque(): - func = tvm.script.from_source(tvm.script.asscript(block_in_opaque_block)) + func = tvm.script.from_source(block_in_opaque_block.script()) s = tir.ScheduleState(func, debug_mask="all") gc.collect() return s diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py index 075b6cd689c45..e2b39ce7c2895 100644 --- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py +++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py @@ -20,247 +20,247 @@ import pytest import tvm from tvm import tir -from tvm.script import ty +from tvm.script import tir as T from tvm.tir.schedule.state import CachedFlags from tvm.tir.stmt_functor import post_order_visit # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg -@tvm.script.tir -def elementwise(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "init") as [vi, vj]: +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j in T.grid(128, 128): + with T.block([128, 128], "init") as [vi, vj]: C[vi, vj] = 0.0 for k in range(0, 128): - with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@tvm.script.tir -def block_in_opaque_block(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - B = tir.match_buffer(b, (128, 128), "float32") - with tir.block([128], "B") as vi: - tir.reads([A[0:128, 0:128]]) - tir.writes([B[0:128, 0:128]]) +@T.prim_func +def block_in_opaque_block(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.match_buffer(b, (128, 128), "float32") + with T.block([128], "B") as vi: + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) B[vi, 0] = A[vi, 0] if A[vi, 0] == 0.0: - with tir.block([], "C"): - tir.reads([A[0:128, 0:128]]) - tir.writes([B[0:128, 0:128]]) - with tir.block([128], "D") as vj: + with T.block([], "C"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + with T.block([128], "D") as vj: B[vi, vj] = A[vi, vj] * 3.0 else: - with tir.block([], "E"): - tir.reads([A[0:128, 0:128]]) - tir.writes([B[0:128, 0:128]]) - with tir.block([128], "F") as vj: + with T.block([], "E"): + T.reads([A[0:128, 0:128]]) + T.writes([B[0:128, 0:128]]) + with T.block([128], "F") as vj: B[vi, vj] = A[vi, vj] * 2.0 -@tvm.script.tir -def write_after_read(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "C") as [vi, vj]: +@T.prim_func +def write_after_read(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 - with tir.block([128, 128], "B") as [vi, vj]: + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 -@tvm.script.tir -def loop_carried_dependency(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128,)) - B = tir.match_buffer(b, (128,)) - C = tir.match_buffer(c, (128,)) +@T.prim_func +def loop_carried_dependency(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128,)) + B = T.match_buffer(b, (128,)) + C = T.match_buffer(c, (128,)) for i in range(0, 128): - with tir.block([128], "B") as vi: + with T.block([128], "B") as vi: B[vi] = A[vi] * 2.0 - with tir.block([128], "C") as vi: - C[vi] = tir.if_then_else(vi >= 1, B[vi - 1] + 1.0, 0.0, dtype="float32") + with T.block([128], "C") as vi: + C[vi] = T.if_then_else(vi >= 1, B[vi - 1] + 1.0, 0.0, dtype="float32") -@tvm.script.tir -def concatenate_multi_producer(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128,)) - B = tir.match_buffer(b, (128,)) +@T.prim_func +def concatenate_multi_producer(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128,)) + B = T.match_buffer(b, (128,)) for i in range(0, 64): - with tir.block([64], "A_0") as vi: + with T.block([64], "A_0") as vi: A[vi] = vi + 1 for i in range(0, 64): - with tir.block([64], "A_1") as vi: - tir.bind(vi, i + 64) + with T.block([64], "A_1") as vi: + T.bind(vi, i + 64) A[vi] = vi + 2 - with tir.block([128], "B") as vi: + with T.block([128], "B") as vi: B[vi] = A[vi] * 2.0 -@tvm.script.tir -def concatenate_multi_producer_uncovered(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128,)) - B = tir.match_buffer(b, (128,)) +@T.prim_func +def concatenate_multi_producer_uncovered(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128,)) + B = T.match_buffer(b, (128,)) for i in range(0, 63): - with tir.block([63], "A_0") as vi: + with T.block([63], "A_0") as vi: A[vi] = vi + 1 for i in range(0, 64): - with tir.block([64], "A_1") as vi: - tir.bind(vi, i + 64) + with T.block([64], "A_1") as vi: + T.bind(vi, i + 64) A[vi] = vi + 2 - with tir.block([128], "B") as vi: + with T.block([128], "B") as vi: B[vi] = A[vi] * 2.0 -@tvm.script.tir -def lca_at_loop(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128,)) - B = tir.match_buffer(b, (128,)) - C = tir.match_buffer(c, (128,)) +@T.prim_func +def lca_at_loop(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128,)) + B = T.match_buffer(b, (128,)) + C = T.match_buffer(c, (128,)) for i in range(0, 128): - with tir.block([128], "B") as vi: + with T.block([128], "B") as vi: B[vi] = A[vi] * 2.0 - with tir.block([128], "C") as vi: + with T.block([128], "C") as vi: C[vi] = B[vi] + 1.0 -@tvm.script.tir -def multi_producer_consumer(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (128,)) - B = tir.match_buffer(b, (128,)) +@T.prim_func +def multi_producer_consumer(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128,)) + B = T.match_buffer(b, (128,)) for i in range(0, 64): - with tir.block([64], "A_0") as vi: + with T.block([64], "A_0") as vi: A[vi] = vi + 1 for i in range(0, 64): - with tir.block([64], "A_1") as vi: - tir.bind(vi, i + 64) + with T.block([64], "A_1") as vi: + T.bind(vi, i + 64) A[vi] = vi + 2 for i in range(0, 64): - with tir.block([64], "B_0") as vi: + with T.block([64], "B_0") as vi: B[vi] = A[vi] + 2.0 for i in range(0, 64): - with tir.block([64], "B_1") as vi: - tir.bind(vi, i + 64) + with T.block([64], "B_1") as vi: + T.bind(vi, i + 64) B[vi] = A[vi] + 3.0 -@tvm.script.tir -def elementwise_affine_producer(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - for i, j, k, l in tir.grid(16, 2, 32, 16): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i * 8 + j * 4 + k // 8) - tir.bind(vj, k % 8 * 16 + l) +@T.prim_func +def elementwise_affine_producer(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + for i, j, k, l in T.grid(16, 2, 32, 16): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i * 8 + j * 4 + k // 8) + T.bind(vj, k % 8 * 16 + l) B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def elementwise_subblock(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - with tir.block([32, 32], "B") as [vi, vj]: - tir.reads([A[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) - tir.writes([B[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) - with tir.block([4, 4], "B_sub") as [vi_i, vj_i]: +@T.prim_func +def elementwise_subblock(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + with T.block([32, 32], "B") as [vi, vj]: + T.reads([A[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) + T.writes([B[vi * 4 : vi * 4 + 4, vj * 4 : vj * 4 + 4]]) + with T.block([4, 4], "B_sub") as [vi_i, vj_i]: B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def elementwise_subblock_uncovered(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") - with tir.block([32, 32], "B") as [vi, vj]: - tir.reads([A[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) - tir.writes([B[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) - with tir.block([2, 2], "B_sub") as [vi_i, vj_i]: +@T.prim_func +def elementwise_subblock_uncovered(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + with T.block([32, 32], "B") as [vi, vj]: + T.reads([A[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) + T.writes([B[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) + with T.block([2, 2], "B_sub") as [vi_i, vj_i]: B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def bound_to_thread(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - B = tir.alloc_buffer([128, 128], scope="shared") - for i in tir.thread_binding(0, 128, thread="threadIdx.x"): - for j in tir.serial(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def bound_to_thread(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + C = T.match_buffer(c, [128, 128]) + B = T.alloc_buffer([128, 128], scope="shared") + for i in T.thread_binding(0, 128, thread="threadIdx.x"): + for j in T.serial(0, 128): + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - for j in tir.serial(0, 128): - with tir.block([128, 128], "C") as [vi, vj]: + for j in T.serial(0, 128): + with T.block([128, 128], "C") as [vi, vj]: C[vj, vi] = B[vj, vi] + 1.0 -@tvm.script.tir -def equal_ranked_threads(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - B = tir.alloc_buffer([128, 128], scope="shared") - for i_o in tir.thread_binding(0, 16, thread="threadIdx.x"): - for i_i in tir.thread_binding(0, 8, thread="threadIdx.y"): - for j in tir.serial(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i_o * 8 + i_i) - tir.bind(vj, j) +@T.prim_func +def equal_ranked_threads(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + C = T.match_buffer(c, [128, 128]) + B = T.alloc_buffer([128, 128], scope="shared") + for i_o in T.thread_binding(0, 16, thread="threadIdx.x"): + for i_i in T.thread_binding(0, 8, thread="threadIdx.y"): + for j in T.serial(0, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i_o * 8 + i_i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 - for j in tir.serial(0, 128): - with tir.block([128, 128], "C") as [vi, vj]: - tir.bind(vi, i_o * 8 + i_i) - tir.bind(vj, j) + for j in T.serial(0, 128): + with T.block([128, 128], "C") as [vi, vj]: + T.bind(vi, i_o * 8 + i_i) + T.bind(vj, j) C[vj, vi] = B[vj, vi] + 1.0 -@tvm.script.tir -def warp_memory(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - B = tir.alloc_buffer([128, 4, 32], scope="warp") - for i_o in tir.thread_binding(0, 4, thread="threadIdx.y"): - for i_i in tir.thread_binding(0, 32, thread="threadIdx.x"): - for j in tir.serial(0, 128): - with tir.block([4, 32, 128], "B") as [warp_id, lane_id, vj]: +@T.prim_func +def warp_memory(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + C = T.match_buffer(c, [128, 128]) + B = T.alloc_buffer([128, 4, 32], scope="warp") + for i_o in T.thread_binding(0, 4, thread="threadIdx.y"): + for i_i in T.thread_binding(0, 32, thread="threadIdx.x"): + for j in T.serial(0, 128): + with T.block([4, 32, 128], "B") as [warp_id, lane_id, vj]: B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 - for j in tir.serial(0, 128): - with tir.block([4, 32, 128], "C") as [warp_id, lane_id, vj]: + for j in T.serial(0, 128): + with T.block([4, 32, 128], "C") as [warp_id, lane_id, vj]: C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0 -@tvm.script.tir -def warp_memory_negative(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - B = tir.alloc_buffer([128, 4, 32], scope="warp") - for i_o in tir.thread_binding(0, 4, thread="threadIdx.y"): - for i_i in tir.thread_binding(0, 32, thread="threadIdx.x"): - for j in tir.serial(0, 128): - with tir.block([4, 32, 128], "B") as [warp_id, lane_id, vj]: +@T.prim_func +def warp_memory_negative(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + C = T.match_buffer(c, [128, 128]) + B = T.alloc_buffer([128, 4, 32], scope="warp") + for i_o in T.thread_binding(0, 4, thread="threadIdx.y"): + for i_i in T.thread_binding(0, 32, thread="threadIdx.x"): + for j in T.serial(0, 128): + with T.block([4, 32, 128], "B") as [warp_id, lane_id, vj]: B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 - for i_o_prime in tir.thread_binding(0, 4, thread="threadIdx.y"): - for j in tir.serial(0, 128): - with tir.block([4, 32, 4, 128], "C") as [_warp_id, lane_id, warp_id, vj]: + for i_o_prime in T.thread_binding(0, 4, thread="threadIdx.y"): + for j in T.serial(0, 128): + with T.block([4, 32, 4, 128], "C") as [_warp_id, lane_id, warp_id, vj]: C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0 diff --git a/tests/python/unittest/test_tir_schedule_storage_align.py b/tests/python/unittest/test_tir_schedule_storage_align.py index a0a069347f950..7d0e91f70e609 100644 --- a/tests/python/unittest/test_tir_schedule_storage_align.py +++ b/tests/python/unittest/test_tir_schedule_storage_align.py @@ -18,90 +18,90 @@ import pytest import tvm from tvm import tir -from tvm.script import ty +from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name -@tvm.script.tir -def element_wise(a: ty.handle, c: ty.handle) -> None: - C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) - A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) +@T.prim_func +def element_wise(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with tir.block([], "root"): - tir.reads([]) - tir.writes([]) - B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) - for i0 in tir.serial(0, 128): - for ax1 in tir.serial(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i0) - tir.bind(vj, ax1) - tir.reads([A[vi, vj]]) - tir.writes([B[vi, vj]]) - B[vi, vj] = (A[vi, vj]*tir.float32(2)) - for i1 in tir.serial(0, 128): - with tir.block([128, 128], "C") as [vi_1, vj_1]: - tir.bind(vi_1, i0) - tir.bind(vj_1, i1) - tir.reads([B[vi_1, vj_1]]) - tir.writes([C[vi_1, vj_1]]) - C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1)) - - -@tvm.script.tir -def element_wise_storage_align(a: ty.handle, c: ty.handle) -> None: - C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) - A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + with T.block([], "root"): + T.reads([]) + T.writes([]) + B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in T.serial(0, 128): + for ax1 in T.serial(0, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i0) + T.bind(vj, ax1) + T.reads([A[vi, vj]]) + T.writes([B[vi, vj]]) + B[vi, vj] = (A[vi, vj]*T.float32(2)) + for i1 in T.serial(0, 128): + with T.block([128, 128], "C") as [vi_1, vj_1]: + T.bind(vi_1, i0) + T.bind(vj_1, i1) + T.reads([B[vi_1, vj_1]]) + T.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1)) + + +@T.prim_func +def element_wise_storage_align(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with tir.block([], "root"): - tir.reads([]) - tir.writes([]) - B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) - for i0 in tir.serial(0, 128): - for ax1 in tir.serial(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, i0) - tir.bind(vj, ax1) - tir.reads([A[vi, vj]]) - tir.writes([B[vi, vj]]) - tir.block_attr({"buffer_dim_align":[[0, 0, 128, 127]]}) - B[vi, vj] = (A[vi, vj]*tir.float32(2)) - for i1 in tir.serial(0, 128): - with tir.block([128, 128], "C") as [vi_1, vj_1]: - tir.bind(vi_1, i0) - tir.bind(vj_1, i1) - tir.reads([B[vi_1, vj_1]]) - tir.writes([C[vi_1, vj_1]]) - C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1)) - - -@tvm.script.tir -def element_wise_invalid_annotation(a: ty.handle, c: ty.handle) -> None: - C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) - A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + with T.block([], "root"): + T.reads([]) + T.writes([]) + B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in T.serial(0, 128): + for ax1 in T.serial(0, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.bind(vi, i0) + T.bind(vj, ax1) + T.reads([A[vi, vj]]) + T.writes([B[vi, vj]]) + T.block_attr({"buffer_dim_align":[[0, 0, 128, 127]]}) + B[vi, vj] = (A[vi, vj]*T.float32(2)) + for i1 in T.serial(0, 128): + with T.block([128, 128], "C") as [vi_1, vj_1]: + T.bind(vi_1, i0) + T.bind(vj_1, i1) + T.reads([B[vi_1, vj_1]]) + T.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1)) + + +@T.prim_func +def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body - with tir.block([], "root"): - tir.reads([]) - tir.writes([]) - B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) - for i0 in tir.serial(0, 128): - for ax1 in tir.serial(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - tir.block_attr({"buffer_dim_align": [0]}) - tir.bind(vi, i0) - tir.bind(vj, ax1) - tir.reads([A[vi, vj]]) - tir.writes([B[vi, vj]]) - B[vi, vj] = (A[vi, vj]*tir.float32(2)) - for i1 in tir.serial(0, 128): - with tir.block([128, 128], "C") as [vi_1, vj_1]: - tir.bind(vi_1, i0) - tir.bind(vj_1, i1) - tir.reads([B[vi_1, vj_1]]) - tir.writes([C[vi_1, vj_1]]) - C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1)) + with T.block([], "root"): + T.reads([]) + T.writes([]) + B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in T.serial(0, 128): + for ax1 in T.serial(0, 128): + with T.block([128, 128], "B") as [vi, vj]: + T.block_attr({"buffer_dim_align": [0]}) + T.bind(vi, i0) + T.bind(vj, ax1) + T.reads([A[vi, vj]]) + T.writes([B[vi, vj]]) + B[vi, vj] = (A[vi, vj]*T.float32(2)) + for i1 in T.serial(0, 128): + with T.block([128, 128], "C") as [vi_1, vj_1]: + T.bind(vi_1, i0) + T.bind(vj_1, i1) + T.reads([B[vi_1, vj_1]]) + T.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1)) def test_storage_align(): diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index da7b096ade17e..36e05c6b51701 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -21,28 +21,28 @@ import pytest import tvm from tvm import tir -from tvm.script import ty +from tvm.script import tir as T from tvm.tir.schedule import BlockRV, Instruction, InstructionKind, LoopRV, Trace # pylint: disable=no-member,invalid-name,unused-variable -@tvm.script.tir -def elementwise(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - B = tir.alloc_buffer((128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: +@T.prim_func +def elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.alloc_buffer((128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 128], "C") as [vi, vj]: + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def elementwise_inlined(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128)) - C = tir.match_buffer(c, (128, 128)) - with tir.block([128, 128], "C") as [vi, vj]: +@T.prim_func +def elementwise_inlined(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + with T.block([128, 128], "C") as [vi, vj]: C[vi, vj] = A[vi, vj] * 2.0 + 1.0 diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index dcaeaaad6164f..185d229b44e14 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -22,22 +22,22 @@ from tvm import tir from tvm.ir import IRModule -from tvm.script import ty +from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip # pylint: disable=no-member,invalid-name,unused-variable -@tvm.script.tir -def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "init") as [vi, vj]: - C[vi, vj] = tir.float32(0) +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j in T.grid(128, 128): + with T.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = T.float32(0) for k in range(0, 128): - with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] diff --git a/tests/python/unittest/test_tir_specialize.py b/tests/python/unittest/test_tir_specialize.py index d6cfadaf1fbcd..86dc5dffed9f1 100644 --- a/tests/python/unittest/test_tir_specialize.py +++ b/tests/python/unittest/test_tir_specialize.py @@ -17,149 +17,146 @@ # pylint: disable=missing-function-docstring, missing-module-docstring import tvm -from tvm import tir -from tvm.script import ty +from tvm.script import tir as T -@tvm.script.tir -def matmul(a: ty.handle, b: ty.handle, c: ty.handle, n: ty.int32) -> None: - m = tir.var("int32") - A = tir.match_buffer(a, [m, n]) - B = tir.match_buffer(b, [m, n]) - C = tir.match_buffer(c, [m, m]) +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle, n: T.int32) -> None: + m = T.var("int32") + A = T.match_buffer(a, [m, n]) + B = T.match_buffer(b, [m, n]) + C = T.match_buffer(c, [m, m]) - with tir.block([m, m, tir.reduce_axis(0, n)], "update") as [vi, vj, vk]: - with tir.init(): + with T.block([m, m, T.reduce_axis(0, n)], "update") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@tvm.script.tir -def matmul_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) +@T.prim_func +def matmul_128(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) - with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with tir.init(): + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@tvm.script.tir -def matmul_m_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - m = tir.var("int32") - A = tir.match_buffer(a, [m, 128]) - B = tir.match_buffer(b, [m, 128]) - C = tir.match_buffer(c, [m, m]) +@T.prim_func +def matmul_m_128(a: T.handle, b: T.handle, c: T.handle) -> None: + m = T.var("int32") + A = T.match_buffer(a, [m, 128]) + B = T.match_buffer(b, [m, 128]) + C = T.match_buffer(c, [m, m]) - with tir.block([m, m, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with tir.init(): + with T.block([m, m, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@tvm.script.tir -def matmul_m_8x(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - x = tir.var("int32") - m = tir.var("int32") - A = tir.match_buffer(a, [m, x * 8]) - B = tir.match_buffer(b, [m, x * 8]) - C = tir.match_buffer(c, [m, m]) +@T.prim_func +def matmul_m_8x(a: T.handle, b: T.handle, c: T.handle) -> None: + x = T.var("int32") + m = T.var("int32") + A = T.match_buffer(a, [m, x * 8]) + B = T.match_buffer(b, [m, x * 8]) + C = T.match_buffer(c, [m, m]) - with tir.block([m, m, tir.reduce_axis(0, x * 8)], "update") as [vi, vj, vk]: - with tir.init(): + with T.block([m, m, T.reduce_axis(0, x * 8)], "update") as [vi, vj, vk]: + with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@tvm.script.tir -def element_wise(a: ty.handle, c: ty.handle) -> None: - m = tir.var("int32") - n = tir.var("int32") - A = tir.match_buffer(a, (m, n), "float32") - C = tir.match_buffer(c, (m, n), "float32") +@T.prim_func +def element_wise(a: T.handle, c: T.handle) -> None: + m = T.var("int32") + n = T.var("int32") + A = T.match_buffer(a, (m, n), "float32") + C = T.match_buffer(c, (m, n), "float32") - B = tir.alloc_buffer((m, n), "float32") + B = T.alloc_buffer((m, n), "float32") - with tir.block([m, n], "B") as [vi, vj]: + with T.block([m, n], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([m, n], "C") as [vi, vj]: + with T.block([m, n], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def element_wise_128_64(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 64), "float32") - C = tir.match_buffer(c, (128, 64), "float32") - B = tir.alloc_buffer((128, 64), "float32") +@T.prim_func +def element_wise_128_64(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 64), "float32") + C = T.match_buffer(c, (128, 64), "float32") + B = T.alloc_buffer((128, 64), "float32") - with tir.block([128, 64], "B") as [vi, vj]: + with T.block([128, 64], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, 64], "C") as [vi, vj]: + with T.block([128, 64], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def element_wise_128_n(a: ty.handle, c: ty.handle) -> None: - n = tir.var("int32") - A = tir.match_buffer(a, (128, n), "float32") - C = tir.match_buffer(c, (128, n), "float32") - B = tir.alloc_buffer((128, n), "float32") +@T.prim_func +def element_wise_128_n(a: T.handle, c: T.handle) -> None: + n = T.var("int32") + A = T.match_buffer(a, (128, n), "float32") + C = T.match_buffer(c, (128, n), "float32") + B = T.alloc_buffer((128, n), "float32") - with tir.block([128, n], "B") as [vi, vj]: + with T.block([128, n], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 - with tir.block([128, n], "C") as [vi, vj]: + with T.block([128, n], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 -@tvm.script.tir -def mem_copy( - a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty.int32, q: ty.int32 -) -> None: - A = tir.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=q) - B = tir.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=q) +@T.prim_func +def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int32, q: T.int32) -> None: + A = T.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=q) + B = T.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=q) - with tir.block([m, n], "") as [vi, vj]: + with T.block([m, n], "") as [vi, vj]: B[vi, vj] = A[vi, vj] -@tvm.script.tir -def mem_copy_16_16_8_4(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32", strides=[8, 1], elem_offset=4) - B = tir.match_buffer(b, (16, 16), "float32", strides=[8, 1], elem_offset=4) +@T.prim_func +def mem_copy_16_16_8_4(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32", strides=[8, 1], elem_offset=4) + B = T.match_buffer(b, (16, 16), "float32", strides=[8, 1], elem_offset=4) - with tir.block([16, 16], "") as [vi, vj]: + with T.block([16, 16], "") as [vi, vj]: B[vi, vj] = A[vi, vj] -@tvm.script.tir -def mem_copy_m_n_p_n(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty.int32) -> None: - A = tir.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=n) - B = tir.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=n) +@T.prim_func +def mem_copy_m_n_p_n(a: T.handle, b: T.handle, m: T.int32, n: T.int32, p: T.int32) -> None: + A = T.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=n) + B = T.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=n) - with tir.block([m, n], "") as [vi, vj]: + with T.block([m, n], "") as [vi, vj]: B[vi, vj] = A[vi, vj] -@tvm.script.tir -def param_in_arith_exprs(a: ty.handle, b: ty.handle) -> None: - n = tir.var("int32") - A = tir.match_buffer(a, [n // 8, 8], "int32") - B = tir.match_buffer(b, [n], "int32") - with tir.block([n - 1], "") as [vi]: +@T.prim_func +def param_in_arith_exprs(a: T.handle, b: T.handle) -> None: + n = T.var("int32") + A = T.match_buffer(a, [n // 8, 8], "int32") + B = T.match_buffer(b, [n], "int32") + with T.block([n - 1], "") as [vi]: B[vi] = A[vi // 8, vi % 8] + (n + 1) * 42 -@tvm.script.tir -def param_in_arith_exprs_n_16(a: ty.handle, b: ty.handle) -> None: - n = tir.var("int32") - A = tir.match_buffer(a, [2, 8], "int32") - B = tir.match_buffer(b, [16], "int32") - with tir.block([15], "") as [vi]: +@T.prim_func +def param_in_arith_exprs_n_16(a: T.handle, b: T.handle) -> None: + n = T.var("int32") + A = T.match_buffer(a, [2, 8], "int32") + B = T.match_buffer(b, [16], "int32") + with T.block([15], "") as [vi]: B[vi] = A[vi // 8, vi % 8] + 714 @@ -171,13 +168,13 @@ def test_specialize_nothing(): def test_specialize_matmul(): a, _, _, n = matmul.params # fully specialized - func = matmul.specialize({a: tir.decl_buffer((128, 128))}) + func = matmul.specialize({a: tvm.tir.decl_buffer((128, 128))}) tvm.ir.assert_structural_equal(func, matmul_128) # partially specialized func = matmul.specialize({n: 128}) tvm.ir.assert_structural_equal(func, matmul_m_128) # symbolic specialized - func = matmul.specialize({n: tir.Var("x", "int32") * 8}) + func = matmul.specialize({n: tvm.tir.Var("x", "int32") * 8}) tvm.ir.assert_structural_equal(func, matmul_m_8x) @@ -185,17 +182,17 @@ def test_specialize_elemwise(): a, c = element_wise.params C = element_wise.buffer_map[c] # fully specialized - func = element_wise.specialize({a: tir.decl_buffer((128, 64))}) + func = element_wise.specialize({a: tvm.tir.decl_buffer((128, 64))}) tvm.ir.assert_structural_equal(func, element_wise_128_64) # partially specialized - func = element_wise.specialize({c: tir.decl_buffer((128, C.shape[1]))}) + func = element_wise.specialize({c: tvm.tir.decl_buffer((128, C.shape[1]))}) tvm.ir.assert_structural_equal(func, element_wise_128_n) def test_specialize_mem_copy(): a, _, m, n, p, q = mem_copy.params # fully specialized - func = mem_copy.specialize({a: tir.decl_buffer((16, 16), strides=[8, 1], elem_offset=4)}) + func = mem_copy.specialize({a: tvm.tir.decl_buffer((16, 16), strides=[8, 1], elem_offset=4)}) tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4) func = mem_copy.specialize({n: 16, m: 16, p: 8, q: 4}) tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4) @@ -211,7 +208,7 @@ def test_specialize_recursive_load(): def test_specialize_with_const_folding(): b = param_in_arith_exprs.params[1] - func = param_in_arith_exprs.specialize({b: tir.decl_buffer([16])}) + func = param_in_arith_exprs.specialize({b: tvm.tir.decl_buffer([16])}) tvm.ir.assert_structural_equal(func, param_in_arith_exprs_n_16) diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index cefdb5fd8c6a4..0cfc724e41de2 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import tir, te -from tvm.script import ty +from tvm.script import tir as T def _check(original, transformed): @@ -27,359 +27,359 @@ def _check(original, transformed): tvm.ir.assert_structural_equal(mod["main"], transformed) -@tvm.script.tir -def elementwise_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") +@T.prim_func +def elementwise_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with tir.block([]): - tir.reads(A[i, 0:16]) - tir.writes(C[i, 0:16]) - B = tir.alloc_buffer((16, 16), "float32") + with T.block([]): + T.reads(A[i, 0:16]) + T.writes(C[i, 0:16]) + B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with tir.block([]) as []: - tir.reads(A[i, j]) - tir.writes(B[i, j]) + with T.block([]) as []: + T.reads(A[i, j]) + T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): - with tir.block([]) as []: - tir.reads(B[i, j]) - tir.writes(C[i, j]) + with T.block([]) as []: + T.reads(B[i, j]) + T.writes(C[i, j]) C[i, j] = B[i, j] * 2.0 -@tvm.script.tir -def compacted_elementwise_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") +@T.prim_func +def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with tir.block([]): - tir.reads(A[i, 0:16]) - tir.writes(C[i, 0:16]) - B = tir.alloc_buffer((1, 16), "float32") + with T.block([]): + T.reads(A[i, 0:16]) + T.writes(C[i, 0:16]) + B = T.alloc_buffer((1, 16), "float32") for j in range(0, 16): - with tir.block() as []: - tir.reads(A[i, j]) - tir.writes(B[0, j]) + with T.block() as []: + T.reads(A[i, j]) + T.writes(B[0, j]) B[0, j] = A[i, j] + 1.0 for j in range(0, 16): - with tir.block() as []: - tir.reads(B[0, j]) - tir.writes(C[i, j]) + with T.block() as []: + T.reads(B[0, j]) + T.writes(C[i, j]) C[i, j] = B[0, j] * 2.0 -@tvm.script.tir -def unschedulable_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") +@T.prim_func +def unschedulable_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with tir.block([]): - tir.reads(A[i, 0:16]) - tir.writes(C[i, 0:16]) - B = tir.alloc_buffer((16, 16), "float32") + with T.block([]): + T.reads(A[i, 0:16]) + T.writes(C[i, 0:16]) + B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - tir.store(B.data, i * 16 + j, A[i, j] + 1.0) + T.store(B.data, i * 16 + j, A[i, j] + 1.0) for j in range(0, 16): C[i, j] = B[i, j] * 2.0 -@tvm.script.tir -def param_buffer_access_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (20, 20), "float32") - B = tir.match_buffer(c, (20, 20), "float32") +@T.prim_func +def param_buffer_access_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (20, 20), "float32") + B = T.match_buffer(c, (20, 20), "float32") for i in range(0, 16): - with tir.block([]): - tir.reads(A[i, 0:16]) - tir.writes(B[i, 0:16]) + with T.block([]): + T.reads(A[i, 0:16]) + T.writes(B[i, 0:16]) for j in range(0, 16): - with tir.block([]) as []: - tir.reads(A[i, j]) - tir.writes(B[i, j]) + with T.block([]) as []: + T.reads(A[i, j]) + T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 -@tvm.script.tir -def shared_mem_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") - for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"): - for i1 in tir.thread_binding(0, 2, thread="vthread"): - for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"): - with tir.block([]): - tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) - tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) - B = tir.alloc_buffer((16, 16), "float32", scope="shared") +@T.prim_func +def shared_mem_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") + for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): + for i1 in T.thread_binding(0, 2, thread="vthread"): + for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): + with T.block([]): + T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) + T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) + B = T.alloc_buffer((16, 16), "float32", scope="shared") for j in range(0, 16): - with tir.block([]) as []: - tir.reads(A[i0 * 8 + i1 * 4 + i2, j]) - tir.writes(B[i0 * 8 + i1 * 4 + i2, j]) + with T.block([]) as []: + T.reads(A[i0 * 8 + i1 * 4 + i2, j]) + T.writes(B[i0 * 8 + i1 * 4 + i2, j]) B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with tir.block([]) as []: - tir.reads(B[i0 * 8 + i1 * 4 + i2, j]) - tir.writes(C[i0 * 8 + i1 * 4 + i2, j]) + with T.block([]) as []: + T.reads(B[i0 * 8 + i1 * 4 + i2, j]) + T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0 -@tvm.script.tir -def compacted_shared_mem_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") - for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"): - for i1 in tir.thread_binding(0, 2, thread="vthread"): - for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"): - with tir.block([]): - tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) - tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) - B = tir.alloc_buffer((8, 16), "float32", scope="shared") +@T.prim_func +def compacted_shared_mem_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") + for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): + for i1 in T.thread_binding(0, 2, thread="vthread"): + for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): + with T.block([]): + T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) + T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) + B = T.alloc_buffer((8, 16), "float32", scope="shared") for j in range(0, 16): - with tir.block([]) as []: - tir.reads(A[i0 * 8 + i1 * 4 + i2, j]) - tir.writes(B[i1 * 4 + i2, j]) + with T.block([]) as []: + T.reads(A[i0 * 8 + i1 * 4 + i2, j]) + T.writes(B[i1 * 4 + i2, j]) B[i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with tir.block([]) as []: - tir.reads(B[i1 * 4 + i2, j]) - tir.writes(C[i0 * 8 + i1 * 4 + i2, j]) + with T.block([]) as []: + T.reads(B[i1 * 4 + i2, j]) + T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i1 * 4 + i2, j] * 2.0 -@tvm.script.tir -def warp_mem_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") - for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"): - for i1 in tir.thread_binding(0, 2, thread="vthread"): - for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"): - with tir.block([]): - tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) - tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) - B = tir.alloc_buffer((16, 16), "float32", scope="warp") +@T.prim_func +def warp_mem_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") + for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): + for i1 in T.thread_binding(0, 2, thread="vthread"): + for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): + with T.block([]): + T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) + T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) + B = T.alloc_buffer((16, 16), "float32", scope="warp") for j in range(0, 16): - with tir.block([]) as []: - tir.reads(A[i0 * 8 + i1 * 4 + i2, j]) - tir.writes(B[i0 * 8 + i1 * 4 + i2, j]) + with T.block([]) as []: + T.reads(A[i0 * 8 + i1 * 4 + i2, j]) + T.writes(B[i0 * 8 + i1 * 4 + i2, j]) B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with tir.block([]) as []: - tir.reads(B[i0 * 8 + i1 * 4 + i2, j]) - tir.writes(C[i0 * 8 + i1 * 4 + i2, j]) + with T.block([]) as []: + T.reads(B[i0 * 8 + i1 * 4 + i2, j]) + T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0 -@tvm.script.tir -def compacted_warp_mem_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") - for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"): - for i1 in tir.thread_binding(0, 2, thread="vthread"): - for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"): - with tir.block([]): - tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) - tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) - B = tir.alloc_buffer((4, 16), "float32", scope="warp") +@T.prim_func +def compacted_warp_mem_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") + for i0 in T.thread_binding(0, 2, thread="blockIdx.x"): + for i1 in T.thread_binding(0, 2, thread="vthread"): + for i2 in T.thread_binding(0, 4, thread="threadIdx.x"): + with T.block([]): + T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) + T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) + B = T.alloc_buffer((4, 16), "float32", scope="warp") for j in range(0, 16): - with tir.block([]) as []: - tir.reads(A[i0 * 8 + i1 * 4 + i2, j]) - tir.writes(B[i2, j]) + with T.block([]) as []: + T.reads(A[i0 * 8 + i1 * 4 + i2, j]) + T.writes(B[i2, j]) B[i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 for j in range(0, 16): - with tir.block([]) as []: - tir.reads(B[i2, j]) - tir.writes(C[i0 * 8 + i1 * 4 + i2, j]) + with T.block([]) as []: + T.reads(B[i2, j]) + T.writes(C[i0 * 8 + i1 * 4 + i2, j]) C[i0 * 8 + i1 * 4 + i2, j] = B[i2, j] * 2.0 -@tvm.script.tir -def symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None: - A = tir.match_buffer(a, (n * 8,), "float32") - C = tir.match_buffer(c, (n * 8,), "float32") +@T.prim_func +def symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None: + A = T.match_buffer(a, (n * 8,), "float32") + C = T.match_buffer(c, (n * 8,), "float32") for i in range(0, n): - with tir.block([]): - tir.reads(A[i * 8 : i * 8 + 8]) - tir.writes(C[i * 8 : i * 8 + 8]) - B = tir.alloc_buffer((n * 8,), "float32") + with T.block([]): + T.reads(A[i * 8 : i * 8 + 8]) + T.writes(C[i * 8 : i * 8 + 8]) + B = T.alloc_buffer((n * 8,), "float32") for j in range(0, 8): - with tir.block([]) as []: - tir.reads(A[i * 8 + j]) - tir.writes(B[i * 8 + j]) + with T.block([]) as []: + T.reads(A[i * 8 + j]) + T.writes(B[i * 8 + j]) B[i * 8 + j] = A[i * 8 + j] + 1.0 for j in range(0, 8): - with tir.block([]) as []: - tir.reads(B[i * 8 + j]) - tir.writes(C[i * 8 + j]) + with T.block([]) as []: + T.reads(B[i * 8 + j]) + T.writes(C[i * 8 + j]) C[i * 8 + j] = B[i * 8 + j] * 2.0 -@tvm.script.tir -def compacted_symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None: - A = tir.match_buffer(a, (n * 8,), "float32") - C = tir.match_buffer(c, (n * 8,), "float32") +@T.prim_func +def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None: + A = T.match_buffer(a, (n * 8,), "float32") + C = T.match_buffer(c, (n * 8,), "float32") for i in range(0, n): - with tir.block([]): - tir.reads(A[i * 8 : i * 8 + 8]) - tir.writes(C[i * 8 : i * 8 + 8]) - B = tir.alloc_buffer((8,), "float32") + with T.block([]): + T.reads(A[i * 8 : i * 8 + 8]) + T.writes(C[i * 8 : i * 8 + 8]) + B = T.alloc_buffer((8,), "float32") for j in range(0, 8): - with tir.block([]) as []: - tir.reads(A[i * 8 + j]) - tir.writes(B[j]) + with T.block([]) as []: + T.reads(A[i * 8 + j]) + T.writes(B[j]) B[j] = A[i * 8 + j] + 1.0 for j in range(0, 8): - with tir.block([]) as []: - tir.reads(B[j]) - tir.writes(C[i * 8 + j]) + with T.block([]) as []: + T.reads(B[j]) + T.writes(C[i * 8 + j]) C[i * 8 + j] = B[j] * 2.0 -@tvm.script.tir -def complex_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None: - A = tir.match_buffer(a, (8, 8), "float32") - C = tir.match_buffer(c, (8, 8), "float32") +@T.prim_func +def complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: + A = T.match_buffer(a, (8, 8), "float32") + C = T.match_buffer(c, (8, 8), "float32") for i in range(0, 8): - with tir.block([]): - tir.reads(A[0, 8]) - tir.writes(C[0, 8]) - B = tir.alloc_buffer((8, 8), "float32") + with T.block([]): + T.reads(A[0, 8]) + T.writes(C[0, 8]) + B = T.alloc_buffer((8, 8), "float32") for j in range(0, 4): - with tir.block([]) as []: - D = tir.alloc_buffer((8, 8), "float32") - tir.reads(A[i, j]) - tir.writes(B[i, j]) + with T.block([]) as []: + D = T.alloc_buffer((8, 8), "float32") + T.reads(A[i, j]) + T.writes(B[i, j]) for k in range(4, 8): D[k, j] = 1.0 for k in range(2, 4): - tir.store(B.data, j, A[i, j] + D[k, j]) + T.store(B.data, j, A[i, j] + D[k, j]) for j in range(3, 5): - with tir.block([]) as []: - tir.reads(B[i, j]) - tir.writes(C[i, j]) + with T.block([]) as []: + T.reads(B[i, j]) + T.writes(C[i, j]) C[i, j] = B[i, j] for j in range(6, 8): - with tir.block([]) as []: - tir.reads(B[i, j]) - tir.writes(C[i, j]) + with T.block([]) as []: + T.reads(B[i, j]) + T.writes(C[i, j]) C[i, j] = B[i, j] -@tvm.script.tir -def compacted_complex_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None: - A = tir.match_buffer(a, (8, 8), "float32") - C = tir.match_buffer(c, (8, 8), "float32") +@T.prim_func +def compacted_complex_func(a: T.handle, c: T.handle, n: T.int32) -> None: + A = T.match_buffer(a, (8, 8), "float32") + C = T.match_buffer(c, (8, 8), "float32") for i in range(0, 8): - with tir.block([]): - tir.reads(A[0, 8]) - tir.writes(C[0, 8]) - B = tir.alloc_buffer((1, 8), "float32") + with T.block([]): + T.reads(A[0, 8]) + T.writes(C[0, 8]) + B = T.alloc_buffer((1, 8), "float32") for j in range(0, 4): - with tir.block([]) as []: - D = tir.alloc_buffer((6, 1), "float32") - tir.reads(A[i, j]) - tir.writes(B[0, j]) + with T.block([]) as []: + D = T.alloc_buffer((6, 1), "float32") + T.reads(A[i, j]) + T.writes(B[0, j]) for k in range(4, 8): D[k - 2, 0] = 1.0 for k in range(2, 4): - tir.store(B.data, j, A[i, j] + D[k - 2, 0]) + T.store(B.data, j, A[i, j] + D[k - 2, 0]) for j in range(3, 5): - with tir.block([]) as []: - tir.reads(B[0, j]) - tir.writes(C[i, j]) + with T.block([]) as []: + T.reads(B[0, j]) + T.writes(C[i, j]) C[i, j] = B[0, j] for j in range(6, 8): - with tir.block([]) as []: - tir.reads(B[0, j]) - tir.writes(C[i, j]) + with T.block([]) as []: + T.reads(B[0, j]) + T.writes(C[i, j]) C[i, j] = B[0, j] -@tvm.script.tir -def match_buffer_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16)) - C = tir.match_buffer(c, (16, 16)) +@T.prim_func +def match_buffer_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16)) + C = T.match_buffer(c, (16, 16)) for i in range(0, 16): - with tir.block([]): - A0 = tir.match_buffer(A[i, 0:16], (16)) - C0 = tir.match_buffer(C[i, 0:16], (16)) - B = tir.alloc_buffer((16, 16)) - with tir.block([]): - B0 = tir.match_buffer(B[i, 0:16], (16)) + with T.block([]): + A0 = T.match_buffer(A[i, 0:16], (16)) + C0 = T.match_buffer(C[i, 0:16], (16)) + B = T.alloc_buffer((16, 16)) + with T.block([]): + B0 = T.match_buffer(B[i, 0:16], (16)) for j in range(0, 16): - with tir.block([]) as []: - A1 = tir.match_buffer(A0[j], ()) - B1 = tir.match_buffer(B0[j], ()) + with T.block([]) as []: + A1 = T.match_buffer(A0[j], ()) + B1 = T.match_buffer(B0[j], ()) B1[()] = A1[()] + 1.0 for j in range(0, 16): - with tir.block([]) as []: - C1 = tir.match_buffer(C0[j], ()) - B2 = tir.match_buffer(B[i, j], ()) + with T.block([]) as []: + C1 = T.match_buffer(C0[j], ()) + B2 = T.match_buffer(B[i, j], ()) C1[()] = B2[()] * 2.0 -@tvm.script.tir -def compacted_match_buffer_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16)) - C = tir.match_buffer(c, (16, 16)) +@T.prim_func +def compacted_match_buffer_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16)) + C = T.match_buffer(c, (16, 16)) for i in range(0, 16): - with tir.block([]): - A0 = tir.match_buffer(A[i, 0:16], (16)) - C0 = tir.match_buffer(C[i, 0:16], (16)) - B = tir.alloc_buffer((1, 16)) - with tir.block([]): - B0 = tir.match_buffer(B[0, 0:16], (16)) + with T.block([]): + A0 = T.match_buffer(A[i, 0:16], (16)) + C0 = T.match_buffer(C[i, 0:16], (16)) + B = T.alloc_buffer((1, 16)) + with T.block([]): + B0 = T.match_buffer(B[0, 0:16], (16)) for j in range(0, 16): - with tir.block([]) as []: - A1 = tir.match_buffer(A0[j], ()) - B1 = tir.match_buffer(B0[j], ()) + with T.block([]) as []: + A1 = T.match_buffer(A0[j], ()) + B1 = T.match_buffer(B0[j], ()) B1[()] = A1[()] + 1.0 for j in range(0, 16): - with tir.block([]) as []: - C1 = tir.match_buffer(C0[j], ()) - B2 = tir.match_buffer(B[0, j], ()) + with T.block([]) as []: + C1 = T.match_buffer(C0[j], ()) + B2 = T.match_buffer(B[0, j], ()) C1[()] = B2[()] * 2.0 -@tvm.script.tir -def storage_align_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") +@T.prim_func +def storage_align_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with tir.block([]): - tir.reads(A[i, 0:16]) - tir.writes(C[i, 0:16]) - B = tir.alloc_buffer((16, 16), "float32") + with T.block([]): + T.reads(A[i, 0:16]) + T.writes(C[i, 0:16]) + B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with tir.block([]) as []: - tir.reads(A[i, j]) - tir.writes(B[i, j]) - tir.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) + with T.block([]) as []: + T.reads(A[i, j]) + T.writes(B[i, j]) + T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): - with tir.block([]) as []: - tir.reads(B[i, j]) - tir.writes(C[i, j]) + with T.block([]) as []: + T.reads(B[i, j]) + T.writes(C[i, j]) C[i, j] = B[i, j] * 2.0 -@tvm.script.tir -def compacted_storage_align_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") +@T.prim_func +def compacted_storage_align_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with tir.block([]): - tir.reads(A[i, 0:16]) - tir.writes(C[i, 0:16]) - B = tir.alloc_buffer((1, 16), strides=(31, 1), dtypes="float32") + with T.block([]): + T.reads(A[i, 0:16]) + T.writes(C[i, 0:16]) + B = T.alloc_buffer((1, 16), strides=(31, 1), dtypes="float32") for j in range(0, 16): - with tir.block() as []: - tir.reads(A[i, j]) - tir.writes(B[0, j]) - tir.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) + with T.block() as []: + T.reads(A[i, j]) + T.writes(B[0, j]) + T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) B[0, j] = A[i, j] + 1.0 for j in range(0, 16): - with tir.block() as []: - tir.reads(B[0, j]) - tir.writes(C[i, j]) + with T.block() as []: + T.reads(B[0, j]) + T.writes(C[i, j]) C[i, j] = B[0, j] * 2.0 diff --git a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py index cfdcc1a659114..287a30916520c 100644 --- a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import tir, te -from tvm.script import ty +from tvm.script import tir as T def _check(original, transformed): @@ -27,45 +27,45 @@ def _check(original, transformed): tvm.ir.assert_structural_equal(mod["main"], transformed) -@tvm.script.tir -def elementwise_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") +@T.prim_func +def elementwise_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with tir.block([]): - tir.reads(A[i, 0:16]) - tir.writes(C[i, 0:16]) - B = tir.alloc_buffer((16, 16), "float32") + with T.block([]): + T.reads(A[i, 0:16]) + T.writes(C[i, 0:16]) + B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): - with tir.block([16, 16]) as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + with T.block([16, 16]) as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) B[vi, vj] = A[vi, vj] + 1.0 for j in range(0, 16): - with tir.block([16, 16]) as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, j) + with T.block([16, 16]) as [vi, vj]: + T.bind(vi, i) + T.bind(vj, j) C[vi, vj] = B[vi, vj] * 2.0 -@tvm.script.tir -def substituted_elementwise_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") +@T.prim_func +def substituted_elementwise_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with tir.block([]): - tir.reads(A[i, 0:16]) - tir.writes(C[i, 0:16]) - B = tir.alloc_buffer([16, 16], "float32") + with T.block([]): + T.reads(A[i, 0:16]) + T.writes(C[i, 0:16]) + B = T.alloc_buffer([16, 16], "float32") for j in range(0, 16): - with tir.block() as []: - tir.reads(A[i, j]) - tir.writes(B[i, j]) + with T.block() as []: + T.reads(A[i, j]) + T.writes(B[i, j]) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): - with tir.block() as []: - tir.reads(B[i, j]) - tir.writes(C[i, j]) + with T.block() as []: + T.reads(B[i, j]) + T.writes(C[i, j]) C[i, j] = B[i, j] * 2.0 diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index c51b5319e85ff..9dd407a9c47b9 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import tir, te -from tvm.script import ty +from tvm.script import tir as T def _check(original, transformed): @@ -27,187 +27,187 @@ def _check(original, transformed): tvm.ir.assert_structural_equal(mod["main"], transformed, True) -@tvm.script.tir -def compacted_elementwise_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") +@T.prim_func +def compacted_elementwise_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): - with tir.block([]): - tir.reads(A[i, 0:16]) - tir.writes(C[i, 0:16]) - B = tir.alloc_buffer([1, 16], "float32", scope="global") + with T.block([]): + T.reads(A[i, 0:16]) + T.writes(C[i, 0:16]) + B = T.alloc_buffer([1, 16], "float32", scope="global") for j in range(0, 16): - with tir.block() as []: - tir.reads(A[i, j]) - tir.writes(B[0, j]) + with T.block() as []: + T.reads(A[i, j]) + T.writes(B[0, j]) B[0, j] = A[i, j] + 1.0 for j in range(0, 16): - with tir.block() as []: - tir.reads(B[0, j]) - tir.writes(C[i, j]) + with T.block() as []: + T.reads(B[0, j]) + T.writes(C[i, j]) C[i, j] = B[0, j] * 2.0 -@tvm.script.tir -def flattened_elementwise_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") - for i in tir.serial(0, 16): - B_new = tir.allocate([16], "float32", "global") - for j in tir.serial(0, 16): - B_new[j] = tir.load("float32", A.data, ((i * 16) + j)) + 1.0 - for j in tir.serial(0, 16): - C.data[((i * 16) + j)] = tir.load("float32", B_new, j) * 2.0 - - -@tvm.script.tir -def compacted_gpu_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") - for i0 in tir.thread_binding(0, 4, thread="blockIdx.x"): - for i1 in tir.thread_binding(0, 2, thread="threadIdx.x"): - for i2 in tir.thread_binding(0, 2, thread="vthread"): - with tir.block([]): - tir.reads(A[i0 * 4 + i1 * 2 + i2, 0:16]) - tir.writes(C[i0 * 4 + i1 * 2 + i2, 0:16]) - B = tir.alloc_buffer([1, 16], "float32", scope="local") +@T.prim_func +def flattened_elementwise_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") + for i in T.serial(0, 16): + B_new = T.allocate([16], "float32", "global") + for j in T.serial(0, 16): + B_new[j] = T.load("float32", A.data, ((i * 16) + j)) + 1.0 + for j in T.serial(0, 16): + C.data[((i * 16) + j)] = T.load("float32", B_new, j) * 2.0 + + +@T.prim_func +def compacted_gpu_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") + for i0 in T.thread_binding(0, 4, thread="blockIdx.x"): + for i1 in T.thread_binding(0, 2, thread="threadIdx.x"): + for i2 in T.thread_binding(0, 2, thread="vthread"): + with T.block([]): + T.reads(A[i0 * 4 + i1 * 2 + i2, 0:16]) + T.writes(C[i0 * 4 + i1 * 2 + i2, 0:16]) + B = T.alloc_buffer([1, 16], "float32", scope="local") for j in range(0, 16): - with tir.block() as []: - tir.reads(A[i0 * 4 + i1 * 2 + i2, j]) - tir.writes(B[0, j]) + with T.block() as []: + T.reads(A[i0 * 4 + i1 * 2 + i2, j]) + T.writes(B[0, j]) B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 for j in range(0, 16): - with tir.block() as []: - tir.reads(B[0, j]) - tir.writes(C[i0 * 4 + i1 * 2 + i2, j]) + with T.block() as []: + T.reads(B[0, j]) + T.writes(C[i0 * 4 + i1 * 2 + i2, j]) C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 -@tvm.script.tir -def flattened_gpu_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") +@T.prim_func +def flattened_gpu_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") - i0 = tir.env_thread("blockIdx.x") - i1 = tir.env_thread("threadIdx.x") - i2 = tir.env_thread("vthread") + i0 = T.env_thread("blockIdx.x") + i1 = T.env_thread("threadIdx.x") + i2 = T.env_thread("vthread") - tir.launch_thread(i0, 4) - tir.launch_thread(i1, 2) - tir.launch_thread(i2, 2) - B = tir.allocate([16], "float32", "local") + T.launch_thread(i0, 4) + T.launch_thread(i1, 2) + T.launch_thread(i2, 2) + B = T.allocate([16], "float32", "local") for j in range(0, 16): - B[j] = tir.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + 1.0 + B[j] = T.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + 1.0 for j in range(0, 16): - C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = tir.load("float32", B, j) * 2.0 + C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = T.load("float32", B, j) * 2.0 -@tvm.script.tir -def compacted_symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32, m: ty.int32) -> None: - A = tir.match_buffer(a, (n, m), "float32") - C = tir.match_buffer(c, (n, m), "float32") +@T.prim_func +def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: + A = T.match_buffer(a, (n, m), "float32") + C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - with tir.block([]): - tir.reads(A[i, m]) - tir.writes(C[i, m]) - B = tir.alloc_buffer((m,), "float32", scope="global") + with T.block([]): + T.reads(A[i, m]) + T.writes(C[i, m]) + B = T.alloc_buffer((m,), "float32", scope="global") for j in range(0, m): - with tir.block([]) as []: - tir.reads(A[i, j]) - tir.writes(B[j]) + with T.block([]) as []: + T.reads(A[i, j]) + T.writes(B[j]) B[j] = A[i, j] + 1.0 for j in range(0, m): - with tir.block([]) as []: - tir.reads(B[j]) - tir.writes(C[i, j]) + with T.block([]) as []: + T.reads(B[j]) + T.writes(C[i, j]) C[i, j] = B[j] * 2.0 -@tvm.script.tir -def flattened_symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32, m: ty.int32) -> None: - A = tir.match_buffer(a, (n, m), "float32") - C = tir.match_buffer(c, (n, m), "float32") +@T.prim_func +def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: + A = T.match_buffer(a, (n, m), "float32") + C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - B = tir.allocate([m], "float32", "global") + B = T.allocate([m], "float32", "global") for j in range(0, m): - B[j] = tir.load("float32", A.data, i * m + j) + 1.0 + B[j] = T.load("float32", A.data, i * m + j) + 1.0 for j in range(0, m): - C.data[i * m + j] = tir.load("float32", B, j) * 2.0 + C.data[i * m + j] = T.load("float32", B, j) * 2.0 -@tvm.script.tir -def compacted_predicate_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (32), "float32") - C = tir.match_buffer(c, (32), "float32") +@T.prim_func +def compacted_predicate_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (32), "float32") + C = T.match_buffer(c, (32), "float32") - for i, j in tir.grid(5, 7): - with tir.block([]) as []: - tir.reads(A[i * 7 + j]) - tir.writes(C[i * 7 + j]) - tir.where(i * 7 + j < 32) + for i, j in T.grid(5, 7): + with T.block([]) as []: + T.reads(A[i * 7 + j]) + T.writes(C[i * 7 + j]) + T.where(i * 7 + j < 32) C[i * 7 + j] = A[i * 7 + j] + 1.0 -@tvm.script.tir -def flattened_predicate_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (32), "float32") - C = tir.match_buffer(c, (32), "float32") +@T.prim_func +def flattened_predicate_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (32), "float32") + C = T.match_buffer(c, (32), "float32") - for i, j in tir.grid(5, 7): + for i, j in T.grid(5, 7): if i * 7 + j < 32: - C.data[i * 7 + j] = tir.load("float32", A.data, i * 7 + j) + 1.0 + C.data[i * 7 + j] = T.load("float32", A.data, i * 7 + j) + 1.0 -@tvm.script.tir -def compacted_unit_loop_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (32), "float32") - C = tir.match_buffer(c, (32), "float32") +@T.prim_func +def compacted_unit_loop_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (32), "float32") + C = T.match_buffer(c, (32), "float32") - for x, y, z in tir.grid(4, 1, 8): - with tir.block([]) as []: - tir.reads(A[x * 8 + y * 8 + z]) - tir.writes(C[x * 8 + y * 8 + z]) + for x, y, z in T.grid(4, 1, 8): + with T.block([]) as []: + T.reads(A[x * 8 + y * 8 + z]) + T.writes(C[x * 8 + y * 8 + z]) C[x * 8 + y * 8 + z] = A[x * 8 + y * 8 + z] + 1.0 -@tvm.script.tir -def flattened_unit_loop_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (32), "float32") - C = tir.match_buffer(c, (32), "float32") +@T.prim_func +def flattened_unit_loop_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (32), "float32") + C = T.match_buffer(c, (32), "float32") - for x, z in tir.grid(4, 8): - C.data[x * 8 + z] = tir.load("float32", A.data, x * 8 + z) + 1.0 + for x, z in T.grid(4, 8): + C.data[x * 8 + z] = T.load("float32", A.data, x * 8 + z) + 1.0 -@tvm.script.tir -def compacted_multi_alloc_func(a: ty.handle, d: ty.handle) -> None: - A = tir.match_buffer(a, (32), "float32") - D = tir.match_buffer(d, (32), "float32") +@T.prim_func +def compacted_multi_alloc_func(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (32), "float32") + D = T.match_buffer(d, (32), "float32") for i in range(0, 32): - with tir.block([]) as []: - tir.reads(A[i]) - tir.writes(D[i]) - B = tir.alloc_buffer((32,), scope="global") - C = tir.alloc_buffer((32,), scope="global") + with T.block([]) as []: + T.reads(A[i]) + T.writes(D[i]) + B = T.alloc_buffer((32,), scope="global") + C = T.alloc_buffer((32,), scope="global") B[i] = A[i] + 1.0 C[i] = A[i] + B[i] D[i] = C[i] * 2.0 -@tvm.script.tir -def flattened_multi_alloc_func(a: ty.handle, d: ty.handle) -> None: - A = tir.match_buffer(a, (32), "float32") - D = tir.match_buffer(d, (32), "float32") +@T.prim_func +def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, (32), "float32") + D = T.match_buffer(d, (32), "float32") for i in range(0, 32): - B = tir.allocate((32,), "float32", "global") - C = tir.allocate((32,), "float32", "global") - B[i] = tir.load("float32", A.data, i) + 1.0 - C[i] = tir.load("float32", A.data, i) + tir.load("float32", B, i) - D.data[i] = tir.load("float32", C, i) * 2.0 + B = T.allocate((32,), "float32", "global") + C = T.allocate((32,), "float32", "global") + B[i] = T.load("float32", A.data, i) + 1.0 + C[i] = T.load("float32", A.data, i) + T.load("float32", B, i) + D.data[i] = T.load("float32", C, i) * 2.0 def test_elementwise(): diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index a219b8d964573..8b109172ea097 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -17,8 +17,7 @@ import tvm import tvm.testing from tvm import te -from tvm import tir -from tvm.script import ty +from tvm.script import tir as T import numpy @@ -539,16 +538,16 @@ def test_simple_rfactor(): assert not tvm.ir.structural_equal(stmt1.body, stmt2.body) -@tvm.script.tir -def partitioned_concat(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) - A = tir.match_buffer(a, [16], dtype="float32") - B = tir.match_buffer(b, [16], dtype="float32") - C = tir.match_buffer(c, [32], dtype="float32") - for i in tir.serial(0, 16): - tir.store(C.data, i, tir.load("float32", A.data, i), True) - for i in tir.serial(0, 16): - tir.store(C.data, i + 16, tir.load("float32", B.data, i + 16), True) +@T.prim_func +def partitioned_concat(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(a, [16], dtype="float32") + B = T.match_buffer(b, [16], dtype="float32") + C = T.match_buffer(c, [32], dtype="float32") + for i in T.serial(0, 16): + T.store(C.data, i, T.load("float32", A.data, i), True) + for i in T.serial(0, 16): + T.store(C.data, i + 16, T.load("float32", B.data, i + 16), True) def test_explicit_partition_hint(): diff --git a/tests/python/unittest/test_tir_transform_lower_init_block.py b/tests/python/unittest/test_tir_transform_lower_init_block.py index 1f8a4adf70547..c1c4fb3d2e8fa 100644 --- a/tests/python/unittest/test_tir_transform_lower_init_block.py +++ b/tests/python/unittest/test_tir_transform_lower_init_block.py @@ -16,73 +16,77 @@ # under the License. import tvm from tvm import tir, te -from tvm.script import ty +from tvm.script import tir as T # pylint: disable=no-self-argument -@tvm.script.tir +@tvm.script.ir_module class WithInit: - def main(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [64, 64, 64]) - B = tir.match_buffer(b, [64]) - - with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: - with tir.init(): - B[i] = tir.float32(0) + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [64, 64, 64]) + B = T.match_buffer(b, [64]) + + with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: + with T.init(): + B[i] = T.float32(0) B[i] += A[i, j, k] -@tvm.script.tir +@tvm.script.ir_module class WithBranch: - def main(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [64, 64, 64]) - B = tir.match_buffer(b, [64]) + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [64, 64, 64]) + B = T.match_buffer(b, [64]) - with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: + with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: if (j == 0) and (k == 32): - B[i] = tir.float32(0) + B[i] = T.float32(0) B[i] += A[i, j, k] -@tvm.script.tir +@tvm.script.ir_module class InitWithMatchBuffer: - def main(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [64, 64, 64]) - B = tir.match_buffer(b, [64]) - - with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: - BB = tir.match_buffer(B[i], ()) - AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64)) - with tir.init(): - BB[()] = tir.float32(0) + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [64, 64, 64]) + B = T.match_buffer(b, [64]) + + with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: + BB = T.match_buffer(B[i], ()) + AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) + with T.init(): + BB[()] = T.float32(0) BB[()] += AA[j, k] -@tvm.script.tir +@tvm.script.ir_module class BranchWithMatchBuffer: - def main(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [64, 64, 64]) - B = tir.match_buffer(b, [64]) - - with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]: - BB = tir.match_buffer(B[i], ()) - AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64)) + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [64, 64, 64]) + B = T.match_buffer(b, [64]) + + with T.block([64, T.reduce_axis(0, 64), T.reduce_axis(32, 64)]) as [i, j, k]: + BB = T.match_buffer(B[i], ()) + AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64)) if (j == 0) and (k == 32): - BB[()] = tir.float32(0) + BB[()] = T.float32(0) BB[()] += AA[j, k] def test_lower_reduction(): - origin_mod = WithInit() + origin_mod = WithInit mod = tvm.tir.transform.LowerInitBlock()(origin_mod) - tvm.ir.assert_structural_equal(mod, WithBranch(), True) + tvm.ir.assert_structural_equal(mod, WithBranch, True) def test_lower_match_buffer(): - origin_mod = InitWithMatchBuffer() + origin_mod = InitWithMatchBuffer mod = tvm.tir.transform.LowerInitBlock()(origin_mod) - tvm.ir.assert_structural_equal(mod, BranchWithMatchBuffer(), True) + tvm.ir.assert_structural_equal(mod, BranchWithMatchBuffer, True) def test_lower_te(): diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py index 07140ab458e60..e55555305a09c 100644 --- a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import tir, te -from tvm.script import ty +from tvm.script import tir as T def _check(original, transformed): @@ -26,130 +26,130 @@ def _check(original, transformed): tvm.ir.assert_structural_equal(mod["main"], transformed) -@tvm.script.tir -def element_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16)) - C = tir.match_buffer(c, (16, 16)) - B = tir.alloc_buffer((16, 16)) +@T.prim_func +def element_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16)) + C = T.match_buffer(c, (16, 16)) + B = T.alloc_buffer((16, 16)) for i_0 in range(0, 16): for j_0 in range(0, 16): - with tir.block([16, 16]) as [i, j]: + with T.block([16, 16]) as [i, j]: B[i, j] = A[i, j] + 1.0 for j_0 in range(0, 16): - with tir.block([16, 16]) as [i, j]: + with T.block([16, 16]) as [i, j]: C[i, j] = B[i, j] * 2.0 -@tvm.script.tir -def transformed_element_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [16, 16]) - C = tir.match_buffer(c, [16, 16]) +@T.prim_func +def transformed_element_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 16]) + C = T.match_buffer(c, [16, 16]) for i_0 in range(0, 16): - with tir.block([]): - tir.reads([A[i_0, 0:16]]) - tir.writes([C[i_0, 0:16]]) - B = tir.alloc_buffer([16, 16]) - for j_0 in tir.serial(0, 16): - with tir.block([16, 16], "") as [i, j]: - tir.bind(i, i_0) - tir.bind(j, j_0) + with T.block([]): + T.reads([A[i_0, 0:16]]) + T.writes([C[i_0, 0:16]]) + B = T.alloc_buffer([16, 16]) + for j_0 in T.serial(0, 16): + with T.block([16, 16], "") as [i, j]: + T.bind(i, i_0) + T.bind(j, j_0) B[i, j] = A[i, j] + 1.0 - for j_0 in tir.serial(0, 16): - with tir.block([16, 16], "") as [i, j]: - tir.bind(i, i_0) - tir.bind(j, j_0) + for j_0 in T.serial(0, 16): + with T.block([16, 16], "") as [i, j]: + T.bind(i, i_0) + T.bind(j, j_0) C[i, j] = B[i, j] * 2.0 -@tvm.script.tir +@T.prim_func def original_func() -> None: - A = tir.alloc_buffer((128, 128), "float32") - with tir.block([128, 128]) as [i, j]: - A[i, j] = tir.float32(0) - with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]: - B = tir.alloc_buffer((128, 128), "float32") - C = tir.alloc_buffer((128, 128), "float32") - D = tir.alloc_buffer((128, 128), "float32") + A = T.alloc_buffer((128, 128), "float32") + with T.block([128, 128]) as [i, j]: + A[i, j] = T.float32(0) + with T.block([32, 32, T.reduce_axis(0, 32)]) as [i, j, k]: + B = T.alloc_buffer((128, 128), "float32") + C = T.alloc_buffer((128, 128), "float32") + D = T.alloc_buffer((128, 128), "float32") if k == 0: - for ii, jj in tir.grid(4, 4): + for ii, jj in T.grid(4, 4): B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] - for ii, jj in tir.grid(4, 4): + for ii, jj in T.grid(4, 4): for kk in range(0, 4): B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] for kk in range(0, 4): B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] -@tvm.script.tir +@T.prim_func def transformed_func() -> None: - A = tir.alloc_buffer([128, 128]) - with tir.block([128, 128], "") as [i, j]: - A[i, j] = tir.float32(0) - with tir.block([32, 32, tir.reduce_axis(0, 32)], "") as [i, j, k]: - B = tir.alloc_buffer([128, 128]) + A = T.alloc_buffer([128, 128]) + with T.block([128, 128], "") as [i, j]: + A[i, j] = T.float32(0) + with T.block([32, 32, T.reduce_axis(0, 32)], "") as [i, j, k]: + B = T.alloc_buffer([128, 128]) if k == 0: - for ii, jj in tir.grid(4, 4): + for ii, jj in T.grid(4, 4): B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] - for ii, jj in tir.grid(4, 4): - with tir.block([], ""): - tir.reads([B[((i * 4) + ii), ((j * 4) + jj)]]) - tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) - C = tir.alloc_buffer([128, 128]) - for kk in tir.serial(0, 4): + for ii, jj in T.grid(4, 4): + with T.block([], ""): + T.reads([B[((i * 4) + ii), ((j * 4) + jj)]]) + T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) + C = T.alloc_buffer([128, 128]) + for kk in T.serial(0, 4): B[((i * 4) + ii), ((j * 4) + jj)] = ( B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)] ) - for kk in tir.serial(0, 4): - with tir.block([], ""): - tir.reads( + for kk in T.serial(0, 4): + with T.block([], ""): + T.reads( [ B[((i * 4) + ii), ((j * 4) + jj)], C[((i * 4) + ii), ((k * 4) + kk)], ] ) - tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) - D = tir.alloc_buffer([128, 128]) + T.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) + D = T.alloc_buffer([128, 128]) B[((i * 4) + ii), ((j * 4) + jj)] = B[((i * 4) + ii), ((j * 4) + jj)] + ( D[((j * 4) + jj), ((k * 4) + kk)] * C[((i * 4) + ii), ((k * 4) + kk)] ) -@tvm.script.tir +@T.prim_func def match_buffer_func() -> None: - C = tir.alloc_buffer((128, 128)) - with tir.block([128]) as [vi]: - C0 = tir.match_buffer(C[vi, 0:128], (128)) - with tir.block([128]) as [jj]: - C1 = tir.match_buffer(C0[jj], ()) + C = T.alloc_buffer((128, 128)) + with T.block([128]) as [vi]: + C0 = T.match_buffer(C[vi, 0:128], (128)) + with T.block([128]) as [jj]: + C1 = T.match_buffer(C0[jj], ()) C1[()] = 0 -@tvm.script.tir +@T.prim_func def transformed_match_buffer_func() -> None: for i in range(0, 128): - with tir.block([128]) as [vi]: - tir.bind(vi, i) - C = tir.alloc_buffer((128, 128)) - C0 = tir.match_buffer(C[vi, 0:128], (128)) - with tir.block([128]) as [jj]: - C1 = tir.match_buffer(C0[jj], ()) + with T.block([128]) as [vi]: + T.bind(vi, i) + C = T.alloc_buffer((128, 128)) + C0 = T.match_buffer(C[vi, 0:128], (128)) + with T.block([128]) as [jj]: + C1 = T.match_buffer(C0[jj], ()) C1[()] = 0 -@tvm.script.tir -def opaque_access(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [1024]) - B = tir.match_buffer(b, [1024]) - A_cache = tir.alloc_buffer([1024]) - for i in tir.serial(0, 8): - with tir.block([8]) as [vi]: - with tir.block([8]) as [v]: - tir.bind(v, vi) - tir.reads([A[(v * 128) : ((v * 128) + 128)]]) - tir.writes([A_cache[(v * 128) : ((v * 128) + 128)]]) - tir.evaluate( - tir.call_extern( +@T.prim_func +def opaque_access(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024]) + B = T.match_buffer(b, [1024]) + A_cache = T.alloc_buffer([1024]) + for i in T.serial(0, 8): + with T.block([8]) as [vi]: + with T.block([8]) as [v]: + T.bind(v, vi) + T.reads([A[(v * 128) : ((v * 128) + 128)]]) + T.writes([A_cache[(v * 128) : ((v * 128) + 128)]]) + T.evaluate( + T.call_extern( "test", A_cache.data, (v * 128), @@ -160,37 +160,37 @@ def opaque_access(a: ty.handle, b: ty.handle) -> None: dtype="float32", ) ) - for j in tir.serial(0, 128): - with tir.block([1024]) as [v]: - tir.bind(v, ((vi * 128) + j)) - tir.reads([A_cache[v]]) - tir.writes([B[v]]) + for j in T.serial(0, 128): + with T.block([1024]) as [v]: + T.bind(v, ((vi * 128) + j)) + T.reads([A_cache[v]]) + T.writes([B[v]]) B[v] = A_cache[v] -@tvm.script.tir -def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [1024]) - B = tir.match_buffer(b, [1024]) - for i in tir.serial(0, 8): - with tir.block([8]) as [vi]: - tir.reads(A[vi * 128 : vi * 128 + 128]) - tir.writes(B[vi * 128 : vi * 128 + 128]) - A_cache = tir.alloc_buffer([1024]) - with tir.block([8]) as [v]: - tir.bind(v, vi) - tir.reads([A[v * 128 : v * 128 + 128]]) - tir.writes([A_cache[v * 128 : v * 128 + 128]]) - tir.evaluate( - tir.call_extern( +@T.prim_func +def transformed_opaque_access(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024]) + B = T.match_buffer(b, [1024]) + for i in T.serial(0, 8): + with T.block([8]) as [vi]: + T.reads(A[vi * 128 : vi * 128 + 128]) + T.writes(B[vi * 128 : vi * 128 + 128]) + A_cache = T.alloc_buffer([1024]) + with T.block([8]) as [v]: + T.bind(v, vi) + T.reads([A[v * 128 : v * 128 + 128]]) + T.writes([A_cache[v * 128 : v * 128 + 128]]) + T.evaluate( + T.call_extern( "test", A_cache.data, v * 128, 128, A.data, v * 128, 128, dtype="float32" ) ) - for j in tir.serial(0, 128): - with tir.block([1024]) as [v]: - tir.bind(v, ((vi * 128) + j)) - tir.reads([A_cache[v]]) - tir.writes([B[v]]) + for j in T.serial(0, 128): + with T.block([1024]) as [v]: + T.bind(v, ((vi * 128) + j)) + T.reads([A_cache[v]]) + T.writes([B[v]]) B[v] = A_cache[v] diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index b57fa6c417b23..37223493a8b55 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm import te -from tvm.script import ty +from tvm.script import tir as T from tvm.relay import GlobalVar @@ -134,10 +134,10 @@ def count_sync(op): assert count[0] == 4 -@tvm.script.tir -def tir_func(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, [2, 2]) - B = tir.match_buffer(a, [2, 2]) +@T.prim_func +def tir_func(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [2, 2]) + B = T.match_buffer(a, [2, 2]) A[0, 1] = B[1, 1] diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py b/tests/python/unittest/test_tir_transform_unify_thread_binding.py index 8e0b6dc804aa7..1ce9b0cacd29d 100644 --- a/tests/python/unittest/test_tir_transform_unify_thread_binding.py +++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py @@ -16,8 +16,8 @@ # under the License. import pytest import tvm -from tvm import tir, te -from tvm.script import ty +from tvm import te +from tvm.script import tir as T def _check(original, transformed): @@ -33,159 +33,159 @@ def _check_fail(original): tvm.tir.transform.UnifyThreadBinding()(mod) -@tvm.script.tir -def element_wise_thread_x(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - j1_0 = tir.env_thread("threadIdx.x") - j0_0 = tir.env_thread("threadIdx.x") - i = tir.env_thread("blockIdx.x") - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - tir.launch_thread(i, 128) - with tir.launch_thread(j0_0, 4): - for j0_1 in tir.serial(0, 32): - tir.store( +@T.prim_func +def element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: + j1_0 = T.env_thread("threadIdx.x") + j0_0 = T.env_thread("threadIdx.x") + i = T.env_thread("blockIdx.x") + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + T.launch_thread(i, 128) + with T.launch_thread(j0_0, 4): + for j0_1 in T.serial(0, 32): + T.store( B.data, i * 128 + j0_0 * 32 + j0_1, - tir.load("float32", A.data, i * 128 + j0_0 * 32 + j0_1) * 2.0, + T.load("float32", A.data, i * 128 + j0_0 * 32 + j0_1) * 2.0, True, ) - tir.launch_thread(j1_0, 4) - for j1_1 in tir.serial(0, 32): - tir.store( + T.launch_thread(j1_0, 4) + for j1_1 in T.serial(0, 32): + T.store( C.data, i * 128 + j1_0 * 32 + j1_1, - tir.load("float32", A.data, i * 128 + j1_0 * 32 + j1_1) + 1.0, + T.load("float32", A.data, i * 128 + j1_0 * 32 + j1_1) + 1.0, True, ) -@tvm.script.tir -def unified_element_wise_thread_x(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - thread_x = tir.env_thread("threadIdx.x") - block_x = tir.env_thread("blockIdx.x") - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - tir.launch_thread(block_x, 128) - with tir.launch_thread(thread_x, 4): - for j0_1 in tir.serial(0, 32): - tir.store( +@T.prim_func +def unified_element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None: + thread_x = T.env_thread("threadIdx.x") + block_x = T.env_thread("blockIdx.x") + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + T.launch_thread(block_x, 128) + with T.launch_thread(thread_x, 4): + for j0_1 in T.serial(0, 32): + T.store( B.data, block_x * 128 + thread_x * 32 + j0_1, - tir.load("float32", A.data, block_x * 128 + thread_x * 32 + j0_1) * 2.0, + T.load("float32", A.data, block_x * 128 + thread_x * 32 + j0_1) * 2.0, True, ) - tir.launch_thread(thread_x, 4) - for j1_1 in tir.serial(0, 32): - tir.store( + T.launch_thread(thread_x, 4) + for j1_1 in T.serial(0, 32): + T.store( C.data, block_x * 128 + thread_x * 32 + j1_1, - tir.load("float32", A.data, block_x * 128 + thread_x * 32 + j1_1) + 1.0, + T.load("float32", A.data, block_x * 128 + thread_x * 32 + j1_1) + 1.0, True, ) -@tvm.script.tir -def element_wise_vthread_x(a: ty.handle, b: ty.handle) -> None: - i_0 = tir.env_thread("vthread.x") - i_1 = tir.env_thread("threadIdx.x") - j_0 = tir.env_thread("vthread.x") - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - tir.launch_thread(i_0, 2) - tir.launch_thread(i_1, 64) - tir.launch_thread(j_0, 2) - for j_1 in tir.serial(0, 64): - tir.store( +@T.prim_func +def element_wise_vthread_x(a: T.handle, b: T.handle) -> None: + i_0 = T.env_thread("vthread.x") + i_1 = T.env_thread("threadIdx.x") + j_0 = T.env_thread("vthread.x") + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + T.launch_thread(i_0, 2) + T.launch_thread(i_1, 64) + T.launch_thread(j_0, 2) + for j_1 in T.serial(0, 64): + T.store( B.data, i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1, - tir.load("float32", A.data, i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1) * 2.0, + T.load("float32", A.data, i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1) * 2.0, True, ) -@tvm.script.tir -def unified_element_wise_vthread_x(a: ty.handle, b: ty.handle) -> None: - vthread_x = tir.env_thread("vthread.x") - thread_x = tir.env_thread("threadIdx.x") - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - tir.launch_thread(vthread_x, 2) - tir.launch_thread(thread_x, 64) - tir.launch_thread(vthread_x, 2) - for j_1 in tir.serial(0, 64): - tir.store( +@T.prim_func +def unified_element_wise_vthread_x(a: T.handle, b: T.handle) -> None: + vthread_x = T.env_thread("vthread.x") + thread_x = T.env_thread("threadIdx.x") + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + T.launch_thread(vthread_x, 2) + T.launch_thread(thread_x, 64) + T.launch_thread(vthread_x, 2) + for j_1 in T.serial(0, 64): + T.store( B.data, vthread_x * 8256 + thread_x * 128 + j_1, - tir.load("float32", A.data, vthread_x * 8256 + thread_x * 128 + j_1) * 2.0, + T.load("float32", A.data, vthread_x * 8256 + thread_x * 128 + j_1) * 2.0, True, ) -@tvm.script.tir +@T.prim_func def element_wise_two_thread_x_in_same_kernel_not_equal( - a: ty.handle, b: ty.handle, c: ty.handle + a: T.handle, b: T.handle, c: T.handle ) -> None: - i = tir.env_thread("blockIdx.x") - j0 = tir.env_thread("threadIdx.x") - j1 = tir.env_thread("threadIdx.x") - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 64]) - tir.launch_thread(i, 128) - with tir.launch_thread(j0, 128): - tir.store(B.data, i * 64 + j0, tir.load("float32", A.data, i * 128 + j0) * 2.0, True) - tir.launch_thread(j1, 64) - tir.store(C.data, i * 64 + j1, tir.load("float32", A.data, i * 128 + j1) + 1.0, True) - - -@tvm.script.tir + i = T.env_thread("blockIdx.x") + j0 = T.env_thread("threadIdx.x") + j1 = T.env_thread("threadIdx.x") + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 64]) + T.launch_thread(i, 128) + with T.launch_thread(j0, 128): + T.store(B.data, i * 64 + j0, T.load("float32", A.data, i * 128 + j0) * 2.0, True) + T.launch_thread(j1, 64) + T.store(C.data, i * 64 + j1, T.load("float32", A.data, i * 128 + j1) + 1.0, True) + + +@T.prim_func def element_wise_kernels_with_different_size( - a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle + a: T.handle, b: T.handle, c: T.handle, d: T.handle ) -> None: - i0 = tir.env_thread("blockIdx.x") - j0 = tir.env_thread("threadIdx.x") - i1 = tir.env_thread("blockIdx.x") - j1 = tir.env_thread("threadIdx.x") - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [256, 256]) - D = tir.match_buffer(d, [256, 256]) - with tir.launch_thread(i0, 128): - tir.launch_thread(j0, 128) - tir.store(B.data, i0 * 128 + j0, tir.load("float32", A.data, i0 * 128 + j0) * 2.0, True) - tir.launch_thread(i1, 256) - tir.launch_thread(j1, 256) - tir.store(D.data, i1 * 256 + j1, tir.load("float32", C.data, i1 * 256 + j1) + 1.0, True) - - -@tvm.script.tir + i0 = T.env_thread("blockIdx.x") + j0 = T.env_thread("threadIdx.x") + i1 = T.env_thread("blockIdx.x") + j1 = T.env_thread("threadIdx.x") + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [256, 256]) + D = T.match_buffer(d, [256, 256]) + with T.launch_thread(i0, 128): + T.launch_thread(j0, 128) + T.store(B.data, i0 * 128 + j0, T.load("float32", A.data, i0 * 128 + j0) * 2.0, True) + T.launch_thread(i1, 256) + T.launch_thread(j1, 256) + T.store(D.data, i1 * 256 + j1, T.load("float32", C.data, i1 * 256 + j1) + 1.0, True) + + +@T.prim_func def unified_element_wise_kernels_with_different_size( - a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle + a: T.handle, b: T.handle, c: T.handle, d: T.handle ) -> None: - block_x = tir.env_thread("blockIdx.x") - thread_x = tir.env_thread("threadIdx.x") - block_x_1 = tir.env_thread("blockIdx.x") - thread_x_1 = tir.env_thread("threadIdx.x") - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [256, 256]) - D = tir.match_buffer(d, [256, 256]) - with tir.launch_thread(block_x, 128): - tir.launch_thread(thread_x, 128) - tir.store( + block_x = T.env_thread("blockIdx.x") + thread_x = T.env_thread("threadIdx.x") + block_x_1 = T.env_thread("blockIdx.x") + thread_x_1 = T.env_thread("threadIdx.x") + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [256, 256]) + D = T.match_buffer(d, [256, 256]) + with T.launch_thread(block_x, 128): + T.launch_thread(thread_x, 128) + T.store( B.data, block_x * 128 + thread_x, - tir.load("float32", A.data, block_x * 128 + thread_x) * 2.0, + T.load("float32", A.data, block_x * 128 + thread_x) * 2.0, True, ) - tir.launch_thread(block_x_1, 256) - tir.launch_thread(thread_x_1, 256) - tir.store( + T.launch_thread(block_x_1, 256) + T.launch_thread(thread_x_1, 256) + T.store( D.data, block_x_1 * 256 + thread_x_1, - tir.load("float32", C.data, block_x_1 * 256 + thread_x_1) + 1.0, + T.load("float32", C.data, block_x_1 * 256 + thread_x_1) + 1.0, True, ) diff --git a/tests/python/unittest/test_tvmscript_complete.py b/tests/python/unittest/test_tvmscript_complete.py index 4798e9e098655..7c521db21bb84 100644 --- a/tests/python/unittest/test_tvmscript_complete.py +++ b/tests/python/unittest/test_tvmscript_complete.py @@ -16,85 +16,83 @@ # under the License. import tvm -from tvm import tir from tvm.ir import Range -from tvm.script import ty, from_source -from tvm.ir.diagnostics import override_renderer +from tvm.script import tir as T -@tvm.script.tir -def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) - with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with tir.init(): - C[vi, vj] = tir.float32(0) + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.init(): + C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@tvm.script.tir -def matmul_original(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) +@T.prim_func +def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) - for i, j in tir.grid(32, 32): - with tir.block([32, 32], "init") as [vi, vj]: - for ii, jj in tir.grid(4, 4): - C[vi * 4 + ii, vj * 4 + jj] = tir.float32(0) + for i, j in T.grid(32, 32): + with T.block([32, 32], "init") as [vi, vj]: + for ii, jj in T.grid(4, 4): + C[vi * 4 + ii, vj * 4 + jj] = T.float32(0) for k in range(0, 32): - with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - for ii, jj, kk in tir.grid(4, 4, 4): + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + for ii, jj, kk in T.grid(4, 4, 4): C[vi * 4 + ii, vj * 4 + jj] = ( C[vi * 4 + ii, vj * 4 + jj] + A[vi * 4 + ii, vk * 4 + kk] * B[vj * 4 + jj, vk * 4 + kk] ) -@tvm.script.tir -def elementwise_with_root(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) +@T.prim_func +def elementwise_with_root(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) - with tir.block([]) as []: - with tir.block([128, 128]) as [vi, vj]: - B[vi, vj] = A[vi, vj] + tir.float32(1) + with T.block([]) as []: + with T.block([128, 128]) as [vi, vj]: + B[vi, vj] = A[vi, vj] + T.float32(1) - with tir.block([128, 128]) as [vi, vj]: - C[vi, vj] = B[vi, vj] + tir.float32(1) + with T.block([128, 128]) as [vi, vj]: + C[vi, vj] = B[vi, vj] + T.float32(1) -def func_with_opaque_block(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) +def func_with_opaque_block(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) - with tir.block([]) as []: - with tir.block([]) as []: - B[0, 0] = A[0, 0] + tir.float32(1) + with T.block([]) as []: + with T.block([]) as []: + B[0, 0] = A[0, 0] + T.float32(1) - with tir.block([128, 128]) as [vi, vj]: - C[vi, vj] = B[vi, vj] + tir.float32(1) + with T.block([128, 128]) as [vi, vj]: + C[vi, vj] = B[vi, vj] + T.float32(1) -@tvm.script.tir -def func_with_part_access_region(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) +@T.prim_func +def func_with_part_access_region(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) - with tir.block([]) as []: - with tir.block([128, 128]) as [vi, vj]: - tir.reads(A[vi, vj]) - B[vi, vj] = A[vi, vj] + tir.float32(1) + with T.block([]) as []: + with T.block([128, 128]) as [vi, vj]: + T.reads(A[vi, vj]) + B[vi, vj] = A[vi, vj] + T.float32(1) - with tir.block([128, 128]) as [vi, vj]: - tir.writes(C[vi, vj]) - C[vi, vj] = B[vi, vj] + tir.float32(1) + with T.block([128, 128]) as [vi, vj]: + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + T.float32(1) def test_complete_matmul(): @@ -104,9 +102,9 @@ def test_complete_matmul(): block = func.body.block.body.body.body.body.block assert isinstance(block, tvm.tir.Block) vi, vj, vk = [x.var for x in block.iter_vars] - access_A = tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vk, 1)]) - access_B = tir.BufferRegion(B, [Range.from_min_extent(vj, 1), Range.from_min_extent(vk, 1)]) - access_C = tir.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)]) + access_A = tvm.tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vk, 1)]) + access_B = tvm.tir.BufferRegion(B, [Range.from_min_extent(vj, 1), Range.from_min_extent(vk, 1)]) + access_C = tvm.tir.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)]) tvm.ir.assert_structural_equal(block.reads, [access_C, access_A, access_B]) tvm.ir.assert_structural_equal(block.writes, [access_C]) @@ -118,7 +116,7 @@ def test_complete_matmul_original(): block1 = func.body.block.body.body.body[0].block assert isinstance(block1, tvm.tir.Block) vi, vj = [x.var for x in block1.iter_vars] - access_C = tir.BufferRegion( + access_C = tvm.tir.BufferRegion( C, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vj * 4, 4)] ) tvm.ir.assert_structural_equal(block1.reads, []) @@ -127,13 +125,13 @@ def test_complete_matmul_original(): block2 = func.body.block.body.body.body[1].body.block assert isinstance(block2, tvm.tir.Block) vi, vj, vk = [x.var for x in block2.iter_vars] - access_A = tir.BufferRegion( + access_A = tvm.tir.BufferRegion( A, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vk * 4, 4)] ) - access_B = tir.BufferRegion( + access_B = tvm.tir.BufferRegion( B, [Range.from_min_extent(vj * 4, 4), Range.from_min_extent(vk * 4, 4)] ) - access_C = tir.BufferRegion( + access_C = tvm.tir.BufferRegion( C, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vj * 4, 4)] ) tvm.ir.assert_structural_equal(block2.reads, [access_C, access_A, access_B]) @@ -149,11 +147,11 @@ def _check_elementwise(func): tvm.ir.assert_structural_equal( block1.reads, - [tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], + [tvm.tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) tvm.ir.assert_structural_equal( block1.writes, - [tir.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], + [tvm.tir.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) block2 = func.body.block.body[1].body.body.block @@ -161,11 +159,11 @@ def _check_elementwise(func): vi, vj = [x.var for x in block2.iter_vars] tvm.ir.assert_structural_equal( block2.reads, - [tir.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], + [tvm.tir.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) tvm.ir.assert_structural_equal( block2.writes, - [tir.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], + [tvm.tir.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) @@ -177,100 +175,96 @@ def test_complete_part_region(): _check_elementwise(func_with_part_access_region) -@tvm.script.tir -def func_with_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: - data_buf = tir.match_buffer(data, (16, 16), "float32") - index_buf = tir.match_buffer(index, (1,), "int32") - out_buf = tir.alloc_buffer((16, 16), "float32") +@T.prim_func +def func_with_bufferslice_indices(data: T.handle, index: T.handle) -> None: + data_buf = T.match_buffer(data, (16, 16), "float32") + index_buf = T.match_buffer(index, (1,), "int32") + out_buf = T.alloc_buffer((16, 16), "float32") - with tir.block([16, 16]) as [vi, vj]: + with T.block([16, 16]) as [vi, vj]: out_buf[vi, vj] = data_buf[vi, index_buf[0]] -@tvm.script.tir -def expected_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: - index_buf = tir.match_buffer( - index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1 - ) - data_buf = tir.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) - with tir.block([], "root"): - tir.reads([]) - tir.writes([]) - out_buf = tir.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) - for i0, i1 in tir.grid(16, 16): - with tir.block([16, 16], "") as [vi, vj]: - tir.bind(vi, i0) - tir.bind(vj, i1) - tir.reads([data_buf[vi, 0:16], index_buf[0]]) - tir.writes([out_buf[vi, vj]]) +@T.prim_func +def expected_bufferslice_indices(data: T.handle, index: T.handle) -> None: + index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) + data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) + with T.block([], "root"): + T.reads([]) + T.writes([]) + out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) + for i0, i1 in T.grid(16, 16): + with T.block([16, 16], "") as [vi, vj]: + T.bind(vi, i0) + T.bind(vj, i1) + T.reads([data_buf[vi, 0:16], index_buf[0]]) + T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[vi, index_buf[0]] -@tvm.script.tir -def func_with_recursive_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: - data_buf = tir.match_buffer(data, (16, 16), "float32") - index_buf = tir.match_buffer(index, (1,), "int32") - out_buf = tir.alloc_buffer((16, 16), "float32") +@T.prim_func +def func_with_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> None: + data_buf = T.match_buffer(data, (16, 16), "float32") + index_buf = T.match_buffer(index, (1,), "int32") + out_buf = T.alloc_buffer((16, 16), "float32") - with tir.block([16, 16]) as [vi, vj]: + with T.block([16, 16]) as [vi, vj]: out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] -@tvm.script.tir -def expected_recursive_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: - index_buf = tir.match_buffer( - index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1 - ) - data_buf = tir.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) - with tir.block([], "root"): - tir.reads([]) - tir.writes([]) - out_buf = tir.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) - for i0, i1 in tir.grid(16, 16): - with tir.block([16, 16], "") as [vi, vj]: - tir.bind(vi, i0) - tir.bind(vj, i1) - tir.reads([data_buf[0:16, 0:16], index_buf[0]]) - tir.writes([out_buf[vi, vj]]) +@T.prim_func +def expected_recursive_bufferslice_indices(data: T.handle, index: T.handle) -> None: + index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) + data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) + with T.block([], "root"): + T.reads([]) + T.writes([]) + out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) + for i0, i1 in T.grid(16, 16): + with T.block([16, 16], "") as [vi, vj]: + T.bind(vi, i0) + T.bind(vj, i1) + T.reads([data_buf[0:16, 0:16], index_buf[0]]) + T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]] def test_complete_buffer_indices(): - new_func = tvm.script.from_source(tvm.script.asscript(func_with_bufferslice_indices)) + new_func = tvm.script.from_source(func_with_bufferslice_indices.script()) tvm.ir.assert_structural_equal(new_func, expected_bufferslice_indices) - new_func = tvm.script.from_source(tvm.script.asscript(func_with_recursive_bufferslice_indices)) + new_func = tvm.script.from_source(func_with_recursive_bufferslice_indices.script()) tvm.ir.assert_structural_equal(new_func, expected_recursive_bufferslice_indices) -@tvm.script.tir -def match_buffer_func(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16)) +@T.prim_func +def match_buffer_func(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16)) for i in range(0, 16): - with tir.block([]): - A0 = tir.match_buffer(A[i, 0:16], (16)) - with tir.block([]): + with T.block([]): + A0 = T.match_buffer(A[i, 0:16], (16)) + with T.block([]): for j in range(0, 16): - with tir.block([]) as []: - A1 = tir.match_buffer(A0[j], ()) + with T.block([]) as []: + A1 = T.match_buffer(A0[j], ()) A1[()] = 1.0 -@tvm.script.tir -def expected_match_buffer_func(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16)) +@T.prim_func +def expected_match_buffer_func(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16)) for i in range(0, 16): - with tir.block([]): - tir.reads([]) - tir.writes(A[i, 0:16]) - A0 = tir.match_buffer(A[i, 0:16], (16)) - with tir.block([]): - tir.reads([]) - tir.writes(A0[0:16]) + with T.block([]): + T.reads([]) + T.writes(A[i, 0:16]) + A0 = T.match_buffer(A[i, 0:16], (16)) + with T.block([]): + T.reads([]) + T.writes(A0[0:16]) for j in range(0, 16): - with tir.block([]) as []: - tir.reads([]) - tir.writes(A0[j]) - A1 = tir.match_buffer(A0[j], ()) + with T.block([]) as []: + T.reads([]) + T.writes(A0[j]) + A1 = T.match_buffer(A0[j], ()) A1[()] = 1.0 diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 70a2aea11293f..99a22636b9272 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -18,27 +18,26 @@ import pytest import sys import tvm -from tvm import tir -from tvm.script import ty, from_source +from tvm.script import tir as T from tvm.ir.diagnostics import override_renderer import inspect -def buffer_bind_missing_args(a: ty.handle) -> None: - A = tir.match_buffer((16, 16), "float32") # error +def buffer_bind_missing_args(a: T.handle) -> None: + A = T.match_buffer((16, 16), "float32") # error def test_buffer_bind(): check_error(buffer_bind_missing_args, 2) -def range_missing_args(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") +def range_missing_args(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") - tir.attr(A, "realize_scope", "") - tir.realize(A[0:16, 0:16], "") - for i in tir.serial(16): # error - for j in tir.serial(0, 16): + T.attr(A, "realize_scope", "") + T.realize(A[0:16, 0:16], "") + for i in T.serial(16): # error + for j in T.serial(0, 16): A[i, j] = 0.0 @@ -46,13 +45,13 @@ def test_range_missing_args(): check_error(range_missing_args, 6) -def undefined_buffer(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") +def undefined_buffer(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") - tir.attr(A, "realize_scope", "") - tir.realize(C[0:16, 0:16], "") # error - for i in tir.serial(16): - for j in tir.serial(0, 16): + T.attr(A, "realize_scope", "") + T.realize(C[0:16, 0:16], "") # error + for i in T.serial(16): + for j in T.serial(0, 16): A[i, j] = 0.0 @@ -60,7 +59,7 @@ def test_undefined_buffer(): check_error(undefined_buffer, 5) -def unsupported_stmt(a: ty.int32) -> None: +def unsupported_stmt(a: T.int32) -> None: if a > 0: print("I love tvm") # error @@ -69,13 +68,13 @@ def test_unsupported_stmt(): check_error(unsupported_stmt, 3) -def unsupported_function_call(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") +def unsupported_function_call(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") - tir.attr(A, "realize_scope", "") - tir.realize(A[0:16, 0:16], "") - for i in tir.const_range(16): # error - for j in tir.serial(0, 16): + T.attr(A, "realize_scope", "") + T.realize(A[0:16, 0:16], "") + for i in T.const_range(16): # error + for j in T.serial(0, 16): A[i, j] = 0.0 @@ -84,7 +83,7 @@ def test_unsupported_function_call(): def missing_type_annotation(a) -> None: # error - tir.evaluate(0.0) + T.evaluate(0.0) def test_missing_type_annotation(): @@ -92,18 +91,18 @@ def test_missing_type_annotation(): def invalid_expr_stmt() -> None: - tir.max(1, 2) # error + T.max(1, 2) # error def test_invalid_expr_stmt(): check_error(invalid_expr_stmt, 2) -def invalid_for_function(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") +def invalid_for_function(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") - for i in tir.evaluate(0.0): # error - for j in tir.serial(0, 16): + for i in T.evaluate(0.0): # error + for j in T.serial(0, 16): A[i, j] = 0.0 @@ -111,36 +110,36 @@ def test_invalid_for_function(): check_error(invalid_for_function, 4) -def invalid_block_function(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") +def invalid_block_function(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") - with tir.evaluate(0.0): # error - tir.evaluate(1.0) + with T.evaluate(0.0): # error + T.evaluate(1.0) def test_invalid_block_function(): check_error(invalid_block_function, 4) -def return_not_allowed(a: ty.handle) -> None: - return tir.evaluate(0) # error +def return_not_allowed(a: T.handle) -> None: + return T.evaluate(0) # error def test_return_not_allowed(): check_error(return_not_allowed, 2) -def tir_assert(a: ty.handle) -> None: - tir.Assert(0, "") # error +def tir_assert(a: T.handle) -> None: + T.Assert(0, "") # error def test_tir_assert(): check_error(tir_assert, 2) -def no_body(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - tir.realize(A, "") # error +def no_body(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + T.realize(A, "") # error def test_no_body(): @@ -148,8 +147,8 @@ def test_no_body(): def allocate_with_buffers() -> None: - with tir.allocate([1], "float32", "") as [A, B]: # error - tir.evaluate(1.0) + with T.allocate([1], "float32", "") as [A, B]: # error + T.evaluate(1.0) def test_allocate_with_buffers(): @@ -157,18 +156,18 @@ def test_allocate_with_buffers(): def inconsistent_binding() -> None: - with tir.block([128, 128]) as [vi]: # error - tir.evaluate(1.0) + with T.block([128, 128]) as [vi]: # error + T.evaluate(1.0) def test_inconsistent_binding(): check_error(inconsistent_binding, 2) -def invalid_block_axes(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - with tir.block([A]) as [vi]: # error - tir.evaluate(1.0) +def invalid_block_axes(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + with T.block([A]) as [vi]: # error + T.evaluate(1.0) def test_invalid_block_axes(): @@ -176,9 +175,9 @@ def test_invalid_block_axes(): def miss_block_bind() -> None: - with tir.block([16, 16]) as [vi, vj]: # error - tir.bind(vi, 1) - tir.evaluate(1.0) + with T.block([16, 16]) as [vi, vj]: # error + T.bind(vi, 1) + T.evaluate(1.0) def test_miss_block_bind(): @@ -187,7 +186,7 @@ def test_miss_block_bind(): def invalid_loop_var() -> None: for i, j in range(0, 16): # error - tir.evaluate(1.0) + T.evaluate(1.0) def test_invalid_loop_var(): @@ -195,8 +194,8 @@ def test_invalid_loop_var(): def inconsistent_grid() -> None: - for i in tir.grid(16, 16): # error - tir.evaluate(1.0) + for i in T.grid(16, 16): # error + T.evaluate(1.0) def test_inconsistent_grid(): @@ -204,9 +203,9 @@ def test_inconsistent_grid(): def invalid_match_buffer_region() -> None: - with tir.block([16, 16]) as [vi, vj]: - A = tir.match_buffer(vi) # error - tir.evaluate(1.0) + with T.block([16, 16]) as [vi, vj]: + A = T.match_buffer(vi) # error + T.evaluate(1.0) def test_invalid_match_buffer_region(): @@ -214,10 +213,10 @@ def test_invalid_match_buffer_region(): def duplicate_buffer() -> None: - A = tir.alloc_buffer((128, 128), "float32") - with tir.block([16, 16]) as [vi, vj]: - A = tir.alloc_buffer((128, 128), "float32") # error - tir.evaluate(1.0) + A = T.alloc_buffer((128, 128), "float32") + with T.block([16, 16]) as [vi, vj]: + A = T.alloc_buffer((128, 128), "float32") # error + T.evaluate(1.0) def test_duplicate_buffer(): @@ -225,39 +224,39 @@ def test_duplicate_buffer(): def duplicate_reads() -> None: - A = tir.alloc_buffer((128, 128), "float32") - with tir.block([16, 16]) as [vi, vj]: - tir.reads(A[0:8, 0:8]) - tir.reads(A[0:16, 0:16]) # error - tir.evaluate(1.0) + A = T.alloc_buffer((128, 128), "float32") + with T.block([16, 16]) as [vi, vj]: + T.reads(A[0:8, 0:8]) + T.reads(A[0:16, 0:16]) # error + T.evaluate(1.0) def duplicate_writes() -> None: - A = tir.alloc_buffer((128, 128), "float32") - with tir.block([16, 16]) as [vi, vj]: - tir.writes(A[0:8, 0:8]) - tir.writes(A[0:16, 0:16]) # error - tir.evaluate(1.0) + A = T.alloc_buffer((128, 128), "float32") + with T.block([16, 16]) as [vi, vj]: + T.writes(A[0:8, 0:8]) + T.writes(A[0:16, 0:16]) # error + T.evaluate(1.0) def duplicate_predicate() -> None: - with tir.block([16, 16]) as [vi, vj]: - tir.where(1) - tir.where(0) # error + with T.block([16, 16]) as [vi, vj]: + T.where(1) + T.where(0) # error def duplicate_annotations() -> None: - with tir.block([16, 16]) as [vi, vj]: - tir.block_attr({}) - tir.block_attr({}) # error + with T.block([16, 16]) as [vi, vj]: + T.block_attr({}) + T.block_attr({}) # error def duplicate_init() -> None: - with tir.block([16, 16]) as [vi, vj]: - with tir.init(): - tir.evaluate(1.0) - with tir.init(): # error - tir.evaluate(1.0) + with T.block([16, 16]) as [vi, vj]: + with T.init(): + T.evaluate(1.0) + with T.init(): # error + T.evaluate(1.0) def test_duplicate_block_signature(): @@ -268,10 +267,10 @@ def test_duplicate_block_signature(): check_error(duplicate_init, 5) -def opaque_access_during_complete(a: ty.handle) -> None: # error - A = tir.match_buffer(a, (16, 16), "float32") - with tir.block([16, 16]) as [vi, vj]: - tir.evaluate(tir.load("float32", A.data, vi * 16 + vj)) +def opaque_access_during_complete(a: T.handle) -> None: # error + A = T.match_buffer(a, (16, 16), "float32") + with T.block([16, 16]) as [vi, vj]: + T.evaluate(T.load("float32", A.data, vi * 16 + vj)) def test_opaque_access_during_complete(): @@ -279,8 +278,8 @@ def test_opaque_access_during_complete(): def convert_slice_to_bufferload() -> None: - A = tir.alloc_buffer((128, 128), "float32") - with tir.block([16, 16]) as [vi, vj]: + A = T.alloc_buffer((128, 128), "float32") + with T.block([16, 16]) as [vi, vj]: A[vi, vj] = A[vi : vi + 2, vj] + 1 # error @@ -289,16 +288,16 @@ def test_convert_slice_to_bufferload(): def error_index_type() -> None: - A = tir.alloc_buffer((128, 128), "float32") - with tir.block([16, 16]) as [vi, vj]: + A = T.alloc_buffer((128, 128), "float32") + with T.block([16, 16]) as [vi, vj]: A[vi, vj] = A[vi, 0.0] + 1 # error def error_bufferslice_index_type() -> None: - A = tir.alloc_buffer((1,), "float32") - B = tir.alloc_buffer((16, 16), "float32") - C = tir.alloc_buffer((16, 16), "float32") - with tir.block([16, 16]) as [vi, vj]: + A = T.alloc_buffer((1,), "float32") + B = T.alloc_buffer((16, 16), "float32") + C = T.alloc_buffer((16, 16), "float32") + with T.block([16, 16]) as [vi, vj]: C[vi, vj] = B[vi, A[0]] # error @@ -308,16 +307,16 @@ def test_error_index_type(): def error_index_with_stop() -> None: - A = tir.alloc_buffer((128, 128), "float32") - with tir.block([16, 16]) as [vi, vj]: + A = T.alloc_buffer((128, 128), "float32") + with T.block([16, 16]) as [vi, vj]: A[vi, vj] = A[vi, 1:10] + 1 # error def error_bufferslice_index_with_stop() -> None: - A = tir.alloc_buffer((1,), "int32") - B = tir.alloc_buffer((16, 16), "float32") - C = tir.alloc_buffer((16, 16), "float32") - with tir.block([16, 16]) as [vi, vj]: + A = T.alloc_buffer((1,), "int32") + B = T.alloc_buffer((16, 16), "float32") + C = T.alloc_buffer((16, 16), "float32") + with T.block([16, 16]) as [vi, vj]: C[vi, vj] = B[vi, A[0:1]] # error @@ -327,10 +326,10 @@ def test_error_index_with_stop_slice(): def mismatch_args() -> None: - A = tir.alloc_buffer((128, 128), "float32") - with tir.block([16, 16]) as [vi, vj]: - tir.reads(A[0, 0], A[1, 1]) # error - tir.evaluate(1.0) + A = T.alloc_buffer((128, 128), "float32") + with T.block([16, 16]) as [vi, vj]: + T.reads(A[0, 0], A[1, 1]) # error + T.evaluate(1.0) def test_mismatch_args(): @@ -338,24 +337,24 @@ def test_mismatch_args(): def special_stmt_except() -> None: - A = tir.alloc_buffer("(128, 128)", "float32") # error - with tir.block([16, 16]) as [vi, vj]: - tir.evaluate(1.0) + A = T.alloc_buffer("(128, 128)", "float32") # error + with T.block([16, 16]) as [vi, vj]: + T.evaluate(1.0) def scope_handler_except() -> None: - for i in tir.serial("1", "1"): # error - tir.evaluate(1) + for i in T.serial("1", "1"): # error + T.evaluate(1) -def intrin_except_unassign(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - tir.evaluate(A) # error +def intrin_except_unassign(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + T.evaluate(A) # error -def intrin_except_assign(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - A[0, 0] = tir.load(A, A, A) # error +def intrin_except_assign(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + A[0, 0] = T.load(A, A, A) # error def test_tvm_exception_catch(): @@ -366,13 +365,13 @@ def test_tvm_exception_catch(): check_error(intrin_except_assign, 3) -def buffer_shape_mismatch(a: ty.handle) -> None: - A = tir.match_buffer(a, (8, 8)) - for i, j in tir.grid(8, 2): - with tir.block([]): - tir.reads([]) - tir.writes([A[i, j * 4 : j * 4 + 4]]) - sub_A = tir.match_buffer( +def buffer_shape_mismatch(a: T.handle) -> None: + A = T.match_buffer(a, (8, 8)) + for i, j in T.grid(8, 2): + with T.block([]): + T.reads([]) + T.writes([A[i, j * 4 : j * 4 + 4]]) + sub_A = T.match_buffer( A[i, j * 4 : j * 4 + 4], (5) ) # error: shape mismatched between 4 and 5 for jj in range(0, 4): @@ -384,9 +383,9 @@ def test_match_buffer_shape_mismatch(): def high_dim_store() -> None: - with tir.block([], "root"): - B = tir.allocate([256], "float32", "global") - for i, j in tir.grid(16, 16): + with T.block([], "root"): + B = T.allocate([256], "float32", "global") + for i, j in T.grid(16, 16): B[i, j] = 1.0 # error: Store is only allowed with one index @@ -394,10 +393,8 @@ def test_high_dim_store(): check_error(high_dim_store, 5) -def check_error(module, rel_lineno): +def check_error(func, rel_lineno): # Override the default renderer to accumulate errors - _, start_line = inspect.getsourcelines(module) - lineno = start_line + rel_lineno - 1 errors = [] def render(e): @@ -407,14 +404,16 @@ def render(e): override_renderer(render) # The diagnostic context throws an exception when it gets an error try: - mod = from_source(module) + source_code = inspect.getsource(func) + source_code = "@T.prim_func\n" + source_code + tvm.script.from_source(source_code) except tvm.error.DiagnosticError as e: pass assert len(errors) == 1, errors for d in errors: assert ( - d.span.line == lineno - ), f"Expected error to be on line {lineno}, but it was on {d.span.line}" + d.span.line - 1 == rel_lineno + ), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}" if __name__ == "__main__": diff --git a/tests/python/unittest/test_tvmscript_ops.py b/tests/python/unittest/test_tvmscript_ops.py index 016e7f7427f64..c55fd7b692823 100644 --- a/tests/python/unittest/test_tvmscript_ops.py +++ b/tests/python/unittest/test_tvmscript_ops.py @@ -16,44 +16,43 @@ # under the License. import tvm -from tvm.script import ty -from tvm import te, tir +from tvm.script import tir as T import numpy as np import tvm.testing -@tvm.script.tir +@T.prim_func def get_valid_counts( - data: ty.handle, - valid_count: ty.handle, - out: ty.handle, - out_indices: ty.handle, - score_threshold: ty.float32, - id_index: ty.int32, - score_index: ty.int32, + data: T.handle, + valid_count: T.handle, + out: T.handle, + out_indices: T.handle, + score_threshold: T.float32, + id_index: T.int32, + score_index: T.int32, ) -> None: - data_buf = tir.match_buffer(data, (1, 2500, 6), "float32") - valid_count_buf = tir.match_buffer(valid_count, (1,), "int32") - out_buf = tir.match_buffer(out, (1, 2500, 6), "float32") - out_indices_buf = tir.match_buffer(out_indices, (1, 2500), "int32") + data_buf = T.match_buffer(data, (1, 2500, 6), "float32") + valid_count_buf = T.match_buffer(valid_count, (1,), "int32") + out_buf = T.match_buffer(out, (1, 2500, 6), "float32") + out_indices_buf = T.match_buffer(out_indices, (1, 2500), "int32") - with tir.block([1], "init") as [vi]: - valid_count_buf[vi] = tir.int32(0) - with tir.block([2500], "update") as [vj]: - tir.reads([data_buf[vi, vj, 6]]) - tir.writes([valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6]]) + with T.block([1], "init") as [vi]: + valid_count_buf[vi] = T.int32(0) + with T.block([2500], "update") as [vj]: + T.reads([data_buf[vi, vj, 6]]) + T.writes([valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6]]) if (data_buf[vi, vj, score_index] > score_threshold) and ( - (id_index < 0) or (data_buf[vi, vj, id_index] >= tir.float32(0)) + (id_index < 0) or (data_buf[vi, vj, id_index] >= T.float32(0)) ): - for k in tir.serial(0, 6): + for k in T.serial(0, 6): out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k] out_indices_buf[vi, valid_count_buf[vi]] = vj valid_count_buf[vi] = valid_count_buf[vi] + 1 if vj >= valid_count_buf[vi]: - for k in tir.serial(0, 6): - out_buf[vi, vj, k] = tir.float32(-1) - out_indices_buf[vi, vj] = tir.int32(-1) + for k in T.serial(0, 6): + out_buf[vi, vj, k] = T.float32(-1) + out_indices_buf[vi, vj] = T.int32(-1) def _check_get_valid_counts_with_numpy(f, dshape, score_threshold, id_index, score_index): @@ -81,7 +80,6 @@ def _check_get_valid_counts_with_numpy(f, dshape, score_threshold, id_index, sco np_out3[i, j] = -1 in_data = tvm.nd.array(np_data, ctx) - score_threshold_data = tvm.nd.array(np.array([score_threshold], dtype=dtype), ctx) out1 = tvm.nd.array(np_out1, ctx) out2 = tvm.nd.array(np_out2, ctx) out3 = tvm.nd.array(np_out3, ctx) @@ -95,9 +93,9 @@ def _check_get_valid_counts_with_numpy(f, dshape, score_threshold, id_index, sco def test_get_valid_counts_script_func(): device = "llvm" # check lowering - print(tvm.script.asscript(get_valid_counts)) - mod = tvm.script.create_module({"get_valid_counts": get_valid_counts}) - print(tvm.script.asscript(mod)) + print(get_valid_counts.script()) + mod = tvm.ir.IRModule({"get_valid_counts": get_valid_counts}) + print(mod.script()) # check building f = tvm.build(mod["get_valid_counts"], target=device) _check_get_valid_counts_with_numpy(f, (1, 2500, 6), 0.0, 0, 1) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 7c123afdc4d06..2a6c3c6faf73c 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -20,145 +20,147 @@ import tvm from tvm import tir -from tvm.script import ty +from tvm.script import tir as T -@tvm.script.tir +@tvm.script.ir_module class Module1: - def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: + @T.prim_func + def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "mmult", "tir.noalias": True}) + T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) # buffer definition - C_global = tir.buffer_decl([1024, 1024], elem_offset=0, align=128, offset_factor=1) - packedB = tir.buffer_decl([32, 1024, 32], elem_offset=0, align=128, offset_factor=1) - A_1 = tir.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - B_1 = tir.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - C_1 = tir.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + C_global = T.buffer_decl([1024, 1024], elem_offset=0, align=128, offset_factor=1) + packedB = T.buffer_decl([32, 1024, 32], elem_offset=0, align=128, offset_factor=1) + A_1 = T.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) # body - tir.realize(packedB[0:32, 0:1024, 0:32], "") - for x in tir.parallel(0, 32): - for y in tir.serial(0, 1024): - for z in tir.vectorized(0, 32): + T.realize(packedB[0:32, 0:1024, 0:32], "") + for x in T.parallel(0, 32): + for y in T.serial(0, 1024): + for z in T.vectorized(0, 32): packedB[x, y, z] = B_1[y, ((x * 32) + z)] - tir.realize(C_1[0:1024, 0:1024], "") - for x_outer in tir.parallel(0, 32): - for y_outer in tir.serial(0, 32): - tir.realize( + T.realize(C_1[0:1024, 0:1024], "") + for x_outer in T.parallel(0, 32): + for y_outer in T.serial(0, 32): + T.realize( C_global[ (x_outer * 32) : ((x_outer * 32) + 32), (y_outer * 32) : ((y_outer * 32) + 32), ], "global", ) - for x_c_init in tir.serial(0, 32): - for y_c_init in tir.vectorized(0, 32): + for x_c_init in T.serial(0, 32): + for y_c_init in T.vectorized(0, 32): C_global[ (x_c_init + (x_outer * 32)), (y_c_init + (y_outer * 32)) - ] = tir.float32(0) - for k_outer in tir.serial(0, 256): - for x_c in tir.serial(0, 32): - for k_inner in tir.unroll(0, 4): - for y_c in tir.vectorized(0, 32): + ] = T.float32(0) + for k_outer in T.serial(0, 256): + for x_c in T.serial(0, 32): + for k_inner in T.unroll(0, 4): + for y_c in T.vectorized(0, 32): C_global[(x_c + (x_outer * 32)), (y_c + (y_outer * 32))] = C_global[ (x_c + (x_outer * 32)), (y_c + (y_outer * 32)) ] + ( A_1[(x_c + (x_outer * 32)), (k_inner + (k_outer * 4))] * packedB[ - tir.floordiv((y_c + (y_outer * 32)), 32), + T.floordiv((y_c + (y_outer * 32)), 32), (k_inner + (k_outer * 4)), - tir.floormod((y_c + (y_outer * 32)), 32), + T.floormod((y_c + (y_outer * 32)), 32), ] ) - for x_inner in tir.serial(0, 32): - for y_inner in tir.serial(0, 32): + for x_inner in T.serial(0, 32): + for y_inner in T.serial(0, 32): C_1[(x_inner + (x_outer * 32)), (y_inner + (y_outer * 32))] = C_global[ (x_inner + (x_outer * 32)), (y_inner + (y_outer * 32)) ] def test_opt_gemm_normalize(): - mod = Module1() - rt_mod = tvm.script.from_source(tvm.script.asscript(mod, True)) + mod = Module1 + rt_mod = tvm.script.from_source(mod.script(True)) tvm.ir.assert_structural_equal(mod, rt_mod, True) -@tvm.script.tir +@tvm.script.ir_module class Module2: - def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: + @T.prim_func + def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "mmult", "tir.noalias": True}) - A_1 = tir.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - B_1 = tir.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - C_1 = tir.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + T.func_attr({"global_symbol": "mmult", "tir.noalias": True}) + A_1 = T.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + C_1 = T.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) # body - packedB = tir.allocate([32768], "float32x32", "global") - for x in tir.parallel(0, 32): - for y in tir.serial(0, 1024): - tir.store( + packedB = T.allocate([32768], "float32x32", "global") + for x in T.parallel(0, 32): + for y in T.serial(0, 1024): + T.store( packedB, - tir.ramp(((x * 32768) + (y * 32)), 1, 32), - tir.load( + T.ramp(((x * 32768) + (y * 32)), 1, 32), + T.load( "float32x32", B_1.data, - tir.ramp(((y * 1024) + (x * 32)), 1, 32), - tir.broadcast(True, 32), + T.ramp(((y * 1024) + (x * 32)), 1, 32), + T.broadcast(True, 32), ), - tir.broadcast(True, 32), + T.broadcast(True, 32), ) - for x_outer in tir.parallel(0, 32): - C_global = tir.allocate([1024], "float32", "global") - for y_outer in tir.serial(0, 32): - for x_c_init in tir.serial(0, 32): - tir.store( + for x_outer in T.parallel(0, 32): + C_global = T.allocate([1024], "float32", "global") + for y_outer in T.serial(0, 32): + for x_c_init in T.serial(0, 32): + T.store( C_global, - tir.ramp((x_c_init * 32), 1, 32), - tir.broadcast(tir.float32(0), 32), - tir.broadcast(True, 32), + T.ramp((x_c_init * 32), 1, 32), + T.broadcast(T.float32(0), 32), + T.broadcast(True, 32), ) - for k_outer in tir.serial(0, 256): - for x_c in tir.serial(0, 32): - tir.store( + for k_outer in T.serial(0, 256): + for x_c in T.serial(0, 32): + T.store( C_global, - tir.ramp((x_c * 32), 1, 32), + T.ramp((x_c * 32), 1, 32), ( - tir.load( + T.load( "float32x32", C_global, - tir.ramp((x_c * 32), 1, 32), - tir.broadcast(True, 32), + T.ramp((x_c * 32), 1, 32), + T.broadcast(True, 32), ) + ( - tir.broadcast( - tir.load( + T.broadcast( + T.load( "float32", A_1.data, (((x_outer * 32768) + (x_c * 1024)) + (k_outer * 4)), ), 32, ) - * tir.load( + * T.load( "float32x32", packedB, - tir.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32), - tir.broadcast(True, 32), + T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32), + T.broadcast(True, 32), ) ) ), - tir.broadcast(True, 32), + T.broadcast(True, 32), ) - tir.store( + T.store( C_global, - tir.ramp((x_c * 32), 1, 32), + T.ramp((x_c * 32), 1, 32), ( - tir.load( + T.load( "float32x32", C_global, - tir.ramp((x_c * 32), 1, 32), - tir.broadcast(True, 32), + T.ramp((x_c * 32), 1, 32), + T.broadcast(True, 32), ) + ( - tir.broadcast( - tir.load( + T.broadcast( + T.load( "float32", A_1.data, ( @@ -168,31 +170,29 @@ def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: ), 32, ) - * tir.load( + * T.load( "float32x32", packedB, - tir.ramp( - (((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32 - ), - tir.broadcast(True, 32), + T.ramp((((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32), + T.broadcast(True, 32), ) ) ), - tir.broadcast(True, 32), + T.broadcast(True, 32), ) - tir.store( + T.store( C_global, - tir.ramp((x_c * 32), 1, 32), + T.ramp((x_c * 32), 1, 32), ( - tir.load( + T.load( "float32x32", C_global, - tir.ramp((x_c * 32), 1, 32), - tir.broadcast(True, 32), + T.ramp((x_c * 32), 1, 32), + T.broadcast(True, 32), ) + ( - tir.broadcast( - tir.load( + T.broadcast( + T.load( "float32", A_1.data, ( @@ -202,31 +202,29 @@ def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: ), 32, ) - * tir.load( + * T.load( "float32x32", packedB, - tir.ramp( - (((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32 - ), - tir.broadcast(True, 32), + T.ramp((((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32), + T.broadcast(True, 32), ) ) ), - tir.broadcast(True, 32), + T.broadcast(True, 32), ) - tir.store( + T.store( C_global, - tir.ramp((x_c * 32), 1, 32), + T.ramp((x_c * 32), 1, 32), ( - tir.load( + T.load( "float32x32", C_global, - tir.ramp((x_c * 32), 1, 32), - tir.broadcast(True, 32), + T.ramp((x_c * 32), 1, 32), + T.broadcast(True, 32), ) + ( - tir.broadcast( - tir.load( + T.broadcast( + T.load( "float32", A_1.data, ( @@ -236,42 +234,41 @@ def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: ), 32, ) - * tir.load( + * T.load( "float32x32", packedB, - tir.ramp( - (((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32 - ), - tir.broadcast(True, 32), + T.ramp((((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32), + T.broadcast(True, 32), ) ) ), - tir.broadcast(True, 32), + T.broadcast(True, 32), ) - for x_inner in tir.serial(0, 32): - for y_inner in tir.serial(0, 32): + for x_inner in T.serial(0, 32): + for y_inner in T.serial(0, 32): C_1.data[ ((((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32)) + y_inner) - ] = tir.load("float32", C_global, ((x_inner * 32) + y_inner)) + ] = T.load("float32", C_global, ((x_inner * 32) + y_inner)) def test_opt_gemm_lower(): - mod = Module2() - rt_mod = tvm.script.from_source(tvm.script.asscript(mod, True)) + mod = Module2 + rt_mod = tvm.script.from_source(mod.script(True)) tvm.ir.assert_structural_equal(mod, rt_mod, True) -@tvm.script.tir +@tvm.script.ir_module class Module3: + @T.prim_func def mmult( - args: ty.handle, - arg_type_ids: ty.handle, - num_args: ty.int32, - out_ret_value: ty.handle, - out_ret_tcode: ty.handle, - ) -> ty.int32: + args: T.handle, + arg_type_ids: T.handle, + num_args: T.int32, + out_ret_value: T.handle, + out_ret_tcode: T.handle, + ) -> T.int32: # function attr dict - tir.func_attr( + T.func_attr( { "tir.noalias": True, "global_symbol": "mmult", @@ -280,29 +277,29 @@ def mmult( } ) # var definition - C_global = tir.buffer_var("float32", "global") - packedB = tir.buffer_var("float32", "global") + C_global = T.buffer_var("float32", "global") + packedB = T.buffer_var("float32", "global") # body assert num_args == 3, "mmult: num_args should be 3" - arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle") - arg0_code: ty.int32 = tir.load("int32", arg_type_ids, 0) - arg1: ty.handle = tir.tvm_struct_get(args, 1, 12, dtype="handle") - arg1_code: ty.int32 = tir.load("int32", arg_type_ids, 1) - arg2: ty.handle = tir.tvm_struct_get(args, 2, 12, dtype="handle") - arg2_code: ty.int32 = tir.load("int32", arg_type_ids, 2) - A: ty.handle = tir.tvm_struct_get(arg0, 0, 1, dtype="handle") - tir.attr(A, "storage_alignment", 128) - arg0_shape: ty.handle = tir.tvm_struct_get(arg0, 0, 2, dtype="handle") - arg0_strides: ty.handle = tir.tvm_struct_get(arg0, 0, 3, dtype="handle") - dev_id: ty.int32 = tir.tvm_struct_get(arg0, 0, 9, dtype="int32") - B: ty.handle = tir.tvm_struct_get(arg1, 0, 1, dtype="handle") - tir.attr(B, "storage_alignment", 128) - arg1_shape: ty.handle = tir.tvm_struct_get(arg1, 0, 2, dtype="handle") - arg1_strides: ty.handle = tir.tvm_struct_get(arg1, 0, 3, dtype="handle") - C: ty.handle = tir.tvm_struct_get(arg2, 0, 1, dtype="handle") - tir.attr(C, "storage_alignment", 128) - arg2_shape: ty.handle = tir.tvm_struct_get(arg2, 0, 2, dtype="handle") - arg2_strides: ty.handle = tir.tvm_struct_get(arg2, 0, 3, dtype="handle") + arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") + arg0_code: T.int32 = T.load("int32", arg_type_ids, 0) + arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") + arg1_code: T.int32 = T.load("int32", arg_type_ids, 1) + arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") + arg2_code: T.int32 = T.load("int32", arg_type_ids, 2) + A: T.handle = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + T.attr(A, "storage_alignment", 128) + arg0_shape: T.handle = T.tvm_struct_get(arg0, 0, 2, dtype="handle") + arg0_strides: T.handle = T.tvm_struct_get(arg0, 0, 3, dtype="handle") + dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") + B: T.handle = T.tvm_struct_get(arg1, 0, 1, dtype="handle") + T.attr(B, "storage_alignment", 128) + arg1_shape: T.handle = T.tvm_struct_get(arg1, 0, 2, dtype="handle") + arg1_strides: T.handle = T.tvm_struct_get(arg1, 0, 3, dtype="handle") + C: T.handle = T.tvm_struct_get(arg2, 0, 1, dtype="handle") + T.attr(C, "storage_alignment", 128) + arg2_shape: T.handle = T.tvm_struct_get(arg2, 0, 2, dtype="handle") + arg2_strides: T.handle = T.tvm_struct_get(arg2, 0, 3, dtype="handle") assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( arg0_code == 4 ), "mmult: Expect arg[0] to be pointer" @@ -312,150 +309,136 @@ def mmult( assert (((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or ( arg2_code == 4 ), "mmult: Expect arg[2] to be pointer" - assert 2 == tir.tvm_struct_get( - arg0, 0, 4, dtype="int32" - ), "arg0.ndim is expected to equal 2" - assert 2 == tir.tvm_struct_get( - arg0, 0, 4, dtype="int32" - ), "arg0.ndim is expected to equal 2" + assert 2 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 2" + assert 2 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 2" assert ( - (tir.tvm_struct_get(arg0, 0, 5, dtype="uint8") == tir.uint8(2)) - and (tir.tvm_struct_get(arg0, 0, 6, dtype="uint8") == tir.uint8(32)) + (T.tvm_struct_get(arg0, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg0, 0, 6, dtype="uint8") == T.uint8(32)) ) and ( - tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1) + T.tvm_struct_get(arg0, 0, 7, dtype="uint16") == T.uint16(1) ), "arg0.dtype is expected to be float32" - assert 1024 == tir.cast( - tir.load("int64", arg0_shape, 0), "int32" + assert 1024 == T.cast( + T.load("int64", arg0_shape, 0), "int32" ), "Argument arg0.shape[0] has an unsatisfied constraint" - assert 1024 == tir.cast( - tir.load("int64", arg0_shape, 1), "int32" + assert 1024 == T.cast( + T.load("int64", arg0_shape, 1), "int32" ), "Argument arg0.shape[1] has an unsatisfied constraint" - if not (tir.isnullptr(arg0_strides, dtype="bool")): - assert (1 == tir.cast(tir.load("int64", arg0_strides, 1), "int32")) and ( - 1024 == tir.cast(tir.load("int64", arg0_strides, 0), "int32") + if not (T.isnullptr(arg0_strides, dtype="bool")): + assert (1 == T.cast(T.load("int64", arg0_strides, 1), "int32")) and ( + 1024 == T.cast(T.load("int64", arg0_strides, 0), "int32") ), "arg0.strides: expected to be compact array" - tir.evaluate(0) - assert tir.uint64(0) == tir.tvm_struct_get( + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( arg0, 0, 8, dtype="uint64" ), "Argument arg0.byte_offset has an unsatisfied constraint" - assert 1 == tir.tvm_struct_get( + assert 1 == T.tvm_struct_get( arg0, 0, 10, dtype="int32" ), "Argument arg0.device_type has an unsatisfied constraint" - assert 2 == tir.tvm_struct_get( - arg1, 0, 4, dtype="int32" - ), "arg1.ndim is expected to equal 2" - assert 2 == tir.tvm_struct_get( - arg1, 0, 4, dtype="int32" - ), "arg1.ndim is expected to equal 2" + assert 2 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 2" + assert 2 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 2" assert ( - (tir.tvm_struct_get(arg1, 0, 5, dtype="uint8") == tir.uint8(2)) - and (tir.tvm_struct_get(arg1, 0, 6, dtype="uint8") == tir.uint8(32)) + (T.tvm_struct_get(arg1, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg1, 0, 6, dtype="uint8") == T.uint8(32)) ) and ( - tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1) + T.tvm_struct_get(arg1, 0, 7, dtype="uint16") == T.uint16(1) ), "arg1.dtype is expected to be float32" - assert 1024 == tir.cast( - tir.load("int64", arg1_shape, 0), "int32" + assert 1024 == T.cast( + T.load("int64", arg1_shape, 0), "int32" ), "Argument arg1.shape[0] has an unsatisfied constraint" - assert 1024 == tir.cast( - tir.load("int64", arg1_shape, 1), "int32" + assert 1024 == T.cast( + T.load("int64", arg1_shape, 1), "int32" ), "Argument arg1.shape[1] has an unsatisfied constraint" - if not (tir.isnullptr(arg1_strides, dtype="bool")): - assert (1 == tir.cast(tir.load("int64", arg1_strides, 1), "int32")) and ( - 1024 == tir.cast(tir.load("int64", arg1_strides, 0), "int32") + if not (T.isnullptr(arg1_strides, dtype="bool")): + assert (1 == T.cast(T.load("int64", arg1_strides, 1), "int32")) and ( + 1024 == T.cast(T.load("int64", arg1_strides, 0), "int32") ), "arg1.strides: expected to be compact array" - tir.evaluate(0) - assert tir.uint64(0) == tir.tvm_struct_get( + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( arg1, 0, 8, dtype="uint64" ), "Argument arg1.byte_offset has an unsatisfied constraint" - assert 1 == tir.tvm_struct_get( + assert 1 == T.tvm_struct_get( arg1, 0, 10, dtype="int32" ), "Argument arg1.device_type has an unsatisfied constraint" - assert dev_id == tir.tvm_struct_get( + assert dev_id == T.tvm_struct_get( arg1, 0, 9, dtype="int32" ), "Argument arg1.device_id has an unsatisfied constraint" - assert 2 == tir.tvm_struct_get( - arg2, 0, 4, dtype="int32" - ), "arg2.ndim is expected to equal 2" - assert 2 == tir.tvm_struct_get( - arg2, 0, 4, dtype="int32" - ), "arg2.ndim is expected to equal 2" + assert 2 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 2" + assert 2 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 2" assert ( - (tir.tvm_struct_get(arg2, 0, 5, dtype="uint8") == tir.uint8(2)) - and (tir.tvm_struct_get(arg2, 0, 6, dtype="uint8") == tir.uint8(32)) + (T.tvm_struct_get(arg2, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg2, 0, 6, dtype="uint8") == T.uint8(32)) ) and ( - tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1) + T.tvm_struct_get(arg2, 0, 7, dtype="uint16") == T.uint16(1) ), "arg2.dtype is expected to be float32" - assert 1024 == tir.cast( - tir.load("int64", arg2_shape, 0), "int32" + assert 1024 == T.cast( + T.load("int64", arg2_shape, 0), "int32" ), "Argument arg2.shape[0] has an unsatisfied constraint" - assert 1024 == tir.cast( - tir.load("int64", arg2_shape, 1), "int32" + assert 1024 == T.cast( + T.load("int64", arg2_shape, 1), "int32" ), "Argument arg2.shape[1] has an unsatisfied constraint" - if not (tir.isnullptr(arg2_strides, dtype="bool")): - assert (1 == tir.cast(tir.load("int64", arg2_strides, 1), "int32")) and ( - 1024 == tir.cast(tir.load("int64", arg2_strides, 0), "int32") + if not (T.isnullptr(arg2_strides, dtype="bool")): + assert (1 == T.cast(T.load("int64", arg2_strides, 1), "int32")) and ( + 1024 == T.cast(T.load("int64", arg2_strides, 0), "int32") ), "arg2.strides: expected to be compact array" - tir.evaluate(0) - assert tir.uint64(0) == tir.tvm_struct_get( + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( arg2, 0, 8, dtype="uint64" ), "Argument arg2.byte_offset has an unsatisfied constraint" - assert 1 == tir.tvm_struct_get( + assert 1 == T.tvm_struct_get( arg2, 0, 10, dtype="int32" ), "Argument arg2.device_type has an unsatisfied constraint" - assert dev_id == tir.tvm_struct_get( + assert dev_id == T.tvm_struct_get( arg2, 0, 9, dtype="int32" ), "Argument arg2.device_id has an unsatisfied constraint" - tir.attr(0, "compute_scope", "mmult_compute_") - tir.attr(packedB, "storage_scope", "global") - tir.attr(packedB, "storage_alignment", 128) - with tir.let( + T.attr(0, "compute_scope", "mmult_compute_") + T.attr(packedB, "storage_scope", "global") + T.attr(packedB, "storage_alignment", 128) + with T.let( packedB, - tir.TVMBackendAllocWorkspace(1, dev_id, tir.uint64(4194304), 2, 32, dtype="handle"), + T.TVMBackendAllocWorkspace(1, dev_id, T.uint64(4194304), 2, 32, dtype="handle"), ): - if tir.isnullptr(packedB, dtype="bool"): - tir.evaluate(tir.tvm_throw_last_error(dtype="int32")) - for x in tir.parallel(0, 32): - for y in tir.serial(0, 1024): - tir.store( + if T.isnullptr(packedB, dtype="bool"): + T.evaluate(T.tvm_throw_last_error(dtype="int32")) + for x in T.parallel(0, 32): + for y in T.serial(0, 1024): + T.store( packedB, - tir.ramp(((x * 32768) + (y * 32)), 1, 32), - tir.load( + T.ramp(((x * 32768) + (y * 32)), 1, 32), + T.load( "float32x32", B, - tir.ramp(((y * 1024) + (x * 32)), 1, 32), - tir.broadcast(True, 32), + T.ramp(((y * 1024) + (x * 32)), 1, 32), + T.broadcast(True, 32), ), - tir.broadcast(True, 32), + T.broadcast(True, 32), ) - for x_outer in tir.parallel(0, 32): - tir.attr(C_global, "storage_scope", "global") - tir.attr(C_global, "storage_alignment", 128) - with tir.let( + for x_outer in T.parallel(0, 32): + T.attr(C_global, "storage_scope", "global") + T.attr(C_global, "storage_alignment", 128) + with T.let( C_global, - tir.TVMBackendAllocWorkspace( - 1, dev_id, tir.uint64(4096), 2, 32, dtype="handle" - ), + T.TVMBackendAllocWorkspace(1, dev_id, T.uint64(4096), 2, 32, dtype="handle"), ): - if tir.isnullptr(C_global, dtype="bool"): - tir.evaluate(tir.tvm_throw_last_error(dtype="int32")) - for y_outer in tir.serial(0, 32): - for x_c_init in tir.serial(0, 32): - tir.store( + if T.isnullptr(C_global, dtype="bool"): + T.evaluate(T.tvm_throw_last_error(dtype="int32")) + for y_outer in T.serial(0, 32): + for x_c_init in T.serial(0, 32): + T.store( C_global, - tir.ramp((x_c_init * 32), 1, 32), - tir.broadcast(tir.float32(0), 32), - tir.broadcast(True, 32), + T.ramp((x_c_init * 32), 1, 32), + T.broadcast(T.float32(0), 32), + T.broadcast(True, 32), ) - for k_outer in tir.serial(0, 256): - for x_c in tir.serial(0, 32): - tir.store( + for k_outer in T.serial(0, 256): + for x_c in T.serial(0, 32): + T.store( C_global, - tir.ramp((x_c * 32), 1, 32), - tir.call_llvm_pure_intrin( - tir.uint32(97), - tir.uint32(3), - tir.broadcast( - tir.load( + T.ramp((x_c * 32), 1, 32), + T.call_llvm_pure_intrin( + T.uint32(97), + T.uint32(3), + T.broadcast( + T.load( "float32", A, ( @@ -465,30 +448,30 @@ def mmult( ), 32, ), - tir.load( + T.load( "float32x32", packedB, - tir.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32), - tir.broadcast(True, 32), + T.ramp(((y_outer * 32768) + (k_outer * 128)), 1, 32), + T.broadcast(True, 32), ), - tir.load( + T.load( "float32x32", C_global, - tir.ramp((x_c * 32), 1, 32), - tir.broadcast(True, 32), + T.ramp((x_c * 32), 1, 32), + T.broadcast(True, 32), ), dtype="float32x32", ), - tir.broadcast(True, 32), + T.broadcast(True, 32), ) - tir.store( + T.store( C_global, - tir.ramp((x_c * 32), 1, 32), - tir.call_llvm_pure_intrin( - tir.uint32(97), - tir.uint32(3), - tir.broadcast( - tir.load( + T.ramp((x_c * 32), 1, 32), + T.call_llvm_pure_intrin( + T.uint32(97), + T.uint32(3), + T.broadcast( + T.load( "float32", A, ( @@ -501,32 +484,32 @@ def mmult( ), 32, ), - tir.load( + T.load( "float32x32", packedB, - tir.ramp( + T.ramp( (((y_outer * 32768) + (k_outer * 128)) + 32), 1, 32 ), - tir.broadcast(True, 32), + T.broadcast(True, 32), ), - tir.load( + T.load( "float32x32", C_global, - tir.ramp((x_c * 32), 1, 32), - tir.broadcast(True, 32), + T.ramp((x_c * 32), 1, 32), + T.broadcast(True, 32), ), dtype="float32x32", ), - tir.broadcast(True, 32), + T.broadcast(True, 32), ) - tir.store( + T.store( C_global, - tir.ramp((x_c * 32), 1, 32), - tir.call_llvm_pure_intrin( - tir.uint32(97), - tir.uint32(3), - tir.broadcast( - tir.load( + T.ramp((x_c * 32), 1, 32), + T.call_llvm_pure_intrin( + T.uint32(97), + T.uint32(3), + T.broadcast( + T.load( "float32", A, ( @@ -539,32 +522,32 @@ def mmult( ), 32, ), - tir.load( + T.load( "float32x32", packedB, - tir.ramp( + T.ramp( (((y_outer * 32768) + (k_outer * 128)) + 64), 1, 32 ), - tir.broadcast(True, 32), + T.broadcast(True, 32), ), - tir.load( + T.load( "float32x32", C_global, - tir.ramp((x_c * 32), 1, 32), - tir.broadcast(True, 32), + T.ramp((x_c * 32), 1, 32), + T.broadcast(True, 32), ), dtype="float32x32", ), - tir.broadcast(True, 32), + T.broadcast(True, 32), ) - tir.store( + T.store( C_global, - tir.ramp((x_c * 32), 1, 32), - tir.call_llvm_pure_intrin( - tir.uint32(97), - tir.uint32(3), - tir.broadcast( - tir.load( + T.ramp((x_c * 32), 1, 32), + T.call_llvm_pure_intrin( + T.uint32(97), + T.uint32(3), + T.broadcast( + T.load( "float32", A, ( @@ -577,128 +560,126 @@ def mmult( ), 32, ), - tir.load( + T.load( "float32x32", packedB, - tir.ramp( + T.ramp( (((y_outer * 32768) + (k_outer * 128)) + 96), 1, 32 ), - tir.broadcast(True, 32), + T.broadcast(True, 32), ), - tir.load( + T.load( "float32x32", C_global, - tir.ramp((x_c * 32), 1, 32), - tir.broadcast(True, 32), + T.ramp((x_c * 32), 1, 32), + T.broadcast(True, 32), ), dtype="float32x32", ), - tir.broadcast(True, 32), + T.broadcast(True, 32), ) - for x_inner in tir.serial(0, 32): - for y_inner in tir.serial(0, 32): + for x_inner in T.serial(0, 32): + for y_inner in T.serial(0, 32): C[ ( (((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32)) + y_inner ) - ] = tir.load("float32", C_global, ((x_inner * 32) + y_inner)) - if tir.TVMBackendFreeWorkspace(1, dev_id, C_global, dtype="int32") != 0: - tir.evaluate(tir.tvm_throw_last_error(dtype="int32")) - if tir.TVMBackendFreeWorkspace(1, dev_id, packedB, dtype="int32") != 0: - tir.evaluate(tir.tvm_throw_last_error(dtype="int32")) + ] = T.load("float32", C_global, ((x_inner * 32) + y_inner)) + if T.TVMBackendFreeWorkspace(1, dev_id, C_global, dtype="int32") != 0: + T.evaluate(T.tvm_throw_last_error(dtype="int32")) + if T.TVMBackendFreeWorkspace(1, dev_id, packedB, dtype="int32") != 0: + T.evaluate(T.tvm_throw_last_error(dtype="int32")) def test_opt_gemm_mod_host(): - mod = Module3() - rt_mod = tvm.script.from_source(tvm.script.asscript(mod, True)) + mod = Module3 + rt_mod = tvm.script.from_source(mod.script(True)) tvm.ir.assert_structural_equal(mod, rt_mod, True) -@tvm.script.tir -def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) -> None: +@T.prim_func +def opt_conv_tensorcore_normalize(A: T.handle, W: T.handle, Conv: T.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) # var definition - bx = tir.env_thread("blockIdx.x") - by = tir.env_thread("blockIdx.y") - bz = tir.env_thread("blockIdx.z") - tx = tir.env_thread("threadIdx.x") - ty = tir.env_thread("threadIdx.y") - tz = tir.env_thread("threadIdx.z") + bx = T.env_thread("blockIdx.x") + by = T.env_thread("blockIdx.y") + bz = T.env_thread("blockIdx.z") + tx = T.env_thread("threadIdx.x") + ty = T.env_thread("threadIdx.y") + tz = T.env_thread("threadIdx.z") # buffer definition - Apad_shared = tir.buffer_decl( + Apad_shared = T.buffer_decl( [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 ) - Apad_shared_wmma_matrix_a = tir.buffer_decl( + Apad_shared_wmma_matrix_a = T.buffer_decl( [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 ) - BA = tir.buffer_decl( + BA = T.buffer_decl( [16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256 ) - BB = tir.buffer_decl( + BB = T.buffer_decl( [16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256 ) - BC = tir.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) - Conv_wmma_accumulator = tir.buffer_decl( + BC = T.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) + Conv_wmma_accumulator = T.buffer_decl( [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 ) - W_shared = tir.buffer_decl( + W_shared = T.buffer_decl( [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 ) - W_shared_wmma_matrix_b = tir.buffer_decl( + W_shared_wmma_matrix_b = T.buffer_decl( [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 ) - buffer = tir.buffer_decl([16, 16], dtype="float16", scope="shared", align=32, offset_factor=256) - buffer_1 = tir.buffer_decl( + buffer = T.buffer_decl([16, 16], dtype="float16", scope="shared", align=32, offset_factor=256) + buffer_1 = T.buffer_decl( [16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256 ) - buffer_2 = tir.buffer_decl( - [16, 16], dtype="float16", scope="shared", align=32, offset_factor=256 - ) - buffer_3 = tir.buffer_decl( + buffer_2 = T.buffer_decl([16, 16], dtype="float16", scope="shared", align=32, offset_factor=256) + buffer_3 = T.buffer_decl( [16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256 ) - buffer_4 = tir.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) - buffer_5 = tir.buffer_decl([16, 16], align=32, offset_factor=256) - A_1 = tir.match_buffer( + buffer_4 = T.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) + buffer_5 = T.buffer_decl([16, 16], align=32, offset_factor=256) + A_1 = T.match_buffer( A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 ) - W_1 = tir.match_buffer( + W_1 = T.match_buffer( W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 ) - Conv_1 = tir.match_buffer( + Conv_1 = T.match_buffer( Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 ) # body - tir.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16], "") - tir.launch_thread(bz, 196) - tir.launch_thread(bx, 2) - tir.launch_thread(by, 4) - tir.launch_thread(ty, 4) - tir.launch_thread(tz, 2) - tir.realize( + T.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16], "") + T.launch_thread(bz, 196) + T.launch_thread(bx, 2) + T.launch_thread(by, 4) + T.launch_thread(ty, 4) + T.launch_thread(tz, 2) + T.realize( Conv_wmma_accumulator[ ((bx * 8) + (ty * 2)) : (((bx * 8) + (ty * 2)) + 2), - tir.floordiv(bz, 14) : (tir.floordiv(bz, 14) + 1), - tir.floormod(bz, 14) : (tir.floormod(bz, 14) + 1), + T.floordiv(bz, 14) : (T.floordiv(bz, 14) + 1), + T.floormod(bz, 14) : (T.floormod(bz, 14) + 1), ((by * 8) + (tz * 4)) : (((by * 8) + (tz * 4)) + 4), 0:16, 0:16, ], "wmma.accumulator", ) - for n_c_init in tir.serial(0, 2): - for o_c_init in tir.serial(0, 4): - tir.attr( + for n_c_init in T.serial(0, 2): + for o_c_init in T.serial(0, 4): + T.attr( [BC, Conv_wmma_accumulator], "buffer_bind_scope", - tir.tvm_tuple( + T.tvm_tuple( (n_c_init + ((bx * 8) + (ty * 2))), 1, - tir.floordiv(bz, 14), + T.floordiv(bz, 14), 1, - tir.floormod(bz, 14), + T.floormod(bz, 14), 1, (o_c_init + ((by * 8) + (tz * 4))), 1, @@ -709,64 +690,64 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ), ) - tir.evaluate( - tir.tvm_fill_fragment( + T.evaluate( + T.tvm_fill_fragment( BC.data, 16, 16, 16, - tir.floordiv(BC.elem_offset, 256), - tir.float32(0), + T.floordiv(BC.elem_offset, 256), + T.float32(0), dtype="handle", ) ) - for ic_outer in tir.serial(0, 8): - for kh in tir.serial(0, 3): - tir.realize( + for ic_outer in T.serial(0, 8): + for kh in T.serial(0, 3): + T.realize( Apad_shared[ (bx * 8) : ((bx * 8) + 8), - (tir.floordiv(bz, 14) + kh) : ((tir.floordiv(bz, 14) + kh) + 1), - tir.floormod(bz, 14) : (tir.floormod(bz, 14) + 3), + (T.floordiv(bz, 14) + kh) : ((T.floordiv(bz, 14) + kh) + 1), + T.floormod(bz, 14) : (T.floormod(bz, 14) + 3), (ic_outer * 2) : ((ic_outer * 2) + 2), 0:16, 0:16, ], "shared", ) - for ax2 in tir.serial(0, 3): - for ax3 in tir.serial(0, 2): - for ax4_ax5_fused_outer in tir.serial(0, 8): - tir.launch_thread(tx, 32) + for ax2 in T.serial(0, 3): + for ax3 in T.serial(0, 2): + for ax4_ax5_fused_outer in T.serial(0, 8): + T.launch_thread(tx, 32) Apad_shared[ ((tz + (ty * 2)) + (bx * 8)), - (tir.floordiv(bz, 14) + kh), - (ax2 + tir.floormod(bz, 14)), + (T.floordiv(bz, 14) + kh), + (ax2 + T.floormod(bz, 14)), (ax3 + (ic_outer * 2)), - tir.floordiv((tx + (ax4_ax5_fused_outer * 32)), 16), - tir.floormod((tx + (ax4_ax5_fused_outer * 32)), 16), - ] = tir.if_then_else( + T.floordiv((tx + (ax4_ax5_fused_outer * 32)), 16), + T.floormod((tx + (ax4_ax5_fused_outer * 32)), 16), + ] = T.if_then_else( ( ( ( - ((tir.floordiv(bz, 14) + kh) >= 1) - and (((tir.floordiv(bz, 14) + kh) - 1) < 14) + ((T.floordiv(bz, 14) + kh) >= 1) + and (((T.floordiv(bz, 14) + kh) - 1) < 14) ) - and ((ax2 + tir.floormod(bz, 14)) >= 1) + and ((ax2 + T.floormod(bz, 14)) >= 1) ) - and (((ax2 + tir.floormod(bz, 14)) - 1) < 14) + and (((ax2 + T.floormod(bz, 14)) - 1) < 14) ), A_1[ ((tz + (ty * 2)) + (bx * 8)), - ((tir.floordiv(bz, 14) + kh) - 1), - ((ax2 + tir.floormod(bz, 14)) - 1), + ((T.floordiv(bz, 14) + kh) - 1), + ((ax2 + T.floormod(bz, 14)) - 1), (ax3 + (ic_outer * 2)), - tir.floordiv((tx + (ax4_ax5_fused_outer * 32)), 16), - tir.floormod((tx + (ax4_ax5_fused_outer * 32)), 16), + T.floordiv((tx + (ax4_ax5_fused_outer * 32)), 16), + T.floormod((tx + (ax4_ax5_fused_outer * 32)), 16), ], - tir.float16(0), + T.float16(0), dtype="float16", ) - tir.realize( + T.realize( W_shared[ kh : (kh + 1), 0:3, @@ -777,48 +758,48 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - ], "shared", ) - for ax1 in tir.serial(0, 3): - for ax2_1 in tir.serial(0, 2): - tir.launch_thread(tx, 32) - for ax4_ax5_fused_inner in tir.vectorized(0, 8): + for ax1 in T.serial(0, 3): + for ax2_1 in T.serial(0, 2): + T.launch_thread(tx, 32) + for ax4_ax5_fused_inner in T.vectorized(0, 8): W_shared[ kh, ax1, (ax2_1 + (ic_outer * 2)), ((tz + (ty * 2)) + (by * 8)), - tir.floordiv((ax4_ax5_fused_inner + (tx * 8)), 16), - tir.floormod((ax4_ax5_fused_inner + (tx * 8)), 16), + T.floordiv((ax4_ax5_fused_inner + (tx * 8)), 16), + T.floormod((ax4_ax5_fused_inner + (tx * 8)), 16), ] = W_1[ kh, ax1, (ax2_1 + (ic_outer * 2)), ((tz + (ty * 2)) + (by * 8)), - tir.floordiv((ax4_ax5_fused_inner + (tx * 8)), 16), - tir.floormod((ax4_ax5_fused_inner + (tx * 8)), 16), + T.floordiv((ax4_ax5_fused_inner + (tx * 8)), 16), + T.floormod((ax4_ax5_fused_inner + (tx * 8)), 16), ] - for ic_inner in tir.serial(0, 2): - for kw in tir.serial(0, 3): - tir.realize( + for ic_inner in T.serial(0, 2): + for kw in T.serial(0, 3): + T.realize( Apad_shared_wmma_matrix_a[ ((bx * 8) + (ty * 2)) : (((bx * 8) + (ty * 2)) + 2), - (tir.floordiv(bz, 14) + kh) : ((tir.floordiv(bz, 14) + kh) + 1), - (kw + tir.floormod(bz, 14)) : ((kw + tir.floormod(bz, 14)) + 1), + (T.floordiv(bz, 14) + kh) : ((T.floordiv(bz, 14) + kh) + 1), + (kw + T.floormod(bz, 14)) : ((kw + T.floormod(bz, 14)) + 1), ((ic_outer * 2) + ic_inner) : (((ic_outer * 2) + ic_inner) + 1), 0:16, 0:16, ], "wmma.matrix_a", ) - for ax0 in tir.serial(0, 2): - tir.attr( + for ax0 in T.serial(0, 2): + T.attr( [buffer, Apad_shared], "buffer_bind_scope", - tir.tvm_tuple( + T.tvm_tuple( (ax0 + ((bx * 8) + (ty * 2))), 1, - (tir.floordiv(bz, 14) + kh), + (T.floordiv(bz, 14) + kh), 1, - (kw + tir.floormod(bz, 14)), + (kw + T.floormod(bz, 14)), 1, ((ic_outer * 2) + ic_inner), 1, @@ -829,15 +810,15 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ), ) - tir.attr( + T.attr( [buffer_1, Apad_shared_wmma_matrix_a], "buffer_bind_scope", - tir.tvm_tuple( + T.tvm_tuple( (ax0 + ((bx * 8) + (ty * 2))), 1, - (tir.floordiv(bz, 14) + kh), + (T.floordiv(bz, 14) + kh), 1, - (kw + tir.floormod(bz, 14)), + (kw + T.floormod(bz, 14)), 1, ((ic_outer * 2) + ic_inner), 1, @@ -848,15 +829,15 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ), ) - tir.evaluate( - tir.tvm_load_matrix_sync( + T.evaluate( + T.tvm_load_matrix_sync( buffer_1.data, 16, 16, 16, - tir.floordiv(buffer_1.elem_offset, 256), - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.floordiv(buffer_1.elem_offset, 256), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), buffer.data, buffer.elem_offset, 256, @@ -868,7 +849,7 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ) ) - tir.realize( + T.realize( W_shared_wmma_matrix_b[ kh : (kh + 1), kw : (kw + 1), @@ -879,11 +860,11 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - ], "wmma.matrix_b", ) - for ax3_1 in tir.serial(0, 4): - tir.attr( + for ax3_1 in T.serial(0, 4): + T.attr( [buffer_2, W_shared], "buffer_bind_scope", - tir.tvm_tuple( + T.tvm_tuple( kh, 1, kw, @@ -899,10 +880,10 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ), ) - tir.attr( + T.attr( [buffer_3, W_shared_wmma_matrix_b], "buffer_bind_scope", - tir.tvm_tuple( + T.tvm_tuple( kh, 1, kw, @@ -918,15 +899,15 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ), ) - tir.evaluate( - tir.tvm_load_matrix_sync( + T.evaluate( + T.tvm_load_matrix_sync( buffer_3.data, 16, 16, 16, - tir.floordiv(buffer_3.elem_offset, 256), - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.floordiv(buffer_3.elem_offset, 256), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), buffer_2.data, buffer_2.elem_offset, 256, @@ -938,17 +919,17 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ) ) - for n_c in tir.serial(0, 2): - for o_c in tir.serial(0, 4): - tir.attr( + for n_c in T.serial(0, 2): + for o_c in T.serial(0, 4): + T.attr( [BA, Apad_shared_wmma_matrix_a], "buffer_bind_scope", - tir.tvm_tuple( + T.tvm_tuple( (n_c + ((bx * 8) + (ty * 2))), 1, - (tir.floordiv(bz, 14) + kh), + (T.floordiv(bz, 14) + kh), 1, - (tir.floormod(bz, 14) + kw), + (T.floormod(bz, 14) + kw), 1, ((ic_outer * 2) + ic_inner), 1, @@ -959,10 +940,10 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ), ) - tir.attr( + T.attr( [BB, W_shared_wmma_matrix_b], "buffer_bind_scope", - tir.tvm_tuple( + T.tvm_tuple( kh, 1, kw, @@ -978,15 +959,15 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ), ) - tir.attr( + T.attr( [BC, Conv_wmma_accumulator], "buffer_bind_scope", - tir.tvm_tuple( + T.tvm_tuple( (n_c + ((bx * 8) + (ty * 2))), 1, - tir.floordiv(bz, 14), + T.floordiv(bz, 14), 1, - tir.floormod(bz, 14), + T.floormod(bz, 14), 1, (o_c + ((by * 8) + (tz * 4))), 1, @@ -997,30 +978,30 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ), ) - tir.evaluate( - tir.tvm_mma_sync( + T.evaluate( + T.tvm_mma_sync( BC.data, - tir.floordiv(BC.elem_offset, 256), + T.floordiv(BC.elem_offset, 256), BA.data, - tir.floordiv(BA.elem_offset, 256), + T.floordiv(BA.elem_offset, 256), BB.data, - tir.floordiv(BB.elem_offset, 256), + T.floordiv(BB.elem_offset, 256), BC.data, - tir.floordiv(BC.elem_offset, 256), + T.floordiv(BC.elem_offset, 256), dtype="handle", ) ) - for n_inner in tir.serial(0, 2): - for o_inner in tir.serial(0, 4): - tir.attr( + for n_inner in T.serial(0, 2): + for o_inner in T.serial(0, 4): + T.attr( [buffer_4, Conv_wmma_accumulator], "buffer_bind_scope", - tir.tvm_tuple( + T.tvm_tuple( ((((bx * 4) + ty) * 2) + n_inner), 1, - tir.floordiv(bz, 14), + T.floordiv(bz, 14), 1, - tir.floormod(bz, 14), + T.floormod(bz, 14), 1, ((((by * 2) + tz) * 4) + o_inner), 1, @@ -1031,15 +1012,15 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ), ) - tir.attr( + T.attr( [buffer_5, Conv_1], "buffer_bind_scope", - tir.tvm_tuple( + T.tvm_tuple( ((((bx * 4) + ty) * 2) + n_inner), 1, - tir.floordiv(bz, 14), + T.floordiv(bz, 14), 1, - tir.floormod(bz, 14), + T.floormod(bz, 14), 1, ((((by * 2) + tz) * 4) + o_inner), 1, @@ -1050,15 +1031,15 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ), ) - tir.evaluate( - tir.tvm_store_matrix_sync( + T.evaluate( + T.tvm_store_matrix_sync( buffer_4.data, 16, 16, 16, - tir.floordiv(buffer_4.elem_offset, 256), - tir.tvm_access_ptr( - tir.type_annotation(dtype="float32"), + T.floordiv(buffer_4.elem_offset, 256), + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), buffer_5.data, buffer_5.elem_offset, 256, @@ -1074,82 +1055,82 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - def test_opt_conv_tensorcore_normalize(): mod = opt_conv_tensorcore_normalize - rt_mod = tvm.script.from_source(tvm.script.asscript(mod, True)) + rt_mod = tvm.script.from_source(mod.script(True)) tvm.ir.assert_structural_equal(mod, rt_mod, True) -@tvm.script.tir -def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> None: +@T.prim_func +def opt_conv_tensorcore_lower(A: T.handle, W: T.handle, Conv: T.handle) -> None: # function attr dict - tir.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) # body - A_1 = tir.match_buffer( + A_1 = T.match_buffer( A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 ) - W_1 = tir.match_buffer( + W_1 = T.match_buffer( W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 ) - Conv_1 = tir.match_buffer( + Conv_1 = T.match_buffer( Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 ) - bx = tir.env_thread("blockIdx.x") - by = tir.env_thread("blockIdx.y") - bz = tir.env_thread("blockIdx.z") - tx = tir.env_thread("threadIdx.x") - ty = tir.env_thread("threadIdx.y") - tz = tir.env_thread("threadIdx.z") - tir.launch_thread(bz, 196) - Conv_wmma_accumulator = tir.allocate([2048], "float32", "wmma.accumulator") - Apad_shared = tir.allocate([12288], "float16", "shared") - W_shared = tir.allocate([12288], "float16", "shared") - Apad_shared_wmma_matrix_a = tir.allocate([512], "float16", "wmma.matrix_a") - W_shared_wmma_matrix_b = tir.allocate([1024], "float16", "wmma.matrix_b") - tir.launch_thread(bx, 2) - tir.launch_thread(by, 4) - tir.launch_thread(ty, 4) - tir.launch_thread(tz, 2) - tir.evaluate( - tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 0, tir.float32(0), dtype="handle") + bx = T.env_thread("blockIdx.x") + by = T.env_thread("blockIdx.y") + bz = T.env_thread("blockIdx.z") + tx = T.env_thread("threadIdx.x") + ty = T.env_thread("threadIdx.y") + tz = T.env_thread("threadIdx.z") + T.launch_thread(bz, 196) + Conv_wmma_accumulator = T.allocate([2048], "float32", "wmma.accumulator") + Apad_shared = T.allocate([12288], "float16", "shared") + W_shared = T.allocate([12288], "float16", "shared") + Apad_shared_wmma_matrix_a = T.allocate([512], "float16", "wmma.matrix_a") + W_shared_wmma_matrix_b = T.allocate([1024], "float16", "wmma.matrix_b") + T.launch_thread(bx, 2) + T.launch_thread(by, 4) + T.launch_thread(ty, 4) + T.launch_thread(tz, 2) + T.evaluate( + T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 0, T.float32(0), dtype="handle") ) - tir.evaluate( - tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 1, tir.float32(0), dtype="handle") + T.evaluate( + T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 1, T.float32(0), dtype="handle") ) - tir.evaluate( - tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 2, tir.float32(0), dtype="handle") + T.evaluate( + T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 2, T.float32(0), dtype="handle") ) - tir.evaluate( - tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 3, tir.float32(0), dtype="handle") + T.evaluate( + T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 3, T.float32(0), dtype="handle") ) - tir.evaluate( - tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 4, tir.float32(0), dtype="handle") + T.evaluate( + T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 4, T.float32(0), dtype="handle") ) - tir.evaluate( - tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 5, tir.float32(0), dtype="handle") + T.evaluate( + T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 5, T.float32(0), dtype="handle") ) - tir.evaluate( - tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 6, tir.float32(0), dtype="handle") + T.evaluate( + T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 6, T.float32(0), dtype="handle") ) - tir.evaluate( - tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 7, tir.float32(0), dtype="handle") + T.evaluate( + T.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 7, T.float32(0), dtype="handle") ) - for ic_outer in tir.serial(0, 8): - for kh in tir.serial(0, 3): - for ax2 in tir.serial(0, 3): - with tir.launch_thread(tx, 32): + for ic_outer in T.serial(0, 8): + for kh in T.serial(0, 3): + for ax2 in T.serial(0, 3): + with T.launch_thread(tx, 32): Apad_shared[ ((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1175,24 +1156,24 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 61440 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): + with T.launch_thread(tx, 32): Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 32) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1218,24 +1199,24 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 61408 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): + with T.launch_thread(tx, 32): Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 64) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1261,24 +1242,24 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 61376 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): + with T.launch_thread(tx, 32): Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 96) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1304,24 +1285,24 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 61344 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): + with T.launch_thread(tx, 32): Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 128) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1347,24 +1328,24 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 61312 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): + with T.launch_thread(tx, 32): Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 160) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1390,24 +1371,24 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 61280 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): + with T.launch_thread(tx, 32): Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 192) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1433,24 +1414,24 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 61248 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): + with T.launch_thread(tx, 32): Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 224) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1476,24 +1457,24 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 61216 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): + with T.launch_thread(tx, 32): Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 256) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1519,24 +1500,24 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 61184 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): + with T.launch_thread(tx, 32): Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 288) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1562,24 +1543,24 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 61152 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): + with T.launch_thread(tx, 32): Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 320) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1605,24 +1586,24 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 61120 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): + with T.launch_thread(tx, 32): Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 352) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1648,24 +1629,24 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 61088 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): + with T.launch_thread(tx, 32): Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 384) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1691,24 +1672,24 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 61056 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): + with T.launch_thread(tx, 32): Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 416) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1734,24 +1715,24 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 61024 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): + with T.launch_thread(tx, 32): Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 448) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) + (1 <= (T.floordiv(bz, 14) + kh)) + and ((T.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1777,24 +1758,21 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 60992 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - tir.launch_thread(tx, 32) + T.launch_thread(tx, 32) Apad_shared[ (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 480) - ] = tir.if_then_else( + ] = T.if_then_else( ( ( - ( - (1 <= (tir.floordiv(bz, 14) + kh)) - and ((tir.floordiv(bz, 14) + kh) < 15) - ) - and (1 <= (ax2 + tir.floormod(bz, 14))) + ((1 <= (T.floordiv(bz, 14) + kh)) and ((T.floordiv(bz, 14) + kh) < 15)) + and (1 <= (ax2 + T.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(bz, 14)) < 15) + and ((ax2 + T.floormod(bz, 14)) < 15) ), - tir.load( + T.load( "float16", A_1.data, ( @@ -1817,17 +1795,17 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No - 60960 ), ), - tir.float16(0), + T.float16(0), dtype="float16", ) - with tir.launch_thread(tx, 32): - tir.store( + with T.launch_thread(tx, 32): + T.store( W_shared, - tir.ramp((((ty * 512) + (tz * 256)) + (tx * 8)), 1, 8), - tir.load( + T.ramp((((ty * 512) + (tz * 256)) + (tx * 8)), 1, 8), + T.load( "float16x8", W_1.data, - tir.ramp( + T.ramp( ( ( ( @@ -1841,18 +1819,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No 1, 8, ), - tir.broadcast(True, 8), + T.broadcast(True, 8), ), - tir.broadcast(True, 8), + T.broadcast(True, 8), ) - with tir.launch_thread(tx, 32): - tir.store( + with T.launch_thread(tx, 32): + T.store( W_shared, - tir.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 2048), 1, 8), - tir.load( + T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 2048), 1, 8), + T.load( "float16x8", W_1.data, - tir.ramp( + T.ramp( ( ( ( @@ -1869,18 +1847,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No 1, 8, ), - tir.broadcast(True, 8), + T.broadcast(True, 8), ), - tir.broadcast(True, 8), + T.broadcast(True, 8), ) - with tir.launch_thread(tx, 32): - tir.store( + with T.launch_thread(tx, 32): + T.store( W_shared, - tir.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 4096), 1, 8), - tir.load( + T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 4096), 1, 8), + T.load( "float16x8", W_1.data, - tir.ramp( + T.ramp( ( ( ( @@ -1897,18 +1875,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No 1, 8, ), - tir.broadcast(True, 8), + T.broadcast(True, 8), ), - tir.broadcast(True, 8), + T.broadcast(True, 8), ) - with tir.launch_thread(tx, 32): - tir.store( + with T.launch_thread(tx, 32): + T.store( W_shared, - tir.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 6144), 1, 8), - tir.load( + T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 6144), 1, 8), + T.load( "float16x8", W_1.data, - tir.ramp( + T.ramp( ( ( ( @@ -1925,18 +1903,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No 1, 8, ), - tir.broadcast(True, 8), + T.broadcast(True, 8), ), - tir.broadcast(True, 8), + T.broadcast(True, 8), ) - with tir.launch_thread(tx, 32): - tir.store( + with T.launch_thread(tx, 32): + T.store( W_shared, - tir.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 8192), 1, 8), - tir.load( + T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 8192), 1, 8), + T.load( "float16x8", W_1.data, - tir.ramp( + T.ramp( ( ( ( @@ -1953,18 +1931,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No 1, 8, ), - tir.broadcast(True, 8), + T.broadcast(True, 8), ), - tir.broadcast(True, 8), + T.broadcast(True, 8), ) - with tir.launch_thread(tx, 32): - tir.store( + with T.launch_thread(tx, 32): + T.store( W_shared, - tir.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 10240), 1, 8), - tir.load( + T.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 10240), 1, 8), + T.load( "float16x8", W_1.data, - tir.ramp( + T.ramp( ( ( ( @@ -1981,21 +1959,21 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No 1, 8, ), - tir.broadcast(True, 8), + T.broadcast(True, 8), ), - tir.broadcast(True, 8), + T.broadcast(True, 8), ) - for ic_inner in tir.serial(0, 2): - for kw in tir.serial(0, 3): - tir.evaluate( - tir.tvm_load_matrix_sync( + for ic_inner in T.serial(0, 2): + for kw in T.serial(0, 3): + T.evaluate( + T.tvm_load_matrix_sync( Apad_shared_wmma_matrix_a, 16, 16, 16, 0, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), Apad_shared, (((ty * 3072) + (kw * 512)) + (ic_inner * 256)), 256, @@ -2007,15 +1985,15 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_load_matrix_sync( + T.evaluate( + T.tvm_load_matrix_sync( Apad_shared_wmma_matrix_a, 16, 16, 16, 1, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), Apad_shared, ((((ty * 3072) + (kw * 512)) + (ic_inner * 256)) + 1536), 256, @@ -2027,15 +2005,15 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_load_matrix_sync( + T.evaluate( + T.tvm_load_matrix_sync( W_shared_wmma_matrix_b, 16, 16, 16, 0, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), W_shared, (((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)), 256, @@ -2047,15 +2025,15 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_load_matrix_sync( + T.evaluate( + T.tvm_load_matrix_sync( W_shared_wmma_matrix_b, 16, 16, 16, 1, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), W_shared, ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 256), 256, @@ -2067,15 +2045,15 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_load_matrix_sync( + T.evaluate( + T.tvm_load_matrix_sync( W_shared_wmma_matrix_b, 16, 16, 16, 2, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), W_shared, ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 512), 256, @@ -2087,15 +2065,15 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_load_matrix_sync( + T.evaluate( + T.tvm_load_matrix_sync( W_shared_wmma_matrix_b, 16, 16, 16, 3, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float16"), + T.tvm_access_ptr( + T.type_annotation(dtype="float16"), W_shared, ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 768), 256, @@ -2107,8 +2085,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_mma_sync( + T.evaluate( + T.tvm_mma_sync( Conv_wmma_accumulator, 0, Apad_shared_wmma_matrix_a, @@ -2120,8 +2098,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_mma_sync( + T.evaluate( + T.tvm_mma_sync( Conv_wmma_accumulator, 1, Apad_shared_wmma_matrix_a, @@ -2133,8 +2111,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_mma_sync( + T.evaluate( + T.tvm_mma_sync( Conv_wmma_accumulator, 2, Apad_shared_wmma_matrix_a, @@ -2146,8 +2124,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_mma_sync( + T.evaluate( + T.tvm_mma_sync( Conv_wmma_accumulator, 3, Apad_shared_wmma_matrix_a, @@ -2159,8 +2137,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_mma_sync( + T.evaluate( + T.tvm_mma_sync( Conv_wmma_accumulator, 4, Apad_shared_wmma_matrix_a, @@ -2172,8 +2150,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_mma_sync( + T.evaluate( + T.tvm_mma_sync( Conv_wmma_accumulator, 5, Apad_shared_wmma_matrix_a, @@ -2185,8 +2163,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_mma_sync( + T.evaluate( + T.tvm_mma_sync( Conv_wmma_accumulator, 6, Apad_shared_wmma_matrix_a, @@ -2198,8 +2176,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_mma_sync( + T.evaluate( + T.tvm_mma_sync( Conv_wmma_accumulator, 7, Apad_shared_wmma_matrix_a, @@ -2211,15 +2189,15 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_store_matrix_sync( + T.evaluate( + T.tvm_store_matrix_sync( Conv_wmma_accumulator, 16, 16, 16, 0, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float32"), + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), Conv_1.data, (((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + (tz * 1024)), 256, @@ -2231,15 +2209,15 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_store_matrix_sync( + T.evaluate( + T.tvm_store_matrix_sync( Conv_wmma_accumulator, 16, 16, 16, 1, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float32"), + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), Conv_1.data, ( ( @@ -2257,15 +2235,15 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_store_matrix_sync( + T.evaluate( + T.tvm_store_matrix_sync( Conv_wmma_accumulator, 16, 16, 16, 2, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float32"), + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), Conv_1.data, ( ( @@ -2283,15 +2261,15 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_store_matrix_sync( + T.evaluate( + T.tvm_store_matrix_sync( Conv_wmma_accumulator, 16, 16, 16, 3, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float32"), + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), Conv_1.data, ( ( @@ -2309,15 +2287,15 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_store_matrix_sync( + T.evaluate( + T.tvm_store_matrix_sync( Conv_wmma_accumulator, 16, 16, 16, 4, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float32"), + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), Conv_1.data, ( ( @@ -2335,15 +2313,15 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_store_matrix_sync( + T.evaluate( + T.tvm_store_matrix_sync( Conv_wmma_accumulator, 16, 16, 16, 5, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float32"), + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), Conv_1.data, ( ( @@ -2361,15 +2339,15 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_store_matrix_sync( + T.evaluate( + T.tvm_store_matrix_sync( Conv_wmma_accumulator, 16, 16, 16, 6, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float32"), + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), Conv_1.data, ( ( @@ -2387,15 +2365,15 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No dtype="handle", ) ) - tir.evaluate( - tir.tvm_store_matrix_sync( + T.evaluate( + T.tvm_store_matrix_sync( Conv_wmma_accumulator, 16, 16, 16, 7, - tir.tvm_access_ptr( - tir.type_annotation(dtype="float32"), + T.tvm_access_ptr( + T.type_annotation(dtype="float32"), Conv_1.data, ( ( @@ -2417,21 +2395,21 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No def test_opt_conv_tensorcore_lower(): mod = opt_conv_tensorcore_lower - rt_mod = tvm.script.from_source(tvm.script.asscript(mod, True)) + rt_mod = tvm.script.from_source(mod.script(True)) tvm.ir.assert_structural_equal(mod, rt_mod, True) -@tvm.script.tir +@T.prim_func def opt_conv_tensorcore_mod_host( - args: ty.handle, - arg_type_ids: ty.handle, - num_args: ty.int32, - out_ret_value: ty.handle, - out_ret_tcode: ty.handle, - resource_handle: ty.handle, -) -> ty.int32: + args: T.handle, + arg_type_ids: T.handle, + num_args: T.int32, + out_ret_value: T.handle, + out_ret_tcode: T.handle, + resource_handle: T.handle, +) -> T.int32: # function attr dict - tir.func_attr( + T.func_attr( { "tir.noalias": True, "global_symbol": "default_function", @@ -2440,28 +2418,28 @@ def opt_conv_tensorcore_mod_host( } ) # body - stack_tcode: ty.handle = tir.tvm_stack_alloca("arg_tcode", 10, dtype="handle") - stack_value: ty.handle = tir.tvm_stack_alloca("arg_value", 10, dtype="handle") + stack_tcode: T.handle = T.tvm_stack_alloca("arg_tcode", 10, dtype="handle") + stack_value: T.handle = T.tvm_stack_alloca("arg_value", 10, dtype="handle") assert num_args == 3, "default_function: num_args should be 3" - arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle") - arg0_code: ty.int32 = tir.load("int32", arg_type_ids, 0) - arg1: ty.handle = tir.tvm_struct_get(args, 1, 12, dtype="handle") - arg1_code: ty.int32 = tir.load("int32", arg_type_ids, 1) - arg2: ty.handle = tir.tvm_struct_get(args, 2, 12, dtype="handle") - arg2_code: ty.int32 = tir.load("int32", arg_type_ids, 2) - A: ty.handle = tir.tvm_struct_get(arg0, 0, 1, dtype="handle") - tir.attr(A, "storage_alignment", 128) - arg0_shape: ty.handle = tir.tvm_struct_get(arg0, 0, 2, dtype="handle") - arg0_strides: ty.handle = tir.tvm_struct_get(arg0, 0, 3, dtype="handle") - dev_id: ty.int32 = tir.tvm_struct_get(arg0, 0, 9, dtype="int32") - W: ty.handle = tir.tvm_struct_get(arg1, 0, 1, dtype="handle") - tir.attr(W, "storage_alignment", 128) - arg1_shape: ty.handle = tir.tvm_struct_get(arg1, 0, 2, dtype="handle") - arg1_strides: ty.handle = tir.tvm_struct_get(arg1, 0, 3, dtype="handle") - Conv: ty.handle = tir.tvm_struct_get(arg2, 0, 1, dtype="handle") - tir.attr(Conv, "storage_alignment", 128) - arg2_shape: ty.handle = tir.tvm_struct_get(arg2, 0, 2, dtype="handle") - arg2_strides: ty.handle = tir.tvm_struct_get(arg2, 0, 3, dtype="handle") + arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") + arg0_code: T.int32 = T.load("int32", arg_type_ids, 0) + arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") + arg1_code: T.int32 = T.load("int32", arg_type_ids, 1) + arg2: T.handle = T.tvm_struct_get(args, 2, 12, dtype="handle") + arg2_code: T.int32 = T.load("int32", arg_type_ids, 2) + A: T.handle = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + T.attr(A, "storage_alignment", 128) + arg0_shape: T.handle = T.tvm_struct_get(arg0, 0, 2, dtype="handle") + arg0_strides: T.handle = T.tvm_struct_get(arg0, 0, 3, dtype="handle") + dev_id: T.int32 = T.tvm_struct_get(arg0, 0, 9, dtype="int32") + W: T.handle = T.tvm_struct_get(arg1, 0, 1, dtype="handle") + T.attr(W, "storage_alignment", 128) + arg1_shape: T.handle = T.tvm_struct_get(arg1, 0, 2, dtype="handle") + arg1_strides: T.handle = T.tvm_struct_get(arg1, 0, 3, dtype="handle") + Conv: T.handle = T.tvm_struct_get(arg2, 0, 1, dtype="handle") + T.attr(Conv, "storage_alignment", 128) + arg2_shape: T.handle = T.tvm_struct_get(arg2, 0, 2, dtype="handle") + arg2_strides: T.handle = T.tvm_struct_get(arg2, 0, 3, dtype="handle") assert (((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or ( arg0_code == 4 ), "default_function: Expect arg[0] to be pointer" @@ -2471,189 +2449,187 @@ def opt_conv_tensorcore_mod_host( assert (((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or ( arg2_code == 4 ), "default_function: Expect arg[2] to be pointer" - assert 6 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 6" - assert 6 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 6" + assert 6 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 6" + assert 6 == T.tvm_struct_get(arg0, 0, 4, dtype="int32"), "arg0.ndim is expected to equal 6" assert ( - (tir.tvm_struct_get(arg0, 0, 5, dtype="uint8") == tir.uint8(2)) - and (tir.tvm_struct_get(arg0, 0, 6, dtype="uint8") == tir.uint8(16)) + (T.tvm_struct_get(arg0, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg0, 0, 6, dtype="uint8") == T.uint8(16)) ) and ( - tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1) + T.tvm_struct_get(arg0, 0, 7, dtype="uint16") == T.uint16(1) ), "arg0.dtype is expected to be float16" - assert 16 == tir.cast( - tir.load("int64", arg0_shape, 0), "int32" + assert 16 == T.cast( + T.load("int64", arg0_shape, 0), "int32" ), "Argument arg0.shape[0] has an unsatisfied constraint" - assert 14 == tir.cast( - tir.load("int64", arg0_shape, 1), "int32" + assert 14 == T.cast( + T.load("int64", arg0_shape, 1), "int32" ), "Argument arg0.shape[1] has an unsatisfied constraint" - assert 14 == tir.cast( - tir.load("int64", arg0_shape, 2), "int32" + assert 14 == T.cast( + T.load("int64", arg0_shape, 2), "int32" ), "Argument arg0.shape[2] has an unsatisfied constraint" - assert 16 == tir.cast( - tir.load("int64", arg0_shape, 3), "int32" + assert 16 == T.cast( + T.load("int64", arg0_shape, 3), "int32" ), "Argument arg0.shape[3] has an unsatisfied constraint" - assert 16 == tir.cast( - tir.load("int64", arg0_shape, 4), "int32" + assert 16 == T.cast( + T.load("int64", arg0_shape, 4), "int32" ), "Argument arg0.shape[4] has an unsatisfied constraint" - assert 16 == tir.cast( - tir.load("int64", arg0_shape, 5), "int32" + assert 16 == T.cast( + T.load("int64", arg0_shape, 5), "int32" ), "Argument arg0.shape[5] has an unsatisfied constraint" - if not (tir.isnullptr(arg0_strides, dtype="bool")): + if not (T.isnullptr(arg0_strides, dtype="bool")): assert ( ( ( ( - (1 == tir.cast(tir.load("int64", arg0_strides, 5), "int32")) - and (16 == tir.cast(tir.load("int64", arg0_strides, 4), "int32")) + (1 == T.cast(T.load("int64", arg0_strides, 5), "int32")) + and (16 == T.cast(T.load("int64", arg0_strides, 4), "int32")) ) - and (256 == tir.cast(tir.load("int64", arg0_strides, 3), "int32")) + and (256 == T.cast(T.load("int64", arg0_strides, 3), "int32")) ) - and (4096 == tir.cast(tir.load("int64", arg0_strides, 2), "int32")) + and (4096 == T.cast(T.load("int64", arg0_strides, 2), "int32")) ) - and (57344 == tir.cast(tir.load("int64", arg0_strides, 1), "int32")) + and (57344 == T.cast(T.load("int64", arg0_strides, 1), "int32")) ) and ( - 802816 == tir.cast(tir.load("int64", arg0_strides, 0), "int32") + 802816 == T.cast(T.load("int64", arg0_strides, 0), "int32") ), "arg0.strides: expected to be compact array" - tir.evaluate(0) - assert tir.uint64(0) == tir.tvm_struct_get( + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( arg0, 0, 8, dtype="uint64" ), "Argument arg0.byte_offset has an unsatisfied constraint" - assert 2 == tir.tvm_struct_get( + assert 2 == T.tvm_struct_get( arg0, 0, 10, dtype="int32" ), "Argument arg0.device_type has an unsatisfied constraint" - assert 6 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 6" - assert 6 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 6" + assert 6 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 6" + assert 6 == T.tvm_struct_get(arg1, 0, 4, dtype="int32"), "arg1.ndim is expected to equal 6" assert ( - (tir.tvm_struct_get(arg1, 0, 5, dtype="uint8") == tir.uint8(2)) - and (tir.tvm_struct_get(arg1, 0, 6, dtype="uint8") == tir.uint8(16)) + (T.tvm_struct_get(arg1, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg1, 0, 6, dtype="uint8") == T.uint8(16)) ) and ( - tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1) + T.tvm_struct_get(arg1, 0, 7, dtype="uint16") == T.uint16(1) ), "arg1.dtype is expected to be float16" - assert 3 == tir.cast( - tir.load("int64", arg1_shape, 0), "int32" + assert 3 == T.cast( + T.load("int64", arg1_shape, 0), "int32" ), "Argument arg1.shape[0] has an unsatisfied constraint" - assert 3 == tir.cast( - tir.load("int64", arg1_shape, 1), "int32" + assert 3 == T.cast( + T.load("int64", arg1_shape, 1), "int32" ), "Argument arg1.shape[1] has an unsatisfied constraint" - assert 16 == tir.cast( - tir.load("int64", arg1_shape, 2), "int32" + assert 16 == T.cast( + T.load("int64", arg1_shape, 2), "int32" ), "Argument arg1.shape[2] has an unsatisfied constraint" - assert 32 == tir.cast( - tir.load("int64", arg1_shape, 3), "int32" + assert 32 == T.cast( + T.load("int64", arg1_shape, 3), "int32" ), "Argument arg1.shape[3] has an unsatisfied constraint" - assert 16 == tir.cast( - tir.load("int64", arg1_shape, 4), "int32" + assert 16 == T.cast( + T.load("int64", arg1_shape, 4), "int32" ), "Argument arg1.shape[4] has an unsatisfied constraint" - assert 16 == tir.cast( - tir.load("int64", arg1_shape, 5), "int32" + assert 16 == T.cast( + T.load("int64", arg1_shape, 5), "int32" ), "Argument arg1.shape[5] has an unsatisfied constraint" - if not (tir.isnullptr(arg1_strides, dtype="bool")): + if not (T.isnullptr(arg1_strides, dtype="bool")): assert ( ( ( ( - (1 == tir.cast(tir.load("int64", arg1_strides, 5), "int32")) - and (16 == tir.cast(tir.load("int64", arg1_strides, 4), "int32")) + (1 == T.cast(T.load("int64", arg1_strides, 5), "int32")) + and (16 == T.cast(T.load("int64", arg1_strides, 4), "int32")) ) - and (256 == tir.cast(tir.load("int64", arg1_strides, 3), "int32")) + and (256 == T.cast(T.load("int64", arg1_strides, 3), "int32")) ) - and (8192 == tir.cast(tir.load("int64", arg1_strides, 2), "int32")) + and (8192 == T.cast(T.load("int64", arg1_strides, 2), "int32")) ) - and (131072 == tir.cast(tir.load("int64", arg1_strides, 1), "int32")) + and (131072 == T.cast(T.load("int64", arg1_strides, 1), "int32")) ) and ( - 393216 == tir.cast(tir.load("int64", arg1_strides, 0), "int32") + 393216 == T.cast(T.load("int64", arg1_strides, 0), "int32") ), "arg1.strides: expected to be compact array" - tir.evaluate(0) - assert tir.uint64(0) == tir.tvm_struct_get( + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( arg1, 0, 8, dtype="uint64" ), "Argument arg1.byte_offset has an unsatisfied constraint" - assert 2 == tir.tvm_struct_get( + assert 2 == T.tvm_struct_get( arg1, 0, 10, dtype="int32" ), "Argument arg1.device_type has an unsatisfied constraint" - assert dev_id == tir.tvm_struct_get( + assert dev_id == T.tvm_struct_get( arg1, 0, 9, dtype="int32" ), "Argument arg1.device_id has an unsatisfied constraint" - assert 6 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 6" - assert 6 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 6" + assert 6 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 6" + assert 6 == T.tvm_struct_get(arg2, 0, 4, dtype="int32"), "arg2.ndim is expected to equal 6" assert ( - (tir.tvm_struct_get(arg2, 0, 5, dtype="uint8") == tir.uint8(2)) - and (tir.tvm_struct_get(arg2, 0, 6, dtype="uint8") == tir.uint8(32)) + (T.tvm_struct_get(arg2, 0, 5, dtype="uint8") == T.uint8(2)) + and (T.tvm_struct_get(arg2, 0, 6, dtype="uint8") == T.uint8(32)) ) and ( - tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1) + T.tvm_struct_get(arg2, 0, 7, dtype="uint16") == T.uint16(1) ), "arg2.dtype is expected to be float32" - assert 16 == tir.cast( - tir.load("int64", arg2_shape, 0), "int32" + assert 16 == T.cast( + T.load("int64", arg2_shape, 0), "int32" ), "Argument arg2.shape[0] has an unsatisfied constraint" - assert 14 == tir.cast( - tir.load("int64", arg2_shape, 1), "int32" + assert 14 == T.cast( + T.load("int64", arg2_shape, 1), "int32" ), "Argument arg2.shape[1] has an unsatisfied constraint" - assert 14 == tir.cast( - tir.load("int64", arg2_shape, 2), "int32" + assert 14 == T.cast( + T.load("int64", arg2_shape, 2), "int32" ), "Argument arg2.shape[2] has an unsatisfied constraint" - assert 32 == tir.cast( - tir.load("int64", arg2_shape, 3), "int32" + assert 32 == T.cast( + T.load("int64", arg2_shape, 3), "int32" ), "Argument arg2.shape[3] has an unsatisfied constraint" - assert 16 == tir.cast( - tir.load("int64", arg2_shape, 4), "int32" + assert 16 == T.cast( + T.load("int64", arg2_shape, 4), "int32" ), "Argument arg2.shape[4] has an unsatisfied constraint" - assert 16 == tir.cast( - tir.load("int64", arg2_shape, 5), "int32" + assert 16 == T.cast( + T.load("int64", arg2_shape, 5), "int32" ), "Argument arg2.shape[5] has an unsatisfied constraint" - if not (tir.isnullptr(arg2_strides, dtype="bool")): + if not (T.isnullptr(arg2_strides, dtype="bool")): assert ( ( ( ( - (1 == tir.cast(tir.load("int64", arg2_strides, 5), "int32")) - and (16 == tir.cast(tir.load("int64", arg2_strides, 4), "int32")) + (1 == T.cast(T.load("int64", arg2_strides, 5), "int32")) + and (16 == T.cast(T.load("int64", arg2_strides, 4), "int32")) ) - and (256 == tir.cast(tir.load("int64", arg2_strides, 3), "int32")) + and (256 == T.cast(T.load("int64", arg2_strides, 3), "int32")) ) - and (8192 == tir.cast(tir.load("int64", arg2_strides, 2), "int32")) + and (8192 == T.cast(T.load("int64", arg2_strides, 2), "int32")) ) - and (114688 == tir.cast(tir.load("int64", arg2_strides, 1), "int32")) + and (114688 == T.cast(T.load("int64", arg2_strides, 1), "int32")) ) and ( - 1605632 == tir.cast(tir.load("int64", arg2_strides, 0), "int32") + 1605632 == T.cast(T.load("int64", arg2_strides, 0), "int32") ), "arg2.strides: expected to be compact array" - tir.evaluate(0) - assert tir.uint64(0) == tir.tvm_struct_get( + T.evaluate(0) + assert T.uint64(0) == T.tvm_struct_get( arg2, 0, 8, dtype="uint64" ), "Argument arg2.byte_offset has an unsatisfied constraint" - assert 2 == tir.tvm_struct_get( + assert 2 == T.tvm_struct_get( arg2, 0, 10, dtype="int32" ), "Argument arg2.device_type has an unsatisfied constraint" - assert dev_id == tir.tvm_struct_get( + assert dev_id == T.tvm_struct_get( arg2, 0, 9, dtype="int32" ), "Argument arg2.device_id has an unsatisfied constraint" - tir.evaluate(tir.tvm_struct_set(stack_value, 0, 12, tir.cast(2, "int64"), dtype="int32")) + T.evaluate(T.tvm_struct_set(stack_value, 0, 12, T.cast(2, "int64"), dtype="int32")) stack_tcode[0] = 0 - tir.evaluate(tir.tvm_struct_set(stack_value, 1, 12, tir.cast(dev_id, "int64"), dtype="int32")) + T.evaluate(T.tvm_struct_set(stack_value, 1, 12, T.cast(dev_id, "int64"), dtype="int32")) stack_tcode[1] = 0 - tir.evaluate( - tir.tvm_call_packed_lowered( - "__tvm_set_device", stack_value, stack_tcode, 0, 2, dtype="int32" - ) + T.evaluate( + T.tvm_call_packed_lowered("__tvm_set_device", stack_value, stack_tcode, 0, 2, dtype="int32") ) - tir.attr(0, "compute_scope", "default_function_compute_") - tir.evaluate(tir.tvm_struct_set(stack_value, 0, 12, A, dtype="int32")) + T.attr(0, "compute_scope", "default_function_compute_") + T.evaluate(T.tvm_struct_set(stack_value, 0, 12, A, dtype="int32")) stack_tcode[0] = 3 - tir.evaluate(tir.tvm_struct_set(stack_value, 1, 12, W, dtype="int32")) + T.evaluate(T.tvm_struct_set(stack_value, 1, 12, W, dtype="int32")) stack_tcode[1] = 3 - tir.evaluate(tir.tvm_struct_set(stack_value, 2, 12, Conv, dtype="int32")) + T.evaluate(T.tvm_struct_set(stack_value, 2, 12, Conv, dtype="int32")) stack_tcode[2] = 3 - tir.evaluate(tir.tvm_struct_set(stack_value, 3, 12, tir.cast(196, "int64"), dtype="int32")) + T.evaluate(T.tvm_struct_set(stack_value, 3, 12, T.cast(196, "int64"), dtype="int32")) stack_tcode[3] = 0 - tir.evaluate(tir.tvm_struct_set(stack_value, 4, 12, tir.cast(2, "int64"), dtype="int32")) + T.evaluate(T.tvm_struct_set(stack_value, 4, 12, T.cast(2, "int64"), dtype="int32")) stack_tcode[4] = 0 - tir.evaluate(tir.tvm_struct_set(stack_value, 5, 12, tir.cast(4, "int64"), dtype="int32")) + T.evaluate(T.tvm_struct_set(stack_value, 5, 12, T.cast(4, "int64"), dtype="int32")) stack_tcode[5] = 0 - tir.evaluate(tir.tvm_struct_set(stack_value, 6, 12, tir.cast(4, "int64"), dtype="int32")) + T.evaluate(T.tvm_struct_set(stack_value, 6, 12, T.cast(4, "int64"), dtype="int32")) stack_tcode[6] = 0 - tir.evaluate(tir.tvm_struct_set(stack_value, 7, 12, tir.cast(2, "int64"), dtype="int32")) + T.evaluate(T.tvm_struct_set(stack_value, 7, 12, T.cast(2, "int64"), dtype="int32")) stack_tcode[7] = 0 - tir.evaluate(tir.tvm_struct_set(stack_value, 8, 12, tir.cast(32, "int64"), dtype="int32")) + T.evaluate(T.tvm_struct_set(stack_value, 8, 12, T.cast(32, "int64"), dtype="int32")) stack_tcode[8] = 0 - tir.evaluate( - tir.tvm_call_packed_lowered( + T.evaluate( + T.tvm_call_packed_lowered( "default_function_kernel0", stack_value, stack_tcode, 0, 9, dtype="int32" ) ) @@ -2661,106 +2637,106 @@ def opt_conv_tensorcore_mod_host( def test_opt_conv_tensorcore_mod_host(): mod = opt_conv_tensorcore_mod_host - rt_mod = tvm.script.from_source(tvm.script.asscript(mod, True)) + rt_mod = tvm.script.from_source(mod.script(True)) tvm.ir.assert_structural_equal(mod, rt_mod, True) -@tvm.script.tir -def vthread_func(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") +@T.prim_func +def vthread_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") - i0 = tir.env_thread("blockIdx.x") - i1 = tir.env_thread("threadIdx.x") - i2 = tir.env_thread("vthread") + i0 = T.env_thread("blockIdx.x") + i1 = T.env_thread("threadIdx.x") + i2 = T.env_thread("vthread") - tir.launch_thread(i0, 4) - tir.launch_thread(i1, 2) - tir.launch_thread(i2, 2) - B = tir.allocate([16], "float32", "local") + T.launch_thread(i0, 4) + T.launch_thread(i1, 2) + T.launch_thread(i2, 2) + B = T.allocate([16], "float32", "local") for j in range(16): - B[j] = tir.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + tir.float32(1) + B[j] = T.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + T.float32(1) for j in range(16): - C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = tir.load("float32", B, j) * tir.float32(2) + C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = T.load("float32", B, j) * T.float32(2) def test_vthread(): func = vthread_func - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func, True) -@tvm.script.tir -def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) - with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - with tir.init(): - C[vi, vj] = tir.float32(0) + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.init(): + C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@tvm.script.tir -def matmul_original(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) +@T.prim_func +def matmul_original(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) - for i, j in tir.grid(128, 128): - with tir.block([128, 128], "init") as [vi, vj]: - C[vi, vj] = tir.float32(0) + for i, j in T.grid(128, 128): + with T.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = T.float32(0) for k in range(128): - with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -@tvm.script.tir -def element_wise(a: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") - C = tir.match_buffer(c, (128, 128), "float32") - B = tir.alloc_buffer((128, 128), "float32") +@T.prim_func +def element_wise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") - with tir.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * tir.float32(2) + with T.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * T.float32(2) - with tir.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + tir.float32(1) + with T.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + T.float32(1) -@tvm.script.tir -def predicate(b: ty.handle, c: ty.handle) -> None: - B = tir.match_buffer(b, (16, 16), "float32") - C = tir.match_buffer(c, (16, 16), "float32") +@T.prim_func +def predicate(b: T.handle, c: T.handle) -> None: + B = T.match_buffer(b, (16, 16), "float32") + C = T.match_buffer(c, (16, 16), "float32") - for i, jo, ji in tir.grid(16, 4, 5): - with tir.block([16, 16], "update") as [vi, vj]: - tir.bind(vi, i) - tir.bind(vj, jo * 4 + ji) - tir.where(jo * 4 + ji < 16) - C[vi, vj] = B[vi, vj] + tir.float32(1) + for i, jo, ji in T.grid(16, 4, 5): + with T.block([16, 16], "update") as [vi, vj]: + T.bind(vi, i) + T.bind(vj, jo * 4 + ji) + T.where(jo * 4 + ji < 16) + C[vi, vj] = B[vi, vj] + T.float32(1) def test_module_define(): - func1 = tvm.script.create_module({"matmul": matmul})["matmul"] - func2 = tvm.script.create_module({"element_wise": element_wise})["element_wise"] - func3 = tvm.script.create_module({"predicate": predicate})["predicate"] - mod1 = tvm.script.create_module({"func1": func1, "func2": func2, "func3": func3}) - mod2 = tvm.script.create_module({"func1": matmul, "func2": element_wise, "func3": predicate}) + func1 = tvm.ir.IRModule({"matmul": matmul})["matmul"] + func2 = tvm.ir.IRModule({"element_wise": element_wise})["element_wise"] + func3 = tvm.ir.IRModule({"predicate": predicate})["predicate"] + mod1 = tvm.ir.IRModule({"func1": func1, "func2": func2, "func3": func3}) + mod2 = tvm.ir.IRModule({"func1": matmul, "func2": element_wise, "func3": predicate}) tvm.ir.assert_structural_equal(mod1, mod2) def test_matmul(): func = matmul - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func) def test_matmul_original(): func = matmul_original - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func) assert isinstance(rt_func.body.block, tir.stmt.Block) @@ -2774,7 +2750,7 @@ def test_matmul_original(): def test_element_wise(): func = element_wise - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func) assert isinstance(rt_func.body.block, tir.stmt.Block) @@ -2790,7 +2766,7 @@ def test_element_wise(): def test_predicate(): func = predicate - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func) assert isinstance(rt_func.body.block, tir.stmt.Block) @@ -2800,21 +2776,21 @@ def test_predicate(): assert isinstance(rt_func.body.block.body.body.body.body.block, tir.stmt.Block) -@tvm.script.tir -def for_thread_binding(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - B = tir.match_buffer(b, (16, 16), "float32") +@T.prim_func +def for_thread_binding(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + B = T.match_buffer(b, (16, 16), "float32") - for i in tir.thread_binding(0, 16, thread="threadIdx.x"): - for j in tir.thread_binding( + for i in T.thread_binding(0, 16, thread="threadIdx.x"): + for j in T.thread_binding( 0, 16, thread="threadIdx.y", annotations={"attr_key": "attr_value"} ): - A[i, j] = B[i, j] + tir.float32(1) + A[i, j] = B[i, j] + T.float32(1) def test_for_thread_binding(): func = for_thread_binding - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func) assert isinstance(rt_func.body, tir.stmt.For) @@ -2826,22 +2802,22 @@ def test_for_thread_binding(): assert rt_func.body.body.annotations["attr_key"] == "attr_value" -@tvm.script.tir -def match_buffer_region(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16, 16), "float32") - B = tir.match_buffer(b, (1), "float32") +@T.prim_func +def match_buffer_region(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16, 16), "float32") + B = T.match_buffer(b, (1), "float32") - with tir.block([16, 4]) as [vi, vj]: - C = tir.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) - with tir.block([4]) as [vii]: - D = tir.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) - for i, j in tir.grid(4, 4): + with T.block([16, 4]) as [vi, vj]: + C = T.match_buffer(A[0:16, vi, vj * 4 : vj * 4 + 4], (16, 1, 4)) + with T.block([4]) as [vii]: + D = T.match_buffer(C[vii * 4 : vii * 4 + 4, 0, 0:4], (4, 1, 4)) + for i, j in T.grid(4, 4): B[0] += D[i, 0, j] def test_match_buffer_region(): func = match_buffer_region - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func) assert isinstance(rt_func.body, tir.stmt.BlockRealize) @@ -2863,27 +2839,27 @@ def test_match_buffer_region(): tvm.ir.assert_structural_equal(buffer_D.shape, [4, 1, 4]) -@tvm.script.tir -def block_elements(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - B = tir.match_buffer(b, (1, 1), "float32") - - with tir.block([1], "update") as [vi]: - tir.bind(vi, 0) - tir.where(True) - tir.reads(A[0:16, 0:16]) - tir.writes(B[0, 0]) - tir.block_attr({"attr_key": "attr_value"}) - C = tir.alloc_buffer((4, 4), dtype="float32") - D = tir.match_buffer(A[0:4, 0], (4, 1)) - with tir.init(): - B[0, 0] = tir.float32(0) +@T.prim_func +def block_elements(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + B = T.match_buffer(b, (1, 1), "float32") + + with T.block([1], "update") as [vi]: + T.bind(vi, 0) + T.where(True) + T.reads(A[0:16, 0:16]) + T.writes(B[0, 0]) + T.block_attr({"attr_key": "attr_value"}) + C = T.alloc_buffer((4, 4), dtype="float32") + D = T.match_buffer(A[0:4, 0], (4, 1)) + with T.init(): + B[0, 0] = T.float32(0) B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2] def test_block_elements(): func = block_elements - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func) assert isinstance(rt_func.body.block, tir.stmt.Block) @@ -2896,27 +2872,27 @@ def test_block_elements(): assert block.annotations["attr_key"] == "attr_value" -@tvm.script.tir -def opaque_block(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - B = tir.match_buffer(b, (16, 16), "float32") +@T.prim_func +def opaque_block(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + B = T.match_buffer(b, (16, 16), "float32") for i in range(16): for j in range(16): - with tir.block([]): - tir.reads([]) - tir.writes(A[i, j]) - A[i, j] = tir.float32(0) - with tir.block([]): - tir.reads([A[i, 0:16]]) - tir.writes([B[i, 0:16]]) + with T.block([]): + T.reads([]) + T.writes(A[i, j]) + A[i, j] = T.float32(0) + with T.block([]): + T.reads([A[i, 0:16]]) + T.writes([B[i, 0:16]]) for j in range(16): B[i, j] = A[i, j] def test_opaque_block(): func = opaque_block - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func) root_block = rt_func.body.block @@ -2931,147 +2907,147 @@ def test_opaque_block(): assert len(root_block.body.body[1].block.iter_vars) == 0 -@tvm.script.tir -def rank0(a: ty.handle) -> None: - A = tir.match_buffer(a, (), "float32") - B = tir.alloc_buffer((), "float32") +@T.prim_func +def rank0(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + B = T.alloc_buffer((), "float32") A[()] = 2 B[()] = A[()] def test_rank0_buffers(): func = rank0 - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func) -@tvm.script.tir -def rank0_block(a: ty.handle) -> None: - A = tir.match_buffer(a, (), "float32") - B = tir.alloc_buffer((), "float32") - tir.store(B.data, 0, tir.load("float32", A.data, 0)) +@T.prim_func +def rank0_block(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + B = T.alloc_buffer((), "float32") + T.store(B.data, 0, T.load("float32", A.data, 0)) - with tir.block([], "update") as []: - tir.reads([A[()]]) - tir.writes([B[()]]) + with T.block([], "update") as []: + T.reads([A[()]]) + T.writes([B[()]]) for i in range(1): B[()] = A[()] def test_rank0_blocks(): func = rank0_block - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func) -@tvm.script.tir -def select(a: ty.handle) -> None: - A = tir.match_buffer(a, (), "float32") - A[()] = tir.Select(True, 1, 2) +@T.prim_func +def select(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + A[()] = T.Select(True, 1, 2) def test_select(): func = select - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func) -@tvm.script.tir -def minmax(a: ty.handle) -> None: - A = tir.match_buffer(a, (), "float32") - A[()] = tir.min(1, 2) - A[()] = tir.max(1, 2) +@T.prim_func +def minmax(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + A[()] = T.min(1, 2) + A[()] = T.max(1, 2) def test_minmax(): func = minmax - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func) -@tvm.script.tir -def abs(a: ty.handle) -> None: - A = tir.match_buffer(a, (128, 128), "float32") +@T.prim_func +def abs(a: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") - with tir.block([128, 128], "A") as [vi, vj]: - A[vi, vj] = tir.abs(A[vi, vj]) + with T.block([128, 128], "A") as [vi, vj]: + A[vi, vj] = T.abs(A[vi, vj]) def test_abs(): func = abs - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func) -@tvm.script.tir -def constant_folding(a: ty.handle) -> None: - A = tir.match_buffer(a, (), "float32") - A[()] = tir.min(2.2, 5.2) - A[()] = tir.max(tir.float32(2.2), tir.float32(tir.float32(5.2))) - A[()] = tir.min(2.2, 5.0) +@T.prim_func +def constant_folding(a: T.handle) -> None: + A = T.match_buffer(a, (), "float32") + A[()] = T.min(2.2, 5.2) + A[()] = T.max(T.float32(2.2), T.float32(T.float32(5.2))) + A[()] = T.min(2.2, 5.0) def test_constant_folding(): func = constant_folding - rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + rt_func = tvm.script.from_source(func.script(True)) tvm.ir.assert_structural_equal(func, rt_func) -@tvm.script.tir +@T.prim_func def simplify_bracket() -> None: - a = tir.var("int32") - b = tir.var("int32") - c = tir.var("int32") - d = tir.var("int32") - tir.evaluate(a + b * (c + d)) + a = T.var("int32") + b = T.var("int32") + c = T.var("int32") + d = T.var("int32") + T.evaluate(a + b * (c + d)) def test_simplify_bracket(): func = simplify_bracket - out_str = tvm.script.asscript(func, True) - assert out_str.count("a + b*(c + d)") == 1 + out_str = func.script(True) + assert out_str.count("a + b * (c + d)") == 1 -@tvm.script.tir -def var_with_same_name(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - with tir.block([16, 16]) as [vi, vj]: +@T.prim_func +def var_with_same_name(a: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + with T.block([16, 16]) as [vi, vj]: A[vi, vj] = 0 - with tir.block([16, 16]) as [vi, vj]: + with T.block([16, 16]) as [vi, vj]: A[vi, vj] = 0 - for i, j in tir.grid(16, 16): - with tir.block([16, 16]) as [vi, vj]: + for i, j in T.grid(16, 16): + with T.block([16, 16]) as [vi, vj]: A[vi, vj] = 0 - for i, j in tir.grid(16, 16): - with tir.block([16, 16]) as [vi, vj]: + for i, j in T.grid(16, 16): + with T.block([16, 16]) as [vi, vj]: A[vi, vj] = 0 def test_same_name_var(): func = var_with_same_name - out_str = tvm.script.asscript(func, True) + out_str = func.script(True) rt_func = tvm.script.from_source(out_str) tvm.ir.assert_structural_equal(func, rt_func) - assert out_str.count("with tir.block([16, 16]) as [vi, vj]") == 4 + assert out_str.count("with T.block([16, 16]) as [vi, vj]") == 4 assert out_str.find("vi_") == -1 assert out_str.find("vj_") == -1 - assert out_str.count("for i0, i1 in tir.grid(16, 16)") == 2 + assert out_str.count("for i0, i1 in T.grid(16, 16)") == 2 assert out_str.find("i0_") == -1 assert out_str.find("i1_") == -1 - assert out_str.count("for i, j in tir.grid(16, 16)") == 2 + assert out_str.count("for i, j in T.grid(16, 16)") == 2 assert out_str.find("i_") == -1 assert out_str.find("i_") == -1 -@tvm.script.tir -def while_loop(a: ty.handle, b: ty.handle) -> None: - A = tir.match_buffer(a, (16,), "float32") - B = tir.match_buffer(b, (16,), "float32") - i = tir.alloc_buffer((), "int32", scope="local") - with tir.block([16]) as [vi]: +@T.prim_func +def while_loop(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (16,), "float32") + B = T.match_buffer(b, (16,), "float32") + i = T.alloc_buffer((), "int32", scope="local") + with T.block([16]) as [vi]: B[vi] = 0 while i[()] < 10: for j in range(16): @@ -3079,7 +3055,7 @@ def while_loop(a: ty.handle, b: ty.handle) -> None: def test_while_loop(): - rt_func = tvm.script.from_source(tvm.script.asscript(while_loop, True)) + rt_func = tvm.script.from_source(while_loop.script(True)) tvm.ir.assert_structural_equal(while_loop, rt_func) diff --git a/tests/python/unittest/test_tvmscript_spans.py b/tests/python/unittest/test_tvmscript_spans.py index 612389538d7e5..f863a4dd983e4 100644 --- a/tests/python/unittest/test_tvmscript_spans.py +++ b/tests/python/unittest/test_tvmscript_spans.py @@ -15,29 +15,27 @@ # specific language governing permissions and limitations # under the License. -import inspect -import tvm -import tvm.script -from tvm import tir +from tvm.script import tir as T +@T.prim_func def loops() -> None: - for i in tir.parallel(0, 2): - for j in tir.serial(0, 1): - for z in tir.vectorized(3, 4): - tir.evaluate(0) + for i in T.parallel(0, 2): + for j in T.serial(0, 1): + for z in T.vectorized(3, 4): + T.evaluate(0) def test_loops(): - _, start_line = inspect.getsourcelines(loops) - parsed = tvm.script.tir(loops) + start_line = 23 + parsed = loops assert parsed.span.line == start_line assert parsed.body.span.line == start_line + 1 - assert parsed.body.min.span.column == 27 - assert parsed.body.extent.span.column == 30 + assert parsed.body.min.span.column == 25 + assert parsed.body.extent.span.column == 28 assert parsed.body.extent.span.line == start_line + 1 assert parsed.body.body.span.line == start_line + 2 @@ -51,14 +49,15 @@ def test_loops(): assert parsed.body.body.body.body.span.column == 17 +@T.prim_func def statements() -> None: - tir.evaluate(1) - tir.evaluate("test") + T.evaluate(1) + T.evaluate("test") def test_statements(): - _, start_line = inspect.getsourcelines(statements) - parsed = tvm.script.tir(statements) + start_line = 53 + parsed = statements assert parsed.body.span.line == start_line + 1 diff --git a/tests/scripts/task_ci_setup.sh b/tests/scripts/task_ci_setup.sh index 01d5587e70ada..7138effe395a4 100755 --- a/tests/scripts/task_ci_setup.sh +++ b/tests/scripts/task_ci_setup.sh @@ -30,7 +30,7 @@ set -o pipefail # echo "Addtiional setup in" ${CI_IMAGE_NAME} -python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.4.0 +python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.4.1 # Rebuild standalone_crt in build/ tree. This file is not currently archived by pack_lib() in # Jenkinsfile. We expect config.cmake to be present from pack_lib().