Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
crazydemo committed Mar 2, 2022
1 parent e0c78cc commit 56ef1a1
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 47 deletions.
82 changes: 45 additions & 37 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from tvm import relay
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.testing.temp_op_attr import TempOpAttr

from ... import _ffi_api
from ...dataflow_pattern import wildcard, is_op
Expand Down Expand Up @@ -211,14 +212,16 @@ def pattern_table():
return dnnl_patterns


def get_optimal_layout_for_conv(data_layout, kernel_layout, weight_shape,
out_shape, paddings, strides, dilates, groups):
def get_optimal_layout_for_conv(
data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups
):
"""Get the optimal layout of dnnl, given shape of conv2d.
Parameters
----------
data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups :String
Input argument.
data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups
: String
Input argument.
Returns
-------
Expand All @@ -238,13 +241,22 @@ def get_optimal_layout_for_conv(data_layout, kernel_layout, weight_shape,


def get_optimal_layout_for_conv_transpose(
data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
data_layout,
kernel_layout,
weight_shape,
out_shape,
paddings,
output_paddings,
strides,
dilates,
groups,
):
"""Get the optimal layout of dnnl, given shape of tranposed conv2d.
Parameters
----------
data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups
data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides,
dilates, groups
: Int, String
Input argument.
Expand All @@ -255,7 +267,7 @@ def get_optimal_layout_for_conv_transpose(
"""
return _ffi_api.get_optimal_layout_for_conv_transpose(
data_layout,
kernel_layout,
kernel_layout,
weight_shape,
out_shape,
paddings,
Expand All @@ -270,16 +282,15 @@ def get_shape(tensor):
"""Get tensor's shape."""
if isinstance(tensor, relay.expr.Var):
return tensor.type_annotation.concrete_shape
elif isinstance(tensor, relay.expr.Constant):
if isinstance(tensor, relay.expr.Constant):
return tensor.data.shape
elif isinstance(tensor, tvm.ir.tensor_type.TensorType):
if isinstance(tensor, tvm.ir.tensor_type.TensorType):
return tensor.concrete_shape
elif isinstance(tensor, tvm.ir.container.Array):
if isinstance(tensor, tvm.ir.container.Array):
return tensor[-1].shape
elif isinstance(tensor, relay.expr.Call):
if isinstance(tensor, relay.expr.Call):
return tensor.checked_type.shape
else:
raise TypeError("Unsupport data type: %s" % type(tensor))
raise TypeError("Unsupport data type: %s" % type(tensor))


def tag2layout(input_data, is_weight=False, conv_type="Conv1D"):
Expand Down Expand Up @@ -318,18 +329,19 @@ def legalize_group_conv(attrs, inputs, types):
"""Legalize group conv / conv_transpose calculation.
Alter weight layout from OIHW to GOIHW / IOHW to GIOHW"""
groups = attrs.groups
if groups == 1:
return
data, weight = inputs
if groups == 1:
if "Transpose" not in type(attrs).__name__:
return relay.nn.conv2d(data, weight, **attrs)
return relay.nn.conv2d_transpose(data, weight, **attrs)
OC, IC, H, W = get_shape(weight)
new_attrs = dict(attrs)
weight = relay.reshape(weight, (groups, OC // groups, IC, H, W))
if "Transpose" not in type(attrs).__name__:
new_attrs["kernel_layout"] = "GOIHW"
return relay.nn.conv2d(data, weight, **new_attrs)
else:
new_attrs["kernel_layout"] = "GIOHW"
return relay.nn.conv2d_transpose(data, weight, **new_attrs)
new_attrs["kernel_layout"] = "GIOHW"
return relay.nn.conv2d_transpose(data, weight, **new_attrs)


def alter_conv(attrs, inputs, tinfos, out_type):
Expand All @@ -346,22 +358,25 @@ def alter_conv(attrs, inputs, tinfos, out_type):
conv_type = type(attrs).__name__.split("Attrs")[0]

res = get_optimal_layout_for_conv(
attrs["data_layout"], attrs["kernel_layout"], weight_shape, out_shape, paddings,
strides, dilates, groups,
attrs["data_layout"],
attrs["kernel_layout"],
weight_shape,
out_shape,
paddings,
strides,
dilates,
groups,
)
src_df, weight_df, dst_df = res.split(",")
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
new_attrs["kernel_layout"] = tag2layout(
weight_df, is_weight=True, conv_type=conv_type
)
new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type)
new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type)

if conv_type == "Conv1D":
return relay.nn.conv1d(data, weight, **new_attrs)
elif conv_type == "Conv2D":
if conv_type == "Conv2D":
return relay.nn.conv2d(data, weight, **new_attrs)
elif conv_type == "Conv3D":
return relay.nn.conv3d(data, weight, **new_attrs)
return relay.nn.conv3d(data, weight, **new_attrs)


def alter_conv_transpose(attrs, inputs, tinfos, out_type):
Expand All @@ -380,7 +395,7 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):

res = get_optimal_layout_for_conv_transpose(
attrs["data_layout"],
attrs["kernel_layout"],
attrs["kernel_layout"],
weight_shape,
out_shape,
paddings,
Expand All @@ -391,17 +406,14 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type):
)
src_df, weight_df, dst_df = res.split(",")
new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type)
new_attrs["kernel_layout"] = tag2layout(
weight_df, is_weight=True, conv_type=conv_type
)
new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type)
new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type)

if conv_type == "Conv1DTranspose":
return relay.nn.conv1d_transpose(data, weight, **new_attrs)
elif conv_type == "Conv2DTranspose":
if conv_type == "Conv2DTranspose":
return relay.nn.conv2d_transpose(data, weight, **new_attrs)
elif conv_type == "Conv3DTranspose":
return relay.nn.conv3d_transpose(data, weight, **new_attrs)
return relay.nn.conv3d_transpose(data, weight, **new_attrs)


def partition_for_dnnl(mod, params=None, alter_layout=True):
Expand All @@ -418,10 +430,8 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
mod : Module
Annotated and partitioned module.
"""

if params:
mod["main"] = bind_params_by_name(mod["main"], params)
from tvm.relay.testing.temp_op_attr import TempOpAttr

with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_group_conv):
with TempOpAttr("nn.conv2d_transpose", "FTVMLegalize", legalize_group_conv):
Expand All @@ -443,8 +453,6 @@ def partition_for_dnnl(mod, params=None, alter_layout=True):
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
if alter_layout:
from tvm.relay.testing.temp_op_attr import TempOpAttr

with TempOpAttr("nn.conv1d", "FTVMAlterOpLayout", alter_conv):
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv):
with TempOpAttr("nn.conv3d", "FTVMAlterOpLayout", alter_conv):
Expand Down
16 changes: 8 additions & 8 deletions src/relay/backend/contrib/dnnl/query_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ std::string md2fmt_tag_str(const dnnl::memory::desc* md) {
return s;
}

dnnl::memory::dims str2dims(const std::string& str_shape,
bool dilates = false,
dnnl::memory::dims str2dims(const std::string& str_shape, bool dilates = false,
std::string interval = ",") {
// Split strings
std::vector<std::string> str_dims;
Expand All @@ -164,21 +163,21 @@ dnnl::memory::dims str2dims(const std::string& str_shape,
dnnl::memory::dims out_dims;
if (dilates) {
std::transform(str_dims.begin(), str_dims.end(), std::back_inserter(out_dims),
[](const std::string& str) { return std::stoi(str) - 1; });
[](const std::string& str) { return std::stoi(str) - 1; });
} else {
std::transform(str_dims.begin(), str_dims.end(), std::back_inserter(out_dims),
[](const std::string& str) { return std::stoi(str); });
[](const std::string& str) { return std::stoi(str); });
}
return out_dims;
}

void check_shapes(const std::vector<std::string> shapes) {
std::regex valid_pat("(\\d*)(,(\\d*))*");
bool checked = std::regex_match(shapes[0], valid_pat);
for (size_t i = 1; i < shapes.size()-1; i++) {
for (size_t i = 1; i < shapes.size() - 1; i++) {
checked &= std::regex_match(shapes[i], valid_pat);
}
checked &= std::regex_match(shapes[shapes.size()-1], std::regex("\\d*"));
checked &= std::regex_match(shapes[shapes.size() - 1], std::regex("\\d*"));
if (!checked) {
LOG(FATAL) << "Invalid input args for query dnnl optimal layout.";
}
Expand All @@ -193,7 +192,7 @@ void check_layout(bool var, bool ref) {
std::string get_optimal_layout_for_conv(std::string data_layout, std::string kernel_layout,
std::string weight_shape, std::string out_shape,
std::string paddings, std::string strides,
std::string dilates, std::string G) {
std::string dilates, std::string G) {
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true);
check_shapes({weight_shape, out_shape, paddings, strides, dilates, G});
Expand Down Expand Up @@ -272,7 +271,8 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker
return res;
}

std::string get_optimal_layout_for_conv_transpose(std::string data_layout, std::string kernel_layout,
std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
std::string kernel_layout,
std::string weight_shape, std::string out_shape,
std::string paddings, std::string output_paddings,
std::string strides, std::string dilates,
Expand Down
3 changes: 1 addition & 2 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
padding_dims_l, padding_dims_r);

// Enable elementwise post-ops.
auto deconv_prim_desc =
dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_);
auto deconv_prim_desc = dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_);

// Push to the network.
auto deconv = dnnl::deconvolution_forward(deconv_prim_desc);
Expand Down

0 comments on commit 56ef1a1

Please sign in to comment.