Skip to content

Commit

Permalink
[MetaSchedule] Introduce MergedDatabase
Browse files Browse the repository at this point in the history
Following up apache#12520 and apache#12626, this PR introduces `MergedDatabase`,
which allow users to compose multiple databases so that the high-level
IR could select the best tuning records among them.

The `MergedDatabase` also comes with an extra field `preferred` to allow
users to override tuning records from other databases. A classic usecase
of the `preferred` parameter is through handcrafted schedule functions:

```python
def schedule_fn(sch: tir.Schedule) -> bool:
  if "nn_conv2d" in sch.mod.attrs["task_name"]:
    handcrafted_scheduling(sch)
    return True
  return False

with ms.database.MergedDatabase(
  databases=[database],
  preferred=ms.database.ScheduleFn(schedule_fn),
):
  lib = relay.build(...)
```
  • Loading branch information
junrushao committed Aug 29, 2022
1 parent c5c99a4 commit a95b1eb
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 22 deletions.
8 changes: 8 additions & 0 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,14 @@ class Database : public runtime::ObjectRef {
*/
TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record,
bool allow_missing);
/*!
* \brief Create a database merged from multiple databases.
* \param databases The databases to be merged.
* \param preferred The preferred database. If the preferred database responses to a query,
* all other databases will be ignored.
* \return The merged database.
*/
TVM_DLL static Database MergedDatabase(Array<Database, void> merge, Optional<Database> preferred);
/*!
* \brief Create a database with customized methods on the python-side.
* \param f_has_workload The packed function of `HasWorkload`.
Expand Down
49 changes: 49 additions & 0 deletions python/tvm/meta_schedule/database/merged_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A database consists of multiple databases."""
from typing import List, Optional

from tvm._ffi import register_object

from .. import _ffi_api
from .database import Database


@register_object("meta_schedule.MergedDatabase")
class MergedDatabase(Database):
"""A merged database from multiple databases"""

def __init__(
self,
databases: List[Database],
preferred: Optional[Database] = None,
) -> None:
"""Construct a merged database from multiple databases.
Parameters
----------
databases : List[Database]
The list of databases to merge.
preferred : Optional[Database] = None
preferred The preferred database. If the preferred database responses to a query,
all other databases will be ignored.
"""
self.__init_handle_by_constructor__(
_ffi_api.DatabaseMergedDatabase, # type: ignore # pylint: disable=no-member
databases,
preferred,
)
22 changes: 0 additions & 22 deletions src/meta_schedule/database/json_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,6 @@
namespace tvm {
namespace meta_schedule {

/*! \brief The struct defining comparison function of sorting by mean run seconds. */
struct SortTuningRecordByMeanRunSecs {
static const constexpr double kMaxMeanTime = 1e10;

static double Mean(const Array<FloatImm>& a) {
if (a.empty()) {
return kMaxMeanTime;
}
double sum = 0.0;
for (const FloatImm& i : a) {
sum += i->value;
}
return sum / a.size();
}

bool operator()(const TuningRecord& a, const TuningRecord& b) const {
double a_time = Mean(a->run_secs.value_or({}));
double b_time = Mean(b->run_secs.value_or({}));
return a_time < b_time;
}
};

/*!
* \brief Read lines from a json file.
* \param path The path to the json file.
Expand Down
101 changes: 101 additions & 0 deletions src/meta_schedule/database/merged_database.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace meta_schedule {

class MergedDatabaseNode : public DatabaseNode {
public:
Array<Database> databases;
Optional<Database> preferred;

void VisitAttrs(AttrVisitor* v) {
v->Visit("databases", &databases);
v->Visit("preferred", &preferred);
}

static constexpr const char* _type_key = "meta_schedule.MergedDatabase";
TVM_DECLARE_FINAL_OBJECT_INFO(MergedDatabaseNode, DatabaseNode);

public:
Optional<TuningRecord> QueryTuningRecord(IRModule mod, Target target) final {
if (preferred) {
if (Optional<TuningRecord> record = preferred.value()->QueryTuningRecord(mod, target)) {
return record;
}
}
std::vector<TuningRecord> results;
results.reserve(databases.size());
for (const Database& db : databases) {
if (Optional<TuningRecord> record = db->QueryTuningRecord(mod, target)) {
results.push_back(record.value());
}
}
if (results.empty()) {
return NullOpt;
}
std::sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs());
return results[0];
}

bool HasWorkload(const IRModule& mod) final {
LOG(FATAL) << "NotImplementedError: MergedDatabase.HasWorkload";
throw;
}

Workload CommitWorkload(const IRModule& mod) final {
LOG(FATAL) << "NotImplementedError: MergedDatabase.CommitWorkload";
throw;
}

void CommitTuningRecord(const TuningRecord& record) final {
LOG(FATAL) << "NotImplementedError: MergedDatabase.CommitTuningRecord";
throw;
}

Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
LOG(FATAL) << "NotImplementedError: MergedDatabase.GetTopK";
throw;
}

Array<TuningRecord> GetAllTuningRecords() final {
LOG(FATAL) << "NotImplementedError: MergedDatabase.GetAllTuningRecords";
throw;
}

int64_t Size() final {
LOG(FATAL) << "NotImplementedError: MergedDatabase.size";
throw;
}
};

Database Database::MergedDatabase(Array<Database> databases, Optional<Database> preferred) {
ObjectPtr<MergedDatabaseNode> n = make_object<MergedDatabaseNode>();
n->databases = std::move(databases);
n->preferred = std::move(preferred);
return Database(n);
}

TVM_REGISTER_NODE_TYPE(MergedDatabaseNode);
TVM_REGISTER_GLOBAL("meta_schedule.DatabaseMergedDatabase")
.set_body_typed(Database::MergedDatabase);

} // namespace meta_schedule
} // namespace tvm
22 changes: 22 additions & 0 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,28 @@ inline Array<Integer> AsIntArray(const ObjectRef& obj) {
return results;
}

/*! \brief The struct defining comparison function of sorting by mean run seconds. */
struct SortTuningRecordByMeanRunSecs {
static const constexpr double kMaxMeanTime = 1e10;

static double Mean(const Array<FloatImm>& a) {
if (a.empty()) {
return kMaxMeanTime;
}
double sum = 0.0;
for (const FloatImm& i : a) {
sum += i->value;
}
return sum / a.size();
}

bool operator()(const TuningRecord& a, const TuningRecord& b) const {
double a_time = Mean(a->run_secs.value_or({}));
double b_time = Mean(b->run_secs.value_or({}));
return a_time < b_time;
}
};

} // namespace meta_schedule
} // namespace tvm

Expand Down

0 comments on commit a95b1eb

Please sign in to comment.