diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 06dd3c0310644..c51de23c19aa5 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -38,6 +38,7 @@ from tvm import relay from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name +from tvm.relay.testing.temp_op_attr import TempOpAttr from ... import _ffi_api from ...dataflow_pattern import wildcard, is_op @@ -211,14 +212,16 @@ def pattern_table(): return dnnl_patterns -def get_optimal_layout_for_conv(data_layout, kernel_layout, weight_shape, - out_shape, paddings, strides, dilates, groups): +def get_optimal_layout_for_conv( + data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups +): """Get the optimal layout of dnnl, given shape of conv2d. Parameters ---------- - data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups :String - Input argument. + data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups + : String + Input argument. Returns ------- @@ -238,13 +241,22 @@ def get_optimal_layout_for_conv(data_layout, kernel_layout, weight_shape, def get_optimal_layout_for_conv_transpose( - data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups + data_layout, + kernel_layout, + weight_shape, + out_shape, + paddings, + output_paddings, + strides, + dilates, + groups, ): """Get the optimal layout of dnnl, given shape of tranposed conv2d. Parameters ---------- - data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides, dilates, groups + data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides, + dilates, groups : Int, String Input argument. @@ -255,7 +267,7 @@ def get_optimal_layout_for_conv_transpose( """ return _ffi_api.get_optimal_layout_for_conv_transpose( data_layout, - kernel_layout, + kernel_layout, weight_shape, out_shape, paddings, @@ -270,16 +282,15 @@ def get_shape(tensor): """Get tensor's shape.""" if isinstance(tensor, relay.expr.Var): return tensor.type_annotation.concrete_shape - elif isinstance(tensor, relay.expr.Constant): + if isinstance(tensor, relay.expr.Constant): return tensor.data.shape - elif isinstance(tensor, tvm.ir.tensor_type.TensorType): + if isinstance(tensor, tvm.ir.tensor_type.TensorType): return tensor.concrete_shape - elif isinstance(tensor, tvm.ir.container.Array): + if isinstance(tensor, tvm.ir.container.Array): return tensor[-1].shape - elif isinstance(tensor, relay.expr.Call): + if isinstance(tensor, relay.expr.Call): return tensor.checked_type.shape - else: - raise TypeError("Unsupport data type: %s" % type(tensor)) + raise TypeError("Unsupport data type: %s" % type(tensor)) def tag2layout(input_data, is_weight=False, conv_type="Conv1D"): @@ -318,18 +329,19 @@ def legalize_group_conv(attrs, inputs, types): """Legalize group conv / conv_transpose calculation. Alter weight layout from OIHW to GOIHW / IOHW to GIOHW""" groups = attrs.groups - if groups == 1: - return data, weight = inputs + if groups == 1: + if "Transpose" not in type(attrs).__name__: + return relay.nn.conv2d(data, weight, **attrs) + return relay.nn.conv2d_transpose(data, weight, **attrs) OC, IC, H, W = get_shape(weight) new_attrs = dict(attrs) weight = relay.reshape(weight, (groups, OC // groups, IC, H, W)) if "Transpose" not in type(attrs).__name__: new_attrs["kernel_layout"] = "GOIHW" return relay.nn.conv2d(data, weight, **new_attrs) - else: - new_attrs["kernel_layout"] = "GIOHW" - return relay.nn.conv2d_transpose(data, weight, **new_attrs) + new_attrs["kernel_layout"] = "GIOHW" + return relay.nn.conv2d_transpose(data, weight, **new_attrs) def alter_conv(attrs, inputs, tinfos, out_type): @@ -346,22 +358,25 @@ def alter_conv(attrs, inputs, tinfos, out_type): conv_type = type(attrs).__name__.split("Attrs")[0] res = get_optimal_layout_for_conv( - attrs["data_layout"], attrs["kernel_layout"], weight_shape, out_shape, paddings, - strides, dilates, groups, + attrs["data_layout"], + attrs["kernel_layout"], + weight_shape, + out_shape, + paddings, + strides, + dilates, + groups, ) src_df, weight_df, dst_df = res.split(",") new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type) - new_attrs["kernel_layout"] = tag2layout( - weight_df, is_weight=True, conv_type=conv_type - ) + new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type) new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type) if conv_type == "Conv1D": return relay.nn.conv1d(data, weight, **new_attrs) - elif conv_type == "Conv2D": + if conv_type == "Conv2D": return relay.nn.conv2d(data, weight, **new_attrs) - elif conv_type == "Conv3D": - return relay.nn.conv3d(data, weight, **new_attrs) + return relay.nn.conv3d(data, weight, **new_attrs) def alter_conv_transpose(attrs, inputs, tinfos, out_type): @@ -380,7 +395,7 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type): res = get_optimal_layout_for_conv_transpose( attrs["data_layout"], - attrs["kernel_layout"], + attrs["kernel_layout"], weight_shape, out_shape, paddings, @@ -391,17 +406,14 @@ def alter_conv_transpose(attrs, inputs, tinfos, out_type): ) src_df, weight_df, dst_df = res.split(",") new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type) - new_attrs["kernel_layout"] = tag2layout( - weight_df, is_weight=True, conv_type=conv_type - ) + new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type) new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type) if conv_type == "Conv1DTranspose": return relay.nn.conv1d_transpose(data, weight, **new_attrs) - elif conv_type == "Conv2DTranspose": + if conv_type == "Conv2DTranspose": return relay.nn.conv2d_transpose(data, weight, **new_attrs) - elif conv_type == "Conv3DTranspose": - return relay.nn.conv3d_transpose(data, weight, **new_attrs) + return relay.nn.conv3d_transpose(data, weight, **new_attrs) def partition_for_dnnl(mod, params=None, alter_layout=True): @@ -418,10 +430,8 @@ def partition_for_dnnl(mod, params=None, alter_layout=True): mod : Module Annotated and partitioned module. """ - if params: mod["main"] = bind_params_by_name(mod["main"], params) - from tvm.relay.testing.temp_op_attr import TempOpAttr with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_group_conv): with TempOpAttr("nn.conv2d_transpose", "FTVMLegalize", legalize_group_conv): @@ -443,8 +453,6 @@ def partition_for_dnnl(mod, params=None, alter_layout=True): with tvm.transform.PassContext(opt_level=3): mod = seq(mod) if alter_layout: - from tvm.relay.testing.temp_op_attr import TempOpAttr - with TempOpAttr("nn.conv1d", "FTVMAlterOpLayout", alter_conv): with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv): with TempOpAttr("nn.conv3d", "FTVMAlterOpLayout", alter_conv): diff --git a/src/relay/backend/contrib/dnnl/query_layout.cc b/src/relay/backend/contrib/dnnl/query_layout.cc index b1a392bfca1fa..eaea79c62d7a4 100755 --- a/src/relay/backend/contrib/dnnl/query_layout.cc +++ b/src/relay/backend/contrib/dnnl/query_layout.cc @@ -146,8 +146,7 @@ std::string md2fmt_tag_str(const dnnl::memory::desc* md) { return s; } -dnnl::memory::dims str2dims(const std::string& str_shape, - bool dilates = false, +dnnl::memory::dims str2dims(const std::string& str_shape, bool dilates = false, std::string interval = ",") { // Split strings std::vector str_dims; @@ -164,10 +163,10 @@ dnnl::memory::dims str2dims(const std::string& str_shape, dnnl::memory::dims out_dims; if (dilates) { std::transform(str_dims.begin(), str_dims.end(), std::back_inserter(out_dims), - [](const std::string& str) { return std::stoi(str) - 1; }); + [](const std::string& str) { return std::stoi(str) - 1; }); } else { std::transform(str_dims.begin(), str_dims.end(), std::back_inserter(out_dims), - [](const std::string& str) { return std::stoi(str); }); + [](const std::string& str) { return std::stoi(str); }); } return out_dims; } @@ -175,10 +174,10 @@ dnnl::memory::dims str2dims(const std::string& str_shape, void check_shapes(const std::vector shapes) { std::regex valid_pat("(\\d*)(,(\\d*))*"); bool checked = std::regex_match(shapes[0], valid_pat); - for (size_t i = 1; i < shapes.size()-1; i++) { + for (size_t i = 1; i < shapes.size() - 1; i++) { checked &= std::regex_match(shapes[i], valid_pat); } - checked &= std::regex_match(shapes[shapes.size()-1], std::regex("\\d*")); + checked &= std::regex_match(shapes[shapes.size() - 1], std::regex("\\d*")); if (!checked) { LOG(FATAL) << "Invalid input args for query dnnl optimal layout."; } @@ -193,7 +192,7 @@ void check_layout(bool var, bool ref) { std::string get_optimal_layout_for_conv(std::string data_layout, std::string kernel_layout, std::string weight_shape, std::string out_shape, std::string paddings, std::string strides, - std::string dilates, std::string G) { + std::string dilates, std::string G) { check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true); check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true); check_shapes({weight_shape, out_shape, paddings, strides, dilates, G}); @@ -272,7 +271,8 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker return res; } -std::string get_optimal_layout_for_conv_transpose(std::string data_layout, std::string kernel_layout, +std::string get_optimal_layout_for_conv_transpose(std::string data_layout, + std::string kernel_layout, std::string weight_shape, std::string out_shape, std::string paddings, std::string output_paddings, std::string strides, std::string dilates, diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 26144c4f8f8fc..951f8ffdb52e4 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -505,8 +505,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { padding_dims_l, padding_dims_r); // Enable elementwise post-ops. - auto deconv_prim_desc = - dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_); + auto deconv_prim_desc = dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_); // Push to the network. auto deconv = dnnl::deconvolution_forward(deconv_prim_desc);