diff --git a/onnxconverter_common/optimizer.py b/onnxconverter_common/optimizer.py index 8cbc18f..8947afb 100644 --- a/onnxconverter_common/optimizer.py +++ b/onnxconverter_common/optimizer.py @@ -1071,7 +1071,7 @@ def _check_transpose_pass_broadcast(node, node_transpose_pass_name, cur_perm_map if prev.origin is not None or len(prev.tensors) == 0: can_process = False break - elif prev.origin.op_type not in _broadcast_flip_whitelist: + else: can_process = False break return can_process @@ -1097,17 +1097,6 @@ def _process_transpose_pass_broadcast(node, node_list, node_transpose_pass_name, if prev.origin is None: init_pred_value = numpy_helper.to_array(prev.tensors[0]) _update_broadcast_from_initializers(node, init_pred_value, cur_perm, add_transpose_idx_) - elif prev.origin.op_type in _broadcast_flip_whitelist: - nnode = LinkedNode( - helper.make_node( - 'Transpose', - ['push_transpose_in_' + str(PushTransposeSolution.transpose_number)], - ['push_transpose_out_' + str(PushTransposeSolution.transpose_number)], - perm=_get_reverse_perm(cur_perm), - name='PushTranspose_' + str(PushTransposeSolution.transpose_number))) - PushTransposeSolution.transpose_number += 1 - node_list = Solution.add_siso_node(node_list, prev, node, list(prev.output.values())[0], nnode) - return node_list, cur_perm_map diff --git a/tests/test_opt.py b/tests/test_opt.py index ea8b249..9418f63 100644 --- a/tests/test_opt.py +++ b/tests/test_opt.py @@ -302,7 +302,7 @@ def test_merge_common(self): def test_onnx_models(self): model_names = ['mobile_segnet_no_opt.onnx', 'srgan_no_opt.onnx', 'test_model_0_no_opt.onnx', 'test_model_1_no_opt.onnx'] - num_transpose_list = [2, 3, 11, 5] + num_transpose_list = [2, 3, 11, 6] dir_path = os.path.dirname(os.path.realpath(__file__)) for idx_, model_name_ in enumerate(model_names): model_dir = dir_path + '/data/' + model_name_