Skip to content

Commit

Permalink
Fix wrong shapes in loop body inputs if shape invariances are set in …
Browse files Browse the repository at this point in the history
…TF (#2203)

* fix wrong shapes in loop body inputs if shape invariances are set in TF
* fix and enable test for TF2

---------

Signed-off-by: f-salvetti <[email protected]>
  • Loading branch information
f-salvetti authored Jul 26, 2023
1 parent d0ba20e commit 6dda2bb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
8 changes: 3 additions & 5 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tensorflow as tf

from backend_test_base import Tf2OnnxBackendTestBase
from common import unittest_main, check_tf_min_version, check_tf_max_version, \
from common import unittest_main, check_tf_min_version, \
check_onnxruntime_min_version, check_tfjs_max_version, skip_tflite
from tf2onnx.tf_loader import is_tf2

Expand Down Expand Up @@ -286,15 +286,13 @@ def func(x, y):
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-5)

@check_tf_min_version("1.9")
@check_tf_max_version("1.15")
@skip_tflite("infinite loop with tflite")
def test_simple_while_loop_var_shape(self):
# test for while_loop with variant shape variables
# may not meet ONNX Loop spec
# Note: this is not working on tf2 itself.
def func(i):
const = tf.constant(np.array([2], dtype=np.int32))
c = lambda i: tf.reduce_all(tf.shape(i) < 10)
b = lambda i: tf.concat([i, const], 0)
b = lambda i: [tf.concat([i, const], 0)]
r = tf.while_loop(c, b, [i], shape_invariants=[tf.TensorShape([None])])
return tf.identity(r, name="output")
input_names_with_port = ["input_1:0"]
Expand Down
3 changes: 2 additions & 1 deletion tf2onnx/onnx_opset/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,8 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
g.set_dtype(func_inputs[0], onnx_pb.TensorProto.INT64)
g.inputs = [g.get_node_by_output(inp) for inp in func_inputs]

for p, c in zip(loop_node.input, func_inputs):
# we should use outputs shape, not inputs, since there may be shape invariants
for p, c in zip(loop_node.output, func_inputs[2:]):
g.copy_shape(p, c)

for i, node in enumerate(g.inputs):
Expand Down

0 comments on commit 6dda2bb

Please sign in to comment.