Skip to content

Commit

Permalink
Support multiple output for first Transpose in MergeCommonSequence, i…
Browse files Browse the repository at this point in the history
…mprove TransposeFanOut (#81)
  • Loading branch information
jiafatom authored and wenbingl committed May 14, 2020
1 parent fec314c commit 9100554
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions onnxconverter_common/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def apply(self, node_list):
FanOutSolution.number = FanOutSolution.number + 1
node_list = Solution.add_siso_node(node_list, self.end_p, suc, list(self.end_p.output.values())[0], nnode)

node_list = Solution.delete_node_1ton(node_list, self.begin, self.begin_n, self.end_p)
node_list = Solution.delete_node_1ton(node_list, self.begin, self.begin_n, self.end)
return node_list, True


Expand Down Expand Up @@ -799,22 +799,16 @@ def find(node):
else:
delta_node = delta_node + 1
if delta_node <= 0:
solution = FanOutSolution(node.get_precedence_by_idx(0), node, next_node, None)
solution = FanOutSolution(node.get_precedence_by_idx(0), node, next_node, next_node)
return solution
else: # simo Transpose op
simo_transpose_case = True
cur_perm = None
for succ_ in node.successor:
if not succ_.is_transpose:
simo_transpose_case = False
break
if not cur_perm:
cur_perm = Solution.get_perm(succ_.origin)
elif cur_perm != Solution.get_perm(succ_.origin):
simo_transpose_case = False
break
if simo_transpose_case and match_perm(perm, cur_perm):
solution = TransposeFanOutSolution(node.get_precedence_by_idx(0), node, None, None)
if simo_transpose_case:
solution = FanOutSolution(node.get_precedence_by_idx(0), node, node, node.successor)
return solution
elif node.is_transpose_switchable_mi:
branch_perm = []
Expand Down Expand Up @@ -998,7 +992,7 @@ def _update_broadcast_from_initializers(node, init_pred_value, cur_perm, init_id
return node


_broadcast_flip_whitelist = {'Transpose', 'Conv', 'BatchNormalization', 'Resize', 'Relu', 'Reshape', 'Add'}
_broadcast_flip_whitelist = {'Transpose', 'Conv', 'BatchNormalization', 'Resize', 'Relu', 'Reshape', 'Add', 'Mul'}


def _get_broadcast_info(node, node_transpose_pass_name, cur_perm_map):
Expand Down Expand Up @@ -1192,7 +1186,7 @@ def apply(self, node_list):
candidate_queue = list()
visited = set()
for successor_ in self.begin_n.successor:
candidate_queue.append((successor_, self.begin))
candidate_queue.append((successor_, self.begin_n))
node_transpose_no_pass = list()
node_transpose_pass = list()
node_transpose_pass_name = {self.begin_n.unique_name}
Expand Down Expand Up @@ -1221,11 +1215,11 @@ def apply(self, node_list):
node_list, cur_perm_map = _process_transpose_pass_node(node, node_list, node_transpose_pass_name, cur_perm_map)

# add transpose
for node_pair_ in node_transpose_no_pass:
(node, prev) = node_pair_
if prev.unique_name == self.begin.unique_name:
return None, False
cur_perm = cur_perm_map[prev.unique_name]
if len(self.begin_n.successor) == 1:
for node_pair_ in node_transpose_no_pass:
(node, prev) = node_pair_
if prev.unique_name == self.begin_n.unique_name:
return None, False

for node_pair_ in node_transpose_no_pass:
node = node_pair_[0]
Expand Down Expand Up @@ -1257,7 +1251,7 @@ def apply(self, node_list):
PushTransposeSolution.transpose_number += 1
node_list = Solution.add_siso_node(node_list, prev, node, list(prev.output.values())[0], nnode)

node_list = Solution.delete_node_nto1(node_list, self.begin, self.begin_n, self.end_p)
node_list = Solution.delete_node_1ton(node_list, self.begin, self.begin_n, self.end_p)
return node_list, True


Expand All @@ -1278,8 +1272,8 @@ def find(node):
break
if pred_nchw or node.origin.op_type in _nchw_input_node_type:
next = node.successor[0]
if next.origin is not None and next.origin.op_type == 'Transpose' and len(next.successor) == 1:
solution = PushTransposeSolution(node, next, next.successor[0], None)
if next.origin is not None and next.origin.op_type == 'Transpose':
solution = PushTransposeSolution(node, next, next.successor, None)
return solution

return None
Expand Down

0 comments on commit 9100554

Please sign in to comment.