Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Evo Independence from TaskScheduler #11590

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,16 @@ class SearchStrategyNode : public runtime::Object {

/*!
* \brief Pre-tuning for the search strategy.
* \param design_spaces The design spaces for pre-tuning.
* \param design_spaces The design spaces used during tuning process.
* \param database The database used during tuning process.
* \param cost_model The cost model used during tuning process.
* \note Pre-tuning is supposed to be called before the tuning process and after the
* initialization. Because the search strategy is stateful, we can always call pretuning
* and reset the search strategy.
*/
virtual void PreTuning(const Array<tir::Schedule>& design_spaces) = 0;
virtual void PreTuning(const Array<tir::Schedule>& design_spaces,
const Optional<Database>& database,
const Optional<CostModel>& cost_model) = 0;

/*!
* \brief Post-tuning for the search strategy.
Expand Down Expand Up @@ -159,7 +163,8 @@ class PySearchStrategyNode : public SearchStrategyNode {
* \brief The function type of `PreTuning` method.
* \param design_spaces The design spaces for pre-tuning.
*/
using FPreTuning = runtime::TypedPackedFunc<void(const Array<tir::Schedule>&)>;
using FPreTuning = runtime::TypedPackedFunc<void(
const Array<tir::Schedule>&, const Optional<Database>&, const Optional<CostModel>&)>;
/*! \brief The function type of `PostTuning` method. */
using FPostTuning = runtime::TypedPackedFunc<void()>;
/*!
Expand Down Expand Up @@ -199,10 +204,8 @@ class PySearchStrategyNode : public SearchStrategyNode {
this->f_initialize_with_tune_context(context);
}

void PreTuning(const Array<tir::Schedule>& design_spaces) final {
ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!";
this->f_pre_tuning(design_spaces);
}
void PreTuning(const Array<tir::Schedule>& design_spaces, const Optional<Database>& database,
const Optional<CostModel>& cost_model) final;

void PostTuning() final {
ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!";
Expand Down
20 changes: 10 additions & 10 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ class TaskSchedulerNode : public runtime::Object {
/*! \brief The runner of the scheduler. */
Runner runner{nullptr};
/*! \brief The database of the scheduler. */
Database database{nullptr};
/*! \brief The maximum number of trials allowed. */
int max_trials;
Optional<Database> database;
/*! \brief The cost model of the scheduler. */
Optional<CostModel> cost_model;
/*! \brief The list of measure callbacks of the scheduler. */
Array<MeasureCallback> measure_callbacks;
/*! \brief The maximum number of trials allowed. */
int max_trials;
/*! \brief The number of trials already conducted. */
int num_trials_already;
/*! \brief The tuning task's logging function. t*/
Expand All @@ -94,9 +94,9 @@ class TaskSchedulerNode : public runtime::Object {
v->Visit("builder", &builder);
v->Visit("runner", &runner);
v->Visit("database", &database);
v->Visit("max_trials", &max_trials);
v->Visit("cost_model", &cost_model);
v->Visit("measure_callbacks", &measure_callbacks);
v->Visit("max_trials", &max_trials);
v->Visit("num_trials_already", &num_trials_already);
// `logging_func` is not visited
}
Expand Down Expand Up @@ -243,10 +243,10 @@ class TaskScheduler : public runtime::ObjectRef {
TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<Database> database, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
int max_trials, //
PackedFunc logging_func);
/*!
* \brief Create a task scheduler that fetches tasks in a gradient based fashion.
Expand All @@ -268,10 +268,10 @@ class TaskScheduler : public runtime::ObjectRef {
Array<FloatImm> task_weights, //
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<Database> database, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
int max_trials, //
PackedFunc logging_func, //
double alpha, //
int window_size, //
Expand All @@ -297,10 +297,10 @@ class TaskScheduler : public runtime::ObjectRef {
Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
int max_trials, //
Optional<Database> database, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks, //
int max_trials, //
PackedFunc logging_func, //
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
Expand Down
2 changes: 0 additions & 2 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ class TuneContextNode : public runtime::Object {
/*! \brief The number of threads to be used. */
int num_threads;

/*! \brief The task scheduler that owns the tune context */
const TaskSchedulerNode* task_scheduler;
/*! \brief Whether the tuning task has been stopped or finished. */
bool is_terminated;
/*! \brief The measure candidates. */
Expand Down
24 changes: 20 additions & 4 deletions python/tvm/meta_schedule/search_strategy/search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Meta Schedule search strategy that generates the measure
candidates for measurement.
"""
from typing import Callable, List, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, List, Optional

from tvm._ffi import register_object
from tvm.runtime import Object
Expand All @@ -29,6 +29,8 @@
from ..runner import RunnerResult

if TYPE_CHECKING:
from ..cost_model import CostModel
from ..database import Database
from ..tune_context import TuneContext


Expand Down Expand Up @@ -87,15 +89,29 @@ def initialize_with_tune_context(self, context: "TuneContext") -> None:
self, context
)

def pre_tuning(self, design_spaces: List[Schedule]) -> None:
def pre_tuning(
self,
design_spaces: List[Schedule],
database: Optional["Database"] = None,
cost_model: Optional["CostModel"] = None,
) -> None:
"""Pre-tuning for the search strategy.

Parameters
----------
design_spaces : List[Schedule]
The design spaces for pre-tuning.
The design spaces used during tuning process.
database : Optional[Database] = None
The database used during tuning process.
cost_model : Optional[CostModel] = None
The cost model used during tuning process.
"""
_ffi_api.SearchStrategyPreTuning(self, design_spaces) # type: ignore # pylint: disable=no-member
_ffi_api.SearchStrategyPreTuning( # type: ignore # pylint: disable=no-member
self,
design_spaces,
database,
cost_model,
)

def post_tuning(self) -> None:
"""Post-tuning for the search strategy."""
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/meta_schedule/task_scheduler/gradient_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def __init__(
task_weights: List[float],
builder: Builder,
runner: Runner,
database: Database,
max_trials: int,
*,
database: Database,
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
max_trials: int,
alpha: float = 0.2,
window_size: int = 3,
seed: int = -1,
Expand All @@ -68,12 +68,12 @@ def __init__(
The runner.
database : Database
The database.
max_trials : int
The maximum number of trials to run.
cost_model : CostModel, default None.
The cost model of the scheduler.
measure_callbacks : Optional[List[MeasureCallback]] = None
The list of measure callbacks of the scheduler.
max_trials : int
The maximum number of trials to run.
alpha : float = 0.2
The parameter alpha in gradient computation.
window_size : int = 3
Expand All @@ -88,9 +88,9 @@ def __init__(
builder,
runner,
database,
max_trials,
cost_model,
measure_callbacks,
max_trials,
make_logging_func(logger),
alpha,
window_size,
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/meta_schedule/task_scheduler/round_robin.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ def __init__(
task_weights: List[float],
builder: Builder,
runner: Runner,
database: Database,
max_trials: int,
*,
database: Database,
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
max_trials: int,
) -> None:
"""Constructor.

Expand All @@ -80,12 +80,12 @@ def __init__(
The runner.
database : Database
The database.
max_trials : int
The maximum number of trials.
cost_model : Optional[CostModel]
The cost model.
measure_callbacks: Optional[List[MeasureCallback]]
The list of measure callbacks of the scheduler.
max_trials : int
The maximum number of trials.
"""
del task_weights
self.__init_handle_by_constructor__(
Expand All @@ -94,8 +94,8 @@ def __init__(
builder,
runner,
database,
max_trials,
cost_model,
measure_callbacks,
max_trials,
make_logging_func(logger),
)
10 changes: 5 additions & 5 deletions python/tvm/meta_schedule/task_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from ..tune_context import TuneContext
from ..utils import make_logging_func


logger = logging.getLogger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -177,9 +176,9 @@ class PyTaskScheduler:
"builder",
"runner",
"database",
"max_trials",
"cost_model",
"measure_callbacks",
"max_trials",
],
"methods": [
"tune",
Expand All @@ -195,18 +194,19 @@ def __init__(
tasks: List[TuneContext],
builder: Builder,
runner: Runner,
database: Database,
max_trials: int,
*,
database: Optional[Database] = None,
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
max_trials: int,
):
self.tasks = tasks
self.builder = builder
self.runner = runner
self.database = database
self.max_trials = max_trials
self.cost_model = cost_model
self.measure_callbacks = measure_callbacks
self.max_trials = max_trials

def tune(self) -> None:
"""Auto-tuning."""
Expand Down
5 changes: 4 additions & 1 deletion src/meta_schedule/measure_callback/add_to_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@ class AddToDatabaseNode : public MeasureCallbackNode {
const Array<MeasureCandidate>& measure_candidates,
const Array<BuilderResult>& builder_results,
const Array<RunnerResult>& runner_results) final {
if (!task_scheduler->database.defined()) {
return;
}
TuneContext task = task_scheduler->tasks[task_id];
Database database = task_scheduler->database;
Database database = task_scheduler->database.value();
Workload workload = database->CommitWorkload(task->mod.value());
Target target = task->target.value();
ICHECK_EQ(runner_results.size(), measure_candidates.size());
Expand Down
Loading