Skip to content

Commit

Permalink
add split transpose optimizer tests
Browse files Browse the repository at this point in the history
Signed-off-by: Deyu Huang <[email protected]>
Co-authored-by: Jay Zhang <[email protected]>
  • Loading branch information
hwangdeyu and fatcat-z committed Apr 25, 2022
1 parent 310949e commit 9eb4694
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,32 @@ def test_transpose_with_concat(self, input_shape, perm, inner_perm):
}
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=1)

@parameterized.expand([
((2, 3, 4, 5), [0, 3, 1, 2], [0, 2, 3, 1]),
((2, 3, 4, 5, 6), [0, 4, 1, 2, 3], [0, 2, 3, 4, 1]),
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
])
def test_transpose_with_split(self, input_shape, perm, inner_perm):
input_shape_with_trans = [input_shape[i] for i in perm]
for axis in range(len(input_shape)):
output_before_trans = list(input_shape)
output_shape = [output_before_trans[i] for i in perm]
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=inner_perm, name="trans1")
node2 = helper.make_node("Split", ["Y"], ["Z"], axis=axis, name="split")
node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=perm, name="trans2")

graph = helper.make_graph(
[node1, node2, node3],
"test_transpose_with_split",
[helper.make_tensor_value_info("X", TensorProto.INT64, input_shape_with_trans)],
[helper.make_tensor_value_info("res", TensorProto.INT64, output_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")

feed_dict = {"X": np.random.randn(*input_shape_with_trans).astype(np.int64)}
self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=0)

@parameterized.expand([
((2, 3, 4), [2, 0, 1], [1, 2, 0]),
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
Expand Down

0 comments on commit 9eb4694

Please sign in to comment.