From d3c0f4046e5665298ca4fa26b119903e59e0e38e Mon Sep 17 00:00:00 2001 From: Hongyi Jin <3231950289@qq.com> Date: Sat, 29 Jan 2022 04:52:21 +0800 Subject: [PATCH] [MetaSchedule][M4a] Mutator: Mutate Parallel (#10096) --- include/tvm/tir/schedule/instruction.h | 3 + python/tvm/meta_schedule/mutator/__init__.py | 1 + .../meta_schedule/mutator/mutate_parallel.py | 33 ++ src/meta_schedule/mutator/mutate_parallel.cc | 312 ++++++++++++++++++ src/tir/schedule/analysis.h | 20 ++ src/tir/schedule/analysis/analysis.cc | 31 ++ src/tir/schedule/instruction.cc | 5 + ...t_meta_schedule_mutator_mutate_parallel.py | 113 +++++++ 8 files changed, 518 insertions(+) create mode 100644 python/tvm/meta_schedule/mutator/mutate_parallel.py create mode 100644 src/meta_schedule/mutator/mutate_parallel.cc create mode 100644 tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py diff --git a/include/tvm/tir/schedule/instruction.h b/include/tvm/tir/schedule/instruction.h index 5a9e687dc8c7..1af5ab07e67c 100644 --- a/include/tvm/tir/schedule/instruction.h +++ b/include/tvm/tir/schedule/instruction.h @@ -121,6 +121,9 @@ class InstructionKindNode : public runtime::Object { // not visited: f_attrs_from_json } + /*! \brief Checks if the instruction kind is EnterPostproc */ + bool IsPostproc() const; + static constexpr const char* _type_key = "tir.InstructionKind"; TVM_DECLARE_FINAL_OBJECT_INFO(InstructionKindNode, runtime::Object); }; diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py index 85deb7253e86..af3485b679f1 100644 --- a/python/tvm/meta_schedule/mutator/__init__.py +++ b/python/tvm/meta_schedule/mutator/__init__.py @@ -21,4 +21,5 @@ """ from .mutator import Mutator, PyMutator from .mutate_compute_location import MutateComputeLocation +from .mutate_parallel import MutateParallel from .mutate_unroll import MutateUnroll diff --git a/python/tvm/meta_schedule/mutator/mutate_parallel.py b/python/tvm/meta_schedule/mutator/mutate_parallel.py new file mode 100644 index 000000000000..c66dddb825f4 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutate_parallel.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Mutator that mutates the parallel extent""" +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .mutator import Mutator + + +@register_object("meta_schedule.MutateParallel") +class MutateParallel(Mutator): + """Mutator that mutates the parallel extent""" + + def __init__(self, max_jobs_per_core: int) -> None: + """Mutator that mutates the parallel extent""" + self.__init_handle_by_constructor__( + _ffi_api.MutatorMutateParallel, # type: ignore # pylint: disable=no-member + max_jobs_per_core, + ) diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc new file mode 100644 index 000000000000..7c973879f2cc --- /dev/null +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check if the instruction is annotation with `meta_schedule_parallel` + * \param inst The instruction to be checked + * \return Whether the instruction is annotation with `meta_schedule_parallel` + */ +bool IsAnnotateWithParallel(const Instruction& inst) { + static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_annotate)) { + return false; + } + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + return ann_key == attr::meta_schedule_parallel; +} + +/*! + * \brief Replace the annotation value + * \param inst The instruction to be replaced + * \param ann_val The new annotation value + * \return The replaced instruction + */ +Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) { + ICHECK_EQ(inst->inputs.size(), 2); + return Instruction(/*kind=*/inst->kind, // + /*inputs=*/{inst->inputs[0], Integer(ann_val)}, // + /*attrs=*/inst->attrs, + /*outputs=*/inst->outputs); +} + +/*! + * \brief Get the output of the instruction Get-Block + * \param inst The instruction to be checked + * \return The output of the instruction Get-Block + */ +const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) { + static const InstructionKind& inst_get_block = InstructionKind::Get("GetBlock"); + if (!inst->kind.same_as(inst_get_block)) { + return nullptr; + } + ICHECK_EQ(inst->outputs.size(), 1); + const BlockRVNode* block = TVM_TYPE_AS(block, inst->outputs[0], BlockRVNode); + return block; +} + +/*! + * \brief Analyze the parallel structure + * \param self The schedule state + * \param block_name The name of the root block + * \param func_name The name of the PrimFunc + * \param limit The uplimit of the parallelism + * \return The parallel structure + */ +std::vector> AnalyzeParallel(const ScheduleState& self, + const String& block_name, const String& func_name, + int64_t limit) { + Array block_srefs = tir::GetBlocks(self, block_name, func_name); + ICHECK_EQ(block_srefs.size(), 1); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]); + ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef(block)); + std::vector> results; + results.reserve(info.realizes.size()); + for (const BlockRealize& realize : info.realizes) { + // Step 1. Extract static loop extents for spatial loops + std::vector loop_extents; + const ForNode* loop = nullptr; + for (const StmtSRefNode* loop_sref = self->stmt2ref.at(realize->block.get())->parent; + (loop = loop_sref->StmtAs()) != nullptr; // + loop_sref = loop_sref->parent) { + int64_t loop_extent = -1; + if (const auto* ext = GetLoopIntExtent(loop)) { + if (!info.non_spatial_vars.count(loop->loop_var.get())) { + loop_extent = *ext; + } + } + if (loop_extent != -1) { + loop_extents.push_back(loop_extent); + } else { + loop_extents.clear(); + } + } + // Step 2. Take the prefix product of loop extents + if (!loop_extents.empty()) { + results.emplace_back(); + std::vector& result = results.back(); + result.reserve(loop_extents.size()); + int64_t prod_extent = 1; + for (auto it = loop_extents.rbegin(); it != loop_extents.rend(); ++it) { + result.push_back(prod_extent *= *it); + if (prod_extent >= limit) { + break; + } + } + } + } + return results; +} + +/*! + * \brief Get the number of parallelizable loops for each subtree + * \param loop_extent_prods The parallel structure for each subtree + * \param limit The uplimit of the parallelism + * \return The number of parallelizable loops for each subtree + */ +std::vector GetNumFusedLoops(const std::vector>& loop_extent_prods, + int64_t limit) { + std::vector results; + results.reserve(loop_extent_prods.size()); + for (const std::vector& prods : loop_extent_prods) { + int n = prods.size(); + int i = std::upper_bound(prods.begin(), prods.end(), limit) - prods.begin(); + if (i > 0 && prods[i - 1] == limit) { + --i; + } + if (i != n) { + ++i; + } + results.push_back(i); + } + return results; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::Instruction; +using tir::Trace; + +/*! \brief Create a Mutator that mutates the parallel extent */ +class MutateParallelNode : public MutatorNode { + public: + /*! + * \brief The maximum number of jobs to be launched per CPU core. + * It sets the uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. + * Use -1 to disable parallelism. + */ + int64_t max_jobs_per_core; + /*! \brief The number of cores in CPU. */ + int max_parallel_extent_; + /*! \brief JSON representation of the workload */ + std::string json_mod_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("max_jobs_per_core", &max_jobs_per_core); + // `max_parallel_extent_` is not visited. + // `json_mod` is not visited. + } + + static constexpr const char* _type_key = "meta_schedule.MutateParallel"; + TVM_DECLARE_FINAL_OBJECT_INFO(MutateParallelNode, MutatorNode); + + public: + struct Candidate; + // Inherit from `MutatorNode` + void InitializeWithTuneContext(const TuneContext& context) final { + Target target = context->target.value(); + this->max_parallel_extent_ = GetTargetNumCores(target) * this->max_jobs_per_core; + this->json_mod_ = SaveJSON(context->mod.value()); + } + // Inherit from `MutatorNode` + Optional Apply(const Trace& trace, TRandState* rand_state) final; +}; + +/*! \brief The candidate to be mutated */ +struct MutateParallelNode::Candidate { + /*! \brief The annotation instruction */ + Instruction inst; + /*! \brief The current parallel extent */ + int64_t parallel_extent; + /*! \brief The name of the root block */ + String block_name; + /*! \brief The name of the PrimFunc */ + String func_name; +}; + +/*! + * \brief Get an instruction that annotates the maximum parallel extent + * \param trace The trace to be mutated + * \param rand_state The random state + * \param candidate The candidate to be mutated + * \return Whether a decision is found + */ +bool FindParallelDecision(const Trace& trace, TRandState* rand_state, + MutateParallelNode::Candidate* candidate) { + using tir::BlockRVNode; + using tir::InstructionNode; + std::unordered_map get_block_insts; + std::vector ann_insts; + get_block_insts.reserve(trace->insts.size()); + ann_insts.reserve(trace->insts.size()); + for (const Instruction& inst : trace->insts) { + if (tir::IsAnnotateWithParallel(inst)) { + ann_insts.push_back(inst.get()); + } + if (const BlockRVNode* block_rv = tir::GetInstGetBlockOutput(inst)) { + get_block_insts[block_rv] = inst.get(); + } + } + int n_ann_insts = ann_insts.size(); + if (n_ann_insts == 0) { + return false; + } + const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; + ICHECK_EQ(ann_inst->inputs.size(), 2); + const InstructionNode* get_block_inst = + get_block_insts.at(Downcast(ann_inst->inputs[0]).get()); + ICHECK_EQ(get_block_inst->attrs.size(), 2); + candidate->inst = GetRef(ann_inst); + candidate->parallel_extent = Downcast(ann_inst->inputs[1])->value; + candidate->block_name = Downcast(get_block_inst->attrs[0]); + candidate->func_name = Downcast(get_block_inst->attrs[1]); + return true; +} + +Optional MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) { + // Step 1. Find a parallel decision. + Candidate candidate; + if (!FindParallelDecision(trace, rand_state, &candidate)) { + return NullOpt; + } + // Step 2. Replay the instructions to recover loop extents + tir::Schedule sch = tir::Schedule::Traced( // + /*mod=*/Downcast(LoadJSON(this->json_mod_)), // + /*rand_state=*/ForkSeed(rand_state), // + /*debug_mode=*/0, + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + trace->ApplyToSchedule(sch, /*remove_postproc=*/true); + // Step 3. Find all possible parallel plans + std::vector> loop_extent_prods = tir::AnalyzeParallel( + sch->state(), candidate.block_name, candidate.func_name, this->max_parallel_extent_); + std::unordered_map> limit2plan; + std::map, int64_t> plan2limit; + for (const std::vector& prods : loop_extent_prods) { + for (int64_t limit : prods) { + if (limit <= this->max_parallel_extent_ && !limit2plan.count(limit)) { + std::vector plan = tir::GetNumFusedLoops(loop_extent_prods, limit); + limit2plan[limit] = plan; + plan2limit[plan] = limit; + } + } + } + // Step 4. Remove the original plan and remove it + std::vector original_plan = + tir::GetNumFusedLoops(loop_extent_prods, candidate.parallel_extent); + auto it = plan2limit.find(original_plan); + if (it != plan2limit.end()) { + plan2limit.erase(it); + } + // Step 5. Pick a new plan + int n_plans = plan2limit.size(); + if (n_plans == 0) { + return NullOpt; + } + it = plan2limit.begin(); + for (int i = 0, n = tir::SampleInt(rand_state, 0, n_plans); i < n; ++i) { + ++it; + } + int64_t limit = it->second; + // Step 6. Assemble a new trace + Array insts; + insts.reserve(trace->insts.size()); + for (const Instruction& inst : trace->insts) { + if (inst.same_as(candidate.inst)) { + insts.push_back(tir::ReplaceAnnValue(candidate.inst, limit)); + } else if (inst->kind->IsPostproc()) { + break; + } else { + insts.push_back(inst); + } + } + return Trace(insts, trace->decisions); +} + +Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) { + ObjectPtr n = make_object(); + n->max_jobs_per_core = max_jobs_per_core; + return Mutator(n); +} + +TVM_REGISTER_NODE_TYPE(MutateParallelNode); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel").set_body_typed(Mutator::MutateParallel); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 591201312cd2..cdbb70bef6dd 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -91,6 +91,26 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref); StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline, bool require_subtree_compact_dataflow); +/*! + * \brief The information of a block scope, including the leaf blocks, + * as well as the loop types (spatial, reduction) for each loop in the scope. + */ +struct ScopeBlockLoopInfo { + /*! \brief A list of the leaf blocks, from left to right */ + std::vector realizes; + /*! \brief The loop vars bound to spatial block iters */ + std::unordered_set spatial_vars; + /*! \brief The loop vars bound to non-spatial block iters */ + std::unordered_set non_spatial_vars; +}; + +/*! + * \brief Inspect the scope of the given sref + * \param scope_block The root block of the scope + * \return The information of the scope + */ +ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block); + /*! * \brief Checks whether the block is a complete block under the scope * \param self The schedule state diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 1579f9154fe6..afdff9d5f832 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -150,6 +150,37 @@ Definition of a scope that is a stage pipeline: return scope_root_sref; } +ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) { + struct Collector : public StmtVisitor { + void VisitStmt_(const BlockRealizeNode* realize) final { + result.realizes.push_back(GetRef(realize)); + const Array& iter_vars = realize->block->iter_vars; + const Array& iter_values = realize->iter_values; + ICHECK_EQ(iter_vars.size(), iter_values.size()); + int n = realize->iter_values.size(); + for (int i = 0; i < n; ++i) { + const IterVar& iter_var = iter_vars[i]; + const PrimExpr& iter_value = iter_values[i]; + std::unordered_set* vars = nullptr; + if (iter_var->iter_type == IterVarType::kDataPar) { + vars = &result.spatial_vars; + } else { + vars = &result.non_spatial_vars; + } + PostOrderVisit(iter_value, [vars](const ObjectRef& obj) { + if (const VarNode* var = obj.as()) { + vars->insert(var); + } + }); + } + } + + ScopeBlockLoopInfo result; + } visitor; + visitor(scope_block->body); + return std::move(visitor.result); +} + /*! * \brief Check the dominant property of a block: * the block is the only writer of its output, dominating the reader of its output buffers diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc index af721767c32f..cedba4b96095 100644 --- a/src/tir/schedule/instruction.cc +++ b/src/tir/schedule/instruction.cc @@ -21,6 +21,11 @@ namespace tvm { namespace tir { +bool InstructionKindNode::IsPostproc() const { + static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); + return this == inst_enter_postproc.get(); +} + Instruction::Instruction(InstructionKind kind, Array inputs, Array attrs, Array outputs) { ObjectPtr n = make_object(); diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py new file mode 100644 index 000000000000..e263114ef60f --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator import MutateParallel, Mutator +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule + +# pylint: disable=invalid-name, no-member + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [512, 512]) + B = T.match_buffer(b, [512, 512]) + C = T.match_buffer(c, [512, 512]) + for i, j, k in T.grid(512, 512, 512): # type: ignore + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore + with T.init(): + C[vi, vj] = 0.0 # type: ignore + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=invalid-name, no-member + + +def _sch(decisions: List[List[int]], ann_val: int) -> Schedule: + sch = Schedule(matmul, debug_mask="all") + # pylint: disable=invalid-name + d0, d1, d2 = decisions + b0 = sch.get_block(name="C", func_name="main") + root = sch.get_block(name="root", func_name="main") + sch.get_consumers(block=b0) + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, + n=4, + max_innermost_factor=64, + decision=d0, + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) + v13, v14, v15, v16 = sch.sample_perfect_tile( + loop=l3, + n=4, + max_innermost_factor=64, + decision=d1, + ) + l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16]) + v21, v22 = sch.sample_perfect_tile( + loop=l4, + n=2, + max_innermost_factor=64, + decision=d2, + ) + l23, l24 = sch.split(loop=l4, factors=[v21, v22]) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True) + sch.annotate(block_or_loop=root, ann_key="meta_schedule.parallel", ann_val=ann_val) + # pylint: enable=invalid-name + return sch + + +def _make_mutator(target: Target, max_jobs_per_core: int) -> Mutator: + mutator = MutateParallel(max_jobs_per_core) + mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target)) + return mutator + + +def test_mutate_parallel_matmul(): + mutator = _make_mutator( + target=Target("llvm --num-cores=16"), + max_jobs_per_core=256, + ) + sch = _sch( + decisions=[ + [4, 32, 4, 1], + [8, 4, 8, 2], + [512, 1], + ], + ann_val=64, + ) + results = set() + for _ in range(100): + trace = mutator.apply(sch.trace) + ann_val = int(trace.insts[-1].inputs[1]) + results.add(ann_val) + if len(results) == 3: + break + assert len(results) == 3 + assert results == {4, 32, 4096} + + +if __name__ == """__main__""": + test_mutate_parallel_matmul()