diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py index 5d6081fa81e4a..bafdd521bffbb 100644 --- a/python/tvm/meta_schedule/testing/__init__.py +++ b/python/tvm/meta_schedule/testing/__init__.py @@ -15,3 +15,4 @@ # specific language governing permissions and limitations # under the License. """Testing utilities in meta schedule""" +from .utils import DummyDatabase, DummyBuilder, DummyRunner, DummyRunnerFuture, DummyMutator diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py new file mode 100644 index 0000000000000..b7ef349140894 --- /dev/null +++ b/python/tvm/meta_schedule/testing/utils.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Testing utilitiy functions in meta schedule""" +from typing import List, Optional +import random + +import tvm + +from tvm.meta_schedule import TuneContext # pylint: disable=unused-import +from tvm.meta_schedule.utils import derived_object +from tvm.meta_schedule.mutator.mutator import PyMutator +from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord +from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult +from tvm.meta_schedule.runner import ( + RunnerInput, + RunnerResult, + RunnerFuture, + PyRunnerFuture, + PyRunner, +) +from tvm.ir import IRModule +from tvm.tir.schedule import Trace + + +@derived_object +class DummyDatabase(PyDatabase): + """ + An in-memory database based on python list for testing. + """ + + def __init__(self): + super().__init__() + self.records = [] + self.workload_reg = [] + + def has_workload(self, mod: IRModule) -> bool: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True + return False + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.records.append(record) + + def commit_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return workload + workload = Workload(mod) + self.workload_reg.append(workload) + return workload + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + return list( + filter( + lambda x: x.workload == workload, + sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), + ) + )[: int(top_k)] + + def __len__(self) -> int: + return len(self.records) + + def print_results(self) -> None: + print("\n".join([str(r) for r in self.records])) + + +@derived_object +class DummyRunnerFuture(PyRunnerFuture): + def done(self) -> bool: + return True + + def result(self) -> RunnerResult: + run_secs = [random.uniform(5, 30) for _ in range(random.randint(1, 10))] + return RunnerResult(run_secs, None) + + +@derived_object +class DummyBuilder(PyBuilder): + def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: + return [BuilderResult("test_path", None) for _ in build_inputs] + + +@derived_object +class DummyRunner(PyRunner): + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + return [DummyRunnerFuture() for _ in runner_inputs] # type: ignore + + +@derived_object +class DummyMutator(PyMutator): + """Dummy Mutator for testing""" + + def initialize_with_tune_context(self, context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace, _) -> Optional[Trace]: + return Trace(trace.insts, {}) diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index a713fa0fee69c..68ee840d15eac 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -32,6 +32,7 @@ from tvm.script import tir as T from tvm.target import Target from tvm.tir import Schedule +from tvm.meta_schedule.testing import DummyDatabase from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base from tvm.meta_schedule.tune import extract_task_from_relay, Parse @@ -106,44 +107,6 @@ def test_meta_schedule_integration_extract_from_resnet(): @requires_torch def test_meta_schedule_integration_apply_history_best(): - @derived_object - class DummyDatabase(PyDatabase): - def __init__(self): - super().__init__() - self.records = [] - self.workload_reg = [] - - def has_workload(self, mod: IRModule) -> Workload: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return True - return False - - def commit_tuning_record(self, record: TuningRecord) -> None: - self.records.append(record) - - def commit_workload(self, mod: IRModule) -> Workload: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return workload - workload = Workload(mod) - self.workload_reg.append(workload) - return workload - - def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: - return list( - filter( - lambda x: x.workload == workload, - sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), - ) - )[: int(top_k)] - - def __len__(self) -> int: - return len(self.records) - - def print_results(self) -> None: - print("\n".join([str(r) for r in self.records])) - mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) database = DummyDatabase() env = ApplyHistoryBest(database) diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py index 73d1e5752f3de..73640bdf74f69 100644 --- a/tests/python/unittest/test_meta_schedule_measure_callback.py +++ b/tests/python/unittest/test_meta_schedule_measure_callback.py @@ -24,15 +24,9 @@ from tvm.ir import IRModule, assert_structural_equal from tvm.meta_schedule.builder import BuilderResult from tvm.meta_schedule.measure_callback import PyMeasureCallback -from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult -from tvm.meta_schedule.runner import ( - RunnerInput, - RunnerResult, - RunnerFuture, - PyRunnerFuture, - PyRunner, -) -from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord +from tvm.meta_schedule.builder import BuilderResult +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.testing import DummyDatabase, DummyRunner, DummyBuilder from tvm.meta_schedule.search_strategy import MeasureCandidate from tvm.meta_schedule.task_scheduler import RoundRobin, TaskScheduler from tvm.meta_schedule.utils import derived_object @@ -61,66 +55,6 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument -@derived_object -class DummyRunnerFuture(PyRunnerFuture): - def done(self) -> bool: - return True - - def result(self) -> RunnerResult: - return RunnerResult([random.uniform(5, 30) for _ in range(random.randint(1, 10))], None) - - -@derived_object -class DummyBuilder(PyBuilder): - def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: - return [BuilderResult("test_path", None) for _ in build_inputs] - - -@derived_object -class DummyRunner(PyRunner): - def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: - return [DummyRunnerFuture() for _ in runner_inputs] - - -@derived_object -class DummyDatabase(PyDatabase): - def __init__(self): - super().__init__() - self.records = [] - self.workload_reg = [] - - def has_workload(self, mod: IRModule) -> Workload: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return True - return False - - def commit_tuning_record(self, record: TuningRecord) -> None: - self.records.append(record) - - def commit_workload(self, mod: IRModule) -> Workload: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return workload - workload = Workload(mod) - self.workload_reg.append(workload) - return workload - - def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: - return list( - filter( - lambda x: x.workload == workload, - sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), - ) - )[: int(top_k)] - - def __len__(self) -> int: - return len(self.records) - - def print_results(self) -> None: - print("\n".join([str(r) for r in self.records])) - - def test_meta_schedule_measure_callback(): @derived_object class FancyMeasureCallback(PyMeasureCallback): diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 6bf4599ebdc5b..80d645a5ce939 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -26,8 +26,6 @@ from tvm.meta_schedule import TuneContext from tvm.meta_schedule.builder import LocalBuilder from tvm.meta_schedule.cost_model import RandomModel -from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload -from tvm.meta_schedule.mutator.mutator import PyMutator from tvm.meta_schedule.runner import LocalRunner, RunnerResult from tvm.meta_schedule.search_strategy import ( EvolutionarySearch, @@ -38,6 +36,7 @@ from tvm.meta_schedule.space_generator import ScheduleFn from tvm.meta_schedule.task_scheduler import RoundRobin from tvm.meta_schedule.utils import derived_object +from tvm.meta_schedule.testing import DummyDatabase, DummyMutator from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace @@ -117,56 +116,6 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name] - @derived_object - class DummyMutator(PyMutator): - """Dummy Mutator for testing""" - - def initialize_with_tune_context(self, context: "TuneContext") -> None: - pass - - def apply(self, trace: Trace, _) -> Optional[Trace]: - return Trace(trace.insts, {}) - - @derived_object - class DummyDatabase(PyDatabase): - """Dummy Database for testing""" - - def __init__(self): - super().__init__() - self.records = [] - self.workload_reg = [] - - def has_workload(self, mod: IRModule) -> bool: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return True - return False - - def commit_tuning_record(self, record: TuningRecord) -> None: - self.records.append(record) - - def commit_workload(self, mod: IRModule) -> Workload: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return workload - workload = Workload(mod) - self.workload_reg.append(workload) - return workload - - def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: - return list( - filter( - lambda x: x.workload == workload, - sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), - ) - )[: int(top_k)] - - def __len__(self) -> int: - return len(self.records) - - def print_results(self) -> None: - print("\n".join([str(r) for r in self.records])) - num_trials_per_iter = 10 num_trials_total = 100 diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index 3936803aab625..e49c35fa445ca 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -26,19 +26,11 @@ from tvm._ffi.base import TVMError from tvm.ir import IRModule from tvm.meta_schedule import TuneContext, measure_callback -from tvm.meta_schedule.builder import BuilderInput, BuilderResult, PyBuilder -from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload -from tvm.meta_schedule.runner import ( - PyRunner, - RunnerFuture, - RunnerInput, - RunnerResult, - PyRunnerFuture, -) from tvm.meta_schedule.search_strategy import ReplayTrace from tvm.meta_schedule.space_generator import ScheduleFn from tvm.meta_schedule.task_scheduler import PyTaskScheduler, RoundRobin from tvm.meta_schedule.utils import derived_object +from tvm.meta_schedule.testing import DummyDatabase, DummyBuilder, DummyRunner, DummyRunnerFuture from tvm.script import tir as T from tvm.tir import Schedule @@ -123,66 +115,6 @@ def _schedule_batch_matmul(sch: Schedule): sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, t_0, t_1) -@derived_object -class DummyRunnerFuture(PyRunnerFuture): - def done(self) -> bool: - return True - - def result(self) -> RunnerResult: - return RunnerResult([random.uniform(5, 30) for _ in range(random.randint(1, 10))], None) - - -@derived_object -class DummyBuilder(PyBuilder): - def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: - return [BuilderResult("test_path", None) for _ in build_inputs] - - -@derived_object -class DummyRunner(PyRunner): - def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: - return [DummyRunnerFuture() for _ in runner_inputs] - - -@derived_object -class DummyDatabase(PyDatabase): - def __init__(self): - super().__init__() - self.records = [] - self.workload_reg = [] - - def has_workload(self, mod: IRModule) -> Workload: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return True - return False - - def commit_tuning_record(self, record: TuningRecord) -> None: - self.records.append(record) - - def commit_workload(self, mod: IRModule) -> Workload: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return workload - workload = Workload(mod) - self.workload_reg.append(workload) - return workload - - def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: - return list( - filter( - lambda x: x.workload == workload, - sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), - ) - )[: int(top_k)] - - def __len__(self) -> int: - return len(self.records) - - def print_results(self) -> None: - print("\n".join([str(r) for r in self.records])) - - @derived_object class MyTaskScheduler(PyTaskScheduler): done = set()