From aa0178b77bd7940525e7441c925296a33b6273ff Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 15 Feb 2022 01:02:14 -0800 Subject: [PATCH] Fix implementation & add tests. Fix rebase. Continue to fix rebase. Fix rebase. --- include/tvm/meta_schedule/database.h | 1 - include/tvm/meta_schedule/task_scheduler.h | 32 ++- include/tvm/meta_schedule/tune_context.h | 2 +- include/tvm/tir/schedule/schedule.h | 8 +- include/tvm/tir/stmt.h | 87 -------- include/tvm/tir/transform.h | 14 -- .../tvm/auto_scheduler/relay_integration.py | 12 +- python/tvm/auto_scheduler/search_task.py | 3 +- .../tvm/auto_scheduler/workload_registry.py | 5 +- .../meta_schedule/builder/local_builder.py | 18 +- .../meta_schedule/cost_model/cost_model.py | 8 +- python/tvm/meta_schedule/cost_model/metric.py | 9 +- .../meta_schedule/cost_model/random_model.py | 12 +- .../random_feature_extractor.py | 2 +- .../tvm/meta_schedule/runner/local_runner.py | 2 +- .../space_generator/post_order_apply.py | 2 +- .../task_scheduler/gradient_based.py | 60 +++--- .../task_scheduler/round_robin.py | 20 +- python/tvm/meta_schedule/utils.py | 67 +++--- python/tvm/relay/build_module.py | 8 +- python/tvm/tir/function.py | 24 +-- python/tvm/tir/schedule/schedule.py | 110 +++++----- python/tvm/tir/transform/transform.py | 12 -- src/arith/iter_affine_map.cc | 39 +--- .../search_strategy/replay_trace.cc | 1 - .../task_scheduler/gradient_based.cc | 82 +++++--- src/meta_schedule/utils.h | 5 +- src/tir/schedule/analysis.h | 59 ++---- src/tir/schedule/analysis/analysis.cc | 16 -- src/tir/schedule/concrete_schedule.h | 4 +- src/tir/schedule/instruction_traits.h | 2 +- src/tir/schedule/primitive.h | 2 +- src/tir/schedule/primitive/annotate.cc | 6 +- src/tir/schedule/primitive/compute_at.cc | 88 ++------ src/tir/schedule/primitive/reduction.cc | 10 +- src/tir/schedule/primitive/sampling.cc | 21 +- src/tir/schedule/state.cc | 6 +- .../unittest/test_meta_schedule_byoc.py | 198 ------------------ .../unittest/test_meta_schedule_cost_model.py | 15 +- .../test_meta_schedule_post_order_apply.py | 1 - .../test_meta_schedule_space_generator.py | 3 - .../test_meta_schedule_task_scheduler.py | 55 ++++- .../unittest/test_tir_schedule_compute_at.py | 28 +-- .../unittest/test_tir_schedule_sampling.py | 2 +- 44 files changed, 393 insertions(+), 768 deletions(-) delete mode 100644 tests/python/unittest/test_meta_schedule_byoc.py diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 307ec309c009..f07d8e136644 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -237,7 +237,6 @@ class PyDatabaseNode : public DatabaseNode { // PackedFuncs are all not visited, because the reflection system doesn't take care of them, // so it cannot be accessible on the python side. If there is such need from the future, // we can then add corresponding accessor methods to help access on python. - // // `f_has_workload` is not visited // `f_commit_workload` is not visited // `f_commit_tuning_record` is not visited diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 654a8ef20c68..bd6019ca65b6 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -68,7 +68,7 @@ namespace meta_schedule { class TaskSchedulerNode : public runtime::Object { public: /*! \brief The function type of the objective function. */ - using FObjectiveFunc = TypedPackedFunc)>; + using FObjectiveFunc = TypedPackedFunc)>; /*! \brief The function type of the tag genration function. */ using FTagGenerationFunc = TypedPackedFunc; @@ -264,6 +264,36 @@ class TaskScheduler : public runtime::ObjectRef { Database database, // Optional cost_model, // Optional> measure_callbacks); + /*! + * \brief Create a task scheduler that fetches tasks in a gradient based fashion. + * \param tasks The tasks to be tuned. + * \param builder The builder of the scheduler. + * \param runner The runner of the scheduler. + * \param database The database of the scheduler. + * \param alpha The parameter alpha to control gradient computation. + * \param beta The parameter beta to control gradient computation. + * \param backward_window_size The parameter to control backward window size. + * \param seed The random seed. + * \param task_weights The weights of each task. + * \param objective_fun_namec The name of objective function for gradient optimization. + * \param tag_generation_func_name The name of function to generate similarity tag for workloads. + * \param cost_model The cost model of the scheduler. + * \param measure_callbacks The measure callbacks of the scheduler. + * \return The task scheduler created. + */ + TVM_DLL static TaskScheduler GradientBased(Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + double alpha, // + double beta, // + int backward_window_size, // + support::LinearCongruentialEngine::TRandState seed, // + Array task_weights, // + String objective_func_name, // + String tag_generation_func_name, // + Optional cost_model, // + Optional> measure_callbacks); /*! * \brief Create a task scheduler with customized methods on the python-side. * \param tasks The tasks to be tuned. diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index ff3a14c076e4..7a7599b0a4f8 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -53,7 +53,7 @@ class TuneContextNode : public runtime::Object { /*! \brief The probability of using certain mutator. */ Map mutator_probs; /*! \brief The name of the tuning task. */ - String task_name; + Optional task_name; /*! \brief The random state. */ support::LinearCongruentialEngine::TRandState rand_state; /*! \brief The number of threads to be used. */ diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 49555e8e37f1..89871f0d6352 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -500,14 +500,14 @@ class ScheduleNode : public runtime::Object { /******** Schedule: Annotation ********/ /*! * \brief Annotate a loop with a key value pair - * \param loop The loop to be annotated + * \param loop_rv The loop to be annotated * \param ann_key The annotation key * \param ann_val The annotation value, a string or a ExprRV */ virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) = 0; /*! * \brief Annotate a block with a key value pair - * \param loop The block to be annotated + * \param block_rv The block to be annotated * \param ann_key The annotation key * \param ann_val The annotation value, a string or a ExprRV */ @@ -515,13 +515,13 @@ class ScheduleNode : public runtime::Object { const ObjectRef& ann_val) = 0; /*! * \brief Unannotate a loop's annotation with key ann_key - * \param loop The loop to be unannotated + * \param loop_rv The loop to be unannotated * \param ann_key The annotation key */ virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0; /*! * \brief Unannotate a block's annotation with key ann_key - * \param loop The block to be unannotated + * \param block_rv The block to be unannotated * \param ann_key The annotation key */ virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 4074f5203857..7b07146f446c 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1442,93 +1442,6 @@ constexpr const char* nested_software_pipeline_stage = "nested_software_pipeline */ constexpr const char* nested_software_pipeline_order = "nested_software_pipeline_order"; -/*! - * \brief Mark that the block need to add predicate for block var bounds during lowering - */ -constexpr const char* require_block_var_bound_predicate = "require_bound_predicate"; - -/*! - * \brief Mark that the loop should be further skip and bound to environment threads to enable - * cooperative fetching. - */ -constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch"; - -/*! - * \brief Mark that the block should be further rewritten using tensorization. - */ -constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize"; - -/*! \brief Mark that tensor core is enabled in the PrimExpr */ -constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled"; - -/*! \brief The allowed range of thread extent in thread bindings */ -constexpr const char* meta_schedule_thread_extent_low_inclusive = - "meta_schedule.thread_extent_low_inclusive"; - -/*! \brief The allowed range of thread extent in thread bindings */ -constexpr const char* meta_schedule_thread_extent_high_inclusive = - "meta_schedule.thread_extent_high_inclusive"; - -/*! - * \brief Mark a block as generated by cache_read or cache_write block. - * 0 means cache_read; 1 means cache_write. - * \sa meta_schedule_cache_type_read - * \sa meta_schedule_cache_type_write - */ -constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type"; - -/*! \sa meta_schedule_cache_type */ -constexpr const int meta_schedule_cache_type_read = 0; - -/*! \sa meta_schedule_cache_type */ -constexpr const int meta_schedule_cache_type_write = 1; - -/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */ -constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure"; - -/*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */ -constexpr const char* meta_schedule_random_compute_producer = - "meta_schedule.random_compute_producer"; - -/*! \brief Mark auto-parallel setting on the block. */ -constexpr const char* meta_schedule_parallel = "meta_schedule.parallel"; - -/*! \brief Mark auto-vectorize setting on the block. */ -constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize"; - -/*! \brief Mark auto-unroll setting on the block. */ -constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit"; - -/*! \brief Mark auto-unroll setting on the block. */ -constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit"; - -/*! \brief Pragma: auto-unroll, max_step */ -constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step"; - -/*! \brief Pragma: unroll explicit */ -constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit"; - -/*! \brief Mark the scope of the software pipeline */ -constexpr const char* software_pipeline_scope = "software_pipeline_scope"; - -/*! \brief Mark the stage of a statement in the software pipeline */ -constexpr const char* software_pipeline_stage = "software_pipeline_stage"; - -/*! \brief Mark the order of a statement in the software pipeline */ -constexpr const char* software_pipeline_order = "software_pipeline_order"; - -/*! \brief Mark the stage of the result of the software pipeline lowering. This is used to specify - * the behavior of nested software pipelines. Should be a 3-tuple consisting of the stage of the - * prologue, the body, and the epilogue of the software pipeline. - */ -constexpr const char* nested_software_pipeline_stage = "nested_software_pipeline_stage"; - -/*! \brief Mark the stage of the result of the software pipeline lowering. This is used to specify - * the behavior of nested software pipelines. Should be a 3-tuple consisting of the stage of the - * prologue, the body, and the epilogue of the software pipeline. - */ -constexpr const char* nested_software_pipeline_order = "nested_software_pipeline_order"; - /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 6b8edf29bf2c..4df54f0208d3 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -383,20 +383,6 @@ TVM_DLL Pass LowerInitBlock(); */ TVM_DLL Pass PlanAndUpdateBufferAllocationLocation(); -/*! - * \brief Narrow the extents of some loops by checking whether some constraints in the block iter - * bound predicates can be directly applied on the loops. - * \return The pass. - */ -TVM_DLL Pass ApplyBlockBoundPredicate(); - -/*! - * \brief Narrow the extents of some loops by checking whether some constraints in the block iter - * bound predicates can be directly applied on the loops. - * \return The pass. - */ -TVM_DLL Pass ApplyBlockBoundPredicate(); - /*! * \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the * corresponding iter_values in BlockRealize, for opaque blocks by removing all diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 7ff1840c9123..e9bb68ad8e93 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -329,9 +329,9 @@ def auto_schedule_topi(func_name, outs): """ # pylint: disable=import-outside-toplevel - from tvm.auto_scheduler.measure import ( # lazily import to avoid recursive dependency + from tvm.auto_scheduler.measure import ( prepare_input_map, - ) + ) # lazily import to avoid recursive dependency io_tensors, has_layout_free, has_complex_op = traverse_to_get_io_tensors(outs) if not io_tensors: # The compute includes dynamic shapes which are not supported yet. @@ -482,10 +482,4 @@ def is_auto_scheduler_enabled(): enabled: bool Whether the auto-scheduler is enabled """ - return PassContext.current().config.get( - "relay.backend.use_auto_scheduler", - False, - ) or PassContext.current().config.get( - "relay.backend.use_meta_schedule", - False, - ) + return PassContext.current().config.get("relay.backend.use_auto_scheduler", False) diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index 0e9c4abebbe1..f1156998bdac 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -543,8 +543,7 @@ def print_best(self, log_file, print_mode="schedule"): code: str The best schedule code in python API or CUDA source code """ - inp, res = load_best_record(log_file, self.workload_key) - print("Best codes (ms):", [float(c) * 1000.0 for c in res.costs]) + inp, _ = load_best_record(log_file, self.workload_key) if inp is None: raise RuntimeError( "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file) diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index 75702b0a21af..885eb0d1d0f8 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -194,10 +194,7 @@ def workload_key_to_tensors(workload_key): assert callable(value) args = deserialize_args(workload[1:]) - result = value(*args) - if isinstance(result, tuple): - result = list(result) - return result + return value(*args) def serialize_workload_registry_entry(workload_key): diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index ca38424957db..da7bb515f112 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -22,28 +22,13 @@ from tvm._ffi import register_func from tvm.ir import IRModule -from tvm.runtime import NDArray -from tvm.runtime import Module, load_param_dict, save_param_dict +from tvm.runtime import Module, NDArray, load_param_dict, save_param_dict from tvm.target import Target from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind from ..utils import cpu_count, get_global_func_with_default_on_worker from .builder import BuilderInput, BuilderResult, PyBuilder -logger = logging.getLogger(__name__) - - -def _serialize_params(params: Optional[Dict[str, NDArray]]) -> Optional[bytearray]: - if params is None: - return None - return save_param_dict(params) - - -def _deserialize_params(params: Optional[bytearray]) -> Optional[Dict[str, NDArray]]: - if params is None: - return None - return load_param_dict(params) - logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -142,6 +127,7 @@ def __init__( The initializer to be used for the worker processes. """ super().__init__() + if max_workers is None: max_workers = cpu_count(logical=True) logger.info("LocalBuilder: max_workers = %d", max_workers) diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index 4fdd80b1769b..f794b11471d9 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -15,19 +15,17 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule CostModel.""" - -from typing import List import ctypes +from typing import List -import numpy as np - +import numpy as np # type: ignore from tvm._ffi import register_object from tvm.runtime import Object from .. import _ffi_api from ..runner import RunnerResult -from ..tune_context import TuneContext from ..search_strategy import MeasureCandidate +from ..tune_context import TuneContext from ..utils import _get_hex_address, check_override diff --git a/python/tvm/meta_schedule/cost_model/metric.py b/python/tvm/meta_schedule/cost_model/metric.py index 7eb6da6f07d9..efd8dc68ac0d 100644 --- a/python/tvm/meta_schedule/cost_model/metric.py +++ b/python/tvm/meta_schedule/cost_model/metric.py @@ -15,11 +15,10 @@ # specific language governing permissions and limitations # under the License. """Cost model metrics for meta schedule""" -from typing import List -import numpy as np +import numpy as np # type: ignore -def max_curve(trial_scores: np.ndarray) -> List[float]: +def max_curve(trial_scores: np.ndarray) -> np.ndarray: """f(n) = max([s[i] fo i < n]) Parameters @@ -29,8 +28,8 @@ def max_curve(trial_scores: np.ndarray) -> List[float]: Returns ------- - curve : List[float] - function values + curve : np.ndarray + A vector, the max-curve function values """ ret = np.empty(len(trial_scores)) keep = -1e9 diff --git a/python/tvm/meta_schedule/cost_model/random_model.py b/python/tvm/meta_schedule/cost_model/random_model.py index 1bb5fc237ae5..8808476aba15 100644 --- a/python/tvm/meta_schedule/cost_model/random_model.py +++ b/python/tvm/meta_schedule/cost_model/random_model.py @@ -17,14 +17,14 @@ """ Random cost model """ -from typing import List, Union, Tuple, Optional +from typing import List, Optional, Tuple, Union -import numpy as np +import numpy as np # type: ignore +from ..cost_model import PyCostModel from ..runner import RunnerResult -from ..tune_context import TuneContext from ..search_strategy import MeasureCandidate -from ..cost_model import PyCostModel +from ..tune_context import TuneContext class RandomModel(PyCostModel): @@ -70,7 +70,7 @@ def load(self, path: str) -> None: path : str The file path. """ - self.random_state = tuple(np.load(path, allow_pickle=True)) + self.random_state = tuple(np.load(path, allow_pickle=True)) # type: ignore def save(self, path: str) -> None: """Save the cost model to given file location. @@ -116,7 +116,7 @@ def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> n The predicted running results. """ np.random.set_state(self.random_state) - # todo(@zxybazh): Use numpy's RandState object: + # TODO(@zxybazh): Use numpy's RandState object: # https://numpy.org/doc/1.16/reference/generated/numpy.random.RandomState.html#numpy.random.RandomState result = np.random.rand(len(candidates)) * self.max_range self.random_state = np.random.get_state() diff --git a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py index d52eda3daac1..d805648bfbfd 100644 --- a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py +++ b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py @@ -17,7 +17,7 @@ """Random Feature Extractor.""" from typing import List, Union, Tuple -import numpy as np +import numpy as np # type: ignore from tvm.runtime.ndarray import NDArray, array from ..tune_context import TuneContext diff --git a/python/tvm/meta_schedule/runner/local_runner.py b/python/tvm/meta_schedule/runner/local_runner.py index 6af403905cb4..b1a9c678c6fc 100644 --- a/python/tvm/meta_schedule/runner/local_runner.py +++ b/python/tvm/meta_schedule/runner/local_runner.py @@ -33,7 +33,7 @@ run_evaluator_common, ) -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # pylint: disable=invalid-name class LocalRunnerFuture(RunnerFuture): diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py index a9b2d560314a..80f372a448f5 100644 --- a/python/tvm/meta_schedule/space_generator/post_order_apply.py +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -32,5 +32,5 @@ class PostOrderApply(SpaceGenerator): def __init__(self): """Constructor""" self.__init_handle_by_constructor__( - _ffi_api.SpaceGeneratorPostOrderApply, # pylint: disable=no-member + _ffi_api.SpaceGeneratorPostOrderApply, # type: ignore # pylint: disable=no-member ) diff --git a/python/tvm/meta_schedule/task_scheduler/gradient_based.py b/python/tvm/meta_schedule/task_scheduler/gradient_based.py index fa573fdd4a74..c21e2b961f88 100644 --- a/python/tvm/meta_schedule/task_scheduler/gradient_based.py +++ b/python/tvm/meta_schedule/task_scheduler/gradient_based.py @@ -19,9 +19,11 @@ from typing import TYPE_CHECKING, List, Optional, Callable from tvm._ffi import register_object +from tvm._ffi.registry import register_func from tvm.ir import IRModule from tvm.tir import Schedule +from tvm.tir.function import PrimFunc from ..measure_callback import MeasureCallback from ..builder import Builder from ..runner import Runner @@ -35,24 +37,17 @@ from ..tune_context import TuneContext -def derive_similarity_tag(log_base=1.618): - def compute(mod: IRModule): - ret = "" - sch = Schedule(mod) - for func in mod.get_global_vars: - sref = sch.get_sref(sch.get_block(func)) - if ( - sref is not None - and sref.stmt is not None - and "meta_scheduler_task_scheduler_tag" in sref.stmt.annotations - ): - ret += sref.stmt.annotations["meta_scheduler_task_scheduler_tag"] + "_" - if ret: - flop_count = _ffi_api.TaskSchedulerFlopCount(mod) # type: ignore # pylint: disable=no-member - ret += "%d" % int(math.log(flop_count + 1, log_base)) - return ret +@register_func("meta_schedule.task_scheduler.derive_similarity_tag") +def derive_similarity_tag(mod: IRModule, log_base: float = 1.618): + ret = "" + for var in mod.get_global_vars(): - return compute + if "meta_scheduler_task_scheduler_tag" in mod[var].attrs: + ret += mod[var].attrs.meta_scheduler_task_scheduler_tag + "_" + if ret: + flop_count = _ffi_api.TaskSchedulerFlopCount(mod) # type: ignore # pylint: disable=no-member + ret += "%d" % int(math.log(flop_count + 1, log_base)) + return ret @register_object("meta_schedule.GradientBased") @@ -71,11 +66,15 @@ def __init__( backward_window_size: int = 3, seed: int = -1, task_weights: List[float] = None, - objective_func: Callable[[List[float]], float] = None, - tag_generation_func: Callable[[IRModule], str] = derive_similarity_tag(), + objective_func_name: str = "meta_schedule.task_scheduler.objective_func", + tag_generation_func_name: str = "meta_schedule.task_scheduler.derive_similarity_tag", cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, ) -> None: + @register_func("meta_schedule.task_scheduler.objective_func") + def weighted_sum(l: List[float]) -> float: + return sum([l[i] * w for i, w in enumerate(self.task_weights)]) + """Constructor. Parameters @@ -98,10 +97,10 @@ def __init__( The random seed. task_weights: Optional[List[float]] The weights of each task. - objective_func: - The objective function for gradient optimization. - tag_generation_func - The function to generate similarity tag for workloads. + objective_func_name: + The name of objective function for gradient optimization. + tag_generation_func_name: + The name of function to generate similarity tag for workloads. cost_model: CostModel The cost model of the scheduler. measure_callbacks: Optional[List[MeasureCallback]] @@ -109,24 +108,25 @@ def __init__( """ if task_weights is None: task_weights = [1.0 for _ in tasks] + self.task_weights = task_weights + assert len(task_weights) == len( tasks ), "The given task weights should be same length as tasks." - if objective_func is None: - objective_func = lambda l: sum([l[i] * w for i, w in enumerate(task_weights)]) + self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerGradientBased, # type: ignore # pylint: disable=no-member tasks, builder, runner, database, - cost_model, - measure_callbacks, - task_weights, alpha, beta, backward_window_size, seed, - objective_func, - tag_generation_func, + task_weights, + objective_func_name, + tag_generation_func_name, + cost_model, + measure_callbacks, ) diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index 274638955287..a63d9a3f2183 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -19,8 +19,8 @@ from typing import List, Optional, TYPE_CHECKING from tvm._ffi import register_object +from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback -from ..measure_callback import MeasureCallback from ..builder import Builder from ..runner import Runner from ..database import Database @@ -35,7 +35,21 @@ @register_object("meta_schedule.RoundRobin") class RoundRobin(TaskScheduler): - """Round Robin Task Scheduler""" + """Round Robin Task Scheduler + + Parameters + ---------- + tasks: List[TuneContext] + The list of tune context to process. + builder: Builder + The builder of the scheduler. + runner: Runner + The runner of the scheduler. + database: Database + The database of the scheduler. + measure_callbacks: Optional[List[MeasureCallback]] = None + The list of measure callbacks of the scheduler. + """ def __init__( self, @@ -58,8 +72,6 @@ def __init__( The runner. database : Database The database. - cost_model: CostModel - The cost model of the scheduler. measure_callbacks: Optional[List[MeasureCallback]] The list of measure callbacks of the scheduler. """ diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 4a63134417e1..b6fe34839264 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -21,7 +21,7 @@ import shutil from typing import Any, Callable, List, Optional, Union -import psutil +import psutil # type: ignore import tvm from tvm._ffi import get_global_func, register_func from tvm.error import TVMError @@ -66,47 +66,29 @@ def _process_error_message(error_msg: str) -> str: def cpu_count(logical: bool = True) -> int: """Return the number of logical or physical CPUs in the system + Parameters ---------- logical : bool = True If True, return the number of logical CPUs, otherwise return the number of physical CPUs + Returns ------- cpu_count : int The number of logical or physical CPUs in the system + Note ---- The meta schedule search infra intentionally does not adopt the following convention in TVM: - C++ API `tvm::runtime::threading::MaxConcurrency()` - Environment variable `TVM_NUM_THREADS` or - Environment variable `OMP_NUM_THREADS` + This is because these variables are dedicated to controlling the runtime behavior of generated kernels, instead of the host-side search. Setting these variables may interfere the host-side search with profiling of generated kernels when measuring locally. """ - return psutil.cpu_count(logical=logical) or 1 - - -@register_func("meta_schedule._process_error_message") -def _process_error_message(error_msg: str) -> str: - error_msg_lines = str(error_msg).splitlines() - if len(error_msg_lines) >= 50: - return "\n".join(error_msg_lines[:25] + ["..."] + error_msg_lines[-25:]) - return error_msg - - -def cpu_count(logical: bool = True) -> int: - """Return the number of logical or physical CPUs in the system - Parameters - ---------- - logical : bool = True - If True, return the number of logical CPUs, otherwise return the number of physical CPUs - Returns - ------- - cpu_count : int - The number of logical or physical CPUs in the system - """ return _cpu_count_impl(logical) @@ -115,14 +97,17 @@ def get_global_func_with_default_on_worker( default: Callable, ) -> Callable: """Get the registered global function on the worker process. + Parameters ---------- name : Union[None, str, Callable] If given a string, retrieve the function in TVM's global registry; If given a python function, return it as it is; Otherwise, return `default`. + default : Callable The function to be returned if `name` is None. + Returns ------- result : Callable @@ -150,6 +135,7 @@ def get_global_func_on_rpc_session( extra_error_msg: Optional[str] = None, ) -> PackedFunc: """Get a PackedFunc from the global registry from an RPCSession. + Parameters ---------- session : RPCSession @@ -158,6 +144,7 @@ def get_global_func_on_rpc_session( The name of the PackedFunc extra_error_msg : Optional[str] Extra information to provide in the error message + Returns ------- result : PackedFunc @@ -181,10 +168,12 @@ def remove_build_dir(artifact_path: str) -> None: def _json_de_tvm(obj: Any) -> Any: """Unpack a TVM nested container to a JSON object in python. + Parameters ---------- obj : Any The TVM nested container to be unpacked. + Returns ------- result : Any @@ -232,10 +221,12 @@ def batch_json_str2obj(json_strs: List[str]) -> List[Any]: def structural_hash(mod: IRModule) -> str: """Get the structural hash of a module. + Parameters ---------- mod : IRModule The module to be hashed. + Returns ------- result : str @@ -249,24 +240,11 @@ def structural_hash(mod: IRModule) -> str: return str(shash) -def _get_hex_address(handle: ctypes.c_void_p) -> str: - """Get the hexadecimal address of a handle. - Parameters - ---------- - handle : ctypes.c_void_p - The handle to be converted. - Returns - ------- - result : str - The hexadecimal address of the handle. - """ - return hex(ctypes.cast(handle, ctypes.c_void_p).value) - - def check_override( derived_class: Any, base_class: Any, required: bool = True, func_name: str = None ) -> Callable: """Check if the derived class has overridden the base class's method. + Parameters ---------- derived_class : Any @@ -278,6 +256,7 @@ def check_override( func_name : str Name of the method. Default value None, which would be set to substring of the given function, e.g. `f_generate`->`generate`. + Returns ------- func : Callable @@ -299,3 +278,17 @@ def inner(func: Callable): return func return inner + + +def _get_hex_address(handle: ctypes.c_void_p) -> str: + """Get the hexadecimal address of a handle. + Parameters + ---------- + handle : ctypes.c_void_p + The handle to be converted. + Returns + ------- + result : str + The hexadecimal address of the handle. + """ + return hex(ctypes.cast(handle, ctypes.c_void_p).value) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index e9fc10186c87..5cfd3a16c3bc 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -284,17 +284,13 @@ def _module_export(module, file_name): # fcompile, addons, kwargs? @register_func("tvm.relay.build") -def _build_module_no_factory_impl(mod, target, target_host, params, mod_name): - target, target_host = Target.check_and_update_host_consist(target, target_host) - return build(mod, target, params=params, mod_name=mod_name).module - - def _build_module_no_factory(mod, target=None, target_host=None, params=None, mod_name="default"): """A wrapper around build which discards the Python GraphFactoryRuntime. This wrapper is suitable to be used from other programming languages as the runtime::Module can be freely passed between language boundaries. """ - return _build_module_no_factory_impl(mod, target, target_host, params, mod_name) + target, target_host = Target.check_and_update_host_consist(target, target_host) + return build(mod, target, params=params, mod_name=mod_name).module def _reconstruct_from_deprecated_options(deprecated_params_target): diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 475c38b71cb5..42bd52930b1a 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -19,16 +19,16 @@ from typing import Callable, List, Mapping, Union import inspect -from tvm._ffi import get_global_func, register_object +import tvm._ffi +import tvm.runtime +from tvm.runtime import Object from tvm.ir import BaseFunc -from tvm.runtime import Object, convert - -from . import _ffi_api from .buffer import Buffer -from .expr import PrimExpr, Var +from .expr import Var, PrimExpr +from . import _ffi_api -@register_object("tir.PrimFunc") +@tvm._ffi.register_object("tir.PrimFunc") class PrimFunc(BaseFunc): """A function declaration expression. @@ -57,7 +57,7 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa param_list = [] buffer_map = {} if buffer_map is None else buffer_map for x in params: - x = convert(x) if not isinstance(x, Object) else x + x = tvm.runtime.convert(x) if not isinstance(x, Object) else x if isinstance(x, Buffer): var = Var(x.name, dtype="handle") param_list.append(var) @@ -68,13 +68,7 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa raise TypeError("params can only contain Var or Buffer") self.__init_handle_by_constructor__( - _ffi_api.PrimFunc, # type: ignore # pylint: disable=no-member - param_list, - body, - ret_type, - buffer_map, - attrs, - span, + _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span # type: ignore ) def with_body(self, new_body, span=None): @@ -148,7 +142,7 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: func : PrimFunc The new function with parameter specialized """ - return _ffi_api.Specialize(self, param_map) # type: ignore # pylint: disable=no-member + return _ffi_api.Specialize(self, param_map) # type: ignore def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: """Print IRModule into TVMScript diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 596d70c2d342..51cf67f92542 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2014,6 +2014,7 @@ def after_tensorize( ########## Schedule: Annotation ########## + @type_checked def annotate( self, block_or_loop: Union[BlockRV, LoopRV], @@ -2030,6 +2031,45 @@ def annotate( The annotation key ann_val : Union[str, int, float, ExprRV, List[Union[str, int, float, ExprRV]]] The annotation value + + Examples + -------- + + Before annotate, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_annotate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do annotate: + + .. code-block:: python + + sch = tir.Schedule(before_annotate) + sch.annotate(sch.get_block("B"), "ann_key", "ann_value") + print(sch.mod["main"].script()) + + After applying annotate, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_annotate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"ann_key", "ann_value"}) + B[vi, vj] = A[vi, vj] * 2.0 + """ if isinstance(ann_val, str): ann_val = String(ann_val) @@ -2037,10 +2077,11 @@ def annotate( ann_val = IntImm("int32", ann_val) elif isinstance(ann_val, float): ann_val = FloatImm("float32", ann_val) - _ffi_api.ScheduleAnnotate( # pylint: disable=no-member + _ffi_api.ScheduleAnnotate( # type: ignore # pylint: disable=no-member self, block_or_loop, ann_key, ann_val ) + @type_checked def unannotate(self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str) -> None: """Unannotate a block/loop's annotation with key ann_key @@ -2050,83 +2091,48 @@ def unannotate(self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str) -> Non The block/loop to be unannotated ann_key : str The annotation key - """ - _ffi_api.ScheduleUnannotate(self, block_or_loop, ann_key) # pylint: disable=no-member - - ########## Schedule: Layout transformation ########## - - def transform_layout( - self, - block: BlockRV, - buffer_index: int, - is_write_index: bool, - index_map: Union[IndexMap, Callable], - ) -> None: - """Apply a transformation represented by IndexMap to buffer - - Parameters - ---------- - block_rv : BlockRV - The block that accesses the target buffer - buffer_index: int - The index of the buffer in block's read or write region - is_write_index : bool - Whether the buffer_index is the index of the block's write region - index_map : Union[IndexMap, Callable] - The transformation to apply Examples -------- - Before transform_layout, in TensorIR, the IR is: + Before unannotate, in TensorIR, the IR is: .. code-block:: python @T.prim_func - def before_transform_layout(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (128, 128), "float32") - B = T.alloc_buffer((128, 128), "float32") - C = T.match_buffer(c, (128, 128), "float32") + def before_unannotate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) + T.block_attr({"ann_key", "ann_value"}) B[vi, vj] = A[vi, vj] * 2.0 - for i, j in T.grid(128, 128): - with T.block("C"): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = B[vi, vj] + 1.0 - Create the schedule and do transform_layout: + Create the schedule and do annotate: .. code-block:: python - sch = tir.Schedule(before_storage_align) - sch.transform_layout(sch.get_block("B"), buffer_index=0, is_write_index=True, - index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16)) + sch = tir.Schedule(before_unannotate) + sch.unannotate(sch.get_block("B"), "ann_key") print(sch.mod["main"].script()) - After applying transform_layout, the IR becomes: + After applying unannotate, the IR becomes: .. code-block:: python @T.prim_func - def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (128, 128), "float32") - B = T.alloc_buffer((8, 8, 16, 16), "float32") - C = T.match_buffer(c, (128, 128), "float32") + def after_unannotate(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] * 2.0 - for i, j in T.grid(128, 128): - with T.block("C"): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0 + B[vi, vj] = A[vi, vj] * 2.0 + """ - if callable(index_map): - index_map = IndexMap.from_func(index_map) - _ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member - self, block, buffer_index, is_write_index, index_map + _ffi_api.ScheduleUnannotate( # type: ignore # pylint: disable=no-member + self, block_or_loop, ann_key ) ########## Schedule: Layout transformation ########## diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 49745cd1a91d..0c0ea9cdb3a2 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -636,18 +636,6 @@ def PlanAndUpdateBufferAllocationLocation(): return _ffi_api.PlanAndUpdateBufferAllocationLocation() # type: ignore -def ApplyBlockBoundPredicate(): - """Narrow the extents of some loops by checking whether some constraints in the block iter - bound predicates can be directly applied on the loops. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.ApplyBlockBoundPredicate() # type: ignore - - def ConvertBlocksToOpaque(): """Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding iter_values in BlockRealize, and then convert the blocks into diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 02e940ea79e3..a4de6592ca13 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -530,22 +530,12 @@ class IterMapRewriter : public ExprMutator { if (predicate_induced_max.defined()) { iter_max = min(predicate_induced_max.value(), iter_max); } - if (analyzer_->CanProve(iter_min <= iter_max)) { - if (!is_zero(iter_min)) { - // structured form's offset should be updated - flattened_map_.erase(structured_form); - structured_form.CopyOnWrite()->base = -iter_min; - mark.CopyOnWrite()->source = structured_form; - flattened_map_[structured_form] = flattened_form; - } - mark.CopyOnWrite()->extent = iter_max - iter_min; - sum_fuse_map_[flattened_form] = {mark, iter_min}; - // we need to note down the flattened form of constrained iterators - // to check the validity of constraints, see also CheckConstraints() - constrained_iters_flattened_.push_back(flattened_form); - expr.CopyOnWrite()->args = Array({split}); - expr.CopyOnWrite()->base = base + iter_min; - return expr; + if (!is_zero(iter_min)) { + // structured form's offset should be updated + flattened_map_.erase(structured_form); + structured_form.CopyOnWrite()->base = -iter_min; + mark.CopyOnWrite()->source = structured_form; + flattened_map_[structured_form] = flattened_form; } mark.CopyOnWrite()->extent = iter_max - iter_min; sum_fuse_map_[flattened_form] = {mark, iter_min}; @@ -621,7 +611,7 @@ class IterMapRewriter : public ExprMutator { } } } - if (!base_scale || base_scale.value()->value < 0) { + if (!base_scale) { diag_ctx_.Emit(Diagnostic::Error(expr->span) << "Fuse iters failed, can not find a valid base scale"); return NullOpt; @@ -900,20 +890,7 @@ bool MatchBoundConstraints(PrimExpr pred, Map& input_iters, iter = lhs_expr; } } - // If it is a predicate for input iters - if (const auto* var_ptr = iter.as()) { - auto it = input_iters.find(GetRef(var_ptr)); - if (it == input_iters.end()) { - return false; - } - PrimExpr iter_min = (*it).second->min; - PrimExpr iter_max = (*it).second->min + (*it).second->extent; - if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value()); - if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value()); - input_iters.Set(GetRef(var_ptr), Range(iter_min, iter_max)); - } else { - result.emplace_back(iter, lower_bound, upper_bound, 0); - } + result.emplace_back(iter, lower_bound, upper_bound, 0); if (is_finish) { break; } diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 8c9e2d8949e9..1eac10d1ad82 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -17,7 +17,6 @@ * under the License. */ #include "../utils.h" -#include "tvm/tir/schedule/schedule.h" namespace tvm { namespace meta_schedule { diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index 2d9ae10135fa..90c1bca1fe2b 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -55,7 +55,7 @@ class GradientBasedNode final : public TaskSchedulerNode { Array input_latencies; for (double latency : latencies) input_latencies.push_back(FloatImm(DataType::Float(32), latency)); - return objective_func(input_latencies); + return objective_func(input_latencies)->value; } void _adjust_similarity_group(int task_id) { @@ -95,8 +95,10 @@ class GradientBasedNode final : public TaskSchedulerNode { // Calculate gradients if already warmed up double max_gradient = -1e30, min_gradient = 1e30; int arg_min_gradient = -1; + std::vector tasks_alive; for (task_id = 0; task_id < n_tasks; ++task_id) { if (!tasks[task_id]->is_stopped) { + tasks_alive.push_back(task_id); // compute gradient from chain rule : (delta f / delta g_i) // here f is given as objective function, default weighted sum double delta = 1e-4; @@ -120,8 +122,13 @@ class GradientBasedNode final : public TaskSchedulerNode { // compute (g_i(t_i + \Delta t) - g(t_i)) / (\Delta t) // which is approximated by // min( - g_i(t_i) / t_i, \Beta \frac{C_i}{max_{k \in N_i}(V_k)} - g_i(t_i)) - double g_next_1 = - task_best_latencies[task_id] - (task_best_latencies[task_id] / task_cnts[task_id]); + double g_next_1; + if (task_cnts[task_id] > 0) { + g_next_1 = + task_best_latencies[task_id] - (task_best_latencies[task_id] / task_cnts[task_id]); + } else { + g_next_1 = beta * 1e30; + } double g_next_2 = beta * 1e30; int group_id = tag_to_group[task_tag[task_id]]; if (task_groups[group_id].size() > 1) { @@ -135,7 +142,10 @@ class GradientBasedNode final : public TaskSchedulerNode { double forward_grad = g_next - task_best_latencies[task_id]; double gradient = chain_grad * (alpha * backward_grad + (1 - alpha) * forward_grad); - ICHECK(gradient <= 0) << "Wrong gradient calculated, should be less than or equal to 0."; + ICHECK(gradient <= 0) + << "Wrong gradient calculated, should be less than or equal to 0. Chain_grad: " + << chain_grad << ", backward_grad: " << backward_grad + << ", forward_grad: " << forward_grad << "."; if (gradient > max_gradient) { max_gradient = gradient; } @@ -145,10 +155,19 @@ class GradientBasedNode final : public TaskSchedulerNode { } } } + // all tasks done + if (tasks_alive.size() == 0) return -1; + // same gradient, sample any task if (std::abs(max_gradient - min_gradient) < 1e-6) { - arg_min_gradient = tir::SampleInt(&rand_state, 0, n_tasks); + task_id = tasks_alive[tir::SampleInt(&rand_state, 0, tasks_alive.size())]; + } else { + task_id = arg_min_gradient; + } + // check if task is running + if (IsTaskRunning(task_id)) { + JoinRunningTask(task_id); } - return arg_min_gradient; + return task_id; } void JoinRunningTask(int task_id) final { @@ -159,7 +178,7 @@ class GradientBasedNode final : public TaskSchedulerNode { Array results; task_cnts[task_id]++; results.reserve(n); - double best_latency = 1e30; + double trial_best_latency = 1e30; for (const RunnerFuture future : task->runner_futures.value()) { RunnerResult result = future->Result(); results.push_back(result); @@ -170,13 +189,15 @@ class GradientBasedNode final : public TaskSchedulerNode { count += 1; sum += run_sec->value; } - best_latency = std::min(best_latency, sum / count); + trial_best_latency = std::min(trial_best_latency, sum / count); } } - task_latency_history[task_id].push_back(best_latency); - if (task_latency_history[task_id].size() == 1 || best_latency < task_best_latencies[task_id]) { - task_best_latencies[task_id] = best_latency; + + if (task_latency_history[task_id].size() == 0 || + trial_best_latency < task_best_latencies[task_id]) { + task_best_latencies[task_id] = trial_best_latency; } + task_latency_history[task_id].push_back(task_best_latencies[task_id]); _adjust_similarity_group(task_id); task->search_strategy.value()->NotifyRunnerResults(task, task->measure_candidates.value(), results); @@ -195,20 +216,19 @@ class GradientBasedNode final : public TaskSchedulerNode { } }; -TaskScheduler TaskScheduler::GradientBased( - Array tasks, // - Builder builder, // - Runner runner, // - Database database, // - double alpha, // - double beta, // - int backward_window_size, // - support::LinearCongruentialEngine::TRandState seed, // - Array task_weights, // - TaskSchedulerNode::FObjectiveFunc objective_func, // - TaskSchedulerNode::FTagGenerationFunc tag_generation_func, // - Optional cost_model, // - Optional> measure_callbacks) { +TaskScheduler TaskScheduler::GradientBased(Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + double alpha, // + double beta, // + int backward_window_size, // + support::LinearCongruentialEngine::TRandState seed, // + Array task_weights, // + String objective_func_name, // + String tag_generation_func_name, // + Optional cost_model, // + Optional> measure_callbacks) { ObjectPtr n = make_object(); n->alpha = alpha; n->beta = beta; @@ -231,13 +251,17 @@ TaskScheduler TaskScheduler::GradientBased( n->task_latency_history.assign(n->tasks.size(), std::vector()); n->task_weights.assign(n->tasks.size(), 1); - CHECK(objective_func != nullptr) << "The task objective function is empty!"; - CHECK(tag_generation_func != nullptr) << "The task tag generation function is empty!"; - n->objective_func = objective_func; + const auto* objective_func_ptr = runtime::Registry::Get(objective_func_name); + CHECK(objective_func_ptr) << "The given objective function is undefined!"; + n->objective_func = *objective_func_ptr; + + const auto* tag_generation_func_ptr = runtime::Registry::Get(tag_generation_func_name); + CHECK(tag_generation_func_ptr) << "The given tag generation function is undefined!"; + TaskSchedulerNode::FTagGenerationFunc tag_generation_func = *tag_generation_func_ptr; if (task_weights.defined()) { CHECK(task_weights.size() == n->tasks.size()) - << "Given task weights number does not equal to task number!"; + << "The given task weights number does not equal to task number!"; int cnt = 0; for (const FloatImm& weight : task_weights) { n->task_weights[cnt++] = weight->value; diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index e73f7bb63f09..a1f53b1960fe 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -27,8 +27,6 @@ #include #include #include -#include -#include #include #include #include @@ -36,9 +34,8 @@ #include #include #include +#include -#include -#include #include #include diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index e47c33a9d22e..92bd6bd4bf99 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -70,26 +70,6 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl */ StmtSRef GetSRefTreeRoot(const StmtSRef& sref); -/*! - * \brief The information of a block scope, including the leaf blocks, - * as well as the loop types (spatial, reduction) for each loop in the scope. - */ -struct ScopeBlockLoopInfo { - /*! \brief A list of the leaf blocks, from left to right */ - std::vector realizes; - /*! \brief The loop vars bound to spatial block iters */ - std::unordered_set spatial_vars; - /*! \brief The loop vars bound to non-spatial block iters */ - std::unordered_set non_spatial_vars; -}; - -/*! - * \brief Inspect the scope of the given sref - * \param scope_block The root block of the scope - * \return The information of the scope - */ -ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block); - /******** Scope ********/ /*! * \brief Checks if scope the specified sref is in is a stage-pipeline and return it @@ -255,15 +235,6 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va */ void CheckAffineBinding(const ScheduleState& self, Block block); -/*! - * \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop, - * from outer to inner. - * \param self The schedule state - * \param block_sref The block to be checked - * \return A boolean flag indicating if the block has a trivial binding - */ -bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref); - /*! * \brief Extracts the ranges of loop variables in a path of the sref tree * \param low_inclusive The lowest node in the path @@ -647,17 +618,27 @@ bool CanComputeInline(const ScheduleState& self, const StmtSRef& block_sref); bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref); /*! - * \brief Provided the access pattern to a buffer, suggest one of the possible layout - * transformation to minimize the locality of the access pattern. - * \param buffer The buffer to be transformed - * \param indices The access pattern to the buffer - * \param loops The loops above the buffer - * \param predicate The predicate of the access - * \param analyzer Arithmetic analyzer + * \brief Checks if a producer block could be successfully computed at the specific loop. + * \param self The schedule state + * \param block_sref The block to be moved + * \param loop_sref The loop where the block to be moved to + * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 + * \return A boolean indicating whether the block could be successfully compute at the specific loop */ -Optional SuggestIndexMap(const Buffer& buffer, const Array& indices, - const Array& loops, const PrimExpr& predicate, - arith::Analyzer* analyzer); +bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, + bool preserve_unit_loops); + +/*! + * \brief Checks if a consumer block could be successfully computed at the specific loop. + * \param self The schedule state + * \param block_sref The block to be moved + * \param loop_sref The loop where the block to be moved to + * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 + * \return A boolean indicating whether the block could be successfully reverse compute at the + * specific loop + */ +bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& loop_sref, bool preserve_unit_loops); /*! * \brief Provided the access pattern to a buffer, suggest one of the possible layout diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 4f642af8b95b..bdb4295e900b 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -518,22 +518,6 @@ void CheckAffineBinding(const ScheduleState& self, Block block) { } } -bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); - Array loops = GetLoops(block_sref); - Array binds = GetBlockRealize(self, block_sref)->iter_values; - if (loops.size() != binds.size()) { - return false; - } - for (int i = 0, n = loops.size(); i < n; ++i) { - const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]); - if (binds[i].get() != loop->loop_var.get()) { - return false; - } - } - return true; -} - Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, const Optional& high_exclusive, const runtime::StorageScope& extra_relax_scope) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 60be2efb5245..3501e7cb723f 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -237,7 +237,7 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { if (it == this->symbol_table_.end()) { LOG(FATAL) << "IndexError: Cannot find corresponding BlockRV: " << block_rv; } - ObjectRef obj = (*it).second; + const ObjectRef& obj = (*it).second; const auto* sref = obj.as(); if (sref == nullptr) { LOG(FATAL) << "ValueError: BlockRV's corresponding type is invalid: " @@ -256,7 +256,7 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { if (it == this->symbol_table_.end()) { LOG(FATAL) << "IndexError: Cannot find corresponding LoopRV: " << loop_rv; } - ObjectRef obj = (*it).second; + const ObjectRef& obj = (*it).second; if (obj.same_as(inline_mark)) { return inline_mark; } diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 71ee09ab6829..14d05a4a340c 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -43,7 +43,7 @@ namespace tir { * * // Convertible to `InstructionKindNode::FInstructionApply` * static Array ApplyToSchedule( - * const Schedule& sch, + * const tir::Schedule& sch, * const Array& inputs, * const Array& attrs, * const Optional& decision); diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 4ad09ab3dfdf..b445b5a9ded8 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -22,7 +22,6 @@ #include #include -#include #include namespace tvm { @@ -441,6 +440,7 @@ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int * \param ann_key The annotation key */ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key); + /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc index 4ed40817132d..f5c1978a1b25 100644 --- a/src/tir/schedule/primitive/annotate.cc +++ b/src/tir/schedule/primitive/annotate.cc @@ -116,7 +116,8 @@ struct AnnotateTraits : public UnpackedInstTraits { return py.Str(); } - friend struct UnpackedInstTraits; + template + friend struct ::tvm::tir::UnpackedInstTraits; }; struct UnannotateTraits : public UnpackedInstTraits { @@ -147,7 +148,8 @@ struct UnannotateTraits : public UnpackedInstTraits { return py.Str(); } - friend struct UnpackedInstTraits; + template + friend struct ::tvm::tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(AnnotateTraits); diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 1ed5bdc03f51..b811afb23614 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -25,49 +25,6 @@ using support::NDIntSet; /******** Error Classes ********/ -/*! - * \brief Represent the iteration domain to fully cover the required region of Intersect(dom, bound) - * The bound region may not get directly intersected with dom region, instead we try to generate - * extra predicates for non-trivial bound. The domain info class can also union with each other. - */ -struct BlockVarDomainInfo { - arith::IntSet dom{arith::IntSet::Nothing()}; // dom is ensured to be bounded - arith::IntSet bound{arith::IntSet::Nothing()}; - - /*! \brief Relaxed union operation */ - void Union(const BlockVarDomainInfo& other) { - // just relax (d0 ^ b0) v (d1 ^ b1) to (d0 v d1) ^ (b0 v b1) - dom = arith::Union({dom, other.dom}); - bound = arith::Union({bound, other.bound}); - } - - /*! \brief Simplify domain info */ - void Simplify(arith::Analyzer* analyzer) { - auto to_simplified = [analyzer](const arith::IntSet& set) { - PrimExpr min = set.HasLowerBound() ? analyzer->Simplify(set.min()) : set.min(); - PrimExpr max = set.HasUpperBound() ? analyzer->Simplify(set.max()) : set.max(); - return arith::IntSet::Interval(min, max); - }; - // if no dom specified, try use bound as dom - if (dom.IsNothing()) { - if (bound.HasLowerBound() && bound.HasUpperBound()) { - bound = to_simplified(bound); - std::swap(dom, bound); - } - return; - } - // simplify intsets - dom = to_simplified(dom); - bound = to_simplified(bound); - // if can proof the dom is within bound, remove bound - auto intersect = to_simplified(arith::Intersect({dom, bound})); - if (analyzer->CanProveEqual(dom.min(), intersect.min()) && - analyzer->CanProveEqual(dom.max(), intersect.max())) { - bound = arith::IntSet::Nothing(); - } - } -}; - /*! * \brief An error raised when not all required blocks are under the given loop. * \tparam is_consumer Indicates if all the required blocks are consumers or producers @@ -360,8 +317,6 @@ class ScopeReconstructor : private StmtMutator { Stmt rm_src_stmt_{nullptr}; /*! \brief The plan to remove the given block by replacing to this loop/block in the AST */ Stmt rm_tgt_stmt_{nullptr}; - /*! \brief Bound predicate for the given block to be moved */ - Optional predicate{NullOpt}; }; /*! @@ -592,11 +547,9 @@ void CalculateProvidedRequiredRegions( /******** Main Implementation ********/ template -std::function ComputeAtOrReverseComputeAtImpl(ScheduleState self, - const StmtSRef& block_sref, - const StmtSRef& loop_sref, - bool preserve_unit_loops, - arith::Analyzer* analyzer) { +void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref, + const StmtSRef& loop_sref, bool preserve_unit_loops, + arith::Analyzer* analyzer, bool check_only = false) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); // Step 1. Bunch of checks @@ -651,35 +604,32 @@ std::function ComputeAtOrReverseComputeAtImpl(ScheduleState self, reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms), /*analyzer=*/analyzer, /*preserve_unit_loops=*/preserve_unit_loops); Block new_scope_root = Downcast(reconstructor(scope_root)); - Optional bound_predicate = reconstructor.predicate; - return [=]() -> void { - // Step 7. Do the actual replacement - self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}}); - // Step 8. Update the cached flags - BlockInfo& block_info = self->block_info[block_sref]; - block_info.affine_binding = IsAffineBinding( - /*realize=*/reconstructor.new_block_realize_, - /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef(block_sref->parent)), - /*analyzer=*/analyzer); - // Step 9. Add bound predicate annotation for the block to be moved if needed - if (bound_predicate.defined()) { - Annotate(self, block_sref, attr::require_block_var_bound_predicate, bound_predicate.value()); - } - }; + + // Step 7. Do the actual replacement + if (check_only) { + return; + } + self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}}); + // Step 8. Update the cached flags + BlockInfo& block_info = self->block_info[block_sref]; + block_info.affine_binding = IsAffineBinding( + /*realize=*/reconstructor.new_block_realize_, + /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef(block_sref->parent)), + /*analyzer=*/analyzer); } void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops) { arith::Analyzer analyzer; ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer)(); + &analyzer); } void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops) { arith::Analyzer analyzer; ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer)(); + &analyzer); } bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, @@ -687,7 +637,7 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S arith::Analyzer analyzer; try { ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer); + &analyzer, true); } catch (const tvm::runtime::Error& e) { return false; } @@ -699,7 +649,7 @@ bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, arith::Analyzer analyzer; try { ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer); + &analyzer, true); } catch (const tvm::runtime::Error& e) { return false; } diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 98106d51f4db..03ffb4fe159e 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -490,6 +490,7 @@ class LoopPropertyError : public ScheduleError { CheckGetSingleChildBlockRealizeOnSRefTree(self, self->stmt2ref.at(loop.get())); meet_reduction_loop = true; } + continue; } else if (meet_reduction_loop && !is_one(loop->extent)) { throw LoopPropertyError(self->mod, loop, kUnboundLoopUnderReductionLoop); } @@ -590,8 +591,8 @@ class BaseBlockCreator { } private: - virtual void CreateNormalIters(int idx) = 0; virtual void CreateAdditionalIter() = 0; + virtual void CreateNormalIters(int idx) = 0; virtual void CreateReductionUpdate() = 0; virtual void CreateReadWriteRegions() = 0; @@ -824,13 +825,6 @@ class WriteBackBlockCreator : public BaseBlockCreator { } } - void CreateAdditionalIter() final { - additional_iter_ = IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kCommReduce); - iter_vars_.insert(iter_vars_.end(), additional_iter_); - iter_values_.insert(iter_values_.end(), rf_loop_->loop_var); - var_map_.Set(rf_additional_iter_->var, additional_iter_->var); - } - void CreateReductionUpdate() final { wb_lhs_ = Downcast(Substitute(combiner_lhs_, var_map_)); wb_rhs_ = diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 9e1658a61768..0e767825573f 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -86,7 +86,6 @@ struct PrimeTable { pow_tab.emplace_back(std::move(tab)); } } - /*! * \brief Factorize a number n, and return in a cryptic format * \param n The number to be factorized @@ -300,17 +299,27 @@ std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandS return SamplePerfectTile(rand_state, extent, n_splits); } CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; - while (true) { - std::vector result = SamplePerfectTile(rand_state, extent, n_splits); - if (result.back() <= max_innermost_factor) { - return result; + std::vector innermost_candidates; + innermost_candidates.reserve(max_innermost_factor); + for (int32_t i = 1; i <= max_innermost_factor; ++i) { + if (extent % i == 0) { + innermost_candidates.push_back(i); } } + // N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space. + // We should do multiple factorization to weight the choices. However, it would lead to slower + // sampling speed. On the other hand, considering potential tricks we might do on the innermost + // loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe add + // more heuristics in the future + int32_t innermost = innermost_candidates[SampleInt(rand_state, 0, innermost_candidates.size())]; + std::vector result = SamplePerfectTile(rand_state, extent / innermost, n_splits - 1); + result.push_back(innermost); + return result; } std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // - const StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, + const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, Optional>* decision) { const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); const int64_t* extent = GetLoopIntExtent(loop); diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 2624afa476e0..eb43157d805a 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -339,10 +339,6 @@ class BlockInfoCollector : private StmtVisitor { /*dom_low_inclusive=*/parent_sref, /*dom_high_exclusive=*/lca, /*analyzer=*/&analyzer_); - for (size_t i = 0; i < consumed_region.size(); ++i) { - const arith::IntSet consumed_interset = arith::Intersect( - {consumed_region[i], arith::IntSet::FromMinExtent(0, buffer->shape[i])}); - } if (!ProducerCoversConsumer(buffer->shape, produced_region, consumed_region, &analyzer_)) { region_cover = false; @@ -902,7 +898,7 @@ class ChildReplacer : private StmtMutator { int seq_index_; }; -void ScheduleStateNode::Replace(const StmtSRef& _src_sref, const Stmt& tgt_stmt, +void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt, const Map& _block_sref_reuse) { if (this->debug_mask != 0) { const StmtNode* src_stmt = _src_sref->stmt; diff --git a/tests/python/unittest/test_meta_schedule_byoc.py b/tests/python/unittest/test_meta_schedule_byoc.py deleted file mode 100644 index fe50350d5133..000000000000 --- a/tests/python/unittest/test_meta_schedule_byoc.py +++ /dev/null @@ -1,198 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" Test Meta Schedule Builder """ -# pylint: disable=missing-docstring - -import sys - -import pytest -import tvm -from tvm import relay -from tvm.meta_schedule.arg_info import TensorInfo -from tvm.meta_schedule.builder import BuilderInput, LocalBuilder -from tvm.meta_schedule.runner import EvaluatorConfig, LocalRunner, RunnerInput -from tvm.meta_schedule.testing import get_network -from tvm.meta_schedule.testing.byoc_trt import ( - build_relay, - build_relay_with_tensorrt, - run_with_graph_executor, -) -from tvm.relay import testing -from tvm.relay.op.contrib import tensorrt -from tvm.target import Target -from tvm.tir import FloatImm - -has_tensorrt_codegen = pytest.mark.skipif( - not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available" -) -has_tensorrt_runtime = pytest.mark.skipif( - not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" -) - -# conv2d+relu network -def get_conv2d_relu( - data_shape, - out_channels, - kernel_size, - strides, - padding, - dilation, - groups, - data_layout, - kernel_layout, - dtype, -): - - data = relay.var("data", relay.TensorType(data_shape, dtype)) - weight = relay.var("weight") - - net = relay.nn.conv2d( - data=data, - weight=weight, # conv kernel - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - channels=out_channels, - kernel_size=kernel_size, - data_layout=data_layout, - kernel_layout=kernel_layout, - ) - net = relay.add(net, net) - net = relay.nn.relu(net) - - inputs = relay.analysis.free_vars(net) - return relay.Function(inputs, net) - - -def verify_meta_schedule_with_tensorrt( - mod, - params, - data_shape, - use_meta_sched: bool = True, - use_trt: bool = True, - mode: str = "vm", -): - if use_meta_sched: - # With meta_schedule - dev = "nvidia/geforce-rtx-2080" - # Build - builder = LocalBuilder( - f_build=build_relay_with_tensorrt if use_trt else build_relay, - timeout_sec=1000, - ) - builder_input = BuilderInput(mod, Target(dev, host="llvm"), params) - builder_result = builder.build([builder_input])[0] - assert builder_result.error_msg is None, builder_result.error_msg - assert builder_result.artifact_path is not None - - # Run - runner_input = RunnerInput( - builder_result.artifact_path, - device_type="cuda", - args_info=[TensorInfo("float32", data_shape)], - ) - runner = LocalRunner( - evaluator_config=EvaluatorConfig( - number=5, - repeat=2, - min_repeat_ms=0, - enable_cpu_cache_flush=False, - ), - f_run_evaluator=run_with_graph_executor, - ) - - # Run the module - runner_future = runner.run([runner_input])[0] - runner_result = runner_future.result() - assert runner_result is not None - assert runner_result.error_msg is None, runner_result.error_msg - assert runner_result.run_secs is not None - - for result in runner_result.run_secs: - if isinstance(result, FloatImm): - result = result.value - assert isinstance(result, float) - assert result >= 0.0 - - else: - # Without meta_schedule - if use_trt: - mod, config = tensorrt.partition_for_tensorrt(mod) - with tvm.transform.PassContext( - opt_level=3, config={"relay.ext.tensorrt.options": config} - ): - _func = relay.create_executor( - mode, mod=mod, device=tvm.cuda(0), target="cuda" - ).evaluate() - else: - with tvm.transform.PassContext(opt_level=3): - _func = relay.create_executor( - mode, mod=mod, device=tvm.cuda(0), target="cuda", params=params - ).evaluate() - - -@has_tensorrt_codegen -def test_conv2d_relu(): - data_shape = (1, 1280, 14, 14) - out_channels = 256 - kernel_size, strides, padding, dilation, groups = (1, 1), (1, 1), (0, 0, 0, 0), (1, 1), 1 - data_layout, kernel_layout = "NCHW", "OIHW" - dtype = "float32" - - f = get_conv2d_relu( - data_shape, - out_channels, - kernel_size, - strides, - padding, - dilation, - groups, - data_layout, - kernel_layout, - dtype, - ) - - mod, params = testing.create_workload(f) - verify_meta_schedule_with_tensorrt(mod, params, data_shape) - - -@has_tensorrt_codegen -@pytest.mark.parametrize( - "model_name", - ["resnet-50", "mobilenet"], -) -@pytest.mark.parametrize("batch_size", [1, 8]) -@pytest.mark.parametrize("use_meta_sched", [True]) -@pytest.mark.parametrize("use_trt", [True, False]) -def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool): - mod, params, input_shape, _oshape = get_network( - name=model_name, - batch_size=batch_size, - ) - verify_meta_schedule_with_tensorrt( - mod, - params, - input_shape, - use_meta_sched=use_meta_sched, - use_trt=use_trt, - mode="vm", - ) - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py index c939792b55ff..4cb018b29aa4 100644 --- a/tests/python/unittest/test_meta_schedule_cost_model.py +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -14,13 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List - -import tempfile +# pylint: disable=missing-docstring import os import re -import sys import shutil +import sys +import tempfile +from typing import List + +import numpy as np import pytest import tvm @@ -32,11 +34,6 @@ from tvm.meta_schedule.tune_context import TuneContext from tvm.script import tir as T from tvm.tir.schedule.schedule import Schedule -from tvm.meta_schedule.search_strategy import MeasureCandidate -from tvm.meta_schedule.runner import RunnerResult -from tvm.meta_schedule.feature_extractor import RandomFeatureExtractor -from tvm.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel -from tvm.meta_schedule.tune_context import TuneContext # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @tvm.script.ir_module diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index 348339aee2e0..95cf6ebaeb43 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -22,7 +22,6 @@ import pytest import tvm -from tvm.tir.schedule import BlockRV, Schedule from tvm.error import TVMError from tvm.meta_schedule import TuneContext from tvm.meta_schedule.schedule_rule import PyScheduleRule diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py index 3eb050db3baa..49a3f6309183 100644 --- a/tests/python/unittest/test_meta_schedule_space_generator.py +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -23,9 +23,6 @@ import pytest import tvm -from tvm._ffi.base import TVMError -from tvm.ir.module import IRModule -from tvm.meta_schedule.space_generator.space_generator import PySpaceGenerator from tvm.script import tir as T from tvm.tir.schedule import Schedule from tvm.meta_schedule.space_generator import ScheduleFn, PySpaceGenerator, SpaceGeneratorUnion diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index d3c4dbca826f..8fe3ccb57524 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -29,7 +29,7 @@ from tvm.meta_schedule.runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult from tvm.meta_schedule.search_strategy import ReplayTrace from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.task_scheduler import PyTaskScheduler, RoundRobin +from tvm.meta_schedule.task_scheduler import PyTaskScheduler, RoundRobin, GradientBased from tvm.script import tir as T from tvm.tir import Schedule @@ -234,7 +234,6 @@ def test_meta_schedule_task_scheduler_multiple(): ) round_robin.tune() assert len(database) == num_trials_total * len(tasks) - print(database.workload_reg) for task in tasks: assert ( len( @@ -335,5 +334,57 @@ def next_task_id(self) -> int: ) +def test_meta_schedule_task_scheduler_multiple_gradient_based(): + num_trials_per_iter = 6 + num_trials_total = 101 + tasks = [ + TuneContext( + MatmulModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="Matmul", + rand_state=42, + ), + TuneContext( + MatmulReluModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="MatmulRelu", + rand_state=0xDEADBEEF, + ), + TuneContext( + BatchMatmulModule, + target=tvm.target.Target("llvm"), + space_generator=ScheduleFn(sch_fn=_schedule_batch_matmul), + search_strategy=ReplayTrace(num_trials_per_iter, num_trials_total), + task_name="BatchMatmul", + rand_state=0x114514, + ), + ] + database = DummyDatabase() + gradient_based = GradientBased( + tasks, + DummyBuilder(), + DummyRunner(), + database, + measure_callbacks=[measure_callback.AddToDatabase()], + seed=0x20220214, + ) + gradient_based.tune() + assert len(database) == num_trials_total * len(tasks) + for task in tasks: + assert ( + len( + database.get_top_k( + database.commit_workload(task.mod), + 100000, + ) + ) + == num_trials_total + ) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index f0f5051f5c33..e1cf399d49a1 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -755,9 +755,8 @@ def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None: T.where(j + i < 16) B[v] = A[v] with T.block("C"): - v = T.axis.spatial(16, j) - T.reads(B[v : v + 2]) - T.writes(C[v]) + v = T.axis.S(16, j) + T.reads([B[v : v + 2]]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") @@ -1254,28 +1253,5 @@ def test_fail_all_producers_under_loop(): sch.reverse_compute_at(block, loop) -def test_compute_at_tiled_pooling_cache(): - sch = tir.Schedule(tiled_pooling_cache, debug_mask="all") - compute = sch.get_block("compute") - _, w_o, _, _, _, _ = sch.get_loops(compute) - cache = sch.get_block("cache") - dache = sch.get_block("dache") - sch.compute_at(cache, w_o) - sch.compute_at(dache, w_o) - tvm.ir.assert_structural_equal(tiled_pooling_cache_after_compute_at, sch.mod["main"]) - verify_trace_roundtrip(sch=sch, mod=tiled_pooling_cache) - - -def test_reverse_compute_at_floordiv_and_floormod_indices(): - sch = tir.Schedule(floordiv_and_floormod_indices, debug_mask="all") - A = sch.get_block("A") - B = sch.get_block("B") - sch.reverse_compute_at(B, sch.get_loops(A)[0]) - tvm.ir.assert_structural_equal( - floordiv_and_floormod_indices_after_reverse_compute_at, sch.mod["main"] - ) - verify_trace_roundtrip(sch=sch, mod=floordiv_and_floormod_indices) - - if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index cf9621dc1d4c..cc2b114824a5 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -25,7 +25,7 @@ from tvm.tir.schedule.testing import verify_trace_roundtrip -# pylint: disable=no-member,invalid-name,unused-variable,line-too-long +# pylint: disable=no-member,invalid-name,unused-variable @T.prim_func