Skip to content

Commit

Permalink
[MetaSchedule] Developer Ergonomics Enhancement (#11622)
Browse files Browse the repository at this point in the history
Per discussion with @Kathryn-cat

- [x] Move `initialize_with_tune_context` as private API `_initialize_with_tune_context`, and
encourage using `TuneContext.initialize`
- [x] Instead of using bunch of import statements, encourage using `ms.xxx` as the prefix
(e.g. `ms.database.MemoryDatabase`) to organize things better
- [x] Move `DefaultLLVM`, `DefaultCUDA` to a separate file and make them more discoverable
- [x] Move `DummyDatabase` to `tvm.meta_schedule.database.MemoryDatabase` given it's actually useful
- [x] Delegate class members' methods in `TuneContext`, for example, having
`TuneContext.generste_design_space` from `TuneContext.space_generator.generste_design_space`

Next PR:
- Allow using a string `"default"` in `TuneContext` as well as `tune_relay/tir/te` to quickly
specify a set of target-specific rules
- Add `TuneContext.tune` to allow directly tuning without task scheduler.
- Enhance detection of `ScheduleFn` in `TuneContext` to make it easier for users to quickly try out
template-driven scheduling on TIR.

Co-Authored-By: Kathryn (Jinqi) Chen <[email protected]>
  • Loading branch information
junrushao and Kathryn-cat authored Jun 10, 2022
1 parent ec24ae6 commit 6fca5c6
Show file tree
Hide file tree
Showing 52 changed files with 1,111 additions and 886 deletions.
11 changes: 4 additions & 7 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,10 @@ class SearchStrategyNode : public runtime::Object {

/*!
* \brief Update the search strategy with measurement results.
* \param context The tuning context.
* \param measure_candidates The candidates to be measured.
* \param results The measurement results from the runner.
*/
virtual void NotifyRunnerResults(const TuneContext& context,
const Array<MeasureCandidate>& measure_candidates,
virtual void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
const Array<RunnerResult>& results) = 0;

static constexpr const char* _type_key = "meta_schedule.SearchStrategy";
Expand Down Expand Up @@ -150,8 +148,8 @@ class PySearchStrategyNode : public SearchStrategyNode {
* \brief The function type of `NotifyRunnerResults` method.
* \param results The measurement results from the runner.
*/
using FNotifyRunnerResults = runtime::TypedPackedFunc<void(
const TuneContext&, const Array<MeasureCandidate>&, const Array<RunnerResult>&)>;
using FNotifyRunnerResults =
runtime::TypedPackedFunc<void(const Array<MeasureCandidate>&, const Array<RunnerResult>&)>;

/*! \brief The packed function to the `InitializeWithTuneContext` method. */
FInitializeWithTuneContext f_initialize_with_tune_context;
Expand All @@ -177,8 +175,7 @@ class PySearchStrategyNode : public SearchStrategyNode {
const Optional<CostModel>& cost_model) final;
void PostTuning() final;
Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final;
void NotifyRunnerResults(const TuneContext& context,
const Array<MeasureCandidate>& measure_candidates,
void NotifyRunnerResults(const Array<MeasureCandidate>& measure_candidates,
const Array<RunnerResult>& results);

static constexpr const char* _type_key = "meta_schedule.PySearchStrategy";
Expand Down
27 changes: 23 additions & 4 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace tvm {
namespace meta_schedule {

class TaskSchedulerNode;
class MeasureCallback;

/*! \brief The auto tuning context. */
class TuneContextNode : public runtime::Object {
Expand Down Expand Up @@ -70,7 +71,7 @@ class TuneContextNode : public runtime::Object {
int num_threads;

/*! \brief Whether the tuning task has been stopped or finished. */
bool is_terminated;
bool is_terminated; // TODO(@junrushao1994): move to TaskScheduler
/*! \brief The measure candidates. */
Optional<Array<MeasureCandidate>> measure_candidates;
/*! \brief The building results. */
Expand All @@ -87,18 +88,36 @@ class TuneContextNode : public runtime::Object {
v->Visit("postprocs", &postprocs);
v->Visit("mutator_probs", &mutator_probs);
v->Visit("task_name", &task_name);
// `logging_func` is not visited
v->Visit("rand_state", &rand_state);
v->Visit("num_threads", &num_threads);
v->Visit("is_terminated", &is_terminated);
v->Visit("measure_candidates", &measure_candidates);
v->Visit("builder_results", &builder_results);
v->Visit("runner_futures", &runner_futures);
v->Visit("measure_candidates", &measure_candidates);
// `logging_func` is not visited
}

/*! \brief Initialize members that needs initialization with tune context. */
void Initialize();

/*! \brief Set the measure candidates from the SearchStrategy */
void _SetMeasureCandidates(const Array<MeasureCandidate>& candidates);
/*!
* \brief Send the measure candidates to builder.
* \param builder The builder to send the candidates to.
*/
void _SendToBuilder(const Builder& builder);
/*!
* \brief Send the built measure candidates to runner.
* \param runner The runner to send the candidates to.
*/
void _SendToRunner(const Runner& runner);
/*!
* \brief Join the running tasks.
* \returns The results from the runner
*/
Array<RunnerResult> _Join();
/*! \brief Set `measure_candidates`, `builder_results` and `runner_futures` to null. */
void _ClearMeasureState();
static constexpr const char* _type_key = "meta_schedule.TuneContext";
TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object);
};
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
builder,
cost_model,
database,
default_config,
feature_extractor,
measure_callback,
mutator,
postproc,
runner,
Expand All @@ -32,5 +34,6 @@
from .extracted_task import ExtractedTask
from .relay_integration import extract_task_from_relay
from .search_strategy import MeasureCandidate
from .tune import TuneConfig, tune_relay, tune_te, tune_tir
from .tune import TuneConfig, tune_extracted_tasks, tune_relay, tune_te, tune_tir
from .tune_context import TuneContext
from .utils import derived_object
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
"""
from .database import Database, PyDatabase, TuningRecord, Workload
from .json_database import JSONDatabase
from .memory_database import MemoryDatabase
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tuning record database"""
"""TuningRecord database"""
from typing import Any, Callable, List, Optional

from tvm._ffi import register_object
Expand Down
63 changes: 63 additions & 0 deletions python/tvm/meta_schedule/database/memory_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""A database that stores TuningRecords in memory"""
from typing import List

from ...ir import IRModule, structural_equal
from ..utils import derived_object
from .database import PyDatabase, TuningRecord, Workload


@derived_object
class MemoryDatabase(PyDatabase):
"""An in-memory database based on python list for testing."""

def __init__(self):
super().__init__()
self.records = []
self.workload_reg = []

def has_workload(self, mod: IRModule) -> bool:
for workload in self.workload_reg:
if structural_equal(workload.mod, mod):
return True
return False

def commit_tuning_record(self, record: TuningRecord) -> None:
self.records.append(record)

def commit_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if structural_equal(workload.mod, mod):
return workload
workload = Workload(mod)
self.workload_reg.append(workload)
return workload

def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
return list(
filter(
lambda x: x.workload == workload,
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
)
)[: int(top_k)]

def __len__(self) -> int:
return len(self.records)

def print_results(self) -> None:
print("\n".join([str(r) for r in self.records]))
Loading

0 comments on commit 6fca5c6

Please sign in to comment.