diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 5e249850f5d5..a75a4cd8ae86 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -113,12 +113,10 @@ class SearchStrategyNode : public runtime::Object { /*! * \brief Update the search strategy with measurement results. - * \param context The tuning context. * \param measure_candidates The candidates to be measured. * \param results The measurement results from the runner. */ - virtual void NotifyRunnerResults(const TuneContext& context, - const Array& measure_candidates, + virtual void NotifyRunnerResults(const Array& measure_candidates, const Array& results) = 0; static constexpr const char* _type_key = "meta_schedule.SearchStrategy"; @@ -150,8 +148,8 @@ class PySearchStrategyNode : public SearchStrategyNode { * \brief The function type of `NotifyRunnerResults` method. * \param results The measurement results from the runner. */ - using FNotifyRunnerResults = runtime::TypedPackedFunc&, const Array&)>; + using FNotifyRunnerResults = + runtime::TypedPackedFunc&, const Array&)>; /*! \brief The packed function to the `InitializeWithTuneContext` method. */ FInitializeWithTuneContext f_initialize_with_tune_context; @@ -177,8 +175,7 @@ class PySearchStrategyNode : public SearchStrategyNode { const Optional& cost_model) final; void PostTuning() final; Optional> GenerateMeasureCandidates() final; - void NotifyRunnerResults(const TuneContext& context, - const Array& measure_candidates, + void NotifyRunnerResults(const Array& measure_candidates, const Array& results); static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index ee09099d1a92..3d732e7fbd99 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -42,6 +42,7 @@ namespace tvm { namespace meta_schedule { class TaskSchedulerNode; +class MeasureCallback; /*! \brief The auto tuning context. */ class TuneContextNode : public runtime::Object { @@ -70,7 +71,7 @@ class TuneContextNode : public runtime::Object { int num_threads; /*! \brief Whether the tuning task has been stopped or finished. */ - bool is_terminated; + bool is_terminated; // TODO(@junrushao1994): move to TaskScheduler /*! \brief The measure candidates. */ Optional> measure_candidates; /*! \brief The building results. */ @@ -87,18 +88,36 @@ class TuneContextNode : public runtime::Object { v->Visit("postprocs", &postprocs); v->Visit("mutator_probs", &mutator_probs); v->Visit("task_name", &task_name); + // `logging_func` is not visited v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); v->Visit("is_terminated", &is_terminated); + v->Visit("measure_candidates", &measure_candidates); v->Visit("builder_results", &builder_results); v->Visit("runner_futures", &runner_futures); - v->Visit("measure_candidates", &measure_candidates); - // `logging_func` is not visited } /*! \brief Initialize members that needs initialization with tune context. */ void Initialize(); - + /*! \brief Set the measure candidates from the SearchStrategy */ + void _SetMeasureCandidates(const Array& candidates); + /*! + * \brief Send the measure candidates to builder. + * \param builder The builder to send the candidates to. + */ + void _SendToBuilder(const Builder& builder); + /*! + * \brief Send the built measure candidates to runner. + * \param runner The runner to send the candidates to. + */ + void _SendToRunner(const Runner& runner); + /*! + * \brief Join the running tasks. + * \returns The results from the runner + */ + Array _Join(); + /*! \brief Set `measure_candidates`, `builder_results` and `runner_futures` to null. */ + void _ClearMeasureState(); static constexpr const char* _type_key = "meta_schedule.TuneContext"; TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object); }; diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 76eebbdf23f1..0028fbdf4faa 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -20,7 +20,9 @@ builder, cost_model, database, + default_config, feature_extractor, + measure_callback, mutator, postproc, runner, @@ -32,5 +34,6 @@ from .extracted_task import ExtractedTask from .relay_integration import extract_task_from_relay from .search_strategy import MeasureCandidate -from .tune import TuneConfig, tune_relay, tune_te, tune_tir +from .tune import TuneConfig, tune_extracted_tasks, tune_relay, tune_te, tune_tir from .tune_context import TuneContext +from .utils import derived_object diff --git a/python/tvm/meta_schedule/database/__init__.py b/python/tvm/meta_schedule/database/__init__.py index 320647b0e31b..2a87eea147d9 100644 --- a/python/tvm/meta_schedule/database/__init__.py +++ b/python/tvm/meta_schedule/database/__init__.py @@ -20,3 +20,4 @@ """ from .database import Database, PyDatabase, TuningRecord, Workload from .json_database import JSONDatabase +from .memory_database import MemoryDatabase diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 8e0c80541020..802a739e6958 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Tuning record database""" +"""TuningRecord database""" from typing import Any, Callable, List, Optional from tvm._ffi import register_object diff --git a/python/tvm/meta_schedule/database/memory_database.py b/python/tvm/meta_schedule/database/memory_database.py new file mode 100644 index 000000000000..6d10e4b5272a --- /dev/null +++ b/python/tvm/meta_schedule/database/memory_database.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A database that stores TuningRecords in memory""" +from typing import List + +from ...ir import IRModule, structural_equal +from ..utils import derived_object +from .database import PyDatabase, TuningRecord, Workload + + +@derived_object +class MemoryDatabase(PyDatabase): + """An in-memory database based on python list for testing.""" + + def __init__(self): + super().__init__() + self.records = [] + self.workload_reg = [] + + def has_workload(self, mod: IRModule) -> bool: + for workload in self.workload_reg: + if structural_equal(workload.mod, mod): + return True + return False + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.records.append(record) + + def commit_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if structural_equal(workload.mod, mod): + return workload + workload = Workload(mod) + self.workload_reg.append(workload) + return workload + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + return list( + filter( + lambda x: x.workload == workload, + sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), + ) + )[: int(top_k)] + + def __len__(self) -> int: + return len(self.records) + + def print_results(self) -> None: + print("\n".join([str(r) for r in self.records])) diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py new file mode 100644 index 000000000000..34411bde057b --- /dev/null +++ b/python/tvm/meta_schedule/default_config.py @@ -0,0 +1,346 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-outside-toplevel +"""Pre-configured Defaults for MetaSchedule search rules""" +import logging +from os import path as osp +from typing import Callable, Dict, List, Optional, Union + +from tvm._ffi.registry import register_func +from tvm.ir import IRModule +from tvm.target import Target +from tvm.tir import PrimFunc + +from .builder import Builder, LocalBuilder +from .cost_model import CostModel, XGBModel +from .database import Database, JSONDatabase +from .feature_extractor import PerStoreFeature +from .measure_callback import MeasureCallback +from .mutator import Mutator +from .postproc import Postproc +from .runner import LocalRunner, Runner +from .schedule_rule import ScheduleRule +from .space_generator import PostOrderApply, SpaceGenerator + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +FnSpaceGenerator = Callable[[], SpaceGenerator] +FnScheduleRule = Callable[[], List[ScheduleRule]] +FnPostproc = Callable[[], List[Postproc]] +FnMutatorProb = Callable[[], Dict[Mutator, float]] + + +@register_func("tvm.meta_schedule.tune.parse_mod") # for use in ApplyHistoryBest +def mod(mod: Union[PrimFunc, IRModule]) -> IRModule: # pylint: disable=redefined-outer-name + """Normalize the input to an IRModule""" + if isinstance(mod, PrimFunc): + mod = mod.with_attr("global_symbol", "main") + mod = mod.with_attr("tir.noalias", True) + mod = IRModule({"main": mod}) + if not isinstance(mod, IRModule): + raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") + # in order to make sure the mod can be found in ApplyHistoryBest + # different func name can cause structural unequal + func_names = mod.get_global_vars() + (func_name,) = func_names + if len(func_names) == 1 and func_name != "main": + mod = IRModule({"main": mod[func_name]}) + return mod + + +def target(target: Union[str, Target]) -> Target: # pylint: disable=redefined-outer-name + """Normalize the input to tvm.target.Target""" + if isinstance(target, str): + target = Target(target) + if not isinstance(target, Target): + raise TypeError(f"Expected `target` to be str or Target, but gets: {target}") + return target + + +def builder(builder: Optional[Builder]) -> Builder: # pylint: disable=redefined-outer-name + """Normalize the input to tvm.meta_schedule.Builder""" + if builder is None: + builder = LocalBuilder() # type: ignore + if not isinstance(builder, Builder): + raise TypeError(f"Expected `builder` to be Builder, but gets: {builder}") + return builder + + +def runner(runner: Optional[Runner]) -> Runner: # pylint: disable=redefined-outer-name + """Normalize the input to tvm.meta_schedule.Runner""" + if runner is None: + runner = LocalRunner() # type: ignore + if not isinstance(runner, Runner): + raise TypeError(f"Expected `runner` to be Runner, but gets: {runner}") + return runner + + +def database( + database: Union[None, Database], # pylint: disable=redefined-outer-name + path: str, +) -> Database: + """Normalize the input to tvm.meta_schedule.Database""" + if database is None: + path_workload = osp.join(path, "database_workload.json") + path_tuning_record = osp.join(path, "database_tuning_record.json") + logger.info( + "Creating JSONDatabase. Workload at: %s. Tuning records at: %s", + path_workload, + path_tuning_record, + ) + database = JSONDatabase( + path_workload=path_workload, + path_tuning_record=path_tuning_record, + ) + if not isinstance(database, Database): + raise TypeError(f"Expected `database` to be Database, but gets: {database}") + return database + + +def callbacks( # pylint: disable=redefined-outer-name + measure_callbacks: Optional[List[MeasureCallback]], +) -> List[MeasureCallback]: + """Normalize the input to List[tvm.meta_schedule.MeasureCallback]""" + if measure_callbacks is None: + from tvm.meta_schedule import measure_callback as M + + return [ + M.AddToDatabase(), + M.RemoveBuildArtifact(), + M.EchoStatistics(), + M.UpdateCostModel(), + ] + if not isinstance(measure_callbacks, (list, tuple)): + raise TypeError( + f"Expected `measure_callbacks` to be List[MeasureCallback], " + f"but gets: {measure_callbacks}" + ) + measure_callbacks = list(measure_callbacks) + for i, callback in enumerate(measure_callbacks): + if not isinstance(callback, MeasureCallback): + raise TypeError( + f"Expected `measure_callbacks` to be List[MeasureCallback], " + f"but measure_callbacks[{i}] is: {callback}" + ) + return measure_callbacks + + +def cost_model( + cost_model: Optional[CostModel], # pylint: disable=redefined-outer-name +) -> CostModel: + """Normalize the input to tvm.meta_schedule.CostModel""" + if cost_model is None: + return XGBModel(extractor=PerStoreFeature()) # type: ignore + if not isinstance(cost_model, CostModel): + raise TypeError(f"Expected `cost_model` to be CostModel, but gets: {cost_model}") + return cost_model + + +def space_generator( + space_generator: Optional[FnSpaceGenerator], # pylint: disable=redefined-outer-name +) -> SpaceGenerator: + """Normalize the input to tvm.meta_schedule.SpaceGenerator""" + if space_generator is None: + return PostOrderApply() + if callable(space_generator): + space_generator = space_generator() + if not isinstance(space_generator, SpaceGenerator): + raise TypeError( + f"Expected `space_generator` to return SpaceGenerator, " f"but gets: {space_generator}" + ) + return space_generator + + +def schedule_rules( # pylint: disable=redefined-outer-name + sch_rules: Optional[FnScheduleRule], + target: Target, +) -> List[ScheduleRule]: + """Normalize the input to List[tvm.meta_schedule.ScheduleRule]""" + if callable(sch_rules): + return sch_rules() + if sch_rules is not None: + raise TypeError(f"Expected `sch_rules` to be None or callable, but gets: {sch_rules}") + if target.kind.name == "llvm": + return _DefaultLLVM.schedule_rules() + if target.kind.name in ["cuda", "rocm", "vulkan"]: + return _DefaultCUDA.schedule_rules() + raise ValueError(f"Unsupported target: {target}") + + +def postproc( # pylint: disable=redefined-outer-name + postproc: Optional[FnPostproc], + target: Target, +) -> List[Postproc]: + """Normalize the input to List[tvm.meta_schedule.Postproc]""" + if callable(postproc): + return postproc() + if postproc is not None: + raise TypeError(f"Expected `postproc` to be None or callable, but gets: {postproc}") + if target.kind.name == "llvm": + return _DefaultLLVM.postprocs() + if target.kind.name in ["cuda", "rocm", "vulkan"]: + return _DefaultCUDA.postprocs() + raise ValueError(f"Unsupported target: {target}") + + +def mutator_probs( # pylint: disable=redefined-outer-name + mutator_probs: Optional[FnMutatorProb], + target: Target, +) -> Dict[Mutator, float]: + """Normalize the input to Dict[tvm.meta_schedule.Mutator, float]""" + if callable(mutator_probs): + return mutator_probs() + if mutator_probs is not None: + raise TypeError( + f"Expected `mutator_probs` to be None or callable, but gets: {mutator_probs}" + ) + if target.kind.name == "llvm": + return _DefaultLLVM.mutator_probs() + if target.kind.name in ["cuda", "rocm", "vulkan"]: + return _DefaultCUDA.mutator_probs() + raise ValueError(f"Unsupported target: {target}") + + +class _DefaultLLVM: + """Default tuning configuration for LLVM.""" + + @staticmethod + def schedule_rules() -> List[ScheduleRule]: + from tvm.meta_schedule import schedule_rule as M + + return [ + M.AutoInline( + into_producer=False, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ), + M.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64), + M.MultiLevelTiling( + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=M.ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ), + M.ParallelizeVectorizeUnroll( + max_jobs_per_core=16, + max_vectorize_extent=64, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, + ), + M.RandomComputeLocation(), + ] + + @staticmethod + def postprocs() -> List[Postproc]: + from tvm.meta_schedule import postproc as M + + return [ + M.DisallowDynamicLoop(), + M.RewriteParallelVectorizeUnroll(), + M.RewriteReductionBlock(), + ] + + @staticmethod + def mutator_probs() -> Dict[Mutator, float]: + from tvm.meta_schedule import mutator as M + + return { + M.MutateTileSize(): 0.9, + M.MutateComputeLocation(): 0.05, + M.MutateUnroll(): 0.03, + M.MutateParallel(max_jobs_per_core=16): 0.02, + } + + +class _DefaultCUDA: + """Default tuning configuration for CUDA.""" + + @staticmethod + def schedule_rules() -> List[ScheduleRule]: + from tvm.meta_schedule import schedule_rule as M + + return [ + M.MultiLevelTiling( + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], + max_innermost_factor=64, + vector_load_lens=[1, 2, 3, 4], + reuse_read=M.ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=M.ReuseType( + req="must", + levels=[3], + scope="local", + ), + ), + M.AutoInline( + into_producer=True, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ), + M.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]), + M.ParallelizeVectorizeUnroll( + max_jobs_per_core=-1, # disable parallelize + max_vectorize_extent=-1, # disable vectorize + unroll_max_steps=[0, 16, 64, 512, 1024], + unroll_explicit=True, + ), + M.AutoBind( + max_threadblocks=256, + thread_extents=[32, 64, 128, 256, 512, 1024], + ), + ] + + @staticmethod + def postprocs() -> List[Postproc]: + from tvm.meta_schedule import postproc as M + + return [ + M.DisallowDynamicLoop(), + M.RewriteCooperativeFetch(), + M.RewriteUnboundBlock(), + M.RewriteParallelVectorizeUnroll(), + M.RewriteReductionBlock(), + M.VerifyGPUCode(), + ] + + @staticmethod + def mutator_probs() -> Dict[Mutator, float]: + from tvm.meta_schedule import mutator as M + + return { + M.MutateTileSize(): 0.9, + M.MutateUnroll(): 0.08, + M.MutateThreadBinding(): 0.02, + } diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py index 2b066f49bd91..0c8de9668034 100644 --- a/python/tvm/meta_schedule/mutator/mutator.py +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule Mutator.""" -from typing import Callable, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Optional from tvm._ffi import register_object from tvm.runtime import Object @@ -31,7 +31,7 @@ class Mutator(Object): """Mutator is designed to mutate the trace to explore the design space.""" - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the mutator with a tune context. Parameters @@ -94,10 +94,10 @@ class PyMutator: _tvm_metadata = { "cls": _PyMutator, - "methods": ["initialize_with_tune_context", "apply", "__str__"], + "methods": ["_initialize_with_tune_context", "apply", "__str__"], } - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the mutator with a tune context. Parameters diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py index 1706aae40614..e37666bd1ce0 100644 --- a/python/tvm/meta_schedule/postproc/postproc.py +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -33,7 +33,7 @@ class Postproc(Object): """Rules to apply a postprocessor to a schedule.""" - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the postprocessor with a tune context. Parameters @@ -96,10 +96,10 @@ class PyPostproc: _tvm_metadata = { "cls": _PyPostproc, - "methods": ["initialize_with_tune_context", "apply", "__str__"], + "methods": ["_initialize_with_tune_context", "apply", "__str__"], } - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the postprocessor with a tune context. Parameters diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index e3ffdb0f4f8e..481444341b86 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -22,10 +22,10 @@ from tvm._ffi import register_object from tvm.runtime import Object -from tvm.tir.schedule import Schedule, BlockRV +from tvm.tir.schedule import BlockRV, Schedule -from ..utils import _get_default_str from .. import _ffi_api +from ..utils import _get_default_str if TYPE_CHECKING: from ..tune_context import TuneContext @@ -35,7 +35,7 @@ class ScheduleRule(Object): """Rules to modify a block in a schedule.""" - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the schedule rule with a tune context. Parameters @@ -102,10 +102,10 @@ class PyScheduleRule: _tvm_metadata = { "cls": _PyScheduleRule, - "methods": ["initialize_with_tune_context", "apply", "__str__"], + "methods": ["_initialize_with_tune_context", "apply", "__str__"], } - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the schedule rule with a tune context. Parameters diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index 14b46a0785f1..1cd8a448fe8e 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -77,7 +77,7 @@ class SearchStrategy(Object): before usage and post-tuned after usage. """ - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the search strategy with tuning context. Parameters @@ -129,7 +129,6 @@ def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]: def notify_runner_results( self, - context: "TuneContext", measure_candidates: List[MeasureCandidate], results: List[RunnerResult], ) -> None: @@ -137,8 +136,6 @@ def notify_runner_results( Parameters ---------- - context : TuneContext - The tuning context for update. measure_candidates : List[MeasureCandidate] The measure candidates for update. results : List[RunnerResult] @@ -146,7 +143,6 @@ def notify_runner_results( """ _ffi_api.SearchStrategyNotifyRunnerResults( # type: ignore # pylint: disable=no-member self, - context, measure_candidates, results, ) @@ -192,7 +188,7 @@ class PySearchStrategy: _tvm_metadata = { "cls": _PySearchStrategy, "methods": [ - "initialize_with_tune_context", + "_initialize_with_tune_context", "pre_tuning", "post_tuning", "generate_measure_candidates", @@ -200,7 +196,7 @@ class PySearchStrategy: ], } - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the search strategy with tuning context. Parameters @@ -236,7 +232,6 @@ def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]: def notify_runner_results( self, - context: "TuneContext", measure_candidates: List[MeasureCandidate], results: List[RunnerResult], ) -> None: @@ -244,8 +239,6 @@ def notify_runner_results( Parameters ---------- - context : TuneContext - The tuning context for update. measure_candidates : List[MeasureCandidate] The measure candidates for update. results : List[RunnerResult] diff --git a/python/tvm/meta_schedule/space_generator/schedule_fn.py b/python/tvm/meta_schedule/space_generator/schedule_fn.py index 6763d9f9d56c..ffc13eecca26 100644 --- a/python/tvm/meta_schedule/space_generator/schedule_fn.py +++ b/python/tvm/meta_schedule/space_generator/schedule_fn.py @@ -53,7 +53,7 @@ def __init__(self, sch_fn: SCH_FN_TYPE): super().__init__() self.sch_fn = sch_fn - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the design space generator with tuning context. Parameters diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index 4b7fff0283e0..eb999de49585 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -35,7 +35,7 @@ class SpaceGenerator(Object): """The abstract design space generator interface.""" - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the design space generator with tuning context. Parameters @@ -96,10 +96,10 @@ class PySpaceGenerator: _tvm_metadata = { "cls": _PySpaceGenerator, - "methods": ["initialize_with_tune_context", "generate_design_space"], + "methods": ["_initialize_with_tune_context", "generate_design_space"], } - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the design space generator with tuning context. Parameters diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py index 24e57928778d..5d6081fa81e4 100644 --- a/python/tvm/meta_schedule/testing/__init__.py +++ b/python/tvm/meta_schedule/testing/__init__.py @@ -15,11 +15,3 @@ # specific language governing permissions and limitations # under the License. """Testing utilities in meta schedule""" -from .utils import ( - DummyDatabase, - DummyBuilder, - DummyRunner, - DummyRunnerFuture, - DummyMutator, - apply_fixed_schedules, -) diff --git a/python/tvm/meta_schedule/testing/dummy_object.py b/python/tvm/meta_schedule/testing/dummy_object.py new file mode 100644 index 000000000000..50ae974df5d8 --- /dev/null +++ b/python/tvm/meta_schedule/testing/dummy_object.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Dummy objects for testing.""" +import random +from typing import List, Optional + +from tvm.tir.schedule import Trace + +from ..builder import BuilderInput, BuilderResult, PyBuilder +from ..mutator import PyMutator +from ..runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput, RunnerResult +from ..tune_context import TuneContext # pylint: disable=unused-import +from ..utils import derived_object + + +@derived_object +class DummyRunnerFuture(PyRunnerFuture): + def done(self) -> bool: + return True + + def result(self) -> RunnerResult: + run_secs = [random.uniform(5, 30) for _ in range(random.randint(1, 10))] + return RunnerResult(run_secs, None) + + +@derived_object +class DummyBuilder(PyBuilder): + def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: + return [BuilderResult("test_path", None) for _ in build_inputs] + + +@derived_object +class DummyRunner(PyRunner): + def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: + return [DummyRunnerFuture() for _ in runner_inputs] # type: ignore + + +@derived_object +class DummyMutator(PyMutator): + """Dummy Mutator for testing""" + + def _initialize_with_tune_context(self, context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace, _) -> Optional[Trace]: + return Trace(trace.insts, {}) diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py index 62950fdd0bb4..f353d401a10c 100644 --- a/python/tvm/meta_schedule/testing/utils.py +++ b/python/tvm/meta_schedule/testing/utils.py @@ -15,114 +15,21 @@ # specific language governing permissions and limitations # under the License. """Testing utility functions in meta schedule""" -import random -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, Optional, Union -import tvm +from tvm import meta_schedule as ms from tvm.ir import IRModule -from tvm.meta_schedule import TuneContext # pylint: disable=unused-import -from tvm.meta_schedule.builder import BuilderInput, BuilderResult, PyBuilder -from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload -from tvm.meta_schedule.extracted_task import ExtractedTask -from tvm.meta_schedule.mutator.mutator import PyMutator -from tvm.meta_schedule.relay_integration import extract_task_from_relay -from tvm.meta_schedule.runner import ( - PyRunner, - PyRunnerFuture, - RunnerFuture, - RunnerInput, - RunnerResult, -) -from tvm.meta_schedule.tune import Parse -from tvm.meta_schedule.utils import derived_object from tvm.relay import Function as RelayFunc from tvm.runtime import NDArray from tvm.target import Target from tvm.tir import Schedule -from tvm.tir.schedule import Trace - - -@derived_object -class DummyDatabase(PyDatabase): - """ - An in-memory database based on python list for testing. - """ - - def __init__(self): - super().__init__() - self.records = [] - self.workload_reg = [] - - def has_workload(self, mod: IRModule) -> bool: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return True - return False - - def commit_tuning_record(self, record: TuningRecord) -> None: - self.records.append(record) - - def commit_workload(self, mod: IRModule) -> Workload: - for workload in self.workload_reg: - if tvm.ir.structural_equal(workload.mod, mod): - return workload - workload = Workload(mod) - self.workload_reg.append(workload) - return workload - - def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: - return list( - filter( - lambda x: x.workload == workload, - sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), - ) - )[: int(top_k)] - - def __len__(self) -> int: - return len(self.records) - - def print_results(self) -> None: - print("\n".join([str(r) for r in self.records])) - - -@derived_object -class DummyRunnerFuture(PyRunnerFuture): - def done(self) -> bool: - return True - - def result(self) -> RunnerResult: - run_secs = [random.uniform(5, 30) for _ in range(random.randint(1, 10))] - return RunnerResult(run_secs, None) - - -@derived_object -class DummyBuilder(PyBuilder): - def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: - return [BuilderResult("test_path", None) for _ in build_inputs] - - -@derived_object -class DummyRunner(PyRunner): - def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: - return [DummyRunnerFuture() for _ in runner_inputs] # type: ignore - - -@derived_object -class DummyMutator(PyMutator): - """Dummy Mutator for testing""" - - def initialize_with_tune_context(self, context: "TuneContext") -> None: - pass - - def apply(self, trace: Trace, _) -> Optional[Trace]: - return Trace(trace.insts, {}) def apply_fixed_schedules( relay_mod: Union[RelayFunc, IRModule], target: Union[str, Target], params: Optional[Dict[str, NDArray]], - schedule_fn: Callable[[ExtractedTask, Schedule], bool], + schedule_fn: Callable[[ms.ExtractedTask, Schedule], bool], ): """Apply fixed schedules (manually written, without any tunable knobs) as specified by schedule_fn to extracted tasks, and return a database that can be passed to ApplyHistoryBest. @@ -145,17 +52,15 @@ def apply_fixed_schedules( The database containing dummy tuning records for manually scheduled traces. """ target = Target(target) if isinstance(target, str) else target - extracted_tasks = extract_task_from_relay(relay_mod, target, params) - - database = DummyDatabase() - + extracted_tasks = ms.extract_task_from_relay(relay_mod, target, params) + database = ms.database.MemoryDatabase() for task in extracted_tasks: - mod = Parse._mod(task.dispatched[0]) + mod = ms.default_config.mod(task.dispatched[0]) sch = Schedule(mod) if schedule_fn(task, sch): workload = database.commit_workload(mod) - tune_rec = TuningRecord(sch.trace, workload, [0.0], target, []) + tune_rec = ms.database.TuningRecord(sch.trace, workload, [0.0], target, []) database.commit_tuning_record(tune_rec) return database diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 9af237b3b7b8..cc7c4cbc9356 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -22,7 +22,6 @@ from os import path as osp from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union -from tvm._ffi.registry import register_func from tvm.ir import IRModule from tvm.ir.transform import PassContext from tvm.runtime import Module, NDArray @@ -30,19 +29,19 @@ from tvm.te import Tensor, create_prim_func from tvm.tir import PrimFunc, Schedule +from . import default_config from .apply_history_best import ApplyHistoryBest -from .builder import Builder, LocalBuilder -from .cost_model import CostModel, XGBModel -from .database import Database, JSONDatabase, TuningRecord +from .builder import Builder +from .cost_model import CostModel +from .database import Database, TuningRecord from .extracted_task import ExtractedTask -from .feature_extractor import PerStoreFeature from .measure_callback import MeasureCallback from .mutator import Mutator from .postproc import Postproc -from .runner import LocalRunner, Runner +from .runner import Runner from .schedule_rule import ScheduleRule from .search_strategy import EvolutionarySearch, ReplayFunc, ReplayTrace -from .space_generator import PostOrderApply, SpaceGenerator +from .space_generator import SpaceGenerator from .task_scheduler import GradientBased, RoundRobin from .tune_context import TuneContext from .utils import autotvm_silencer, batch_parameterize_config @@ -55,295 +54,6 @@ FnMutatorProb = Callable[[], Dict[Mutator, float]] -class DefaultLLVM: - """Default tuning configuration for LLVM.""" - - @staticmethod - def _sch_rules() -> List[ScheduleRule]: - from tvm.meta_schedule import schedule_rule as M - - return [ - M.AutoInline( - into_producer=False, - into_consumer=True, - inline_const_tensor=True, - disallow_if_then_else=True, - require_injective=True, - require_ordered=True, - disallow_op=["tir.exp"], - ), - M.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64), - M.MultiLevelTiling( - structure="SSRSRS", - tile_binds=None, - max_innermost_factor=64, - vector_load_lens=None, - reuse_read=None, - reuse_write=M.ReuseType( - req="may", - levels=[1, 2], - scope="global", - ), - ), - M.ParallelizeVectorizeUnroll( - max_jobs_per_core=16, - max_vectorize_extent=64, - unroll_max_steps=[0, 16, 64, 512], - unroll_explicit=True, - ), - M.RandomComputeLocation(), - ] - - @staticmethod - def _postproc() -> List[Postproc]: - from tvm.meta_schedule import postproc as M - - return [ - M.DisallowDynamicLoop(), - M.RewriteParallelVectorizeUnroll(), - M.RewriteReductionBlock(), - ] - - @staticmethod - def _mutator_probs() -> Dict[Mutator, float]: - from tvm.meta_schedule import mutator as M - - return { - M.MutateTileSize(): 0.9, - M.MutateComputeLocation(): 0.05, - M.MutateUnroll(): 0.03, - M.MutateParallel(max_jobs_per_core=16): 0.02, - } - - -class DefaultCUDA: - """Default tuning configuration for CUDA.""" - - @staticmethod - def _sch_rules() -> List[ScheduleRule]: - from tvm.meta_schedule import schedule_rule as M - - return [ - M.MultiLevelTiling( - structure="SSSRRSRS", - tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], - max_innermost_factor=64, - vector_load_lens=[1, 2, 3, 4], - reuse_read=M.ReuseType( - req="must", - levels=[4], - scope="shared", - ), - reuse_write=M.ReuseType( - req="must", - levels=[3], - scope="local", - ), - ), - M.AutoInline( - into_producer=True, - into_consumer=True, - inline_const_tensor=True, - disallow_if_then_else=False, - require_injective=False, - require_ordered=False, - disallow_op=None, - ), - M.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]), - M.ParallelizeVectorizeUnroll( - max_jobs_per_core=-1, # disable parallelize - max_vectorize_extent=-1, # disable vectorize - unroll_max_steps=[0, 16, 64, 512, 1024], - unroll_explicit=True, - ), - M.AutoBind( - max_threadblocks=256, - thread_extents=[32, 64, 128, 256, 512, 1024], - ), - ] - - @staticmethod - def _postproc() -> List[Postproc]: - from tvm.meta_schedule import postproc as M - - return [ - M.DisallowDynamicLoop(), - M.RewriteCooperativeFetch(), - M.RewriteUnboundBlock(), - M.RewriteParallelVectorizeUnroll(), - M.RewriteReductionBlock(), - M.VerifyGPUCode(), - ] - - @staticmethod - def _mutator_probs() -> Dict[Mutator, float]: - from tvm.meta_schedule import mutator as M - - return { - M.MutateTileSize(): 0.9, - M.MutateUnroll(): 0.08, - M.MutateThreadBinding(): 0.02, - } - - -class Parse: - """Parse tuning configuration from user inputs.""" - - @staticmethod - @register_func("tvm.meta_schedule.tune.parse_mod") # for use in ApplyHistoryBest - def _mod(mod: Union[PrimFunc, IRModule]) -> IRModule: - if isinstance(mod, PrimFunc): - mod = mod.with_attr("global_symbol", "main") - mod = mod.with_attr("tir.noalias", True) - mod = IRModule({"main": mod}) - if not isinstance(mod, IRModule): - raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") - # in order to make sure the mod can be found in ApplyHistoryBest - # different func name can cause structural unequal - func_names = mod.get_global_vars() - (func_name,) = func_names - if len(func_names) == 1 and func_name != "main": - mod = IRModule({"main": mod[func_name]}) - return mod - - @staticmethod - def _target(target: Union[str, Target]) -> Target: - if isinstance(target, str): - target = Target(target) - if not isinstance(target, Target): - raise TypeError(f"Expected `target` to be str or Target, but gets: {target}") - return target - - @staticmethod - def _builder(builder: Optional[Builder]) -> Builder: - if builder is None: - builder = LocalBuilder() # type: ignore - if not isinstance(builder, Builder): - raise TypeError(f"Expected `builder` to be Builder, but gets: {builder}") - return builder - - @staticmethod - def _runner(runner: Optional[Runner]) -> Runner: - if runner is None: - runner = LocalRunner() # type: ignore - if not isinstance(runner, Runner): - raise TypeError(f"Expected `runner` to be Runner, but gets: {runner}") - return runner - - @staticmethod - def _database(database: Union[None, Database], path: str) -> Database: - if database is None: - path_workload = osp.join(path, "database_workload.json") - path_tuning_record = osp.join(path, "database_tuning_record.json") - logger.info( - "Creating JSONDatabase. Workload at: %s. Tuning records at: %s", - path_workload, - path_tuning_record, - ) - database = JSONDatabase( - path_workload=path_workload, - path_tuning_record=path_tuning_record, - ) - if not isinstance(database, Database): - raise TypeError(f"Expected `database` to be Database, but gets: {database}") - return database - - @staticmethod - def _callbacks( - measure_callbacks: Optional[List[MeasureCallback]], - ) -> List[MeasureCallback]: - if measure_callbacks is None: - from tvm.meta_schedule import measure_callback as M - - return [ - M.AddToDatabase(), - M.RemoveBuildArtifact(), - M.EchoStatistics(), - M.UpdateCostModel(), - ] - if not isinstance(measure_callbacks, (list, tuple)): - raise TypeError( - f"Expected `measure_callbacks` to be List[MeasureCallback], " - f"but gets: {measure_callbacks}" - ) - measure_callbacks = list(measure_callbacks) - for i, callback in enumerate(measure_callbacks): - if not isinstance(callback, MeasureCallback): - raise TypeError( - f"Expected `measure_callbacks` to be List[MeasureCallback], " - f"but measure_callbacks[{i}] is: {callback}" - ) - return measure_callbacks - - @staticmethod - def _cost_model(cost_model: Optional[CostModel]) -> CostModel: - if cost_model is None: - return XGBModel(extractor=PerStoreFeature()) # type: ignore - if not isinstance(cost_model, CostModel): - raise TypeError(f"Expected `cost_model` to be CostModel, but gets: {cost_model}") - return cost_model - - @staticmethod - def _space_generator(space_generator: Optional[FnSpaceGenerator]) -> SpaceGenerator: - if space_generator is None: - return PostOrderApply() - if callable(space_generator): - space_generator = space_generator() - if not isinstance(space_generator, SpaceGenerator): - raise TypeError( - f"Expected `space_generator` to return SpaceGenerator, " - f"but gets: {space_generator}" - ) - return space_generator - - @staticmethod - def _sch_rules(sch_rules: Optional[FnScheduleRule], target: Target) -> List[ScheduleRule]: - if callable(sch_rules): - return sch_rules() - if sch_rules is not None: - raise TypeError(f"Expected `sch_rules` to be None or callable, but gets: {sch_rules}") - # pylint: disable=protected-access - if target.kind.name == "llvm": - return DefaultLLVM._sch_rules() - if target.kind.name in ["cuda", "rocm", "vulkan"]: - return DefaultCUDA._sch_rules() - # pylint: enable=protected-access - raise ValueError(f"Unsupported target: {target}") - - @staticmethod - def _postproc(postproc: Optional[FnPostproc], target: Target) -> List[Postproc]: - if callable(postproc): - return postproc() - if postproc is not None: - raise TypeError(f"Expected `postproc` to be None or callable, but gets: {postproc}") - # pylint: disable=protected-access - if target.kind.name == "llvm": - return DefaultLLVM._postproc() - if target.kind.name in ["cuda", "rocm", "vulkan"]: - return DefaultCUDA._postproc() - # pylint: enable=protected-access - raise ValueError(f"Unsupported target: {target}") - - @staticmethod - def _mutator_probs( - mutator_probs: Optional[FnMutatorProb], - target: Target, - ) -> Dict[Mutator, float]: - if callable(mutator_probs): - return mutator_probs() - if mutator_probs is not None: - raise TypeError( - f"Expected `mutator_probs` to be None or callable, but gets: {mutator_probs}" - ) - # pylint: disable=protected-access - if target.kind.name == "llvm": - return DefaultLLVM._mutator_probs() - if target.kind.name in ["cuda", "rocm", "vulkan"]: - return DefaultCUDA._mutator_probs() - # pylint: enable=protected-access - raise ValueError(f"Unsupported target: {target}") - - class TuneConfig(NamedTuple): """Configuration for tuning @@ -544,7 +254,7 @@ def tune_extracted_tasks( Parameters ---------- extracted_tasks : List[ExtractedTask] - The list of extraced tasks. + The list of extracted tasks. config : TuneConfig The search strategy config. work_dir : Optional[str] @@ -597,24 +307,24 @@ def tune_extracted_tasks( ) logger.info("Working directory: %s", work_dir) - database = Parse._database(database, work_dir) - builder = Parse._builder(builder) - runner = Parse._runner(runner) - cost_model = Parse._cost_model(cost_model) - measure_callbacks = Parse._callbacks(measure_callbacks) + database = default_config.database(database, work_dir) + builder = default_config.builder(builder) + runner = default_config.runner(runner) + cost_model = default_config.cost_model(cost_model) + measure_callbacks = default_config.callbacks(measure_callbacks) # parse the tuning contexts tune_contexts = [] for i, task in enumerate(extracted_tasks): assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" tune_contexts.append( TuneContext( - mod=Parse._mod(task.dispatched[0]), + mod=default_config.mod(task.dispatched[0]), target=task.target, - space_generator=Parse._space_generator(space), + space_generator=default_config.space_generator(space), search_strategy=config.create_strategy(), - sch_rules=Parse._sch_rules(sch_rules, task.target), - postprocs=Parse._postproc(postprocs, task.target), - mutator_probs=Parse._mutator_probs(mutator_probs, task.target), + sch_rules=default_config.schedule_rules(sch_rules, task.target), + postprocs=default_config.postproc(postprocs, task.target), + mutator_probs=default_config.mutator_probs(mutator_probs, task.target), task_name=task.task_name, logger=logging.getLogger( logger_name_pattern.format(task_id=i, task_name=task.task_name) @@ -694,8 +404,7 @@ def tune_tir( ) # pylint: disable=protected-access - mod = Parse._mod(mod) - target = Parse._target(target) + target = default_config.target(target) # pylint: enable=protected-access database = tune_extracted_tasks( extracted_tasks=[ @@ -851,7 +560,7 @@ def tune_relay( from .relay_integration import extract_task_from_relay # pylint: disable=protected-access, enable=import-outside-toplevel - target = Parse._target(target) + target = default_config.target(target) # pylint: enable=protected-access, # parse the tuning contexts extracted_tasks = extract_task_from_relay(mod, target, params) diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 19ab0a40cf61..78fd3d659faf 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -17,23 +17,26 @@ """Meta Schedule tuning context.""" import logging -from typing import Optional, List, Dict, TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, List, Optional from tvm import IRModule from tvm._ffi import register_object from tvm.meta_schedule.utils import cpu_count, make_logging_func from tvm.runtime import Object from tvm.target import Target -from tvm.tir import PrimFunc +from tvm.tir import PrimFunc, Schedule from . import _ffi_api if TYPE_CHECKING: - from .space_generator import SpaceGenerator - from .search_strategy import SearchStrategy - from .schedule_rule import ScheduleRule - from .postproc import Postproc + from .cost_model import CostModel + from .database import Database from .mutator import Mutator + from .postproc import Postproc + from .runner import RunnerResult + from .schedule_rule import ScheduleRule + from .search_strategy import MeasureCandidate, SearchStrategy + from .space_generator import SpaceGenerator @register_object("meta_schedule.TuneContext") @@ -114,7 +117,6 @@ def __init__( self.logger = logging.getLogger(__name__) else: self.logger = None - self.__init_handle_by_constructor__( _ffi_api.TuneContext, # type: ignore # pylint: disable=no-member mod, @@ -132,5 +134,105 @@ def __init__( def initialize(self): """Initialize the tuning context""" - _ffi_api.TuneContextInitialize(self) # type: ignore # pylint: disable=no-member + + def generate_design_space(self) -> List[Schedule]: + """Generate design spaces given a module. + + Delegated to self.space_generator.generate_design_space with self.mod + + Returns + ------- + design_spaces : List[Schedule] + The generated design spaces, i.e., schedules. + """ + if self.mod is None: + raise ValueError("`mod` is not provided. Please construct TuneContext with `mod`") + if self.space_generator is None: + raise ValueError( + "space_generator is not provided." + "Please construct TuneContext with space_generator" + ) + return self.space_generator.generate_design_space(self.mod) + + def pre_tuning( + self, + design_spaces: List[Schedule], + database: Optional["Database"] = None, + cost_model: Optional["CostModel"] = None, + ) -> None: + """A method to be called for SearchStrategy to do necessary preparation before tuning. + + Delegated to self.search_strategy.pre_tuning. + + Parameters + ---------- + design_spaces : List[Schedule] + The design spaces used during tuning process. + database : Optional[Database] = None + The database used during tuning process. + cost_model : Optional[CostModel] = None + The cost model used during tuning process. + """ + if self.search_strategy is None: + raise ValueError( + "search_strategy is not provided." + "Please construct TuneContext with search_strategy" + ) + return self.search_strategy.pre_tuning(design_spaces, database, cost_model) + + def post_tuning(self) -> None: + """A method to be called for SearchStrategy to do necessary cleanup after tuning. + + Delegated to self.search_strategy.post_tuning. + """ + if self.search_strategy is None: + raise ValueError( + "search_strategy is not provided." + "Please construct TuneContext with search_strategy" + ) + _ffi_api.SearchStrategyPostTuning(self) # type: ignore # pylint: disable=no-member + + def generate_measure_candidates(self) -> Optional[List["MeasureCandidate"]]: + """Generate a batch of measure candidates from design spaces for measurement. + + Delegated to self.search_strategy.generate_measure_candidates. + + Returns + ------- + measure_candidates : Optional[List[IRModule]] + The measure candidates generated, None if search is finished. + """ + if self.search_strategy is None: + raise ValueError( + "search_strategy is not provided." + "Please construct TuneContext with search_strategy" + ) + return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # type: ignore # pylint: disable=no-member + + def notify_runner_results( + self, + measure_candidates: List["MeasureCandidate"], + results: List["RunnerResult"], + ) -> None: + """Update the state in SearchStrategy with profiling results. + + Delegated to self.search_strategy.notify_runner_results. + + Parameters + ---------- + measure_candidates : List[MeasureCandidate] + The measure candidates for update. + results : List[RunnerResult] + The profiling results from the runner. + """ + if self.search_strategy is None: + raise ValueError( + "search_strategy is not provided." + "Please construct TuneContext with search_strategy" + ) + _ffi_api.SearchStrategyNotifyRunnerResults( # type: ignore # pylint: disable=no-member + self, + measure_candidates, + results, + ) diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index 8b36a9521704..7714af3fec74 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -314,8 +314,7 @@ class EvolutionarySearchNode : public SearchStrategyNode { /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ inline Optional> GenerateMeasureCandidates(); /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ - inline void NotifyRunnerResults(const TuneContext& context, - const Array& measure_candidates, + inline void NotifyRunnerResults(const Array& measure_candidates, const Array& results); }; @@ -399,7 +398,7 @@ class EvolutionarySearchNode : public SearchStrategyNode { << "ValueError: Database is not supplied in PreTuning. Evolutionary" "search algorithm requires a database to be present, so that it " "could sample from previously-explored population. If you do not " - "intent to store data on disk, please use `tvm.meta_schedule.testing.DummyDatabase`"; + "intent to store data on disk, please use `tvm.meta_schedule.database.MemoryDatabase`"; CHECK(cost_model.defined()) << "ValueError: CostModel is not supplied in PreTuning. Evolutionary search " "algorithm expects a cost model to filter out potentially less efficient kernels. If " @@ -430,11 +429,10 @@ class EvolutionarySearchNode : public SearchStrategyNode { return this->state_->GenerateMeasureCandidates(); } - void NotifyRunnerResults(const TuneContext& context, - const Array& measure_candidates, + void NotifyRunnerResults(const Array& measure_candidates, const Array& results) final { ICHECK(this->state_ != nullptr); - this->state_->NotifyRunnerResults(context, measure_candidates, results); + this->state_->NotifyRunnerResults(measure_candidates, results); } }; @@ -681,8 +679,7 @@ Optional> EvolutionarySearchNode::State::GenerateMeasure } void EvolutionarySearchNode::State::NotifyRunnerResults( - const TuneContext& context, const Array& measure_candidates, - const Array& results) { + const Array& measure_candidates, const Array& results) { st += results.size(); ed += results.size(); } diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index 1aaaaa09e8ab..24bc38ae80f5 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -98,8 +98,7 @@ class ReplayFuncNode : public SearchStrategyNode { return this->state_->GenerateMeasureCandidates(); } - void NotifyRunnerResults(const TuneContext& context, - const Array& measure_candidates, + void NotifyRunnerResults(const Array& measure_candidates, const Array& results) final { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(results); diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 355f71455d91..b4b5ef8b3154 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -116,8 +116,7 @@ class ReplayTraceNode : public SearchStrategyNode { return this->state_->GenerateMeasureCandidates(); } - void NotifyRunnerResults(const TuneContext& context, - const Array& measure_candidates, + void NotifyRunnerResults(const Array& measure_candidates, const Array& results) final { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(results); diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc index f4c392ca2f1a..5865fc842248 100644 --- a/src/meta_schedule/search_strategy/search_strategy.cc +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -52,12 +52,11 @@ Optional> PySearchStrategyNode::GenerateMeasureCandidate return f_generate_measure_candidates(); } -void PySearchStrategyNode::NotifyRunnerResults(const TuneContext& context, - const Array& measure_candidates, +void PySearchStrategyNode::NotifyRunnerResults(const Array& measure_candidates, const Array& results) { ICHECK(f_notify_runner_results != nullptr) << "PySearchStrategy's NotifyRunnerResults method not implemented!"; - f_notify_runner_results(context, measure_candidates, results); + f_notify_runner_results(measure_candidates, results); } SearchStrategy SearchStrategy::PySearchStrategy( diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 5d41f2edfb26..9c1f451414e3 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -21,77 +21,6 @@ namespace tvm { namespace meta_schedule { -/*! - * \brief Send the measure candidates to builder. - * \param builder The builder to send the candidates to. - * \param context The tuning context. - * \param candidates The measure candidates. - */ -void SendToBuilder(const Builder& builder, const TuneContext& context, PackedFunc logging_func) { - Array candidates = context->measure_candidates.value(); - TVM_PY_LOG(INFO, logging_func) << "Sending " << candidates.size() << " sample(s) to builder"; - Target target = context->target.value(); - Array inputs; - inputs.reserve(candidates.size()); - for (const MeasureCandidate& candidate : candidates) { - ICHECK(candidate.defined()) << "Undefined MeasureCandidate found"; - inputs.push_back(BuilderInput(candidate->sch->mod(), target)); - } - context->builder_results = builder->Build(inputs); -} - -/*! - * \brief Send the built measure candidates to runner. - * \param runner The runner to send the candidates to. - * \param context The tuning context. - * \param candidates The measure candidates. - * \param builder_results The builder results. - * \return An array of the runner results. - */ -void SendToRunner(const Runner& runner, const TuneContext& context, PackedFunc logging_func) { - Array candidates = context->measure_candidates.value(); - Array builder_results = context->builder_results.value(); - TVM_PY_LOG(INFO, logging_func) << "Sending " << candidates.size() << " sample(s) to runner"; - Target target = context->target.value(); - ICHECK_EQ(candidates.size(), builder_results.size()); - int n = candidates.size(); - int n_build_errors = 0; - Array inputs; - inputs.reserve(n); - for (int i = 0; i < n; ++i) { - const MeasureCandidate& candidate = candidates[i]; - const BuilderResult& builder_result = builder_results[i]; - if (builder_result->error_msg.defined()) { - ++n_build_errors; - continue; - } - inputs.push_back(RunnerInput(/*artifact_path=*/builder_result->artifact_path.value(), - /*device_type=*/target->kind->name, - /*args_info=*/candidate->args_info)); - } - Array futures = runner->Run(inputs); - if (n_build_errors == 0) { - context->runner_futures = futures; - return; - } - Array results; - results.reserve(n); - for (int i = 0, j = 0; i < n; ++i) { - const BuilderResult& builder_result = builder_results[i]; - if (builder_result->error_msg.defined()) { - results.push_back(RunnerFuture( - /*f_done=*/[]() -> bool { return true; }, - /*f_result=*/ - [msg = builder_result->error_msg]() -> RunnerResult { - return RunnerResult(NullOpt, msg); - })); - } else { - results.push_back(futures[j++]); - } - } - context->runner_futures = results; -} - void TaskSchedulerNode::InitializeTask(int task_id) { TuneContext task = this->tasks[task_id]; TVM_PY_LOG(INFO, this->logging_func) @@ -132,11 +61,17 @@ void TaskSchedulerNode::Tune() { TuneContext task = tasks[task_id]; ICHECK(!task->is_terminated); ICHECK(!task->runner_futures.defined()); - SearchStrategy strategy = task->search_strategy.value(); - if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) { - num_trials_already += task->measure_candidates.value().size(); - SendToBuilder(this->builder, task, this->logging_func); - SendToRunner(this->runner, task, this->logging_func); + if (Optional> candidates = + task->search_strategy.value()->GenerateMeasureCandidates()) { + int num_candidates = candidates.value().size(); + task->_SetMeasureCandidates(candidates.value()); + num_trials_already += num_candidates; + TVM_PY_LOG(INFO, this->logging_func) + << "Sending " << num_candidates << " sample(s) to builder"; + task->_SendToBuilder(this->builder); + TVM_PY_LOG(INFO, this->logging_func) + << "Sending " << num_candidates << " sample(s) to runner"; + task->_SendToRunner(this->runner); } else { ICHECK(!task->is_terminated); task->is_terminated = true; @@ -174,28 +109,12 @@ void TaskSchedulerNode::TouchTask(int task_id) { Array TaskSchedulerNode::JoinRunningTask(int task_id) { TuneContext task = tasks[task_id]; - ICHECK(task->runner_futures.defined()); - Array futures = task->runner_futures.value(); - int n = futures.size(); - Array results; - results.reserve(n); - for (RunnerFuture future : futures) { - results.push_back(future->Result()); - } - task->search_strategy.value()->NotifyRunnerResults(task, task->measure_candidates.value(), - results); - // Invoke the callbacks - ICHECK(task->measure_candidates.defined()); - ICHECK(task->builder_results.defined()); - ICHECK_EQ(results.size(), task->measure_candidates.value().size()); - ICHECK_EQ(results.size(), task->builder_results.value().size()); + Array results = task->_Join(); for (const MeasureCallback& callback : this->measure_callbacks) { callback->Apply(GetRef(this), task_id, task->measure_candidates.value(), task->builder_results.value(), results); } - task->measure_candidates = NullOpt; - task->builder_results = NullOpt; - task->runner_futures = NullOpt; + task->_ClearMeasureState(); return results; } diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 3607e3050803..362db0a38097 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -70,6 +70,87 @@ void TuneContextNode::Initialize() { } } +void TuneContextNode::_SetMeasureCandidates(const Array& candidates) { + this->measure_candidates = candidates; +} + +void TuneContextNode::_SendToBuilder(const Builder& builder) { + Array candidates = this->measure_candidates.value(); + Target target = this->target.value(); + Array inputs; + inputs.reserve(candidates.size()); + for (const MeasureCandidate& candidate : candidates) { + inputs.push_back(BuilderInput(candidate->sch->mod(), target)); + } + this->builder_results = builder->Build(inputs); +} + +void TuneContextNode::_SendToRunner(const Runner& runner) { + Array candidates = this->measure_candidates.value(); + Array builder_results = this->builder_results.value(); + Target target = this->target.value(); + ICHECK_EQ(candidates.size(), builder_results.size()); + int n = candidates.size(); + int n_build_errors = 0; + Array inputs; + inputs.reserve(n); + for (int i = 0; i < n; ++i) { + const MeasureCandidate& candidate = candidates[i]; + const BuilderResult& builder_result = builder_results[i]; + if (builder_result->error_msg.defined()) { + ++n_build_errors; + continue; + } + inputs.push_back(RunnerInput(/*artifact_path=*/builder_result->artifact_path.value(), + /*device_type=*/target->kind->name, + /*args_info=*/candidate->args_info)); + } + Array futures = runner->Run(inputs); + if (n_build_errors == 0) { + this->runner_futures = futures; + return; + } + Array results; + results.reserve(n); + for (int i = 0, j = 0; i < n; ++i) { + const BuilderResult& builder_result = builder_results[i]; + if (builder_result->error_msg.defined()) { + results.push_back(RunnerFuture( + /*f_done=*/[]() -> bool { return true; }, + /*f_result=*/ + [msg = builder_result->error_msg]() -> RunnerResult { + return RunnerResult(NullOpt, msg); + })); + } else { + results.push_back(futures[j++]); + } + } + this->runner_futures = results; +} + +Array TuneContextNode::_Join() { + ICHECK(this->runner_futures.defined()); + Array futures = this->runner_futures.value(); + int n = futures.size(); + Array results; + results.reserve(n); + for (RunnerFuture future : futures) { + results.push_back(future->Result()); + } + this->search_strategy.value()->NotifyRunnerResults(this->measure_candidates.value(), results); + ICHECK(this->measure_candidates.defined()); + ICHECK(this->builder_results.defined()); + ICHECK_EQ(results.size(), this->measure_candidates.value().size()); + ICHECK_EQ(results.size(), this->builder_results.value().size()); + return results; +} + +void TuneContextNode::_ClearMeasureState() { + this->measure_candidates = NullOpt; + this->builder_results = NullOpt; + this->runner_futures = NullOpt; +} + TVM_REGISTER_NODE_TYPE(TuneContextNode); TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") diff --git a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py index 04dcf957780c..31b8b8182995 100644 --- a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py +++ b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py @@ -17,11 +17,9 @@ # pylint: disable=missing-docstring import tvm +from tvm import meta_schedule as ms from tvm.ir import IRModule -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.space_generator import PostOrderApply from tvm.meta_schedule.testing.conv2d_winograd_cpu import conv2d_winograd_cpu -from tvm.meta_schedule.tune import DefaultLLVM from tvm.target import Target from tvm.tir.schedule import Schedule, Trace @@ -164,16 +162,20 @@ def inverse(sch: Schedule): def test_conv2d_winograd_cpu(): mod = conv2d_winograd_cpu mod = IRModule({"main": mod}) - context = TuneContext( + target = Target("llvm --num-cores=16") + context = ms.TuneContext( mod=mod, - target=Target("llvm"), + target=target, task_name="Custom Search Space Task", - sch_rules=DefaultLLVM._sch_rules(), # pylint: disable=protected-access + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=ms.default_config.schedule_rules( + None, + target, + ), ) - post_order_apply = PostOrderApply() - post_order_apply.initialize_with_tune_context(context) + context.initialize() + post_order_apply = context.space_generator (sch,) = post_order_apply.generate_design_space(mod) - decisions = dict( zip( [i for i in sch.trace.insts[:-4] if i.kind.name.startswith("Sample")], diff --git a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py index 328f98e7f0cb..f8fdb79a1ded 100644 --- a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py +++ b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py @@ -17,11 +17,9 @@ # pylint: disable=missing-docstring import tvm +from tvm import meta_schedule as ms from tvm.ir import IRModule -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.space_generator import PostOrderApply from tvm.meta_schedule.testing.conv2d_winograd_cuda import conv2d_winograd_cuda -from tvm.meta_schedule.tune import DefaultCUDA from tvm.target import Target from tvm.tir.schedule import Schedule, Trace @@ -283,16 +281,17 @@ def root_anno(sch: Schedule): def test_conv2d_winograd_cuda(): mod = conv2d_winograd_cuda mod = IRModule({"main": mod}) - context = TuneContext( + context = ms.TuneContext( mod=mod, target=Target("nvidia/geforce-rtx-3090", host="llvm"), task_name="Custom Search Space Task", - sch_rules=DefaultCUDA._sch_rules(), # pylint: disable=protected-access + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=ms.default_config.schedule_rules( # pylint: disable=protected-access + None, Target("cuda") + ), ) - for sch_rule in context.sch_rules: - sch_rule.initialize_with_tune_context(context) - post_order_apply = PostOrderApply() - post_order_apply.initialize_with_tune_context(context) + context.initialize() + post_order_apply = context.space_generator (sch,) = post_order_apply.generate_design_space(mod) decisions = dict( zip( diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index 3b33039bd287..155d6aa235fd 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -14,21 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys - +"""Integration test for MetaSchedule""" import numpy as np import pytest import tvm import tvm.testing from tvm import meta_schedule as ms -from tvm import relay -from tvm.meta_schedule import ApplyHistoryBest -from tvm.meta_schedule.database import TuningRecord -from tvm.meta_schedule.relay_integration import extract_task_from_relay -from tvm.meta_schedule.testing import DummyDatabase +from tvm import relay, te, tir from tvm.meta_schedule.testing.relay_workload import get_network from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base -from tvm.meta_schedule.tune import Parse from tvm.script import tir as T from tvm.target import Target from tvm.tir import Schedule @@ -63,7 +57,7 @@ def _has_torch(): def test_meta_schedule_apply_history_best_no_current(): - assert ApplyHistoryBest.current() is None + assert ms.ApplyHistoryBest.current() is None @requires_torch @@ -199,7 +193,6 @@ def test_meta_schedule_integration_extract_from_bert_base(): @requires_torch def test_meta_schedule_integration_extract_from_resnet_with_filter_func(): def filter_func(args) -> bool: - from tvm import te, tir has_complex_op = False visited = set() @@ -262,14 +255,25 @@ def traverse(t): @requires_torch def test_meta_schedule_integration_apply_history_best(): mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) - database = DummyDatabase() - env = ApplyHistoryBest(database) + database = ms.database.MemoryDatabase() + env = ms.ApplyHistoryBest(database) target = Target("llvm") workload = database.commit_workload(MockModule) database.commit_tuning_record( - TuningRecord(Schedule(MockModule).trace, workload, [1.0], target, []) + ms.database.TuningRecord( + trace=Schedule(MockModule).trace, + workload=workload, + run_secs=[1.0], + target=target, + args_info=[], + ) + ) + mod = env.query( + task_name="mock-task", + mod=mod, + target=target, + dispatched=[MockModule], ) - mod = env.query(task_name="mock-task", mod=mod, target=target, dispatched=[MockModule]) assert tvm.ir.structural_equal(mod, workload.mod) @@ -277,7 +281,7 @@ def test_meta_schedule_integration_apply_history_best(): def extract_task_qbert(): mod, params, _ = load_quantized_bert_base(batch_size=1, seq_len=128) target = "llvm -mcpu=cascadelake" - extracted_tasks = extract_task_from_relay(mod, target, params) + extracted_tasks = ms.extract_task_from_relay(mod, target, params) tune_tasks = list( filter( lambda task: "dense" in task.task_name or "batch_matmul" in task.task_name, @@ -294,7 +298,7 @@ def extract_task_qbert(): if out_type.dtype == "float32": continue - mod = Parse._mod(task.dispatched[0]) + mod = ms.default_config.mod(task.dispatched[0]) sch = tvm.tir.Schedule(mod) block = sch.get_block("compute") annotations = sch.get(block).annotations @@ -331,7 +335,7 @@ def test_extract_task_arm_conv2d_nchwc(): params = {"weight": weight_np, "bias": bias_np} target = "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon" - extracted_tasks = extract_task_from_relay(relay_mod, target, params) + extracted_tasks = ms.extract_task_from_relay(relay_mod, target, params) tune_tasks = list( filter( lambda task: "conv2d" in task.task_name, diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py index 298b51e0158e..fba8c883e501 100644 --- a/tests/python/unittest/test_meta_schedule_measure_callback.py +++ b/tests/python/unittest/test_meta_schedule_measure_callback.py @@ -20,13 +20,8 @@ import pytest import tvm -from tvm.meta_schedule.builder import BuilderResult -from tvm.meta_schedule.measure_callback import PyMeasureCallback -from tvm.meta_schedule.runner import RunnerResult -from tvm.meta_schedule.search_strategy import MeasureCandidate -from tvm.meta_schedule.task_scheduler import RoundRobin, TaskScheduler -from tvm.meta_schedule.testing import DummyBuilder, DummyDatabase, DummyRunner -from tvm.meta_schedule.utils import derived_object +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.dummy_object import DummyBuilder, DummyRunner from tvm.script import tir as T from tvm.tir.schedule import Schedule @@ -53,85 +48,87 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: def test_meta_schedule_measure_callback(): - @derived_object - class FancyMeasureCallback(PyMeasureCallback): + @ms.derived_object + class FancyMeasureCallback(ms.measure_callback.PyMeasureCallback): def apply( self, - task_scheduler: TaskScheduler, + task_scheduler: ms.task_scheduler.TaskScheduler, task_id: int, - measure_candidates: List[MeasureCandidate], - builds: List[BuilderResult], - results: List[RunnerResult], + measure_candidates: List[ms.MeasureCandidate], + builder_results: List[ms.builder.BuilderResult], + runner_results: List[ms.runner.RunnerResult], ) -> None: assert len(measure_candidates) == 1 tvm.ir.assert_structural_equal(measure_candidates[0].sch.mod, Matmul) assert ( - len(builds) == 1 - and builds[0].error_msg is None - and builds[0].artifact_path == "test_build" + len(builder_results) == 1 + and builder_results[0].error_msg is None + and builder_results[0].artifact_path == "test_build" ) assert ( - len(results) == 1 and results[0].error_msg is None and len(results[0].run_secs) == 2 + len(runner_results) == 1 + and runner_results[0].error_msg is None + and len(runner_results[0].run_secs) == 2 ) measure_callback = FancyMeasureCallback() measure_callback.apply( - RoundRobin( + ms.task_scheduler.RoundRobin( tasks=[], task_weights=[], builder=DummyBuilder(), runner=DummyRunner(), - database=DummyDatabase(), + database=ms.database.MemoryDatabase(), max_trials=1, ), 0, - [MeasureCandidate(Schedule(Matmul), None)], - [BuilderResult("test_build", None)], - [RunnerResult([1.0, 2.1], None)], + [ms.MeasureCandidate(Schedule(Matmul), None)], + [ms.builder.BuilderResult("test_build", None)], + [ms.runner.RunnerResult([1.0, 2.1], None)], ) def test_meta_schedule_measure_callback_fail(): - @derived_object - class FailingMeasureCallback(PyMeasureCallback): + @ms.derived_object + class FailingMeasureCallback(ms.measure_callback.PyMeasureCallback): def apply( self, - task_scheduler: TaskScheduler, + task_scheduler: ms.task_scheduler.TaskScheduler, task_id: int, - measure_candidates: List[MeasureCandidate], - builds: List[BuilderResult], - results: List[RunnerResult], + measure_candidates: List[ms.MeasureCandidate], + builder_results: List[ms.builder.BuilderResult], + runner_results: List[ms.runner.RunnerResult], ) -> None: raise ValueError("test") measure_callback = FailingMeasureCallback() with pytest.raises(ValueError, match="test"): measure_callback.apply( - RoundRobin( + ms.task_scheduler.RoundRobin( tasks=[], task_weights=[], builder=DummyBuilder(), runner=DummyRunner(), - database=DummyDatabase(), + database=ms.database.MemoryDatabase(), max_trials=1, ), 0, - [MeasureCandidate(Schedule(Matmul), None)], - [BuilderResult("test_build", None)], - [RunnerResult([1.0, 2.1], None)], + [ms.MeasureCandidate(Schedule(Matmul), None)], + [ms.builder.BuilderResult("test_build", None)], + [ms.runner.RunnerResult([1.0, 2.1], None)], ) def test_meta_schedule_measure_callback_as_string(): - @derived_object - class NotSoFancyMeasureCallback(PyMeasureCallback): + @ms.derived_object + class NotSoFancyMeasureCallback(ms.measure_callback.PyMeasureCallback): def apply( self, - task_scheduler: "TaskScheduler", + task_scheduler: ms.task_scheduler.TaskScheduler, task_id: int, - measure_candidates: List[MeasureCandidate], - builds: List[BuilderResult], - results: List[RunnerResult], + measure_candidates: List[ms.MeasureCandidate], + builder_results: List[ms.builder.BuilderResult], + runner_results: List[ms.runner.RunnerResult], ) -> None: pass diff --git a/tests/python/unittest/test_meta_schedule_multi_anchor.py b/tests/python/unittest/test_meta_schedule_multi_anchor.py index 0b8af9c14550..b7d012ca04d6 100644 --- a/tests/python/unittest/test_meta_schedule_multi_anchor.py +++ b/tests/python/unittest/test_meta_schedule_multi_anchor.py @@ -17,9 +17,9 @@ import numpy as np import tvm import tvm.testing +from tvm import meta_schedule as ms from tvm import relay -from tvm.meta_schedule import ApplyHistoryBest -from tvm.meta_schedule.testing import apply_fixed_schedules +from tvm.meta_schedule.testing.utils import apply_fixed_schedules def get_dense_dense(data_shape, weight_shape): @@ -27,10 +27,8 @@ def multi_dense(): p_data = relay.var("p_data", shape=data_shape, dtype="float32") p_weight1 = relay.var("p_weight1", shape=weight_shape, dtype="float32") p_weight2 = relay.var("p_weight2", shape=weight_shape, dtype="float32") - dense1 = relay.nn.dense(p_data, p_weight1) dense2 = relay.nn.dense(dense1, p_weight2) - f = relay.Function([p_data, p_weight1, p_weight2], dense2) f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) return f @@ -38,7 +36,6 @@ def multi_dense(): data = relay.var("data", shape=data_shape, dtype="float32") weight1 = relay.var("weight1", shape=weight_shape, dtype="float32") weight2 = relay.var("weight2", shape=weight_shape, dtype="float32") - out = relay.Call(multi_dense(), [data, weight1, weight2]) return relay.Function([data, weight1, weight2], out) @@ -51,26 +48,18 @@ def get_ref(data_np, weight1_np, weight2_np): def schedule_dense_dense(sch): dense1 = sch.get_block("T_matmul_NT") dense2 = sch.get_block("T_matmul_NT_1") - - y1, x1, k1 = sch.get_loops(dense1) - y2, x2, k2 = sch.get_loops(dense2) - - # ... + _y1, _x1, _k1 = sch.get_loops(dense1) + _y2, _x2, _k2 = sch.get_loops(dense2) def test_dense_dense(): M, N, K = 128, 128, 128 data_shape = (M, K) weight_shape = (N, K) - relay_mod = tvm.IRModule.from_expr(get_dense_dense(data_shape, weight_shape)) - - # print(relay.transform.InferType()(relay_mod)) - data_np = np.random.randn(*data_shape).astype("float32") weight1_np = np.random.randn(*weight_shape).astype("float32") weight2_np = np.random.randn(*weight_shape).astype("float32") - target = "llvm" params = {"weight1": weight1_np, "weight2": weight2_np} @@ -81,8 +70,7 @@ def schedule_fn(task, sch): return False database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) - - with ApplyHistoryBest(database): + with ms.ApplyHistoryBest(database): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_meta_schedule": True}, @@ -90,16 +78,11 @@ def schedule_fn(task, sch): lib = relay.build(relay_mod, target=target, params=params) dev = tvm.device(target, 0) - runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) - runtime.set_input("data", data_np) runtime.run() - out = runtime.get_output(0).numpy() - ref = get_ref(data_np, weight1_np, weight2_np) - tvm.testing.assert_allclose(out, ref, atol=1e-4, rtol=1e-4) diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py index 20a977189da5..882655c17f5a 100644 --- a/tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py @@ -62,9 +62,15 @@ def _sch(decision: int) -> Schedule: def _make_mutator(target: Target) -> Mutator: - mutator = MutateComputeLocation() - mutator.initialize_with_tune_context(TuneContext(mod=add, target=target)) - return mutator + ctx = TuneContext( + mod=add, + target=target, + mutator_probs={ + MutateComputeLocation(): 1.0, + }, + ) + ctx.initialize() + return list(ctx.mutator_probs.keys())[0] def test_mutate_compute_location_add(): diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py index e263114ef60f..42e8ffd678f5 100644 --- a/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py @@ -80,9 +80,15 @@ def _sch(decisions: List[List[int]], ann_val: int) -> Schedule: def _make_mutator(target: Target, max_jobs_per_core: int) -> Mutator: - mutator = MutateParallel(max_jobs_per_core) - mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target)) - return mutator + ctx = TuneContext( + mod=matmul, + target=target, + mutator_probs={ + MutateParallel(max_jobs_per_core): 1.0, + }, + ) + ctx.initialize() + return list(ctx.mutator_probs.keys())[0] def test_mutate_parallel_matmul(): diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_thread_binding.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_thread_binding.py index a2e5dcbd1f0a..10bbdb366c8f 100644 --- a/tests/python/unittest/test_meta_schedule_mutator_mutate_thread_binding.py +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_thread_binding.py @@ -63,9 +63,15 @@ def _sch() -> Schedule: def _make_mutator(target: Target) -> Mutator: - mutator = MutateThreadBinding() - mutator.initialize_with_tune_context(TuneContext(mod=element_wise, target=target)) - return mutator + ctx = TuneContext( + mod=element_wise, + target=target, + mutator_probs={ + MutateThreadBinding(): 1.0, + }, + ) + ctx.initialize() + return list(ctx.mutator_probs.keys())[0] def test_mutate_thread_binding(): diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py index 4a3b1f8e943a..47b386447b02 100644 --- a/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py @@ -68,9 +68,13 @@ def _sch(decisions: List[List[int]]) -> Schedule: def _make_mutator(target: Target) -> Mutator: - mutator = MutateTileSize() - mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target)) - return mutator + ctx = TuneContext( + mod=matmul, + target=target, + mutator_probs={MutateTileSize(): 1.0}, + ) + ctx.initialize() + return list(ctx.mutator_probs.keys())[0] def test_mutate_tile_size_matmul(): diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py index 3f3fbcafc0db..dece8a8bc1ec 100644 --- a/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py @@ -85,9 +85,15 @@ def _sch(decisions: List[List[int]]) -> Schedule: def _make_mutator(target: Target) -> Mutator: - mutator = MutateUnroll() - mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target)) - return mutator + ctx = TuneContext( + mod=matmul, + target=target, + mutator_probs={ + MutateUnroll(): 1.0, + }, + ) + ctx.initialize() + return list(ctx.mutator_probs.keys())[0] def test_mutate_unroll_matmul(): diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index c5b6adb466e2..4300e66aa567 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -155,7 +155,7 @@ def _check_correct(schedule: Schedule): @derived_object class WowSoFancyScheduleRule(PyScheduleRule): - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: pass def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: @@ -172,7 +172,7 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: @derived_object class DoubleScheduleRule(PyScheduleRule): - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: pass def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: @@ -197,7 +197,7 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: @derived_object class ReorderScheduleRule(PyScheduleRule): - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: pass def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: @@ -220,10 +220,11 @@ def test_meta_schedule_post_order_apply(): mod=mod, target=Target("llvm"), task_name="Test Task", + space_generator=PostOrderApply(), sch_rules=[WowSoFancyScheduleRule()], ) - post_order_apply = PostOrderApply() - post_order_apply.initialize_with_tune_context(context) + context.initialize() + post_order_apply = context.space_generator schs = post_order_apply.generate_design_space(mod) assert len(schs) == 1 assert not tvm.ir.structural_equal(schs[0].mod, mod) @@ -236,10 +237,11 @@ def test_meta_schedule_post_order_apply_double(): mod=mod, target=Target("llvm"), task_name="Double Rules Task", + space_generator=PostOrderApply(), sch_rules=[DoubleScheduleRule()], ) - post_order_apply = PostOrderApply() - post_order_apply.initialize_with_tune_context(context) + context.initialize() + post_order_apply = context.space_generator schs = post_order_apply.generate_design_space(mod) assert len(schs) == 2 for sch in schs: @@ -253,10 +255,11 @@ def test_meta_schedule_post_order_apply_multiple(): mod=mod, target=Target("llvm"), task_name="Double Rules Task", + space_generator=PostOrderApply(), sch_rules=[DoubleScheduleRule(), ReorderScheduleRule()], ) - post_order_apply = PostOrderApply() - post_order_apply.initialize_with_tune_context(context) + context.initialize() + post_order_apply = context.space_generator schs = post_order_apply.generate_design_space(mod) assert len(schs) == 4 for sch in schs: @@ -270,10 +273,11 @@ def test_meta_schedule_post_order_apply_duplicate_matmul(): mod=mod, target=Target("llvm"), task_name="Duplicate Matmul Task", + space_generator=PostOrderApply(), sch_rules=[WowSoFancyScheduleRule()], ) - post_order_apply = PostOrderApply() - post_order_apply.initialize_with_tune_context(context) + context.initialize() + post_order_apply = context.space_generator with pytest.raises( TVMError, match=r".*TVMError: Check failed: \(block_names_.count\(block->name_hint\) == 0\)" @@ -285,7 +289,7 @@ def test_meta_schedule_post_order_apply_duplicate_matmul(): def test_meta_schedule_post_order_apply_remove_block(): @derived_object class TrinityDouble(PyScheduleRule): - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: pass def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: @@ -307,7 +311,7 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: @derived_object class RemoveBlock(PyScheduleRule): - def initialize_with_tune_context(self, context: "TuneContext") -> None: + def _initialize_with_tune_context(self, context: "TuneContext") -> None: pass def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: @@ -341,10 +345,11 @@ def correct_trace(a, b, c, d): mod=mod, target=Target("llvm"), task_name="Remove Block Task", + space_generator=PostOrderApply(), sch_rules=[RemoveBlock(), TrinityDouble()], ) - post_order_apply = PostOrderApply() - post_order_apply.initialize_with_tune_context(context) + context.initialize() + post_order_apply = context.space_generator schs = post_order_apply.generate_design_space(mod) assert len(schs) == 4 for sch in schs: @@ -368,13 +373,12 @@ def test_meta_schedule_custom_search_space(): mod=mod, target=Target("llvm"), task_name="Custom Search Space Task", + space_generator=PostOrderApply(), sch_rules=[], ) - post_order_apply = PostOrderApply() - post_order_apply.initialize_with_tune_context(context) - + context.initialize() + post_order_apply = context.space_generator post_order_apply.generate_design_space(mod) - called = False def custom_search_space_func(sch: Schedule, _: BlockRV) -> List[Schedule]: @@ -383,7 +387,6 @@ def custom_search_space_func(sch: Schedule, _: BlockRV) -> List[Schedule]: return [sch] register_func("tvm.meta_schedule.test.custom_search_space", custom_search_space_func) - post_order_apply.generate_design_space(mod) assert called diff --git a/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py b/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py index d27e3e61084f..906519cd36eb 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py +++ b/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py @@ -37,8 +37,7 @@ def _create_context(mod, target) -> TuneContext: ], task_name="test", ) - for rule in ctx.postprocs: - rule.initialize_with_tune_context(ctx) + ctx.initialize() return ctx diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py index aa1d219d1c65..e31e912ae4a9 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py @@ -39,8 +39,7 @@ def _create_context(mod, target) -> TuneContext: ], task_name="test", ) - for rule in ctx.postprocs: - rule.initialize_with_tune_context(ctx) + ctx.initialize() return ctx diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py index 263448aa1be6..c7b6e89727a1 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py @@ -37,8 +37,7 @@ def _create_context(mod, target) -> TuneContext: ], task_name="test", ) - for rule in ctx.postprocs: - rule.initialize_with_tune_context(ctx) + ctx.initialize() return ctx diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py index bc84fb1ad0b2..51bf2226d3e1 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensorize.py @@ -17,9 +17,8 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm import tvm.tir.tensor_intrin +from tvm.meta_schedule import TuneContext, postproc from tvm.script import tir as T -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule import postproc @tvm.script.ir_module @@ -458,8 +457,7 @@ def _create_context(mod, target, postprocs): postprocs=postprocs, task_name="test", ) - for rule in ctx.postprocs: - rule.initialize_with_tune_context(ctx) + ctx.initialize() return ctx diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py index 61bd0e349fcf..d797bc9d154d 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py @@ -38,8 +38,7 @@ def _create_context(mod, target) -> TuneContext: ], task_name="test", ) - for rule in ctx.postprocs: - rule.initialize_with_tune_context(ctx) + ctx.initialize() return ctx diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py index a1d2bcfcde08..c91f7bfd1dae 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -41,8 +41,7 @@ def _create_context(mod, target) -> TuneContext: ], task_name="test", ) - for rule in ctx.postprocs: - rule.initialize_with_tune_context(ctx) + ctx.initialize() return ctx diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py index 5a8031220354..7f7f52d1f8a2 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py @@ -33,9 +33,7 @@ def _create_context(mod, target, rule) -> TuneContext: sch_rules=[rule], task_name="test", ) - ctx.space_generator.initialize_with_tune_context(ctx) - for sch_rule in ctx.sch_rules: - sch_rule.initialize_with_tune_context(ctx) + ctx.initialize() return ctx diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py index aa7cb09265e9..2cedd2051dc8 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py @@ -68,9 +68,7 @@ def _create_context(mod, target, rule) -> TuneContext: sch_rules=[rule], task_name="test", ) - ctx.space_generator.initialize_with_tune_context(ctx) - for sch_rule in ctx.sch_rules: - sch_rule.initialize_with_tune_context(ctx) + ctx.initialize() return ctx diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py index e206fcc4502c..5e6690d88e83 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -252,9 +252,7 @@ def _create_context(mod, target, rule): sch_rules=[rule], task_name="test", ) - ctx.space_generator.initialize_with_tune_context(ctx) - for sch_rule in ctx.sch_rules: - sch_rule.initialize_with_tune_context(ctx) + ctx.initialize() return ctx diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py index 47f405842c98..79d53cebe45f 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py @@ -16,17 +16,16 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.meta_schedule.testing import te_workload from tvm.meta_schedule.testing.schedule_rule import cross_thread_reduction from tvm.meta_schedule.testing.space_generation import check_trace from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T from tvm.target import Target from tvm.te.operation import create_prim_func -import tvm -from tvm.script import tir as T - @tvm.script.ir_module class Softmax_mn_after_inline: @@ -68,9 +67,7 @@ def _create_context(mod, target, rule) -> TuneContext: sch_rules=[rule], task_name="test", ) - ctx.space_generator.initialize_with_tune_context(ctx) - for sch_rule in ctx.sch_rules: - sch_rule.initialize_with_tune_context(ctx) + ctx.initialize() return ctx diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py index 43ce9969be84..029dbc52efd1 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -17,18 +17,17 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm from tvm import te +from tvm.meta_schedule import schedule_rule from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.meta_schedule.testing import te_workload -from tvm.meta_schedule.testing.schedule_rule import ( - multi_level_tiling, -) +from tvm.meta_schedule.testing.schedule_rule import multi_level_tiling from tvm.meta_schedule.testing.space_generation import check_trace from tvm.meta_schedule.tune_context import TuneContext -from tvm.meta_schedule import schedule_rule from tvm.script import tir as T -from tvm.te import create_prim_func from tvm.target import Target -from tvm.tir.tensor_intrin import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN, DP4A_INTRIN +from tvm.te import create_prim_func +from tvm.tir.tensor_intrin import DP4A_INTRIN +from tvm.tir.tensor_intrin import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN def _create_context(mod, target, rule) -> TuneContext: @@ -39,9 +38,7 @@ def _create_context(mod, target, rule) -> TuneContext: sch_rules=[rule], task_name="test", ) - ctx.space_generator.initialize_with_tune_context(ctx) - for sch_rule in ctx.sch_rules: - sch_rule.initialize_with_tune_context(ctx) + ctx.initialize() return ctx diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py index 85aa80eb3c82..752bf5e04c4e 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py @@ -232,9 +232,7 @@ def _create_context(mod, target, rule): sch_rules=[rule], task_name="test", ) - ctx.space_generator.initialize_with_tune_context(ctx) - for sch_rule in ctx.sch_rules: - sch_rule.initialize_with_tune_context(ctx) + ctx.initialize() return ctx diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py index 18db006c6ca8..379fb4675aa5 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py @@ -16,8 +16,8 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm -from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.meta_schedule.schedule_rule import RandomComputeLocation +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.meta_schedule.testing.space_generation import check_trace from tvm.meta_schedule.tune_context import TuneContext from tvm.script import tir as T @@ -63,9 +63,7 @@ def _create_context(mod, target, rule): sch_rules=[rule], task_name="test", ) - ctx.space_generator.initialize_with_tune_context(ctx) - for sch_rule in ctx.sch_rules: - sch_rule.initialize_with_tune_context(ctx) + ctx.initialize() return ctx diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 4eb8aac5a331..fd8c023b5e4e 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -16,25 +16,13 @@ # under the License. """ Test Meta Schedule SearchStrategy """ # pylint: disable=missing-function-docstring -import sys from typing import List import pytest import tvm import tvm.testing from tvm import meta_schedule as ms -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.runner import RunnerResult -from tvm.meta_schedule.search_strategy import ( - EvolutionarySearch, - ReplayFunc, - ReplayTrace, - SearchStrategy, -) -from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.task_scheduler import RoundRobin -from tvm.meta_schedule.testing import DummyMutator -from tvm.meta_schedule.testing.utils import DummyDatabase +from tvm.meta_schedule.testing.dummy_object import DummyMutator from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace @@ -81,34 +69,51 @@ def _schedule_matmul(sch: Schedule): sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) -@pytest.mark.parametrize("TestClass", [ReplayFunc, ReplayTrace]) -def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disable = invalid-name +@pytest.mark.parametrize( + "TestClass", + [ + ms.search_strategy.ReplayFunc, + ms.search_strategy.ReplayTrace, + ], +) +def test_meta_schedule_replay_func( + TestClass: ms.search_strategy.SearchStrategy, +): # pylint: disable = invalid-name num_trials_per_iter = 7 max_trials_per_task = 20 - strategy = TestClass( - num_trials_per_iter=num_trials_per_iter, max_trials_per_task=max_trials_per_task + context = ms.TuneContext( + mod=Matmul, + space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=TestClass( + num_trials_per_iter=num_trials_per_iter, max_trials_per_task=max_trials_per_task + ), ) - context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) - context.space_generator.initialize_with_tune_context(context) + context.initialize() + strategy = context.search_strategy spaces = context.space_generator.generate_design_space(context.mod) - - strategy.initialize_with_tune_context(context) strategy.pre_tuning(spaces) - (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + (correct_sch,) = ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul).generate_design_space( + Matmul + ) num_trials_each_iter: List[int] = [] candidates = strategy.generate_measure_candidates() while candidates is not None: num_trials_each_iter.append(len(candidates)) - runner_results: List[RunnerResult] = [] + runner_results: List[ms.runner.RunnerResult] = [] for candidate in candidates: _is_trace_equal( candidate.sch, correct_sch, - remove_decisions=(isinstance(strategy, ReplayTrace)), + remove_decisions=(isinstance(strategy, ms.search_strategy.ReplayTrace)), + ) + runner_results.append( + ms.runner.RunnerResult( + run_secs=[0.11, 0.41, 0.54], + error_msg=None, + ) ) - runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) - strategy.notify_runner_results(context, candidates, runner_results) + strategy.notify_runner_results(candidates, runner_results) candidates = strategy.generate_measure_candidates() strategy.post_tuning() assert num_trials_each_iter == [7, 7, 6] @@ -123,14 +128,16 @@ def _schedule_matmul_small(sch: Schedule): num_trials_per_iter = 10 max_trials_per_task = 2000 - (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + (correct_sch,) = ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul).generate_design_space( + Matmul + ) - context = TuneContext( + context = ms.TuneContext( mod=Matmul, - space_generator=ScheduleFn( + space_generator=ms.space_generator.ScheduleFn( sch_fn=_schedule_matmul_small, ), - search_strategy=EvolutionarySearch( + search_strategy=ms.search_strategy.EvolutionarySearch( num_trials_per_iter=num_trials_per_iter, max_trials_per_task=max_trials_per_task, population_size=5, @@ -151,22 +158,27 @@ def _schedule_matmul_small(sch: Schedule): strategy = context.search_strategy strategy.pre_tuning( context.space_generator.generate_design_space(context.mod), - database=DummyDatabase(), + database=ms.database.MemoryDatabase(), cost_model=ms.cost_model.RandomModel(), ) num_trials_each_iter: List[int] = [] candidates = strategy.generate_measure_candidates() while candidates is not None: num_trials_each_iter.append(len(candidates)) - runner_results: List[RunnerResult] = [] + runner_results: List[ms.runner.RunnerResult] = [] for candidate in candidates: _is_trace_equal( candidate.sch, correct_sch, - remove_decisions=(isinstance(strategy, ReplayTrace)), + remove_decisions=(isinstance(strategy, ms.search_strategy.ReplayTrace)), ) - runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) - strategy.notify_runner_results(context, candidates, runner_results) + runner_results.append( + ms.runner.RunnerResult( + run_secs=[0.11, 0.41, 0.54], + error_msg=None, + ) + ) + strategy.notify_runner_results(candidates, runner_results) candidates = strategy.generate_measure_candidates() strategy.post_tuning() assert sum(num_trials_each_iter) == 25 @@ -177,14 +189,16 @@ def test_meta_schedule_evolutionary_search_early_stop(): # pylint: disable = in def _schedule_matmul_empty(sch: Schedule): return sch - (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + (correct_sch,) = ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul).generate_design_space( + Matmul + ) num_trials_per_iter = 10 max_trials_per_task = 100 - context = TuneContext( + context = ms.TuneContext( mod=Matmul, - search_strategy=EvolutionarySearch( + search_strategy=ms.search_strategy.EvolutionarySearch( num_trials_per_iter=num_trials_per_iter, max_trials_per_task=max_trials_per_task, population_size=5, @@ -195,7 +209,7 @@ def _schedule_matmul_empty(sch: Schedule): genetic_max_fail_count=10, eps_greedy=0.9, ), - space_generator=ScheduleFn( + space_generator=ms.space_generator.ScheduleFn( sch_fn=_schedule_matmul_empty, ), mutator_probs={ @@ -208,22 +222,27 @@ def _schedule_matmul_empty(sch: Schedule): strategy = context.search_strategy strategy.pre_tuning( context.space_generator.generate_design_space(context.mod), - database=DummyDatabase(), + database=ms.database.MemoryDatabase(), cost_model=ms.cost_model.RandomModel(), ) num_trials_each_iter: List[int] = [] candidates = strategy.generate_measure_candidates() while candidates is not None: num_trials_each_iter.append(len(candidates)) - runner_results: List[RunnerResult] = [] + runner_results: List[ms.runner.RunnerResult] = [] for candidate in candidates: _is_trace_equal( candidate.sch, correct_sch, - remove_decisions=(isinstance(strategy, ReplayTrace)), + remove_decisions=(isinstance(strategy, ms.search_strategy.ReplayTrace)), + ) + runner_results.append( + ms.runner.RunnerResult( + run_secs=[0.11, 0.41, 0.54], + error_msg=None, + ), ) - runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) - strategy.notify_runner_results(context, candidates, runner_results) + strategy.notify_runner_results(candidates, runner_results) candidates = strategy.generate_measure_candidates() strategy.post_tuning() assert num_trials_each_iter == [1, 0, 0, 0, 0] diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py index 84104c8bcff2..9201fe16e849 100644 --- a/tests/python/unittest/test_meta_schedule_space_generator.py +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -17,21 +17,23 @@ """ Test Meta Schedule SpaceGenerator """ # pylint: disable=missing-function-docstring -import sys import math +import sys import pytest - import tvm import tvm.testing -from tvm.meta_schedule.utils import derived_object -from tvm.meta_schedule.space_generator import ScheduleFn, PySpaceGenerator, SpaceGeneratorUnion -from tvm.meta_schedule.tune_context import TuneContext from tvm._ffi.base import TVMError +from tvm.meta_schedule.space_generator import ( + PySpaceGenerator, + ScheduleFn, + SpaceGeneratorUnion, +) +from tvm.meta_schedule.tune_context import TuneContext +from tvm.meta_schedule.utils import derived_object from tvm.script import tir as T from tvm.tir.schedule import Schedule - # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off @@ -98,7 +100,7 @@ class TestPySpaceGenerator(PySpaceGenerator): TVMError, match="PySpaceGenerator's InitializeWithTuneContext method not implemented!" ): generator = TestPySpaceGenerator() - generator.initialize_with_tune_context(TuneContext()) + generator._initialize_with_tune_context(TuneContext()) if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index f24dc5fbbc1f..fc2497f05303 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -23,16 +23,18 @@ import pytest import tvm import tvm.testing +from tvm import meta_schedule as ms from tvm._ffi.base import TVMError -from tvm.meta_schedule import TuneContext, measure_callback -from tvm.meta_schedule.search_strategy import ReplayTrace -from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.task_scheduler import GradientBased, PyTaskScheduler, RoundRobin -from tvm.meta_schedule.testing import DummyBuilder, DummyDatabase, DummyRunner -from tvm.meta_schedule.utils import derived_object +from tvm.meta_schedule.testing.dummy_object import DummyBuilder, DummyRunner from tvm.script import tir as T from tvm.tir import Schedule +# from tvm.meta_schedule import TuneContext, measure_callback +# from tvm.meta_schedule.search_strategy import ReplayTrace +# from tvm.meta_schedule.space_generator import ScheduleFn +# from tvm.meta_schedule.task_scheduler import GradientBased, PyTaskScheduler, RoundRobin +# from tvm.meta_schedule.utils import derived_object + # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @@ -123,8 +125,8 @@ def _schedule_batch_matmul(sch: Schedule): sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, t_0, t_1) -@derived_object -class MyTaskScheduler(PyTaskScheduler): +@ms.derived_object +class MyTaskScheduler(ms.task_scheduler.PyTaskScheduler): done: Set = set() def next_task_id(self) -> int: @@ -153,14 +155,17 @@ def next_task_id(self) -> int: def test_meta_schedule_task_scheduler_single(): num_trials_per_iter = 3 max_trials_per_task = 10 - database = DummyDatabase() - round_robin = RoundRobin( + database = ms.database.MemoryDatabase() + round_robin = ms.task_scheduler.RoundRobin( [ - TuneContext( + ms.TuneContext( MatmulModule, target=tvm.target.Target("llvm"), - space_generator=ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), + space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ms.search_strategy.ReplayTrace( + num_trials_per_iter, + max_trials_per_task, + ), task_name="Test", rand_state=42, ) @@ -169,7 +174,7 @@ def test_meta_schedule_task_scheduler_single(): builder=DummyBuilder(), runner=DummyRunner(), database=database, - measure_callbacks=[measure_callback.AddToDatabase()], + measure_callbacks=[ms.measure_callback.AddToDatabase()], max_trials=max_trials_per_task, ) round_robin.tune() @@ -180,39 +185,48 @@ def test_meta_schedule_task_scheduler_multiple(): num_trials_per_iter = 6 max_trials_per_task = 101 tasks = [ - TuneContext( + ms.TuneContext( MatmulModule, target=tvm.target.Target("llvm"), - space_generator=ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), + space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ms.search_strategy.ReplayTrace( + num_trials_per_iter, + max_trials_per_task, + ), task_name="Matmul", rand_state=42, ), - TuneContext( + ms.TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), - space_generator=ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), + space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ms.search_strategy.ReplayTrace( + num_trials_per_iter, + max_trials_per_task, + ), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), - TuneContext( + ms.TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), - space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), + space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_batch_matmul), + search_strategy=ms.search_strategy.ReplayTrace( + num_trials_per_iter, + max_trials_per_task, + ), task_name="BatchMatmul", rand_state=0x114514, ), ] - database = DummyDatabase() - round_robin = RoundRobin( + database = ms.database.MemoryDatabase() + round_robin = ms.task_scheduler.RoundRobin( tasks, [1.0, 1.0, 1.0], builder=DummyBuilder(), runner=DummyRunner(), database=database, - measure_callbacks=[measure_callback.AddToDatabase()], + measure_callbacks=[ms.measure_callback.AddToDatabase()], max_trials=max_trials_per_task * len(tasks), ) round_robin.tune() @@ -230,8 +244,8 @@ def test_meta_schedule_task_scheduler_multiple(): def test_meta_schedule_task_scheduler_NIE(): # pylint: disable=invalid-name - @derived_object - class NIETaskScheduler(PyTaskScheduler): + @ms.derived_object + class NIETaskScheduler(ms.task_scheduler.PyTaskScheduler): pass with pytest.raises(TVMError, match="PyTaskScheduler's NextTaskId method not implemented!"): @@ -239,21 +253,21 @@ class NIETaskScheduler(PyTaskScheduler): tasks=[], builder=DummyBuilder(), runner=DummyRunner(), - database=DummyDatabase(), + database=ms.database.MemoryDatabase(), max_trials=1, ) scheduler.next_task_id() def test_meta_schedule_task_scheduler_avoid_cyclic(): # pylint: disable=invalid-name - database = DummyDatabase() + database = ms.database.MemoryDatabase() scheduler = MyTaskScheduler( [], builder=DummyBuilder(), runner=DummyRunner(), database=database, measure_callbacks=[ - measure_callback.AddToDatabase(), + ms.measure_callback.AddToDatabase(), ], max_trials=10, ) @@ -266,40 +280,47 @@ def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: d num_trials_per_iter = 6 max_trials_per_task = 101 tasks = [ - TuneContext( + ms.TuneContext( MatmulModule, target=tvm.target.Target("llvm"), - space_generator=ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), + space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ms.search_strategy.ReplayTrace( + num_trials_per_iter, + max_trials_per_task, + ), task_name="Matmul", rand_state=42, ), - TuneContext( + ms.TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), - space_generator=ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), + space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ms.search_strategy.ReplayTrace( + num_trials_per_iter, + max_trials_per_task, + ), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), - TuneContext( + ms.TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), - space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), + space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_batch_matmul), + search_strategy=ms.search_strategy.ReplayTrace( + num_trials_per_iter, + max_trials_per_task, + ), task_name="BatchMatmul", rand_state=0x114514, ), ] - database = DummyDatabase() + database = ms.database.MemoryDatabase() scheduler = MyTaskScheduler( tasks, builder=DummyBuilder(), runner=DummyRunner(), database=database, - measure_callbacks=[ - measure_callback.AddToDatabase(), - ], + measure_callbacks=[ms.measure_callback.AddToDatabase()], max_trials=max_trials_per_task * len(tasks), ) scheduler.tune() @@ -320,39 +341,48 @@ def test_meta_schedule_task_scheduler_multiple_gradient_based(): num_trials_per_iter = 6 max_trials_per_task = 101 tasks = [ - TuneContext( + ms.TuneContext( MatmulModule, target=tvm.target.Target("llvm"), - space_generator=ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), + space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ms.search_strategy.ReplayTrace( + num_trials_per_iter, + max_trials_per_task, + ), task_name="Matmul", rand_state=42, ), - TuneContext( + ms.TuneContext( MatmulReluModule, target=tvm.target.Target("llvm"), - space_generator=ScheduleFn(sch_fn=_schedule_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), + space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ms.search_strategy.ReplayTrace( + num_trials_per_iter, + max_trials_per_task, + ), task_name="MatmulRelu", rand_state=0xDEADBEEF, ), - TuneContext( + ms.TuneContext( BatchMatmulModule, target=tvm.target.Target("llvm"), - space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), - search_strategy=ReplayTrace(num_trials_per_iter, max_trials_per_task), + space_generator=ms.space_generator.ScheduleFn(sch_fn=_schedule_batch_matmul), + search_strategy=ms.search_strategy.ReplayTrace( + num_trials_per_iter, + max_trials_per_task, + ), task_name="BatchMatmul", rand_state=0x114514, ), ] - database = DummyDatabase() - gradient_based = GradientBased( + database = ms.database.MemoryDatabase() + gradient_based = ms.task_scheduler.GradientBased( tasks, task_weights=[1.0, 1.0, 1.0], builder=DummyBuilder(), runner=DummyRunner(), database=database, - measure_callbacks=[measure_callback.AddToDatabase()], + measure_callbacks=[ms.measure_callback.AddToDatabase()], seed=0x20220214, max_trials=max_trials_per_task * len(tasks), ) diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index e0883dbd227e..c2baf8d2b921 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -23,17 +23,13 @@ import numpy as np # type: ignore import pytest import tvm +from tvm import meta_schedule as ms from tvm import relay from tvm._ffi import register_func from tvm.contrib import graph_executor from tvm.ir import IRModule -from tvm.meta_schedule import ApplyHistoryBest, TuneConfig -from tvm.meta_schedule.database import JSONDatabase, PyDatabase, TuningRecord, Workload -from tvm.meta_schedule.relay_integration import extract_task_from_relay -from tvm.meta_schedule.testing import apply_fixed_schedules from tvm.meta_schedule.testing.relay_workload import get_network -from tvm.meta_schedule.tune import tune_extracted_tasks, tune_relay -from tvm.meta_schedule.utils import derived_object +from tvm.meta_schedule.testing.utils import apply_fixed_schedules from tvm.script import tir as T from tvm.target.target import Target from tvm.tir.schedule import BlockRV, Schedule @@ -142,11 +138,11 @@ def test_meta_schedule_tune_relay( mod, params, (input_name, _, _) = get_network(name=model_name, input_shape=input_shape) target = Target(target) with tempfile.TemporaryDirectory() as work_dir: - rt_mod1: tvm.runtime.Module = tune_relay( + rt_mod1: tvm.runtime.Module = ms.tune_relay( mod=mod, params=params, target=target, - config=TuneConfig( + config=ms.TuneConfig( strategy="evolutionary", num_trials_per_iter=32, max_trials_per_task=20000, @@ -156,7 +152,7 @@ def test_meta_schedule_tune_relay( }, ), work_dir=work_dir, - database=JSONDatabase( + database=ms.database.JSONDatabase( osp.join(work_dir, "workload.json"), osp.join(work_dir, "records.json"), ), @@ -178,14 +174,14 @@ def get_output(data, lib): def test_meta_schedule_te2primfunc_argument_order(): - @derived_object - class TestDummyDatabase(PyDatabase): + @ms.derived_object + class TestDummyDatabase(ms.database.PyDatabase): def __init__(self): super().__init__() self.records = [] self.workload_reg = [] - def has_workload(self, mod: IRModule) -> Workload: + def has_workload(self, mod: IRModule) -> ms.database.Workload: for workload in self.workload_reg: if tvm.ir.structural_equal(workload.mod, mod): return True @@ -195,18 +191,22 @@ def has_workload(self, mod: IRModule) -> Workload: + " Incorrect TIR was generated from TE subgraph." ) - def commit_tuning_record(self, record: TuningRecord) -> None: + def commit_tuning_record(self, record: ms.database.TuningRecord) -> None: self.records.append(record) - def commit_workload(self, mod: IRModule) -> Workload: + def commit_workload(self, mod: IRModule) -> ms.database.Workload: for workload in self.workload_reg: if tvm.ir.structural_equal(workload.mod, mod): return workload - workload = Workload(mod) + workload = ms.database.Workload(mod) self.workload_reg.append(workload) return workload - def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + def get_top_k( + self, + workload: ms.database.Workload, + top_k: int, + ) -> List[ms.database.TuningRecord]: return list( filter( lambda x: x.workload == workload, @@ -250,7 +250,7 @@ def print_results(self) -> None: database.commit_workload(tvmgen_default_fused_layout_transform_1) database.commit_workload(tvmgen_default_fused_nn_contrib_conv2d_NCHWc) - with ApplyHistoryBest(database): + with ms.ApplyHistoryBest(database): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_meta_schedule": True}, @@ -300,12 +300,11 @@ def test_meta_schedule_relay_lowering(): data = tvm.nd.array(data_sample, dev) with tempfile.TemporaryDirectory() as work_dir: - database = JSONDatabase( + database = ms.database.JSONDatabase( osp.join(work_dir, "workload.json"), osp.join(work_dir, "records.json") ) - database.commit_tuning_record( - TuningRecord( + ms.database.TuningRecord( Trace([], {}), database.commit_workload(tvmgen_default_fused_nn_contrib_conv2d_NCHWc), [0.0], @@ -313,8 +312,7 @@ def test_meta_schedule_relay_lowering(): args_info=[], ) ) - - with ApplyHistoryBest(database): + with ms.ApplyHistoryBest(database): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_meta_schedule": True}, @@ -435,8 +433,7 @@ def manual_tir_common(do_tune=False): params = {"weight": weight_np, "bias": bias_np} if do_tune: - extracted_tasks = extract_task_from_relay(relay_mod, target, params) - + extracted_tasks = ms.extract_task_from_relay(relay_mod, target, params) # Filter out tasks that we don't intend to schedule / tune with TIR. tune_tasks = list( filter( @@ -444,7 +441,7 @@ def manual_tir_common(do_tune=False): extracted_tasks, ) ) - config = TuneConfig( + config = ms.TuneConfig( strategy="replay_trace", num_trials_per_iter=64, max_trials_per_task=20000, @@ -454,7 +451,7 @@ def manual_tir_common(do_tune=False): with tempfile.TemporaryDirectory() as work_dir: # postprocs=lambda: [] is important to prevent default post processors from # tampering with the manual schedule. - database = tune_extracted_tasks( + database = ms.tune_extracted_tasks( tune_tasks, config, work_dir=work_dir, @@ -480,7 +477,7 @@ def schedule_fn(task, sch): database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) - with ApplyHistoryBest(database): + with ms.ApplyHistoryBest(database): with tvm.transform.PassContext( opt_level=3, config={"relay.backend.use_meta_schedule": True},