Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoSchedule] Sparse dense tuning support with custom sketch rule #7313

Merged
merged 36 commits into from
Mar 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/auto_scheduler/measure_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
namespace tvm {
namespace auto_scheduler {

const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5"; // NOLINT(*)
const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.6"; // NOLINT(*)

/*! \brief Callback for logging the input and results of measurements to file */
class RecordToFileNode : public MeasureCallbackNode {
Expand Down
8 changes: 7 additions & 1 deletion include/tvm/auto_scheduler/search_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_AUTO_SCHEDULER_SEARCH_TASK_H_

#include <tvm/auto_scheduler/compute_dag.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/target/target.h>

namespace tvm {
Expand Down Expand Up @@ -120,6 +121,8 @@ class SearchTaskNode : public Object {
HardwareParams hardware_params;
/*! \brief The layout rewrite option used for measuring programs. */
LayoutRewriteOption layout_rewrite_option;
/*! \brief Names of some user defined input data used in program measuring. */
Array<String> task_input_names;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("compute_dag", &compute_dag);
Expand All @@ -128,6 +131,7 @@ class SearchTaskNode : public Object {
v->Visit("target_host", &target_host);
v->Visit("hardware_params", &hardware_params);
v->Visit("layout_rewrite_option", &layout_rewrite_option);
v->Visit("task_input_names", &task_input_names);
}

static constexpr const char* _type_key = "auto_scheduler.SearchTask";
Expand All @@ -148,9 +152,11 @@ class SearchTask : public ObjectRef {
* \param target_host The target host device of this search task.
* \param hardware_params Hardware parameters used in this search task.
* \param layout_rewrite_option The layout rewrite option used for measuring programs.
* \param task_input_names Names of some user defined input data used in program measuring.
*/
SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host,
Optional<HardwareParams> hardware_params, LayoutRewriteOption layout_rewrite_option);
Optional<HardwareParams> hardware_params, LayoutRewriteOption layout_rewrite_option,
Array<String> task_input_names);

TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode);
};
Expand Down
1 change: 1 addition & 0 deletions python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
LocalRunner,
RPCRunner,
LocalRPCMeasureContext,
register_task_input_check_func,
)
from .measure_record import RecordToFile, RecordReader, load_best_record, load_records, save_records
from .relay_integration import (
Expand Down
166 changes: 154 additions & 12 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import shutil
import tempfile
import multiprocessing
import logging

import tvm._ffi
from tvm.runtime import Object, module, ndarray
Expand All @@ -50,6 +51,7 @@
call_func_with_timeout,
check_remote,
get_const_tuple,
get_func_name,
make_traceback_info,
request_remote,
)
Expand All @@ -58,6 +60,8 @@
deserialize_workload_registry_entry,
)

# pylint: disable=invalid-name
logger = logging.getLogger("auto_scheduler")

# The time cost for measurements with errors
# We use 1e10 instead of sys.float_info.max for better readability in log
Expand Down Expand Up @@ -223,6 +227,7 @@ def recover_measure_input(inp, rebuild_state=False):
target_host=task.target_host,
hardware_params=task.hardware_params,
layout_rewrite_option=task.layout_rewrite_option,
task_inputs=list(task.task_input_names),
)

if rebuild_state:
Expand Down Expand Up @@ -719,6 +724,97 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
return results


TASK_INPUT_CHECK_FUNC_REGISTRY = {}


def register_task_input_check_func(func_name, f=None, override=False):
"""Register a function that checks the input buffer map.

The input function should take a list of Tensor wich indicate the Input/output Tensor of a TVM
subgraph and return a Map from the input Tensor to its buffer name.

Parameters
----------
func_name : Union[Function, str]
The check function that returns the compute declaration Tensors or its function name.
f : Optional[Function]
The check function to be registered.
override : boolean = False
Whether to override existing entry.

Examples
--------
.. code-block:: python

@auto_scheduler.register_task_input_check_func
def check_task_input_by_placeholder_name(args : List[Tensor]):
tensor_input_map = {}
for arg in args:
if isinstance(arg.op, tvm.te.PlaceholderOp):
if arg.op.name != "placeholder":
tensor_input_map[arg] = arg.op.name
return tensor_input_map
"""
global TASK_INPUT_CHECK_FUNC_REGISTRY

if callable(func_name):
f = func_name
func_name = get_func_name(f)
if not isinstance(func_name, str):
raise ValueError("expect string function name")

def register(myf):
"""internal register function"""
if func_name in TASK_INPUT_CHECK_FUNC_REGISTRY and not override:
raise RuntimeError("%s has been registered already" % func_name)
TASK_INPUT_CHECK_FUNC_REGISTRY[func_name] = myf
return myf

if f:
return register(f)
return register


def _prepare_input_map(args):
"""This function deals with special task inputs. Map the input Tensor of a TVM subgraph
to a specific buffer name in the global buffer map.

Parameters
----------
args : List[Tensor]
Input/output Tensor of a TVM subgraph.

Returns
-------
Dict[Tensor, str] :
Map from the input Tensor to its buffer name.

Notes
-----
The buffer name is specially designed, and these buffer should be provided in
`SearchTask(..., task_inputs={...})`.
"""
# pylint: disable=import-outside-toplevel

global TASK_INPUT_CHECK_FUNC_REGISTRY

# A dict that maps the input tensor arg to a buffer name
tensor_input_map = {}

# Case 0: Check placeholder name
for arg in args:
if isinstance(arg.op, tvm.te.PlaceholderOp):
if arg.op.name != "placeholder":
tensor_input_map[arg] = arg.op.name

# Case 1: Check specific tensor inputs
for func_name in TASK_INPUT_CHECK_FUNC_REGISTRY:
func = TASK_INPUT_CHECK_FUNC_REGISTRY[func_name]
tensor_input_map.update(func(args))

return tensor_input_map


def _timed_eval_func(
inp_serialized,
build_res,
Expand All @@ -729,7 +825,11 @@ def _timed_eval_func(
enable_cpu_cache_flush,
verbose,
):
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

inp = MeasureInput.deserialize(inp_serialized)
task_input_names = inp.task.task_input_names
tic = time.time()
error_no = 0
error_msg = None
Expand Down Expand Up @@ -758,11 +858,31 @@ def _timed_eval_func(

if error_no == 0:
try:
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args]
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"
for arg in args:
random_fill(arg)

tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {}
args = []
task_inputs_count = 0
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
args.append(get_task_input_buffer(inp.task.workload_key, tensor_name))
task_inputs_count += 1
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with `SearchTask(..., task_inputs={...})`"
)
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx)
random_fill(empty_array)
args.append(empty_array)
if task_inputs_count != len(task_input_names):
logger.warning(
"task_inputs not fully matched, check if there's any unexpected error"
)
ctx.sync()
costs = time_f(*args).results
# pylint: disable=broad-except
Expand Down Expand Up @@ -911,7 +1031,11 @@ def _timed_rpc_run(
enable_cpu_cache_flush,
verbose,
):
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

inp = MeasureInput.deserialize(inp_serialized)
task_input_names = inp.task.task_input_names
tic = time.time()
error_no = 0
error_msg = None
Expand Down Expand Up @@ -943,18 +1067,36 @@ def _timed_rpc_run(

if error_no == 0:
try:
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args]
try:
random_fill = remote.get_function("tvm.contrib.random.random_fill")
except AttributeError:
raise AttributeError(
"Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
random_fill = remote.get_function("tvm.contrib.random.random_fill")
assert (
random_fill
), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices"

tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {}
args = []
task_inputs_count = 0
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
args.append(get_task_input_buffer(inp.task.workload_key, tensor_name))
task_inputs_count += 1
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with `SearchTask(..., task_inputs={...})`"
)
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx)
random_fill(empty_array)
args.append(empty_array)
if task_inputs_count != len(task_input_names):
logger.warning(
"task_inputs not fully matched, check if there's any unexpected error"
)
for arg in args:
random_fill(arg)
ctx.sync()

costs = time_f(*args).results

# clean up remote files
remote.remove(build_res.filename)
remote.remove(os.path.splitext(build_res.filename)[0] + ".so")
Expand Down
Loading