From 725badba8aa1a11ed997cf62c32a62d103e67b56 Mon Sep 17 00:00:00 2001
From: Hongyi Jin <3231950289@qq.com>
Date: Sat, 29 Jan 2022 04:52:21 +0800
Subject: [PATCH] [MetaSchedule][M4a] Mutator: Mutate Parallel (#10096)

---
 include/tvm/tir/schedule/instruction.h        |   3 +
 python/tvm/meta_schedule/mutator/__init__.py  |   1 +
 .../meta_schedule/mutator/mutate_parallel.py  |  33 ++
 src/meta_schedule/mutator/mutate_parallel.cc  | 312 ++++++++++++++++++
 src/tir/schedule/analysis.h                   |  20 ++
 src/tir/schedule/analysis/analysis.cc         |  31 ++
 src/tir/schedule/instruction.cc               |   5 +
 ...t_meta_schedule_mutator_mutate_parallel.py | 113 +++++++
 8 files changed, 518 insertions(+)
 create mode 100644 python/tvm/meta_schedule/mutator/mutate_parallel.py
 create mode 100644 src/meta_schedule/mutator/mutate_parallel.cc
 create mode 100644 tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py

diff --git a/include/tvm/tir/schedule/instruction.h b/include/tvm/tir/schedule/instruction.h
index 5a9e687dc8c7..1af5ab07e67c 100644
--- a/include/tvm/tir/schedule/instruction.h
+++ b/include/tvm/tir/schedule/instruction.h
@@ -121,6 +121,9 @@ class InstructionKindNode : public runtime::Object {
     // not visited: f_attrs_from_json
   }
 
+  /*! \brief Checks if the instruction kind is EnterPostproc */
+  bool IsPostproc() const;
+
   static constexpr const char* _type_key = "tir.InstructionKind";
   TVM_DECLARE_FINAL_OBJECT_INFO(InstructionKindNode, runtime::Object);
 };
diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py
index 85deb7253e86..af3485b679f1 100644
--- a/python/tvm/meta_schedule/mutator/__init__.py
+++ b/python/tvm/meta_schedule/mutator/__init__.py
@@ -21,4 +21,5 @@
 """
 from .mutator import Mutator, PyMutator
 from .mutate_compute_location import MutateComputeLocation
+from .mutate_parallel import MutateParallel
 from .mutate_unroll import MutateUnroll
diff --git a/python/tvm/meta_schedule/mutator/mutate_parallel.py b/python/tvm/meta_schedule/mutator/mutate_parallel.py
new file mode 100644
index 000000000000..c66dddb825f4
--- /dev/null
+++ b/python/tvm/meta_schedule/mutator/mutate_parallel.py
@@ -0,0 +1,33 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Mutator that mutates the parallel extent"""
+from tvm._ffi.registry import register_object
+
+from .. import _ffi_api
+from .mutator import Mutator
+
+
+@register_object("meta_schedule.MutateParallel")
+class MutateParallel(Mutator):
+    """Mutator that mutates the parallel extent"""
+
+    def __init__(self, max_jobs_per_core: int) -> None:
+        """Mutator that mutates the parallel extent"""
+        self.__init_handle_by_constructor__(
+            _ffi_api.MutatorMutateParallel,  # type: ignore # pylint: disable=no-member
+            max_jobs_per_core,
+        )
diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc
new file mode 100644
index 000000000000..7c973879f2cc
--- /dev/null
+++ b/src/meta_schedule/mutator/mutate_parallel.cc
@@ -0,0 +1,312 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <algorithm>
+#include <unordered_map>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Check if the instruction is annotation with `meta_schedule_parallel`
+ * \param inst The instruction to be checked
+ * \return Whether the instruction is annotation with `meta_schedule_parallel`
+ */
+bool IsAnnotateWithParallel(const Instruction& inst) {
+  static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate");
+  if (!inst->kind.same_as(inst_annotate)) {
+    return false;
+  }
+  ICHECK_EQ(inst->attrs.size(), 1);
+  String ann_key = Downcast<String>(inst->attrs[0]);
+  return ann_key == attr::meta_schedule_parallel;
+}
+
+/*!
+ * \brief Replace the annotation value
+ * \param inst The instruction to be replaced
+ * \param ann_val The new annotation value
+ * \return The replaced instruction
+ */
+Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) {
+  ICHECK_EQ(inst->inputs.size(), 2);
+  return Instruction(/*kind=*/inst->kind,                             //
+                     /*inputs=*/{inst->inputs[0], Integer(ann_val)},  //
+                     /*attrs=*/inst->attrs,
+                     /*outputs=*/inst->outputs);
+}
+
+/*!
+ * \brief Get the output of the instruction Get-Block
+ * \param inst The instruction to be checked
+ * \return The output of the instruction Get-Block
+ */
+const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) {
+  static const InstructionKind& inst_get_block = InstructionKind::Get("GetBlock");
+  if (!inst->kind.same_as(inst_get_block)) {
+    return nullptr;
+  }
+  ICHECK_EQ(inst->outputs.size(), 1);
+  const BlockRVNode* block = TVM_TYPE_AS(block, inst->outputs[0], BlockRVNode);
+  return block;
+}
+
+/*!
+ * \brief Analyze the parallel structure
+ * \param self The schedule state
+ * \param block_name The name of the root block
+ * \param func_name The name of the PrimFunc
+ * \param limit The uplimit of the parallelism
+ * \return The parallel structure
+ */
+std::vector<std::vector<int64_t>> AnalyzeParallel(const ScheduleState& self,
+                                                  const String& block_name, const String& func_name,
+                                                  int64_t limit) {
+  Array<StmtSRef> block_srefs = tir::GetBlocks(self, block_name, func_name);
+  ICHECK_EQ(block_srefs.size(), 1);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]);
+  ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef<Block>(block));
+  std::vector<std::vector<int64_t>> results;
+  results.reserve(info.realizes.size());
+  for (const BlockRealize& realize : info.realizes) {
+    // Step 1. Extract static loop extents for spatial loops
+    std::vector<int64_t> loop_extents;
+    const ForNode* loop = nullptr;
+    for (const StmtSRefNode* loop_sref = self->stmt2ref.at(realize->block.get())->parent;
+         (loop = loop_sref->StmtAs<ForNode>()) != nullptr;  //
+         loop_sref = loop_sref->parent) {
+      int64_t loop_extent = -1;
+      if (const auto* ext = GetLoopIntExtent(loop)) {
+        if (!info.non_spatial_vars.count(loop->loop_var.get())) {
+          loop_extent = *ext;
+        }
+      }
+      if (loop_extent != -1) {
+        loop_extents.push_back(loop_extent);
+      } else {
+        loop_extents.clear();
+      }
+    }
+    // Step 2. Take the prefix product of loop extents
+    if (!loop_extents.empty()) {
+      results.emplace_back();
+      std::vector<int64_t>& result = results.back();
+      result.reserve(loop_extents.size());
+      int64_t prod_extent = 1;
+      for (auto it = loop_extents.rbegin(); it != loop_extents.rend(); ++it) {
+        result.push_back(prod_extent *= *it);
+        if (prod_extent >= limit) {
+          break;
+        }
+      }
+    }
+  }
+  return results;
+}
+
+/*!
+ * \brief Get the number of parallelizable loops for each subtree
+ * \param loop_extent_prods The parallel structure for each subtree
+ * \param limit The uplimit of the parallelism
+ * \return The number of parallelizable loops for each subtree
+ */
+std::vector<int> GetNumFusedLoops(const std::vector<std::vector<int64_t>>& loop_extent_prods,
+                                  int64_t limit) {
+  std::vector<int> results;
+  results.reserve(loop_extent_prods.size());
+  for (const std::vector<int64_t>& prods : loop_extent_prods) {
+    int n = prods.size();
+    int i = std::upper_bound(prods.begin(), prods.end(), limit) - prods.begin();
+    if (i > 0 && prods[i - 1] == limit) {
+      --i;
+    }
+    if (i != n) {
+      ++i;
+    }
+    results.push_back(i);
+  }
+  return results;
+}
+
+}  // namespace tir
+}  // namespace tvm
+
+namespace tvm {
+namespace meta_schedule {
+
+using tir::Instruction;
+using tir::Trace;
+
+/*! \brief Create a Mutator that mutates the parallel extent */
+class MutateParallelNode : public MutatorNode {
+ public:
+  /*!
+   * \brief The maximum number of jobs to be launched per CPU core.
+   * It sets the uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`.
+   * Use -1 to disable parallelism.
+   */
+  int64_t max_jobs_per_core;
+  /*! \brief The number of cores in CPU. */
+  int max_parallel_extent_;
+  /*! \brief JSON representation of the workload */
+  std::string json_mod_;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("max_jobs_per_core", &max_jobs_per_core);
+    // `max_parallel_extent_` is not visited.
+    // `json_mod` is not visited.
+  }
+
+  static constexpr const char* _type_key = "meta_schedule.MutateParallel";
+  TVM_DECLARE_FINAL_OBJECT_INFO(MutateParallelNode, MutatorNode);
+
+ public:
+  struct Candidate;
+  // Inherit from `MutatorNode`
+  void InitializeWithTuneContext(const TuneContext& context) final {
+    Target target = context->target.value();
+    this->max_parallel_extent_ = GetTargetNumCores(target) * this->max_jobs_per_core;
+    this->json_mod_ = SaveJSON(context->mod.value());
+  }
+  // Inherit from `MutatorNode`
+  Optional<Trace> Apply(const Trace& trace, TRandState* rand_state) final;
+};
+
+/*! \brief The candidate to be mutated */
+struct MutateParallelNode::Candidate {
+  /*! \brief The annotation instruction */
+  Instruction inst;
+  /*! \brief The current parallel extent */
+  int64_t parallel_extent;
+  /*! \brief The name of the root block */
+  String block_name;
+  /*! \brief The name of the PrimFunc */
+  String func_name;
+};
+
+/*!
+ * \brief Get an instruction that annotates the maximum parallel extent
+ * \param trace The trace to be mutated
+ * \param rand_state The random state
+ * \param candidate The candidate to be mutated
+ * \return Whether a decision is found
+ */
+bool FindParallelDecision(const Trace& trace, TRandState* rand_state,
+                          MutateParallelNode::Candidate* candidate) {
+  using tir::BlockRVNode;
+  using tir::InstructionNode;
+  std::unordered_map<const BlockRVNode*, const InstructionNode*> get_block_insts;
+  std::vector<const InstructionNode*> ann_insts;
+  get_block_insts.reserve(trace->insts.size());
+  ann_insts.reserve(trace->insts.size());
+  for (const Instruction& inst : trace->insts) {
+    if (tir::IsAnnotateWithParallel(inst)) {
+      ann_insts.push_back(inst.get());
+    }
+    if (const BlockRVNode* block_rv = tir::GetInstGetBlockOutput(inst)) {
+      get_block_insts[block_rv] = inst.get();
+    }
+  }
+  int n_ann_insts = ann_insts.size();
+  if (n_ann_insts == 0) {
+    return false;
+  }
+  const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)];
+  ICHECK_EQ(ann_inst->inputs.size(), 2);
+  const InstructionNode* get_block_inst =
+      get_block_insts.at(Downcast<tir::BlockRV>(ann_inst->inputs[0]).get());
+  ICHECK_EQ(get_block_inst->attrs.size(), 2);
+  candidate->inst = GetRef<Instruction>(ann_inst);
+  candidate->parallel_extent = Downcast<IntImm>(ann_inst->inputs[1])->value;
+  candidate->block_name = Downcast<String>(get_block_inst->attrs[0]);
+  candidate->func_name = Downcast<String>(get_block_inst->attrs[1]);
+  return true;
+}
+
+Optional<Trace> MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) {
+  // Step 1. Find a parallel decision.
+  Candidate candidate;
+  if (!FindParallelDecision(trace, rand_state, &candidate)) {
+    return NullOpt;
+  }
+  // Step 2. Replay the instructions to recover loop extents
+  tir::Schedule sch = tir::Schedule::Traced(                  //
+      /*mod=*/Downcast<IRModule>(LoadJSON(this->json_mod_)),  //
+      /*rand_state=*/ForkSeed(rand_state),                    //
+      /*debug_mode=*/0,
+      /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
+  trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
+  // Step 3. Find all possible parallel plans
+  std::vector<std::vector<int64_t>> loop_extent_prods = tir::AnalyzeParallel(
+      sch->state(), candidate.block_name, candidate.func_name, this->max_parallel_extent_);
+  std::unordered_map<int64_t, std::vector<int>> limit2plan;
+  std::map<std::vector<int>, int64_t> plan2limit;
+  for (const std::vector<int64_t>& prods : loop_extent_prods) {
+    for (int64_t limit : prods) {
+      if (limit <= this->max_parallel_extent_ && !limit2plan.count(limit)) {
+        std::vector<int> plan = tir::GetNumFusedLoops(loop_extent_prods, limit);
+        limit2plan[limit] = plan;
+        plan2limit[plan] = limit;
+      }
+    }
+  }
+  // Step 4. Remove the original plan and remove it
+  std::vector<int> original_plan =
+      tir::GetNumFusedLoops(loop_extent_prods, candidate.parallel_extent);
+  auto it = plan2limit.find(original_plan);
+  if (it != plan2limit.end()) {
+    plan2limit.erase(it);
+  }
+  // Step 5. Pick a new plan
+  int n_plans = plan2limit.size();
+  if (n_plans == 0) {
+    return NullOpt;
+  }
+  it = plan2limit.begin();
+  for (int i = 0, n = tir::SampleInt(rand_state, 0, n_plans); i < n; ++i) {
+    ++it;
+  }
+  int64_t limit = it->second;
+  // Step 6. Assemble a new trace
+  Array<Instruction> insts;
+  insts.reserve(trace->insts.size());
+  for (const Instruction& inst : trace->insts) {
+    if (inst.same_as(candidate.inst)) {
+      insts.push_back(tir::ReplaceAnnValue(candidate.inst, limit));
+    } else if (inst->kind->IsPostproc()) {
+      break;
+    } else {
+      insts.push_back(inst);
+    }
+  }
+  return Trace(insts, trace->decisions);
+}
+
+Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) {
+  ObjectPtr<MutateParallelNode> n = make_object<MutateParallelNode>();
+  n->max_jobs_per_core = max_jobs_per_core;
+  return Mutator(n);
+}
+
+TVM_REGISTER_NODE_TYPE(MutateParallelNode);
+TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel").set_body_typed(Mutator::MutateParallel);
+
+}  // namespace meta_schedule
+}  // namespace tvm
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 591201312cd2..cdbb70bef6dd 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -91,6 +91,26 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref);
 StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline,
                       bool require_subtree_compact_dataflow);
 
+/*!
+ * \brief The information of a block scope, including the leaf blocks,
+ * as well as the loop types (spatial, reduction) for each loop in the scope.
+ */
+struct ScopeBlockLoopInfo {
+  /*! \brief A list of the leaf blocks, from left to right */
+  std::vector<BlockRealize> realizes;
+  /*! \brief The loop vars bound to spatial block iters */
+  std::unordered_set<const VarNode*> spatial_vars;
+  /*! \brief The loop vars bound to non-spatial block iters */
+  std::unordered_set<const VarNode*> non_spatial_vars;
+};
+
+/*!
+ * \brief Inspect the scope of the given sref
+ * \param scope_block The root block of the scope
+ * \return The information of the scope
+ */
+ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block);
+
 /*!
  * \brief Checks whether the block is a complete block under the scope
  * \param self The schedule state
diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc
index 1579f9154fe6..afdff9d5f832 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -150,6 +150,37 @@ Definition of a scope that is a stage pipeline:
   return scope_root_sref;
 }
 
+ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) {
+  struct Collector : public StmtVisitor {
+    void VisitStmt_(const BlockRealizeNode* realize) final {
+      result.realizes.push_back(GetRef<BlockRealize>(realize));
+      const Array<IterVar>& iter_vars = realize->block->iter_vars;
+      const Array<PrimExpr>& iter_values = realize->iter_values;
+      ICHECK_EQ(iter_vars.size(), iter_values.size());
+      int n = realize->iter_values.size();
+      for (int i = 0; i < n; ++i) {
+        const IterVar& iter_var = iter_vars[i];
+        const PrimExpr& iter_value = iter_values[i];
+        std::unordered_set<const VarNode*>* vars = nullptr;
+        if (iter_var->iter_type == IterVarType::kDataPar) {
+          vars = &result.spatial_vars;
+        } else {
+          vars = &result.non_spatial_vars;
+        }
+        PostOrderVisit(iter_value, [vars](const ObjectRef& obj) {
+          if (const VarNode* var = obj.as<VarNode>()) {
+            vars->insert(var);
+          }
+        });
+      }
+    }
+
+    ScopeBlockLoopInfo result;
+  } visitor;
+  visitor(scope_block->body);
+  return std::move(visitor.result);
+}
+
 /*!
  * \brief Check the dominant property of a block:
  * the block is the only writer of its output, dominating the reader of its output buffers
diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc
index af721767c32f..cedba4b96095 100644
--- a/src/tir/schedule/instruction.cc
+++ b/src/tir/schedule/instruction.cc
@@ -21,6 +21,11 @@
 namespace tvm {
 namespace tir {
 
+bool InstructionKindNode::IsPostproc() const {
+  static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc");
+  return this == inst_enter_postproc.get();
+}
+
 Instruction::Instruction(InstructionKind kind, Array<ObjectRef> inputs, Array<ObjectRef> attrs,
                          Array<ObjectRef> outputs) {
   ObjectPtr<InstructionNode> n = make_object<InstructionNode>();
diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py
new file mode 100644
index 000000000000..e263114ef60f
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+from typing import List
+
+from tvm.meta_schedule import TuneContext
+from tvm.meta_schedule.mutator import MutateParallel, Mutator
+from tvm.script import tir as T
+from tvm.target import Target
+from tvm.tir import Schedule
+
+# pylint: disable=invalid-name, no-member
+
+
+@T.prim_func
+def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, [512, 512])
+    B = T.match_buffer(b, [512, 512])
+    C = T.match_buffer(c, [512, 512])
+    for i, j, k in T.grid(512, 512, 512):  # type: ignore
+        with T.block("C"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])  # type: ignore
+            with T.init():
+                C[vi, vj] = 0.0  # type: ignore
+            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+# pylint: enable=invalid-name, no-member
+
+
+def _sch(decisions: List[List[int]], ann_val: int) -> Schedule:
+    sch = Schedule(matmul, debug_mask="all")
+    # pylint: disable=invalid-name
+    d0, d1, d2 = decisions
+    b0 = sch.get_block(name="C", func_name="main")
+    root = sch.get_block(name="root", func_name="main")
+    sch.get_consumers(block=b0)
+    b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")
+    l2, l3, l4 = sch.get_loops(block=b0)
+    v5, v6, v7, v8 = sch.sample_perfect_tile(
+        loop=l2,
+        n=4,
+        max_innermost_factor=64,
+        decision=d0,
+    )
+    l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])
+    v13, v14, v15, v16 = sch.sample_perfect_tile(
+        loop=l3,
+        n=4,
+        max_innermost_factor=64,
+        decision=d1,
+    )
+    l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])
+    v21, v22 = sch.sample_perfect_tile(
+        loop=l4,
+        n=2,
+        max_innermost_factor=64,
+        decision=d2,
+    )
+    l23, l24 = sch.split(loop=l4, factors=[v21, v22])
+    sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)
+    sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)
+    sch.annotate(block_or_loop=root, ann_key="meta_schedule.parallel", ann_val=ann_val)
+    # pylint: enable=invalid-name
+    return sch
+
+
+def _make_mutator(target: Target, max_jobs_per_core: int) -> Mutator:
+    mutator = MutateParallel(max_jobs_per_core)
+    mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target))
+    return mutator
+
+
+def test_mutate_parallel_matmul():
+    mutator = _make_mutator(
+        target=Target("llvm --num-cores=16"),
+        max_jobs_per_core=256,
+    )
+    sch = _sch(
+        decisions=[
+            [4, 32, 4, 1],
+            [8, 4, 8, 2],
+            [512, 1],
+        ],
+        ann_val=64,
+    )
+    results = set()
+    for _ in range(100):
+        trace = mutator.apply(sch.trace)
+        ann_val = int(trace.insts[-1].inputs[1])
+        results.add(ann_val)
+        if len(results) == 3:
+            break
+    assert len(results) == 3
+    assert results == {4, 32, 4096}
+
+
+if __name__ == """__main__""":
+    test_mutate_parallel_matmul()