Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Introduce ScheduleFnDatabase #12626

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,23 +207,29 @@ class DatabaseNode : public runtime::Object {
* \brief Query the best record of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \param workload_name The name of the workload to be searched for.
* \return The best record of the given workload; NullOpt if not found.
*/
virtual Optional<TuningRecord> QueryTuningRecord(IRModule mod, Target target);
virtual Optional<TuningRecord> QueryTuningRecord(const IRModule& mod, const Target& target,
const String& workload_name);
/*!
* \brief Query the best schedule of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \param workload_name The name of the workload to be searched for.
* \return The schedule in the best schedule of the given workload; NullOpt if not found.
*/
virtual Optional<tir::Schedule> QuerySchedule(IRModule mod, Target target);
virtual Optional<tir::Schedule> QuerySchedule(const IRModule& mod, const Target& target,
const String& workload_name);
/*!
* \brief Query the best IRModule of the given workload from the database.
* \param mod The IRModule to be searched for.
* \param target The target to be searched for.
* \param workload_name The name of the workload to be searched for.
* \return The IRModule in the best IRModule of the given workload; NullOpt if not found.
*/
virtual Optional<IRModule> QueryIRModule(IRModule mod, Target target);
virtual Optional<IRModule> QueryIRModule(const IRModule& mod, const Target& target,
const String& workload_name);

static constexpr const char* _type_key = "meta_schedule.Database";
TVM_DECLARE_BASE_OBJECT_INFO(DatabaseNode, runtime::Object);
Expand Down Expand Up @@ -336,6 +342,13 @@ class Database : public runtime::ObjectRef {
public:
/*! An in-memory database. */
TVM_DLL static Database MemoryDatabase();
/*!
* \brief A database for injecting handcrafted schedule functions.
* \param schedule_fn The function to do scheduling, which takes a TIR schedule,
* and returns a boolean indicating if the schedule is successful.
*/
TVM_DLL static Database ScheduleFnDatabase(
runtime::TypedPackedFunc<bool(tir::Schedule)> schedule_fn);
/*!
* \brief Create a default database that uses JSON file for tuning records.
* \param path_workload The path to the workload table.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .database import Database, PyDatabase, TuningRecord, Workload
from .json_database import JSONDatabase
from .memory_database import MemoryDatabase
from .schedule_fn_database import ScheduleFnDatabase
41 changes: 32 additions & 9 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,12 @@ def __len__(self) -> int:
"""
return _ffi_api.DatabaseSize(self) # type: ignore # pylint: disable=no-member

def query_tuning_record(self, mod: IRModule, target: Target) -> Optional[TuningRecord]:
def query_tuning_record(
self,
mod: IRModule,
target: Target,
workload_name: str,
) -> Optional[TuningRecord]:
"""Query the best record of the given workload from the database.

Parameters
Expand All @@ -244,15 +249,22 @@ def query_tuning_record(self, mod: IRModule, target: Target) -> Optional[TuningR
The IRModule to be searched for.
target : Target
The target to be searched for.
workload_name : str
The name of the workload to be searched for.

Returns
-------
tuning_record : Optional[TuningRecord]
The best record of the given workload; None if not found.
"""
return _ffi_api.DatabaseQueryTuningRecord(self, mod, target) # type: ignore # pylint: disable=no-member
return _ffi_api.DatabaseQueryTuningRecord(self, mod, target, workload_name) # type: ignore # pylint: disable=no-member

def query_schedule(self, mod: IRModule, target: Target) -> Optional[Schedule]:
def query_schedule(
self,
mod: IRModule,
target: Target,
workload_name: str,
) -> Optional[Schedule]:
"""Query the best schedule of the given workload from the database.

Parameters
Expand All @@ -261,15 +273,22 @@ def query_schedule(self, mod: IRModule, target: Target) -> Optional[Schedule]:
The IRModule to be searched for.
target : Target
The target to be searched for.
workload_name : str
The name of the workload to be searched for.

Returns
-------
schedule : Optional[Schedule]
The best schedule of the given workload; None if not found.
"""
return _ffi_api.DatabaseQuerySchedule(self, mod, target) # type: ignore # pylint: disable=no-member
return _ffi_api.DatabaseQuerySchedule(self, mod, target, workload_name) # type: ignore # pylint: disable=no-member

def query_ir_module(self, mod: IRModule, target: Target) -> Optional[IRModule]:
def query_ir_module(
self,
mod: IRModule,
target: Target,
workload_name: str,
) -> Optional[IRModule]:
"""Query the best IRModule of the given workload from the database.

Parameters
Expand All @@ -278,18 +297,22 @@ def query_ir_module(self, mod: IRModule, target: Target) -> Optional[IRModule]:
The IRModule to be searched for.
target : Target
The target to be searched for.
workload_name : str
The name of the workload to be searched for.

Returns
-------
ir_module : Optional[IRModule]
The best IRModule of the given workload; None if not found.
"""
return _ffi_api.DatabaseQueryIRModule(self, mod, target) # type: ignore # pylint: disable=no-member
return _ffi_api.DatabaseQueryIRModule(self, mod, target, workload_name) # type: ignore # pylint: disable=no-member

def query(
self,
mod: IRModule,
target: Target,
*,
workload_name: str = "main",
kind: Union[
Literal["schedule"],
Literal["record"],
Expand All @@ -313,11 +336,11 @@ def query(
The best optimization outcome of the given workload.
"""
if kind == "schedule":
return self.query_schedule(mod, target)
return self.query_schedule(mod, target, workload_name)
if kind == "record":
return self.query_tuning_record(mod, target)
return self.query_tuning_record(mod, target, workload_name)
if kind == "ir_module":
return self.query_ir_module(mod, target)
return self.query_ir_module(mod, target, workload_name)
raise ValueError(f'Unknown kind: {kind}. Candidates are: "schedule", "record", "ir_module"')

def __enter__(self) -> "Database":
Expand Down
38 changes: 38 additions & 0 deletions python/tvm/meta_schedule/database/schedule_fn_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A database for injecting handcrafted schedule functions."""
from typing import Callable

from tvm._ffi import register_object
from tvm.tir import Schedule

from .. import _ffi_api
from .database import Database


@register_object("meta_schedule.ScheduleFnDatabase")
class ScheduleFnDatabase(Database):
"""A database for injecting handcrafted schedule functions."""

def __init__(
self,
schedule_fn: Callable[[Schedule], bool],
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.DatabaseScheduleFnDatabase, # type: ignore # pylint: disable=no-member
schedule_fn,
)
83 changes: 0 additions & 83 deletions python/tvm/meta_schedule/testing/utils.py

This file was deleted.

13 changes: 8 additions & 5 deletions src/meta_schedule/database/database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& w

/******** Database ********/

Optional<TuningRecord> DatabaseNode::QueryTuningRecord(IRModule mod, Target target) {
Optional<TuningRecord> DatabaseNode::QueryTuningRecord(const IRModule& mod, const Target& target,
const String& workload_name) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we pass workload_name? The these three functions do not seem to use this: QueryTuningRecord, QuerySchedule, QueryIRModule.

if (!this->HasWorkload(mod)) {
return NullOpt;
}
Expand All @@ -168,8 +169,9 @@ Optional<TuningRecord> DatabaseNode::QueryTuningRecord(IRModule mod, Target targ
return records[0];
}

Optional<tir::Schedule> DatabaseNode::QuerySchedule(IRModule mod, Target target) {
if (Optional<TuningRecord> opt_record = this->QueryTuningRecord(mod, target)) {
Optional<tir::Schedule> DatabaseNode::QuerySchedule(const IRModule& mod, const Target& target,
const String& workload_name) {
if (Optional<TuningRecord> opt_record = this->QueryTuningRecord(mod, target, workload_name)) {
TuningRecord record = opt_record.value();
tir::Schedule sch =
tir::Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0,
Expand All @@ -181,8 +183,9 @@ Optional<tir::Schedule> DatabaseNode::QuerySchedule(IRModule mod, Target target)
}
}

Optional<IRModule> DatabaseNode::QueryIRModule(IRModule mod, Target target) {
if (Optional<tir::Schedule> opt_sch = this->QuerySchedule(mod, target)) {
Optional<IRModule> DatabaseNode::QueryIRModule(const IRModule& mod, const Target& target,
const String& workload_name) {
if (Optional<tir::Schedule> opt_sch = this->QuerySchedule(mod, target, workload_name)) {
return opt_sch.value()->mod();
} else {
return NullOpt;
Expand Down
10 changes: 5 additions & 5 deletions src/meta_schedule/database/memory_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class MemoryDatabaseNode : public DatabaseNode {
return false;
}

Workload CommitWorkload(const IRModule& mod) {
Workload CommitWorkload(const IRModule& mod) final {
for (const auto& workload : workloads) {
if (StructuralEqual()(workload->mod, mod)) {
return workload;
Expand All @@ -55,9 +55,9 @@ class MemoryDatabaseNode : public DatabaseNode {
return workload;
}

void CommitTuningRecord(const TuningRecord& record) { records.push_back(record); }
void CommitTuningRecord(const TuningRecord& record) final { records.push_back(record); }

Array<TuningRecord> GetTopK(const Workload& workload, int top_k) {
Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
std::vector<std::pair<double, TuningRecord>> results;
results.reserve(this->records.size());
for (const TuningRecord& record : records) {
Expand Down Expand Up @@ -91,9 +91,9 @@ class MemoryDatabaseNode : public DatabaseNode {
return ret;
}

Array<TuningRecord> GetAllTuningRecords() { return records; }
Array<TuningRecord> GetAllTuningRecords() final { return records; }

int64_t Size() { return records.size(); }
int64_t Size() final { return records.size(); }
};

Database Database::MemoryDatabase() {
Expand Down
Loading