From 70139e0fea9af325a47e6f5faa4453fc02af7ebd Mon Sep 17 00:00:00 2001 From: Hongyi Jin <3231950289@qq.com> Date: Thu, 20 Jan 2022 03:48:02 +0800 Subject: [PATCH] [MetaSchedule] Schedule Rule: Add RFactor (#9975) * add rfactor * format * fix ci --- include/tvm/meta_schedule/schedule_rule.h | 10 + .../meta_schedule/schedule_rule/__init__.py | 1 + .../schedule_rule/add_rfactor.py | 49 + .../meta_schedule/testing/schedule_rule.py | 8 + .../tvm/meta_schedule/testing/te_workload.py | 877 ++++++++++++++++++ .../schedule_rule/add_rfactor.cc | 122 +++ src/meta_schedule/utils.h | 20 + src/target/target_kind.cc | 1 + src/tir/schedule/analysis.h | 38 + src/tir/schedule/analysis/analysis.cc | 186 ++++ src/tir/schedule/utils.h | 54 ++ ...meta_schedule_schedule_rule_add_rfactor.py | 80 ++ 12 files changed, 1446 insertions(+) create mode 100644 python/tvm/meta_schedule/schedule_rule/add_rfactor.py create mode 100644 python/tvm/meta_schedule/testing/te_workload.py create mode 100644 src/meta_schedule/schedule_rule/add_rfactor.cc create mode 100644 tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 6ee394791991..95fce13df02f 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -152,6 +152,16 @@ class ScheduleRule : public runtime::ObjectRef { Optional vector_load_max_len, // Optional> reuse_read, // Optional> reuse_write); + /*! + * \brief Create a rule: add-rfactor to some blocks if needed + * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the + * uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable + * parallelism. + * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, // + Optional max_innermost_factor); /*! * \brief A rule that randomly select a compute-at location for a free block * \return The rule created diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index 9ad3c0627ea9..475c43a3fda1 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -16,6 +16,7 @@ Meta Schedule schedule rules are used for modification of blocks in a schedule. See also PostOrderApply. """ +from .add_rfactor import AddRFactor from .auto_inline import AutoInline from .schedule_rule import PyScheduleRule, ScheduleRule from .random_compute_location import RandomComputeLocation diff --git a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py new file mode 100644 index 000000000000..72f9fc92f96e --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Add-rfactor Rule that add-rfactor to some blocks if needed""" +from typing import Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.AddRFactor") +class AddRFactor(ScheduleRule): + """Rules for add-rfactor to some blocks if needed. + + Parameters + ---------- + max_jobs_per_core: int + The maximum number of jobs to be launched per CPU core. It sets the uplimit of CPU + parallelism, i.e. `num_cores * max_jobs_per_core`. + Use -1 to disable parallelism. + max_innermost_factor: Optional[int] = None + The maximum size of the innermost factor. None means no limit. + """ + + def __init__( + self, + max_jobs_per_core: int = 16, + max_innermost_factor: Optional[int] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleAddRFactor, # type: ignore # pylint: disable=no-member + max_jobs_per_core, + max_innermost_factor, + ) diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index e69be1333092..020869da4b10 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -16,6 +16,7 @@ # under the License. """Default schedule rules""" from tvm.meta_schedule.schedule_rule import ( + AddRFactor, AutoInline, ScheduleRule, ) @@ -45,3 +46,10 @@ def auto_inline(target: Target) -> ScheduleRule: disallow_op=None, ) raise NotImplementedError(f"{target.kind.name} is not supported") + + +def add_rfactor(target: Target) -> ScheduleRule: + """Default schedule rules for with add_rfactor""" + if target.kind.name == "llvm": + return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64) + raise NotImplementedError(f"{target.kind.name} is not supported") diff --git a/python/tvm/meta_schedule/testing/te_workload.py b/python/tvm/meta_schedule/testing/te_workload.py new file mode 100644 index 000000000000..49a60a27526a --- /dev/null +++ b/python/tvm/meta_schedule/testing/te_workload.py @@ -0,0 +1,877 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Workloads in TE""" +# pylint: disable=missing-docstring +from typing import Tuple + +from tvm import te, tir, topi + + +def batch_matmul_nkkm( # pylint: disable=invalid-name,missing-docstring + B: int, + N: int, + M: int, + K: int, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + x = te.placeholder((B, N, K), name="X") + y = te.placeholder((B, K, M), name="Y") + k = te.reduce_axis((0, K), name="k") + z = te.compute( # pylint: disable=invalid-name + (B, N, M), + lambda b, i, j: te.sum(x[b][i][k] * y[b][k][j], axis=[k]), + name="Z", + ) + return (x, y, z) + + +def conv1d_nlc( # pylint: disable=invalid-name,missing-docstring + N: int, + L: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, L, CI), name="inputs") + weight = te.placeholder((kernel_size, CI // groups, CO), name="weight") + + batch_size, in_len, _ = inputs.shape + k_len, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + out_len = (in_len + 2 * padding - dilation * (k_len - 1) - 1) // stride + 1 + rc = te.reduce_axis((0, channel_per_group), name="rc") + rl = te.reduce_axis((0, k_len), name="rl") + + padded = topi.nn.pad(inputs, [0, padding, 0]) + output = te.compute( + (batch_size, out_len, out_channel), + lambda n, l, co: te.sum( + ( + padded[ + n, + l * stride + rl * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rl, rc, co] + ), + axis=[rl, rc], + ), + name="conv1d_nlc", + ) + return (inputs, weight, output) + + +def conv2d_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, CI), name="inputs") + weight = te.placeholder((kernel_size, kernel_size, CI // groups, CO), name="weight") + batch_size, in_h, in_w, _ = inputs.shape + k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + rc = te.reduce_axis((0, channel_per_group), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, co: te.sum( + ( + padded[ + n, + h * stride + rh * dilation, + w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rh, rw, rc, co] + ), + axis=[rh, rw, rc], + ), + name="conv2d_nhwc", + ) + return (inputs, weight, output) + + +def conv3d_ndhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + D: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, D, H, W, CI), name="inputs") + weight = te.placeholder( + (kernel_size, kernel_size, kernel_size, CI // groups, CO), name="weight" + ) + batch_size, in_d, in_h, in_w, _ = inputs.shape + k_d, k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_d = (in_d + 2 * padding - dilation * (k_d - 1) - 1) // stride + 1 + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rd = te.reduce_axis((0, k_d), name="rd") + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + rc = te.reduce_axis((0, channel_per_group), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, padding, 0]) + output = te.compute( + (batch_size, out_d, out_h, out_w, out_channel), + lambda n, d, h, w, co: te.sum( + ( + padded[ + n, + d * stride + rd * dilation, + h * stride + rh * dilation, + w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rd, rh, rw, rc, co] + ), + axis=[rd, rh, rw, rc], + ), + name="conv3d_ndhwc", + ) + return (inputs, weight, output) + + +def depthwise_conv2d_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + C: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + factor: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, C)) + weight = te.placeholder((factor, kernel_size, kernel_size, C)) + batch_size, in_h, in_w, in_channel = inputs.shape + factor, k_h, k_w, in_channel = weight.shape + out_channel = in_channel * factor + assert int(factor) == 1, "Not optimized for factor != 1" + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, c: te.sum( + ( + padded[ + n, + h * stride + rh * dilation, + w * stride + rw * dilation, + c // factor, + ] + * weight[c % factor, rh, rw, c // factor] + ), + axis=[rh, rw], + ), + name="depth_conv2d_nhwc", + ) + return (inputs, weight, output) + + +def conv2d_transpose_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, CI), name="inputs") + weight = te.placeholder((kernel_size, kernel_size, CI, CO), name="weight") + + batch, in_h, in_w, in_c = inputs.shape + filter_h, filter_w, in_c, out_c = weight.shape + stride_h, stride_w = (stride, stride) + + # compute padding + fpad_top, fpad_left, fpad_bottom, fpad_right = topi.nn.get_pad_tuple( + padding, (filter_h, filter_w) + ) + bpad_top = filter_h - 1 - fpad_top + bpad_bottom = filter_h - 1 - fpad_bottom + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + + # padding stage + padded = topi.nn.pad( + inputs, + [ + 0, + (bpad_top + stride_h - 1) // stride_h, + (bpad_left + stride_w - 1) // stride_w, + 0, + ], + [ + 0, + (bpad_bottom + stride_h - 1) // stride_h, + (bpad_right + stride_w - 1) // stride_w, + 0, + ], + ) + + # remove extra padding introduced by dilatation + idx_div = te.indexdiv + idx_mod = te.indexmod + border_h = idx_mod(stride_h - idx_mod(bpad_top, stride_h), stride_h) + border_w = idx_mod(stride_w - idx_mod(bpad_left, stride_w), stride_w) + + # dilation stage + strides = [1, stride_h, stride_w, 1] + n = len(padded.shape) + + # We should embed this dilation directly into te.compute rather than creating a new te.compute. + # Only in this way can we use unroll to eliminate the multiplication of zeros. + def _dilate(*indices): + not_zero = [] + index_tuple = [] + for i in range(n): + if not strides[i] == 1: + index_tuple.append(idx_div(indices[i], strides[i])) + not_zero.append(idx_mod(indices[i], strides[i]).equal(0)) + else: + index_tuple.append(indices[i]) + if not_zero: + not_zero = te.all(*not_zero) + return te.if_then_else(not_zero, padded(*index_tuple), tir.const(0.0, padded.dtype)) + return padded(*index_tuple) + + # convolution stage + out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + rc = te.reduce_axis((0, in_c), name="rc") + rh = te.reduce_axis((0, filter_h), name="rh") + rw = te.reduce_axis((0, filter_w), name="rw") + + output = te.compute( + (batch, out_h, out_w, out_c), + lambda n, h, w, co: te.sum( + _dilate(n, h + rh + border_h, w + rw + border_w, rc) + * weight[filter_h - 1 - rh, filter_w - 1 - rw, rc, co], + axis=[rh, rw, rc], + ), + name="conv2d_transpose_nhwc", + ) + return (inputs, weight, output) + + +def conv2d_capsule_nhwijc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + capsule_size: int = 4, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, capsule_size, capsule_size, CI), name="inputs") + weight = te.placeholder( + (kernel_size, kernel_size, capsule_size, capsule_size, CI, CO), name="weight" + ) + batch_size, in_h, in_w, _, _, in_channel = inputs.shape + k_h, k_w, _, _, _, out_channel = weight.shape + + out_h = (in_h + 2 * padding - kernel_size) // stride + 1 + out_w = (in_w + 2 * padding - kernel_size) // stride + 1 + + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + cap_k = te.reduce_axis((0, capsule_size), name="cap_k") + rc = te.reduce_axis((0, in_channel), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0, 0, 0]) + output = te.compute( + (batch_size, out_h, out_w, capsule_size, capsule_size, out_channel), + lambda n, h, w, cap_i, cap_j, co: te.sum( + ( + padded[n, h * stride + rh, w * stride + rw, cap_i, cap_k, rc] + * weight[rh, rw, cap_k, cap_j, rc, co] + ), + axis=[rh, rw, cap_k, rc], + ), + name="conv2d_capsule_nhwijc", + ) + return (inputs, weight, output) + + +def norm_bmn( # pylint: disable=invalid-name,missing-docstring + B: int, + M: int, + N: int, +) -> Tuple[te.Tensor, te.Tensor]: + a = te.placeholder((B, M, N), name="A") + i = te.reduce_axis((0, M), name="i") + j = te.reduce_axis((0, N), name="j") + c = te.compute( + (B,), + lambda b: te.sum(a[b][i][j] * a[b][i][j], axis=[i, j]), + name="C", + ) + d = te.compute((B,), lambda b: te.sqrt(c[b]), name="D") + return (a, d) + + +def conv2d_nhwc_without_layout_rewrite( # pylint: disable=invalid-name + Input: int, + Filter: int, + stride: int, + padding: int, + dilation: int, + out_dtype="float32", +): + """A copy of `topi.nn.conv2d_nhwc` but without the 'layout_free` attribute. + We use this in single op and subgraph evaluation + because we don't want to introduce graph level optimization. + """ + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_height, in_width, in_channel = Input.shape # type: ignore + kernel_h, kernel_w, _channel, num_filter = Filter.shape # type: ignore + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = topi.nn.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + out_channel = num_filter + out_height = topi.utils.simplify( + (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1 + ) + out_width = topi.utils.simplify( + (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1 + ) + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + PaddedInput = topi.nn.pad(Input, pad_before, pad_after, name="PaddedInput") + rc = te.reduce_axis((0, in_channel), name="rc") + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + Output = te.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: te.sum( + PaddedInput[ + nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc + ].astype(out_dtype) + * Filter[ry, rx, rc, ff].astype(out_dtype), # type: ignore + axis=[ry, rx, rc], + ), + name="Conv2dOutput", + tag="conv2d_nhwc", + ) + return Output + + +def conv2d_nhwc_bn_relu( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + strides: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor]: + data = te.placeholder((N, H, W, CI), name="data") + kernel = te.placeholder((kernel_size, kernel_size, CI, CO), name="kernel") + bias = te.placeholder((CO,), name="bias") + bn_scale = te.placeholder((CO,), name="bn_scale") + bn_offset = te.placeholder((CO,), name="bn_offset") + OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding, dilation) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] + bias[l], name="bias_add" + ) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] * bn_scale[l], name="bn_mul" + ) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] + bn_offset[l], name="bn_add" + ) + out = topi.nn.relu(conv) + return (data, kernel, bias, bn_offset, bn_scale, out) + + +def transpose_batch_matmul( # pylint: disable=invalid-name,missing-docstring + batch: int, + seq_len: int, + n_head: int, + n_dim: int, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + query = te.placeholder((batch, seq_len, n_head, n_dim), name="query") + value = te.placeholder((batch, seq_len, n_head, n_dim), name="value") + query_T = te.compute( + (batch, n_head, seq_len, n_dim), + lambda b, h, l, d: query[b, l, h, d], + name="query_T", + ) + value_T = te.compute( + (batch, n_head, n_dim, seq_len), + lambda b, h, d, l: value[b, l, h, d], + name="value_T", + ) + k = te.reduce_axis((0, n_dim), name="k") + out = te.compute( + (batch, n_head, seq_len, seq_len), + lambda b, h, i, j: te.sum(query_T[b, h, i, k] * value_T[b, h, k, j], axis=[k]), + name="C", + ) + return (query, value, out) + + +def conv2d_winograd_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + tile_size = 4 # _infer_tile_size(data, kernel) + inputs = te.placeholder((N, H, W, CI), name="inputs") + N, H, W, CI = topi.utils.get_const_tuple(inputs.shape) + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation" + + KH = KW = kernel_size + HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) + HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride + assert HSTR == 1 and WSTR == 1 and KH == KW + + data_pad = topi.nn.pad(inputs, (0, HPAD, WPAD, 0), (0, HPAD, WPAD, 0), name="data_pad") + + r = KW + m = tile_size + alpha = m + r - 1 + A, B, _G = topi.nn.winograd_util.winograd_transform_matrices(m, r, "float32") + + H = (H + 2 * HPAD - KH) // HSTR + 1 + W = (W + 2 * WPAD - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + P = N * nH * nW + _rkh = te.reduce_axis((0, KH), name="r_kh") + _rkw = te.reduce_axis((0, KW), name="r_kw") + kshape = (alpha, alpha, CI, CO) + kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") + + idxdiv = te.indexdiv + idxmod = te.indexmod + # pack input tile + input_tile = te.compute( + (alpha, alpha, P, CI), + lambda eps, nu, p, ci: data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p, nW), nH) * m + eps][ + idxmod(p, nW) * m + nu + ][ci], + name="input_tile", + ) + + # transform data + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_b") + data_pack = te.compute( + (alpha, alpha, P, CI), + lambda eps, nu, p, ci: te.sum( + input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] + ), + name="data_pack", + attrs={"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"]}, + ) + + # do batch gemm + ci = te.reduce_axis((0, CI), name="ci") + bgemm = te.compute( + (alpha, alpha, P, CO), + lambda eps, nu, p, co: te.sum( + data_pack[eps][nu][p][ci] * kernel_pack[eps][nu][ci][co], axis=[ci] + ), + name="bgemm", + ) + + # inverse transform + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_b") + inverse = te.compute( + (m, m, P, CO), + lambda vh, vw, p, co: te.sum( + bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b] + ), + name="inverse", + attrs={"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"]}, + ) + + # output + output = te.compute( + (N, H, W, CO), + lambda n, h, w, co: inverse[ + idxmod(h, m), idxmod(w, m), n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), co + ], + name="conv2d_winograd", + ) + + return (inputs, kernel_pack, output) + + +def matmul(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A") + b = te.placeholder((k, m), name="B") + k = te.reduce_axis((0, k), name="k") + c = te.compute( + (n, m), + lambda i, j: te.sum(a[i, k] * b[k, j], axis=[k]), + name="C", + ) + return (a, b, c) + + +def matmul_fp16(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A", dtype="float16") + b = te.placeholder((k, m), name="B", dtype="float16") + k = te.reduce_axis((0, k), name="k") + + def f_compute(i, j): + v_a = tir.Cast(dtype="float32", value=a[i, k]) + v_b = tir.Cast(dtype="float32", value=b[k, j]) + return te.sum(v_a * v_b, axis=[k]) + + c = te.compute((n, m), f_compute, name="C") + return (a, b, c) + + +def matmul_relu(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A") + b = te.placeholder((m, k), name="B") + k = te.reduce_axis((0, k), name="k") + c = te.compute( + (n, m), + lambda i, j: te.sum(a[i, k] * b[k, j], axis=[k]), + name="C", + ) + d = topi.nn.relu(c) # pylint: disable=invalid-name + return (a, b, d) + + +def matmul_relu_fp16(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A", dtype="float16") + b = te.placeholder((k, m), name="B", dtype="float16") + k = te.reduce_axis((0, k), name="k") + + def f_compute(i, j): + v_a = tir.Cast(dtype="float32", value=a[i, k]) + v_b = tir.Cast(dtype="float32", value=b[k, j]) + return te.sum(v_a * v_b, axis=[k]) + + c = te.compute((n, m), f_compute, name="C") + d = topi.nn.relu(c) # pylint: disable=invalid-name + return (a, b, d) + + +def conv2d_nchw( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + co: int, + kh: int, + kw: int, + stride: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + x = te.placeholder((n, ci, h, w), name="X") + w = te.placeholder((co, ci, kh, kw), name="W") + y = topi.nn.conv2d_nchw(Input=x, Filter=w, stride=stride, padding=padding, dilation=dilation) + return (x, w, y) + + +def conv2d_nchw_bias_bn_relu( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + co: int, + kh: int, + kw: int, + stride: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor]: + oh = (h + 2 * padding - (kh - 1) * dilation - 1) // stride + 1 # pylint: disable=invalid-name + ow = (w + 2 * padding - (kw - 1) * dilation - 1) // stride + 1 # pylint: disable=invalid-name + x = te.placeholder((n, ci, h, w), name="X") + w = te.placeholder((co, ci, kh, kw), name="W") + b = te.placeholder((co, 1, 1), name="B") + bn_scale = te.placeholder((co, 1, 1), name="bn_scale") + bn_offset = te.placeholder((co, 1, 1), name="bn_offset") + y = topi.nn.conv2d_nchw(Input=x, Filter=w, stride=stride, padding=padding, dilation=dilation) + y = te.compute((n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] + b[j, 0, 0], name="bias_add") + y = te.compute( + (n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] * bn_scale[j, 0, 0], name="bn_mul" + ) + y = te.compute( + (n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] + bn_offset[j, 0, 0], name="bn_add" + ) + y = topi.nn.relu(y) + return (x, w, b, bn_scale, bn_offset, y) + + +def max_pool2d_nchw( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + padding: int, +) -> Tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid-name + x = te.placeholder((n, ci, h, w), name="X") + y = topi.nn.pool2d(x, [2, 2], [1, 1], [1, 1], [padding, padding, padding, padding], "max") + return (x, y) + + +def softmax_mn(m, n) -> Tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid-name + a = te.placeholder((m, n), name="A") + b = topi.nn.softmax(a, axis=1) + + return (a, b) + + +def create_te_workload(name: str, idx: int) -> tir.PrimFunc: + workload_func, params = CONFIGS[name] + return te.create_prim_func(workload_func(*params[idx])) # type: ignore + + +CONFIGS = { + "C1D": ( + conv1d_nlc, + [ + # derived from conv2d_shapes + (1, 256, 64, 128, 3, 2, 1), + # (1, 256, 64, 128, 1, 2, 0), + # (1, 256, 64, 64, 1, 1, 0), + # (1, 128, 128, 256, 3, 2, 1), + (1, 128, 128, 256, 1, 2, 0), + # (1, 128, 128, 128, 3, 1, 1), + # (1, 64, 256, 512, 3, 2, 1), + # (1, 64, 256, 512, 1, 2, 0), + (1, 64, 256, 256, 5, 1, 2), + (1, 32, 512, 512, 3, 1, 1), + ], + ), + "C2D": ( + conv2d_nhwc, + [ + # all conv2d layers in resnet-18 + (1, 224, 224, 3, 64, 7, 2, 3), + # (1, 56, 56, 64, 128, 3, 2, 1), + # (1, 56, 56, 64, 128, 1, 2, 0), + # (1, 56, 56, 64, 64, 3, 1, 1), + (1, 56, 56, 64, 64, 1, 1, 0), + # (1, 28, 28, 128, 256, 3, 2, 1), + # (1, 28, 28, 128, 256, 1, 2, 0), + # (1, 28, 28, 128, 128, 3, 1, 1), + # (1, 14, 14, 256, 512, 3, 2, 1), + # (1, 14, 14, 256, 512, 1, 2, 0), + (1, 14, 14, 256, 256, 3, 1, 1), + (1, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "C3D": ( + conv3d_ndhwc, + [ + # Derived from conv2d_shapes. Use depth=16 for all configurations + (1, 16, 224, 224, 3, 64, 7, 2, 3), + # (1, 16, 56, 56, 64, 128, 3, 2, 1), + # (1, 16, 56, 56, 64, 128, 1, 2, 0), + # (1, 16, 56, 56, 64, 64, 3, 1, 1), + (1, 16, 56, 56, 64, 64, 1, 1, 0), + # (1, 16, 28, 28, 128, 256, 3, 2, 1), + # (1, 16, 28, 28, 128, 256, 1, 2, 0), + # (1, 16, 28, 28, 128, 128, 3, 1, 1), + # (1, 16, 14, 14, 256, 512, 3, 2, 1), + # (1, 16, 14, 14, 256, 512, 1, 2, 0), + (1, 16, 14, 14, 256, 256, 3, 1, 1), + (1, 16, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "GMM": ( + batch_matmul_nkkm, + [ + (1, 128, 128, 128), + (1, 512, 32, 512), + (1, 512, 512, 512), + (1, 1024, 1024, 1024), + ], + ), + "GRP": ( + conv2d_nhwc, + [ + # Derived from conv2d_shapes. Use group=4 for all configurations + (1, 56, 56, 64, 128, 3, 2, 1, 1, 4), + # (1, 56, 56, 64, 128, 1, 2, 0 , 1, 4), + # (1, 56, 56, 64, 64, 3, 1, 1 , 1, 4), + (1, 56, 56, 64, 64, 1, 1, 0, 1, 4), + # (1, 28, 28, 128, 256, 3, 2, 1, 1, 4), + # (1, 28, 28, 128, 256, 1, 2, 0, 1, 4), + # (1, 28, 28, 128, 128, 3, 1, 1, 1, 4), + # (1, 14, 14, 256, 512, 3, 2, 1, 1, 4), + # (1, 14, 14, 256, 512, 1, 2, 0, 1, 4), + (1, 14, 14, 256, 256, 3, 1, 1, 1, 4), + (1, 7, 7, 512, 512, 3, 1, 1, 1, 4), + ], + ), + "DIL": ( + conv2d_nhwc, + [ + # Derived from conv2d_shapes. Use dilation=2 for all configurations + (1, 224, 224, 3, 64, 7, 2, 3, 2), + # (1, 56, 56, 64, 128, 3, 2, 1 , 2), + # (1, 56, 56, 64, 128, 1, 2, 0 , 2), + # (1, 56, 56, 64, 64, 3, 1, 1 , 2), + (1, 56, 56, 64, 64, 1, 1, 0, 2), + # (1, 28, 28, 128, 256, 3, 2, 1, 2), + # (1, 28, 28, 128, 256, 1, 2, 0, 2), + # (1, 28, 28, 128, 128, 3, 1, 1, 2), + # (1, 14, 14, 256, 512, 3, 2, 1, 2), + # (1, 14, 14, 256, 512, 1, 2, 0, 2), + (1, 14, 14, 256, 256, 3, 1, 1, 2), + (1, 7, 7, 512, 512, 3, 1, 1, 2), + ], + ), + "DEP": ( + depthwise_conv2d_nhwc, + [ + # all depthwise conv2d layers in mobilenet + (1, 112, 112, 32, 3, 1, 1), + (1, 112, 112, 64, 3, 2, 1), + # (1, 56, 56, 128, 3, 1, 1), + # (1, 56, 56, 128, 3, 2, 1), + # (1, 28, 28, 256, 3, 1, 1), + # (1, 28, 28, 256, 3, 2, 1), + # (1, 14, 14, 512, 3, 1, 1), + (1, 14, 14, 512, 3, 2, 1), + (1, 7, 7, 1024, 3, 1, 1), + ], + ), + "T2D": ( + conv2d_transpose_nhwc, + [ + # all conv2d tranpose layers in DCGAN + (1, 4, 4, 512, 256, 4, 2, 1), + (1, 8, 8, 256, 128, 4, 2, 1), + (1, 16, 16, 128, 64, 4, 2, 1), + (1, 32, 32, 64, 3, 4, 2, 1), + ], + ), + "CAP": ( + conv2d_capsule_nhwijc, + [ + # all conv2d capsule layers in matrix capsules withemrouting (ICLR 2018) + (1, 16, 16, 32, 32, 3, 2, 1), + (1, 8, 8, 32, 32, 3, 1, 1), + (1, 16, 16, 8, 16, 3, 2, 1), + (1, 8, 8, 16, 16, 3, 1, 1), + ], + ), + "NRM": ( + norm_bmn, + [ + (1, 256, 256), + (1, 512, 512), + (1, 1024, 1024), + (1, 4096, 1024), + ], + ), + "SFM": ( + softmax_mn, + [ + (256, 256), + (512, 512), + (1024, 1024), + (2048, 2048), + ], + ), + "C2d-BN-RELU": ( + conv2d_nhwc_bn_relu, + [ + (1, 224, 224, 3, 64, 7, 2, 3), + (1, 56, 56, 64, 128, 3, 2, 1), + (1, 28, 28, 128, 256, 1, 2, 0), + (1, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "TBG": ( + transpose_batch_matmul, + [ + (1, 128, 12, 64), + (1, 128, 16, 64), + (1, 64, 12, 128), + (1, 128, 12, 128), + ], + ), +} diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc new file mode 100644 index 000000000000..5ef2ac3aad36 --- /dev/null +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class AddRFactorNode : public ScheduleRuleNode { + public: + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(context->target.defined()); + Target target = context->target.value(); + this->max_parallel_basic_ = GetTargetNumCores(target); + if (this->max_jobs_per_core != -1) { + this->max_parallel_extent_ = max_parallel_basic_ * max_jobs_per_core; + } + } + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv); + + public: + /*! + * \brief The maximum number of jobs to be launched per core. + * It sets the uplimit of parallelism, i.e. `num_cores * max_jobs_per_core`. + * Use -1 to disable parallelism. + */ + int max_jobs_per_core; + /*! \brief The maximum size of the innermost factor */ + int max_innermost_factor; + /*! \brief The number of uplimit of parallelism. */ + int max_parallel_extent_; + /*! \brief The number of cores. */ + int max_parallel_basic_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("max_jobs_per_core", &max_jobs_per_core); + v->Visit("max_innermost_factor", &max_innermost_factor); + // `max_parallel_extent_` is not visited + // `max_parallel_basic_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.AddRFactor"; + TVM_DECLARE_FINAL_OBJECT_INFO(AddRFactorNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core, + Optional max_innermost_factor) { + ObjectPtr n = make_object(); + n->max_jobs_per_core = max_jobs_per_core; + n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; + n->max_parallel_extent_ = -1; + n->max_parallel_basic_ = -1; + return ScheduleRule(n); +} + +Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) { + tir::StmtSRef block_sref = sch->GetSRef(block_rv); + if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_, + max_parallel_basic_)) { + return {sch}; + } + + // Make a copy of the original schedule. + tir::Schedule ori_sch = sch->Copy(); + ori_sch->Seed(sch->ForkSeed()); + + // Reorder the loop axes if reduction loops are not innermost. + // After the reordering, fuse all the reduction loops. + size_t num_spatial_loops; + tir::LoopRV fused_reduce_loop; + ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, &num_spatial_loops); + + // Split the fused reduction loop. + Array factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor); + const Array& split_loops = + sch->Split(fused_reduce_loop, {factors.begin(), factors.end()}); + + Array res; + for (const tir::LoopRV& split_loop : split_loops) { + tir::Schedule sch_tmp = sch->Copy(); + sch_tmp->Seed(sch->ForkSeed()); + try { + const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops); + Array axes = sch_tmp->GetLoops(block_rf); + ICHECK_GT(axes.size(), num_spatial_loops); + + // Annotate that the rfactor block, which is now the producer of the original block, needs to + // be considered by the rule Random-Compute-Location. + sch_tmp->Annotate(block_rv, tir::attr::meta_schedule_random_compute_producer, Bool(true)); + res.push_back(sch_tmp); + } catch (const tvm::runtime::Error& e) { + } + } + + res.push_back(ori_sch); + return res; +} + +TVM_REGISTER_NODE_TYPE(AddRFactorNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor") + .set_body_typed(ScheduleRule::AddRFactor); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index ef15f4995541..5b497695400a 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -318,6 +318,26 @@ struct ThreadedTraceApply { Item* items_; }; +/*! + * \brief Get the number of cores in CPU + * \param target The target + * \return The number of cores. + */ +inline int GetTargetNumCores(const Target& target) { + int num_cores = target->GetAttr("num-cores").value_or(-1); + if (num_cores == -1) { + static const auto* f_cpu_count = runtime::Registry::Get("meta_schedule.cpu_count"); + ICHECK(f_cpu_count) + << "ValueError: Cannot find the packed function \"meta_schedule._cpu_count\""; + num_cores = (*f_cpu_count)(false); + LOG(FATAL) + << "Target does not have attribute \"num-cores\", physical core number must be " + "defined! For example, on the local machine, the target must be \"llvm -num-cores " + << num_cores << "\""; + } + return num_cores; +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index e4bf48b2a51e..c562c78bd187 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -254,6 +254,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("mabi") .add_attr_option("system-lib") .add_attr_option("runtime") + .add_attr_option("num-cores") .add_attr_option("link-params", Bool(false)) .add_attr_option("unpacked-api") .add_attr_option("interface-api") diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 9622e2dcd318..636cc7d0a5db 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -520,6 +520,44 @@ std::tuple AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region); +/*! + * \brief Check if the block is a data parallel block, i.e. all the block vars are data parallel + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block is a data parallel block + */ +bool IsSpatial(const StmtSRef& block_sref); + +/*! + * \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop, + * from outer to inner. + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block has a trivial binding + */ +bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref); + +/*! + * \brief Checks if the given block has data reuse opportunity and thus multi-level tiling is + * beneficial. + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean indicating whether the block has data reuse opportunity + */ +bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref); + +/*! + * \brief Checks if the rfactor or cross thread reduction is beneficial to the given block. + * \param self The schedule state. + * \param block_sref The block to be checked. + * \param max_parallel_extent The maximum parallel jobs on the target. + * \param max_parallel_basic The maximum cores on the target. + * \return A boolean indicating whether the operation is beneficial. + */ +bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // + const tir::StmtSRef& block_sref, // + int64_t max_parallel_extent, // + int64_t max_parallel_basic); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 052097314ee2..2053f8ddde93 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1661,5 +1661,191 @@ void CheckStorageScope(const ScheduleState& self, String storage_scope) { } } +bool IsSpatial(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + for (const IterVar& iter_var : block->iter_vars) { + if (iter_var->iter_type != IterVarType::kDataPar) { + return false; + } + } + return true; +} + +bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Array loops = GetLoops(block_sref); + Array binds = GetBlockRealize(self, block_sref)->iter_values; + if (loops.size() != binds.size()) { + return false; + } + for (int i = 0, n = loops.size(); i < n; ++i) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]); + if (binds[i].get() != loop->loop_var.get()) { + return false; + } + } + return true; +} + +bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + if (block->writes.size() != 1 || block->reads.empty() || IsSpatial(block_sref) || + !IsTrivialBinding(self, block_sref)) { + return false; + } + const BufferNode* write_buffer = block->writes[0]->buffer.get(); + // Step 1. Sort out spatial block variables + std::vector spatial_block_vars; + spatial_block_vars.reserve(block->iter_vars.size()); + for (const IterVar& block_var : block->iter_vars) { + if (block_var->iter_type == IterVarType::kDataPar) { + spatial_block_vars.push_back(block_var->var.get()); + } + } + // Step 2. Enumerate each read region, check the number of block vars that are not used + // to index the read region + int total_unused_block_vars = 0; + std::unordered_set read_buffers; + read_buffers.reserve(block->reads.size()); + for (const BufferRegion& buffer_region : block->reads) { + const BufferNode* buffer = buffer_region->buffer.get(); + const Array& regions = buffer_region->region; + // Step 2.1. Duplication of read buffers are not allowed + if (read_buffers.insert(buffer).second == false) { + return false; + } + // Step 2.2. Skip the reduction buffer + if (buffer == write_buffer) { + continue; + } + // Step 2.3. Collect the block vars that are used to index the read region + std::unordered_set vars; + for (const Range& range : regions) { + if (as_const_int(range->extent) == nullptr) { + return false; + } + for (const Var& var : UndefinedVars(range->min)) { + vars.insert(var.get()); + } + } + // Step 2.4. Check if the block vars are not used to index the read region + int n_unused_block_vars = 0; + for (const VarNode* block_var : spatial_block_vars) { + if (vars.count(block_var) == 0) { + ++n_unused_block_vars; + } + } + total_unused_block_vars += n_unused_block_vars; + } + return total_unused_block_vars >= 1; +} + +std::pair GetCumulativeSpaceAndReductionLength(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref) { + Array loops = tir::GetLoops(block_sref); + int64_t cum_space_len = 1, cum_reduce_len = 1; + /* + * Return (-1, -1) if + * 1. there is some loop with type other than kDataPar and kCommReduce; + * 2. there is some loop which is dynamic. + */ + for (const tir::StmtSRef& loop_sref : loops) { + tir::IterVarType type = GetLoopIterType(loop_sref); + if (type == tir::kDataPar) { + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (*extent != -1) { + cum_space_len *= *extent; + } else { + return std::make_pair(-1, -1); + } + } else if (type == tir::kCommReduce) { + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (*extent != -1) { + cum_reduce_len *= *extent; + } else { + return std::make_pair(-1, -1); + } + } else { + return std::make_pair(-1, -1); + } + } + return std::make_pair(cum_space_len, cum_reduce_len); +} + +bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // + const tir::StmtSRef& block_sref, // + int64_t max_parallel_extent, // + int64_t max_parallel_basic) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Array loops = tir::GetLoops(block_sref); + + // Cond 1. The block has only one write buffer + if (block->writes.size() != 1) { + return false; + } + + // Cond 2. The block is a reduction block and has trivial binding. + const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, // + /*require_stage_pipeline=*/false, // + /*require_subtree_compact_dataflow=*/false); + if (!(IsReductionBlock(self, block_sref, scope_sref) && // + IsTrivialBinding(self, block_sref))) { + return false; + } + + // Cond 3. Every the loop axis must be either spatial axis or reduction axis. + for (const tir::StmtSRef& loop_sref : loops) { + const tir::IterVarType& type = GetLoopIterType(loop_sref); + if (type != tir::kDataPar && type != tir::kCommReduce) { + return false; + } + } + + // Cond 4. Whether there is at least one reduction loop. + // Cond 5. The loops are continuous, and the body of the innermost loop is exactly the block. + bool has_reduction_loop = false; + for (size_t i = 0; i < loops.size(); ++i) { + // Cond 4. + if (GetLoopIterType(loops[i]) == tir::kCommReduce) { + has_reduction_loop = true; + } + + // Cond 5. + const ForNode* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]); + if (i < loops.size() - 1) { + const ForNode* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]); + if (loop_i->body.get() != loop_i1) { + return false; + } + } else { + const auto* block_realize = loop_i->body.as(); + if (!block_realize || block_realize->block.get() != block) { + return false; + } + } + } + if (!has_reduction_loop) { + return false; + } + + // Cond 6. Can successfully calculating the cumulative loop length. + int64_t cum_space_len, cum_reduce_len; + std::tie(cum_space_len, cum_reduce_len) = GetCumulativeSpaceAndReductionLength(self, block_sref); + if (cum_space_len == -1 || cum_reduce_len == -1) { + return false; + } + + // Cond 7. + if (NeedsMultiLevelTiling(self, block_sref)) { + // Do not use rfactor/cross-thread-reduction if we have enough parallelism on spatial loops. + return !(cum_space_len >= cum_reduce_len || cum_space_len > max_parallel_extent); + } else if (cum_reduce_len > 1) { + // Always try rfactor/cross-thread-reduction for other reduction blocks. + return cum_reduce_len > max_parallel_basic; + } else { + return false; + } +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 2acab384af0b..be6d5a18a47f 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -320,6 +320,60 @@ inline bool HasAnn(const StmtSRef& sref, const String& ann_key, bool ann_val) { return result.defined() && result.value()->value == ann_val; } +/********** Helper Functions for RuleAddRFactor and RuleCrossThreadReduction **********/ + +/*! + * \brief Reorder the reduction loops to innermost positions if needed. + * \param sch The schedule + * \param block_rv The block where to apply the reorder + * \param fused_reduce_loop The fusion-generated loop to return. + * \param num_spatial_loops The number of spatial loops to return. + * \note Before invoking this helper function, make sure that the block has only spatial and + * reduction loop axes. + */ +inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::BlockRV& block_rv, + tir::LoopRV* fused_reduce_loop, + size_t* num_spatial_loops) { + Array loops = sch->GetLoops(block_rv); + Array loop_srefs; + for (const tir::LoopRV& loop_rv : loops) { + loop_srefs.push_back(sch->GetSRef(loop_rv)); + } + + Array new_order; + // Step 1. Add spatial loops. + *num_spatial_loops = 0; + for (size_t i = 0; i < loops.size(); ++i) { + if (GetLoopIterType(loop_srefs[i]) == tir::kDataPar) { + new_order.push_back(loops[i]); + (*num_spatial_loops)++; + } + } + // Step 2. Add reduction loops. + Array reduction_loops; + for (size_t i = 0; i < loops.size(); ++i) { + if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) { + new_order.push_back(loops[i]); + reduction_loops.push_back(loops[i]); + } + } + // Step 3. Apply reordering if new_order differs from the original order. + ICHECK_EQ(new_order.size(), loops.size()); + for (size_t i = 0; i < loops.size(); ++i) { + if (!new_order[i].same_as(loops[i])) { + sch->Reorder(new_order); + break; + } + } + // Step 4. Fuse all the reduction loops if there are multiple reduction loops. + CHECK(!reduction_loops.empty()) << "ValueError: There should be at least one reduction loop"; + if (reduction_loops.size() > 1) { + *fused_reduce_loop = sch->Fuse(reduction_loops); + } else { + *fused_reduce_loop = reduction_loops[0]; + } +} + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py new file mode 100644 index 000000000000..5a8031220354 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.schedule_rule import add_rfactor +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.target import Target +from tvm.te.operation import create_prim_func + + +def _create_context(mod, target, rule) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_cpu_matmul(): + expected = [ + [], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l6, l7 = sch.split(loop=l3, factors=[v4, v5])", + "b8 = sch.rfactor(loop=l7, factor_axis=2)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l6, l7 = sch.split(loop=l3, factors=[v4, v5])", + "b8 = sch.rfactor(loop=l6, factor_axis=2)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + ], + ] + target = Target("llvm --num-cores=32") + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=4, + m=4, + k=512, + ) + ), + target=target, + rule=add_rfactor(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_cpu_matmul()