Skip to content

Commit

Permalink
[Bug][Meta Schedule] Fix Infinite Loop Caused When Calling Methods No…
Browse files Browse the repository at this point in the history
…t Overrided In PyClass. (apache#9451)

* Fix Infinite Loop Caused When Calling Methods Not Overrided In PyClass.

* Add new line.

* Lint.
  • Loading branch information
zxybazh authored Nov 5, 2021
1 parent 5527cbf commit 048994b
Show file tree
Hide file tree
Showing 17 changed files with 366 additions and 112 deletions.
1 change: 1 addition & 0 deletions include/tvm/meta_schedule/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class PyBuilderNode : public BuilderNode {
}

Array<BuilderResult> Build(const Array<BuilderInput>& build_inputs) final {
ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!";
return f_build(build_inputs);
}

Expand Down
23 changes: 17 additions & 6 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,29 @@ class PyDatabaseNode : public DatabaseNode {
// `f_size` is not visited
}

static constexpr const char* _type_key = "meta_schedule.PyDatabase";
TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode);

Workload CommitWorkload(const IRModule& mod) final { return f_commit_workload(mod); }
Workload CommitWorkload(const IRModule& mod) final {
ICHECK(f_commit_workload != nullptr) << "PyDatabase's CommitWorkload method not implemented!";
return f_commit_workload(mod);
}

void CommitTuningRecord(const TuningRecord& record) final { f_commit_tuning_record(record); }
void CommitTuningRecord(const TuningRecord& record) final {
ICHECK(f_commit_tuning_record != nullptr)
<< "PyDatabase's CommitTuningRecord method not implemented!";
f_commit_tuning_record(record);
}

Array<TuningRecord> GetTopK(const Workload& workload, int top_k) final {
ICHECK(f_get_top_k != nullptr) << "PyDatabase's GetTopK method not implemented!";
return f_get_top_k(workload, top_k);
}

int64_t Size() final { return f_size(); }
int64_t Size() final {
ICHECK(f_size != nullptr) << "PyDatabase's Size method not implemented!";
return f_size();
}

static constexpr const char* _type_key = "meta_schedule.PyDatabase";
TVM_DECLARE_FINAL_OBJECT_INFO(PyDatabaseNode, DatabaseNode);
};

/*!
Expand Down
5 changes: 4 additions & 1 deletion include/tvm/meta_schedule/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ class PyRunnerNode : public RunnerNode {
// `f_run` is not visited
}

Array<RunnerFuture> Run(Array<RunnerInput> runner_inputs) final { return f_run(runner_inputs); }
Array<RunnerFuture> Run(Array<RunnerInput> runner_inputs) final {
ICHECK(f_run != nullptr) << "PyRunner's Run method not implemented!";
return f_run(runner_inputs);
}

static constexpr const char* _type_key = "meta_schedule.PyRunner";
TVM_DECLARE_FINAL_OBJECT_INFO(PyRunnerNode, RunnerNode);
Expand Down
12 changes: 11 additions & 1 deletion include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,30 @@ class PySearchStrategyNode : public SearchStrategyNode {
}

void InitializeWithTuneContext(const TuneContext& context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PySearchStrategy's InitializeWithTuneContext method not implemented!";
this->f_initialize_with_tune_context(context);
}

void PreTuning(const Array<tir::Schedule>& design_spaces) final {
ICHECK(f_pre_tuning != nullptr) << "PySearchStrategy's PreTuning method not implemented!";
this->f_pre_tuning(design_spaces);
}

void PostTuning() final { this->f_post_tuning(); }
void PostTuning() final {
ICHECK(f_post_tuning != nullptr) << "PySearchStrategy's PostTuning method not implemented!";
this->f_post_tuning();
}

Optional<Array<MeasureCandidate>> GenerateMeasureCandidates() final {
ICHECK(f_generate_measure_candidates != nullptr)
<< "PySearchStrategy's GenerateMeasureCandidates method not implemented!";
return this->f_generate_measure_candidates();
}

void NotifyRunnerResults(const Array<RunnerResult>& results) final {
ICHECK(f_notify_runner_results != nullptr)
<< "PySearchStrategy's NotifyRunnerResults method not implemented!";
this->f_notify_runner_results(results);
}

Expand Down
4 changes: 4 additions & 0 deletions include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,14 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode {
}

void InitializeWithTuneContext(const TuneContext& tune_context) final {
ICHECK(f_initialize_with_tune_context != nullptr)
<< "PySpaceGenerator's InitializeWithTuneContext !";
f_initialize_with_tune_context(tune_context);
}

Array<tir::Schedule> GenerateDesignSpace(const IRModule& mod) final {
ICHECK(f_generate_design_space != nullptr)
<< "PySpaceGenerator's GenerateDesignSpace method not implemented!";
return f_generate_design_space(mod);
}

Expand Down
68 changes: 57 additions & 11 deletions include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ class TaskSchedulerNode : public runtime::Object {
/*! \brief Auto-tuning. */
virtual void Tune();

/*!
* \brief Initialize modules of the given task.
* \param task_id The task id to be initialized.
*/
virtual void InitializeTask(int task_id);

/*!
* \brief Set specific task to be stopped.
* \param task_id The task id to be stopped.
Expand Down Expand Up @@ -116,12 +122,17 @@ class TaskSchedulerNode : public runtime::Object {
TVM_DECLARE_BASE_OBJECT_INFO(TaskSchedulerNode, Object);
};

class TaskScheduler;

/*! \brief The task scheduler with customized methods on the python-side. */
class PyTaskSchedulerNode : public TaskSchedulerNode {
public:
/*! \brief The function type of `Tune` method. */
using FTune = runtime::TypedPackedFunc<void()>;

/*! \brief The function type of `InitializeTask` method. */
using FInitializeTask = runtime::TypedPackedFunc<void(int)>;

/*!
* \brief The function type of `SetTaskStopped` method.
* \param task_id The task id to be stopped.
Expand Down Expand Up @@ -149,6 +160,8 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {

/*! \brief The packed function to the `Tune` funcion. */
FTune f_tune;
/*! \brief The packed function to the `InitializeTask` funcion. */
FInitializeTask f_initialize_task;
/*! \brief The packed function to the `SetTaskStopped` function. */
FSetTaskStopped f_set_task_stopped;
/*! \brief The packed function to the `IsTaskRunning` function. */
Expand All @@ -160,29 +173,55 @@ class PyTaskSchedulerNode : public TaskSchedulerNode {

void VisitAttrs(tvm::AttrVisitor* v) {
// `f_tune` is not visited
// `f_initialize_task` is not visited
// `f_set_task_stopped` is not visited
// `f_is_task_running` is not visited
// `f_join_running_task` is not visited
// `f_next_task_id` is not visited
}

void Tune() final { //
f_tune();
void Tune() final {
if (f_tune == nullptr) {
TaskSchedulerNode::Tune();
} else {
f_tune();
}
}

void InitializeTask(int task_id) final {
if (f_initialize_task == nullptr) {
TaskSchedulerNode::InitializeTask(task_id);
} else {
f_initialize_task(task_id);
}
}

void SetTaskStopped(int task_id) final { //
f_set_task_stopped(task_id);
void SetTaskStopped(int task_id) final {
if (f_set_task_stopped == nullptr) {
TaskSchedulerNode::SetTaskStopped(task_id);
} else {
f_set_task_stopped(task_id);
}
}

bool IsTaskRunning(int task_id) final { //
return f_is_task_running(task_id);
bool IsTaskRunning(int task_id) final {
if (f_is_task_running == nullptr) {
return TaskSchedulerNode::IsTaskRunning(task_id);
} else {
return f_is_task_running(task_id);
}
}

void JoinRunningTask(int task_id) final { //
f_join_running_task(task_id);
void JoinRunningTask(int task_id) final {
if (f_join_running_task == nullptr) {
return TaskSchedulerNode::JoinRunningTask(task_id);
} else {
return f_join_running_task(task_id);
}
}

int NextTaskId() final { //
int NextTaskId() final {
ICHECK(f_next_task_id != nullptr) << "PyTaskScheduler's NextTaskId method not implemented!";
return f_next_task_id();
}

Expand All @@ -203,10 +242,17 @@ class TaskScheduler : public runtime::ObjectRef {
* \param runner The runner of the scheduler.
* \param database The database of the scheduler.
*/
TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, Builder builder, Runner runner,
Database database);
TVM_DLL static TaskScheduler RoundRobin(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database); //
TVM_DLL static TaskScheduler PyTaskScheduler(
Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
PyTaskSchedulerNode::FTune f_tune, //
PyTaskSchedulerNode::FInitializeTask f_initialize_task, //
PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, //
PyTaskSchedulerNode::FIsTaskRunning f_is_task_running, //
PyTaskSchedulerNode::FJoinRunningTask f_join_running_task, //
Expand Down
5 changes: 2 additions & 3 deletions python/tvm/meta_schedule/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm.target import Target

from .. import _ffi_api
from ..utils import check_override


@register_object("meta_schedule.BuilderInput")
Expand Down Expand Up @@ -119,13 +120,11 @@ class PyBuilder(Builder):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, Builder)
def f_build(build_inputs: List[BuilderInput]) -> List[BuilderResult]:
return self.build(build_inputs)

self.__init_handle_by_constructor__(
_ffi_api.BuilderPyBuilder, # type: ignore # pylint: disable=no-member
f_build,
)

def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
raise NotImplementedError
18 changes: 5 additions & 13 deletions python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from .. import _ffi_api
from ..arg_info import ArgInfo
from ..utils import _json_de_tvm
from ..utils import _json_de_tvm, check_override


@register_object("meta_schedule.Workload")
Expand Down Expand Up @@ -207,15 +207,19 @@ class PyDatabase(Database):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, Database)
def f_commit_workload(mod: IRModule) -> Workload:
return self.commit_workload(mod)

@check_override(self.__class__, Database)
def f_commit_tuning_record(record: TuningRecord) -> None:
self.commit_tuning_record(record)

@check_override(self.__class__, Database)
def f_get_top_k(workload: Workload, top_k: int) -> List[TuningRecord]:
return self.get_top_k(workload, top_k)

@check_override(self.__class__, Database, func_name="__len__")
def f_size() -> int:
return len(self)

Expand All @@ -226,15 +230,3 @@ def f_size() -> int:
f_get_top_k,
f_size,
)

def commit_workload(self, mod: IRModule) -> Workload:
raise NotImplementedError

def commit_tuning_record(self, record: TuningRecord) -> None:
raise NotImplementedError

def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
raise NotImplementedError

def __len__(self) -> int:
raise NotImplementedError
5 changes: 2 additions & 3 deletions python/tvm/meta_schedule/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .. import _ffi_api
from ..arg_info import ArgInfo
from ..utils import check_override


@register_object("meta_schedule.RunnerInput")
Expand Down Expand Up @@ -158,13 +159,11 @@ class PyRunner(Runner):
def __init__(self) -> None:
"""Constructor"""

@check_override(self.__class__, Runner)
def f_run(runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
return self.run(runner_inputs)

self.__init_handle_by_constructor__(
_ffi_api.RunnerPyRunner, # type: ignore # pylint: disable=no-member
f_run,
)

def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
raise NotImplementedError
27 changes: 10 additions & 17 deletions python/tvm/meta_schedule/search_strategy/search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Search Strategy"""

"""
Meta Schedule search strategy that generates the measure
candidates for measurement.
"""
from typing import List, Optional, TYPE_CHECKING

from tvm._ffi import register_object
Expand All @@ -25,6 +27,7 @@
from .. import _ffi_api
from ..arg_info import ArgInfo
from ..runner import RunnerResult
from ..utils import check_override

if TYPE_CHECKING:
from ..tune_context import TuneContext
Expand Down Expand Up @@ -126,18 +129,23 @@ class PySearchStrategy(SearchStrategy):
def __init__(self):
"""Constructor."""

@check_override(self.__class__, SearchStrategy)
def f_initialize_with_tune_context(context: "TuneContext") -> None:
self.initialize_with_tune_context(context)

@check_override(self.__class__, SearchStrategy)
def f_pre_tuning(design_spaces: List[Schedule]) -> None:
self.pre_tuning(design_spaces)

@check_override(self.__class__, SearchStrategy)
def f_post_tuning() -> None:
self.post_tuning()

@check_override(self.__class__, SearchStrategy)
def f_generate_measure_candidates() -> List[MeasureCandidate]:
return self.generate_measure_candidates()

@check_override(self.__class__, SearchStrategy)
def f_notify_runner_results(results: List["RunnerResult"]) -> None:
self.notify_runner_results(results)

Expand All @@ -149,18 +157,3 @@ def f_notify_runner_results(results: List["RunnerResult"]) -> None:
f_generate_measure_candidates,
f_notify_runner_results,
)

def initialize_with_tune_context(self, tune_context: "TuneContext") -> None:
raise NotImplementedError

def pre_tuning(self, design_spaces: List[Schedule]) -> None:
raise NotImplementedError

def post_tuning(self) -> None:
raise NotImplementedError

def generate_measure_candidates(self) -> List[MeasureCandidate]:
raise NotImplementedError

def notify_runner_results(self, results: List["RunnerResult"]) -> None:
raise NotImplementedError
Loading

0 comments on commit 048994b

Please sign in to comment.