diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 88db2e2277867..1812a524763b0 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -357,6 +357,14 @@ class Database : public runtime::ObjectRef { */ TVM_DLL static Database JSONDatabase(String path_workload, String path_tuning_record, bool allow_missing); + /*! + * \brief Create a database merged from multiple databases. + * \param databases The databases to be merged. + * \param preferred The preferred database. If the preferred database responses to a query, + * all other databases will be ignored. + * \return The merged database. + */ + TVM_DLL static Database MergedDatabase(Array merge, Optional preferred); /*! * \brief Create a database with customized methods on the python-side. * \param f_has_workload The packed function of `HasWorkload`. diff --git a/python/tvm/meta_schedule/database/merged_database.py b/python/tvm/meta_schedule/database/merged_database.py new file mode 100644 index 0000000000000..8d7818567cb5f --- /dev/null +++ b/python/tvm/meta_schedule/database/merged_database.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A database consists of multiple databases.""" +from typing import List, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .database import Database + + +@register_object("meta_schedule.MergedDatabase") +class MergedDatabase(Database): + """A merged database from multiple databases""" + + def __init__( + self, + databases: List[Database], + preferred: Optional[Database] = None, + ) -> None: + """Construct a merged database from multiple databases. + + Parameters + ---------- + databases : List[Database] + The list of databases to merge. + preferred : Optional[Database] = None + preferred The preferred database. If the preferred database responses to a query, + all other databases will be ignored. + """ + self.__init_handle_by_constructor__( + _ffi_api.DatabaseMergedDatabase, # type: ignore # pylint: disable=no-member + databases, + preferred, + ) diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 2e4f852608353..91b96c82479f9 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -25,28 +25,6 @@ namespace tvm { namespace meta_schedule { -/*! \brief The struct defining comparison function of sorting by mean run seconds. */ -struct SortTuningRecordByMeanRunSecs { - static const constexpr double kMaxMeanTime = 1e10; - - static double Mean(const Array& a) { - if (a.empty()) { - return kMaxMeanTime; - } - double sum = 0.0; - for (const FloatImm& i : a) { - sum += i->value; - } - return sum / a.size(); - } - - bool operator()(const TuningRecord& a, const TuningRecord& b) const { - double a_time = Mean(a->run_secs.value_or({})); - double b_time = Mean(b->run_secs.value_or({})); - return a_time < b_time; - } -}; - /*! * \brief Read lines from a json file. * \param path The path to the json file. diff --git a/src/meta_schedule/database/merged_database.cc b/src/meta_schedule/database/merged_database.cc new file mode 100644 index 0000000000000..f3513a409e96f --- /dev/null +++ b/src/meta_schedule/database/merged_database.cc @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class MergedDatabaseNode : public DatabaseNode { + public: + Array databases; + Optional preferred; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("databases", &databases); + v->Visit("preferred", &preferred); + } + + static constexpr const char* _type_key = "meta_schedule.MergedDatabase"; + TVM_DECLARE_FINAL_OBJECT_INFO(MergedDatabaseNode, DatabaseNode); + + public: + Optional QueryTuningRecord(IRModule mod, Target target) final { + if (preferred) { + if (Optional record = preferred.value()->QueryTuningRecord(mod, target)) { + return record; + } + } + std::vector results; + results.reserve(databases.size()); + for (const Database& db : databases) { + if (Optional record = db->QueryTuningRecord(mod, target)) { + results.push_back(record.value()); + } + } + if (results.empty()) { + return NullOpt; + } + std::sort(results.begin(), results.end(), SortTuningRecordByMeanRunSecs()); + return results[0]; + } + + bool HasWorkload(const IRModule& mod) final { + LOG(FATAL) << "NotImplementedError: MergedDatabase.HasWorkload"; + throw; + } + + Workload CommitWorkload(const IRModule& mod) final { + LOG(FATAL) << "NotImplementedError: MergedDatabase.CommitWorkload"; + throw; + } + + void CommitTuningRecord(const TuningRecord& record) final { + LOG(FATAL) << "NotImplementedError: MergedDatabase.CommitTuningRecord"; + throw; + } + + Array GetTopK(const Workload& workload, int top_k) final { + LOG(FATAL) << "NotImplementedError: MergedDatabase.GetTopK"; + throw; + } + + Array GetAllTuningRecords() final { + LOG(FATAL) << "NotImplementedError: MergedDatabase.GetAllTuningRecords"; + throw; + } + + int64_t Size() final { + LOG(FATAL) << "NotImplementedError: MergedDatabase.size"; + throw; + } +}; + +Database Database::MergedDatabase(Array databases, Optional preferred) { + ObjectPtr n = make_object(); + n->databases = std::move(databases); + n->preferred = std::move(preferred); + return Database(n); +} + +TVM_REGISTER_NODE_TYPE(MergedDatabaseNode); +TVM_REGISTER_GLOBAL("meta_schedule.DatabaseMergedDatabase") + .set_body_typed(Database::MergedDatabase); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index db37935ec2063..ad56fa7f6a526 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -404,6 +404,28 @@ inline Array AsIntArray(const ObjectRef& obj) { return results; } +/*! \brief The struct defining comparison function of sorting by mean run seconds. */ +struct SortTuningRecordByMeanRunSecs { + static const constexpr double kMaxMeanTime = 1e10; + + static double Mean(const Array& a) { + if (a.empty()) { + return kMaxMeanTime; + } + double sum = 0.0; + for (const FloatImm& i : a) { + sum += i->value; + } + return sum / a.size(); + } + + bool operator()(const TuningRecord& a, const TuningRecord& b) const { + double a_time = Mean(a->run_secs.value_or({})); + double b_time = Mean(b->run_secs.value_or({})); + return a_time < b_time; + } +}; + } // namespace meta_schedule } // namespace tvm