Skip to content

Commit

Permalink
fix nightly CI failed (#1784)
Browse files Browse the repository at this point in the history
* Fix onnxruntime-nightly-unittest-matrix CI failed.

Co-authored-by: hwangdeyu [email protected]
Signed-off-by: hwangdeyu <[email protected]>
  • Loading branch information
hwangdeyu committed Nov 24, 2021
1 parent 65aaa2c commit 6b8411b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 26 deletions.
2 changes: 1 addition & 1 deletion tests/test_backend.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2830,7 +2830,7 @@ def func(x):
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04)

@check_opset_min_version(7, "batchnorm")
@check_tf_min_version("2.0", "tf-1.x does not support NDHWC")
@check_tf_min_version("2.4", "tf version above 2.4 supports NDHWC")
def test_fused_batchnorm_3d(self):
x_shape = [1, 28, 28, 2, 2]
x_dtype = np.float32
Expand Down
58 changes: 33 additions & 25 deletions tf2onnx/rewriter/conv2d_with_add_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,39 @@
# pylint: disable=missing-docstring

def rewrite_biasadd_with_conv2d(g, ops):
pattern = \
pattern1 = \
OpTypePattern('BiasAdd', name='biasadd', inputs=[
OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*']), '*'])
matcher = GraphMatcher(pattern)
match_results = list(matcher.match_ops(ops))
for match in match_results:
biasadd = match.get_op('biasadd')
conv = match.get_op('conv')

#backup the conv and biasadd values
conv_type = conv.type
conv_input = conv.input
conv_attr = conv.attr
dtype = g.get_dtype(conv.output[0])
shape = g.get_shape(conv.output[0])
conv_name = biasadd.name
conv_output = biasadd.output
conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]]

if len(g.find_output_consumers(conv.output[0])) > 1:
continue
# Remove the Conv and BiasAdd node
g.remove_node(conv.name)
g.remove_node(biasadd.name)

g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output,
shapes=[shape], dtypes=[dtype], skip_conversion=False)
pattern2 = \
OpTypePattern('BiasAdd', name='biasadd', inputs=[
OpTypePattern('Conv2D|Conv2DBackpropInput', name='conv', inputs=['*', '*', '*']), '*'], allow_reorder = True)

for pattern in [pattern1, pattern2]:
matcher = GraphMatcher(pattern)
match_results = list(matcher.match_ops(ops))
for match in match_results:
biasadd = match.get_op('biasadd')
conv = match.get_op('conv')

#backup the conv and biasadd values
conv_type = conv.type
conv_input = conv.input
conv_attr = conv.attr
dtype = g.get_dtype(conv.output[0])
shape = g.get_shape(conv.output[0])
conv_name = biasadd.name
conv_output = biasadd.output
if pattern == pattern2:
conv_inputs = [conv_input[0], conv_input[1], conv_input[2], biasadd.input[1]]
else:
conv_inputs = [conv_input[0], conv_input[1], biasadd.input[1]]

if len(g.find_output_consumers(conv.output[0])) > 1:
continue
# Remove the Conv and BiasAdd node
g.remove_node(conv.name)
g.remove_node(biasadd.name)

g.make_node(conv_type, conv_inputs, attr=conv_attr, name=conv_name, outputs=conv_output,
shapes=[shape], dtypes=[dtype], skip_conversion=False)
return ops

0 comments on commit 6b8411b

Please sign in to comment.