Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[Relax][MS] Task extraction with proper weights (#129)
Browse files Browse the repository at this point in the history
* [Relax][MS] Task extraction with proper weights (Hzfengsy#32)

* Add a unit test

* Update the deduplication mapping / Update the unit test

* Update test for DummyDB reusing

* Remove unnecessary args

* Remove unused import
  • Loading branch information
MasterJH5574 authored and YuchenJin committed Nov 17, 2022
1 parent badee2a commit 5199a20
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 148 deletions.
86 changes: 12 additions & 74 deletions python/tvm/meta_schedule/relax_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,68 +15,21 @@
# specific language governing permissions and limitations
# under the License.
"""Meta schedule integration with high-level IR"""
from typing import Any, List, Union, Tuple, Dict, Optional
from typing import List, Union, Dict, Optional

import tvm
from tvm.ir import IRModule, structural_hash, structural_equal
from tvm import relax
from tvm._ffi import get_global_func
from tvm.ir import IRModule
from tvm.meta_schedule import ExtractedTask
from tvm.target import Target
from tvm.relax.expr import Function as RelaxFunc
from tvm.relax.utils import tir_partitioner
from tvm.runtime import NDArray


def deduplicate_extracted_tasks(
mods: List[IRModule],
) -> Tuple[List[IRModule], List[int]]:
"""Remove duplicate modules.
Parameters
----------
mods : List[IRModule]
The list of IRModule.
Returns
-------
tasks : Tuple[List[IRModule], List[int]]
A tuple containing the deduplicated modules and the count for each module.
"""
hash2modules: Dict[int, List[IRModule]] = {}
hash2counts: Dict[int, List[int]] = {}
for mod in mods:
shash = structural_hash(mod)
if shash in hash2modules:
is_dup = False
for i, relax_mod in enumerate(hash2modules[shash]):
# duplicate module was found
if structural_equal(mod, relax_mod):
hash2counts[shash][i] += 1
is_dup = True
break
if is_dup is False:
# hash conflict but actually different modules
hash2modules[shash].append(mod)
hash2counts[shash].append(1)

else:
hash2modules[shash] = [mod]
hash2counts[shash] = [1]

dedup: List[IRModule] = []
count: List[int] = []
for shash, relax_mods in hash2modules.items():
for i, mod in enumerate(relax_mods):
dedup.append(mod)
count.append(hash2counts[shash][i])
return dedup, count


def extract_task_from_relax(
mod: Union[IRModule, RelaxFunc],
target: Target,
params: Optional[Dict[str, NDArray]] = None,
*,
opt_level: int = 3,
pass_config: Optional[Dict[str, Any]] = None,
disabled_pass: Optional[List[str]] = None,
) -> List[ExtractedTask]:
"""Extract tuning tasks from a relax program.
Expand All @@ -92,30 +45,15 @@ def extract_task_from_relax(
tasks: List[ExtractedTask]
The tasks extracted from this module
"""
if isinstance(mod, RelaxFunc):
mod = IRModule.from_expr(mod)
if not isinstance(target, Target):
target = Target(target)

if disabled_pass is None:
disabled_pass = []
if pass_config is None:
pass_config = {}
extract_task_func = get_global_func(
"relax.backend.MetaScheduleExtractTask",
allow_missing=False,
)

if isinstance(mod, RelaxFunc):
mod = IRModule.from_expr(mod)
if params:
mod = tvm.relax.transform.BindParams("main", params)(mod)
mod = relax.transform.BindParams("main", params)(mod)

tir_partitions = tir_partitioner(mod)
tir_mods, tir_counts = deduplicate_extracted_tasks(tir_partitions)
tasks = []
with target, tvm.transform.PassContext(
opt_level=opt_level,
config=pass_config,
disabled_pass=disabled_pass,
):
for i, tir_mod in enumerate(tir_mods):
task_name = tir_mod.get_global_vars()[0].name_hint
# The second arg to ExtractedTask is supposed to be a high-level IRModule,
# passing tir_mod as a workaround.
tasks.append(ExtractedTask(task_name, tir_mod, target, [tir_mod], tir_counts[i]))
return tasks
return list(extract_task_func(mod, target))
24 changes: 0 additions & 24 deletions python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,6 @@
# under the License.
"""Utility functions for Relax"""
from typing import List
from tvm.tir import PrimFunc
from tvm import IRModule


def tir_partitioner(mod: IRModule) -> List[IRModule]:
"""Extracts tir PrimFuncs from the input IRModule.
Parameters
----------
mod : IRModule
The input IRModule.
Returns
-------
output : List[IRModule]
The result tir PrimFuncs.
"""
partitions = []
for gvar in mod.get_global_vars():
if isinstance(mod[gvar], PrimFunc):
tir_mod = IRModule({})
tir_mod[gvar] = mod[gvar]
partitions.append(tir_mod)
return partitions


def metadata_partitioner(rx_txt: str) -> List[str]:
Expand Down
102 changes: 102 additions & 0 deletions src/relax/backend/task_extraction.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/meta_schedule/extracted_task.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/target/target.h>
#include <tvm/tir/function.h>

namespace tvm {
namespace relax {
namespace backend {

using tvm::meta_schedule::ExtractedTask;

/*!
* \brief Extract the Meta-Schedule tuning task from a given IRModule.
* \note
* 1. The task extractor is responsible for task deduplication. The
* deduplication is achieved by comparing structural hashes of PrimFuncs.
* 2. For a PrimFunc, the weight of its corresponding task is the number
* of times it called by op Call-TIR. Say in an IRModule there are three
* PrimFuncs `fn1`, `fn2` and `fn3` sharing the same structural hash.
* Suppose `fn1` is called by 5 Call-TIR ops among all Relax function,
* `fn2` is called by 3 Call-TIR and `fn3` is called by 5 Call-TIR.
* Then we will have a ExtractedTask for all three functions, whose weight
* is 5 + 3 + 2 = 10.
*/
class TaskExtractor : public ExprVisitor {
public:
static Array<ExtractedTask> ExtractTask(IRModule mod, Target target) {
TaskExtractor extracor(mod, target);
// We go through each Relax function in the module.
for (const auto& kv : mod->functions) {
if (const auto* func = kv.second.as<FunctionNode>()) {
extracor(GetRef<Function>(func));
}
}
return std::move(extracor.tasks_);
}

private:
explicit TaskExtractor(IRModule mod, Target target)
: mod_(std::move(mod)), target_(std::move(target)) {}

void VisitExpr_(const CallNode* call) final {
static const Op& call_tir_op = Op::Get("relax.call_tir");
if (!call->op.same_as(call_tir_op)) {
// Since the Relax function is of A-normal form, the arguments of this call cannot be another
// Calls. And hence we do not need to recurse into this Call.
return;
}

const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
const tir::PrimFunc& func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));

auto it = func2task_.find(func);
if (it != func2task_.end()) {
it->second->weight += 1;
return;
}

IRModule tir_mod({{global_var, func}});
ExtractedTask task(/*task_name=*/global_var->name_hint, //
/*mod=*/tir_mod, //
/*target=*/target_, //
/*dispatched=*/{tir_mod}, //
/*weight=*/1);
tasks_.push_back(task);
func2task_.emplace(func, task);
}

IRModule mod_;
Target target_;
Array<ExtractedTask> tasks_;
std::unordered_map<tir::PrimFunc, ExtractedTask, StructuralHash, StructuralEqual> func2task_;
};

TVM_REGISTER_GLOBAL("relax.backend.MetaScheduleExtractTask")
.set_body_typed([](IRModule mod, Target target) {
return TaskExtractor::ExtractTask(std::move(mod), std::move(target));
});

} // namespace backend
} // namespace relax
} // namespace tvm
114 changes: 64 additions & 50 deletions tests/python/relax/test_autotir_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import tvm
from tvm.script import tir as T, relax as R
from tvm import relax

import numpy as np
from tvm.ir.module import IRModule
from tvm.target.target import Target
import pytest
import tempfile
from typing import List
from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord
from tvm.meta_schedule.utils import derived_object
import time
import tvm

from tvm import meta_schedule as ms
from tvm import relax
from tvm import transform
import time
import pytest
from tvm.ir.module import IRModule
from tvm.meta_schedule.testing import DummyDatabase
from tvm.script import relax as R, tir as T
from tvm.target.target import Target


# Test case with dynamic shape.
Expand Down Expand Up @@ -78,45 +78,6 @@ def main(x:Tensor((m,n), "float32"), w:Tensor((n,k), "float32")) -> Tensor:
"""


@derived_object
class DummyDatabase(PyDatabase):
def __init__(self):
super().__init__()
self.records = []
self.workload_reg = []

def has_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if tvm.ir.structural_equal(workload.mod, mod):
return True
return False

def commit_tuning_record(self, record: TuningRecord) -> None:
self.records.append(record)

def commit_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if tvm.ir.structural_equal(workload.mod, mod):
return workload
workload = Workload(mod)
self.workload_reg.append(workload)
return workload

def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
return list(
filter(
lambda x: x.workload == workload,
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
)
)[: int(top_k)]

def __len__(self) -> int:
return len(self.records)

def print_results(self) -> None:
print("\n".join([str(r) for r in self.records]))


@pytest.mark.parametrize("dev", ["cpu"])
def test_autotir(dev: str):
@tvm.script.ir_module
Expand Down Expand Up @@ -159,7 +120,7 @@ def main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tens
return lv1

mod = InputModule
assert isinstance(mod, tvm.IRModule)
assert isinstance(mod, IRModule)

if dev == "cpu":
target = Target("llvm --num-cores=16")
Expand Down Expand Up @@ -213,5 +174,58 @@ def test_autotir_gpu():
test_autotir("cuda")


def test_meta_schedule_extract_task_from_relax():
@tvm.script.ir_module
class Module:
@T.prim_func
def add1(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]):
for i, j in T.grid(128, 128):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + 1.0

@T.prim_func
def add2(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]):
for i, j in T.grid(128, 128):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + 2.0

# It is intentional that `add3` equals `add1`, in order to test the deduplication
# correctness.
@T.prim_func
def add3(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]):
for i, j in T.grid(128, 128):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + 1.0

@T.prim_func
def multiply1(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]):
for i, j in T.grid(128, 128):
with T.block("multiply"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0

@R.function
def main(x: Tensor((128, 128), "float32")) -> Tensor(_, "float32"):
with R.dataflow():
lv0 = R.call_tir(add1, (x,), (128, 128), dtype="float32")
lv1 = R.call_tir(multiply1, (lv0,), (128, 128), dtype="float32")
lv2 = R.call_tir(add2, (lv1,), (128, 128), dtype="float32")
lv3 = R.call_tir(multiply1, (lv2,), (128, 128), dtype="float32")
lv4 = R.call_tir(add3, (lv3,), (128, 128), dtype="float32")
gv = R.call_tir(add1, (lv4,), (128, 128), dtype="float32")
relax.output(gv)
return gv

tasks = ms.relax_integration.extract_task_from_relax(Module, Target("llvm --num-cores=16"))
expected_weights = {"add1": 3, "add2": 1, "multiply1": 2}
assert len(tasks) == len(expected_weights)
for task in tasks:
assert task.task_name in expected_weights
assert expected_weights[task.task_name] == task.weight


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 5199a20

Please sign in to comment.