Skip to content

Commit

Permalink
[Meta Schedule] Minor Fixes (apache#507)
Browse files Browse the repository at this point in the history
* Fix sttr func & schedule naming.

* Remoove empty list as default value in Taskcheduler & TuneContext.

* Fix schedule -> sch.

Co-authored-by: Junru Shao <[email protected]>
  • Loading branch information
zxybazh and junrushao authored Nov 9, 2021
1 parent ed4a8cc commit ad7adb3
Show file tree
Hide file tree
Showing 12 changed files with 62 additions and 43 deletions.
2 changes: 1 addition & 1 deletion include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class PostprocNode : public runtime::Object {
* \param sch The schedule to be post processed.
* \return Whether the post processing was successfully applied.
*/
virtual bool Apply(const tir::Schedule& schedule) = 0;
virtual bool Apply(const tir::Schedule& sch) = 0;

static constexpr const char* _type_key = "meta_schedule.Postproc";
TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object);
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class TaskSchedulerNode : public runtime::Object {
/*! \brief The database of the scheduler. */
Database database{nullptr};
/*! \brief The list of measure callbacks of the scheduler. */
Array<MeasureCallback> measure_callbacks;
Optional<Array<MeasureCallback>> measure_callbacks;

/*! \brief The default desctructor. */
virtual ~TaskSchedulerNode() = default;
Expand Down Expand Up @@ -250,13 +250,13 @@ class TaskScheduler : public runtime::ObjectRef {
Builder builder, //
Runner runner, //
Database database, //
Array<MeasureCallback> measure_callbacks);
Optional<Array<MeasureCallback>> measure_callbacks);
TVM_DLL static TaskScheduler PyTaskScheduler(
Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
Array<MeasureCallback> measure_callbacks, //
Optional<Array<MeasureCallback>> measure_callbacks, //
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
Expand Down
12 changes: 6 additions & 6 deletions include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ class TuneContextNode : public runtime::Object {
/*! \brief The search strategy. */
Optional<SearchStrategy> search_strategy;
/*! \brief The schedule rules. */
Array<ScheduleRule> sch_rules;
Optional<Array<ScheduleRule>> sch_rules;
/*! \brief The post processings. */
Array<Postproc> postprocs;
Optional<Array<Postproc>> postprocs;
/*! \brief The mutators. */
Array<Mutator> mutators;
Optional<Array<Mutator>> mutators;
/*! \brief The name of the tuning task. */
Optional<String> task_name;
/*! \brief The random state. */
Expand Down Expand Up @@ -105,9 +105,9 @@ class TuneContext : public runtime::ObjectRef {
Optional<Target> target, //
Optional<SpaceGenerator> space_generator, //
Optional<SearchStrategy> search_strategy, //
Array<ScheduleRule> sch_rules, //
Array<Postproc> postprocs, //
Array<Mutator> mutators, //
Optional<Array<ScheduleRule>> sch_rules, //
Optional<Array<Postproc>> postprocs, //
Optional<Array<Mutator>> mutators, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads);
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/mutator/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ def f_as_string() -> str:
)

def __str__(self) -> str:
return f"PyMutator({_get_hex_address(self.handle)})"
return f"{self.__class__.__name__}({_get_hex_address(self.handle)})"
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/postproc/postproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,4 @@ def f_as_string() -> str:
)

def __str__(self) -> str:
return f"PyPostproc({_get_hex_address(self.handle)})"
return f"{self.__class__.__name__}({_get_hex_address(self.handle)})"
22 changes: 18 additions & 4 deletions python/tvm/meta_schedule/task_scheduler/round_robin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""Round Robin Task Scheduler"""

from typing import List, TYPE_CHECKING
from typing import List, Optional, TYPE_CHECKING

from tvm._ffi import register_object
from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback
Expand All @@ -34,15 +34,29 @@

@register_object("meta_schedule.RoundRobin")
class RoundRobin(TaskScheduler):
"""Round Robin Task Scheduler"""
"""Round Robin Task Scheduler
Parameters
----------
tasks: List[TuneContext]
The list of tune context to process.
builder: Builder
The builder of the scheduler.
runner: Runner
The runner of the scheduler.
database: Database
The database of the scheduler.
measure_callbacks: Optional[List[MeasureCallback]] = None
The list of measure callbacks of the scheduler.
"""

def __init__(
self,
tasks: List["TuneContext"],
builder: Builder,
runner: Runner,
database: Database,
measure_callbacks: List[MeasureCallback] = [],
measure_callbacks: Optional[List[MeasureCallback]] = None,
) -> None:
"""Constructor.
Expand All @@ -56,7 +70,7 @@ def __init__(
The runner.
database : Database
The database.
measure_callbacks: List[MeasureCallback]
measure_callbacks: Optional[List[MeasureCallback]]
The list of measure callbacks of the scheduler.
"""
self.__init_handle_by_constructor__(
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/meta_schedule/task_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""Auto-tuning Task Scheduler"""

from typing import List
from typing import List, Optional

from tvm._ffi import register_object
from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback
Expand Down Expand Up @@ -44,7 +44,7 @@ class TaskScheduler(Object):
The runner of the scheduler.
database: Database
The database of the scheduler.
measure_callbacks: List[MeasureCallback]
measure_callbacks: List[MeasureCallback] = None
The list of measure callbacks of the scheduler.
"""

Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(
builder: Builder,
runner: Runner,
database: Database,
measure_callbacks: List[MeasureCallback] = [],
measure_callbacks: Optional[List[MeasureCallback]] = None,
):
"""Constructor.
Expand Down
18 changes: 9 additions & 9 deletions python/tvm/meta_schedule/tune_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ class TuneContext(Object):
The design space generator.
search_strategy : Optional[SearchStrategy] = None
The search strategy.
sch_rules : List[ScheduleRule] = []
sch_rules: Optional[List[ScheduleRule]] = None,
The schedule rules.
postproc : List[Postproc] = []
postproc: Optional[List[Postproc"]] = None,
The post processings.
mutator : List[Mutator] = []
mutator: Optional[List[Mutator]] = None,
The mutators.
task_name : Optional[str] = None
The name of the tuning task.
Expand All @@ -80,9 +80,9 @@ class TuneContext(Object):
target: Optional[Target]
space_generator: Optional["SpaceGenerator"]
search_strategy: Optional["SearchStrategy"]
sch_rules: List["ScheduleRule"]
postproc: List["Postproc"]
mutator: List["Mutator"]
sch_rules: Optional[List["ScheduleRule"]]
postproc: Optional[List["Postproc"]]
mutator: Optional[List["Mutator"]]
task_name: Optional[str]
rand_state: int
num_threads: int
Expand All @@ -93,9 +93,9 @@ def __init__(
target: Optional[Target] = None,
space_generator: Optional["SpaceGenerator"] = None,
search_strategy: Optional["SearchStrategy"] = None,
sch_rules: List["ScheduleRule"] = [],
postproc: List["Postproc"] = [],
mutator: List["Mutator"] = [],
sch_rules: Optional[List["ScheduleRule"]] = None,
postproc: Optional[List["Postproc"]] = None,
mutator: Optional[List["Mutator"]] = None,
task_name: Optional[str] = None,
rand_state: int = -1,
num_threads: Optional[int] = None,
Expand Down
4 changes: 3 additions & 1 deletion src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ class PostOrderApplyNode : public SpaceGeneratorNode {

void InitializeWithTuneContext(const TuneContext& tune_context) final {
this->rand_state_ = ForkSeed(&tune_context->rand_state);
this->sch_rules_ = tune_context->sch_rules;
CHECK(tune_context->sch_rules.defined())
<< "ValueError: Schedules rules not given in PostOrderApply!";
this->sch_rules_ = tune_context->sch_rules.value();
}

Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod_) final {
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/task_scheduler/round_robin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ TaskScheduler TaskScheduler::RoundRobin(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
Array<MeasureCallback> measure_callbacks) {
Optional<Array<MeasureCallback>> measure_callbacks) {
ObjectPtr<RoundRobinNode> n = make_object<RoundRobinNode>();
n->tasks = tasks;
n->builder = builder;
Expand Down
23 changes: 13 additions & 10 deletions src/meta_schedule/task_scheduler/task_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,18 @@ void TaskSchedulerNode::InitializeTask(int task_id) {
space->InitializeWithTuneContext(task);
strategy->InitializeWithTuneContext(task);
// Initialize the rules.
for (const ScheduleRule& sch_rule : task->sch_rules) {
sch_rule->InitializeWithTuneContext(task);
}
for (const Mutator& mutator : task->mutators) {
mutator->InitializeWithTuneContext(task);
}
for (const Postproc& postproc : task->postprocs) {
postproc->InitializeWithTuneContext(task);
}
if (task->sch_rules.defined())
for (const ScheduleRule& sch_rule : task->sch_rules.value()) {
sch_rule->InitializeWithTuneContext(task);
}
if (task->mutators.defined())
for (const Mutator& mutator : task->mutators.value()) {
mutator->InitializeWithTuneContext(task);
}
if (task->postprocs.defined())
for (const Postproc& postproc : task->postprocs.value()) {
postproc->InitializeWithTuneContext(task);
}
}

void TaskSchedulerNode::Tune() {
Expand Down Expand Up @@ -211,7 +214,7 @@ TaskScheduler TaskScheduler::PyTaskScheduler(
Builder builder, //
Runner runner, //
Database database, //
Array<MeasureCallback> measure_callbacks, //
Optional<Array<MeasureCallback>> measure_callbacks, //
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
Expand Down
6 changes: 3 additions & 3 deletions src/meta_schedule/tune_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ TuneContext::TuneContext(Optional<IRModule> mod,
Optional<Target> target, //
Optional<SpaceGenerator> space_generator, //
Optional<SearchStrategy> search_strategy, //
Array<ScheduleRule> sch_rules, //
Array<Postproc> postprocs, //
Array<Mutator> mutators, //
Optional<Array<ScheduleRule>> sch_rules, //
Optional<Array<Postproc>> postprocs, //
Optional<Array<Mutator>> mutators, //
Optional<String> task_name, //
support::LinearCongruentialEngine::TRandState rand_state, //
int num_threads) {
Expand Down

0 comments on commit ad7adb3

Please sign in to comment.