From 4ef737973444831f99155247473c6cdf075d9fd6 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 21 Apr 2022 22:20:33 +0800 Subject: [PATCH] [Relax][MS] Task extraction with proper weights (#129) * [Relax][MS] Task extraction with proper weights (hzfengsy#32) * Add a unit test * Update the deduplication mapping / Update the unit test * Update test for DummyDB reusing * Remove unnecessary args * Remove unused import --- python/tvm/meta_schedule/relax_integration.py | 86 ++----------- python/tvm/relax/utils.py | 24 ---- src/relax/backend/task_extraction.cc | 102 ++++++++++++++++ .../python/relax/test_autotir_integration.py | 114 ++++++++++-------- 4 files changed, 178 insertions(+), 148 deletions(-) create mode 100644 src/relax/backend/task_extraction.cc diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py index 3c7c28a35b..29d32c8aa4 100644 --- a/python/tvm/meta_schedule/relax_integration.py +++ b/python/tvm/meta_schedule/relax_integration.py @@ -15,68 +15,21 @@ # specific language governing permissions and limitations # under the License. """Meta schedule integration with high-level IR""" -from typing import Any, List, Union, Tuple, Dict, Optional +from typing import List, Union, Dict, Optional -import tvm -from tvm.ir import IRModule, structural_hash, structural_equal +from tvm import relax +from tvm._ffi import get_global_func +from tvm.ir import IRModule from tvm.meta_schedule import ExtractedTask from tvm.target import Target from tvm.relax.expr import Function as RelaxFunc -from tvm.relax.utils import tir_partitioner from tvm.runtime import NDArray -def deduplicate_extracted_tasks( - mods: List[IRModule], -) -> Tuple[List[IRModule], List[int]]: - """Remove duplicate modules. - Parameters - ---------- - mods : List[IRModule] - The list of IRModule. - Returns - ------- - tasks : Tuple[List[IRModule], List[int]] - A tuple containing the deduplicated modules and the count for each module. - """ - hash2modules: Dict[int, List[IRModule]] = {} - hash2counts: Dict[int, List[int]] = {} - for mod in mods: - shash = structural_hash(mod) - if shash in hash2modules: - is_dup = False - for i, relax_mod in enumerate(hash2modules[shash]): - # duplicate module was found - if structural_equal(mod, relax_mod): - hash2counts[shash][i] += 1 - is_dup = True - break - if is_dup is False: - # hash conflict but actually different modules - hash2modules[shash].append(mod) - hash2counts[shash].append(1) - - else: - hash2modules[shash] = [mod] - hash2counts[shash] = [1] - - dedup: List[IRModule] = [] - count: List[int] = [] - for shash, relax_mods in hash2modules.items(): - for i, mod in enumerate(relax_mods): - dedup.append(mod) - count.append(hash2counts[shash][i]) - return dedup, count - - def extract_task_from_relax( mod: Union[IRModule, RelaxFunc], target: Target, params: Optional[Dict[str, NDArray]] = None, - *, - opt_level: int = 3, - pass_config: Optional[Dict[str, Any]] = None, - disabled_pass: Optional[List[str]] = None, ) -> List[ExtractedTask]: """Extract tuning tasks from a relax program. @@ -92,30 +45,15 @@ def extract_task_from_relax( tasks: List[ExtractedTask] The tasks extracted from this module """ - if isinstance(mod, RelaxFunc): - mod = IRModule.from_expr(mod) - if not isinstance(target, Target): - target = Target(target) - if disabled_pass is None: - disabled_pass = [] - if pass_config is None: - pass_config = {} + extract_task_func = get_global_func( + "relax.backend.MetaScheduleExtractTask", + allow_missing=False, + ) + if isinstance(mod, RelaxFunc): + mod = IRModule.from_expr(mod) if params: - mod = tvm.relax.transform.BindParams("main", params)(mod) + mod = relax.transform.BindParams("main", params)(mod) - tir_partitions = tir_partitioner(mod) - tir_mods, tir_counts = deduplicate_extracted_tasks(tir_partitions) - tasks = [] - with target, tvm.transform.PassContext( - opt_level=opt_level, - config=pass_config, - disabled_pass=disabled_pass, - ): - for i, tir_mod in enumerate(tir_mods): - task_name = tir_mod.get_global_vars()[0].name_hint - # The second arg to ExtractedTask is supposed to be a high-level IRModule, - # passing tir_mod as a workaround. - tasks.append(ExtractedTask(task_name, tir_mod, target, [tir_mod], tir_counts[i])) - return tasks + return list(extract_task_func(mod, target)) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 8d0f61b39e..bc8d41774b 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -16,30 +16,6 @@ # under the License. """Utility functions for Relax""" from typing import List -from tvm.tir import PrimFunc -from tvm import IRModule - - -def tir_partitioner(mod: IRModule) -> List[IRModule]: - """Extracts tir PrimFuncs from the input IRModule. - - Parameters - ---------- - mod : IRModule - The input IRModule. - - Returns - ------- - output : List[IRModule] - The result tir PrimFuncs. - """ - partitions = [] - for gvar in mod.get_global_vars(): - if isinstance(mod[gvar], PrimFunc): - tir_mod = IRModule({}) - tir_mod[gvar] = mod[gvar] - partitions.append(tir_mod) - return partitions def metadata_partitioner(rx_txt: str) -> List[str]: diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc new file mode 100644 index 0000000000..dd8d2f60cb --- /dev/null +++ b/src/relax/backend/task_extraction.cc @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { +namespace backend { + +using tvm::meta_schedule::ExtractedTask; + +/*! + * \brief Extract the Meta-Schedule tuning task from a given IRModule. + * \note + * 1. The task extractor is responsible for task deduplication. The + * deduplication is achieved by comparing structural hashes of PrimFuncs. + * 2. For a PrimFunc, the weight of its corresponding task is the number + * of times it called by op Call-TIR. Say in an IRModule there are three + * PrimFuncs `fn1`, `fn2` and `fn3` sharing the same structural hash. + * Suppose `fn1` is called by 5 Call-TIR ops among all Relax function, + * `fn2` is called by 3 Call-TIR and `fn3` is called by 5 Call-TIR. + * Then we will have a ExtractedTask for all three functions, whose weight + * is 5 + 3 + 2 = 10. + */ +class TaskExtractor : public ExprVisitor { + public: + static Array ExtractTask(IRModule mod, Target target) { + TaskExtractor extracor(mod, target); + // We go through each Relax function in the module. + for (const auto& kv : mod->functions) { + if (const auto* func = kv.second.as()) { + extracor(GetRef(func)); + } + } + return std::move(extracor.tasks_); + } + + private: + explicit TaskExtractor(IRModule mod, Target target) + : mod_(std::move(mod)), target_(std::move(target)) {} + + void VisitExpr_(const CallNode* call) final { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + if (!call->op.same_as(call_tir_op)) { + // Since the Relax function is of A-normal form, the arguments of this call cannot be another + // Calls. And hence we do not need to recurse into this Call. + return; + } + + const GlobalVar& global_var = Downcast(call->args[0]); + const tir::PrimFunc& func = Downcast(mod_->Lookup(global_var)); + + auto it = func2task_.find(func); + if (it != func2task_.end()) { + it->second->weight += 1; + return; + } + + IRModule tir_mod({{global_var, func}}); + ExtractedTask task(/*task_name=*/global_var->name_hint, // + /*mod=*/tir_mod, // + /*target=*/target_, // + /*dispatched=*/{tir_mod}, // + /*weight=*/1); + tasks_.push_back(task); + func2task_.emplace(func, task); + } + + IRModule mod_; + Target target_; + Array tasks_; + std::unordered_map func2task_; +}; + +TVM_REGISTER_GLOBAL("relax.backend.MetaScheduleExtractTask") + .set_body_typed([](IRModule mod, Target target) { + return TaskExtractor::ExtractTask(std::move(mod), std::move(target)); + }); + +} // namespace backend +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_autotir_integration.py b/tests/python/relax/test_autotir_integration.py index d3e8db2820..a69abcdff8 100644 --- a/tests/python/relax/test_autotir_integration.py +++ b/tests/python/relax/test_autotir_integration.py @@ -15,20 +15,20 @@ # specific language governing permissions and limitations # under the License. from __future__ import annotations -import tvm -from tvm.script import tir as T, relax as R -from tvm import relax + import numpy as np -from tvm.ir.module import IRModule -from tvm.target.target import Target +import pytest import tempfile -from typing import List -from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord -from tvm.meta_schedule.utils import derived_object +import time +import tvm + from tvm import meta_schedule as ms +from tvm import relax from tvm import transform -import time -import pytest +from tvm.ir.module import IRModule +from tvm.meta_schedule.testing import DummyDatabase +from tvm.script import relax as R, tir as T +from tvm.target.target import Target # Test case with dynamic shape. @@ -78,45 +78,6 @@ def main(x:Tensor((m,n), "float32"), w:Tensor((n,k), "float32")) -> Tensor: """ -@derived_object -class DummyDatabase(PyDatabase): - def __init__(self): - super().__init__() - self.records = [] - self.workload_reg = [] - - def has_workload(self, mod: IRModule) -> Workload: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return True - return False - - def commit_tuning_record(self, record: TuningRecord) -> None: - self.records.append(record) - - def commit_workload(self, mod: IRModule) -> Workload: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return workload - workload = Workload(mod) - self.workload_reg.append(workload) - return workload - - def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: - return list( - filter( - lambda x: x.workload == workload, - sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), - ) - )[: int(top_k)] - - def __len__(self) -> int: - return len(self.records) - - def print_results(self) -> None: - print("\n".join([str(r) for r in self.records])) - - @pytest.mark.parametrize("dev", ["cpu"]) def test_autotir(dev: str): @tvm.script.ir_module @@ -159,7 +120,7 @@ def main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tens return lv1 mod = InputModule - assert isinstance(mod, tvm.IRModule) + assert isinstance(mod, IRModule) if dev == "cpu": target = Target("llvm --num-cores=16") @@ -213,5 +174,58 @@ def test_autotir_gpu(): test_autotir("cuda") +def test_meta_schedule_extract_task_from_relax(): + @tvm.script.ir_module + class Module: + @T.prim_func + def add1(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + for i, j in T.grid(128, 128): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 1.0 + + @T.prim_func + def add2(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + for i, j in T.grid(128, 128): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 2.0 + + # It is intentional that `add3` equals `add1`, in order to test the deduplication + # correctness. + @T.prim_func + def add3(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + for i, j in T.grid(128, 128): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 1.0 + + @T.prim_func + def multiply1(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + for i, j in T.grid(128, 128): + with T.block("multiply"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + @R.function + def main(x: Tensor((128, 128), "float32")) -> Tensor(_, "float32"): + with R.dataflow(): + lv0 = R.call_tir(add1, (x,), (128, 128), dtype="float32") + lv1 = R.call_tir(multiply1, (lv0,), (128, 128), dtype="float32") + lv2 = R.call_tir(add2, (lv1,), (128, 128), dtype="float32") + lv3 = R.call_tir(multiply1, (lv2,), (128, 128), dtype="float32") + lv4 = R.call_tir(add3, (lv3,), (128, 128), dtype="float32") + gv = R.call_tir(add1, (lv4,), (128, 128), dtype="float32") + relax.output(gv) + return gv + + tasks = ms.relax_integration.extract_task_from_relax(Module, Target("llvm --num-cores=16")) + expected_weights = {"add1": 3, "add2": 1, "multiply1": 2} + assert len(tasks) == len(expected_weights) + for task in tasks: + assert task.task_name in expected_weights + assert expected_weights[task.task_name] == task.weight + + if __name__ == "__main__": pytest.main([__file__])