From e099356bc8b7ceedec0e99a101de199a5deccfe4 Mon Sep 17 00:00:00 2001 From: Jay Zhang <36183870+fatcat-z@users.noreply.github.com> Date: Wed, 25 May 2022 17:10:46 +0800 Subject: [PATCH] Update the way to check input_signature in from_function(). (#1947) * Update the way to check input_signature in from_function(). Signed-off-by: Jay Zhang --- tests/test_api.py | 17 +++++++++++++++++ tf2onnx/convert.py | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/test_api.py b/tests/test_api.py index 3bf170f03..e81d73946 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -173,6 +173,23 @@ def func(foo, a, x, b, w): res_onnx = self.run_onnxruntime(output_path, {"x": x, "w": w}, output_names) self.assertAllClose(res_tf, res_onnx[0], rtol=1e-5, atol=1e-5) + @check_tf_min_version("2.0") + def test_function_nparray(self): + @tf.function + def func(x): + return tf.math.sqrt(x) + + output_path = os.path.join(self.test_data_directory, "model.onnx") + x = np.asarray([1.0, 2.0]) + + res_tf = func(x) + spec = np.asarray([[1.0, 2.0]]) + model_proto, _ = tf2onnx.convert.from_function(func, input_signature=spec, + opset=self.config.opset, output_path=output_path) + output_names = [n.name for n in model_proto.graph.output] + res_onnx = self.run_onnxruntime(output_path, {'x': x}, output_names) + self.assertAllClose(res_tf, res_onnx[0], rtol=1e-5, atol=1e-5) + @check_tf_min_version("1.15") def _test_graphdef(self): def func(x, y): diff --git a/tf2onnx/convert.py b/tf2onnx/convert.py index bdf7df58f..455224a12 100644 --- a/tf2onnx/convert.py +++ b/tf2onnx/convert.py @@ -535,7 +535,7 @@ def from_function(function, input_signature=None, opset=None, custom_ops=None, c if LooseVersion(tf.__version__) < "2.0": raise NotImplementedError("from_function requires tf-2.0 or newer") - if not input_signature: + if input_signature is None: raise ValueError("from_function requires input_signature") concrete_func = function.get_concrete_function(*input_signature)