diff --git a/python/tvm/tir/schedule/_type_checker.py b/python/tvm/tir/schedule/_type_checker.py new file mode 100644 index 000000000000..c815282b74fc --- /dev/null +++ b/python/tvm/tir/schedule/_type_checker.py @@ -0,0 +1,193 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Type checking functionality""" +import functools +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import typing + + +def _is_none_type(type_: Any) -> bool: + return type_ is None or type_ is type(None) + + +if hasattr(typing, "_GenericAlias"): + + class _Subtype: + @staticmethod + def _origin(type_: Any) -> Any: + if isinstance(type_, typing._GenericAlias): # type: ignore # pylint: disable=protected-access + return type_.__origin__ + return None + + @staticmethod + def list_(type_: Any) -> Any: + if _Subtype._origin(type_) is list: + (subtype,) = type_.__args__ + return [subtype] + return None + + @staticmethod + def optional(type_: Any) -> Optional[List[type]]: + if _Subtype._origin(type_) is Union: + subtypes = type_.__args__ + if len(subtypes) == 2 and _is_none_type(subtypes[1]): + return [subtypes[0]] + return None + + @staticmethod + def union(type_: Any) -> Optional[List[type]]: + if _Subtype._origin(type_) is Union: + subtypes = type_.__args__ + if len(subtypes) != 2 or not _is_none_type(subtypes[1]): + return list(subtypes) + return None + + +elif hasattr(typing, "_Union"): + + class _Subtype: # type: ignore + @staticmethod + def list_(type_: Any) -> Optional[List[type]]: + if isinstance(type_, typing.GenericMeta): # type: ignore # pylint: disable=no-member + if type_.__name__ == "List": + (subtype,) = type_.__args__ # type: ignore # pylint: disable=no-member + return [subtype] + return None + + @staticmethod + def optional(type_: Any) -> Optional[List[type]]: + if isinstance(type_, typing._Union): # type: ignore # pylint: disable=no-member,protected-access + subtypes = type_.__args__ + if len(subtypes) == 2 and _is_none_type(subtypes[1]): + return [subtypes[0]] + return None + + @staticmethod + def union(type_: Any) -> Optional[List[type]]: + if isinstance(type_, typing._Union): # type: ignore # pylint: disable=no-member,protected-access + subtypes = type_.__args__ + if len(subtypes) != 2 or not _is_none_type(subtypes[1]): + return list(subtypes) + return None + + +def _dispatcher(type_: Any) -> Tuple[str, List[type]]: + if _is_none_type(type_): + return "none", [] + + subtype = _Subtype.list_(type_) + if subtype is not None: + return "list", subtype + + subtype = _Subtype.optional(type_) + if subtype is not None: + return "optional", subtype + + subtype = _Subtype.union(type_) + if subtype is not None: + return "union", subtype + + return "atomic", [type_] + + +_TYPE2STR: Dict[Any, Callable] = { + "none": lambda: "None", + "atomic": lambda t: str(t.__name__), + "list": lambda t: f"List[{_type2str(t)}]", + "optional": lambda t: f"Optional[{_type2str(t)}]", + "union": lambda *t: f"Union[{', '.join([_type2str(x) for x in t])}]", +} + + +def _type2str(type_: Any) -> str: + key, subtypes = _dispatcher(type_) + return _TYPE2STR[key](*subtypes) + + +def _type_check_err(x: Any, name: str, expected: Any) -> str: + return ( + f'"{name}" has wrong type. ' + f'Expected "{_type2str(expected)}", ' + f'but gets: "{_type2str(type(x))}"' + ) + + +def _type_check_vtable() -> Dict[str, Callable]: + def _type_check_none(v: Any, name: str) -> Optional[str]: + return None if v is None else _type_check_err(v, name, None) + + def _type_check_atomic(v: Any, name: str, type_: Any) -> Optional[str]: + return None if isinstance(v, type_) else _type_check_err(v, name, type_) + + def _type_check_list(v: List[Any], name: str, type_: Any) -> Optional[str]: + if not isinstance(v, (list, tuple)): + return _type_check_err(v, name, list) + for i, x in enumerate(v): + error_msg = _type_check(x, f"{name}[{i}]", type_) + if error_msg is not None: + return error_msg + return None + + def _type_check_optional(v: Any, name: str, type_: Any) -> Optional[str]: + return None if v is None else _type_check(v, name, type_) + + def _type_check_union(v: Any, name: str, *types: Any) -> Optional[str]: + for type_ in types: + error_msg = _type_check(v, name, type_) + if error_msg is None: + return None + return _type_check_err(v, name, types) + + return { + "none": _type_check_none, + "atomic": _type_check_atomic, + "list": _type_check_list, + "optional": _type_check_optional, + "union": _type_check_union, + } + + +_TYPE_CHECK: Dict[Any, Callable] = _type_check_vtable() + + +def _type_check(v: Any, name: str, type_: Any) -> Optional[str]: + key, subtypes = _dispatcher(type_) + return _TYPE_CHECK[key](v, name, *subtypes) + + +def type_checked(func: Callable) -> Callable: + """Type check the input arguments of a function.""" + sig = inspect.signature(func) + + @functools.wraps(func) + def wrap(*args, **kwargs): + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + for param in sig.parameters.values(): + if param.annotation != inspect.Signature.empty: + error_msg = _type_check( + bound_args.arguments[param.name], + param.name, + param.annotation, + ) + if error_msg is not None: + error_msg = f'In "{func.__qualname__}", {error_msg}' + raise TypeError(error_msg) + return func(*args, **kwargs) + + return wrap diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 884eeb7c612c..a4bb544557d5 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -26,6 +26,7 @@ from . import _ffi_api from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod from .trace import Trace +from ._type_checker import type_checked @register_error @@ -104,6 +105,7 @@ class Schedule(Object): Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html """ + @type_checked def __init__( self, mod: Union[PrimFunc, IRModule], @@ -198,6 +200,7 @@ def copy(self) -> "Schedule": """ return _ffi_api.ScheduleCopy(self) # type: ignore # pylint: disable=no-member + @type_checked def seed(self, seed: int) -> None: """Seed the randomness @@ -218,6 +221,7 @@ def fork_seed(self) -> int: """ return _ffi_api.ScheduleForkSeed(self) # type: ignore # pylint: disable=no-member + @type_checked def show(self, rand_var: RAND_VAR_TYPE) -> str: """Returns a string representation of the value that the random variable evaluates to @@ -235,6 +239,7 @@ def show(self, rand_var: RAND_VAR_TYPE) -> str: ########## Lookup ########## + @type_checked def get( self, rand_var_or_sref: Union[RAND_VAR_TYPE, StmtSRef], @@ -263,6 +268,7 @@ def get( result = result.value return result + @type_checked def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Optional[StmtSRef]: """Returns the corresponding sref to the given 1) LoopRV @@ -284,6 +290,7 @@ def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Opti self, rand_var_or_stmt ) + @type_checked def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None: """Remove a random variable from the symbol table @@ -296,6 +303,7 @@ def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None: ########## Schedule: Sampling ########## + @type_checked def sample_categorical( self, candidates: List[int], @@ -325,6 +333,7 @@ def sample_categorical( decision, ) + @type_checked def sample_perfect_tile( self, loop: LoopRV, @@ -350,15 +359,18 @@ def sample_perfect_tile( result : List[ExprRV] A list of length `n`, the random perfect tile sizes sampled """ - return _ffi_api.ScheduleSamplePerfectTile( # type: ignore # pylint: disable=no-member - self, - loop, - n, - max_innermost_factor, - decision, + return list( + _ffi_api.ScheduleSamplePerfectTile( # type: ignore # pylint: disable=no-member + self, + loop, + n, + max_innermost_factor, + decision, + ) ) ########## Schedule: Get blocks & loops ########## + @type_checked def get_block( self, name: str, @@ -385,6 +397,7 @@ def get_block( func_name, ) + @type_checked def get_loops(self, block: BlockRV) -> List[LoopRV]: """Get the parent loops of the block in its scope, from outer to inner @@ -398,8 +411,9 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]: loops : List[LoopRV] A list of loops above the given block in its scope, from outer to inner """ - return _ffi_api.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member + return list(_ffi_api.ScheduleGetLoops(self, block)) # type: ignore # pylint: disable=no-member + @type_checked def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> List[BlockRV]: """Get the leaf blocks of a specific block/loop @@ -413,8 +427,9 @@ def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> List[BlockR blocks : List[LoopRV] A list of leaf blocks inside a specific block/loop """ - return _ffi_api.ScheduleGetChildBlocks(self, block_or_loop) # type: ignore # pylint: disable=no-member + return list(_ffi_api.ScheduleGetChildBlocks(self, block_or_loop)) # type: ignore # pylint: disable=no-member + @type_checked def get_producers(self, block: BlockRV) -> List[BlockRV]: """Get the producers of a specific block @@ -428,8 +443,9 @@ def get_producers(self, block: BlockRV) -> List[BlockRV]: producers : List[BlockRV] A list of producers of the given block """ - return _ffi_api.ScheduleGetProducers(self, block) # type: ignore # pylint: disable=no-member + return list(_ffi_api.ScheduleGetProducers(self, block)) # type: ignore # pylint: disable=no-member + @type_checked def get_consumers(self, block: BlockRV) -> List[BlockRV]: """Get the consumers of a specific block @@ -443,9 +459,10 @@ def get_consumers(self, block: BlockRV) -> List[BlockRV]: consumers : List[BlockRV] A list of consumers of the given block """ - return _ffi_api.ScheduleGetConsumers(self, block) # type: ignore # pylint: disable=no-member + return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore # pylint: disable=no-member ########## Schedule: Transform loops ########## + @type_checked def fuse(self, *loops: List[LoopRV]) -> LoopRV: """Fuse a list of consecutive loops into one. It requires: 1) The loops can't have annotations or thread bindings. @@ -506,10 +523,11 @@ def after_fuse(a: T.handle, b: T.handle) -> None: """ return _ffi_api.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member + @type_checked def split( self, loop: LoopRV, - factors: List[Union[ExprRV, None]], + factors: List[Union[int, ExprRV, None]], ) -> List[LoopRV]: """Split a loop into a list of consecutive loops. It requires: 1) The loop can't have annotation or thread binding. @@ -523,12 +541,12 @@ def split( loop : LoopRV The loop to be split - factors: List[Union[ExprRV, None]] + factors: List[Union[int, ExprRV, None]] The splitting factors Potential inputs are: - None - ExprRV - - Nonnegative constant integers + - Non-negative constant integers Returns ------- @@ -578,8 +596,9 @@ def after_split(a: T.handle, b: T.handle) -> None: """ # it will be checked later in C++ implementation # that there is at most one None in `factors` - return _ffi_api.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member + return list(_ffi_api.ScheduleSplit(self, loop, factors)) # type: ignore # pylint: disable=no-member + @type_checked def reorder(self, *ordered_loops: List[LoopRV]) -> None: """ Reorder a list of loops. It doesn't require the loops to be consecutive. @@ -641,6 +660,7 @@ def after_reorder(a: T.handle, b: T.handle) -> None: ########## Schedule: Manipulate ForKind ########## + @type_checked def parallel(self, loop: LoopRV) -> None: """Parallelize the input loop. It requires: 1) The scope block that the loop is in should have stage-pipeline property @@ -695,6 +715,7 @@ def after_parallel(a: T.handle, b: T.handle) -> None: """ _ffi_api.ScheduleParallel(self, loop) # type: ignore # pylint: disable=no-member + @type_checked def vectorize(self, loop: LoopRV) -> None: """Vectorize the input loop. It requires: 1) The scope block that the loop is in should have stage-pipeline property @@ -749,6 +770,7 @@ def after_vectorize(a: T.handle, b: T.handle) -> None: """ _ffi_api.ScheduleVectorize(self, loop) # type: ignore # pylint: disable=no-member + @type_checked def bind(self, loop: LoopRV, thread_axis: str) -> None: """Bind the input loop to the given thread axis. It requires: 1) The scope block that the loop is in should have stage-pipeline property @@ -812,6 +834,7 @@ def after_bind(a: T.handle, b: T.handle) -> None: """ _ffi_api.ScheduleBind(self, loop, thread_axis) # type: ignore # pylint: disable=no-member + @type_checked def unroll(self, loop: LoopRV) -> None: """Unroll the input loop. It requires nothing @@ -863,6 +886,7 @@ def after_unroll(a: T.handle, b: T.handle) -> None: ########## Schedule: Insert cache stages ########## + @type_checked def cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str) -> BlockRV: """Create a block that reads a buffer region into a read cache. It requires: @@ -933,6 +957,7 @@ def after_cache_read(a: T.handle, b: T.handle) -> None: self, block, read_buffer_index, storage_scope ) + @type_checked def cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: str) -> BlockRV: """Create a block that reads a buffer region into a write cache. It requires: @@ -1006,6 +1031,7 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: ########## Schedule: Compute location ########## + @type_checked def compute_at( self, block: BlockRV, @@ -1098,6 +1124,7 @@ def after_compute_at(a: T.handle, c: T.handle) -> None: preserve_unit_loops, ) + @type_checked def reverse_compute_at( self, block: BlockRV, @@ -1187,6 +1214,7 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None: preserve_unit_loops, ) + @type_checked def compute_inline(self, block: BlockRV) -> None: """Inline a block into its consumer(s). It requires: @@ -1250,6 +1278,7 @@ def after_inline(a: T.handle, c: T.handle) -> None: """ _ffi_api.ScheduleComputeInline(self, block) # type: ignore # pylint: disable=no-member + @type_checked def reverse_compute_inline(self, block: BlockRV) -> None: """Inline a block into its only producer. It requires: @@ -1318,6 +1347,7 @@ def after_inline(a: T.handle, c: T.handle) -> None: ########## Schedule: Reduction ########## + @type_checked def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV: """Decompose a reduction block into two separate blocks. @@ -1394,6 +1424,7 @@ def after_decompose(a: ty.handle, c: ty.handle) -> None: """ return _ffi_api.ScheduleDecomposeReduction(self, block, loop) # type: ignore # pylint: disable=no-member + @type_checked def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV: """Factorize an associative reduction block by the specified loop. @@ -1544,6 +1575,7 @@ def after_rfactor(a: T.handle, b: T.handle) -> None: ######## Schedule: Block annotation ######## + @type_checked def storage_align( # pylint: disable=too-many-arguments self, block: BlockRV, @@ -1634,6 +1666,7 @@ def after_storage_align(a: T.handle, c: T.handle) -> None: ########## Schedule: Misc ########## + @type_checked def enter_postproc(self) -> None: """A no-op that marks the start of postprocessing phase of scheduling""" _ffi_api.ScheduleEnterPostproc(self) # type: ignore # pylint: disable=no-member diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 035c16f506cf..e4df5f893ae9 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -204,7 +204,7 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { << (obj.defined() ? obj->GetTypeKey() : "None"); } if (sref->stmt == nullptr) { - LOG(FATAL) << "ValueError: The StmtSRef has expired"; + LOG(FATAL) << "ValueError: The block no longer exists in the IRModule"; } return GetRef(sref); } @@ -229,7 +229,7 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { << (obj.defined() ? obj->GetTypeKey() : "None"); } if (sref->stmt == nullptr) { - LOG(FATAL) << "ValueError: The StmtSRef has expired"; + LOG(FATAL) << "ValueError: The loop no longer exists in the IRModule"; } return GetRef(sref); }