From 2ba9aad5cf5b5b4aaec71d67c2f704426f04c48a Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 9 Nov 2021 00:52:51 -0800 Subject: [PATCH] [MetaSchedule] Multi-Level-Tiling & Auto-Inline (#503) --- include/tvm/meta_schedule/schedule_rule.h | 47 +- include/tvm/tir/stmt.h | 22 +- .../meta_schedule/schedule_rule/__init__.py | 4 +- .../schedule_rule/auto_inline.py | 71 +++ .../schedule_rule/multi_level_tiling.py | 84 ++++ .../schedule_rule/schedule_rule.py | 8 +- python/tvm/meta_schedule/testing/__init__.py | 2 + .../meta_schedule/testing/schedule_rule.py | 131 ++++++ .../tvm/meta_schedule/testing/te_workload.py | 89 +++- python/tvm/meta_schedule/tune_context.py | 7 +- .../schedule_rule/auto_inline.cc | 190 ++++++++ .../schedule_rule/multi_level_tiling.cc | 418 ++++++++++++++++++ .../space_generator/post_order_apply.cc | 3 +- src/meta_schedule/utils.h | 2 +- src/support/array.h | 23 + src/tir/schedule/analysis.h | 100 +++++ src/tir/schedule/analysis/analysis.cc | 238 ++++++++++ src/tir/schedule/concrete_schedule.h | 2 +- src/tir/schedule/primitive/compute_inline.cc | 63 ++- src/tir/schedule/primitive/for_kind.cc | 2 +- src/tir/schedule/primitive/sampling.cc | 8 +- src/tir/schedule/utils.h | 69 ++- .../test_meta_schedule_schedule_rule.py | 24 +- ...meta_schedule_schedule_rule_auto_inline.py | 17 + .../unittest/test_meta_schedule_sketch_cpu.py | 392 ++++++++++++++++ .../test_meta_schedule_sketch_cuda.py | 333 ++++++++++++++ .../test_tir_schedule_compute_inline.py | 28 ++ 27 files changed, 2323 insertions(+), 54 deletions(-) create mode 100644 python/tvm/meta_schedule/schedule_rule/auto_inline.py create mode 100644 python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py create mode 100644 python/tvm/meta_schedule/testing/schedule_rule.py create mode 100644 src/meta_schedule/schedule_rule/auto_inline.cc create mode 100644 src/meta_schedule/schedule_rule/multi_level_tiling.cc create mode 100644 tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py create mode 100644 tests/python/unittest/test_meta_schedule_sketch_cpu.py create mode 100644 tests/python/unittest/test_meta_schedule_sketch_cuda.py diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 92aa46beeaf6..b9e6d8777449 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -76,11 +76,11 @@ class PyScheduleRuleNode : public ScheduleRuleNode { */ using FAsString = runtime::TypedPackedFunc; - /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ FInitializeWithTuneContext f_initialize_with_tune_context; - /*! \brief The packed function to the `Apply` funcion. */ + /*! \brief The packed function to the `Apply` function. */ FApply f_apply; - /*! \brief The packed function to the `AsString` funcion. */ + /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; void VisitAttrs(tvm::AttrVisitor* v) { @@ -110,6 +110,47 @@ class PyScheduleRuleNode : public ScheduleRuleNode { */ class ScheduleRule : public runtime::ObjectRef { public: + /*! + * \brief Create an auto-inline rule that inlines spatial blocks if it satisfies some conditions + * \brief into_producer If allows to inline a block into its producer + * \brief into_consumer If allows to inline a block into its consumer + * \brief into_cache_only If it only allows to inline into a block generated by cache_read/write + * \param inline_const_tensor Always inline constant tensors + * \param disallow_if_then_else Always disallow if-then-else-like constructs + * \param require_ordered Always require the read-to-write mapping to be ordered + * \param require_injective Always require the read-to-write mapping to be injective + * \param disallow_op The operators that are disallowed in auto inline + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule AutoInline(bool into_producer, // + bool into_consumer, // + bool into_cache_only, // + bool inline_const_tensor, // + bool disallow_if_then_else, // + bool require_injective, // + bool require_ordered, // + Optional> disallow_op); + /*! + * \brief Create a mega rule: multi-level tiling with data reuse + * \param structure The tiling structure. Recommended: + * - 'SSRSRS' on CPU + * - 'SSSRRSRS' on GPU + * \param tile_bind For each level of tiles, which thread axis it is bound to. Recommended: + * - NullOpt on CPU + * - [blockIdx.x, vthread.x, threadIdx.x] on GPU + * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit + * \param vector_load_max_len The length of vector lane in vectorized cooperative fetching. + * NullOpt means disable vectorization + * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. + * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule MultiLevelTiling(String structure, // + Optional> tile_binds, // + Optional max_innermost_factor, // + Optional vector_load_max_len, // + Optional> reuse_read, // + Optional> reuse_write); /*! * \brief Create a schedule rule with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 4f5772822d9e..170fc8662e2a 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1224,7 +1224,7 @@ class BlockRealize : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode); }; -/*! \brief namespace of possible attribute sin AttrStmt.attr_key */ +/*! \brief namespace of possible attributes in AttrStmt.attr_key */ namespace attr { // The above attr does not pass to ir stage. /*! \brief Mark launching extent of thread, used by device API. */ @@ -1355,6 +1355,26 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_ */ constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint"; +/*! + * \brief Mark that the loop should be further skip and bound to environment threads to enable + * cooperative fetching. + */ +constexpr const char* meta_schedule_lazy_cooperative_fetch = "meta_schedule.lazy_cooperative_fetch"; + +/*! + * \brief Mark a block as generated by cache_read or cache_write block. + * 0 means cache_read; 1 means cache_write. + * \sa meta_schedule_cache_type_read + * \sa meta_schedule_cache_type_write + */ +constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type"; + +/*! \sa meta_schedule_cache_type */ +constexpr const int meta_schedule_cache_type_read = 0; + +/*! \sa meta_schedule_cache_type */ +constexpr const int meta_schedule_cache_type_write = 1; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index 34a7590b60c0..a00a8d861924 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -16,4 +16,6 @@ Meta Schedule schedule rules are used for modification of blocks in a schedule. See also PostOrderApply. """ -from .schedule_rule import ScheduleRule, PyScheduleRule +from .auto_inline import AutoInline +from .multi_level_tiling import MultiLevelTiling, ReuseType +from .schedule_rule import PyScheduleRule, ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/auto_inline.py b/python/tvm/meta_schedule/schedule_rule/auto_inline.py new file mode 100644 index 000000000000..83828586bfb2 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/auto_inline.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Auto-Inline. Rule that inlines spatial blocks if it satisfies some conditions""" +from typing import List, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.AutoInline") +class AutoInline(ScheduleRule): + """Rule that inlines spatial blocks if it satisfies some conditions + + Parameters + ---------- + into_producer : bool + If allows to inline a block into its producer + into_consumer : bool + If allows to inline a block into its consumer + into_cache_only : bool + If it only allows to inline into a block generated by cache_read/write + inline_const_tensor : bool + Always inline constant tensors + disallow_if_then_else : bool + Always disallow if-then-else-like constructs + require_injective : bool + Always require the read-to-write mapping to be ordered + require_ordered : bool + Always require the read-to-write mapping to be injective + disallow_op : Optional[List[str]] + The operators that are disallowed in auto inline + """ + + def __init__( + self, + into_producer: bool, + into_consumer: bool, + into_cache_only: bool, + inline_const_tensor: bool, + disallow_if_then_else: bool, + require_injective: bool, + require_ordered: bool, + disallow_op: Optional[List[str]] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleAutoInline, # type: ignore # pylint: disable=no-member + into_producer, + into_consumer, + into_cache_only, + inline_const_tensor, + disallow_if_then_else, + require_injective, + require_ordered, + disallow_op, + ) diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py new file mode 100644 index 000000000000..669ede242e06 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Multi-level tiling with reuse.""" +from typing import Any, Dict, List, Literal, NamedTuple, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +class ReuseType(NamedTuple): + """Reuse type.""" + + req: Literal["no", "may", "must"] + levels: List[int] + scope: str + + def as_dict(self) -> Dict[str, Any]: + """Return the dict representation of the reuse type.""" + return { + "req": self.req, + "levels": self.levels, + "scope": self.scope, + } + + +@register_object("meta_schedule.MultiLevelTiling") +class MultiLevelTiling(ScheduleRule): + """Multi-level tiling with reuse. + + Parameters + ---------- + structure : str + The tiling structure. Recommended: + - 'SSRSRS' on CPU + - 'SSSRRSRS' on GPU + tile_bind : Optional[List[str]] + For each level of tiles, which thread axis it is bound to. Recommended: + - None on CPU + - [blockIdx.x, vthread.x, threadIdx.x] on GPU + max_innermost_factor : Optional[int] + The maximum size of the innermost factor. None means no limit + vector_load_max_len : Optional[int] + The length of vector lane in vectorized cooperative fetching. + None means disable vectorization + reuse_read : Optional[ReuseType] + Data reuse configuration for reading. None means no reuse. + reuse_write : Optional[ReuseType] + Data reuse configuration for writing. None means no reuse. + """ + + def __init__( + self, + structure: str, + tile_binds: Optional[List[str]] = None, + max_innermost_factor: Optional[int] = None, + vector_load_max_len: Optional[int] = None, + reuse_read: Optional[ReuseType] = None, + reuse_write: Optional[ReuseType] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleMultiLevelTiling, # type: ignore # pylint: disable=no-member + structure, + tile_binds, + max_innermost_factor, + vector_load_max_len, + reuse_read.as_dict() if reuse_read is not None else None, + reuse_write.as_dict() if reuse_write is not None else None, + ) diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index ec101410f671..b995e5acb6fc 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -47,7 +47,7 @@ def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: self, tune_context ) - def apply(self, schedule: Schedule, block: BlockRV) -> List[Schedule]: + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: """Apply a schedule rule to the specific block in the given schedule. Parameters @@ -62,7 +62,9 @@ def apply(self, schedule: Schedule, block: BlockRV) -> List[Schedule]: design_spaces : List[Schedule] The list of schedules generated by applying the schedule rule. """ - return _ffi_api.ScheduleRuleApply(self, schedule, block) + return _ffi_api.ScheduleRuleApply( # type: ignore # pylint: disable=no-member + self, sch, block + ) @register_object("meta_schedule.PyScheduleRule") @@ -91,4 +93,4 @@ def f_as_string() -> str: ) def __str__(self) -> str: - return f"PyScheduleRule({_get_hex_address(self.handle)})" + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py index 6a7b27b1f070..b64891a3858d 100644 --- a/python/tvm/meta_schedule/testing/__init__.py +++ b/python/tvm/meta_schedule/testing/__init__.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Testing utilities in meta schedule""" +from . import te_workload +from . import schedule_rule from .local_rpc import LocalRPC from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network, get_torch_model from .te_workload import create_te_workload diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py new file mode 100644 index 000000000000..f92768d14251 --- /dev/null +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default schedule rules""" +from typing import List + +from tvm.meta_schedule.schedule_rule import ( + AutoInline, + MultiLevelTiling, + ReuseType, + ScheduleRule, +) +from tvm.target import Target + + +def get(target: Target) -> List[ScheduleRule]: + """Default schedule rules""" + if target.kind.name == "llvm": + return [ + auto_inline(target), + multi_level_tiling(target), + ] + if target.kind.name == "cuda": + return [ + auto_inline(target), + multi_level_tiling(target), + auto_inline_after_tiling(target), + ] + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def auto_inline(target: Target) -> ScheduleRule: + """Default schedule rules for auto inline""" + if target.kind.name == "llvm": + return AutoInline( + into_producer=False, + into_consumer=True, + into_cache_only=False, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ) + if target.kind.name == "cuda": + return AutoInline( + into_producer=False, + into_consumer=True, + into_cache_only=False, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def auto_inline_after_tiling(target: Target) -> ScheduleRule: + """Default schedule rules for auto inline after tiling""" + if target.kind.name == "llvm": + return AutoInline( + into_producer=True, + into_consumer=True, + into_cache_only=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ) + if target.kind.name == "cuda": + return AutoInline( + into_producer=True, + into_consumer=True, + into_cache_only=True, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def multi_level_tiling(target: Target) -> ScheduleRule: + """Default schedule rules for with multi-level tiling and reuse""" + if target.kind.name == "llvm": + return MultiLevelTiling( + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_max_len=None, + reuse_read=None, + reuse_write=ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ) + if target.kind.name == "cuda": + return MultiLevelTiling( + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], + max_innermost_factor=64, + vector_load_max_len=4, + reuse_read=ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=ReuseType( + req="must", + levels=[3], + scope="local", + ), + ) + raise NotImplementedError(f"{target.kind.name} is not supported") diff --git a/python/tvm/meta_schedule/testing/te_workload.py b/python/tvm/meta_schedule/testing/te_workload.py index e146750e259b..d57bea86e44b 100644 --- a/python/tvm/meta_schedule/testing/te_workload.py +++ b/python/tvm/meta_schedule/testing/te_workload.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. """Workloads in TE""" +# pylint: disable=missing-docstring from typing import Tuple + from tvm import te, tir, topi @@ -575,6 +577,92 @@ def conv2d_winograd_nhwc( # pylint: disable=invalid-name,missing-docstring return (inputs, kernel_pack, output) +def matmul(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A") + b = te.placeholder((k, m), name="B") + k = te.reduce_axis((0, k), name="k") + c = te.compute( + (n, m), + lambda i, j: te.sum(a[i, k] * b[k, j], axis=[k]), + name="C", + ) + return (a, b, c) + + +def matmul_relu(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A") + b = te.placeholder((m, k), name="B") + k = te.reduce_axis((0, k), name="k") + c = te.compute( + (n, m), + lambda i, j: te.sum(a[i, k] * b[k, j], axis=[k]), + name="C", + ) + d = topi.nn.relu(c) # pylint: disable=invalid-name + return (a, b, d) + + +def conv2d_nchw( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + co: int, + kh: int, + kw: int, + stride: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + x = te.placeholder((n, ci, h, w), name="X") + w = te.placeholder((co, ci, kh, kw), name="W") + y = topi.nn.conv2d_nchw(Input=x, Filter=w, stride=stride, padding=padding, dilation=dilation) + return (x, w, y) + + +def conv2d_nchw_bias_bn_relu( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + co: int, + kh: int, + kw: int, + stride: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor]: + oh = (h + 2 * padding - (kh - 1) * dilation - 1) // stride + 1 # pylint: disable=invalid-name + ow = (w + 2 * padding - (kw - 1) * dilation - 1) // stride + 1 # pylint: disable=invalid-name + x = te.placeholder((n, ci, h, w), name="X") + w = te.placeholder((co, ci, kh, kw), name="W") + b = te.placeholder((co, 1, 1), name="B") + bn_scale = te.placeholder((co, 1, 1), name="bn_scale") + bn_offset = te.placeholder((co, 1, 1), name="bn_offset") + y = topi.nn.conv2d_nchw(Input=x, Filter=w, stride=stride, padding=padding, dilation=dilation) + y = te.compute((n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] + b[j, 0, 0], name="bias_add") + y = te.compute( + (n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] * bn_scale[j, 0, 0], name="bn_mul" + ) + y = te.compute( + (n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] + bn_offset[j, 0, 0], name="bn_add" + ) + y = topi.nn.relu(y) + return (x, w, b, bn_scale, bn_offset, y) + + + +def max_pool2d_nchw( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + padding: int, +) -> Tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid-name + x = te.placeholder((n, ci, h, w), name="X") + y = topi.nn.pool2d(x, [2, 2], [1, 1], [1, 1], [padding, padding, padding, padding], "max") + return (x, y) + def create_te_workload(name: str, idx: int) -> tir.PrimFunc: workload_func, params = CONFIGS[name] @@ -741,4 +829,3 @@ def create_te_workload(name: str, idx: int) -> tir.PrimFunc: ], ), } - diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index af219086395a..fe56198f0dbd 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -16,13 +16,14 @@ # under the License. """Meta Schedule tuning context.""" -from typing import TYPE_CHECKING, Optional, List +from typing import List, Optional, TYPE_CHECKING from tvm import IRModule from tvm._ffi import register_object from tvm.meta_schedule.utils import cpu_count from tvm.runtime import Object from tvm.target import Target +from tvm.tir import PrimFunc from . import _ffi_api @@ -114,7 +115,7 @@ def __init__( sch_rules : List[ScheduleRule] = [] The schedule rules. postproc : List[Postproc] = [] - The post processings. + The post-processors. mutator : List[Mutator] = [] The mutators. task_name : Optional[str] = None @@ -125,6 +126,8 @@ def __init__( num_threads : Optional[int] = None The number of threads to be used, None means using the logical cpu count. """ + if isinstance(mod, PrimFunc): + mod = IRModule.from_expr(mod) if num_threads is None: num_threads = cpu_count() diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc new file mode 100644 index 000000000000..711401591c3d --- /dev/null +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::Schedule; + +/*! \brief The type of inline to be performed on a specific block */ +enum class InlineType : int32_t { + /*! \brief No inline opportunity */ + kNoInline = 0, + /*! \brief Inline the block into its consumer */ + kInlineIntoConsumer = 1, + /*! \brief Inline the block into its producer */ + kInlineIntoProducer = 2, +}; + +/*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */ +class AutoInlineNode : public ScheduleRuleNode { + public: + /*! \brief Checks if the specific block should be inlined */ + inline InlineType CheckInline(const Schedule& sch, const BlockRV& block_rv); + + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final {} + + // Inherited from ScheduleRuleNode + Array Apply(const Schedule& sch, const BlockRV& block_rv) final { + InlineType inline_type = CheckInline(sch, block_rv); + if (inline_type == InlineType::kInlineIntoConsumer) { + sch->ComputeInline(block_rv); + } else if (inline_type == InlineType::kInlineIntoProducer) { + sch->ReverseComputeInline(block_rv); + } + return {sch}; + } + + public: + /*! \brief If allows to inline a block into its producer */ + bool into_producer; + /*! \brief If allows to inline a block into its consumer */ + bool into_consumer; + /*! \brief If it only allows to inline into a block generated by cache_read/write */ + bool into_cache_only; + /*! \brief Always inline constant tensors */ + bool inline_const_tensor; + /*! \brief Always disallow if-then-else-like constructs */ + bool disallow_if_then_else; + /*! \brief Always require the read-to-write mapping to be injective to do auto inline */ + bool require_injective; + /*! \brief Always require the read-to-write mapping to be ordered to do auto inline */ + bool require_ordered; + /*! \brief The operators that are disallowed in auto inline */ + Array disallow_op; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("into_producer", &into_producer); + v->Visit("into_consumer", &into_consumer); + v->Visit("into_cache_only", &into_cache_only); + v->Visit("inline_const_tensor", &inline_const_tensor); + v->Visit("disallow_if_then_else", &disallow_if_then_else); + v->Visit("require_injective", &require_injective); + v->Visit("require_ordered", &require_ordered); + v->Visit("disallow_op", &disallow_op); + } + + static constexpr const char* _type_key = "meta_schedule.AutoInline"; + TVM_DECLARE_FINAL_OBJECT_INFO(AutoInlineNode, ScheduleRuleNode); +}; + +inline InlineType AutoInlineNode::CheckInline(const Schedule& sch, const BlockRV& block_rv) { + using namespace tvm::tir; + StmtSRef block_sref = sch->GetSRef(block_rv); + ScheduleState state = sch->state(); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + BlockRealize realize = GetBlockRealize(state, block_sref); + // Cond 1. The block has only one write buffer + if (block->writes.size() != 1) { + return InlineType::kNoInline; + } + // Cond 2. The block is a spatial block + if (!IsSpatial(block_sref)) { + return InlineType::kNoInline; + } + // Cond 3. For a block that generates a constant tensor, ignore all other conditions + if (inline_const_tensor && block->reads.empty()) { + return InlineType::kInlineIntoConsumer; + } + // Cond 4. The block doesn't contain any disallowed operators + if (!disallow_op.empty() && HasOp(realize, disallow_op)) { + return InlineType::kNoInline; + } + // Cond 5. The block doesn't have any if-then-else-like constructs + if (disallow_if_then_else && HasIfThenElse(realize)) { + return InlineType::kNoInline; + } + // Cond 6. The mapping from read indices to write indices are injective and ordered + if (require_injective || require_ordered) { + const BufferRegion& write_region = block->writes[0]; + for (const BufferRegion& read_region : block->reads) { + bool injective, ordered; + constexpr auto _ = std::ignore; + std::tie(/*exists=*/_, /*surjective=*/_, injective, ordered, /*no_const_read=*/_, + /*no_shift_read=*/_) = AnalyzeReadWritePattern(read_region, write_region); + if (require_injective && injective == false) { + return InlineType::kNoInline; + } + if (require_ordered && ordered == false) { + return InlineType::kNoInline; + } + } + } + // Last cond: Check inline into the spatial consumer or the spatial producer + if (into_consumer) { + Array consumer_srefs = GetConsumers(state, block_sref); + if (consumer_srefs.size() == 1 && IsSpatial(consumer_srefs[0])) { + if (!into_cache_only || + tir::GetAnn(consumer_srefs[0], tir::attr::meta_schedule_cache_type).defined()) { + if (CanComputeInline(state, block_sref)) { + return InlineType::kInlineIntoConsumer; + } + } + } + } + if (into_producer) { + Array producer_srefs = GetProducers(state, block_sref); + if (producer_srefs.size() == 1 && IsSpatial(producer_srefs[0])) { + if (!into_cache_only || + tir::GetAnn(producer_srefs[0], tir::attr::meta_schedule_cache_type).defined()) { + if (CanReverseComputeInline(state, block_sref)) { + return InlineType::kInlineIntoProducer; + } + } + } + } + return InlineType::kNoInline; +} + +ScheduleRule ScheduleRule::AutoInline(bool into_producer, // + bool into_consumer, // + bool into_cache_only, // + bool inline_const_tensor, // + bool disallow_if_then_else, // + bool require_injective, // + bool require_ordered, // + Optional> disallow_op) { + ObjectPtr n = make_object(); + n->into_producer = into_producer; + n->into_consumer = into_consumer; + n->into_cache_only = into_cache_only; + n->inline_const_tensor = inline_const_tensor; + n->disallow_if_then_else = disallow_if_then_else; + n->require_injective = require_injective; + n->require_ordered = require_ordered; + n->disallow_op.clear(); + if (disallow_op.defined()) { + Array op_names = disallow_op.value(); + n->disallow_op.reserve(op_names.size()); + for (const String& op_name : op_names) { + n->disallow_op.push_back(Op::Get(op_name)); + } + } + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(AutoInlineNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") + .set_body_typed(ScheduleRule::AutoInline); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc new file mode 100644 index 000000000000..dabfc805268e --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { +/*! + * \brief Get the buffer dimensions for all the read buffers of a block, but marks the reduction + * buffers' dimensions as -1 + * \param block_sref The block to be processed + * \return The buffer dimensions for all the read buffers of a block, except for reduction buffers + * \note The method is not designed for generic analysis and relies on assumptions in the scenario + * of multi-level tiling, so it's intentionally kept inside this file not in the analysis header + */ +std::vector GetReadBufferNDims(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BufferNode* write_buffer = block->writes[0]->buffer.get(); + int n = block->reads.size(); + std::vector results(n, -1); + for (int i = 0; i < n; ++i) { + const BufferNode* read_buffer = block->reads[i]->buffer.get(); + if (read_buffer != write_buffer) { + results[i] = read_buffer->shape.size(); + } + } + return results; +} +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::ExprRV; +using tir::IterVarType; +using tir::LoopRV; +using tir::Schedule; + +/*! + * \brief Configuration of data reuse type: + * 0) kNoReuse: no reuse is allowed, then no cache_read/write is performed. + * 1) kMayReuse: reuse is allowed, but no reuse is explored. + * 2) kMustReuse: reuse is allowed and no reuse is not explored. + */ +enum class ReuseType : int32_t { + kNoReuse = 0, + kMayReuse = 1, + kMustReuse = 2, +}; + +/*! + * \brief Converts a string to ReuseType. + * \param str The string to be converted. + * \return The converted ReuseType. + */ +ReuseType Str2ReuseType(const String& str) { + if (str == "no") { + return ReuseType::kNoReuse; + } else if (str == "may") { + return ReuseType::kMayReuse; + } else if (str == "must") { + return ReuseType::kMustReuse; + } else { + LOG(FATAL) << "ValueError: Unknown ReuseType: " << str; + throw; + } +} + +/*! \brief Configuration of data reuse patterns */ +struct ReuseConfig { + /*! \brief Type of data reuse: no-reuse, may-reuse or must-reuse */ + ReuseType req; + /*! \brief Which levels are caching stage inserted at */ + std::vector levels; + /*! \brief The storage scope */ + String scope; + + /*! \brief Default constructor: no data reuse */ + ReuseConfig() : req(ReuseType::kNoReuse) {} + + /*! \brief Construct from a configuration dictionary */ + explicit ReuseConfig(const Map& config) + : req(Str2ReuseType(Downcast(config.at("req")))), + levels(support::AsVector(Downcast>(config.at("levels")))), + scope(Downcast(config.at("scope"))) { + ICHECK_EQ(config.size(), 3); + } +}; + +/*! \brief The state of auto scheduling for the multi-level tiling rule */ +struct State { + /*! \brief The schedule to date */ + Schedule sch; + /*! \brief The block to be tiled */ + BlockRV block_rv; + /*! \brief The write cache */ + Optional write_cache; + /*! \brief Indicating if the write cache is generated by cache_write */ + bool write_cache_is_added; + /*! \brief The loop tiles */ + Array> tiles; + + /*! \brief Default constructor */ + explicit State(Schedule sch, BlockRV block_rv, Optional write_cache = NullOpt, + bool write_cache_is_added = false, Array> tiles = {}) + : sch(sch), + block_rv(block_rv), + write_cache(write_cache), + write_cache_is_added(write_cache_is_added), + tiles(tiles) {} +}; + +/*! + * \brief Helper to apply a sub-rule to a list of auto scheduling states + * \tparam FLambda The type of the sub-rule functor + * \param states The list of states to be applied + * \return The list of states after applying the sub-rule + */ +template +std::vector SubRule(std::vector states, FLambda sub_rule) { + std::vector results; + for (auto&& state : states) { + std::vector next = sub_rule(std::move(state)); + results.insert(results.end(), + std::make_move_iterator(next.begin()), // + std::make_move_iterator(next.end())); + } + return results; +} + +/*! + * \brief The mega rule: multi-level tiling with data reuse + */ +class MultiLevelTilingNode : public ScheduleRuleNode { + public: + // SubRule 1. add write cache + inline std::vector AddWriteReuse(State state) const; + // SubRule 2. tile the loop nest + inline std::vector TileLoopNest(State state) const; + // SubRule 3. add read cache + inline std::vector AddReadReuse(State state) const; + // SubRule 4. fuse write cache + inline std::vector FuseWriteReuse(State state) const; + // Do nothing; Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final {} + // Entry of the mega rule; Inherited from ScheduleRuleNode + Array Apply(const Schedule& sch, const BlockRV& block_rv) final { + if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { + return {sch}; + } + std::vector states{State(sch, block_rv)}; + states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); + states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); + states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); }); + states = SubRule(std::move(states), [&](State state) { return FuseWriteReuse(state); }); + Array results; + for (auto&& state : states) { + results.push_back(std::move(state.sch)); + } + return results; + } + + public: + /*! + * \brief The tiling structure. Recommended: + * - 'SSRSRS' on CPU + * - 'SSSRRSRS' on GPU + */ + String structure; + /*! \brief For each level of tiles, which thread axis it is bound to */ + Array tile_binds; + /*! \brief The maximum size of the innermost factor */ + int max_innermost_factor; + /*! \brief The length of vector lane in vectorized cooperative fetching */ + int vector_load_max_len; + /*! \brief Data reuse configuration for reading */ + ReuseConfig reuse_read_; + /*! \brief Data reuse configuration for writing */ + ReuseConfig reuse_write_; + /*! \brief The indices of spatial tiles in `structure` */ + std::vector s_indices_; + /*! \brief The indices of reduction tiles in `structure` */ + std::vector r_indices_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("structure", &structure); + v->Visit("tile_binds", &tile_binds); + v->Visit("max_innermost_factor", &max_innermost_factor); + v->Visit("vector_load_max_len", &vector_load_max_len); + // `reuse_read_` is not visited + // `reuse_write_` is not visited + // `s_indices_` is not visited + // `r_indices_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.MultiLevelTiling"; + TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode); +}; + +inline std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { + const ReuseConfig& config = this->reuse_write_; + if (config.req == ReuseType::kNoReuse) { + return {std::move(state)}; + } + // Case 1. If the write cache is already there, we don't need to add another. + if (config.req == ReuseType::kMayReuse) { + Array consumer_rvs = state.sch->GetConsumers(state.block_rv); + if (consumer_rvs.size() == 1 && IsWriteCache(state.sch->GetSRef(consumer_rvs[0]))) { + state.write_cache = consumer_rvs[0]; + state.write_cache_is_added = false; + return {std::move(state)}; + } + } + std::vector results; + results.reserve(2); + // Case 2. No write cache is added + if (config.req == ReuseType::kMayReuse) { + State new_state(/*sch=*/state.sch->Copy(), /*block_rv=*/state.block_rv, + /*write_cache=*/NullOpt, + /*write_cache_is_added=*/false); + new_state.sch->Seed(state.sch->ForkSeed()); + results.emplace_back(std::move(new_state)); + } + // Case 3. Add one write cache + state.write_cache = state.sch->CacheWrite(/*block_rv=*/state.block_rv, /*read_buffer_index=*/0, + /*storage_scope=*/config.scope); + { + tir::Annotate(state.sch->state(), state.sch->GetSRef(state.write_cache.value()), // + tir::attr::meta_schedule_cache_type, // + Integer(tir::attr::meta_schedule_cache_type_write)); + } + + state.write_cache_is_added = true; + results.emplace_back(std::move(state)); + return results; +} + +inline std::vector MultiLevelTilingNode::TileLoopNest(State state) const { + Schedule& sch = state.sch; + const BlockRV& block_rv = state.block_rv; + // Step 1. Assuming trivial binding, pair the loops and their iter-var-types + Array loops = sch->GetLoops(block_rv); + std::vector iter_types = GetBlockVarTypes(sch->GetSRef(state.block_rv)); + ICHECK_EQ(loops.size(), iter_types.size()); + // Step 2. For each loop axis, tile it + std::vector> tiles(s_indices_.size() + r_indices_.size()); + for (int i = 0, n = loops.size(); i < n; ++i) { + const std::vector* idx = nullptr; + if (iter_types[i] == IterVarType::kDataPar) { + idx = &s_indices_; + } else if (iter_types[i] == IterVarType::kCommReduce) { + idx = &r_indices_; + } else { + continue; + } + // Do the split + int n_tiles = idx->size(); + LoopRV loop = loops[i]; + Array factors = sch->SamplePerfectTile( + /*loop=*/loop, + /*n=*/n_tiles, + /*max_innermost_factor=*/max_innermost_factor); + Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); + // Put every tile to its slot + for (int j = 0; j < n_tiles; ++j) { + tiles[idx->at(j)].push_back(splits[j]); + } + } + // Step 3. Reorder to organize the tiles + sch->Reorder(support::ConcatArrayList(tiles.begin(), tiles.end())); + // Step 4. Bind the tiles to threads + int n_binds = std::min(tile_binds.size(), tiles.size()); + for (int i = 0; i < n_binds; ++i) { + LoopRV fused = sch->Fuse(tiles[i]); + sch->Bind(fused, tile_binds[i]); + tiles[i] = {fused}; + } + state.tiles = Array>{tiles.begin(), tiles.end()}; + return {state}; +} + +inline std::vector MultiLevelTilingNode::AddReadReuse(State state) const { + const ReuseConfig& config = this->reuse_read_; + if (config.req == ReuseType::kNoReuse) { + return {std::move(state)}; + } + ICHECK(config.req != ReuseType::kMayReuse); + const BlockRV& block_rv = state.block_rv; + std::vector results; + results.reserve(config.levels.size()); + for (int level : config.levels) { + Schedule sch = state.sch->Copy(); + sch->Seed(state.sch->ForkSeed()); + const LoopRV& loop_rv = state.tiles[level - 1].back(); + // Enumerate all buffers that are read but not written + std::vector read_buffer_ndims = tir::GetReadBufferNDims(sch->GetSRef(block_rv)); + for (int i = 0, n_reads = read_buffer_ndims.size(); i < n_reads; ++i) { + int buffer_ndim = read_buffer_ndims[i]; + if (buffer_ndim == -1) { + continue; + } + // Do cache_read + BlockRV cache_read_block = sch->CacheRead(block_rv, i, config.scope); + { + tir::Annotate(sch->state(), sch->GetSRef(cache_read_block), // + tir::attr::meta_schedule_cache_type, + Integer(tir::attr::meta_schedule_cache_type_read)); + } + // Insert cache_read block to the proper place + sch->ComputeAt(cache_read_block, loop_rv, true); + // Fuse the iterators of the cache_read + Array buffer_loops = sch->GetLoops(cache_read_block); + LoopRV fused = sch->Fuse(Array{buffer_loops.end() - buffer_ndim, // + buffer_loops.end()}); + // Do cooperative fetching + if (vector_load_max_len > 0) { + // cooperative fetch + vectorized loading + // Split into inner and outer + Array factors = sch->SamplePerfectTile(fused, 2, vector_load_max_len); + Array splits = sch->Split(fused, {factors[0], factors[1]}); + // Vectorize the inner loop + sch->Vectorize(splits[1]); + fused = splits[0]; + } + // Add cooperative fetching + sch->Annotate(fused, tir::attr::meta_schedule_lazy_cooperative_fetch, Integer(1)); + } + State new_state = state; + new_state.sch = sch; + results.push_back(std::move(new_state)); + } + return results; +} + +inline std::vector MultiLevelTilingNode::FuseWriteReuse(State state) const { + const ReuseConfig& config = this->reuse_write_; + if (config.req == ReuseType::kNoReuse) { + return {std::move(state)}; + } + // If the only-consumer does not exist, or is not elementwise, then do not do fusion + if (!state.write_cache.defined()) { + return {std::move(state)}; + } + std::vector results; + // Special case. + // Stages added by `cache_write` must be fused at some level, otherwise it has no benefit. + // On the other hand, If the consumer stage is not added by `cache_write`, + // we may choose not to fuse by setting `must_cache_write = False` + if (!state.write_cache_is_added && config.req != ReuseType::kMustReuse) { + results.push_back(state); + } + BlockRV consumer = state.write_cache.value(); + // Enumerate the level of tile to be fused at + for (int level : config.levels) { + Schedule sch = state.sch->Copy(); + sch->Seed(state.sch->ForkSeed()); + const LoopRV& loop_rv = state.tiles[level - 1].back(); + sch->ReverseComputeAt(consumer, loop_rv, true); + State new_state = state; + new_state.sch = sch; + results.push_back(std::move(new_state)); + } + return results; +} + +// Constructor + +ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional> tile_binds, + Optional max_innermost_factor, + Optional vector_load_max_len, + Optional> reuse_read, + Optional> reuse_write) { + ObjectPtr n = make_object(); + n->structure = structure; + n->tile_binds = tile_binds.value_or({}); + n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; + n->vector_load_max_len = vector_load_max_len.value_or(Integer(-1))->value; + n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig(); + n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig(); + for (int i = 0, len = structure.size(); i < len; ++i) { + char c = structure.data()[i]; + if (c == 'S') { + n->s_indices_.push_back(i); + } else if (c == 'R') { + n->r_indices_.push_back(i); + } else { + LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure; + } + } + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(MultiLevelTilingNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTiling") + .set_body_typed(ScheduleRule::MultiLevelTiling); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 41afbc57d79b..f1fb27f7a3a3 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -104,8 +104,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { /*mod=*/mod_, // /*rand_state=*/ForkSeed(&this->rand_state_), // /*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags, // - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail // - ); + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); std::vector stack; Array result{sch}; diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index be76d3e8db98..d8e96d0156e4 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -44,7 +44,7 @@ #include "../printer/text_printer.h" #include "../support/array.h" #include "../support/base64.h" -#include "../tir/schedule/primitive.h" +#include "../tir/schedule/utils.h" namespace tvm { namespace meta_schedule { diff --git a/src/support/array.h b/src/support/array.h index 95b4f58a2e22..218150f9dba0 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -100,6 +100,29 @@ inline Array AsArray(const ShapeTuple& shape) { return result; } +/*! + * \brief Concatenate a list of arrays into a single array + * \tparam T The type of elements in the arrays + * \tparam Iterator The type of the iterator into the list of arrays + * \param begin The begin iterator to the array list + * \param end The end iterator to the array list + * \return The concatenated array + */ +template +inline Array ConcatArrayList(Iterator begin, Iterator end) { + int size = 0; + for (Iterator it = begin; it != end; ++it) { + size += (*it).size(); + } + Array result; + result.reserve(size); + for (Iterator it = begin; it != end; ++it) { + const auto& item = *it; + result.insert(result.end(), item.begin(), item.end()); + } + return result; +} + /********** Implementation details of AsVector **********/ namespace details { diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 5a2f46c910b4..d8327df295c8 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -21,6 +21,7 @@ #include +#include #include #include #include @@ -169,6 +170,27 @@ bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref); +/*! + * \brief Check if the block is a data parallel block, i.e. all the block vars are data parallel + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block is a data parallel block + */ +bool IsSpatial(const StmtSRef& block_sref); + +/*! + * \brief Extracts the types of the block vars + * \param block_sref The block to be checked + * \return A vector of types of the block vars + */ +std::vector GetBlockVarTypes(const StmtSRef& block_sref); + +/*! + * \brief Checks if a block could be considered as a "write cache" + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block is a write cache + */ +bool IsWriteCache(const StmtSRef& block_sref); + /******** Binding ********/ /*! * \brief Verifies if the block binding in a specific BlockRealize is an affine binding. @@ -190,6 +212,15 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va */ void CheckAffineBinding(const ScheduleState& self, Block block); +/*! + * \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop, + * from outer to inner. + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block has a trivial binding + */ +bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref); + /*! * \brief Extracts the ranges of loop variables in a path of the sref tree * \param low_inclusive The lowest node in the path @@ -345,6 +376,75 @@ std::vector> GetReducerGetters(); bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs); +/******** Misc ********/ + +/*! + * \brief Given the read/write region, extract the pattern of their index correspondence + * namely, the mapping from read index to the write index. + * \param read_region The read region + * \param write_region The write region + * \return A tuple of booleans, the extracted pattern + * 0) exists: if the pattern is found + * 1) surjective: if the pattern is surjective, i.e. each write index is mapped at least once + * e.g. A[i, j] = B[i, i, j] + * 2) injective: if the pattern is injective, i.e. each write index is mapped at most once. + * e.g. A[i, j] = B[i] + * 3) ordered: if the mapping is ordered + * 4) no_const_read: if there is no constant indexing in the read indices, + * e.g. A[i, j] = B[0, i, j] + * 5) no_shift_read: if there is no constant shift in the read indices, + * e.g. A[i, j] = B[i + 1, j] + */ +std::tuple +AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region); + +/*! + * \brief Checks if the given block has data reuse opportunity and thus multi-level tiling is + * beneficial. + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean indicating whether the block has data reuse opportunity + */ +bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref); + +/*! + * \brief Checks if the given AST contains the specific operators + * \param stmt The AST to be checked + * \param ops The list of operators to be checked + * \return A boolean indicating whether the AST contains the specific operators + */ +bool HasOp(const Stmt& stmt, const Array& ops); + +/*! + * \brief Checks if the given AST contains if-then-else, including + * 1) IfThenElse statement + * 2) Select expression + * 3) The operator `tir.if_then_else` + * 4) Block predicates + */ +bool HasIfThenElse(const Stmt& stmt); + +/*! + * \brief Checks if a block could be successfully computed inline into its consumer + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean indicating whether the block could be successfully computed inline + */ +bool CanComputeInline(const ScheduleState& self, const StmtSRef& block_sref); + +/*! + * \brief Checks if a block could be successfully computed inline into its producer + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean indicating whether the block could be successfully computed inline + */ +bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index e3a535e9b3d4..8d960d68eed9 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -440,6 +440,43 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, } } +bool IsSpatial(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + for (const IterVar& iter_var : block->iter_vars) { + if (iter_var->iter_type != IterVarType::kDataPar) { + return false; + } + } + return true; +} + +std::vector GetBlockVarTypes(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + std::vector results; + results.reserve(block->iter_vars.size()); + for (const IterVar& iter_var : block->iter_vars) { + results.push_back(iter_var->iter_type); + } + return results; +} + +bool IsWriteCache(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + if (block->writes.size() != 1) { + return false; + } + const BufferRegion& write_region = block->writes[0]; + for (const BufferRegion& read_region : block->reads) { + bool exists, surjective, injective, ordered, no_const_read, no_shift_read; + std::tie(exists, surjective, injective, ordered, no_const_read, no_shift_read) = + AnalyzeReadWritePattern(read_region, write_region); + if (!(injective && ordered)) { + return false; + } + } + return true; +} + /******** Binding ********/ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, @@ -487,6 +524,22 @@ void CheckAffineBinding(const ScheduleState& self, Block block) { } } +bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Array loops = GetLoops(block_sref); + Array binds = GetBlockRealize(self, block_sref)->iter_values; + if (loops.size() != binds.size()) { + return false; + } + for (int i = 0, n = loops.size(); i < n; ++i) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]); + if (binds[i].get() != loop->loop_var.get()) { + return false; + } + } + return true; +} + Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, const Optional& high_exclusive, const runtime::StorageScope& extra_relax_scope) { @@ -1172,5 +1225,190 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { return GetRef(p); } +/******** Misc ********/ + +std::tuple +AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region) { + static constexpr const std::tuple kNotExist = { + false, false, false, false, false, false}; + // Step 1. Extract the write indices + int w_dim = write_region->buffer->shape.size(); + std::unordered_map var2idx; + var2idx.reserve(w_dim); + for (int i = 0; i < w_dim; ++i) { + const Range& dom = write_region->region[i]; + if (as_const_int(dom->extent) == nullptr) { + return kNotExist; + } + if (const auto* v = dom->min.as()) { + var2idx.emplace(v, i); + } else { + return kNotExist; + } + } + // Step 2. Map each read index to a write index + bool no_const_read = true; + bool no_shift_read = true; + int r_dim = read_region->buffer->shape.size(); + std::vector mapped(r_dim, -1); + for (int i = 0; i < r_dim; ++i) { + const Range& dom = read_region->region[i]; + if (as_const_int(dom->extent) == nullptr) { + return kNotExist; + } + // Case 1. Read index is a constant + if (as_const_int(dom->min) != nullptr) { + no_const_read = false; + continue; + } + // Case 2. Read index cannot be recognized as `var +/- const` + // where `var` is a write index and `const` is an optional constant shift + Optional opt_const = NullOpt; + const VarNode* var = + static_cast(AnalyzeVarWithShift(dom->min, &opt_const).get()); + if (var == nullptr || !var2idx.count(var)) { + return kNotExist; + } + // Case 3. Read index is `var +/- const` + mapped[i] = var2idx.at(var); + if (opt_const.defined()) { + no_shift_read = false; + } + } + // Step 3. Check if the mapping is ordered, and count how many times each var is mapped + std::vector mapped_counter(w_dim, 0); + bool ordered = true; + int last_mapped = -1; + for (int i : mapped) { + if (i != -1) { + ++mapped_counter[i]; + if (last_mapped != -1 && last_mapped > i) { + ordered = false; + } + last_mapped = i; + } + } + // Step 4. Check if the mapping is surjective or injective + // Surjective: each write index is mapped at least once + // Injective: each write index is mapped at most once + bool surjective = true; + bool injective = true; + for (int cnt : mapped_counter) { + if (cnt == 0) { + surjective = false; + } else if (cnt >= 2) { + injective = false; + } + } + return {/*exist=*/true, surjective, injective, ordered, no_const_read, no_shift_read}; +} + +bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + if (block->writes.size() != 1 || block->reads.empty() || IsSpatial(block_sref) || + !IsTrivialBinding(self, block_sref)) { + return false; + } + const BufferNode* write_buffer = block->writes[0]->buffer.get(); + // Step 1. Sort out spatial block variables + std::vector spatial_block_vars; + spatial_block_vars.reserve(block->iter_vars.size()); + for (const IterVar& block_var : block->iter_vars) { + if (block_var->iter_type == IterVarType::kDataPar) { + spatial_block_vars.push_back(block_var->var.get()); + } + } + // Step 2. Enumerate each read region, check the number of block vars that are not used + // to index the read region + int total_unused_block_vars = 0; + std::unordered_set read_buffers; + read_buffers.reserve(block->reads.size()); + for (const BufferRegion& buffer_region : block->reads) { + const BufferNode* buffer = buffer_region->buffer.get(); + const Array& regions = buffer_region->region; + // Step 2.1. Duplication of read buffers are not allowed + if (read_buffers.insert(buffer).second == false) { + return false; + } + // Step 2.2. Skip the reduction buffer + if (buffer == write_buffer) { + continue; + } + // Step 2.3. Collect the block vars that are used to index the read region + std::unordered_set vars; + for (const Range& range : regions) { + if (as_const_int(range->extent) == nullptr) { + return false; + } + for (const Var& var : UndefinedVars(range->min)) { + vars.insert(var.get()); + } + } + // Step 2.4. Check if the block vars are not used to index the read region + int n_unused_block_vars = 0; + for (const VarNode* block_var : spatial_block_vars) { + if (vars.count(block_var) == 0) { + ++n_unused_block_vars; + } + } + total_unused_block_vars += n_unused_block_vars; + } + return total_unused_block_vars >= 1; +} + +bool HasOp(const Stmt& stmt, const Array& ops) { + std::unordered_set op_set; + op_set.reserve(ops.size()); + for (const Op& op : ops) { + op_set.insert(op.operator->()); + } + bool found = false; + tir::PreOrderVisit(stmt, [&found, &op_set](const ObjectRef& obj) -> bool { + if (found) { + return false; + } + if (const auto* call = obj.as()) { + if (op_set.count(call->op.operator->())) { + found = true; + } + } + return !found; + }); + return found; +} + +bool HasIfThenElse(const Stmt& stmt) { + bool has_branch = false; + auto f_visit = [&has_branch](const ObjectRef& obj) -> bool { + if (has_branch) { + // stop visiting + return false; + } + if (const auto* realize = obj.as()) { + // Case 1: BlockRealize + if (!is_one(realize->predicate)) { + has_branch = true; + } + } else if (obj->IsInstance() || obj->IsInstance()) { + // Case 2: IfThenElse / Select + has_branch = true; + } else if (const auto* call = obj.as()) { + // Case 3: Call + static const Op& op_if_then_else = Op::Get("tir.if_then_else"); + if (call->op.same_as(op_if_then_else)) { + has_branch = true; + } + } + return !has_branch; + }; + tir::PreOrderVisit(stmt, f_visit); + return has_branch; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 07f3b59d4bee..6b08d1dd8e8c 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -129,7 +129,7 @@ class ConcreteScheduleNode : public ScheduleNode { void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; void Annotate(const BlockRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; - void Unannotate(const BlockRV& loop_rv, const String& ann_key); + void Unannotate(const BlockRV& loop_rv, const String& ann_key) override; /******** Schedule: Misc ********/ void EnterPostproc() override {} diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 539a82f9ae5c..0edd39ac3a7b 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -60,11 +60,27 @@ class NotSingleReadWriteBuffer : public ScheduleError { bool is_read_; Block block_; - static Buffer GetSingleRead(const ScheduleState& self, const Block& block) { - if (block->reads.size() != 1) { + static Buffer GetSingleRead(const ScheduleState& self, const Block& block, + const StmtSRef& scope_root_sref) { + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& + buffer_writers = self->block_info.at(scope_root_sref).scope->buffer_writers; + const BufferNode* read_buffer = nullptr; + for (const BufferRegion& read_region : block->reads) { + const BufferNode* buffer = read_region->buffer.get(); + if (buffer == read_buffer) { + continue; + } + if (buffer_writers.count(GetRef(buffer)) > 0) { + if (read_buffer != nullptr) { + throw NotSingleReadWriteBuffer(self->mod, true, block); + } + read_buffer = buffer; + } + } + if (read_buffer == nullptr) { throw NotSingleReadWriteBuffer(self->mod, true, block); } - return block->reads[0]->buffer; + return GetRef(read_buffer); } static Buffer GetSingleWrite(const ScheduleState& self, const Block& block) { @@ -167,7 +183,7 @@ class OpaqueAccessError : public ScheduleError { * \brief The base class of the inliner, which handles: * 1) Substitute a subtree with the specific block being inlined * 2) Update the block signature to reflect the changes of read/write/allocated buffers - * 3) Maintain a list of index variables and their substition of the buffer being inlined + * 3) Maintain a list of index variables and their substitution of the buffer being inlined */ class BaseInliner : public StmtExprMutator { protected: @@ -522,7 +538,7 @@ class ReverseComputeInliner : public BaseInliner { PrimExpr producer_rhs_{nullptr}; }; -void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { +std::function ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref) { const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref); Block producer_block = GetRef(_producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); @@ -531,6 +547,7 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { /*require_stage_pipeline=*/true, /*require_subtree_compact_dataflow=*/false); // Step 2. Check completeness + CheckNotOutputBlock(self, producer_block_sref, scope_root_sref); CheckCompleteBlock(self, producer_block_sref, scope_root_sref); // Step 3. Analyze the block body ComputeInliner inliner(inlined_buffer, producer_block, scope_root_sref); @@ -546,17 +563,32 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { throw OpaqueAccessError(self->mod, scope_root_sref); } // Step 6. Do the real mutation on the AST and the sref tree in the schedule state - self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); + return [=]() -> void { self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); }; } -void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) { +void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { + ComputeInlineImpl(self, producer_block_sref)(); +} + +bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_sref) { + try { + ComputeInlineImpl(self, producer_block_sref); + } catch (const tvm::runtime::Error& e) { + return false; + } + return true; +} + +std::function ReverseComputeInlineImpl(ScheduleState self, + const StmtSRef& consumer_block_sref) { const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref); Block consumer_block = GetRef(_consumer_block); - Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block); // Step 1. Get the scope block StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, // /*require_stage_pipeline=*/true, /*require_subtree_compact_dataflow=*/false); + Buffer inlined_buffer = + NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block, scope_root_sref); // Step 2. Check completeness CheckCompleteBlock(self, consumer_block_sref, scope_root_sref); // Step 3. Check if the consumer has a single complete producer @@ -575,7 +607,20 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre throw OpaqueAccessError(self->mod, scope_root_sref); } // Step 7. Do the real mutation on the AST and the sref tree in the schedule state - self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); + return [=]() -> void { self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); }; +} + +bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) { + try { + ReverseComputeInlineImpl(self, block_sref); + } catch (const tvm::runtime::Error& e) { + return false; + } + return true; +} + +void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) { + ReverseComputeInlineImpl(self, consumer_block_sref)(); } /******** InstructionKind Registration ********/ diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index 55869e12b6b2..acab85460a71 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -83,7 +83,7 @@ void CheckLoopParallelizableInBlock(const ScheduleState& self, ForKind for_kind, const Block& block = block_realize->block; // Cond 1. The block is required to have affine bindings. - CheckAffineBinding(self, block); + /* CheckAffineBinding(self, block); */ // Cond 2. For each block iter whose binding contains `loop_var`, only two cases are allowed. ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 4acf61860112..0312b924fcb3 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -301,9 +301,9 @@ std::vector SamplePerfectTile( const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, Optional>* decision) { const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); - int64_t extent = GetLoopIntExtent(loop); + const int64_t* extent = GetLoopIntExtent(loop); std::vector result; - if (extent == -1) { + if (extent == nullptr) { // Case 1. Handle loops with non-constant length result = std::vector(n_splits, 1); result[0] = -1; @@ -312,7 +312,7 @@ std::vector SamplePerfectTile( result = support::AsVector(decision->value()); int n = result.size(); ICHECK_GE(n, 2); - int64_t len = extent; + int64_t len = *extent; for (int i = n - 1; i > 0; --i) { int64_t& l = result[i]; // A previous decision could become invalid because of the change of outer tiles @@ -326,7 +326,7 @@ std::vector SamplePerfectTile( result[0] = len; } else { // Case 3. Use fresh new sampling result - result = SamplePerfectTile(rand_state, extent, n_splits, max_innermost_factor); + result = SamplePerfectTile(rand_state, *extent, n_splits, max_innermost_factor); ICHECK_LE(result.back(), max_innermost_factor); } *decision = support::AsArray(result); diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index c66c2ca76693..71a8bdf5b8f5 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -210,26 +210,77 @@ inline Map AsIntSet(const Map& var_dom) { return {result.begin(), result.end()}; } -/**************** Loop extents ****************/ +/**************** PrimExpr parsing and extents ****************/ /*! * \brief Get the extents of a loop * \param loop The loop to be queried - * \return The extents of the loop + * \return The extent of the loop, nullptr if the extent is not constant */ -inline int64_t GetLoopIntExtent(const ForNode* loop) { - const auto* int_extent = loop->extent.as(); - return int_extent ? int_extent->value : -1; -} +inline const int64_t* GetLoopIntExtent(const ForNode* loop) { return as_const_int(loop->extent); } /*! * \brief Get the extents of a loop * \param loop_sref The loop to be queried - * \return The extents of the loop + * \return The extent of the loop, nullptr if the extent is not constant */ -inline int64_t GetLoopIntExtent(const StmtSRef& loop_sref) { +inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) { const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); - return GetLoopIntExtent(loop); + return as_const_int(loop->extent); +} + +/*! + * \brief Check if an expression consists of a single variable, + * or a variable plus/minus an constant integer shift + * \param expr The expression to be checked + * \return result Output, the var if it satisfies the condition; otherwise NullOpt + */ +inline Optional AnalyzeVarWithShift(const PrimExpr& expr, Optional* constant) { + if (const auto* var = expr.as()) { + *constant = NullOpt; + return GetRef(var); + } + arith::PVar var; + arith::PVar shift; + // match: "var + shift" + if ((var + shift).Match(expr) || (shift + var).Match(expr)) { + *constant = shift.Eval(); + return var.Eval(); + } + // match: "var - shift" + if ((var - shift).Match(expr)) { + IntImm result = shift.Eval(); + *constant = IntImm(result->dtype, -result->value); + return var.Eval(); + } + return NullOpt; +} + +/******** Annotation ********/ + +/*! + * \brief Get the annotation on a Block/For + * \tparam TObjectRef The type of the annotation value + * \param sref The sref to the block or the for loop + * \param ann_key The annotation key to be looked up + * \return NullOpt if not found; otherwise the annotation value + */ +template +inline Optional GetAnn(const StmtSRef& sref, const String& ann_key) { + const Map* annotations = nullptr; + if (const auto* loop = sref->StmtAs()) { + annotations = &loop->annotations; + } else if (const auto* block = sref->StmtAs()) { + annotations = &block->annotations; + } else { + LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + } + for (const auto& ann : *annotations) { + if (ann.first == ann_key) { + return Downcast(ann.second); + } + } + return NullOpt; } } // namespace tir diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule.py b/tests/python/unittest/test_meta_schedule_schedule_rule.py index e79ca69ca64d..1d34d94bfe05 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule.py @@ -15,20 +15,15 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring - -from typing import List - import math import re +from typing import List import tvm -from tvm.script import tir as T - -from tvm.meta_schedule.schedule_rule import PyScheduleRule from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.utils import _get_hex_address - -from tvm.tir.schedule import Schedule, BlockRV +from tvm.meta_schedule.schedule_rule import PyScheduleRule +from tvm.script import tir as T +from tvm.tir.schedule import BlockRV, Schedule # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, @@ -61,7 +56,7 @@ def _check_correct(schedule: Schedule): def test_meta_schedule_schedule_rule(): class FancyScheduleRule(PyScheduleRule): - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + def initialize_with_tune_context(self, tune_context: TuneContext) -> None: pass def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: @@ -80,21 +75,18 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: try: tvm.ir.assert_structural_equal(mod, res[0].mod) raise Exception("The schedule rule did not change the schedule.") - except (ValueError): + except ValueError: _check_correct(res[0]) def test_meta_schedule_schedule_rule_as_string(): class YetStillSomeFancyScheduleRule(PyScheduleRule): - def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + def initialize_with_tune_context(self, tune_context: TuneContext) -> None: pass - def apply(self, schedule: Schedule, block: BlockRV) -> List[Schedule]: + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: pass - def __str__(self) -> str: - return f"YetStillSomeFancyScheduleRule({_get_hex_address(self.handle)})" - sch_rule = YetStillSomeFancyScheduleRule() pattern = re.compile(r"YetStillSomeFancyScheduleRule\(0x[a-f|0-9]*\)") assert pattern.match(str(sch_rule)) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py new file mode 100644 index 000000000000..ae60803d08f3 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring diff --git a/tests/python/unittest/test_meta_schedule_sketch_cpu.py b/tests/python/unittest/test_meta_schedule_sketch_cpu.py new file mode 100644 index 000000000000..b5dfdadaa15d --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_sketch_cpu.py @@ -0,0 +1,392 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from typing import List + +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing import te_workload +from tvm.target import Target +from tvm.te import create_prim_func +from tvm.tir.schedule import Trace +from tvm.tir.schedule.schedule import Schedule + + +def _create_context(mod): + from tvm.meta_schedule.testing import ( # pylint: disable=import-outside-toplevel + schedule_rule as sch_rules, + ) + + target = Target("llvm") + ctx = ms.TuneContext( + mod=mod, + target=target, + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=sch_rules.get(target), + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for rule in ctx.sch_rules: + rule.initialize_with_tune_context(ctx) + return ctx + + +def _check_trace(spaces: List[Schedule], expected: List[List[str]]): + expected_traces = {"\n".join(t) for t in expected} + actual_traces = set() + for space in spaces: + trace = Trace(space.trace.insts, {}) + trace = trace.simplified(remove_postproc=True) + str_trace = "\n".join(str(trace).strip().splitlines()) + actual_traces.add(str_trace) + assert str_trace in expected_traces, "\n" + str_trace + assert len(expected_traces) == len(actual_traces) + + +def _debug_print(spaces): + for i, space in enumerate(spaces): + print(f"##### Space {i}") + print(space.mod.script()) + trace = Trace(space.trace.insts, {}) + trace = trace.simplified(remove_postproc=True) + print(str(trace).strip().splitlines()) + + +def test_meta_schedule_cpu_sketch_matmul(): + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l17, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + ], + ] + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ) + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + _check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l17, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + ], + ] + # pylint: enable=line-too-long + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ) + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + _check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_conv2d_nchw(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="compute", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l2, l3, l4, l5, l6, l7, l8 = sch.get_loops(block=b0)", + "v9, v10, v11, v12 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l13, l14, l15, l16 = sch.split(loop=l2, factors=[v9, v10, v11, v12])", + "v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l21, l22, l23, l24 = sch.split(loop=l3, factors=[v17, v18, v19, v20])", + "v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l29, l30, l31, l32 = sch.split(loop=l4, factors=[v25, v26, v27, v28])", + "v33, v34, v35, v36 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l37, l38, l39, l40 = sch.split(loop=l5, factors=[v33, v34, v35, v36])", + "v41, v42 = sch.sample_perfect_tile(loop=l6, n=2, max_innermost_factor=64)", + "l43, l44 = sch.split(loop=l6, factors=[v41, v42])", + "v45, v46 = sch.sample_perfect_tile(loop=l7, n=2, max_innermost_factor=64)", + "l47, l48 = sch.split(loop=l7, factors=[v45, v46])", + "v49, v50 = sch.sample_perfect_tile(loop=l8, n=2, max_innermost_factor=64)", + "l51, l52 = sch.split(loop=l8, factors=[v49, v50])", + "sch.reorder(l13, l21, l29, l37, l14, l22, l30, l38, l43, l47, l51, l15, l23, l31, l39, l44, l48, l52, l16, l24, l32, l40)", + "sch.reverse_compute_at(block=b1, loop=l38, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="compute", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l2, l3, l4, l5, l6, l7, l8 = sch.get_loops(block=b0)", + "v9, v10, v11, v12 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l13, l14, l15, l16 = sch.split(loop=l2, factors=[v9, v10, v11, v12])", + "v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l21, l22, l23, l24 = sch.split(loop=l3, factors=[v17, v18, v19, v20])", + "v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l29, l30, l31, l32 = sch.split(loop=l4, factors=[v25, v26, v27, v28])", + "v33, v34, v35, v36 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l37, l38, l39, l40 = sch.split(loop=l5, factors=[v33, v34, v35, v36])", + "v41, v42 = sch.sample_perfect_tile(loop=l6, n=2, max_innermost_factor=64)", + "l43, l44 = sch.split(loop=l6, factors=[v41, v42])", + "v45, v46 = sch.sample_perfect_tile(loop=l7, n=2, max_innermost_factor=64)", + "l47, l48 = sch.split(loop=l7, factors=[v45, v46])", + "v49, v50 = sch.sample_perfect_tile(loop=l8, n=2, max_innermost_factor=64)", + "l51, l52 = sch.split(loop=l8, factors=[v49, v50])", + "sch.reorder(l13, l21, l29, l37, l14, l22, l30, l38, l43, l47, l51, l15, l23, l31, l39, l44, l48, l52, l16, l24, l32, l40)", + "sch.reverse_compute_at(block=b1, loop=l37, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="compute", func_name="main")', + "l1, l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b0)", + "v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l12, l13, l14, l15 = sch.split(loop=l1, factors=[v8, v9, v10, v11])", + "v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l20, l21, l22, l23 = sch.split(loop=l2, factors=[v16, v17, v18, v19])", + "v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l28, l29, l30, l31 = sch.split(loop=l3, factors=[v24, v25, v26, v27])", + "v32, v33, v34, v35 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l36, l37, l38, l39 = sch.split(loop=l4, factors=[v32, v33, v34, v35])", + "v40, v41 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l42, l43 = sch.split(loop=l5, factors=[v40, v41])", + "v44, v45 = sch.sample_perfect_tile(loop=l6, n=2, max_innermost_factor=64)", + "l46, l47 = sch.split(loop=l6, factors=[v44, v45])", + "v48, v49 = sch.sample_perfect_tile(loop=l7, n=2, max_innermost_factor=64)", + "l50, l51 = sch.split(loop=l7, factors=[v48, v49])", + "sch.reorder(l12, l20, l28, l36, l13, l21, l29, l37, l42, l46, l50, l14, l22, l30, l38, l43, l47, l51, l15, l23, l31, l39)", + ], + ] + # pylint: enable=line-too-long + ctx = _create_context( + create_prim_func( + te_workload.conv2d_nchw( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ) + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + _check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable=invalid-name + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="bias_add", func_name="main")', + 'b1 = sch.get_block(name="bn_mul", func_name="main")', + 'b2 = sch.get_block(name="bn_add", func_name="main")', + "sch.compute_inline(block=b2)", + "sch.compute_inline(block=b1)", + "sch.compute_inline(block=b0)", + 'b3 = sch.get_block(name="compute", func_name="main")', + "b4, = sch.get_consumers(block=b3)", + "l5, l6, l7, l8, l9, l10, l11 = sch.get_loops(block=b3)", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l5, factors=[v12, v13, v14, v15])", + "v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l24, l25, l26, l27 = sch.split(loop=l6, factors=[v20, v21, v22, v23])", + "v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l32, l33, l34, l35 = sch.split(loop=l7, factors=[v28, v29, v30, v31])", + "v36, v37, v38, v39 = sch.sample_perfect_tile(loop=l8, n=4, max_innermost_factor=64)", + "l40, l41, l42, l43 = sch.split(loop=l8, factors=[v36, v37, v38, v39])", + "v44, v45 = sch.sample_perfect_tile(loop=l9, n=2, max_innermost_factor=64)", + "l46, l47 = sch.split(loop=l9, factors=[v44, v45])", + "v48, v49 = sch.sample_perfect_tile(loop=l10, n=2, max_innermost_factor=64)", + "l50, l51 = sch.split(loop=l10, factors=[v48, v49])", + "v52, v53 = sch.sample_perfect_tile(loop=l11, n=2, max_innermost_factor=64)", + "l54, l55 = sch.split(loop=l11, factors=[v52, v53])", + "sch.reorder(l16, l24, l32, l40, l17, l25, l33, l41, l46, l50, l54, l18, l26, l34, l42, l47, l51, l55, l19, l27, l35, l43)", + "sch.reverse_compute_at(block=b4, loop=l41, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="bias_add", func_name="main")', + 'b1 = sch.get_block(name="bn_mul", func_name="main")', + 'b2 = sch.get_block(name="bn_add", func_name="main")', + "sch.compute_inline(block=b2)", + "sch.compute_inline(block=b1)", + "sch.compute_inline(block=b0)", + 'b3 = sch.get_block(name="compute", func_name="main")', + "b4, = sch.get_consumers(block=b3)", + "l5, l6, l7, l8, l9, l10, l11 = sch.get_loops(block=b3)", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l5, factors=[v12, v13, v14, v15])", + "v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l24, l25, l26, l27 = sch.split(loop=l6, factors=[v20, v21, v22, v23])", + "v28, v29, v30, v31 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l32, l33, l34, l35 = sch.split(loop=l7, factors=[v28, v29, v30, v31])", + "v36, v37, v38, v39 = sch.sample_perfect_tile(loop=l8, n=4, max_innermost_factor=64)", + "l40, l41, l42, l43 = sch.split(loop=l8, factors=[v36, v37, v38, v39])", + "v44, v45 = sch.sample_perfect_tile(loop=l9, n=2, max_innermost_factor=64)", + "l46, l47 = sch.split(loop=l9, factors=[v44, v45])", + "v48, v49 = sch.sample_perfect_tile(loop=l10, n=2, max_innermost_factor=64)", + "l50, l51 = sch.split(loop=l10, factors=[v48, v49])", + "v52, v53 = sch.sample_perfect_tile(loop=l11, n=2, max_innermost_factor=64)", + "l54, l55 = sch.split(loop=l11, factors=[v52, v53])", + "sch.reorder(l16, l24, l32, l40, l17, l25, l33, l41, l46, l50, l54, l18, l26, l34, l42, l47, l51, l55, l19, l27, l35, l43)", + "sch.reverse_compute_at(block=b4, loop=l40, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="bias_add", func_name="main")', + 'b1 = sch.get_block(name="bn_mul", func_name="main")', + 'b2 = sch.get_block(name="bn_add", func_name="main")', + "sch.compute_inline(block=b2)", + "sch.compute_inline(block=b1)", + "sch.compute_inline(block=b0)", + 'b3 = sch.get_block(name="compute", func_name="main")', + "l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b3)", + "v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l15, l16, l17, l18 = sch.split(loop=l4, factors=[v11, v12, v13, v14])", + "v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l23, l24, l25, l26 = sch.split(loop=l5, factors=[v19, v20, v21, v22])", + "v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l31, l32, l33, l34 = sch.split(loop=l6, factors=[v27, v28, v29, v30])", + "v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l39, l40, l41, l42 = sch.split(loop=l7, factors=[v35, v36, v37, v38])", + "v43, v44 = sch.sample_perfect_tile(loop=l8, n=2, max_innermost_factor=64)", + "l45, l46 = sch.split(loop=l8, factors=[v43, v44])", + "v47, v48 = sch.sample_perfect_tile(loop=l9, n=2, max_innermost_factor=64)", + "l49, l50 = sch.split(loop=l9, factors=[v47, v48])", + "v51, v52 = sch.sample_perfect_tile(loop=l10, n=2, max_innermost_factor=64)", + "l53, l54 = sch.split(loop=l10, factors=[v51, v52])", + "sch.reorder(l15, l23, l31, l39, l16, l24, l32, l40, l45, l49, l53, l17, l25, l33, l41, l46, l50, l54, l18, l26, l34, l42)", + ], + ] + # pylint: enable=line-too-long + ctx = _create_context( + create_prim_func( + te_workload.conv2d_nchw_bias_bn_relu( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ) + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + _check_trace(spaces, expected) + + +def test_meta_schedule_sketch_cpu_max_pool2d_nchw(): + expected: List[List[str]] = [[]] + ctx = _create_context( + create_prim_func( + te_workload.max_pool2d_nchw( + n=1, + h=56, + w=56, + ci=512, + padding=1, + ) + ) + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + _check_trace(spaces, expected) + + +if __name__ == "__main__": + test_meta_schedule_cpu_sketch_matmul() + test_meta_schedule_cpu_sketch_matmul_relu() + test_meta_schedule_cpu_sketch_conv2d_nchw() + test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu() + test_meta_schedule_sketch_cpu_max_pool2d_nchw() diff --git a/tests/python/unittest/test_meta_schedule_sketch_cuda.py b/tests/python/unittest/test_meta_schedule_sketch_cuda.py new file mode 100644 index 000000000000..9d18b0f59678 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_sketch_cuda.py @@ -0,0 +1,333 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from typing import List + +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing import te_workload +from tvm.target import Target +from tvm.te import create_prim_func +from tvm.tir.schedule import Trace +from tvm.tir.schedule.schedule import Schedule + + +def _create_context(mod): + from tvm.meta_schedule.testing import ( # pylint: disable=import-outside-toplevel + schedule_rule as sch_rules, + ) + + target = Target("cuda", host="llvm") + ctx = ms.TuneContext( + mod=mod, + target=target, + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=sch_rules.get(target), + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for rule in ctx.sch_rules: + rule.initialize_with_tune_context(ctx) + return ctx + + +def _check_trace(spaces: List[Schedule], expected: List[List[str]]): + expected_traces = {"\n".join(t) for t in expected} + actual_traces = set() + for space in spaces: + trace = Trace(space.trace.insts, {}) + trace = trace.simplified(remove_postproc=True) + str_trace = "\n".join(str(trace).strip().splitlines()) + actual_traces.add(str_trace) + assert str_trace in expected_traces, "\n" + str_trace + assert len(expected_traces) == len(actual_traces) + + +def _debug_print(spaces: List[Schedule]) -> None: + for i, space in enumerate(spaces): + print(f"##### Space {i}") + print(space.mod.script()) + trace = Trace(space.trace.insts, {}) + trace = trace.simplified(remove_postproc=True) + print(str(trace).strip().splitlines()) + + +def test_meta_schedule_cuda_sketch_matmul(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])", + "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])", + "v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64)", + "l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])", + "sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)", + "l31 = sch.fuse(l10, l20)", + 'sch.bind(loop=l31, thread_axis="blockIdx.x")', + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="vthread.x")', + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="threadIdx.x")', + 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=1)", + "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", + "l41 = sch.fuse(l39, l40)", + "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", + "l44, l45 = sch.split(loop=l41, factors=[v42, v43])", + "sch.vectorize(loop=l45)", + 'sch.annotate(block_or_loop=l44, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + 'b46 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b46, loop=l28, preserve_unit_loops=1)", + "l47, l48, l49, l50, l51, l52 = sch.get_loops(block=b46)", + "l53 = sch.fuse(l51, l52)", + "v54, v55 = sch.sample_perfect_tile(loop=l53, n=2, max_innermost_factor=4)", + "l56, l57 = sch.split(loop=l53, factors=[v54, v55])", + "sch.vectorize(loop=l57)", + 'sch.annotate(block_or_loop=l56, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=1)", + ] + ] + # pylint: enable=line-too-long + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ) + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + _check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])", + "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])", + "v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64)", + "l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])", + "sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)", + "l31 = sch.fuse(l10, l20)", + 'sch.bind(loop=l31, thread_axis="blockIdx.x")', + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="vthread.x")', + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="threadIdx.x")', + 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=1)", + "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", + "l41 = sch.fuse(l39, l40)", + "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", + "l44, l45 = sch.split(loop=l41, factors=[v42, v43])", + "sch.vectorize(loop=l45)", + 'sch.annotate(block_or_loop=l44, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + 'b46 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b46, loop=l28, preserve_unit_loops=1)", + "l47, l48, l49, l50, l51, l52 = sch.get_loops(block=b46)", + "l53 = sch.fuse(l51, l52)", + "v54, v55 = sch.sample_perfect_tile(loop=l53, n=2, max_innermost_factor=4)", + "l56, l57 = sch.split(loop=l53, factors=[v54, v55])", + "sch.vectorize(loop=l57)", + 'sch.annotate(block_or_loop=l56, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=1)", + 'b58 = sch.get_block(name="compute", func_name="main")', + "sch.reverse_compute_inline(block=b58)", + ] + ] + # pylint: enable=line-too-long + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ) + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + _check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_conv2d_nchw(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="compute", func_name="main")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l2, l3, l4, l5, l6, l7, l8 = sch.get_loops(block=b0)", + "v9, v10, v11, v12, v13 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l14, l15, l16, l17, l18 = sch.split(loop=l2, factors=[v9, v10, v11, v12, v13])", + "v19, v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l24, l25, l26, l27, l28 = sch.split(loop=l3, factors=[v19, v20, v21, v22, v23])", + "v29, v30, v31, v32, v33 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64)", + "l34, l35, l36, l37, l38 = sch.split(loop=l4, factors=[v29, v30, v31, v32, v33])", + "v39, v40, v41, v42, v43 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64)", + "l44, l45, l46, l47, l48 = sch.split(loop=l5, factors=[v39, v40, v41, v42, v43])", + "v49, v50, v51 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64)", + "l52, l53, l54 = sch.split(loop=l6, factors=[v49, v50, v51])", + "v55, v56, v57 = sch.sample_perfect_tile(loop=l7, n=3, max_innermost_factor=64)", + "l58, l59, l60 = sch.split(loop=l7, factors=[v55, v56, v57])", + "v61, v62, v63 = sch.sample_perfect_tile(loop=l8, n=3, max_innermost_factor=64)", + "l64, l65, l66 = sch.split(loop=l8, factors=[v61, v62, v63])", + "sch.reorder(l14, l24, l34, l44, l15, l25, l35, l45, l16, l26, l36, l46, l52, l58, l64, l53, l59, l65, l17, l27, l37, l47, l54, l60, l66, l18, l28, l38, l48)", + "l67 = sch.fuse(l14, l24, l34, l44)", + 'sch.bind(loop=l67, thread_axis="blockIdx.x")', + "l68 = sch.fuse(l15, l25, l35, l45)", + 'sch.bind(loop=l68, thread_axis="vthread.x")', + "l69 = sch.fuse(l16, l26, l36, l46)", + 'sch.bind(loop=l69, thread_axis="threadIdx.x")', + 'b70 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b70, loop=l64, preserve_unit_loops=1)", + "l71, l72, l73, l74, l75, l76, l77, l78, l79, l80 = sch.get_loops(block=b70)", + "l81 = sch.fuse(l77, l78, l79, l80)", + "v82, v83 = sch.sample_perfect_tile(loop=l81, n=2, max_innermost_factor=4)", + "l84, l85 = sch.split(loop=l81, factors=[v82, v83])", + "sch.vectorize(loop=l85)", + 'sch.annotate(block_or_loop=l84, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + 'b86 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b86, loop=l64, preserve_unit_loops=1)", + "l87, l88, l89, l90, l91, l92, l93, l94, l95, l96 = sch.get_loops(block=b86)", + "l97 = sch.fuse(l93, l94, l95, l96)", + "v98, v99 = sch.sample_perfect_tile(loop=l97, n=2, max_innermost_factor=4)", + "l100, l101 = sch.split(loop=l97, factors=[v98, v99])", + "sch.vectorize(loop=l101)", + 'sch.annotate(block_or_loop=l100, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + "sch.reverse_compute_at(block=b1, loop=l69, preserve_unit_loops=1)", + 'b102 = sch.get_block(name="pad_temp", func_name="main")', + "sch.compute_inline(block=b102)", + ] + ] + # pylint: enable=line-too-long + ctx = _create_context( + create_prim_func( + te_workload.conv2d_nchw( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ) + ) + + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + _check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable=invalid-name + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="bias_add", func_name="main")', + 'b1 = sch.get_block(name="bn_mul", func_name="main")', + 'b2 = sch.get_block(name="bn_add", func_name="main")', + "sch.compute_inline(block=b2)", + "sch.compute_inline(block=b1)", + "sch.compute_inline(block=b0)", + 'b3 = sch.get_block(name="compute", func_name="main")', + 'b4 = sch.cache_write(block=b3, write_buffer_index=0, storage_scope="local")', + "l5, l6, l7, l8, l9, l10, l11 = sch.get_loops(block=b3)", + "v12, v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64)", + "l17, l18, l19, l20, l21 = sch.split(loop=l5, factors=[v12, v13, v14, v15, v16])", + "v22, v23, v24, v25, v26 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64)", + "l27, l28, l29, l30, l31 = sch.split(loop=l6, factors=[v22, v23, v24, v25, v26])", + "v32, v33, v34, v35, v36 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64)", + "l37, l38, l39, l40, l41 = sch.split(loop=l7, factors=[v32, v33, v34, v35, v36])", + "v42, v43, v44, v45, v46 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64)", + "l47, l48, l49, l50, l51 = sch.split(loop=l8, factors=[v42, v43, v44, v45, v46])", + "v52, v53, v54 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64)", + "l55, l56, l57 = sch.split(loop=l9, factors=[v52, v53, v54])", + "v58, v59, v60 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64)", + "l61, l62, l63 = sch.split(loop=l10, factors=[v58, v59, v60])", + "v64, v65, v66 = sch.sample_perfect_tile(loop=l11, n=3, max_innermost_factor=64)", + "l67, l68, l69 = sch.split(loop=l11, factors=[v64, v65, v66])", + "sch.reorder(l17, l27, l37, l47, l18, l28, l38, l48, l19, l29, l39, l49, l55, l61, l67, l56, l62, l68, l20, l30, l40, l50, l57, l63, l69, l21, l31, l41, l51)", + "l70 = sch.fuse(l17, l27, l37, l47)", + 'sch.bind(loop=l70, thread_axis="blockIdx.x")', + "l71 = sch.fuse(l18, l28, l38, l48)", + 'sch.bind(loop=l71, thread_axis="vthread.x")', + "l72 = sch.fuse(l19, l29, l39, l49)", + 'sch.bind(loop=l72, thread_axis="threadIdx.x")', + 'b73 = sch.cache_read(block=b3, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b73, loop=l67, preserve_unit_loops=1)", + "l74, l75, l76, l77, l78, l79, l80, l81, l82, l83 = sch.get_loops(block=b73)", + "l84 = sch.fuse(l80, l81, l82, l83)", + "v85, v86 = sch.sample_perfect_tile(loop=l84, n=2, max_innermost_factor=4)", + "l87, l88 = sch.split(loop=l84, factors=[v85, v86])", + "sch.vectorize(loop=l88)", + 'sch.annotate(block_or_loop=l87, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + 'b89 = sch.cache_read(block=b3, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b89, loop=l67, preserve_unit_loops=1)", + "l90, l91, l92, l93, l94, l95, l96, l97, l98, l99 = sch.get_loops(block=b89)", + "l100 = sch.fuse(l96, l97, l98, l99)", + "v101, v102 = sch.sample_perfect_tile(loop=l100, n=2, max_innermost_factor=4)", + "l103, l104 = sch.split(loop=l100, factors=[v101, v102])", + "sch.vectorize(loop=l104)", + 'sch.annotate(block_or_loop=l103, ann_key="meta_schedule.lazy_cooperative_fetch", ann_val=1)', + "sch.reverse_compute_at(block=b4, loop=l72, preserve_unit_loops=1)", + 'b105 = sch.get_block(name="pad_temp", func_name="main")', + 'b106 = sch.get_block(name="compute_1", func_name="main")', + "sch.reverse_compute_inline(block=b106)", + "sch.compute_inline(block=b105)", + ] + ] + # pylint: enable=line-too-long + ctx = _create_context( + create_prim_func( + te_workload.conv2d_nchw_bias_bn_relu( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ) + ) + + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + _check_trace(spaces, expected) + + +if __name__ == "__main__": + test_meta_schedule_cuda_sketch_matmul() + test_meta_schedule_cuda_sketch_matmul_relu() + test_meta_schedule_cuda_sketch_conv2d_nchw() + test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu() diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index 617c75b75cd9..f1b588013099 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -272,6 +272,27 @@ def elementwise_multi_loads_inlined(a: T.handle, c: T.handle) -> None: C[vi, vj] = A[vi, vj] * 2.0 + A[vi, vj + 1] * 2.0 + A[vi, vj + 2] * 2.0 +@T.prim_func +def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + compute = T.match_buffer(var_compute, [512, 512], dtype="float32") + C = T.alloc_buffer([512, 512], dtype="float32") + for i0, i1, i2 in T.grid(512, 512, 512): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads([C[i, j], A[i, k], B[k, j]]) + T.writes([C[i, j]]) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + for i0, i1 in T.grid(512, 512): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads([C[i0_1, i1_1]]) + T.writes([compute[i0_1, i1_1]]) + compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) + # pylint: enable=no-member,invalid-name,unused-variable @@ -401,6 +422,13 @@ def test_buffer_matched(): sch.compute_inline(block_b) +def test_output_block(): + sch = tir.Schedule(matmul_relu, debug_mask="all") + block= sch.get_block("compute") + with pytest.raises(tvm.tir.ScheduleError): + sch.compute_inline(block) + + def test_compute_inline_predicate(): sch = tir.Schedule(elementwise_predicate, debug_mask="all") block_b = sch.get_block("B")