From 987b10fc52a24748c8c3c9c8759387b51a5d618d Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 27 Aug 2022 16:12:57 -0700 Subject: [PATCH] [MetaSchedule] Introduce `ScheduleFnDatabase` Following #12520, this PR introduces `ScheduleFnDatabase`, a mocked database to allow injecting handcrafted schedules provided by a schedule function. The schedule function comes with the following signature: ```python def schedule_fn( sch: tir.Schedule, ) -> bool: task_name = sch.mod.attrs["task_name"] # ^^^ provides an optional name of the task queried ... ``` This mocked database helps incorporate the existing testing utility `apply_fixed_schedule` more formally into the MetaSchedule-Relay build pipeline, and allows further extension to Relax with the same interface. Next as another follow-up, we will introduce ConcatDatabase that allows mixing multiple databases, including the mocked and ones from JSON files. --- include/tvm/meta_schedule/database.h | 19 +++- python/tvm/meta_schedule/database/__init__.py | 1 + python/tvm/meta_schedule/database/database.py | 41 +++++-- .../database/schedule_fn_database.py | 38 +++++++ python/tvm/meta_schedule/testing/utils.py | 83 -------------- src/meta_schedule/database/database.cc | 13 ++- src/meta_schedule/database/memory_database.cc | 10 +- .../database/schedule_fn_database.cc | 103 ++++++++++++++++++ src/relay/backend/te_compiler_cache.cc | 5 +- tests/python/unittest/test_link_params.py | 15 ++- .../test_meta_schedule_multi_anchor.py | 8 +- .../test_meta_schedule_relay_tir_compute.py | 18 +-- .../unittest/test_meta_schedule_tune_relay.py | 7 +- 13 files changed, 226 insertions(+), 135 deletions(-) create mode 100644 python/tvm/meta_schedule/database/schedule_fn_database.py delete mode 100644 python/tvm/meta_schedule/testing/utils.py create mode 100644 src/meta_schedule/database/schedule_fn_database.cc diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 0e7f45d39332..88db2e227786 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -207,23 +207,29 @@ class DatabaseNode : public runtime::Object { * \brief Query the best record of the given workload from the database. * \param mod The IRModule to be searched for. * \param target The target to be searched for. + * \param workload_name The name of the workload to be searched for. * \return The best record of the given workload; NullOpt if not found. */ - virtual Optional QueryTuningRecord(IRModule mod, Target target); + virtual Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const String& workload_name); /*! * \brief Query the best schedule of the given workload from the database. * \param mod The IRModule to be searched for. * \param target The target to be searched for. + * \param workload_name The name of the workload to be searched for. * \return The schedule in the best schedule of the given workload; NullOpt if not found. */ - virtual Optional QuerySchedule(IRModule mod, Target target); + virtual Optional QuerySchedule(const IRModule& mod, const Target& target, + const String& workload_name); /*! * \brief Query the best IRModule of the given workload from the database. * \param mod The IRModule to be searched for. * \param target The target to be searched for. + * \param workload_name The name of the workload to be searched for. * \return The IRModule in the best IRModule of the given workload; NullOpt if not found. */ - virtual Optional QueryIRModule(IRModule mod, Target target); + virtual Optional QueryIRModule(const IRModule& mod, const Target& target, + const String& workload_name); static constexpr const char* _type_key = "meta_schedule.Database"; TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object); @@ -336,6 +342,13 @@ class Database : public runtime::ObjectRef { public: /*! An in-memory database. */ TVM_DLL static Database MemoryDatabase(); + /*! + * \brief A database for injecting handcrafted schedule functions. + * \param schedule_fn The function to do scheduling, which takes a TIR schedule, + * and returns a boolean indicating if the schedule is successful. + */ + TVM_DLL static Database ScheduleFnDatabase( + runtime::TypedPackedFunc schedule_fn); /*! * \brief Create a default database that uses JSON file for tuning records. * \param path_workload The path to the workload table. diff --git a/python/tvm/meta_schedule/database/__init__.py b/python/tvm/meta_schedule/database/__init__.py index 2a87eea147d9..7726daf6eb63 100644 --- a/python/tvm/meta_schedule/database/__init__.py +++ b/python/tvm/meta_schedule/database/__init__.py @@ -21,3 +21,4 @@ from .database import Database, PyDatabase, TuningRecord, Workload from .json_database import JSONDatabase from .memory_database import MemoryDatabase +from .schedule_fn_database import ScheduleFnDatabase diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 68283b4554e5..aa509b715132 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -235,7 +235,12 @@ def __len__(self) -> int: """ return _ffi_api.DatabaseSize(self) # type: ignore # pylint: disable=no-member - def query_tuning_record(self, mod: IRModule, target: Target) -> Optional[TuningRecord]: + def query_tuning_record( + self, + mod: IRModule, + target: Target, + workload_name: str, + ) -> Optional[TuningRecord]: """Query the best record of the given workload from the database. Parameters @@ -244,15 +249,22 @@ def query_tuning_record(self, mod: IRModule, target: Target) -> Optional[TuningR The IRModule to be searched for. target : Target The target to be searched for. + workload_name : str + The name of the workload to be searched for. Returns ------- tuning_record : Optional[TuningRecord] The best record of the given workload; None if not found. """ - return _ffi_api.DatabaseQueryTuningRecord(self, mod, target) # type: ignore # pylint: disable=no-member + return _ffi_api.DatabaseQueryTuningRecord(self, mod, target, workload_name) # type: ignore # pylint: disable=no-member - def query_schedule(self, mod: IRModule, target: Target) -> Optional[Schedule]: + def query_schedule( + self, + mod: IRModule, + target: Target, + workload_name: str, + ) -> Optional[Schedule]: """Query the best schedule of the given workload from the database. Parameters @@ -261,15 +273,22 @@ def query_schedule(self, mod: IRModule, target: Target) -> Optional[Schedule]: The IRModule to be searched for. target : Target The target to be searched for. + workload_name : str + The name of the workload to be searched for. Returns ------- schedule : Optional[Schedule] The best schedule of the given workload; None if not found. """ - return _ffi_api.DatabaseQuerySchedule(self, mod, target) # type: ignore # pylint: disable=no-member + return _ffi_api.DatabaseQuerySchedule(self, mod, target, workload_name) # type: ignore # pylint: disable=no-member - def query_ir_module(self, mod: IRModule, target: Target) -> Optional[IRModule]: + def query_ir_module( + self, + mod: IRModule, + target: Target, + workload_name: str, + ) -> Optional[IRModule]: """Query the best IRModule of the given workload from the database. Parameters @@ -278,18 +297,22 @@ def query_ir_module(self, mod: IRModule, target: Target) -> Optional[IRModule]: The IRModule to be searched for. target : Target The target to be searched for. + workload_name : str + The name of the workload to be searched for. Returns ------- ir_module : Optional[IRModule] The best IRModule of the given workload; None if not found. """ - return _ffi_api.DatabaseQueryIRModule(self, mod, target) # type: ignore # pylint: disable=no-member + return _ffi_api.DatabaseQueryIRModule(self, mod, target, workload_name) # type: ignore # pylint: disable=no-member def query( self, mod: IRModule, target: Target, + *, + workload_name: str = "main", kind: Union[ Literal["schedule"], Literal["record"], @@ -313,11 +336,11 @@ def query( The best optimization outcome of the given workload. """ if kind == "schedule": - return self.query_schedule(mod, target) + return self.query_schedule(mod, target, workload_name) if kind == "record": - return self.query_tuning_record(mod, target) + return self.query_tuning_record(mod, target, workload_name) if kind == "ir_module": - return self.query_ir_module(mod, target) + return self.query_ir_module(mod, target, workload_name) raise ValueError(f'Unknown kind: {kind}. Candidates are: "schedule", "record", "ir_module"') def __enter__(self) -> "Database": diff --git a/python/tvm/meta_schedule/database/schedule_fn_database.py b/python/tvm/meta_schedule/database/schedule_fn_database.py new file mode 100644 index 000000000000..2918f05799dc --- /dev/null +++ b/python/tvm/meta_schedule/database/schedule_fn_database.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A database for injecting handcrafted schedule functions.""" +from typing import Callable + +from tvm._ffi import register_object +from tvm.tir import Schedule + +from .. import _ffi_api +from .database import Database + + +@register_object("meta_schedule.ScheduleFnDatabase") +class ScheduleFnDatabase(Database): + """A database for injecting handcrafted schedule functions.""" + + def __init__( + self, + schedule_fn: Callable[[Schedule], bool], + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.DatabaseScheduleFnDatabase, # type: ignore # pylint: disable=no-member + schedule_fn, + ) diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py deleted file mode 100644 index 5919fb47c809..000000000000 --- a/python/tvm/meta_schedule/testing/utils.py +++ /dev/null @@ -1,83 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Testing utility functions in meta schedule""" -from typing import Callable, Dict, Optional, Union - -from tvm import meta_schedule as ms -from tvm.ir import IRModule, transform -from tvm.relay import Function as RelayFunc -from tvm.runtime import NDArray -from tvm.target import Target -from tvm.tir import Schedule - - -def apply_fixed_schedules( - relay_mod: Union[RelayFunc, IRModule], - target: Union[str, Target], - params: Optional[Dict[str, NDArray]], - schedule_fn: Callable[[ms.ExtractedTask, Schedule], bool], - tir_converter: str = "default", -): - """Apply fixed schedules (manually written, without any tunable knobs) as specified by - schedule_fn to extracted tasks, and return a database that can be passed to compilation. - - Parameters - ---------- - mod : Union[RelayFunc, IRModule] - The Relay module to apply fixed schedules. - target : Union[str, Target] - The target used to extract tasks. - params : Optional[Dict[str, tvm.runtime.NDArray]] - The associated parameters of the module. - schedule_fn : Callable[[ExtractedTask, Schedule], bool] - A callable that is applied for each extracted task and the corresponding default schedule. - Returns True if the given schedule should be committed to the database, False otherwise. - tir_converter : str - The filter function to filter out the extracted tasks. Builtin filters: - - "default" - - "allow_extern" - The converter is a PackedFunc registered as f"relay.backend.tir_converter.{tir_converter}", - with the signature below: - (args: List[te.Tensor], constants: List[NDArray]) -> Optional[tir.PrimFunc] - - Returns - ------- - database : Database - The database containing dummy tuning records for manually scheduled traces. - """ - target = Target(target) if isinstance(target, str) else target - config = {"relay.backend.use_meta_schedule": True} - for k, v in transform.PassContext.current().config.items(): - config[k] = v - - extracted_tasks = ms.extract_task_from_relay( - relay_mod, - target, - params, - tir_converter=tir_converter, - ) - database = ms.database.MemoryDatabase() - for task in extracted_tasks: - mod = ms.default_config.mod(task.dispatched[0]) - sch = Schedule(mod) - - if schedule_fn(task, sch): - workload = database.commit_workload(mod) - tune_rec = ms.database.TuningRecord(sch.trace, workload, [0.0], target, []) - database.commit_tuning_record(tune_rec) - - return database diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index fedd2aa35278..d082ff7a3901 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -156,7 +156,8 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w /******** Database ********/ -Optional DatabaseNode::QueryTuningRecord(IRModule mod, Target target) { +Optional DatabaseNode::QueryTuningRecord(const IRModule& mod, const Target& target, + const String& workload_name) { if (!this->HasWorkload(mod)) { return NullOpt; } @@ -168,8 +169,9 @@ Optional DatabaseNode::QueryTuningRecord(IRModule mod, Target targ return records[0]; } -Optional DatabaseNode::QuerySchedule(IRModule mod, Target target) { - if (Optional opt_record = this->QueryTuningRecord(mod, target)) { +Optional DatabaseNode::QuerySchedule(const IRModule& mod, const Target& target, + const String& workload_name) { + if (Optional opt_record = this->QueryTuningRecord(mod, target, workload_name)) { TuningRecord record = opt_record.value(); tir::Schedule sch = tir::Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0, @@ -181,8 +183,9 @@ Optional DatabaseNode::QuerySchedule(IRModule mod, Target target) } } -Optional DatabaseNode::QueryIRModule(IRModule mod, Target target) { - if (Optional opt_sch = this->QuerySchedule(mod, target)) { +Optional DatabaseNode::QueryIRModule(const IRModule& mod, const Target& target, + const String& workload_name) { + if (Optional opt_sch = this->QuerySchedule(mod, target, workload_name)) { return opt_sch.value()->mod(); } else { return NullOpt; diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index a00d5501ad1d..b6c635555152 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -44,7 +44,7 @@ class MemoryDatabaseNode : public DatabaseNode { return false; } - Workload CommitWorkload(const IRModule& mod) { + Workload CommitWorkload(const IRModule& mod) final { for (const auto& workload : workloads) { if (StructuralEqual()(workload->mod, mod)) { return workload; @@ -55,9 +55,9 @@ class MemoryDatabaseNode : public DatabaseNode { return workload; } - void CommitTuningRecord(const TuningRecord& record) { records.push_back(record); } + void CommitTuningRecord(const TuningRecord& record) final { records.push_back(record); } - Array GetTopK(const Workload& workload, int top_k) { + Array GetTopK(const Workload& workload, int top_k) final { std::vector> results; results.reserve(this->records.size()); for (const TuningRecord& record : records) { @@ -91,9 +91,9 @@ class MemoryDatabaseNode : public DatabaseNode { return ret; } - Array GetAllTuningRecords() { return records; } + Array GetAllTuningRecords() final { return records; } - int64_t Size() { return records.size(); } + int64_t Size() final { return records.size(); } }; Database Database::MemoryDatabase() { diff --git a/src/meta_schedule/database/schedule_fn_database.cc b/src/meta_schedule/database/schedule_fn_database.cc new file mode 100644 index 000000000000..751721fe52d4 --- /dev/null +++ b/src/meta_schedule/database/schedule_fn_database.cc @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class ScheduleFnDatabaseNode : public DatabaseNode { + public: + runtime::TypedPackedFunc schedule_fn; + + void VisitAttrs(AttrVisitor* v) { + // `schedule_fn` is not visited. + } + + static constexpr const char* _type_key = "meta_schedule.ScheduleFnDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnDatabaseNode, DatabaseNode); + + public: + Optional QueryTuningRecord(const IRModule& mod, const Target& target, + const String& workload_name) final { + if (Optional sch = this->QuerySchedule(mod, target, workload_name)) { + return TuningRecord(sch.value()->trace().value(), + /*workload=*/Workload(mod, 0), // + /*run_secs=*/NullOpt, // + /*target=*/target, // + /*arg_info=*/NullOpt); + } + return NullOpt; + } + + Optional QuerySchedule(const IRModule& mod, const Target& target, + const String& workload_name) final { + tir::Schedule sch = + tir::Schedule::Traced(WithAttr(mod, "task_name", workload_name), + /*rand_state=*/-1, + /*debug_mode=*/0, + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); + if (!schedule_fn(sch)) { + return NullOpt; + } + return sch; + } + + bool HasWorkload(const IRModule& mod) final { + LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.HasWorkload"; + throw; + } + + Workload CommitWorkload(const IRModule& mod) final { + LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.CommitWorkload"; + throw; + } + + void CommitTuningRecord(const TuningRecord& record) final { + LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.CommitTuningRecord"; + throw; + } + + Array GetTopK(const Workload& workload, int top_k) final { + LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.GetTopK"; + throw; + } + + Array GetAllTuningRecords() final { + LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.GetAllTuningRecords"; + throw; + } + + int64_t Size() final { + LOG(FATAL) << "NotImplementedError: ScheduleFnDatabase.size"; + throw; + } +}; + +Database Database::ScheduleFnDatabase(runtime::TypedPackedFunc schedule_fn) { + ObjectPtr n = make_object(); + n->schedule_fn = std::move(schedule_fn); + return Database(n); +} + +TVM_REGISTER_NODE_TYPE(ScheduleFnDatabaseNode); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseScheduleFnDatabase") + .set_body_typed(Database::ScheduleFnDatabase); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 0e2a3e270257..1d7566ebe2bd 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -367,7 +367,8 @@ class ScheduleBuilder : public ExprVisitor { if (Optional f = tir_converter(te_args, constants)) { if (Optional opt_record = database_.value()->QueryTuningRecord( /*mod=*/backend::PrimFuncToIRModule(f.value()), - /*target=*/target_)) { + /*target=*/target_, + /*workload_name=*/prim_fn_var->name_hint)) { static InstructionKind kind_transform_layout = InstructionKind::Get("TransformLayout"); TuningRecord record = opt_record.value(); for (const Instruction& inst : record->trace->insts) { @@ -383,6 +384,8 @@ class ScheduleBuilder : public ExprVisitor { ICHECK_EQ(mod->functions.size(), 1); mod = tir::transform::RemoveWeightLayoutRewriteBlock()(std::move(mod)); prim_func = Downcast(mod->Lookup("main")); + } else { + LOG(WARNING) << "Cannot find workload: " << prim_fn_var->name_hint; } } } diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py index c741ecb59ae0..b14c18e55f4b 100644 --- a/tests/python/unittest/test_link_params.py +++ b/tests/python/unittest/test_link_params.py @@ -29,7 +29,6 @@ from tvm import meta_schedule as ms from tvm import relay from tvm.contrib import utils -from tvm.meta_schedule.testing.utils import apply_fixed_schedules from tvm.relay.backend import Executor, Runtime INPUT_SHAPE = (1, 3, 16, 16) @@ -407,21 +406,21 @@ def schedule_dense(sch): target = "llvm" params = {"weight": weight_np} - def schedule_fn(task, sch): - if "nn_dense" in task.task_name: + def schedule_fn(sch): + if "nn_dense" in sch.mod.attrs["task_name"]: schedule_dense(sch) return True return False link_params = True - with tvm.transform.PassContext(config={"relay.FuseOps.link_params": link_params}): - database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) - with StringIO() as stderr_buf, redirect_stderr(stderr_buf): - with database, tvm.transform.PassContext( + with ms.database.ScheduleFnDatabase(schedule_fn), tvm.transform.PassContext( opt_level=3, - config={"relay.backend.use_meta_schedule": True}, + config={ + "relay.backend.use_meta_schedule": True, + "relay.FuseOps.link_params": link_params, + }, ): executor = Executor("graph", {"link-params": link_params}) lib = relay.build(relay_mod, target=target, executor=executor) diff --git a/tests/python/unittest/test_meta_schedule_multi_anchor.py b/tests/python/unittest/test_meta_schedule_multi_anchor.py index 177001781179..cb6f59c6e5d5 100644 --- a/tests/python/unittest/test_meta_schedule_multi_anchor.py +++ b/tests/python/unittest/test_meta_schedule_multi_anchor.py @@ -19,7 +19,6 @@ import tvm.testing from tvm import meta_schedule as ms from tvm import relay -from tvm.meta_schedule.testing.utils import apply_fixed_schedules def get_dense_dense(data_shape, weight_shape): @@ -63,14 +62,13 @@ def test_dense_dense(): target = "llvm" params = {"weight1": weight1_np, "weight2": weight2_np} - def schedule_fn(task, sch): - if "nn_dense_nn_dense" in task.task_name: + def schedule_fn(sch): + if "nn_dense_nn_dense" in sch.mod.attrs["task_name"]: schedule_dense_dense(sch) return True return False - database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) - with database: + with ms.database.ScheduleFnDatabase(schedule_fn): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_meta_schedule": True}, diff --git a/tests/python/unittest/test_meta_schedule_relay_tir_compute.py b/tests/python/unittest/test_meta_schedule_relay_tir_compute.py index 939851a65731..b37333803603 100644 --- a/tests/python/unittest/test_meta_schedule_relay_tir_compute.py +++ b/tests/python/unittest/test_meta_schedule_relay_tir_compute.py @@ -18,8 +18,9 @@ import tvm import tvm.testing import tvm.topi.testing -from tvm import autotvm, relay, te -from tvm.meta_schedule.testing.utils import apply_fixed_schedules +from tvm import autotvm +from tvm import meta_schedule as ms +from tvm import relay, te from tvm.relay.testing.temp_op_attr import TempOpAttr from tvm.script import tir as T @@ -139,21 +140,14 @@ def test_conv2d(): target = "llvm" params = {"weight": weight_np} - def schedule_fn(task, sch): - if "nn_conv2d" in task.task_name: + def schedule_fn(sch): + if "nn_conv2d" in sch.mod.attrs["task_name"]: schedule_tir_conv2d_nchw_oihw(sch) return True return False with TempOpAttr("nn.conv2d", "FTVMStrategy", _tmp_strategy): - database = apply_fixed_schedules( - relay_mod, - target, - params, - schedule_fn, - tir_converter="allow_extern", - ) - with database, tvm.transform.PassContext( + with ms.database.ScheduleFnDatabase(schedule_fn), tvm.transform.PassContext( opt_level=3, config={ "relay.backend.use_meta_schedule": True, diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index bc37fed7d691..b05b57feaf4c 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -29,7 +29,6 @@ from tvm.contrib import graph_executor from tvm.ir import IRModule from tvm.meta_schedule.testing.relay_workload import get_network -from tvm.meta_schedule.testing.utils import apply_fixed_schedules from tvm.script import tir as T from tvm.target.target import Target from tvm.tir.schedule import BlockRV, Schedule @@ -452,8 +451,8 @@ def manual_tir_common(do_tune=False): ) else: - def schedule_fn(task, sch): - if "dense" not in task.task_name: + def schedule_fn(sch) -> bool: + if "dense" not in sch.mod.attrs["task_name"]: return False block = sch.get_block("compute") @@ -468,7 +467,7 @@ def schedule_fn(task, sch): return True - database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) + database = ms.database.ScheduleFnDatabase(schedule_fn) with database, tvm.transform.PassContext( opt_level=3,