Skip to content

Commit

Permalink
kernel layout rewrite (apache#28)
Browse files Browse the repository at this point in the history
* kernel layout rewrite

* remove some hacks

* add defuse_ops pass and move kernel_layout_rewrite pass after fuse_ops pass

* set TVM_RELAY_DISABLE_BUILD_CACHE for task extraction and prepare_layout_rewrite
  • Loading branch information
minminsun authored and merrymercy committed Jun 20, 2020
1 parent c7364df commit 36cd9ef
Show file tree
Hide file tree
Showing 25 changed files with 787 additions and 353 deletions.
13 changes: 13 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,19 @@ struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
}
};

/*! \brief Attributes for KernelLayoutTransform operator */
struct KernelLayoutTransformAttrs : public tvm::AttrsNode<KernelLayoutTransformAttrs> {
std::string src_layout;
std::string dst_layout;

TVM_DECLARE_ATTRS(KernelLayoutTransformAttrs, "relay.attrs.KernelLayoutTransformAttrs") {
TVM_ATTR_FIELD(src_layout)
.describe("The source layout of the tensor. (e.g. 1N32C112H112W)");
TVM_ATTR_FIELD(dst_layout)
.describe("The destination layout of the tensor. (e.g. 1N2C112H112W16c)");
}
};

/*! \brief Attributes for ShapeOf operator */
struct ShapeOfAttrs : public tvm::AttrsNode<ShapeOfAttrs> {
DataType dtype;
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,20 @@ TVM_DLL Pass CanonicalizeOps();
*/
TVM_DLL Pass AlterOpLayout();

/*!
* \brief Alternate the layouts of kernels.
*
* \return The pass.
*/
TVM_DLL Pass KernelLayoutTransform();

/*!
* \brief The reverse of FuseOps.
*
* \return The pass.
*/
TVM_DLL Pass DeFuseOps();

/*!
* \brief Given a dest layout, this pass transforms the expr such that most of the ops input data
* layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ansor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@
FallbackContext, clear_fallback_cache, ApplyGraphBest, BlockingEmptyContext
from .topi_integration import register_topi_schedule, TaskExtractEnv
from .relay_integration import extract_from_program, extract_from_multiple_program, \
finish_layout_rewrite
finish_layout_rewrite, prepare_layout_rewrite
9 changes: 7 additions & 2 deletions python/tvm/ansor/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,17 @@ def apply_steps_from_state(self, state, layout_rewrite_level=None):
args : List[Tensor]
"""
if isinstance(state, State):
return _ffi_api.ComputeDAGApplyStepsFromState(self, state.state_object)
return _ffi_api.ComputeDAGApplyStepsFromState(self, state.state_object,
layout_rewrite_level)
elif isinstance(state, StateObject):
return _ffi_api.ComputeDAGApplyStepsFromState(self, state)
return _ffi_api.ComputeDAGApplyStepsFromState(self, state,
layout_rewrite_level)
else:
raise ValueError("The input must be a State or StateObject")

def rewrite_layout_from_state(self, state: State):
return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state)

def print_python_code_from_state(self, state):
"""
Parameters
Expand Down
1 change: 0 additions & 1 deletion python/tvm/ansor/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,4 +534,3 @@ def timed_func(inp, build_res):
print("")

return measure_results

7 changes: 3 additions & 4 deletions python/tvm/ansor/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _lower(mod,
# If failed to compile, then fallback to use VM compiler.
# TODO: Currently VM compiler is likely to stack overflow for large models.
try:
with relay.build_config(opt_level=3):
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
opt_mod, _ = relay.optimize(mod, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
grc.codegen(opt_mod["main"])
Expand Down Expand Up @@ -191,7 +191,7 @@ def prepare_layout_rewrite(mod, params, ops, target):
"""Prepare for kernel layout rewrite. This function will write layout infos to a global static variable,
then these layout info will be used by a relay pass `kernel_layout_transform`.
"""
from .. import relay
from tvm import relay

env = TaskExtractEnv.get(do_layout_rewrite=True)

Expand All @@ -203,9 +203,8 @@ def prepare_layout_rewrite(mod, params, ops, target):
else:
warnings.warn("Op %s is not tunable, ignored." % op_name)

env.reset(topi_scheds)
with env:
env.reset(topi_scheds)

# wrap build call in thread to avoid multiprocessing problems
build_thread = threading.Thread(target=_lower,
args=(mod, target, params))
Expand Down
13 changes: 9 additions & 4 deletions python/tvm/ansor/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
import os
import json
import tvm.te._ffi_api
from tvm import target as _target
from tvm.te import tensor
from tvm.te.tensor import PlaceholderOp, ComputeOp

from .dispatcher import DispatchContext
from .dispatcher import DispatchContext, BlockingEmptyContext
from .workload_registry import register_auto_scheduler_workload_bufs, \
make_workload_key_bufs, compute_dag_hash
from .compute_dag import ComputeDAG

def traverse_to_get_io_tensors(outs):
layout_free_ops = []
Expand Down Expand Up @@ -77,11 +80,14 @@ def __init__(self, do_layout_rewrite=False):
def __enter__(self):
self.tracing = True
self.wkl_key_collection = {}
self.relay_disable_build_cache_ = os.environ.get("TVM_RELAY_DISABLE_BUILD_CACHE", "false")
os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = "true"

return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.tracing = False
os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = self.relay_disable_build_cache_

def reset(self, wanted_relay_ops=None):
"""Reset task collections
Expand Down Expand Up @@ -144,7 +150,7 @@ def get(do_layout_rewrite=False):
The single instance of TaskExtractEnv
"""
if not TaskExtractEnv.current:
TaskExtractEnv.current = TaskExtractEnv()
TaskExtractEnv.current = TaskExtractEnv(do_layout_rewrite)
else:
TaskExtractEnv.current.do_layout_rewrite = do_layout_rewrite
return TaskExtractEnv.current
Expand Down Expand Up @@ -188,7 +194,7 @@ def wrapper(outs, *args, **kwargs):
# Rewrite the dag and update the transform history for
# the new dag in DispatchContext
dispatch_ctx = DispatchContext.current
tgt = _target.current_target()
tgt = _target.Target.current()
state = dispatch_ctx.query(tgt, key)
dag = ComputeDAG(outs)
new_dag = dag.rewrite_layout_from_state(state)
Expand All @@ -199,7 +205,6 @@ def wrapper(outs, *args, **kwargs):
task_env.layout_rewrite_success_ct += 1

# Call schedule_func under FallbackContext() to avoid layout rewrite
tgt = _target.Target.current()
cfg = BlockingEmptyContext().query(tgt, key)
return topi_schedule(cfg, outs)

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def compute_strided_set(attrs, inputs, output_type):
# layout_transform
_reg.register_injective_schedule("layout_transform")
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
_reg.register_injective_schedule("kernel_layout_transform")
_reg.register_pattern("kernel_layout_transform", OpPattern.INJECTIVE)

# argwhere
@_reg.register_compute("argwhere")
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ class ClipAttrs(Attrs):
class LayoutTransformAttrs(Attrs):
"""Attributes for transform.layout_transform"""

@tvm._ffi.register_object("relay.attrs.KernelLayoutTransformAttrs")
class KernelLayoutTransformAttrs(Attrs):
"""Attributes for transform.kernel_layout_transform"""

@tvm._ffi.register_object("relay.attrs.ShapeOfAttrs")
class ShapeOfAttrs(Attrs):
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,27 @@ def layout_transform(data, src_layout, dst_layout):
"""
return _make.layout_transform(data, src_layout, dst_layout)

def kernel_layout_transform(data, src_layout, dst_layout):
"""Transform the layout of a kernel
Parameters
----------
data : relay.Expr
The source tensor to be transformed
src_layout: str
The source layout. (e.g 1N32C112H112W)
dst_layout: str
The destination layout. (e.g. 1N2C112H112W16c)
Returns
-------
ret : relay.Expr
The transformed tensor.
"""
return _make.kernel_layout_transform(data, src_layout, dst_layout)


def reverse_reshape(data, newshape):
"""Reshapes the input array where the special values are inferred from
Expand Down
25 changes: 15 additions & 10 deletions python/tvm/relay/testing/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,32 @@
from . import layers
from .init import create_workload

def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"):
def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", layout="NCHW"):
"""get symbol of nature dqn"""
data_shape = (batch_size,) + image_shape
data = relay.var("data", shape=data_shape, dtype=dtype)

bias_axis = layout.index('C')

conv1_bias = relay.var("conv1_bias")
conv1 = layers.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0),
channels=32, name="conv1")
conv1 = relay.nn.bias_add(conv1, conv1_bias)
channels=32, name="conv1", data_layout=layout,
kernel_layout=layers.conv_kernel_layout(layout))
conv1 = relay.nn.bias_add(conv1, conv1_bias, bias_axis)
relu1 = relay.nn.relu(conv1)

conv2_bias = relay.var("conv2_bias")
conv2 = layers.conv2d(relu1, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0),
channels=64, name="conv2")
conv2 = relay.nn.bias_add(conv2, conv2_bias)
channels=64, name="conv2", data_layout=layout,
kernel_layout=layers.conv_kernel_layout(layout))
conv2 = relay.nn.bias_add(conv2, conv2_bias, bias_axis)
relu2 = relay.nn.relu(conv2)

conv3_bias = relay.var("conv3_bias")
conv3 = layers.conv2d(relu2, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0),
channels=64, name="conv3")
conv3 = relay.nn.bias_add(conv3, conv3_bias)
channels=64, name="conv3", data_layout=layout,
kernel_layout=layers.conv_kernel_layout(layout))
conv3 = relay.nn.bias_add(conv3, conv3_bias, bias_axis)
relu3 = relay.nn.relu(conv3)

bf1 = relay.nn.batch_flatten(relu3)
Expand All @@ -58,7 +63,7 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"
return relay.Function(args, dense2)


def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"):
def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32", layout="NCHW"):
"""Get benchmark workload for a Deep Q Network
Parameters
----------
Expand All @@ -72,10 +77,10 @@ def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="flo
The data type
Returns
-------
mod : tvm.IRModule
mod : tvm.relay.Module
The relay module that contains a DQN network.
params : dict of str to NDArray
The parameters.
"""
net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype)
net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype, layout=layout)
return create_workload(net)
4 changes: 4 additions & 0 deletions python/tvm/relay/testing/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def resnet(units,
data = relay.var("data", shape=data_shape, dtype=dtype)
data = layers.batch_norm_infer(data=data, epsilon=2e-5, scale=False, name='bn_data')
(_, _, height, _) = data_shape
if layout == "NHWC":
(_, height, _, _) = data_shape
if height <= 32: # such as cifar10
body = layers.conv2d(
data=data, channels=filter_list[0], kernel_size=(3, 3),
Expand Down Expand Up @@ -209,6 +211,8 @@ def get_net(batch_size,
Original author Wei Wu
"""
(_, height, _) = image_shape
if layout == "NHWC":
(height, _, _) = image_shape
data_shape = (batch_size,) + image_shape
if height <= 28:
num_stages = 3
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/te/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ class Tensor(DataProducer, _expr.ExprOp):

def __call__(self, *indices):
ndim = self.ndim
if len(indices) != ndim:
raise ValueError("Need to provide %d index in tensor slice" % ndim)
# After ansor kernel layout rewrite, len(indices) <= ndim,
# and the indices will get modified by Ansor during schedule generation.
# if len(indices) != ndim:
# raise ValueError("Need to provide %d index in tensor slice" % ndim)
indices = convert_to_object(indices)
args = []
for x in indices:
Expand Down
9 changes: 5 additions & 4 deletions scripts/tune_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ def get_network(name, model_path, batch_size, layout):
input_shape = (batch_size, 100)
mod, params = relay.testing.dcgan.get_workload(batch_size=batch_size)
elif name == 'dqn':
image_shape = (4, 84, 84)
layout = "NHWC"
image_shape = (84, 84, 4)
input_shape = (batch_size, *image_shape)
mod, params = relay.testing.dqn.get_workload(batch_size=batch_size, image_shape=image_shape, dtype=dtype)
mod, params = relay.testing.dqn.get_workload(batch_size=batch_size, image_shape=image_shape, dtype=dtype, layout=layout)
elif name == 'mobilenet':
image_shape = (224, 224, 3) if layout == 'NHWC' else (3, 224, 224)
input_shape = (batch_size, *image_shape)
Expand Down Expand Up @@ -229,7 +230,7 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune,
if measure_ctx:
del measure_ctx

kernel_layout_rewrite = False
kernel_layout_rewrite = False

# Compile graph with best states found by auto-scheduler
print("=============== Compile ===============")
Expand All @@ -245,7 +246,7 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune,
ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE
ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE

with relay.build_config(opt_level=3):
with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
graph, lib, opt_params = relay.build_module.build(
mod, target=target, params=params)

Expand Down
Loading

0 comments on commit 36cd9ef

Please sign in to comment.