diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 002fa51ee5e3d..d80fa70eee8a2 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -119,13 +119,21 @@ class Mutator : public runtime::ObjectRef { * \return The created mutator. */ TVM_DLL static Mutator MutateParallel(int64_t max_jobs_per_core); - /*! \brief Create a Mutator that mutates auto unroll step */ + /*! + * \brief Create a Mutator that mutates auto unroll step + * \return The mutator created + */ TVM_DLL static Mutator MutateUnroll(); /*! * \brief Create a Mutator that mutates the outcome of SampleComputeLocation * \return The mutator created */ TVM_DLL static Mutator MutateComputeLocation(); + /*! + * \brief Create a Mutator that mutates auto thread binding. + * \return The mutator created + */ + TVM_DLL static Mutator MutateThreadBinding(); /*! * \brief Create a mutator with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index 8b32ce460933c..195d558550170 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -144,10 +144,10 @@ class Postproc : public runtime::ObjectRef { TVM_DLL static Postproc RewriteReductionBlock(); /*! * \brief Create a postprocessor that adds thread binding to unbound blocks - * \param max_threadblock The max number of threadblocks in the cuda device. + * \param max_threadblocks The max number of threadblocks in the cuda device. * \return The postprocessor created. */ - TVM_DLL static Postproc RewriteUnboundBlock(int max_threadblock); + TVM_DLL static Postproc RewriteUnboundBlock(int max_threadblocks); /*! * \brief Create a postprocessor that applies tensorization to annotated blocks * \param vectorize_init_loop Whether or not vectorize the initialization loop produced by diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 2b2eefeb75742..b39c72e24db8e 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -212,6 +212,13 @@ class ScheduleRule : public runtime::ObjectRef { int max_vectorize_extent, // Array unroll_max_steps, // bool unroll_explicit); + /*! + * \brief Auto bind loops around the block to BlockIdx and ThreadIdx + * \param max_threadblocks The maximum number of threadblock on GPU + * \param thread_extents Candidates of thread axis extent. + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule AutoBind(int max_threadblocks, Array thread_extents); /*! * \brief Create a schedule rule with customized methods on the python-side. * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py index e534ba14346ee..a0f7bac357680 100644 --- a/python/tvm/meta_schedule/mutator/__init__.py +++ b/python/tvm/meta_schedule/mutator/__init__.py @@ -22,5 +22,6 @@ from .mutator import Mutator, PyMutator from .mutate_compute_location import MutateComputeLocation from .mutate_tile_size import MutateTileSize +from .mutate_thread_binding import MutateThreadBinding from .mutate_parallel import MutateParallel from .mutate_unroll import MutateUnroll diff --git a/python/tvm/meta_schedule/mutator/mutate_thread_binding.py b/python/tvm/meta_schedule/mutator/mutate_thread_binding.py new file mode 100644 index 0000000000000..6a2553f94346b --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutate_thread_binding.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Mutator that mutates the thread binding extent""" +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .mutator import Mutator + + +@register_object("meta_schedule.MutateThreadBinding") +class MutateThreadBinding(Mutator): + """Mutator that mutates the binding extent""" + + def __init__(self) -> None: + """Mutator that mutates the binding extent""" + self.__init_handle_by_constructor__( + _ffi_api.MutateThreadBinding, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py index c89bc4b0369ab..aef5bca690e47 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py +++ b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py @@ -17,6 +17,7 @@ """A postprocessor that adds thread binding to unbound blocks""" from tvm._ffi.registry import register_object + from .. import _ffi_api from .postproc import Postproc @@ -25,8 +26,8 @@ class RewriteUnboundBlock(Postproc): """A postprocessor that adds thread binding to unbound blocks""" - def __init__(self, max_threadblock: int = 256) -> None: + def __init__(self, max_threadblocks: int = 256) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocRewriteUnboundBlock, # type: ignore # pylint: disable=no-member - max_threadblock, + max_threadblocks, ) diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index a958fdc39db1f..18fc1de78c7b2 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -20,6 +20,7 @@ blocks in a schedule. See also PostOrderApply. """ from .add_rfactor import AddRFactor +from .auto_bind import AutoBind from .auto_inline import AutoInline from .cross_thread_reduction import CrossThreadReduction from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType diff --git a/python/tvm/meta_schedule/schedule_rule/auto_bind.py b/python/tvm/meta_schedule/schedule_rule/auto_bind.py new file mode 100644 index 0000000000000..c211093e92758 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/auto_bind.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Auto-bind Rule that binds blocks to threads if needed""" +from typing import List, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.AutoBind") +class AutoBind(ScheduleRule): + """Auto bind loops around the block to BlockIdx and ThreadIdx + + Parameters + ---------- + max_threadblocks: int + The maximum number of threadblock on GPU. + thread_extents: Optional[List[int]] + Candidates of thread axis extent. + """ + + def __init__( + self, + max_threadblocks: int = 256, + thread_extents: Optional[List[int]] = None, + ) -> None: + if thread_extents is None: + thread_extents = [32, 64, 128, 256, 512, 1024] + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleAutoBind, # type: ignore # pylint: disable=no-member + max_threadblocks, + thread_extents, + ) diff --git a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py index 261768c4897bf..d6242020726b0 100644 --- a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py +++ b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py @@ -131,7 +131,7 @@ def conv2d_winograd_cpu( vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap( "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1] ) - T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"}) + T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse.llvm"}) T.reads( [ inverse[vh, vw, p_3, co_1], diff --git a/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py b/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py index 530eadafc0f38..e737f9b04e622 100644 --- a/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py +++ b/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py @@ -132,7 +132,7 @@ def conv2d_winograd_cuda( # type: ignore vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap( "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1] ) - T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"}) + T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse.cuda"}) T.reads( [ inverse[vh, vw, p_3, co_1], diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index b149f20c52e3e..e159bfaaaa5ae 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -17,6 +17,7 @@ """Default schedule rules""" from tvm.meta_schedule.schedule_rule import ( AddRFactor, + AutoBind, AutoInline, CrossThreadReduction, MultiLevelTiling, @@ -28,6 +29,13 @@ from tvm.target import Target +def auto_bind(target: Target) -> ScheduleRule: + """Default schedule rules for auto bind""" + if target.kind.name == "cuda": + return AutoBind(max_threadblocks=256, thread_extents=[32, 64, 128, 256, 512, 1024]) + raise NotImplementedError(f"{target.kind.name} is not supported") + + def auto_inline(target: Target) -> ScheduleRule: """Default schedule rules for auto inline""" if target.kind.name == "llvm": diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 270c0dab8db43..9af237b3b7b86 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -156,6 +156,10 @@ def _sch_rules() -> List[ScheduleRule]: unroll_max_steps=[0, 16, 64, 512, 1024], unroll_explicit=True, ), + M.AutoBind( + max_threadblocks=256, + thread_extents=[32, 64, 128, 256, 512, 1024], + ), ] @staticmethod @@ -177,7 +181,8 @@ def _mutator_probs() -> Dict[Mutator, float]: return { M.MutateTileSize(): 0.9, - M.MutateUnroll(): 0.1, + M.MutateUnroll(): 0.08, + M.MutateThreadBinding(): 0.02, } @@ -842,6 +847,7 @@ def tune_relay( """ # pylint: disable=import-outside-toplevel from tvm.relay import build as relay_build + from .relay_integration import extract_task_from_relay # pylint: disable=protected-access, enable=import-outside-toplevel diff --git a/python/tvm/topi/cuda/conv2d_nhwc_winograd.py b/python/tvm/topi/cuda/conv2d_nhwc_winograd.py index 80745a90d9ff0..8accbbe532737 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc_winograd.py +++ b/python/tvm/topi/cuda/conv2d_nhwc_winograd.py @@ -440,7 +440,7 @@ def nhwc_winograd_cuda( bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b] ), name="inverse", - attrs={"schedule_rule": "meta_schedule.winograd_inverse"}, + attrs={"schedule_rule": "meta_schedule.winograd_inverse.cuda"}, ) # Output diff --git a/python/tvm/topi/cuda/conv2d_winograd.py b/python/tvm/topi/cuda/conv2d_winograd.py index 4ff3f52b998f9..d2b373ba87a7d 100644 --- a/python/tvm/topi/cuda/conv2d_winograd.py +++ b/python/tvm/topi/cuda/conv2d_winograd.py @@ -152,7 +152,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_ bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b] ), name="inverse", - attrs={"schedule_rule": "meta_schedule.winograd_inverse"}, + attrs={"schedule_rule": "meta_schedule.winograd_inverse.cuda"}, ) # output diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index c27ea81144ac2..b7ae9b3e1cd7c 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -1096,6 +1096,11 @@ def _conv2d_winograd_nhwc_impl( bgemm = auto_scheduler.rewrite_compute_body(bgemm, auto_scheduler_rewritten_layout) # inverse transform + if target is not None: + target_kind = "meta_schedule.winograd_inverse." + target.kind.name + else: + target_kind = "None" + r_a = te.reduce_axis((0, alpha), "r_a") r_b = te.reduce_axis((0, alpha), "r_b") inverse = te.compute( @@ -1106,7 +1111,7 @@ def _conv2d_winograd_nhwc_impl( name="inverse", attrs={ "auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"], - "schedule_rule": "meta_schedule.winograd_inverse", + "schedule_rule": target_kind, }, # the attrs are necessary hints for the auto-scheduler ) diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc new file mode 100644 index 0000000000000..41207162ee1d4 --- /dev/null +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +using tir::Instruction; +using tir::InstructionKind; +using tir::Trace; + +/*! \brief A mutator that mutates the thread binding factor decision of SampleCategorical */ +class MutateThreadBindingNode : public MutatorNode { + public: + /*! \brief JSON representation of the workload */ + std::string json_mod_; + + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "meta_schedule.MutateThreadBinding"; + TVM_DECLARE_FINAL_OBJECT_INFO(MutateThreadBindingNode, MutatorNode); + + public: + // Inherit from `MutatorNode` + void InitializeWithTuneContext(const TuneContext& context) final { + this->json_mod_ = SaveJSON(context->mod.value()); + } + // Inherit from `MutatorNode` + Optional Apply(const Trace& trace, TRandState* rand_state) final; + + private: + struct Candidate { + /*! \brief The sampling instruction to be mutated */ + Instruction inst; + /*! \brief The probability */ + std::vector probs; + /*! \brief The decision made */ + int decision; + + explicit Candidate(Instruction inst, std::vector probs, int decision) + : inst(std::move(inst)), probs(std::move(probs)), decision(std::move(decision)) {} + }; + + std::vector FindCandidates(const Trace& trace, TRandState* rand_state); +}; + +/*! + * \brief Find Candidate with the following pattern: + * \code + * v = sch.sample_categorical(...) + * l1, l2 = sch.split(loop=l0, factors=[None, v]) + * sch.bind(loop=l2, thread_axis="threadIdx.x") + * \endcode + * + * \param trace The trace from which to find the instructions + * \return All the candidate instructions + */ +std::vector MutateThreadBindingNode::FindCandidates( + const Trace& trace, TRandState* rand_state) { + using tir::InstructionNode; + + static InstructionKind inst_sample_categorical = InstructionKind::Get("SampleCategorical"); + static InstructionKind inst_split = InstructionKind::Get("Split"); + static InstructionKind inst_bind = InstructionKind::Get("Bind"); + + std::vector candidates; + std::unordered_map sample_insts; + std::unordered_map sampled_split_insts; + std::vector bind_insts; + + auto is_split_by_sample = [&sample_insts](const Instruction& inst) -> bool { + if (!inst->kind.same_as(inst_split)) { + return false; + } + // Only consider cases with 2 factors and the first one is None + if (inst->inputs.size() != 3 || inst->inputs[1].defined()) return false; + ICHECK(inst->inputs[2].defined()); + + return sample_insts.find(Downcast(inst->inputs[2]).get()) != sample_insts.end(); + }; + + auto is_thread_binding_by_sample = [&sampled_split_insts](const Instruction& inst) -> bool { + if (!inst->kind.same_as(inst_bind)) { + return false; + } + ICHECK_EQ(inst->inputs.size(), 1); + ICHECK_EQ(inst->attrs.size(), 1); + if (Downcast(inst->attrs[0]) != "threadIdx.x") return false; + + return sampled_split_insts.find(Downcast(inst->inputs[0]).get()) != + sampled_split_insts.end(); + }; + + for (const Instruction& inst : trace->insts) { + if (inst->kind.same_as(inst_sample_categorical)) { + ICHECK_EQ(inst->outputs.size(), 1); + const PrimExprNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[0], PrimExprNode); + sample_insts[var_rv] = inst.get(); + } else if (is_split_by_sample(inst)) { + CHECK_EQ(inst->outputs.size(), 2); + // Only consider the inner loop, which can be bound to threadIdx.x + const tir::LoopRVNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[1], tir::LoopRVNode); + sampled_split_insts[var_rv] = inst.get(); + } else if (is_thread_binding_by_sample(inst)) { + bind_insts.push_back(inst.get()); + } + } + + for (const InstructionNode* bind_inst : bind_insts) { + const auto* loop_rv = TVM_TYPE_AS(loop_rv, bind_inst->inputs[0], tir::LoopRVNode); + auto split_it = sampled_split_insts.find(loop_rv); + ICHECK(split_it != sampled_split_insts.end()); + const InstructionNode* split_inst = split_it->second; + + const auto* expr_rv = TVM_TYPE_AS(expr_rv, split_inst->inputs[2], PrimExprNode); + auto sample_it = sample_insts.find(expr_rv); + ICHECK(sample_it != sample_insts.end()); + const InstructionNode* sample_inst = sample_it->second; + + int decision = Downcast(trace->decisions[GetRef(sample_inst)])->value; + + std::vector probs = + support::AsVector(Downcast>(sample_inst->attrs[1])); + + candidates.emplace_back(GetRef(sample_inst), probs, decision); + } + return candidates; +} + +Optional MutateThreadBindingNode::Apply(const Trace& trace, TRandState* rand_state) { + std::vector candidates = FindCandidates(trace, rand_state); + if (candidates.empty()) { + return NullOpt; + } + Candidate candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; + // Remove the current decision + candidate.probs.erase(candidate.probs.begin() + candidate.decision); + int result = tir::MakeMultinomialSampler(rand_state, candidate.probs)(); + if (result >= candidate.decision) { + result += 1; + } + return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true); +} + +Mutator Mutator::MutateThreadBinding() { return Mutator(make_object()); } + +TVM_REGISTER_NODE_TYPE(MutateThreadBindingNode); +TVM_REGISTER_GLOBAL("meta_schedule.MutateThreadBinding") + .set_body_typed(Mutator::MutateThreadBinding); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc index 73dc89d30e1fb..183f04e7ba239 100644 --- a/src/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -16,84 +16,12 @@ * specific language governing permissions and limitations * under the License. */ +#include "../schedule_rule/auto_bind.h" #include "../utils.h" namespace tvm { namespace tir { -/*! \brief The rewrite type for an unbound block */ -enum class BindType : int32_t { - /*! \brief No additional thread binding is needed */ - kNoBind = 0, - /*! \brief Need to bind to blockIdx */ - kBindBlock = 1, - /*! \brief Need to bind to both blockIdx and threadIdx */ - kBindBlockThread = 2, -}; - -/*! - * \brief Check the combination of bindings to be added to the block - * \param block_sref The block to be checked - * \param fuse_first_num The number of loops to be fused - * \return The type of binding to be added to the block - */ -BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) { - Array loops = tir::GetLoops(block_sref); - int n = loops.size(); - if (n == 0) { - return BindType::kNoBind; - } - int i_block_idx = -1; - int i_thread_idx = -1; - int i_multi_child = -1; - int i_spatial_loop = -1; - for (int i = 0; i < n; ++i) { - const StmtSRef& loop_sref = loops[i]; - const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); - runtime::ThreadScope thread_scope = GetThreadScope(loop); - if (IsBlockIdx(thread_scope)) { - if (i_block_idx == -1) { - i_block_idx = i; - } - } - if (IsThreadIdx(thread_scope)) { - if (i_thread_idx == -1) { - i_thread_idx = i; - } - } - if (loop->kind != tir::ForKind::kSerial) { - if (i_multi_child == -1) { - i_multi_child = i; - } - } - if (!IsSingleStmt(loop->body)) { - if (i_multi_child == -1) { - i_multi_child = i + 1; - } - } - if (tir::GetLoopIterType(loop_sref) == IterVarType::kDataPar) { - if (i_spatial_loop == i - 1) { - ++i_spatial_loop; - } - } - } - if (i_multi_child == -1) { - i_multi_child = n; - } - if ((i_block_idx != -1 && i_thread_idx != -1) || i_spatial_loop == -1) { - return BindType::kNoBind; - } else if (i_block_idx != -1 && i_thread_idx == -1) { - ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; - throw; - } else if (i_block_idx == -1 && i_thread_idx != -1) { - *fuse_first_num = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1); - return BindType::kBindBlock; - } else { // i_block_idx == -1 && i_thread_idx == -1 - *fuse_first_num = std::min(i_multi_child, i_spatial_loop + 1); - return BindType::kBindBlockThread; - } -} - /*! \brief Find all the blocks that are not bound */ class UnboundBlockFinder : private StmtVisitor { public: @@ -159,11 +87,11 @@ class RewriteUnboundBlockNode : public PostprocNode { // Inherited from PostprocNode void InitializeWithTuneContext(const TuneContext& context) final { CHECK(context->target.defined()) << "ValueError: target is not defined"; - Optional max_num_threads = + Optional max_threads_per_block = context->target.value()->GetAttr("max_threads_per_block"); - CHECK(max_num_threads.defined()) + CHECK(max_threads_per_block.defined()) << "ValueError: missing attribute `max_threads_per_block` in the target"; - this->max_num_threads_ = max_num_threads.value(); + this->max_threads_per_block_ = max_threads_per_block.value(); } // Inherited from PostprocNode @@ -171,13 +99,13 @@ class RewriteUnboundBlockNode : public PostprocNode { public: /*! \brief The max number of threads per block from Target */ - int max_num_threads_ = -1; + int max_threads_per_block_ = -1; /*! \brief The max number of threadblocks in the cuda device */ - int max_threadblock_ = -1; + int max_threadblocks_ = -1; void VisitAttrs(tvm::AttrVisitor* v) { - // `max_num_threads_` is not visited - // `max_threadblock_` is not visited + // `max_threads_per_block_` is not visited + // `max_threadblocks_` is not visited } static constexpr const char* _type_key = "meta_schedule.RewriteUnboundBlock"; @@ -186,61 +114,28 @@ class RewriteUnboundBlockNode : public PostprocNode { bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) { using tir::BlockRV; + using tir::ExprRV; using tir::LoopRV; using tir::Schedule; - ICHECK_NE(this->max_num_threads_, -1); + ICHECK_NE(this->max_threads_per_block_, -1); + auto get_factor = [t = this->max_threads_per_block_](int max_extent) -> ExprRV { + return Integer(std::min(t, max_extent)); + }; std::vector> unbound_blocks = tir::UnboundBlockFinder::Find(sch->state()); for (const auto& kv : unbound_blocks) { tir::StmtSRef block_sref = kv.first; String global_var_name = kv.second; - int fuse_first_num = 0; - tir::BindType bind_type = tir::GetBindType(block_sref, &fuse_first_num); - if (bind_type == tir::BindType::kNoBind) { - continue; - } BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); - Array loop_rvs = sch->GetLoops(block_rv); - LoopRV fused = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + fuse_first_num}); - if (bind_type == tir::BindType::kBindBlock) { - sch->Bind(fused, "blockIdx.x"); - } else if (bind_type == tir::BindType::kBindBlockThread) { - int64_t extent_size = 0; - Array splits; - if (const int64_t* extent_ptr = tir::GetLoopIntExtent(sch->Get(fused).get())) { - extent_size = *extent_ptr; - if (extent_size > max_threadblock_ * max_num_threads_) { - splits = - sch->Split(fused, {NullOpt, Integer(max_threadblock_), Integer(max_num_threads_)}); - ICHECK_EQ(splits.size(), 3); - sch->Reorder({splits[1], splits[2], splits[0]}); - sch->Bind(splits[1], "blockIdx.x"); - sch->Bind(splits[2], "threadIdx.x"); - } else { - ICHECK_NE(extent_size, 0); - splits = sch->Split( - fused, - {NullOpt, Integer(std::min(static_cast(max_num_threads_), extent_size))}); - ICHECK_EQ(splits.size(), 2); - sch->Bind(splits[0], "blockIdx.x"); - sch->Bind(splits[1], "threadIdx.x"); - } - } else { - // loop is dynamic, returns nullptr - splits = sch->Split(fused, {NullOpt, Integer(max_num_threads_)}); - ICHECK_EQ(splits.size(), 2); - sch->Bind(splits[0], "blockIdx.x"); - sch->Bind(splits[1], "threadIdx.x"); - } - } + BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor); } return true; } -Postproc Postproc::RewriteUnboundBlock(int max_threadblock) { +Postproc Postproc::RewriteUnboundBlock(int max_threadblocks) { ObjectPtr n = make_object(); - n->max_threadblock_ = max_threadblock; - n->max_num_threads_ = -1; + n->max_threadblocks_ = max_threadblocks; + n->max_threads_per_block_ = -1; return Postproc(n); } diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc new file mode 100644 index 0000000000000..9c16856557e00 --- /dev/null +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./auto_bind.h" + +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block_rv, + int64_t max_threadblocks, int64_t max_threads_per_block, + std::function get_factor) { + using namespace tvm::tir; + Array loops = tir::GetLoops(sch->GetSRef(block_rv)); + int n = loops.size(); + if (n == 0) { + return; + } + int i_block_idx = -1; + int i_thread_idx = -1; + int i_multi_child = -1; + int i_spatial_loop = -1; + for (int i = 0; i < n; ++i) { + const StmtSRef& loop_sref = loops[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsBlockIdx(thread_scope)) { + if (i_block_idx == -1) { + i_block_idx = i; + } + } + if (IsThreadIdx(thread_scope)) { + if (i_thread_idx == -1) { + i_thread_idx = i; + } + } + if (loop->kind != ForKind::kSerial) { + if (i_multi_child == -1) { + i_multi_child = i; + } + } + if (!IsSingleStmt(loop->body)) { + if (i_multi_child == -1) { + i_multi_child = i + 1; + } + } + if (GetLoopIterType(loop_sref) == IterVarType::kDataPar) { + if (i_spatial_loop == i - 1) { + ++i_spatial_loop; + } + } + } + if (i_multi_child == -1) { + i_multi_child = n; + } + if ((i_block_idx != -1 && i_thread_idx != -1) || i_spatial_loop == -1) { + return; + } + if (i_block_idx != -1 && i_thread_idx == -1) { + ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; + throw; + } + LoopRV loop_rv{nullptr}; + if (i_block_idx == -1 && i_thread_idx != -1) { + int num_fuse = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1); + Array loop_rvs = sch->GetLoops(block_rv); + loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse}); + sch->Bind(loop_rv, "blockIdx.x"); + return; + } else { // i_block_idx == -1 && i_thread_idx == -1 + Array loop_rvs = sch->GetLoops(block_rv); + int num_fuse = std::min(i_multi_child, i_spatial_loop + 1); + loop_rv = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + num_fuse}); + } + int64_t extent = -1; + if (const int64_t* e = GetLoopIntExtent(sch->Get(loop_rv).get())) { + extent = *e; + } else { + extent = std::numeric_limits::max(); + } + if (extent <= max_threadblocks * max_threads_per_block) { + ExprRV factor = get_factor(std::min(extent, max_threads_per_block)); + Array splits = sch->Split(loop_rv, {NullOpt, factor}); + ICHECK_EQ(splits.size(), 2); + sch->Bind(splits[0], "blockIdx.x"); + sch->Bind(splits[1], "threadIdx.x"); + } else { + Array splits = sch->Split(loop_rv, {NullOpt, + Integer(max_threadblocks), // + Integer(max_threads_per_block)}); + ICHECK_EQ(splits.size(), 3); + sch->Reorder({splits[1], splits[2], splits[0]}); + sch->Bind(splits[1], "blockIdx.x"); + sch->Bind(splits[2], "threadIdx.x"); + } +} + +std::function MakeFactorSampler(tir::Schedule sch, + Array thread_extents) { + return [sch = std::move(sch), + thread_extents = std::move(thread_extents)](int64_t max_extent) -> tir::ExprRV { + Array extents; + extents.reserve(thread_extents.size()); + for (const Integer extent : thread_extents) { + if (extent->value <= max_extent) { + extents.push_back(extent); + } + } + int n = extents.size(); + if (n == 0) { + return Integer(max_extent); + } + if (n == 1) { + return Integer(extents[0]); + } + Array probs(n, FloatImm(DataType::Float(64), 1.0 / n)); + return sch->SampleCategorical(extents, probs); + }; +} + +class AutoBindNode : public ScheduleRuleNode { + public: + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final { + CHECK(context->target.defined()) << "ValueError: target is not defined"; + Optional max_threads_per_block = + context->target.value()->GetAttr("max_threads_per_block"); + CHECK(max_threads_per_block.defined()) + << "ValueError: missing attribute `max_threads_per_block` in the target"; + this->max_threads_per_block_ = max_threads_per_block.value(); + } + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final; + + public: + /*! \brief The max number of threads per block from Target */ + int64_t max_threads_per_block_ = -1; + /*! \brief The max number of threadblocks in the cuda device */ + int64_t max_threadblocks_ = -1; + /*! \brief thread_extents Candidates of thread axis extent. */ + Array thread_extents_; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `max_threads_per_block_` is not visited + // `max_threadblocks_` is not visited + // `thread_extents_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.AutoBind"; + TVM_DECLARE_FINAL_OBJECT_INFO(AutoBindNode, ScheduleRuleNode); +}; + +Array AutoBindNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) { + ICHECK_NE(this->max_threads_per_block_, -1); + auto get_factor = MakeFactorSampler(sch, this->thread_extents_); + BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor); + return {sch}; +} + +ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array thread_extents) { + ObjectPtr n = make_object(); + n->max_threadblocks_ = max_threadblocks; + n->max_threads_per_block_ = -1; + n->thread_extents_ = std::move(thread_extents); + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(AutoBindNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoBind").set_body_typed(ScheduleRule::AutoBind); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_bind.h b/src/meta_schedule/schedule_rule/auto_bind.h new file mode 100644 index 0000000000000..42cab104a2ff2 --- /dev/null +++ b/src/meta_schedule/schedule_rule/auto_bind.h @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! + * \brief Bind the given block if it is not bound to blockIdx or threadIdx. + * \param sch The schedule. + * \param block The block to be bound. + * \param max_threadblocks The maximum number of threadblocks allowed. + * \param max_threads The maximum number of threads allowed. + * \param get_factor A function that returns the tiling factor. + */ +void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block, + int64_t max_threadblocks, int64_t max_threads_per_block, + std::function get_factor); + +/*! + * \brief Given candidates of thread_extents, make a sampler that use `sch->SampleCategorical` + * to return a random thread extent. + * \param sch The schedule + * \param thread_extents The candidate thread extents. + * \return A sampler that returns a random thread extent. + */ +std::function MakeFactorSampler(tir::Schedule sch, + Array thread_extents); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/winograd.cc b/src/meta_schedule/schedule_rule/winograd.cc index d8aab3a3f757a..ceec080b00a9f 100644 --- a/src/meta_schedule/schedule_rule/winograd.cc +++ b/src/meta_schedule/schedule_rule/winograd.cc @@ -17,9 +17,12 @@ * under the License. */ #include "../utils.h" +#include "./auto_bind.h" namespace tvm { -namespace tir { +namespace meta_schedule { + +using namespace tvm::tir; TVM_REGISTER_GLOBAL("meta_schedule.compute_inline") .set_body_typed([](Schedule sch, BlockRV block) -> Array { @@ -63,7 +66,7 @@ inline LoopRV ScheduleDataPack(Schedule sch, BlockRV block) { return t1[1]; } -TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse") +TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse.llvm") .set_body_typed([](Schedule sch, BlockRV block) -> Array { ScheduleDataPack(sch, block); return {sch}; @@ -81,6 +84,16 @@ TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.llvm") return {sch}; }); +TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse.cuda") + .set_body_typed([](Schedule sch, BlockRV block) -> Array { + ScheduleDataPack(sch, block); + int64_t max_threadblocks = 256; + int64_t max_threads_per_block = 1024; + auto get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); + BindBlockThreadIdx(sch, block, max_threadblocks, max_threads_per_block, get_factor); + return {sch}; + }); + TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.cuda") .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { BlockRV input_tile = GetOnlyProducer(sch, data_pack); @@ -89,8 +102,12 @@ TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.cuda") sch->ComputeAt(input_tile, /*loop_rv=*/loop, /*preserve_unit_loops=*/true); sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local"); sch->ComputeInline(data_pad); + int64_t max_threadblocks = 256; + int64_t max_threads_per_block = 1024; + auto get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); + BindBlockThreadIdx(sch, data_pack, max_threadblocks, max_threads_per_block, get_factor); return {sch}; }); -} // namespace tir +} // namespace meta_schedule } // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py index afe6548d6fe39..328f98e7f0cb0 100644 --- a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py +++ b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py @@ -44,6 +44,25 @@ def input_tile_data_pad(sch: Schedule): b127 = sch.get_block(name="data_pad") sch.compute_inline(block=b127) + b3 = sch.get_block(name="data_pack") + l25, l26, l27, l28, _, _, _, _ = sch.get_loops(block=b3) + l33 = sch.fuse(l25, l26, l27, l28) + v34 = sch.sample_categorical( + candidates=[32, 64, 128, 256, 512, 1024], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=2, + ) + l35, l36 = sch.split(loop=l33, factors=[None, v34]) + sch.bind(loop=l35, thread_axis="blockIdx.x") + sch.bind(loop=l36, thread_axis="threadIdx.x") + def data_pack(sch: Schedule): b16 = sch.get_block(name="data_pack") l17, l18, l19, l20, l21, l22 = sch.get_loops(block=b16) @@ -74,6 +93,16 @@ def bgemm(sch: Schedule): ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS", ) + sch.annotate( + block_or_loop=b31, + ann_key="meta_schedule.thread_extent_low_inclusive", + ann_val=32, + ) + sch.annotate( + block_or_loop=b31, + ann_key="meta_schedule.thread_extent_high_inclusive", + ann_val=1024, + ) b32 = sch.cache_write(block=b31, write_buffer_index=0, storage_scope="local") b31, b32 = b32, b31 l33, l34, l35, l36, l37 = sch.get_loops(block=b32) @@ -185,6 +214,57 @@ def inverse(sch: Schedule): sch.unroll(loop=l6) sch.unroll(loop=l7) sch.reorder(l10, l14, l11, l15, l2, l3, l6, l7) + l59 = sch.fuse(l10, l14, l11, l15) + v60 = sch.sample_categorical( + candidates=[32, 64, 128, 256, 512, 1024], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=2, + ) + l61, l62 = sch.split(loop=l59, factors=[None, v60]) + sch.bind(loop=l61, thread_axis="blockIdx.x") + sch.bind(loop=l62, thread_axis="threadIdx.x") + + def conv2d(sch: Schedule): + b7 = sch.get_block(name="conv2d_winograd") + l141, l142, l143, l144 = sch.get_loops(block=b7) + l145 = sch.fuse(l141, l142, l143, l144) + v146 = sch.sample_categorical( + candidates=[32, 64, 128, 256, 512, 1024], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=2, + ) + l147, l148 = sch.split(loop=l145, factors=[None, v146]) + sch.bind(loop=l147, thread_axis="blockIdx.x") + sch.bind(loop=l148, thread_axis="threadIdx.x") + + def root_anno(sch: Schedule): + b8 = sch.get_block(name="root", func_name="main") + v140 = sch.sample_categorical( + candidates=[0, 16, 64, 512, 1024], + probs=[ + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + 0.20000000000000001, + ], + decision=2, + ) + sch.annotate(block_or_loop=b8, ann_key="meta_schedule.unroll_explicit", ann_val=v140) # pylint: enable=invalid-name @@ -194,6 +274,8 @@ def inverse(sch: Schedule): input_tile_data_pad(sch) bgemm(sch) inverse(sch) + conv2d(sch) + root_anno(sch) return sch.mod @@ -203,23 +285,27 @@ def test_conv2d_winograd_cuda(): mod = IRModule({"main": mod}) context = TuneContext( mod=mod, - target=Target("cuda"), + target=Target("nvidia/geforce-rtx-3090", host="llvm"), task_name="Custom Search Space Task", sch_rules=DefaultCUDA._sch_rules(), # pylint: disable=protected-access ) + for sch_rule in context.sch_rules: + sch_rule.initialize_with_tune_context(context) post_order_apply = PostOrderApply() post_order_apply.initialize_with_tune_context(context) (sch,) = post_order_apply.generate_design_space(mod) decisions = dict( zip( - [i for i in sch.trace.insts[:-2] if i.kind.name.startswith("Sample")], + [i for i in sch.trace.insts if i.kind.name.startswith("Sample")], [ # data_pack [3, 3], [64, 2], + 2, # inverse [3, 3], [2, 64], + 2, # bgemm [1, 1, 1, 1, 6], [1, 1, 1, 3, 2], @@ -228,10 +314,14 @@ def test_conv2d_winograd_cuda(): [32, 1, 4], 1, 1, + # root anno + 2, + # conv2d + 2, ], ) ) - trace = Trace(sch.trace.insts[:-2], decisions=decisions) + trace = Trace(sch.trace.insts, decisions=decisions) sch = Schedule(mod=mod) trace.apply_to_schedule(sch, remove_postproc=False) answer = sch.mod diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_thread_binding.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_thread_binding.py new file mode 100644 index 0000000000000..a2e5dcbd1f0a8 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_thread_binding.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator import MutateThreadBinding, Mutator +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule + +# pylint: disable=invalid-name, no-member + + +@T.prim_func +def element_wise(var_A: T.handle, var_B: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + for i, j in T.grid(512, 512): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 1.0 + + +# pylint: enable=invalid-name, no-member + + +def _sch() -> Schedule: + sch = Schedule(element_wise, debug_mask="all") + # pylint: disable=invalid-name + b0 = sch.get_block(name="C", func_name="main") + l1, l2 = sch.get_loops(block=b0) + l3 = sch.fuse(l1, l2) + v4 = sch.sample_categorical( + candidates=[32, 64, 128, 256, 512, 1024], + probs=[ + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + 0.16666666666666666, + ], + decision=3, + ) + l5, l6 = sch.split(loop=l3, factors=[None, v4]) + sch.bind(loop=l5, thread_axis="blockIdx.x") + sch.bind(loop=l6, thread_axis="threadIdx.x") + # pylint: enable=invalid-name + return sch + + +def _make_mutator(target: Target) -> Mutator: + mutator = MutateThreadBinding() + mutator.initialize_with_tune_context(TuneContext(mod=element_wise, target=target)) + return mutator + + +def test_mutate_thread_binding(): + mutator = _make_mutator(target=Target("cuda")) + sch = _sch() + results = set() + for _ in range(100): + trace = mutator.apply(sch.trace) + decision = trace.decisions[trace.insts[-4]] + results.add(decision) + if len(results) == 5: + break + assert len(results) == 5 + assert results == {0, 1, 2, 4, 5} + + +if __name__ == "__main__": + test_mutate_thread_binding() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py new file mode 100644 index 0000000000000..bd0a24e8b642e --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import auto_bind +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.target import Target +from tvm.script import tir as T + + +@T.prim_func +def element_wise(var_A: T.handle, var_B: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + for i, j in T.grid(512, 512): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 1.0 + + +def _create_context(mod, target, rule) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_cuda_element_wise(): + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1, l2 = sch.get_loops(block=b0)", + "l3 = sch.fuse(l1, l2)", + "v4 = sch.sample_categorical(candidates=[32, 64, 128, 256, 512, 1024], probs=[0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666])", + "l5, l6 = sch.split(loop=l3, factors=[None, v4])", + 'sch.bind(loop=l5, thread_axis="blockIdx.x")', + 'sch.bind(loop=l6, thread_axis="threadIdx.x")', + ] + ] + target = Target("nvidia/geforce-rtx-3080", host="llvm") + ctx = _create_context( + element_wise, + target=target, + rule=auto_bind(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_cuda_element_wise()