Skip to content

Commit

Permalink
Restore the use of ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"] (#9925)
Browse files Browse the repository at this point in the history
  • Loading branch information
Icemist authored Jan 14, 2022
1 parent 220b122 commit 0a159c4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
10 changes: 8 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,14 @@ def flatten_to_nd(x, x_shape, nd=3):
# Convert a and b into 3 dimensional tensors.
a = flatten_to_nd(inputs[0], a_shape, 3)
b = flatten_to_nd(inputs[1], b_shape, 3)
# Perform a NN batch matmul.
output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False)
if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]:
# Transpose matrix dimensions of b.
bt = _op.transpose(b, [0, 2, 1])
# Perform a NT batch matmul.
output = _op.nn.batch_matmul(a, bt, out_dtype=out_dtype)
else:
# Perform a NN batch matmul.
output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False)
# Determine the output batch dimension.
if a_rank > b_rank:
out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2])
Expand Down
34 changes: 34 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,39 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None):
)


@tvm.testing.parametrize_targets
def test_use_nt_batch_matmul(target, dev):
a_shape = (2, 3, 4)
b_shape = (2, 4, 3)
out_shape = [2, 3, 3]
a_array = np.random.uniform(size=a_shape).astype("float32")
b_array = np.random.uniform(size=b_shape).astype("float32")

for use_nt_batch_matmul in [True, False]:
mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])

graph = helper.make_graph(
[mul_node],
"matmul_test",
inputs=[
helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)),
],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
)

model = helper.make_model(graph, producer_name="matmul_test")
_, shape_dict = get_input_data_shape_dict(model, [a_array, b_array])

mod, _ = relay.frontend.from_onnx(
model, shape_dict, convert_config={"use_nt_batch_matmul": use_nt_batch_matmul}
)
has_transpose_op = "transpose" in str(mod)
# use_nt_batch_matmul implies, TVM converts qualified onnx `matmul`
# to `transpose(weight) + nn.batch_matmul_NT`, otherwise to `nn.batch_matmul`
assert has_transpose_op == use_nt_batch_matmul


@tvm.testing.parametrize_targets
def test_matmulinteger16(target, dev):
def verify_matmulinteger16(a_shape, b_shape, out_shape):
Expand Down Expand Up @@ -6287,6 +6320,7 @@ def verify_scan(
test_random_uniform()
test_convinteger()
test_batch_matmul()
test_use_nt_batch_matmul()
test_global_lppool()
test_scan()
test_random_uniform_like()
Expand Down

0 comments on commit 0a159c4

Please sign in to comment.