Skip to content

Commit

Permalink
[Bugfix][ONNX] Improve broadcast and batch_matmul conversion
Browse files Browse the repository at this point in the history
This commit provides batch_matmul conversions between a 3D or above
matrix and a 1D matrix with proper broadcasting, which improves
the robustness of the ONNX frontend. This issue was captured in apache#16891.
  • Loading branch information
xhmelon committed May 1, 2024
1 parent adc21b2 commit 11f34ce
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
15 changes: 15 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,21 @@ def matmul_out_dtype(inputs, out_dtype):
a = flatten_to_nd(inputs[0], a_shape, 2)
b = _op.transpose(inputs[1])
output = _op.nn.dense(a, b, out_dtype=out_dtype)
elif a_rank == 1 or b_rank == 1:
a, b = inputs
_a_shape = tuple(a_shape.data.numpy())
_b_shape = tuple(b_shape.data.numpy())
if a_rank == 1:
axis = -2
a = _op.expand_dims(a, axis=0)
batches = _b_shape[:-2]
a = _op.broadcast_to(a, (*batches, 1, _a_shape[0]))
else:
axis = -1
b = _op.expand_dims(b, axis=-1)
batches = _a_shape[:-2]
b = _op.broadcast_to(b, (*batches, _b_shape[0], 1))
return _op.squeeze(_op.nn.batch_matmul(a, b, transpose_b=False), axis=axis)
else:
a = inputs[0]
b = inputs[1]
Expand Down
2 changes: 2 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,8 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None):
verify_batch_matmul((2, 4, 3), (3, 4), (2, 4, 4))
verify_batch_matmul((2, 3, 4, 3), (3, 4), (2, 3, 4, 4))
# Test implicit broadcasting.
verify_batch_matmul((5, ), (5, 5, 4), (5, 4))
verify_batch_matmul((5, 4, 5), (5, ), (5, 4))
verify_batch_matmul((4, 3), (2, 3, 4), (2, 4, 4))
verify_batch_matmul((2, 4, 3), (1, 3, 4), (2, 4, 4))
verify_batch_matmul((1, 4, 3), (2, 3, 4), (2, 4, 4))
Expand Down

0 comments on commit 11f34ce

Please sign in to comment.