Skip to content

Commit

Permalink
Update the way to check input_signature in from_function(). (#1947)
Browse files Browse the repository at this point in the history
* Update the way to check input_signature in from_function().

Signed-off-by: Jay Zhang <[email protected]>
  • Loading branch information
fatcat-z authored May 25, 2022
1 parent 880754e commit e099356
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
17 changes: 17 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,23 @@ def func(foo, a, x, b, w):
res_onnx = self.run_onnxruntime(output_path, {"x": x, "w": w}, output_names)
self.assertAllClose(res_tf, res_onnx[0], rtol=1e-5, atol=1e-5)

@check_tf_min_version("2.0")
def test_function_nparray(self):
@tf.function
def func(x):
return tf.math.sqrt(x)

output_path = os.path.join(self.test_data_directory, "model.onnx")
x = np.asarray([1.0, 2.0])

res_tf = func(x)
spec = np.asarray([[1.0, 2.0]])
model_proto, _ = tf2onnx.convert.from_function(func, input_signature=spec,
opset=self.config.opset, output_path=output_path)
output_names = [n.name for n in model_proto.graph.output]
res_onnx = self.run_onnxruntime(output_path, {'x': x}, output_names)
self.assertAllClose(res_tf, res_onnx[0], rtol=1e-5, atol=1e-5)

@check_tf_min_version("1.15")
def _test_graphdef(self):
def func(x, y):
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def from_function(function, input_signature=None, opset=None, custom_ops=None, c
if LooseVersion(tf.__version__) < "2.0":
raise NotImplementedError("from_function requires tf-2.0 or newer")

if not input_signature:
if input_signature is None:
raise ValueError("from_function requires input_signature")

concrete_func = function.get_concrete_function(*input_signature)
Expand Down

0 comments on commit e099356

Please sign in to comment.