Skip to content

Commit

Permalink
Remove Matmul from broadcast op (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiafatom authored Jun 17, 2020
1 parent 3816b85 commit c5c29c9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion onnxconverter_common/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def find(node):


_broadcast_types = {'Add', 'And', 'Div', 'Equal', 'Greater', 'GreaterOrEqual', 'Less', 'LessOrEqual',
'MatMul', 'Max', 'Mean', 'Min', 'Mod', 'Mul', 'Or', 'Pow', 'PRelu', 'Sub', 'Sum',
'Max', 'Mean', 'Min', 'Mod', 'Mul', 'Or', 'Pow', 'PRelu', 'Sub', 'Sum',
'Where', 'Xor'}
_transpose_pass_type_set = {'Pad', 'Squeeze', 'Unsqueeze', 'Slice'}
_transpose_pass_type_set.update(_broadcast_types)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def test_merge_common(self):
def test_onnx_models(self):
model_names = ['mobile_segnet_no_opt.onnx', 'srgan_no_opt.onnx', 'test_model_0_no_opt.onnx',
'test_model_1_no_opt.onnx']
num_transpose_list = [2, 3, 11, 4]
num_transpose_list = [2, 3, 11, 5]
dir_path = os.path.dirname(os.path.realpath(__file__))
for idx_, model_name_ in enumerate(model_names):
model_dir = dir_path + '/data/' + model_name_
Expand Down

0 comments on commit c5c29c9

Please sign in to comment.