From c8abc8db03aa7453a79568439d10c3f970bc4f5a Mon Sep 17 00:00:00 2001 From: Deyu Huang Date: Fri, 29 Apr 2022 11:53:32 +0800 Subject: [PATCH] add dynamic split optimizer test Signed-off-by: Deyu Huang --- tests/test_optimizers.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index d419bfc70..acbc748af 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -133,15 +133,34 @@ def test_transpose_with_split(self, input_shape, perm, inner_perm): graph = helper.make_graph( [node1, node2, node3], "test_transpose_with_split", - [helper.make_tensor_value_info("X", TensorProto.INT64, input_shape_with_trans)], - [helper.make_tensor_value_info("res", TensorProto.INT64, output_shape)], + [helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape_with_trans)], + [helper.make_tensor_value_info("res", TensorProto.FLOAT, output_shape)], ) model_proto = self.make_model(graph, producer_name="onnx-tests") - - feed_dict = {"X": np.random.randn(*input_shape_with_trans).astype(np.int64)} + feed_dict = {"X": np.random.randn(*input_shape_with_trans).astype(np.float32)} self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=0) + @parameterized.expand([ + ((1, -1), (1, 1710), (1710,), [1, 0]), + ((3, 1, 1, 5, -1), (3, 1, 1, 5, 6), (3, 5, 6), [0, 2, 3, 4, 1]), + ]) + def test_transpose_with_split_dynamic_shape(self, input_shape, specific_input, output_shape, perm): + node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") + node2 = helper.make_node("Split", ["Y"], ["Z"], axis=1, split=[1], name="split") + node3 = helper.make_node("Squeeze", ["Z"], ["B"], name="squeeze") + + graph = helper.make_graph( + [node1, node2, node3], + "test_transpose_with_split_dynamic_shape", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)], + [helper.make_tensor_value_info("B", TensorProto.FLOAT, output_shape)], + ) + + model_proto = self.make_model(graph, producer_name="onnx-tests") + self.run_transpose_compare(["B"], {"X": np.random.randn(*specific_input).astype(np.float32)}, + model_proto, remaining_transpose_num=0) + @parameterized.expand([ ((2, 3, 4), [2, 0, 1], [1, 2, 0]), ((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),