diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index a38603666dd5..2bd486cb8671 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -126,13 +126,16 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {"GIODHW", tag::giodhw}, // Blocking layout. + {"NCW8c", tag::nCw8c}, {"NCW16c", tag::nCw16c}, {"OIW16i16o", tag::OIw16i16o}, + {"OWI8o", tag::Owi8o}, {"OWI16o", tag::Owi16o}, {"NCHW4c", tag::nChw4c}, {"NCHW8c", tag::nChw8c}, {"NCHW16c", tag::nChw16c}, {"OIHW8i8o", tag::OIhw8i8o}, + {"IOHW8i8o", tag::any}, {"OIHW16i16o", tag::OIhw16i16o}, {"IOHW16i16o", tag::IOhw16i16o}, {"GOIHW4i4o", tag::gOIhw4i4o}, @@ -145,9 +148,12 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {"OHWI64o", tag::Ohwi64o}, {"GOIHW8g", tag::Goihw8g}, {"GOIHW16g", tag::Goihw16g}, + {"NCDHW8c", tag::nCdhw8c}, {"NCDHW16c", tag::nCdhw16c}, {"OIDHW16i16o", tag::OIdhw16i16o}, {"IODHW16i16o", tag::IOdhw16i16o}, + {"OIDHW8i8o", tag::OIdhw8i8o}, + {"IODHW8i8o", tag::any}, {"ODHWI16o", tag::Odhwi16o}, }; @@ -391,7 +397,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto conv_src_memory = BindDNNLMemory(data_entry, conv_src_md); // Weight memory. - auto conv_weights_memory = BindDNNLMemory(weight_entry, conv_weights_md); + auto conv_weights_memory = BindDNNLMemory(weight_entry, conv_prim_desc.weights_desc()); // Output memory. auto conv_dst_memory = BindDNNLMemory(out_entry, conv_prim_desc.dst_desc()); @@ -448,7 +454,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { // Check layout. if (layout_dict.find(data_layout) == layout_dict.end() || layout_dict.find(kernel_layout) == layout_dict.end()) { - LOG(FATAL) << "Unsupported layout: " << data_layout << " " << kernel_layout; + LOG(FATAL) << "Unsupported layout for deconv: " << data_layout << " " << kernel_layout; } // Memory shapes. @@ -514,7 +520,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto deconv_src_memory = BindDNNLMemory(data_entry, deconv_src_md); // Weight memory. - auto deconv_weights_memory = BindDNNLMemory(weight_entry, deconv_weights_md); + auto deconv_weights_memory = BindDNNLMemory(weight_entry, deconv_prim_desc.weights_desc()); // Output memory. auto deconv_dst_memory = BindDNNLMemory(out_entry, deconv_prim_desc.dst_desc());