Skip to content

Commit

Permalink
Transpose optimization for Softmax and LogSoftmax (fixes #1716) (#1964)
Browse files Browse the repository at this point in the history
* Transpose optimization for Softmax and LogSoftmax (fixes #1716)

In opsets 13 and higher, the axis of the operation is arbitrary and can simply be changed according to the permutation of the Transpose.
In lower opsets, Softmax always coerces its inputs to a 2D tensor, making Transpose operations necessary if the permutation moves axes between the coerced batch and feature dimensions.

Signed-off-by: janbernloehr <[email protected]>

Co-authored-by: fthielke <[email protected]>
  • Loading branch information
janbernloehr and fthielke authored Jun 11, 2022
1 parent a8f78ac commit 9cea907
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 0 deletions.
124 changes: 124 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,130 @@ def test_transpose_argmax(self):
self.run_transpose_compare(["res"], {"X": np.random.randn(*input_shape).astype(np.float32)},
model_proto, remaining_transpose_num=0)

@check_opset_max_version(
12, "Before opset 13, Softmax coerced its inputs to 2D and can thus only be optimized for certain permutations"
)
def test_transpose_softmax_valid_perm(self):
input_shape = [4, 4, 4, 4]
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node1 = helper.make_node("Softmax", ["Y"], ["Z"], axis=1, name="softmax")
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node0, node1, node2],
"transpose-softmax-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=0
)

@check_opset_max_version(
12, "Before opset 13, Softmax coerced its inputs to 2D and can thus only be optimized for certain permutations"
)
def test_transpose_softmax_invalid_perm(self):
input_shape = [4, 4, 4, 4]
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node1 = helper.make_node("Softmax", ["Y"], ["Z"], axis=3, name="softmax")
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node0, node1, node2],
"transpose-softmax-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=2
)

@check_opset_min_version(13, "Softmax can be optimized for all permutations since opset 13")
def test_transpose_softmax_13(self):
input_shape = [4, 4, 4, 4]
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node1 = helper.make_node("Softmax", ["Y"], ["Z"], axis=3, name="softmax")
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node0, node1, node2],
"transpose-softmax-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=0
)

@check_opset_max_version(
12,
"Before opset 13, LogSoftmax coerced its inputs to 2D and can thus only be optimized for certain permutations",
)
def test_transpose_logsoftmax_valid_perm(self):
input_shape = [4, 4, 4, 4]
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node1 = helper.make_node("LogSoftmax", ["Y"], ["Z"], axis=1, name="logsoftmax")
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node0, node1, node2],
"transpose-logsoftmax-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=0
)

@check_opset_max_version(
12,
"Before opset 13, LogSoftmax coerced its inputs to 2D and can thus only be optimized for certain permutations",
)
def test_transpose_logsoftmax_invalid_perm(self):
input_shape = [4, 4, 4, 4]
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node1 = helper.make_node("LogSoftmax", ["Y"], ["Z"], axis=3, name="logsoftmax")
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node0, node1, node2],
"transpose-logsoftmax-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=2
)

@check_opset_min_version(13, "LogSoftmax can be optimized for all permutations since opset 13")
def test_transpose_logsoftmax_13(self):
input_shape = [4, 4, 4, 4]
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
node1 = helper.make_node("LogSoftmax", ["Y"], ["Z"], axis=3, name="logsoftmax")
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")

graph = helper.make_graph(
[node0, node1, node2],
"transpose-logsoftmax-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("res", TensorProto.FLOAT, input_shape)],
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(
["res"], {"X": np.random.randn(*input_shape).astype(np.float32)}, model_proto, remaining_transpose_num=0
)

def test_transpose_tile(self):
input_shape = [1, 2, 3, 4]

Expand Down
24 changes: 24 additions & 0 deletions tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def _initialize_handlers(self):
"Identity": self._identity_handler,
"LeakyRelu": self._simple_through_handler,
"Log": self._simple_through_handler,
"LogSoftmax": self._softmax_handler,
"Max": self._maxmin_handler,
"Min": self._maxmin_handler,
"Mul": self._mul_handler,
Expand All @@ -223,6 +224,7 @@ def _initialize_handlers(self):
"Relu": self._simple_through_handler,
"Shape": self._shape_handler,
"Sigmoid": self._simple_through_handler,
"Softmax": self._softmax_handler,
"Sum": self._sum_handler,
"Slice": self._slice_handler,
"Split": self._split_handler,
Expand Down Expand Up @@ -827,6 +829,28 @@ def permute_pads(pads):
def _prelu_handler(self, trans, node):
return self._handle_node_having_branches(trans, node)

def _softmax_handler(self, trans, node):
trans_rank = get_transpose_rank(trans)
perm = trans.get_attr("perm").ints

if self._g.opset >= 13:
# Softmax operates on an arbitrary axis since opset 13
axis = node.get_attr_value("axis", -1)
new_axis = perm[axis + trans_rank if axis < 0 else axis]
if not self._switch_transpose_and_node(node, trans):
return False
node.set_attr("axis", new_axis)
return True

# For older opsets, the "axis" attribute determines the coercion point for coercing the input tensor to 2D.
# We can safely switch transpose and node if the permutation does not make any axes cross that boundary.
coercion_axis = node.get_attr_value("axis", 1)
for from_axis, to_axis in enumerate(perm):
if (from_axis < coercion_axis <= to_axis) or (from_axis >= coercion_axis > to_axis):
return False

return self._switch_transpose_and_node(node, trans)

def _arg_min_max_handler(self, trans, node):
axis = node.get_attr_value("axis", 0)
node.set_attr("axes", [axis])
Expand Down

0 comments on commit 9cea907

Please sign in to comment.