Skip to content

Commit

Permalink
[MetaSchedule][M4a] Mutator: Mutate Parallel (apache#10096)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhongyii authored and sunggg committed Jan 28, 2022
1 parent 91abbf8 commit d3c0f40
Show file tree
Hide file tree
Showing 8 changed files with 518 additions and 0 deletions.
3 changes: 3 additions & 0 deletions include/tvm/tir/schedule/instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class InstructionKindNode : public runtime::Object {
// not visited: f_attrs_from_json
}

/*! \brief Checks if the instruction kind is EnterPostproc */
bool IsPostproc() const;

static constexpr const char* _type_key = "tir.InstructionKind";
TVM_DECLARE_FINAL_OBJECT_INFO(InstructionKindNode, runtime::Object);
};
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/mutator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
"""
from .mutator import Mutator, PyMutator
from .mutate_compute_location import MutateComputeLocation
from .mutate_parallel import MutateParallel
from .mutate_unroll import MutateUnroll
33 changes: 33 additions & 0 deletions python/tvm/meta_schedule/mutator/mutate_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Mutator that mutates the parallel extent"""
from tvm._ffi.registry import register_object

from .. import _ffi_api
from .mutator import Mutator


@register_object("meta_schedule.MutateParallel")
class MutateParallel(Mutator):
"""Mutator that mutates the parallel extent"""

def __init__(self, max_jobs_per_core: int) -> None:
"""Mutator that mutates the parallel extent"""
self.__init_handle_by_constructor__(
_ffi_api.MutatorMutateParallel, # type: ignore # pylint: disable=no-member
max_jobs_per_core,
)
312 changes: 312 additions & 0 deletions src/meta_schedule/mutator/mutate_parallel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <algorithm>
#include <unordered_map>

#include "../utils.h"

namespace tvm {
namespace tir {

/*!
* \brief Check if the instruction is annotation with `meta_schedule_parallel`
* \param inst The instruction to be checked
* \return Whether the instruction is annotation with `meta_schedule_parallel`
*/
bool IsAnnotateWithParallel(const Instruction& inst) {
static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate");
if (!inst->kind.same_as(inst_annotate)) {
return false;
}
ICHECK_EQ(inst->attrs.size(), 1);
String ann_key = Downcast<String>(inst->attrs[0]);
return ann_key == attr::meta_schedule_parallel;
}

/*!
* \brief Replace the annotation value
* \param inst The instruction to be replaced
* \param ann_val The new annotation value
* \return The replaced instruction
*/
Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) {
ICHECK_EQ(inst->inputs.size(), 2);
return Instruction(/*kind=*/inst->kind, //
/*inputs=*/{inst->inputs[0], Integer(ann_val)}, //
/*attrs=*/inst->attrs,
/*outputs=*/inst->outputs);
}

/*!
* \brief Get the output of the instruction Get-Block
* \param inst The instruction to be checked
* \return The output of the instruction Get-Block
*/
const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) {
static const InstructionKind& inst_get_block = InstructionKind::Get("GetBlock");
if (!inst->kind.same_as(inst_get_block)) {
return nullptr;
}
ICHECK_EQ(inst->outputs.size(), 1);
const BlockRVNode* block = TVM_TYPE_AS(block, inst->outputs[0], BlockRVNode);
return block;
}

/*!
* \brief Analyze the parallel structure
* \param self The schedule state
* \param block_name The name of the root block
* \param func_name The name of the PrimFunc
* \param limit The uplimit of the parallelism
* \return The parallel structure
*/
std::vector<std::vector<int64_t>> AnalyzeParallel(const ScheduleState& self,
const String& block_name, const String& func_name,
int64_t limit) {
Array<StmtSRef> block_srefs = tir::GetBlocks(self, block_name, func_name);
ICHECK_EQ(block_srefs.size(), 1);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]);
ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef<Block>(block));
std::vector<std::vector<int64_t>> results;
results.reserve(info.realizes.size());
for (const BlockRealize& realize : info.realizes) {
// Step 1. Extract static loop extents for spatial loops
std::vector<int64_t> loop_extents;
const ForNode* loop = nullptr;
for (const StmtSRefNode* loop_sref = self->stmt2ref.at(realize->block.get())->parent;
(loop = loop_sref->StmtAs<ForNode>()) != nullptr; //
loop_sref = loop_sref->parent) {
int64_t loop_extent = -1;
if (const auto* ext = GetLoopIntExtent(loop)) {
if (!info.non_spatial_vars.count(loop->loop_var.get())) {
loop_extent = *ext;
}
}
if (loop_extent != -1) {
loop_extents.push_back(loop_extent);
} else {
loop_extents.clear();
}
}
// Step 2. Take the prefix product of loop extents
if (!loop_extents.empty()) {
results.emplace_back();
std::vector<int64_t>& result = results.back();
result.reserve(loop_extents.size());
int64_t prod_extent = 1;
for (auto it = loop_extents.rbegin(); it != loop_extents.rend(); ++it) {
result.push_back(prod_extent *= *it);
if (prod_extent >= limit) {
break;
}
}
}
}
return results;
}

/*!
* \brief Get the number of parallelizable loops for each subtree
* \param loop_extent_prods The parallel structure for each subtree
* \param limit The uplimit of the parallelism
* \return The number of parallelizable loops for each subtree
*/
std::vector<int> GetNumFusedLoops(const std::vector<std::vector<int64_t>>& loop_extent_prods,
int64_t limit) {
std::vector<int> results;
results.reserve(loop_extent_prods.size());
for (const std::vector<int64_t>& prods : loop_extent_prods) {
int n = prods.size();
int i = std::upper_bound(prods.begin(), prods.end(), limit) - prods.begin();
if (i > 0 && prods[i - 1] == limit) {
--i;
}
if (i != n) {
++i;
}
results.push_back(i);
}
return results;
}

} // namespace tir
} // namespace tvm

namespace tvm {
namespace meta_schedule {

using tir::Instruction;
using tir::Trace;

/*! \brief Create a Mutator that mutates the parallel extent */
class MutateParallelNode : public MutatorNode {
public:
/*!
* \brief The maximum number of jobs to be launched per CPU core.
* It sets the uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`.
* Use -1 to disable parallelism.
*/
int64_t max_jobs_per_core;
/*! \brief The number of cores in CPU. */
int max_parallel_extent_;
/*! \brief JSON representation of the workload */
std::string json_mod_;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("max_jobs_per_core", &max_jobs_per_core);
// `max_parallel_extent_` is not visited.
// `json_mod` is not visited.
}

static constexpr const char* _type_key = "meta_schedule.MutateParallel";
TVM_DECLARE_FINAL_OBJECT_INFO(MutateParallelNode, MutatorNode);

public:
struct Candidate;
// Inherit from `MutatorNode`
void InitializeWithTuneContext(const TuneContext& context) final {
Target target = context->target.value();
this->max_parallel_extent_ = GetTargetNumCores(target) * this->max_jobs_per_core;
this->json_mod_ = SaveJSON(context->mod.value());
}
// Inherit from `MutatorNode`
Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
};

/*! \brief The candidate to be mutated */
struct MutateParallelNode::Candidate {
/*! \brief The annotation instruction */
Instruction inst;
/*! \brief The current parallel extent */
int64_t parallel_extent;
/*! \brief The name of the root block */
String block_name;
/*! \brief The name of the PrimFunc */
String func_name;
};

/*!
* \brief Get an instruction that annotates the maximum parallel extent
* \param trace The trace to be mutated
* \param rand_state The random state
* \param candidate The candidate to be mutated
* \return Whether a decision is found
*/
bool FindParallelDecision(const Trace& trace, TRandState* rand_state,
MutateParallelNode::Candidate* candidate) {
using tir::BlockRVNode;
using tir::InstructionNode;
std::unordered_map<const BlockRVNode*, const InstructionNode*> get_block_insts;
std::vector<const InstructionNode*> ann_insts;
get_block_insts.reserve(trace->insts.size());
ann_insts.reserve(trace->insts.size());
for (const Instruction& inst : trace->insts) {
if (tir::IsAnnotateWithParallel(inst)) {
ann_insts.push_back(inst.get());
}
if (const BlockRVNode* block_rv = tir::GetInstGetBlockOutput(inst)) {
get_block_insts[block_rv] = inst.get();
}
}
int n_ann_insts = ann_insts.size();
if (n_ann_insts == 0) {
return false;
}
const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)];
ICHECK_EQ(ann_inst->inputs.size(), 2);
const InstructionNode* get_block_inst =
get_block_insts.at(Downcast<tir::BlockRV>(ann_inst->inputs[0]).get());
ICHECK_EQ(get_block_inst->attrs.size(), 2);
candidate->inst = GetRef<Instruction>(ann_inst);
candidate->parallel_extent = Downcast<IntImm>(ann_inst->inputs[1])->value;
candidate->block_name = Downcast<String>(get_block_inst->attrs[0]);
candidate->func_name = Downcast<String>(get_block_inst->attrs[1]);
return true;
}

Optional<Trace> MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) {
// Step 1. Find a parallel decision.
Candidate candidate;
if (!FindParallelDecision(trace, rand_state, &candidate)) {
return NullOpt;
}
// Step 2. Replay the instructions to recover loop extents
tir::Schedule sch = tir::Schedule::Traced( //
/*mod=*/Downcast<IRModule>(LoadJSON(this->json_mod_)), //
/*rand_state=*/ForkSeed(rand_state), //
/*debug_mode=*/0,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
// Step 3. Find all possible parallel plans
std::vector<std::vector<int64_t>> loop_extent_prods = tir::AnalyzeParallel(
sch->state(), candidate.block_name, candidate.func_name, this->max_parallel_extent_);
std::unordered_map<int64_t, std::vector<int>> limit2plan;
std::map<std::vector<int>, int64_t> plan2limit;
for (const std::vector<int64_t>& prods : loop_extent_prods) {
for (int64_t limit : prods) {
if (limit <= this->max_parallel_extent_ && !limit2plan.count(limit)) {
std::vector<int> plan = tir::GetNumFusedLoops(loop_extent_prods, limit);
limit2plan[limit] = plan;
plan2limit[plan] = limit;
}
}
}
// Step 4. Remove the original plan and remove it
std::vector<int> original_plan =
tir::GetNumFusedLoops(loop_extent_prods, candidate.parallel_extent);
auto it = plan2limit.find(original_plan);
if (it != plan2limit.end()) {
plan2limit.erase(it);
}
// Step 5. Pick a new plan
int n_plans = plan2limit.size();
if (n_plans == 0) {
return NullOpt;
}
it = plan2limit.begin();
for (int i = 0, n = tir::SampleInt(rand_state, 0, n_plans); i < n; ++i) {
++it;
}
int64_t limit = it->second;
// Step 6. Assemble a new trace
Array<Instruction> insts;
insts.reserve(trace->insts.size());
for (const Instruction& inst : trace->insts) {
if (inst.same_as(candidate.inst)) {
insts.push_back(tir::ReplaceAnnValue(candidate.inst, limit));
} else if (inst->kind->IsPostproc()) {
break;
} else {
insts.push_back(inst);
}
}
return Trace(insts, trace->decisions);
}

Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) {
ObjectPtr<MutateParallelNode> n = make_object<MutateParallelNode>();
n->max_jobs_per_core = max_jobs_per_core;
return Mutator(n);
}

TVM_REGISTER_NODE_TYPE(MutateParallelNode);
TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel").set_body_typed(Mutator::MutateParallel);

} // namespace meta_schedule
} // namespace tvm
20 changes: 20 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,26 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref);
StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline,
bool require_subtree_compact_dataflow);

/*!
* \brief The information of a block scope, including the leaf blocks,
* as well as the loop types (spatial, reduction) for each loop in the scope.
*/
struct ScopeBlockLoopInfo {
/*! \brief A list of the leaf blocks, from left to right */
std::vector<BlockRealize> realizes;
/*! \brief The loop vars bound to spatial block iters */
std::unordered_set<const VarNode*> spatial_vars;
/*! \brief The loop vars bound to non-spatial block iters */
std::unordered_set<const VarNode*> non_spatial_vars;
};

/*!
* \brief Inspect the scope of the given sref
* \param scope_block The root block of the scope
* \return The information of the scope
*/
ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block);

/*!
* \brief Checks whether the block is a complete block under the scope
* \param self The schedule state
Expand Down
Loading

0 comments on commit d3c0f40

Please sign in to comment.