diff --git a/tests/test_backend.py b/tests/test_backend.py old mode 100644 new mode 100755 index 8bf466ea7..9e00f9e56 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2830,7 +2830,7 @@ def func(x): self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04) @check_opset_min_version(7, "batchnorm") - @check_tf_min_version("2.0", "tf-1.x does not support NDHWC") + @check_tf_min_version("2.4", "tf version above 2.4 supports NDHWC") def test_fused_batchnorm_3d(self): x_shape = [1, 28, 28, 2, 2] x_dtype = np.float32 diff --git a/tf2onnx/rewriter/conv2d_with_add_rewriter.py b/tf2onnx/rewriter/conv2d_with_add_rewriter.py index aa941d75b..5fd516935 100644 --- a/tf2onnx/rewriter/conv2d_with_add_rewriter.py +++ b/tf2onnx/rewriter/conv2d_with_add_rewriter.py @@ -13,31 +13,39 @@ # pylint: disable=missing-docstring def rewrite_biasadd_with_conv2d(g, ops): - pattern = \ + pattern1 = \ OpTypePattern('BiasAdd', name='biasadd', inputs=[ OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*']) - matcher = GraphMatcher(pattern) - match_results = list(matcher.match_ops(ops)) - for match in match_results: - biasadd = match.get_op('biasadd') - conv = match.get_op('conv') - - #backup the conv and biasadd values - conv_type = conv.type - conv_input = conv.input - conv_attr = conv.attr - dtype = g.get_dtype(conv.output[0]) - shape = g.get_shape(conv.output[0]) - conv_name = biasadd.name - conv_output = biasadd.output - conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]] - - if len(g.find_output_consumers(conv.output[0])) > 1: - continue - # Remove the Conv and BiasAdd node - g.remove_node(conv.name) - g.remove_node(biasadd.name) - - g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output, - shapes=[shape], dtypes=[dtype], skip_conversion=False) + pattern2 = \ + OpTypePattern('BiasAdd', name='biasadd', inputs=[ + OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*', '*']), '*'], allow_reorder = True) + + for pattern in [pattern1, pattern2]: + matcher = GraphMatcher(pattern) + match_results = list(matcher.match_ops(ops)) + for match in match_results: + biasadd = match.get_op('biasadd') + conv = match.get_op('conv') + + #backup the conv and biasadd values + conv_type = conv.type + conv_input = conv.input + conv_attr = conv.attr + dtype = g.get_dtype(conv.output[0]) + shape = g.get_shape(conv.output[0]) + conv_name = biasadd.name + conv_output = biasadd.output + if pattern == pattern2: + conv_inputs = [conv_input[0], conv_input[1], conv_input[2], biasadd.input[1]] + else: + conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]] + + if len(g.find_output_consumers(conv.output[0])) > 1: + continue + # Remove the Conv and BiasAdd node + g.remove_node(conv.name) + g.remove_node(biasadd.name) + + g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output, + shapes=[shape], dtypes=[dtype], skip_conversion=False) return ops