Skip to content

Commit

Permalink
[Meta Schedule] Add Auto-Thread Binding Rule
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy authored and junrushao committed May 20, 2022
1 parent a6a3404 commit 121c17e
Show file tree
Hide file tree
Showing 23 changed files with 826 additions and 139 deletions.
10 changes: 9 additions & 1 deletion include/tvm/meta_schedule/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,21 @@ class Mutator : public runtime::ObjectRef {
* \return The created mutator.
*/
TVM_DLL static Mutator MutateParallel(int64_t max_jobs_per_core);
/*! \brief Create a Mutator that mutates auto unroll step */
/*!
* \brief Create a Mutator that mutates auto unroll step
* \return The mutator created
*/
TVM_DLL static Mutator MutateUnroll();
/*!
* \brief Create a Mutator that mutates the outcome of SampleComputeLocation
* \return The mutator created
*/
TVM_DLL static Mutator MutateComputeLocation();
/*!
* \brief Create a Mutator that mutates auto thread binding.
* \return The mutator created
*/
TVM_DLL static Mutator MutateThreadBinding();
/*!
* \brief Create a mutator with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ class Postproc : public runtime::ObjectRef {
TVM_DLL static Postproc RewriteReductionBlock();
/*!
* \brief Create a postprocessor that adds thread binding to unbound blocks
* \param max_threadblock The max number of threadblocks in the cuda device.
* \param max_threadblocks The max number of threadblocks in the cuda device.
* \return The postprocessor created.
*/
TVM_DLL static Postproc RewriteUnboundBlock(int max_threadblock);
TVM_DLL static Postproc RewriteUnboundBlock(int max_threadblocks);
/*!
* \brief Create a postprocessor that applies tensorization to annotated blocks
* \param vectorize_init_loop Whether or not vectorize the initialization loop produced by
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ class ScheduleRule : public runtime::ObjectRef {
int max_vectorize_extent, //
Array<Integer> unroll_max_steps, //
bool unroll_explicit);
/*!
* \brief Auto bind loops around the block to BlockIdx and ThreadIdx
* \param max_threadblocks The maximum number of threadblock on GPU
* \param thread_extents Candidates of thread axis extent.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array<Integer> thread_extents);
/*!
* \brief Create a schedule rule with customized methods on the python-side.
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/mutator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
from .mutator import Mutator, PyMutator
from .mutate_compute_location import MutateComputeLocation
from .mutate_tile_size import MutateTileSize
from .mutate_thread_binding import MutateThreadBinding
from .mutate_parallel import MutateParallel
from .mutate_unroll import MutateUnroll
32 changes: 32 additions & 0 deletions python/tvm/meta_schedule/mutator/mutate_thread_binding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Mutator that mutates the thread binding extent"""
from tvm._ffi.registry import register_object

from .. import _ffi_api
from .mutator import Mutator


@register_object("meta_schedule.MutateThreadBinding")
class MutateThreadBinding(Mutator):
"""Mutator that mutates the binding extent"""

def __init__(self) -> None:
"""Mutator that mutates the binding extent"""
self.__init_handle_by_constructor__(
_ffi_api.MutateThreadBinding, # type: ignore # pylint: disable=no-member
)
5 changes: 3 additions & 2 deletions python/tvm/meta_schedule/postproc/rewrite_unbound_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""A postprocessor that adds thread binding to unbound blocks"""

from tvm._ffi.registry import register_object

from .. import _ffi_api
from .postproc import Postproc

Expand All @@ -25,8 +26,8 @@
class RewriteUnboundBlock(Postproc):
"""A postprocessor that adds thread binding to unbound blocks"""

def __init__(self, max_threadblock: int = 256) -> None:
def __init__(self, max_threadblocks: int = 256) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocRewriteUnboundBlock, # type: ignore # pylint: disable=no-member
max_threadblock,
max_threadblocks,
)
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
blocks in a schedule. See also PostOrderApply.
"""
from .add_rfactor import AddRFactor
from .auto_bind import AutoBind
from .auto_inline import AutoInline
from .cross_thread_reduction import CrossThreadReduction
from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType
Expand Down
49 changes: 49 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/auto_bind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Auto-bind Rule that binds blocks to threads if needed"""
from typing import List, Optional

from tvm._ffi import register_object

from .. import _ffi_api
from .schedule_rule import ScheduleRule


@register_object("meta_schedule.AutoBind")
class AutoBind(ScheduleRule):
"""Auto bind loops around the block to BlockIdx and ThreadIdx
Parameters
----------
max_threadblocks: int
The maximum number of threadblock on GPU.
thread_extents: Optional[List[int]]
Candidates of thread axis extent.
"""

def __init__(
self,
max_threadblocks: int = 256,
thread_extents: Optional[List[int]] = None,
) -> None:
if thread_extents is None:
thread_extents = [32, 64, 128, 256, 512, 1024]
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleAutoBind, # type: ignore # pylint: disable=no-member
max_threadblocks,
thread_extents,
)
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def conv2d_winograd_cpu(
vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap(
"SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1]
)
T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"})
T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse.llvm"})
T.reads(
[
inverse[vh, vw, p_3, co_1],
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def conv2d_winograd_cuda( # type: ignore
vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap(
"SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1]
)
T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"})
T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse.cuda"})
T.reads(
[
inverse[vh, vw, p_3, co_1],
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""Default schedule rules"""
from tvm.meta_schedule.schedule_rule import (
AddRFactor,
AutoBind,
AutoInline,
CrossThreadReduction,
MultiLevelTiling,
Expand All @@ -28,6 +29,13 @@
from tvm.target import Target


def auto_bind(target: Target) -> ScheduleRule:
"""Default schedule rules for auto bind"""
if target.kind.name == "cuda":
return AutoBind(max_threadblocks=256, thread_extents=[32, 64, 128, 256, 512, 1024])
raise NotImplementedError(f"{target.kind.name} is not supported")


def auto_inline(target: Target) -> ScheduleRule:
"""Default schedule rules for auto inline"""
if target.kind.name == "llvm":
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def _sch_rules() -> List[ScheduleRule]:
unroll_max_steps=[0, 16, 64, 512, 1024],
unroll_explicit=True,
),
M.AutoBind(
max_threadblocks=256,
thread_extents=[32, 64, 128, 256, 512, 1024],
),
]

@staticmethod
Expand All @@ -177,7 +181,8 @@ def _mutator_probs() -> Dict[Mutator, float]:

return {
M.MutateTileSize(): 0.9,
M.MutateUnroll(): 0.1,
M.MutateUnroll(): 0.08,
M.MutateThreadBinding(): 0.02,
}


Expand Down Expand Up @@ -842,6 +847,7 @@ def tune_relay(
"""
# pylint: disable=import-outside-toplevel
from tvm.relay import build as relay_build

from .relay_integration import extract_task_from_relay

# pylint: disable=protected-access, enable=import-outside-toplevel
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/conv2d_nhwc_winograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def nhwc_winograd_cuda(
bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]
),
name="inverse",
attrs={"schedule_rule": "meta_schedule.winograd_inverse"},
attrs={"schedule_rule": "meta_schedule.winograd_inverse.cuda"},
)

# Output
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/conv2d_winograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]
),
name="inverse",
attrs={"schedule_rule": "meta_schedule.winograd_inverse"},
attrs={"schedule_rule": "meta_schedule.winograd_inverse.cuda"},
)

# output
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,11 @@ def _conv2d_winograd_nhwc_impl(
bgemm = auto_scheduler.rewrite_compute_body(bgemm, auto_scheduler_rewritten_layout)

# inverse transform
if target is not None:
target_kind = "meta_schedule.winograd_inverse." + target.kind.name
else:
target_kind = "None"

r_a = te.reduce_axis((0, alpha), "r_a")
r_b = te.reduce_axis((0, alpha), "r_b")
inverse = te.compute(
Expand All @@ -1106,7 +1111,7 @@ def _conv2d_winograd_nhwc_impl(
name="inverse",
attrs={
"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"],
"schedule_rule": "meta_schedule.winograd_inverse",
"schedule_rule": target_kind,
},
# the attrs are necessary hints for the auto-scheduler
)
Expand Down
Loading

0 comments on commit 121c17e

Please sign in to comment.