Skip to content

Commit

Permalink
[Meta Schedule] Refactor meta schedule testing utils (apache#10648)
Browse files Browse the repository at this point in the history
This PR moves some utility testing classes into `meta_schedule/testing/utils` and updated the following tests involved:

- test_meta_schedule_integration.py
- test_meta_schedule_measure_callback.py
- test_meta_schedule_search_strategy.py
- test_meta_schedule_task_scheduler.py
  • Loading branch information
Yuanjing Shi authored and pfk-beta committed Apr 11, 2022
1 parent 9d527d9 commit c62cc34
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 228 deletions.
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilities in meta schedule"""
from .utils import DummyDatabase, DummyBuilder, DummyRunner, DummyRunnerFuture, DummyMutator
112 changes: 112 additions & 0 deletions python/tvm/meta_schedule/testing/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Testing utilitiy functions in meta schedule"""
from typing import List, Optional
import random

import tvm

from tvm.meta_schedule import TuneContext # pylint: disable=unused-import
from tvm.meta_schedule.utils import derived_object
from tvm.meta_schedule.mutator.mutator import PyMutator
from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord
from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult
from tvm.meta_schedule.runner import (
RunnerInput,
RunnerResult,
RunnerFuture,
PyRunnerFuture,
PyRunner,
)
from tvm.ir import IRModule
from tvm.tir.schedule import Trace


@derived_object
class DummyDatabase(PyDatabase):
"""
An in-memory database based on python list for testing.
"""

def __init__(self):
super().__init__()
self.records = []
self.workload_reg = []

def has_workload(self, mod: IRModule) -> bool:
for workload in self.workload_reg:
if tvm.ir.structural_equal(workload.mod, mod):
return True
return False

def commit_tuning_record(self, record: TuningRecord) -> None:
self.records.append(record)

def commit_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if tvm.ir.structural_equal(workload.mod, mod):
return workload
workload = Workload(mod)
self.workload_reg.append(workload)
return workload

def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
return list(
filter(
lambda x: x.workload == workload,
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
)
)[: int(top_k)]

def __len__(self) -> int:
return len(self.records)

def print_results(self) -> None:
print("\n".join([str(r) for r in self.records]))


@derived_object
class DummyRunnerFuture(PyRunnerFuture):
def done(self) -> bool:
return True

def result(self) -> RunnerResult:
run_secs = [random.uniform(5, 30) for _ in range(random.randint(1, 10))]
return RunnerResult(run_secs, None)


@derived_object
class DummyBuilder(PyBuilder):
def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
return [BuilderResult("test_path", None) for _ in build_inputs]


@derived_object
class DummyRunner(PyRunner):
def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
return [DummyRunnerFuture() for _ in runner_inputs] # type: ignore


@derived_object
class DummyMutator(PyMutator):
"""Dummy Mutator for testing"""

def initialize_with_tune_context(self, context: "TuneContext") -> None:
pass

def apply(self, trace: Trace, _) -> Optional[Trace]:
return Trace(trace.insts, {})
39 changes: 1 addition & 38 deletions tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tvm.script import tir as T
from tvm.target import Target
from tvm.tir import Schedule
from tvm.meta_schedule.testing import DummyDatabase
from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
from tvm.meta_schedule.tune import extract_task_from_relay, Parse

Expand Down Expand Up @@ -106,44 +107,6 @@ def test_meta_schedule_integration_extract_from_resnet():

@requires_torch
def test_meta_schedule_integration_apply_history_best():
@derived_object
class DummyDatabase(PyDatabase):
def __init__(self):
super().__init__()
self.records = []
self.workload_reg = []

def has_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if tvm.ir.structural_equal(workload.mod, mod):
return True
return False

def commit_tuning_record(self, record: TuningRecord) -> None:
self.records.append(record)

def commit_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if tvm.ir.structural_equal(workload.mod, mod):
return workload
workload = Workload(mod)
self.workload_reg.append(workload)
return workload

def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
return list(
filter(
lambda x: x.workload == workload,
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
)
)[: int(top_k)]

def __len__(self) -> int:
return len(self.records)

def print_results(self) -> None:
print("\n".join([str(r) for r in self.records]))

mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
database = DummyDatabase()
env = ApplyHistoryBest(database)
Expand Down
72 changes: 3 additions & 69 deletions tests/python/unittest/test_meta_schedule_measure_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,9 @@
from tvm.ir import IRModule, assert_structural_equal
from tvm.meta_schedule.builder import BuilderResult
from tvm.meta_schedule.measure_callback import PyMeasureCallback
from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult
from tvm.meta_schedule.runner import (
RunnerInput,
RunnerResult,
RunnerFuture,
PyRunnerFuture,
PyRunner,
)
from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord
from tvm.meta_schedule.builder import BuilderResult
from tvm.meta_schedule.runner import RunnerResult
from tvm.meta_schedule.testing import DummyDatabase, DummyRunner, DummyBuilder
from tvm.meta_schedule.search_strategy import MeasureCandidate
from tvm.meta_schedule.task_scheduler import RoundRobin, TaskScheduler
from tvm.meta_schedule.utils import derived_object
Expand Down Expand Up @@ -61,66 +55,6 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument


@derived_object
class DummyRunnerFuture(PyRunnerFuture):
def done(self) -> bool:
return True

def result(self) -> RunnerResult:
return RunnerResult([random.uniform(5, 30) for _ in range(random.randint(1, 10))], None)


@derived_object
class DummyBuilder(PyBuilder):
def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
return [BuilderResult("test_path", None) for _ in build_inputs]


@derived_object
class DummyRunner(PyRunner):
def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
return [DummyRunnerFuture() for _ in runner_inputs]


@derived_object
class DummyDatabase(PyDatabase):
def __init__(self):
super().__init__()
self.records = []
self.workload_reg = []

def has_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if tvm.ir.structural_equal(workload.mod, mod):
return True
return False

def commit_tuning_record(self, record: TuningRecord) -> None:
self.records.append(record)

def commit_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if tvm.ir.structural_equal(workload.mod, mod):
return workload
workload = Workload(mod)
self.workload_reg.append(workload)
return workload

def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
return list(
filter(
lambda x: x.workload == workload,
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
)
)[: int(top_k)]

def __len__(self) -> int:
return len(self.records)

def print_results(self) -> None:
print("\n".join([str(r) for r in self.records]))


def test_meta_schedule_measure_callback():
@derived_object
class FancyMeasureCallback(PyMeasureCallback):
Expand Down
53 changes: 1 addition & 52 deletions tests/python/unittest/test_meta_schedule_search_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
from tvm.meta_schedule import TuneContext
from tvm.meta_schedule.builder import LocalBuilder
from tvm.meta_schedule.cost_model import RandomModel
from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
from tvm.meta_schedule.mutator.mutator import PyMutator
from tvm.meta_schedule.runner import LocalRunner, RunnerResult
from tvm.meta_schedule.search_strategy import (
EvolutionarySearch,
Expand All @@ -38,6 +36,7 @@
from tvm.meta_schedule.space_generator import ScheduleFn
from tvm.meta_schedule.task_scheduler import RoundRobin
from tvm.meta_schedule.utils import derived_object
from tvm.meta_schedule.testing import DummyDatabase, DummyMutator
from tvm.script import tir as T
from tvm.tir.schedule import Schedule, Trace

Expand Down Expand Up @@ -117,56 +116,6 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl


def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name]
@derived_object
class DummyMutator(PyMutator):
"""Dummy Mutator for testing"""

def initialize_with_tune_context(self, context: "TuneContext") -> None:
pass

def apply(self, trace: Trace, _) -> Optional[Trace]:
return Trace(trace.insts, {})

@derived_object
class DummyDatabase(PyDatabase):
"""Dummy Database for testing"""

def __init__(self):
super().__init__()
self.records = []
self.workload_reg = []

def has_workload(self, mod: IRModule) -> bool:
for workload in self.workload_reg:
if tvm.ir.structural_equal(workload.mod, mod):
return True
return False

def commit_tuning_record(self, record: TuningRecord) -> None:
self.records.append(record)

def commit_workload(self, mod: IRModule) -> Workload:
for workload in self.workload_reg:
if tvm.ir.structural_equal(workload.mod, mod):
return workload
workload = Workload(mod)
self.workload_reg.append(workload)
return workload

def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
return list(
filter(
lambda x: x.workload == workload,
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
)
)[: int(top_k)]

def __len__(self) -> int:
return len(self.records)

def print_results(self) -> None:
print("\n".join([str(r) for r in self.records]))

num_trials_per_iter = 10
num_trials_total = 100

Expand Down
Loading

0 comments on commit c62cc34

Please sign in to comment.