Skip to content

Commit

Permalink
[AutoSchedule] Sparse dense tuning support with custom sketch rule (a…
Browse files Browse the repository at this point in the history
…pache#7313)

* Add sparse dense tuning tutorial

* Add sparse input fusion

* Update the dag to support output fusion

* Update

* Add task input to search_task

* Update

* Add search_inputs to measure

* Lint fix

* Lint fix

* Update

* Update

* Update

* Update

* Add file save load support

* Update

* Update

* Update

* Remove add_task_inputs API

* Update

* Update

* Update

* Lint fix

* Lint fix

* Lint fix

* Lint fix

* Update

* Add example ci_log

* Update

* retrigger ci

* Update

* Update

* Update

* Lint fix

* Lint fix

* Lint fix
  • Loading branch information
jcf94 authored and trevor-m committed May 11, 2021
1 parent f6006fc commit 32cc782
Show file tree
Hide file tree
Showing 15 changed files with 1,109 additions and 26 deletions.
2 changes: 1 addition & 1 deletion include/tvm/auto_scheduler/measure_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
namespace tvm {
namespace auto_scheduler {

const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.5"; // NOLINT(*)
const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.6"; // NOLINT(*)

/*! \brief Callback for logging the input and results of measurements to file */
class RecordToFileNode : public MeasureCallbackNode {
Expand Down
8 changes: 7 additions & 1 deletion include/tvm/auto_scheduler/search_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_AUTO_SCHEDULER_SEARCH_TASK_H_

#include <tvm/auto_scheduler/compute_dag.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/target/target.h>

namespace tvm {
Expand Down Expand Up @@ -120,6 +121,8 @@ class SearchTaskNode : public Object {
HardwareParams hardware_params;
/*! \brief The layout rewrite option used for measuring programs. */
LayoutRewriteOption layout_rewrite_option;
/*! \brief Names of some user defined input data used in program measuring. */
Array<String> task_input_names;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("compute_dag", &compute_dag);
Expand All @@ -128,6 +131,7 @@ class SearchTaskNode : public Object {
v->Visit("target_host", &target_host);
v->Visit("hardware_params", &hardware_params);
v->Visit("layout_rewrite_option", &layout_rewrite_option);
v->Visit("task_input_names", &task_input_names);
}

static constexpr const char* _type_key = "auto_scheduler.SearchTask";
Expand All @@ -148,9 +152,11 @@ class SearchTask : public ObjectRef {
* \param target_host The target host device of this search task.
* \param hardware_params Hardware parameters used in this search task.
* \param layout_rewrite_option The layout rewrite option used for measuring programs.
* \param task_input_names Names of some user defined input data used in program measuring.
*/
SearchTask(ComputeDAG compute_dag, String workload_key, Target target, Target target_host,
Optional<HardwareParams> hardware_params, LayoutRewriteOption layout_rewrite_option);
Optional<HardwareParams> hardware_params, LayoutRewriteOption layout_rewrite_option,
Array<String> task_input_names);

TVM_DEFINE_OBJECT_REF_METHODS(SearchTask, ObjectRef, SearchTaskNode);
};
Expand Down
1 change: 1 addition & 0 deletions python/tvm/auto_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
LocalRunner,
RPCRunner,
LocalRPCMeasureContext,
register_task_input_check_func,
)
from .measure_record import RecordToFile, RecordReader, load_best_record, load_records, save_records
from .relay_integration import (
Expand Down
166 changes: 154 additions & 12 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import shutil
import tempfile
import multiprocessing
import logging

import tvm._ffi
from tvm.runtime import Object, module, ndarray
Expand All @@ -50,6 +51,7 @@
call_func_with_timeout,
check_remote,
get_const_tuple,
get_func_name,
make_traceback_info,
request_remote,
)
Expand All @@ -58,6 +60,8 @@
deserialize_workload_registry_entry,
)

# pylint: disable=invalid-name
logger = logging.getLogger("auto_scheduler")

# The time cost for measurements with errors
# We use 1e10 instead of sys.float_info.max for better readability in log
Expand Down Expand Up @@ -223,6 +227,7 @@ def recover_measure_input(inp, rebuild_state=False):
target_host=task.target_host,
hardware_params=task.hardware_params,
layout_rewrite_option=task.layout_rewrite_option,
task_inputs=list(task.task_input_names),
)

if rebuild_state:
Expand Down Expand Up @@ -719,6 +724,97 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
return results


TASK_INPUT_CHECK_FUNC_REGISTRY = {}


def register_task_input_check_func(func_name, f=None, override=False):
"""Register a function that checks the input buffer map.
The input function should take a list of Tensor wich indicate the Input/output Tensor of a TVM
subgraph and return a Map from the input Tensor to its buffer name.
Parameters
----------
func_name : Union[Function, str]
The check function that returns the compute declaration Tensors or its function name.
f : Optional[Function]
The check function to be registered.
override : boolean = False
Whether to override existing entry.
Examples
--------
.. code-block:: python
@auto_scheduler.register_task_input_check_func
def check_task_input_by_placeholder_name(args : List[Tensor]):
tensor_input_map = {}
for arg in args:
if isinstance(arg.op, tvm.te.PlaceholderOp):
if arg.op.name != "placeholder":
tensor_input_map[arg] = arg.op.name
return tensor_input_map
"""
global TASK_INPUT_CHECK_FUNC_REGISTRY

if callable(func_name):
f = func_name
func_name = get_func_name(f)
if not isinstance(func_name, str):
raise ValueError("expect string function name")

def register(myf):
"""internal register function"""
if func_name in TASK_INPUT_CHECK_FUNC_REGISTRY and not override:
raise RuntimeError("%s has been registered already" % func_name)
TASK_INPUT_CHECK_FUNC_REGISTRY[func_name] = myf
return myf

if f:
return register(f)
return register


def _prepare_input_map(args):
"""This function deals with special task inputs. Map the input Tensor of a TVM subgraph
to a specific buffer name in the global buffer map.
Parameters
----------
args : List[Tensor]
Input/output Tensor of a TVM subgraph.
Returns
-------
Dict[Tensor, str] :
Map from the input Tensor to its buffer name.
Notes
-----
The buffer name is specially designed, and these buffer should be provided in
`SearchTask(..., task_inputs={...})`.
"""
# pylint: disable=import-outside-toplevel

global TASK_INPUT_CHECK_FUNC_REGISTRY

# A dict that maps the input tensor arg to a buffer name
tensor_input_map = {}

# Case 0: Check placeholder name
for arg in args:
if isinstance(arg.op, tvm.te.PlaceholderOp):
if arg.op.name != "placeholder":
tensor_input_map[arg] = arg.op.name

# Case 1: Check specific tensor inputs
for func_name in TASK_INPUT_CHECK_FUNC_REGISTRY:
func = TASK_INPUT_CHECK_FUNC_REGISTRY[func_name]
tensor_input_map.update(func(args))

return tensor_input_map


def _timed_eval_func(
inp_serialized,
build_res,
Expand All @@ -729,7 +825,11 @@ def _timed_eval_func(
enable_cpu_cache_flush,
verbose,
):
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

inp = MeasureInput.deserialize(inp_serialized)
task_input_names = inp.task.task_input_names
tic = time.time()
error_no = 0
error_msg = None
Expand Down Expand Up @@ -758,11 +858,31 @@ def _timed_eval_func(

if error_no == 0:
try:
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args]
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill", True)
assert random_fill, "Please make sure USE_RANDOM is ON in the config.cmake"
for arg in args:
random_fill(arg)

tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {}
args = []
task_inputs_count = 0
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
args.append(get_task_input_buffer(inp.task.workload_key, tensor_name))
task_inputs_count += 1
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with `SearchTask(..., task_inputs={...})`"
)
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx)
random_fill(empty_array)
args.append(empty_array)
if task_inputs_count != len(task_input_names):
logger.warning(
"task_inputs not fully matched, check if there's any unexpected error"
)
ctx.sync()
costs = time_f(*args).results
# pylint: disable=broad-except
Expand Down Expand Up @@ -911,7 +1031,11 @@ def _timed_rpc_run(
enable_cpu_cache_flush,
verbose,
):
# pylint: disable=import-outside-toplevel
from .search_task import get_task_input_buffer # lazily import to avoid recursive dependency

inp = MeasureInput.deserialize(inp_serialized)
task_input_names = inp.task.task_input_names
tic = time.time()
error_no = 0
error_msg = None
Expand Down Expand Up @@ -943,18 +1067,36 @@ def _timed_rpc_run(

if error_no == 0:
try:
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in build_res.args]
try:
random_fill = remote.get_function("tvm.contrib.random.random_fill")
except AttributeError:
raise AttributeError(
"Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
random_fill = remote.get_function("tvm.contrib.random.random_fill")
assert (
random_fill
), "Please make sure USE_RANDOM is ON in the config.cmake on the remote devices"

tensor_input_map = _prepare_input_map(build_res.args) if task_input_names else {}
args = []
task_inputs_count = 0
for arg in build_res.args:
if arg in tensor_input_map:
tensor_name = tensor_input_map[arg]
if tensor_name in task_input_names:
args.append(get_task_input_buffer(inp.task.workload_key, tensor_name))
task_inputs_count += 1
else:
raise ValueError(
"%s not found in task_inputs, " % (tensor_name)
+ "should provide with `SearchTask(..., task_inputs={...})`"
)
else:
empty_array = ndarray.empty(get_const_tuple(arg.shape), arg.dtype, ctx)
random_fill(empty_array)
args.append(empty_array)
if task_inputs_count != len(task_input_names):
logger.warning(
"task_inputs not fully matched, check if there's any unexpected error"
)
for arg in args:
random_fill(arg)
ctx.sync()

costs = time_f(*args).results

# clean up remote files
remote.remove(build_res.filename)
remote.remove(os.path.splitext(build_res.filename)[0] + ".so")
Expand Down
Loading

0 comments on commit 32cc782

Please sign in to comment.