Skip to content

Commit

Permalink
Fix implementation & add tests.
Browse files Browse the repository at this point in the history
Fix rebase.

Continue to fix rebase.

Fix rebase.
  • Loading branch information
zxybazh committed Feb 22, 2022
1 parent 66d926c commit aa0178b
Show file tree
Hide file tree
Showing 44 changed files with 393 additions and 768 deletions.
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ class PyDatabaseNode : public DatabaseNode {
// PackedFuncs are all not visited, because the reflection system doesn't take care of them,
// so it cannot be accessible on the python side. If there is such need from the future,
// we can then add corresponding accessor methods to help access on python.
//
// `f_has_workload` is not visited
// `f_commit_workload` is not visited
// `f_commit_tuning_record` is not visited
Expand Down
32 changes: 31 additions & 1 deletion include/tvm/meta_schedule/task_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ namespace meta_schedule {
class TaskSchedulerNode : public runtime::Object {
public:
/*! \brief The function type of the objective function. */
using FObjectiveFunc = TypedPackedFunc<double(Array<FloatImm>)>;
using FObjectiveFunc = TypedPackedFunc<FloatImm(Array<FloatImm>)>;
/*! \brief The function type of the tag genration function. */
using FTagGenerationFunc = TypedPackedFunc<String(const IRModule&)>;

Expand Down Expand Up @@ -264,6 +264,36 @@ class TaskScheduler : public runtime::ObjectRef {
Database database, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks);
/*!
* \brief Create a task scheduler that fetches tasks in a gradient based fashion.
* \param tasks The tasks to be tuned.
* \param builder The builder of the scheduler.
* \param runner The runner of the scheduler.
* \param database The database of the scheduler.
* \param alpha The parameter alpha to control gradient computation.
* \param beta The parameter beta to control gradient computation.
* \param backward_window_size The parameter to control backward window size.
* \param seed The random seed.
* \param task_weights The weights of each task.
* \param objective_fun_namec The name of objective function for gradient optimization.
* \param tag_generation_func_name The name of function to generate similarity tag for workloads.
* \param cost_model The cost model of the scheduler.
* \param measure_callbacks The measure callbacks of the scheduler.
* \return The task scheduler created.
*/
TVM_DLL static TaskScheduler GradientBased(Array<TuneContext> tasks, //
Builder builder, //
Runner runner, //
Database database, //
double alpha, //
double beta, //
int backward_window_size, //
support::LinearCongruentialEngine::TRandState seed, //
Array<FloatImm> task_weights, //
String objective_func_name, //
String tag_generation_func_name, //
Optional<CostModel> cost_model, //
Optional<Array<MeasureCallback>> measure_callbacks);
/*!
* \brief Create a task scheduler with customized methods on the python-side.
* \param tasks The tasks to be tuned.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/meta_schedule/tune_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class TuneContextNode : public runtime::Object {
/*! \brief The probability of using certain mutator. */
Map<Mutator, FloatImm> mutator_probs;
/*! \brief The name of the tuning task. */
String task_name;
Optional<String> task_name;
/*! \brief The random state. */
support::LinearCongruentialEngine::TRandState rand_state;
/*! \brief The number of threads to be used. */
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,28 +500,28 @@ class ScheduleNode : public runtime::Object {
/******** Schedule: Annotation ********/
/*!
* \brief Annotate a loop with a key value pair
* \param loop The loop to be annotated
* \param loop_rv The loop to be annotated
* \param ann_key The annotation key
* \param ann_val The annotation value, a string or a ExprRV
*/
virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) = 0;
/*!
* \brief Annotate a block with a key value pair
* \param loop The block to be annotated
* \param block_rv The block to be annotated
* \param ann_key The annotation key
* \param ann_val The annotation value, a string or a ExprRV
*/
virtual void Annotate(const BlockRV& block_rv, const String& ann_key,
const ObjectRef& ann_val) = 0;
/*!
* \brief Unannotate a loop's annotation with key ann_key
* \param loop The loop to be unannotated
* \param loop_rv The loop to be unannotated
* \param ann_key The annotation key
*/
virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0;
/*!
* \brief Unannotate a block's annotation with key ann_key
* \param loop The block to be unannotated
* \param block_rv The block to be unannotated
* \param ann_key The annotation key
*/
virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0;
Expand Down
87 changes: 0 additions & 87 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1442,93 +1442,6 @@ constexpr const char* nested_software_pipeline_stage = "nested_software_pipeline
*/
constexpr const char* nested_software_pipeline_order = "nested_software_pipeline_order";

/*!
* \brief Mark that the block need to add predicate for block var bounds during lowering
*/
constexpr const char* require_block_var_bound_predicate = "require_bound_predicate";

/*!
* \brief Mark that the loop should be further skip and bound to environment threads to enable
* cooperative fetching.
*/
constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch";

/*!
* \brief Mark that the block should be further rewritten using tensorization.
*/
constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";

/*! \brief Mark that tensor core is enabled in the PrimExpr */
constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled";

/*! \brief The allowed range of thread extent in thread bindings */
constexpr const char* meta_schedule_thread_extent_low_inclusive =
"meta_schedule.thread_extent_low_inclusive";

/*! \brief The allowed range of thread extent in thread bindings */
constexpr const char* meta_schedule_thread_extent_high_inclusive =
"meta_schedule.thread_extent_high_inclusive";

/*!
* \brief Mark a block as generated by cache_read or cache_write block.
* 0 means cache_read; 1 means cache_write.
* \sa meta_schedule_cache_type_read
* \sa meta_schedule_cache_type_write
*/
constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type";

/*! \sa meta_schedule_cache_type */
constexpr const int meta_schedule_cache_type_read = 0;

/*! \sa meta_schedule_cache_type */
constexpr const int meta_schedule_cache_type_write = 1;

/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";

/*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */
constexpr const char* meta_schedule_random_compute_producer =
"meta_schedule.random_compute_producer";

/*! \brief Mark auto-parallel setting on the block. */
constexpr const char* meta_schedule_parallel = "meta_schedule.parallel";

/*! \brief Mark auto-vectorize setting on the block. */
constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize";

/*! \brief Mark auto-unroll setting on the block. */
constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit";

/*! \brief Mark auto-unroll setting on the block. */
constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";

/*! \brief Pragma: auto-unroll, max_step */
constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step";

/*! \brief Pragma: unroll explicit */
constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit";

/*! \brief Mark the scope of the software pipeline */
constexpr const char* software_pipeline_scope = "software_pipeline_scope";

/*! \brief Mark the stage of a statement in the software pipeline */
constexpr const char* software_pipeline_stage = "software_pipeline_stage";

/*! \brief Mark the order of a statement in the software pipeline */
constexpr const char* software_pipeline_order = "software_pipeline_order";

/*! \brief Mark the stage of the result of the software pipeline lowering. This is used to specify
* the behavior of nested software pipelines. Should be a 3-tuple consisting of the stage of the
* prologue, the body, and the epilogue of the software pipeline.
*/
constexpr const char* nested_software_pipeline_stage = "nested_software_pipeline_stage";

/*! \brief Mark the stage of the result of the software pipeline lowering. This is used to specify
* the behavior of nested software pipelines. Should be a 3-tuple consisting of the stage of the
* prologue, the body, and the epilogue of the software pipeline.
*/
constexpr const char* nested_software_pipeline_order = "nested_software_pipeline_order";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
14 changes: 0 additions & 14 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -383,20 +383,6 @@ TVM_DLL Pass LowerInitBlock();
*/
TVM_DLL Pass PlanAndUpdateBufferAllocationLocation();

/*!
* \brief Narrow the extents of some loops by checking whether some constraints in the block iter
* bound predicates can be directly applied on the loops.
* \return The pass.
*/
TVM_DLL Pass ApplyBlockBoundPredicate();

/*!
* \brief Narrow the extents of some loops by checking whether some constraints in the block iter
* bound predicates can be directly applied on the loops.
* \return The pass.
*/
TVM_DLL Pass ApplyBlockBoundPredicate();

/*!
* \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the
* corresponding iter_values in BlockRealize, for opaque blocks by removing all
Expand Down
12 changes: 3 additions & 9 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,9 @@ def auto_schedule_topi(func_name, outs):
"""

# pylint: disable=import-outside-toplevel
from tvm.auto_scheduler.measure import ( # lazily import to avoid recursive dependency
from tvm.auto_scheduler.measure import (
prepare_input_map,
)
) # lazily import to avoid recursive dependency

io_tensors, has_layout_free, has_complex_op = traverse_to_get_io_tensors(outs)
if not io_tensors: # The compute includes dynamic shapes which are not supported yet.
Expand Down Expand Up @@ -482,10 +482,4 @@ def is_auto_scheduler_enabled():
enabled: bool
Whether the auto-scheduler is enabled
"""
return PassContext.current().config.get(
"relay.backend.use_auto_scheduler",
False,
) or PassContext.current().config.get(
"relay.backend.use_meta_schedule",
False,
)
return PassContext.current().config.get("relay.backend.use_auto_scheduler", False)
3 changes: 1 addition & 2 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,7 @@ def print_best(self, log_file, print_mode="schedule"):
code: str
The best schedule code in python API or CUDA source code
"""
inp, res = load_best_record(log_file, self.workload_key)
print("Best codes (ms):", [float(c) * 1000.0 for c in res.costs])
inp, _ = load_best_record(log_file, self.workload_key)
if inp is None:
raise RuntimeError(
"Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file)
Expand Down
5 changes: 1 addition & 4 deletions python/tvm/auto_scheduler/workload_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,7 @@ def workload_key_to_tensors(workload_key):
assert callable(value)

args = deserialize_args(workload[1:])
result = value(*args)
if isinstance(result, tuple):
result = list(result)
return result
return value(*args)


def serialize_workload_registry_entry(workload_key):
Expand Down
18 changes: 2 additions & 16 deletions python/tvm/meta_schedule/builder/local_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,13 @@

from tvm._ffi import register_func
from tvm.ir import IRModule
from tvm.runtime import NDArray
from tvm.runtime import Module, load_param_dict, save_param_dict
from tvm.runtime import Module, NDArray, load_param_dict, save_param_dict
from tvm.target import Target

from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind
from ..utils import cpu_count, get_global_func_with_default_on_worker
from .builder import BuilderInput, BuilderResult, PyBuilder

logger = logging.getLogger(__name__)


def _serialize_params(params: Optional[Dict[str, NDArray]]) -> Optional[bytearray]:
if params is None:
return None
return save_param_dict(params)


def _deserialize_params(params: Optional[bytearray]) -> Optional[Dict[str, NDArray]]:
if params is None:
return None
return load_param_dict(params)


logger = logging.getLogger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -142,6 +127,7 @@ def __init__(
The initializer to be used for the worker processes.
"""
super().__init__()

if max_workers is None:
max_workers = cpu_count(logical=True)
logger.info("LocalBuilder: max_workers = %d", max_workers)
Expand Down
8 changes: 3 additions & 5 deletions python/tvm/meta_schedule/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,17 @@
# specific language governing permissions and limitations
# under the License.
"""Meta Schedule CostModel."""

from typing import List
import ctypes
from typing import List

import numpy as np

import numpy as np # type: ignore
from tvm._ffi import register_object
from tvm.runtime import Object

from .. import _ffi_api
from ..runner import RunnerResult
from ..tune_context import TuneContext
from ..search_strategy import MeasureCandidate
from ..tune_context import TuneContext
from ..utils import _get_hex_address, check_override


Expand Down
9 changes: 4 additions & 5 deletions python/tvm/meta_schedule/cost_model/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Cost model metrics for meta schedule"""
from typing import List
import numpy as np
import numpy as np # type: ignore


def max_curve(trial_scores: np.ndarray) -> List[float]:
def max_curve(trial_scores: np.ndarray) -> np.ndarray:
"""f(n) = max([s[i] fo i < n])
Parameters
Expand All @@ -29,8 +28,8 @@ def max_curve(trial_scores: np.ndarray) -> List[float]:
Returns
-------
curve : List[float]
function values
curve : np.ndarray
A vector, the max-curve function values
"""
ret = np.empty(len(trial_scores))
keep = -1e9
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/meta_schedule/cost_model/random_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
"""
Random cost model
"""
from typing import List, Union, Tuple, Optional
from typing import List, Optional, Tuple, Union

import numpy as np
import numpy as np # type: ignore

from ..cost_model import PyCostModel
from ..runner import RunnerResult
from ..tune_context import TuneContext
from ..search_strategy import MeasureCandidate
from ..cost_model import PyCostModel
from ..tune_context import TuneContext


class RandomModel(PyCostModel):
Expand Down Expand Up @@ -70,7 +70,7 @@ def load(self, path: str) -> None:
path : str
The file path.
"""
self.random_state = tuple(np.load(path, allow_pickle=True))
self.random_state = tuple(np.load(path, allow_pickle=True)) # type: ignore

def save(self, path: str) -> None:
"""Save the cost model to given file location.
Expand Down Expand Up @@ -116,7 +116,7 @@ def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> n
The predicted running results.
"""
np.random.set_state(self.random_state)
# todo(@zxybazh): Use numpy's RandState object:
# TODO(@zxybazh): Use numpy's RandState object:
# https://numpy.org/doc/1.16/reference/generated/numpy.random.RandomState.html#numpy.random.RandomState
result = np.random.rand(len(candidates)) * self.max_range
self.random_state = np.random.get_state()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Random Feature Extractor."""
from typing import List, Union, Tuple

import numpy as np
import numpy as np # type: ignore
from tvm.runtime.ndarray import NDArray, array

from ..tune_context import TuneContext
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/runner/local_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
run_evaluator_common,
)

logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__) # pylint: disable=invalid-name


class LocalRunnerFuture(RunnerFuture):
Expand Down
Loading

0 comments on commit aa0178b

Please sign in to comment.