Skip to content

Commit

Permalink
[BYOC-DNNL] Support DNNL optimal layout (#10421)
Browse files Browse the repository at this point in the history
* enable dnnl optimal layout for supported ops

* verfied cv models with onednnv1.7

* rebase to the latest main branch

* fix format related comments

* remove unnecessary layout transformation

* change deconv into conv_transpose

* rename some variables and functions

* simplify query_layout

* add checkes for query_layout

* fix lint

* move partition_for_dnnl from dnnl.py to test_dnnl.py

* remove unnecessary model test

* add more dnnl layout

* rename flag in convolution.cc

* enhance dnnl layout
  • Loading branch information
crazydemo authored Mar 7, 2022
1 parent 174d09e commit 12f213a
Show file tree
Hide file tree
Showing 7 changed files with 990 additions and 174 deletions.
246 changes: 207 additions & 39 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
import logging

import tvm.ir
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm import relay

from ... import _ffi_api
from ...dataflow_pattern import wildcard, is_op
from .register import register_pattern_table

Expand Down Expand Up @@ -94,12 +94,12 @@ def _func_wrapper(expr):


def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None):
"""Create patterns related to conv and deconv.
"""Create patterns related to conv and conv_transpose.
Parameters
----------
with_bias : bool
Whether attach `bias_add` to `conv / deconv`.
Whether attach `bias_add` to `conv / conv_transpose`.
with_eltwise : str
The attached elementwise post-op name.
Returns
Expand Down Expand Up @@ -147,12 +147,12 @@ def make_dense_pattern(with_bias=True, with_eltwise=None):
return dense_out


def make_dnnl_pattern(op, with_bias, with_eltwise):
def make_dnnl_pattern(op_name, with_bias, with_eltwise):
"""Create dnnl patterns.
Parameters
----------
op : str
op_name : str
The first call node's op name.
with_bias : bool
Whether attach `bias_add` to `nn.dense`.
Expand All @@ -163,18 +163,20 @@ def make_dnnl_pattern(op, with_bias, with_eltwise):
pattern : Tuple(pattern_name, CallPattern)
Created pattern name, along with its CallPattern.
"""
pat_name = op.replace("nn", "dnnl")
pat_name = op_name.replace("nn", "dnnl")
if "_transpose" in op_name:
pat_name = "dnnl.deconv" + op_name.split("_")[0][-2::]
pat_name += "_bias" if with_bias else ""
pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else ""
if "conv" in op:
dnnl_pattern = (pat_name, make_conv_pattern(op, with_bias, with_eltwise))
elif op == "nn.dense":
if "conv" in op_name:
dnnl_pattern = (pat_name, make_conv_pattern(op_name, with_bias, with_eltwise))
elif op_name == "nn.dense":
dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise))
else:
logger.warning(
"Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose and "
"dense op are supported, but got %s.",
op,
op_name,
)
dnnl_pattern = ()
return dnnl_pattern
Expand Down Expand Up @@ -207,39 +209,205 @@ def pattern_table():
return dnnl_patterns


def partition_for_dnnl(mod, params=None):
"""Partition the graph greedily offloading supported operators to DNNL.
def get_optimal_layout_for_conv(
data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups
):
"""Get the optimal layout of dnnl, given shape of conv2d.
Parameters
----------
mod : Module
The module to run passes on.
params : Optional[Dict[str, NDArray]]
Constant input parameters.
data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups
: String
Input argument.
Returns
-------
mod : Module
Annotated and partitioned module.
layouts : string
The result.
"""
return _ffi_api.get_optimal_layout_for_conv(
data_layout,
kernel_layout,
weight_shape,
out_shape,
paddings,
strides,
dilates,
groups,
)


def get_optimal_layout_for_conv_transpose(
data_layout,
kernel_layout,
weight_shape,
out_shape,
paddings,
output_paddings,
strides,
dilates,
groups,
):
"""Get the optimal layout of dnnl, given shape of tranposed conv2d.
Parameters
----------
data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides,
dilates, groups
: Int, String
Input argument.
Returns
-------
layouts : string
The result.
"""
return _ffi_api.get_optimal_layout_for_conv_transpose(
data_layout,
kernel_layout,
weight_shape,
out_shape,
paddings,
output_paddings,
strides,
dilates,
groups,
)


def get_shape(tensor):
"""Get tensor's shape."""
if isinstance(tensor, relay.expr.Var):
return tensor.type_annotation.concrete_shape
if isinstance(tensor, relay.expr.Constant):
return tensor.data.shape
if isinstance(tensor, tvm.ir.tensor_type.TensorType):
return tensor.concrete_shape
if isinstance(tensor, tvm.ir.container.Array):
return tensor[-1].shape
if isinstance(tensor, relay.expr.Call):
return tensor.checked_type.shape
raise TypeError("Unsupport data type: %s" % type(tensor))


if params:
mod["main"] = bind_params_by_name(mod["main"], params)
seq = tvm.transform.Sequential(
[
transform.CanonicalizeOps(),
transform.InferType(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.FoldScaleAxis(),
# fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu`
transform.SimplifyExpr(),
transform.FoldConstant(),
transform.MergeComposite(pattern_table()),
transform.AnnotateTarget("dnnl"),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
]
def tag2layout(input_data, is_weight=False, conv_type="Conv1D"):
"""Transfer layout, denoted with `a, b, c, d, e`,
into valid layout (NCHW / OIHW) of TVM."""
if "Conv1D" in conv_type:
data_dic = {"a": "N", "b": "C", "c": "W"}
weight_dic = {"a": "O", "b": "I", "c": "W", "d": "G"}
elif "Conv2D" in conv_type:
data_dic = {"a": "N", "b": "C", "c": "H", "d": "W"}
weight_dic = {"a": "O", "b": "I", "c": "H", "d": "W"}
if "e" in input_data:
weight_dic = {"a": "G", "b": "O", "c": "I", "d": "H", "e": "W"}
elif "Conv3D" in conv_type:
data_dic = {"a": "N", "b": "C", "c": "D", "d": "H", "e": "W"}
weight_dic = {"a": "O", "b": "I", "c": "D", "d": "H", "e": "W", "f": "G"}

dic = weight_dic if is_weight else data_dic
res = ""

for i in input_data:
if i.isupper():
i = i.lower()
res += dic[i]
dic[i] = dic[i].lower()
elif i.islower():
res += dic[i]
elif i.isdigit():
res += i
else:
raise ValueError("Unsupport layout format: %s" % input_data)
return res


def legalize_group_conv(attrs, inputs, types):
"""Legalize group conv / conv_transpose calculation.
Alter weight layout from OIHW to GOIHW / IOHW to GIOHW"""
groups = attrs.groups
data, weight = inputs
if groups == 1:
if "Transpose" not in type(attrs).__name__:
return relay.nn.conv2d(data, weight, **attrs)
return relay.nn.conv2d_transpose(data, weight, **attrs)
OC, IC, H, W = get_shape(weight)
new_attrs = dict(attrs)
weight = relay.reshape(weight, (groups, OC // groups, IC, H, W))
if "Transpose" not in type(attrs).__name__:
new_attrs["kernel_layout"] = "GOIHW"
return relay.nn.conv2d(data, weight, **new_attrs)
new_attrs["kernel_layout"] = "GIOHW"
return relay.nn.conv2d_transpose(data, weight, **new_attrs)


def alter_conv(attrs, inputs, tinfos, out_type):
"""The convolution's layout auto-query func for dnnl."""

data, weight = inputs
groups = str(attrs.groups)
weight_shape = ",".join([str(x) for x in get_shape(weight)])
out_shape = ",".join([str(x) for x in get_shape(out_type)])
paddings = ",".join([str(x) for x in attrs.get_int_tuple("padding")])
strides = ",".join([str(x) for x in attrs.get_int_tuple("strides")])
dilates = ",".join([str(x) for x in attrs.get_int_tuple("dilation")])
new_attrs = dict(attrs)
conv_type = type(attrs).__name__.split("Attrs")[0]

res = get_optimal_layout_for_conv(
attrs["data_layout"],
attrs["kernel_layout"],
weight_shape,
out_shape,
paddings,
strides,
dilates,
groups,
)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
return mod
src_df, weight_df, dst_df = res.split(",")
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type)
new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type)

if conv_type == "Conv1D":
return relay.nn.conv1d(data, weight, **new_attrs)
if conv_type == "Conv2D":
return relay.nn.conv2d(data, weight, **new_attrs)
return relay.nn.conv3d(data, weight, **new_attrs)


def alter_conv_transpose(attrs, inputs, tinfos, out_type):
"""The transposed convolution's layout auto-query func for dnnl."""

data, weight = inputs
weight_shape = ",".join([str(x) for x in get_shape(weight)])
out_shape = ",".join([str(x) for x in get_shape(out_type)])
paddings = ",".join([str(x) for x in attrs.get_int_tuple("padding")])
output_paddings = ",".join([str(x) for x in attrs.get_int_tuple("output_padding")])
strides = ",".join([str(x) for x in attrs.get_int_tuple("strides")])
dilates = ",".join([str(x) for x in attrs.get_int_tuple("dilation")])
groups = str(attrs.groups)
new_attrs = dict(attrs)
conv_type = type(attrs).__name__.split("Attrs")[0]

res = get_optimal_layout_for_conv_transpose(
attrs["data_layout"],
attrs["kernel_layout"],
weight_shape,
out_shape,
paddings,
output_paddings,
strides,
dilates,
groups,
)
src_df, weight_df, dst_df = res.split(",")
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type)
new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type)

if conv_type == "Conv1DTranspose":
return relay.nn.conv1d_transpose(data, weight, **new_attrs)
if conv_type == "Conv2DTranspose":
return relay.nn.conv2d_transpose(data, weight, **new_attrs)
return relay.nn.conv3d_transpose(data, weight, **new_attrs)
42 changes: 29 additions & 13 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,14 +445,30 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
{"relu", "nn.relu"},
{"tanh", "tanh"},
{"sigmoid", "sigmoid"},
{"nn.deconv2d", "nn.conv2d_transpose"},
{"nn.deconv3d", "nn.conv3d_transpose"},
};

std::vector<std::string> ParsingOpList(std::string op, std::string pattern_name) {
std::vector<std::string> op_list = {"nn." + op};
for (auto& t : op_map) {
if (pattern_name.find(t.first) != std::string::npos) {
op_list.push_back(t.second);
std::vector<std::string> ParsingOpList(const std::string& pattern_name,
std::string interval = "_") {
ICHECK_NE(pattern_name, "");
std::vector<std::string> op_list;
size_t pos = 0, start = 0;
while ((pos = pattern_name.find(interval, start)) != std::string::npos) {
std::string op_name = pattern_name.substr(start, pos - start);
if (op_name.find("dnnl") != std::string::npos) {
op_name.replace(op_name.find("dnnl"), 4, "nn");
if (op_name.find("deconv") != std::string::npos) {
op_name = op_map[op_name];
}
} else {
op_name = op_map[op_name];
}
if (pos > start) op_list.push_back(op_name);
start = pos + interval.size();
}
if (pattern_name.size() > start) {
op_list.push_back(op_map[pattern_name.substr(start)]);
}
return op_list;
}
Expand All @@ -471,28 +487,28 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer {
ICHECK(comp.defined()) << "DNNL JSON runtime only supports composite functions.";
name = comp.value();

if (name.find("dnnl.conv2d_transpose") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList("conv2d_transpose", name);
if (name.find("dnnl.deconv2d") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList(name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("dnnl.conv3d_transpose") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList("conv3d_transpose", name);
} else if (name.find("dnnl.deconv3d") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList(name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("dnnl.conv1d") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList("conv1d", name);
std::vector<std::string> op_list = ParsingOpList(name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("dnnl.conv2d") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList("conv2d", name);
std::vector<std::string> op_list = ParsingOpList(name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("dnnl.conv3d") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList("conv3d", name);
std::vector<std::string> op_list = ParsingOpList(name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else if (name.find("dnnl.dense") != std::string::npos) {
std::vector<std::string> op_list = ParsingOpList("dense", name);
std::vector<std::string> op_list = ParsingOpList(name);
call = GetRootCall(fn->body.as<CallNode>(), op_list.size() - 1, op_list);
ICHECK(call->op.as<OpNode>()) << "Not op node";
} else {
Expand Down
Loading

0 comments on commit 12f213a

Please sign in to comment.