From ade8d71bb544d62eb022646c4386d9165a689e60 Mon Sep 17 00:00:00 2001 From: Alexey Voronov Date: Fri, 14 Jan 2022 11:00:57 +0300 Subject: [PATCH] Restore the use of ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"] (#9925) --- python/tvm/relay/frontend/onnx.py | 10 +++++-- tests/python/frontend/onnx/test_forward.py | 34 ++++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 263eb851f867..57d7568a72ef 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -248,8 +248,14 @@ def flatten_to_nd(x, x_shape, nd=3): # Convert a and b into 3 dimensional tensors. a = flatten_to_nd(inputs[0], a_shape, 3) b = flatten_to_nd(inputs[1], b_shape, 3) - # Perform a NN batch matmul. - output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False) + if ONNX_DEFAULT_CONFIGS["use_nt_batch_matmul"]: + # Transpose matrix dimensions of b. + bt = _op.transpose(b, [0, 2, 1]) + # Perform a NT batch matmul. + output = _op.nn.batch_matmul(a, bt, out_dtype=out_dtype) + else: + # Perform a NN batch matmul. + output = _op.nn.batch_matmul(a, b, out_dtype=out_dtype, transpose_b=False) # Determine the output batch dimension. if a_rank > b_rank: out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 701906e4be40..287fbe41bd77 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1281,6 +1281,39 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, convert_config=None): ) +@tvm.testing.parametrize_targets +def test_use_nt_batch_matmul(target, dev): + a_shape = (2, 3, 4) + b_shape = (2, 4, 3) + out_shape = [2, 3, 3] + a_array = np.random.uniform(size=a_shape).astype("float32") + b_array = np.random.uniform(size=b_shape).astype("float32") + + for use_nt_batch_matmul in [True, False]: + mul_node = helper.make_node("MatMul", ["a", "b"], ["out"]) + + graph = helper.make_graph( + [mul_node], + "matmul_test", + inputs=[ + helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)), + helper.make_tensor_value_info("b", TensorProto.FLOAT, list(b_shape)), + ], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) + + model = helper.make_model(graph, producer_name="matmul_test") + _, shape_dict = get_input_data_shape_dict(model, [a_array, b_array]) + + mod, _ = relay.frontend.from_onnx( + model, shape_dict, convert_config={"use_nt_batch_matmul": use_nt_batch_matmul} + ) + has_transpose_op = "transpose" in str(mod) + # use_nt_batch_matmul implies, TVM converts qualified onnx `matmul` + # to `transpose(weight) + nn.batch_matmul_NT`, otherwise to `nn.batch_matmul` + assert has_transpose_op == use_nt_batch_matmul + + @tvm.testing.parametrize_targets def test_matmulinteger16(target, dev): def verify_matmulinteger16(a_shape, b_shape, out_shape): @@ -6287,6 +6320,7 @@ def verify_scan( test_random_uniform() test_convinteger() test_batch_matmul() + test_use_nt_batch_matmul() test_global_lppool() test_scan() test_random_uniform_like()