From 4a92032953069fd9b98e955fdf51997b1b514ae4 Mon Sep 17 00:00:00 2001 From: Thierry Moreau Date: Thu, 23 May 2019 18:54:02 -0700 Subject: [PATCH] autoTVM task extraction for VTA (nnvm for now) --- nnvm/python/nnvm/top/nn.py | 10 +- python/tvm/autotvm/task/nnvm_integration.py | 71 +++--- python/tvm/autotvm/task/topi_integration.py | 69 ++++-- vta/scripts/tune_resnet.py | 231 ++++++++++++++++++++ 4 files changed, 321 insertions(+), 60 deletions(-) create mode 100644 vta/scripts/tune_resnet.py diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 13964f4e25f69..128f985bd6d2f 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -114,25 +114,25 @@ def compute_conv2d(attrs, inputs, _): if groups == 1 and layout == 'NCHW4c' and inputs[0].dtype == 'int8': # pylint: disable=assignment-from-no-return out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding, - dilation, layout, out_dtype=out_dtype) + dilation, layout, out_dtype) # pylint: enable=assignment-from-no-return elif groups == 1: out = topi.nn.conv2d( - inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype=out_dtype) + inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype) elif layout == "NCHW" and \ groups == get_const_int(inputs[0].shape[1]) and \ groups == channels: out = topi.nn.depthwise_conv2d_nchw( - inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) + inputs[0], inputs[1], strides, padding, dilation, out_dtype) elif layout in ["NCHW", "NCHW4c"]: out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups, - out_dtype=out_dtype) + out_dtype) elif layout == "NHWC" and \ kernel_layout == "HWOI" and \ groups == get_const_int(inputs[0].shape[3]) and \ groups == channels: out = topi.nn.depthwise_conv2d_nhwc( - inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) + inputs[0], inputs[1], strides, padding, dilation, out_dtype) else: raise ValueError("not support arbitrary group number for now") diff --git a/python/tvm/autotvm/task/nnvm_integration.py b/python/tvm/autotvm/task/nnvm_integration.py index dbcee0e516e11..e4d2b3fb80235 100644 --- a/python/tvm/autotvm/task/nnvm_integration.py +++ b/python/tvm/autotvm/task/nnvm_integration.py @@ -27,15 +27,16 @@ from .task import create from .topi_integration import TaskExtractEnv +from .dispatcher import ApplyHistoryBest logger = logging.getLogger('autotvm') -def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): +def extract_from_graph(graph, shape, dtype, target, symbols, params, target_host=None): """ Extract tuning tasks from a nnvm graph. This function collects tuning tasks by building the graph - with a "tracing" target and tracing all the calls to topi. + and trace all the calls to topi. Parameters ---------- @@ -49,6 +50,8 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): The compilation target symbols : Array of nnvm.symbol Array of nnvm symbols want to be tuned + params : dict of str to NDArray + The parameter dictionary. target_host: tvm.target.Target The host compilation target @@ -78,32 +81,35 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): topi_funcs.extend(SYMBOL2TOPI[sym_name]) else: warnings.warn("Symbol %s is not tunable, ignored" % sym_name) - - # run compiler to collect all TOPI calls during compilation env.reset(topi_funcs) - # disable logger temporarily - old_state = logger.disabled - logger.disabled = True + with env: + # disable logger temporarily + old_state = logger.disabled + logger.disabled = True - # use a "tracing" target to do a fake compile for collecting topi calls - tracing_target = _target.create("llvm -device=tracing") - nnvm.compiler.engine.clear_cache() - nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype) + # run compiler to collect all TOPI calls during compilation + nnvm.compiler.engine.clear_cache() + nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype, + target_host=target_host, params=params) - logger.disabled = old_state + logger.disabled = old_state # create tasks for target tasks = [] for task_name, args in env.get_tasks(): - tasks.append(create(task_name, args, - target=target, target_host=target_host, - template_key='direct')) + try: + tsk = create(task_name, args, + target=target, target_host=target_host, + template_key='direct') + tasks.append(tsk) + except topi.InvalidShapeError: + print("[Warning] Invalid Shape during AutoTVM Task Creation") return tasks -def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_host=None): +def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, params, target_host=None): """ Extract tuning tasks from multiple nnvm graphs. This function is the multiple graph version of extract_from_graph @@ -120,6 +126,8 @@ def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_ The compilation target symbols : Array of nnvm.symbol Array of nnvm symbols want to be tuned + params : dict of str to NDArray + The parameter dictionary. target_host: tvm.target.Target The host compilation target @@ -149,28 +157,29 @@ def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_ topi_funcs.extend(SYMBOL2TOPI[sym_name]) else: warnings.warn("Symbol %s is not tunable, ignored" % sym_name) - - # run compiler to collect all TOPI calls during compilation env.reset(topi_funcs) - # disable logger temporarily - old_state = logger.disabled - logger.disabled = True + with env: + # disable logger temporarily + old_state = logger.disabled + logger.disabled = True - # use a "tracing" target to do a fake compile for collecting topi calls - tracing_target = _target.create("llvm -device=tracing") + nnvm.compiler.engine.clear_cache() + for graph, shape, dtype in zip(graphs, shapes, dtypes): + nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype) - nnvm.compiler.engine.clear_cache() - for graph, shape, dtype in zip(graphs, shapes, dtypes): - nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype) - - logger.disabled = old_state + logger.disabled = old_state # create tasks for target tasks = [] for task_name, args in env.get_tasks(): - tasks.append(create(task_name, args, - target=target, target_host=target_host, - template_key='direct')) + try: + tsk = create(task_name, args, + target=target, target_host=target_host, + template_key='direct') + tasks.append(tsk) + except topi.InvalidShapeError: + print("[Warning] Invalid Shape during AutoTVM Task Creation") return tasks + diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 3c983768ab3ea..58e78cf53c021 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -27,6 +27,9 @@ See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. """ +import warnings +import sys + from ... import _api_internal, tensor, placeholder, create_schedule from .task import args_to_workload, dispatcher, register @@ -73,6 +76,7 @@ def deserialize_args(args): class TaskExtractEnv: """Global environment for extracting tuning tasks from nnvm graph""" current = None + registered = None def __init__(self): import topi @@ -106,47 +110,64 @@ def __init__(self): topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw], } - self._register_tracing() + # support reflection for tracing + self.func_to_reflection = { + topi.nn.conv2d: lambda x: setattr(topi.nn, 'conv2d', x), + topi.nn.conv2d_NCHWc: lambda x: setattr(topi.nn, 'conv2d_NCHWc', x), + topi.nn.depthwise_conv2d_nchw: lambda x: setattr(topi.nn, 'depthwise_conv2d_nchw', x), + topi.nn.group_conv2d_nchw: lambda x: setattr(topi.nn, 'group_conv2d_nchw', x), + topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x), + topi.nn.dense: lambda x: setattr(topi.nn, 'dense', x), + topi.nn.bitserial_conv2d_nchw: lambda x: setattr(topi.nn, 'bitserial_conv2d_nchw', x), + topi.nn.bitserial_conv2d_nhwc: lambda x: setattr(topi.nn, 'bitserial_conv2d_nhwc', x), + topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x), + topi.nn.deformable_conv2d_nchw: lambda x: setattr(topi.nn, 'deformable_conv2d_nchw', x), + } + self._register_topi_task() self.task_collection = [] self.wanted_topi_funcs = list(self.topi_to_task.keys()) + self.modified_funcs = [] + + def __enter__(self): + self.task_collection = [] + self.modified_funcs = [] - def _register_tracing(self): - """Register tracing function to track the topi function call""" - # register topi compute for "tracing" target - for topi_compute in self.topi_to_task: + for topi_compute in self.wanted_topi_funcs: def _local_scope(compute_func): """start a scope to hold the local function in for loop""" - @compute_func.register("tracing", ) - def _tracing_topi_compute(*args, **kwargs): - assert not kwargs, "Do not support extracting tuning tasks when" \ - "kwargs is used in TOPI function call." \ + def _tracing_wrapper(*args, **kwargs): + assert not kwargs, "Do not support extracting tuning tasks when " \ + "kwargs is used in TOPI function call. " \ "Please modify it to use only positional args." + key = (self.topi_to_task[compute_func], serialize_args(args)) + if key not in self.task_collection: + self.task_collection.append(key) + + return compute_func(*args, **kwargs) + + self.func_to_reflection[topi_compute](_tracing_wrapper) + self.modified_funcs.append(topi_compute) - if compute_func in self.wanted_topi_funcs: # record this call - key = (self.topi_to_task[compute_func], serialize_args(args)) - if key not in self.task_collection: - self.task_collection.append(key) - return compute_func.fdefault(*args) _local_scope(topi_compute) - # register topi schedule for "tracing" target - for topi_compute in self.topi_to_task: - for topi_schedule in self.topi_to_schedule[topi_compute]: - def _local_scope_(schedule_func): - """start a scope to hold the local function in for loop""" + return self - @schedule_func.register("tracing", ) - def _tracing_topi_compute(outs): - outs = [outs] if isinstance(outs, tensor.Tensor) else outs - return create_schedule([x.op for x in outs]) - _local_scope_(topi_schedule) + def __exit__(self, exc_type, exc_val, exc_tb): + # revert modification + for func in self.modified_funcs: + self.func_to_reflection[func](func) def _register_topi_task(self): """register tuning wrapper for topi function""" import topi + # Avoid double registration for certain targets + if TaskExtractEnv.registered: + return + TaskExtractEnv.registered = True + # Tuning wrapper for topi functions @register("topi_nn_conv2d") def _topi_nn_conv2d(*args, **kwargs): diff --git a/vta/scripts/tune_resnet.py b/vta/scripts/tune_resnet.py new file mode 100644 index 0000000000000..b22a63e09df83 --- /dev/null +++ b/vta/scripts/tune_resnet.py @@ -0,0 +1,231 @@ +import argparse +import os +import time +import numpy as np + +import tvm +from tvm import rpc, autotvm +from tvm.autotvm.measure.measure_methods import request_remote +from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner +from tvm.contrib import graph_runtime, util +from tvm.contrib.download import download + +import topi +import nnvm.compiler +import vta +import vta.testing + +env = vta.get_env() + +def register_vta_tuning_tasks(): + from tvm.autotvm.task.topi_integration import TaskExtractEnv, deserialize_args + + @tvm.tag_scope(tag=topi.tag.ELEMWISE) + def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + + # init autotvm env to register VTA operator + TaskExtractEnv() + + @autotvm.task.register("topi_nn_conv2d", override=True) + def _topi_nn_conv2d(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + + with tvm.target.vta(): + res = topi.nn.conv2d(*args, **kwargs) + res = topi.right_shift(res, 8) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.current_target().device_name == 'vta': + s = topi.generic.schedule_conv2d_nchw([res]) + else: + s = tvm.create_schedule([res.op]) + return s, [A, W, res] + + + +def generate_graph(sym, params, target, target_host): + # Populate the shape and data type dictionary + shape_dict = {"data": (1, 3, 224, 224)} + dtype_dict = {"data": 'float32'} + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + + # Apply NNVM graph optimization passes + sym = vta.graph.clean_cast(sym) + sym = vta.graph.clean_conv_fuse(sym) + assert env.BLOCK_IN == env.BLOCK_OUT + sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT) + + # Compile NNVM graph + with nnvm.compiler.build_config(opt_level=3): + with vta.build_config(): + graph, lib, params = nnvm.compiler.build( + sym, target, shape_dict, dtype_dict, + params=params, target_host=target_host) + + return graph, lib, params + + +def extract_tasks(sym, params, target, target_host): + # Populate the shape and data type dictionary + shape_dict = {"data": (1, 3, 224, 224)} + dtype_dict = {"data": 'float32'} + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + + # Apply NNVM graph optimization passes + sym = vta.graph.clean_cast(sym) + sym = vta.graph.clean_conv_fuse(sym) + assert env.BLOCK_IN == env.BLOCK_OUT + sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT) + + with vta.build_config(): + tasks = autotvm.task.extract_from_graph(graph=sym, shape=shape_dict, dtype=dtype_dict, target=target, + params=params, symbols=(nnvm.sym.conv2d,), target_host=target_host) + return tasks + + +def download_model(): + url = "https://github.com/uwsaml/web-data/raw/master/vta/models/" + categ_fn = 'synset.txt' + graph_fn = 'resnet18_qt8.json' + params_fn = 'resnet18_qt8.params' + data_dir = '_data' + if not os.path.exists(data_dir): + os.makedirs(data_dir) + + for file in [categ_fn, graph_fn, params_fn]: + if not os.path.isfile(file): + download(os.path.join(url, file), os.path.join(data_dir, file)) + + sym = nnvm.graph.load_json(open(os.path.join(data_dir, graph_fn)).read()) + params = nnvm.compiler.load_param_dict(open(os.path.join(data_dir, params_fn), 'rb').read()) + + return sym, params + + +def tune_tasks(tasks, + measure_option, + tuner='xgb', + n_trial=1000, + early_stopping=None, + log_filename='tuning.log', + use_transfer_learning=True, + try_winograd=True): + # create tmp log file + tmp_log_file = log_filename + ".tmp" + if os.path.exists(tmp_log_file): + os.remove(tmp_log_file) + + for i, tsk in enumerate(reversed(tasks)): + prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) + + # create tuner + if tuner == 'xgb' or tuner == 'xgb-rank': + tuner_obj = XGBTuner(tsk, loss_type='rank') + elif tuner == 'ga': + tuner_obj = GATuner(tsk, pop_size=50) + elif tuner == 'random': + tuner_obj = RandomTuner(tsk) + elif tuner == 'gridsearch': + tuner_obj = GridSearchTuner(tsk) + else: + raise ValueError("Invalid tuner: " + tuner) + + if use_transfer_learning: + if os.path.isfile(tmp_log_file): + tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) + + # do tuning + n_trial_ = min(n_trial, len(tsk.config_space)) + tuner_obj.tune(n_trial_, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(n_trial_, prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file)]) + + # pick best records to a cache file + autotvm.record.pick_best(tmp_log_file, log_filename) + os.remove(tmp_log_file) + +if __name__ == '__main__': + + # Get tracker info from env + tracket_host = os.environ.get("TVM_TRACKER_HOST", None) + tracket_port = int(os.environ.get("TVM_TRACKER_PORT", None)) + if not tracket_host or not tracket_port: + print("Set your AutoTVM tracker node host and port variables to run the autotuner") + exit() + + tuning_opt = { + 'log_filename': 'resnet-18.log', + + 'tuner': 'random', + 'n_trial': 1e9, + 'early_stopping': None, + + 'measure_option': autotvm.measure_option( + builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func), + runner=autotvm.RPCRunner(env.TARGET, tracket_host, tracket_port, + number=4, repeat=3, timeout=60, + check_correctness=True)) + } + + # download model + sym, params = download_model() + + # register VTA tuning tasks + register_vta_tuning_tasks() + + # extract tasks + print("Extract tasks...") + target = tvm.target.vta() + target_host = env.target_host + tasks = extract_tasks(sym, params, target, target_host) + + print("Tuning...") + tune_tasks(tasks, **tuning_opt) + + # compile kernels with history best records + with autotvm.tophub.context(target, extra_files=[tuning_opt['log_filename']]): + print("Compile...") + graph, lib, params = generate_graph(sym, params, target, target_host) + input_shape = (1, 3, 224, 224) + dtype = 'float32' + + # export library + tmp = util.tempdir() + filename = "net.tar" + lib.export_library(tmp.relpath(filename)) + + # upload module to device + print("Upload...") + remote = autotvm.measure.request_remote(env.TARGET, tracket_host, tracket_port, timeout=10000) + remote.upload(tmp.relpath(filename)) + rlib = remote.load_module(filename) + + # upload parameters to device + ctx = remote.context(str(target), 0) + rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()} + data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype)) + module = graph_runtime.create(graph, rlib, ctx) + module.set_input('data', data_tvm) + module.set_input(**rparams) + + # evaluate + print("Evaluate inference time cost...") + ftimer = module.module.time_evaluator("run", ctx, number=3, repeat=3) + prof_res = np.array(ftimer().results) * 1000 # convert to millisecond + print("Mean inference time (std dev): %.2f ms (%.2f ms)" % + (np.mean(prof_res), np.std(prof_res))) +