Skip to content

Commit

Permalink
autoTVM task extraction for VTA (nnvm for now)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmoreau89 committed May 28, 2019
1 parent 7ec3e64 commit 4a92032
Show file tree
Hide file tree
Showing 4 changed files with 321 additions and 60 deletions.
10 changes: 5 additions & 5 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,25 @@ def compute_conv2d(attrs, inputs, _):
if groups == 1 and layout == 'NCHW4c' and inputs[0].dtype == 'int8':
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding,
dilation, layout, out_dtype=out_dtype)
dilation, layout, out_dtype)
# pylint: enable=assignment-from-no-return
elif groups == 1:
out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype)
elif layout == "NCHW" and \
groups == get_const_int(inputs[0].shape[1]) and \
groups == channels:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
elif layout in ["NCHW", "NCHW4c"]:
out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
out_dtype=out_dtype)
out_dtype)
elif layout == "NHWC" and \
kernel_layout == "HWOI" and \
groups == get_const_int(inputs[0].shape[3]) and \
groups == channels:
out = topi.nn.depthwise_conv2d_nhwc(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
else:
raise ValueError("not support arbitrary group number for now")

Expand Down
71 changes: 40 additions & 31 deletions python/tvm/autotvm/task/nnvm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@

from .task import create
from .topi_integration import TaskExtractEnv
from .dispatcher import ApplyHistoryBest

logger = logging.getLogger('autotvm')


def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
def extract_from_graph(graph, shape, dtype, target, symbols, params, target_host=None):
""" Extract tuning tasks from a nnvm graph.
This function collects tuning tasks by building the graph
with a "tracing" target and tracing all the calls to topi.
and trace all the calls to topi.
Parameters
----------
Expand All @@ -49,6 +50,8 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
The compilation target
symbols : Array of nnvm.symbol
Array of nnvm symbols want to be tuned
params : dict of str to NDArray
The parameter dictionary.
target_host: tvm.target.Target
The host compilation target
Expand Down Expand Up @@ -78,32 +81,35 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
topi_funcs.extend(SYMBOL2TOPI[sym_name])
else:
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)

# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)

# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
with env:
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True

# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
nnvm.compiler.engine.clear_cache()
nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)
# run compiler to collect all TOPI calls during compilation
nnvm.compiler.engine.clear_cache()
nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype,
target_host=target_host, params=params)

logger.disabled = old_state
logger.disabled = old_state

# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
try:
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
print("[Warning] Invalid Shape during AutoTVM Task Creation")

return tasks


def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_host=None):
def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, params, target_host=None):
""" Extract tuning tasks from multiple nnvm graphs.
This function is the multiple graph version of extract_from_graph
Expand All @@ -120,6 +126,8 @@ def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_
The compilation target
symbols : Array of nnvm.symbol
Array of nnvm symbols want to be tuned
params : dict of str to NDArray
The parameter dictionary.
target_host: tvm.target.Target
The host compilation target
Expand Down Expand Up @@ -149,28 +157,29 @@ def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_
topi_funcs.extend(SYMBOL2TOPI[sym_name])
else:
warnings.warn("Symbol %s is not tunable, ignored" % sym_name)

# run compiler to collect all TOPI calls during compilation
env.reset(topi_funcs)

# disable logger temporarily
old_state = logger.disabled
logger.disabled = True
with env:
# disable logger temporarily
old_state = logger.disabled
logger.disabled = True

# use a "tracing" target to do a fake compile for collecting topi calls
tracing_target = _target.create("llvm -device=tracing")
nnvm.compiler.engine.clear_cache()
for graph, shape, dtype in zip(graphs, shapes, dtypes):
nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype)

nnvm.compiler.engine.clear_cache()
for graph, shape, dtype in zip(graphs, shapes, dtypes):
nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)

logger.disabled = old_state
logger.disabled = old_state

# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
tasks.append(create(task_name, args,
target=target, target_host=target_host,
template_key='direct'))
try:
tsk = create(task_name, args,
target=target, target_host=target_host,
template_key='direct')
tasks.append(tsk)
except topi.InvalidShapeError:
print("[Warning] Invalid Shape during AutoTVM Task Creation")

return tasks

69 changes: 45 additions & 24 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""

import warnings
import sys

from ... import _api_internal, tensor, placeholder, create_schedule

from .task import args_to_workload, dispatcher, register
Expand Down Expand Up @@ -73,6 +76,7 @@ def deserialize_args(args):
class TaskExtractEnv:
"""Global environment for extracting tuning tasks from nnvm graph"""
current = None
registered = None

def __init__(self):
import topi
Expand Down Expand Up @@ -106,47 +110,64 @@ def __init__(self):
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
}

self._register_tracing()
# support reflection for tracing
self.func_to_reflection = {
topi.nn.conv2d: lambda x: setattr(topi.nn, 'conv2d', x),
topi.nn.conv2d_NCHWc: lambda x: setattr(topi.nn, 'conv2d_NCHWc', x),
topi.nn.depthwise_conv2d_nchw: lambda x: setattr(topi.nn, 'depthwise_conv2d_nchw', x),
topi.nn.group_conv2d_nchw: lambda x: setattr(topi.nn, 'group_conv2d_nchw', x),
topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x),
topi.nn.dense: lambda x: setattr(topi.nn, 'dense', x),
topi.nn.bitserial_conv2d_nchw: lambda x: setattr(topi.nn, 'bitserial_conv2d_nchw', x),
topi.nn.bitserial_conv2d_nhwc: lambda x: setattr(topi.nn, 'bitserial_conv2d_nhwc', x),
topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x),
topi.nn.deformable_conv2d_nchw: lambda x: setattr(topi.nn, 'deformable_conv2d_nchw', x),
}

self._register_topi_task()
self.task_collection = []
self.wanted_topi_funcs = list(self.topi_to_task.keys())
self.modified_funcs = []

def __enter__(self):
self.task_collection = []
self.modified_funcs = []

def _register_tracing(self):
"""Register tracing function to track the topi function call"""
# register topi compute for "tracing" target
for topi_compute in self.topi_to_task:
for topi_compute in self.wanted_topi_funcs:
def _local_scope(compute_func):
"""start a scope to hold the local function in for loop"""

@compute_func.register("tracing", )
def _tracing_topi_compute(*args, **kwargs):
assert not kwargs, "Do not support extracting tuning tasks when" \
"kwargs is used in TOPI function call." \
def _tracing_wrapper(*args, **kwargs):
assert not kwargs, "Do not support extracting tuning tasks when " \
"kwargs is used in TOPI function call. " \
"Please modify it to use only positional args."
key = (self.topi_to_task[compute_func], serialize_args(args))
if key not in self.task_collection:
self.task_collection.append(key)

return compute_func(*args, **kwargs)

self.func_to_reflection[topi_compute](_tracing_wrapper)
self.modified_funcs.append(topi_compute)

if compute_func in self.wanted_topi_funcs: # record this call
key = (self.topi_to_task[compute_func], serialize_args(args))
if key not in self.task_collection:
self.task_collection.append(key)
return compute_func.fdefault(*args)
_local_scope(topi_compute)

# register topi schedule for "tracing" target
for topi_compute in self.topi_to_task:
for topi_schedule in self.topi_to_schedule[topi_compute]:
def _local_scope_(schedule_func):
"""start a scope to hold the local function in for loop"""
return self

@schedule_func.register("tracing", )
def _tracing_topi_compute(outs):
outs = [outs] if isinstance(outs, tensor.Tensor) else outs
return create_schedule([x.op for x in outs])
_local_scope_(topi_schedule)
def __exit__(self, exc_type, exc_val, exc_tb):
# revert modification
for func in self.modified_funcs:
self.func_to_reflection[func](func)

def _register_topi_task(self):
"""register tuning wrapper for topi function"""
import topi

# Avoid double registration for certain targets
if TaskExtractEnv.registered:
return
TaskExtractEnv.registered = True

# Tuning wrapper for topi functions
@register("topi_nn_conv2d")
def _topi_nn_conv2d(*args, **kwargs):
Expand Down
Loading

0 comments on commit 4a92032

Please sign in to comment.