forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MetaSchedule][M4a] Mutator: Mutate Parallel (apache#10096)
- Loading branch information
1 parent
91abbf8
commit d3c0f40
Showing
8 changed files
with
518 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Mutator that mutates the parallel extent""" | ||
from tvm._ffi.registry import register_object | ||
|
||
from .. import _ffi_api | ||
from .mutator import Mutator | ||
|
||
|
||
@register_object("meta_schedule.MutateParallel") | ||
class MutateParallel(Mutator): | ||
"""Mutator that mutates the parallel extent""" | ||
|
||
def __init__(self, max_jobs_per_core: int) -> None: | ||
"""Mutator that mutates the parallel extent""" | ||
self.__init_handle_by_constructor__( | ||
_ffi_api.MutatorMutateParallel, # type: ignore # pylint: disable=no-member | ||
max_jobs_per_core, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
#include <algorithm> | ||
#include <unordered_map> | ||
|
||
#include "../utils.h" | ||
|
||
namespace tvm { | ||
namespace tir { | ||
|
||
/*! | ||
* \brief Check if the instruction is annotation with `meta_schedule_parallel` | ||
* \param inst The instruction to be checked | ||
* \return Whether the instruction is annotation with `meta_schedule_parallel` | ||
*/ | ||
bool IsAnnotateWithParallel(const Instruction& inst) { | ||
static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate"); | ||
if (!inst->kind.same_as(inst_annotate)) { | ||
return false; | ||
} | ||
ICHECK_EQ(inst->attrs.size(), 1); | ||
String ann_key = Downcast<String>(inst->attrs[0]); | ||
return ann_key == attr::meta_schedule_parallel; | ||
} | ||
|
||
/*! | ||
* \brief Replace the annotation value | ||
* \param inst The instruction to be replaced | ||
* \param ann_val The new annotation value | ||
* \return The replaced instruction | ||
*/ | ||
Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) { | ||
ICHECK_EQ(inst->inputs.size(), 2); | ||
return Instruction(/*kind=*/inst->kind, // | ||
/*inputs=*/{inst->inputs[0], Integer(ann_val)}, // | ||
/*attrs=*/inst->attrs, | ||
/*outputs=*/inst->outputs); | ||
} | ||
|
||
/*! | ||
* \brief Get the output of the instruction Get-Block | ||
* \param inst The instruction to be checked | ||
* \return The output of the instruction Get-Block | ||
*/ | ||
const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) { | ||
static const InstructionKind& inst_get_block = InstructionKind::Get("GetBlock"); | ||
if (!inst->kind.same_as(inst_get_block)) { | ||
return nullptr; | ||
} | ||
ICHECK_EQ(inst->outputs.size(), 1); | ||
const BlockRVNode* block = TVM_TYPE_AS(block, inst->outputs[0], BlockRVNode); | ||
return block; | ||
} | ||
|
||
/*! | ||
* \brief Analyze the parallel structure | ||
* \param self The schedule state | ||
* \param block_name The name of the root block | ||
* \param func_name The name of the PrimFunc | ||
* \param limit The uplimit of the parallelism | ||
* \return The parallel structure | ||
*/ | ||
std::vector<std::vector<int64_t>> AnalyzeParallel(const ScheduleState& self, | ||
const String& block_name, const String& func_name, | ||
int64_t limit) { | ||
Array<StmtSRef> block_srefs = tir::GetBlocks(self, block_name, func_name); | ||
ICHECK_EQ(block_srefs.size(), 1); | ||
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]); | ||
ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef<Block>(block)); | ||
std::vector<std::vector<int64_t>> results; | ||
results.reserve(info.realizes.size()); | ||
for (const BlockRealize& realize : info.realizes) { | ||
// Step 1. Extract static loop extents for spatial loops | ||
std::vector<int64_t> loop_extents; | ||
const ForNode* loop = nullptr; | ||
for (const StmtSRefNode* loop_sref = self->stmt2ref.at(realize->block.get())->parent; | ||
(loop = loop_sref->StmtAs<ForNode>()) != nullptr; // | ||
loop_sref = loop_sref->parent) { | ||
int64_t loop_extent = -1; | ||
if (const auto* ext = GetLoopIntExtent(loop)) { | ||
if (!info.non_spatial_vars.count(loop->loop_var.get())) { | ||
loop_extent = *ext; | ||
} | ||
} | ||
if (loop_extent != -1) { | ||
loop_extents.push_back(loop_extent); | ||
} else { | ||
loop_extents.clear(); | ||
} | ||
} | ||
// Step 2. Take the prefix product of loop extents | ||
if (!loop_extents.empty()) { | ||
results.emplace_back(); | ||
std::vector<int64_t>& result = results.back(); | ||
result.reserve(loop_extents.size()); | ||
int64_t prod_extent = 1; | ||
for (auto it = loop_extents.rbegin(); it != loop_extents.rend(); ++it) { | ||
result.push_back(prod_extent *= *it); | ||
if (prod_extent >= limit) { | ||
break; | ||
} | ||
} | ||
} | ||
} | ||
return results; | ||
} | ||
|
||
/*! | ||
* \brief Get the number of parallelizable loops for each subtree | ||
* \param loop_extent_prods The parallel structure for each subtree | ||
* \param limit The uplimit of the parallelism | ||
* \return The number of parallelizable loops for each subtree | ||
*/ | ||
std::vector<int> GetNumFusedLoops(const std::vector<std::vector<int64_t>>& loop_extent_prods, | ||
int64_t limit) { | ||
std::vector<int> results; | ||
results.reserve(loop_extent_prods.size()); | ||
for (const std::vector<int64_t>& prods : loop_extent_prods) { | ||
int n = prods.size(); | ||
int i = std::upper_bound(prods.begin(), prods.end(), limit) - prods.begin(); | ||
if (i > 0 && prods[i - 1] == limit) { | ||
--i; | ||
} | ||
if (i != n) { | ||
++i; | ||
} | ||
results.push_back(i); | ||
} | ||
return results; | ||
} | ||
|
||
} // namespace tir | ||
} // namespace tvm | ||
|
||
namespace tvm { | ||
namespace meta_schedule { | ||
|
||
using tir::Instruction; | ||
using tir::Trace; | ||
|
||
/*! \brief Create a Mutator that mutates the parallel extent */ | ||
class MutateParallelNode : public MutatorNode { | ||
public: | ||
/*! | ||
* \brief The maximum number of jobs to be launched per CPU core. | ||
* It sets the uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. | ||
* Use -1 to disable parallelism. | ||
*/ | ||
int64_t max_jobs_per_core; | ||
/*! \brief The number of cores in CPU. */ | ||
int max_parallel_extent_; | ||
/*! \brief JSON representation of the workload */ | ||
std::string json_mod_; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) { | ||
v->Visit("max_jobs_per_core", &max_jobs_per_core); | ||
// `max_parallel_extent_` is not visited. | ||
// `json_mod` is not visited. | ||
} | ||
|
||
static constexpr const char* _type_key = "meta_schedule.MutateParallel"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(MutateParallelNode, MutatorNode); | ||
|
||
public: | ||
struct Candidate; | ||
// Inherit from `MutatorNode` | ||
void InitializeWithTuneContext(const TuneContext& context) final { | ||
Target target = context->target.value(); | ||
this->max_parallel_extent_ = GetTargetNumCores(target) * this->max_jobs_per_core; | ||
this->json_mod_ = SaveJSON(context->mod.value()); | ||
} | ||
// Inherit from `MutatorNode` | ||
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final; | ||
}; | ||
|
||
/*! \brief The candidate to be mutated */ | ||
struct MutateParallelNode::Candidate { | ||
/*! \brief The annotation instruction */ | ||
Instruction inst; | ||
/*! \brief The current parallel extent */ | ||
int64_t parallel_extent; | ||
/*! \brief The name of the root block */ | ||
String block_name; | ||
/*! \brief The name of the PrimFunc */ | ||
String func_name; | ||
}; | ||
|
||
/*! | ||
* \brief Get an instruction that annotates the maximum parallel extent | ||
* \param trace The trace to be mutated | ||
* \param rand_state The random state | ||
* \param candidate The candidate to be mutated | ||
* \return Whether a decision is found | ||
*/ | ||
bool FindParallelDecision(const Trace& trace, TRandState* rand_state, | ||
MutateParallelNode::Candidate* candidate) { | ||
using tir::BlockRVNode; | ||
using tir::InstructionNode; | ||
std::unordered_map<const BlockRVNode*, const InstructionNode*> get_block_insts; | ||
std::vector<const InstructionNode*> ann_insts; | ||
get_block_insts.reserve(trace->insts.size()); | ||
ann_insts.reserve(trace->insts.size()); | ||
for (const Instruction& inst : trace->insts) { | ||
if (tir::IsAnnotateWithParallel(inst)) { | ||
ann_insts.push_back(inst.get()); | ||
} | ||
if (const BlockRVNode* block_rv = tir::GetInstGetBlockOutput(inst)) { | ||
get_block_insts[block_rv] = inst.get(); | ||
} | ||
} | ||
int n_ann_insts = ann_insts.size(); | ||
if (n_ann_insts == 0) { | ||
return false; | ||
} | ||
const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; | ||
ICHECK_EQ(ann_inst->inputs.size(), 2); | ||
const InstructionNode* get_block_inst = | ||
get_block_insts.at(Downcast<tir::BlockRV>(ann_inst->inputs[0]).get()); | ||
ICHECK_EQ(get_block_inst->attrs.size(), 2); | ||
candidate->inst = GetRef<Instruction>(ann_inst); | ||
candidate->parallel_extent = Downcast<IntImm>(ann_inst->inputs[1])->value; | ||
candidate->block_name = Downcast<String>(get_block_inst->attrs[0]); | ||
candidate->func_name = Downcast<String>(get_block_inst->attrs[1]); | ||
return true; | ||
} | ||
|
||
Optional<Trace> MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) { | ||
// Step 1. Find a parallel decision. | ||
Candidate candidate; | ||
if (!FindParallelDecision(trace, rand_state, &candidate)) { | ||
return NullOpt; | ||
} | ||
// Step 2. Replay the instructions to recover loop extents | ||
tir::Schedule sch = tir::Schedule::Traced( // | ||
/*mod=*/Downcast<IRModule>(LoadJSON(this->json_mod_)), // | ||
/*rand_state=*/ForkSeed(rand_state), // | ||
/*debug_mode=*/0, | ||
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); | ||
trace->ApplyToSchedule(sch, /*remove_postproc=*/true); | ||
// Step 3. Find all possible parallel plans | ||
std::vector<std::vector<int64_t>> loop_extent_prods = tir::AnalyzeParallel( | ||
sch->state(), candidate.block_name, candidate.func_name, this->max_parallel_extent_); | ||
std::unordered_map<int64_t, std::vector<int>> limit2plan; | ||
std::map<std::vector<int>, int64_t> plan2limit; | ||
for (const std::vector<int64_t>& prods : loop_extent_prods) { | ||
for (int64_t limit : prods) { | ||
if (limit <= this->max_parallel_extent_ && !limit2plan.count(limit)) { | ||
std::vector<int> plan = tir::GetNumFusedLoops(loop_extent_prods, limit); | ||
limit2plan[limit] = plan; | ||
plan2limit[plan] = limit; | ||
} | ||
} | ||
} | ||
// Step 4. Remove the original plan and remove it | ||
std::vector<int> original_plan = | ||
tir::GetNumFusedLoops(loop_extent_prods, candidate.parallel_extent); | ||
auto it = plan2limit.find(original_plan); | ||
if (it != plan2limit.end()) { | ||
plan2limit.erase(it); | ||
} | ||
// Step 5. Pick a new plan | ||
int n_plans = plan2limit.size(); | ||
if (n_plans == 0) { | ||
return NullOpt; | ||
} | ||
it = plan2limit.begin(); | ||
for (int i = 0, n = tir::SampleInt(rand_state, 0, n_plans); i < n; ++i) { | ||
++it; | ||
} | ||
int64_t limit = it->second; | ||
// Step 6. Assemble a new trace | ||
Array<Instruction> insts; | ||
insts.reserve(trace->insts.size()); | ||
for (const Instruction& inst : trace->insts) { | ||
if (inst.same_as(candidate.inst)) { | ||
insts.push_back(tir::ReplaceAnnValue(candidate.inst, limit)); | ||
} else if (inst->kind->IsPostproc()) { | ||
break; | ||
} else { | ||
insts.push_back(inst); | ||
} | ||
} | ||
return Trace(insts, trace->decisions); | ||
} | ||
|
||
Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) { | ||
ObjectPtr<MutateParallelNode> n = make_object<MutateParallelNode>(); | ||
n->max_jobs_per_core = max_jobs_per_core; | ||
return Mutator(n); | ||
} | ||
|
||
TVM_REGISTER_NODE_TYPE(MutateParallelNode); | ||
TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel").set_body_typed(Mutator::MutateParallel); | ||
|
||
} // namespace meta_schedule | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.