Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TFLite] Support for BATCH_MATMUL tflite operator #14423

Merged
merged 4 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from .. import op as _op
from .. import qnn as _qnn
from .common import ExprTable
from .common import fold_constant as _fold_constant
from .common import infer_shape as _infer_shape
from .common import infer_type as _infer_type
from .common import lstm_cell, to_int_list, shape_of, try_infer_value
from .common import set_span
from .tflite_flexbuffer import FlexBufferDecoder
Expand Down Expand Up @@ -80,6 +82,7 @@ def __init__(self, model, subgraph, exp_tab):
"ARG_MIN": self.convert_arg_min,
"AVERAGE_POOL_2D": self.convert_average_pool2d,
"BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
"BATCH_MATMUL": self.convert_batch_matmul,
"CAST": self.convert_cast,
"CEIL": self.convert_ceil,
"CONCATENATION": self.convert_concatenation,
Expand Down Expand Up @@ -492,6 +495,21 @@ def get_tensor_type_str(self, tensor_type):
"Tensor type {} is currently not supported".format(str(tensor_type))
)

def flatten_to_nd(self, x, x_shape, nd=3):
"""Flatten input tensor to nd rank"""
ndims = _infer_shape(x_shape)[0]
if ndims == nd:
return x
newshape = _op.concatenate(
[
_expr.const([-1], dtype=_infer_type(x_shape).checked_type.dtype),
_op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
],
0,
)
out = _op.reshape(x, _fold_constant(newshape))
return out

def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
lhs_scale = lhs_tensor.qnn_params["scale"]
rhs_scale = rhs_tensor.qnn_params["scale"]
Expand Down Expand Up @@ -2959,6 +2977,138 @@ def convert_batch_to_space_nd(self, op):

return out

def convert_batch_matmul(self, op):
"""batch_matmul implementation."""
try:
from tflite.BatchMatMulOptions import BatchMatMulOptions
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
output_tensor = self.get_output_tensors(op)

assert len(input_tensors) == 2, "two input tensor arguments expected"

batch_matmul_options = BatchMatMulOptions()
op_options = op.BuiltinOptions()
batch_matmul_options.Init(op_options.Bytes, op_options.Pos)

input_a = self.get_expr(input_tensors[0].tensor_idx)
input_b = self.get_expr(input_tensors[1].tensor_idx)

shape_a = shape_of(input_a)
shape_b = shape_of(input_b)
rank_a = _infer_shape(shape_a)[0]
rank_b = _infer_shape(shape_b)[0]

if rank_a > 2 or rank_b > 2:
# Determine the output batch dimension
new_a_shape = shape_a
new_b_shape = shape_b
if rank_a > rank_b:
rank_diff = rank_a - rank_b
new_b_shape = _op.concatenate(
[
_expr.const([1] * rank_diff, dtype=_infer_type(b_shape).checked_type.dtype),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the "b_shape" here is a mistake? should be "shape_b"?

shape_b,
],
0,
)
elif rank_a < rank_b:
rank_diff = rank_b - rank_a
new_a_shape = _op.concatenate(
[
_expr.const([1] * rank_diff, dtype=_infer_type(a_shape).checked_type.dtype),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the "a_shape" here is a mistake? should be "shape_a"?

shape_a,
],
0,
)
else:
pass

out_batch = _op.concatenate(
[
_op.maximum(
_op.strided_slice(new_b_shape, [i], [i + 1]),
_op.strided_slice(new_a_shape, [i], [i + 1]),
)
for i in range(max(rank_a, rank_b) - 2)
],
0,
)

out_batch_shape = _fold_constant(out_batch)

a_broadcasted_shape = _fold_constant(
_op.concatenate(
[
out_batch,
_op.strided_slice(shape_a, [rank_a - 2], [rank_a]),
],
0,
)
)
b_broadcasted_shape = _fold_constant(
_op.concatenate(
[
out_batch,
_op.strided_slice(shape_b, [rank_b - 2], [rank_b]),
],
0,
)
)
if not tvm.ir.structural_equal(shape_a, a_broadcasted_shape):
input_a = _op.transform.broadcast_to(a, a_broadcasted_shape)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"a" is not define?

if not tvm.ir.structural_equal(shape_b, b_broadcasted_shape):
input_b = _op.transform.broadcast_to(b, b_broadcasted_shape)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"b" is not define


input_a = self.flatten_to_nd(input_a, shape_a, 3)
input_b = self.flatten_to_nd(input_b, shape_b, 3)

if batch_matmul_options.AdjX():
input_a = _op.transpose(input_a, [0, 2, 1])
if batch_matmul_options.AdjY() == False:
input_b = _op.transpose(input_b, [0, 2, 1])

if self.is_quantized(op):
output = _qnn.op.batch_matmul(
input_a,
input_b,
relay.const(0, "int32"),
relay.const(0, "int32"),
relay.const(1.0, "float32"),
relay.const(1.0, "float32"),
)
else:
output = _op.nn.batch_matmul(input_a, input_b)

# Reshape output to original dimensions.
output_shape = shape_of(output)

rank_out = _infer_shape(output_shape)[0]

final_shape = _op.concatenate(
[
_op.strided_slice(shape_a, [0], [rank_a - 2]),
_op.strided_slice(output_shape, [rank_out - 2], [rank_out]),
],
0,
)

reshape = _op.reshape(output, _fold_constant(final_shape))
# qnn batch matmul returns a int32 tensor so we need to requantize
if self.is_quantized(op):
return _qnn.op.requantize(
reshape,
relay.const(1.0, "float32"),
relay.const(0, "int32"),
relay.const(1.0, "float32"),
relay.const(0, "int32"),
out_dtype="int8",
)
else:
return reshape

def convert_space_to_batch_nd(self, op):
"""space_to_batch_nd implementation."""
input_tensors = self.get_input_tensors(op)
Expand Down
71 changes: 62 additions & 9 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import variables
from tensorflow import raw_ops

try:
from tensorflow import lite as interpreter_wrapper
Expand Down Expand Up @@ -319,6 +320,12 @@ def compare_tflite_with_tvm(
sess.run(variables.global_variables_initializer())
# convert to tflite model
converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors)

if len(input_tensors[0].shape) <= 4 and len(input_tensors[1].shape) <= 4:
converter._experimental_disable_batchmatmul_unfold = True
else:
converter._experimental_disable_batchmatmul_unfold = False

converter.experimental_new_converter = experimental_new_converter
if quantized:
if int_quant_dtype == tf.int16:
Expand Down Expand Up @@ -734,24 +741,70 @@ def test_forward_cast():
#######################################################################
# Batch Mat Mul
# ----
def _test_batch_matmul(a_shape, b_shape, dtype, adjoint_a=False, adjoint_b=False):
def _test_batch_matmul(
a_shape, b_shape, dtype, out_dtype, adjoint_a=False, adjoint_b=False, quantized=False
):
with tf.Graph().as_default():
a = array_ops.placeholder(shape=a_shape, dtype=dtype, name="A")
b = array_ops.placeholder(shape=b_shape, dtype=dtype, name="B")
result = math_ops.matmul(a, b, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name="batchmatmul")
print(tf.__version__)

result = raw_ops.BatchMatMulV3(
x=a, y=b, Tout=out_dtype, adj_x=adjoint_a, adj_y=adjoint_b, name="batchmatmul"
)
input_range = {"A": (-100, 100), "B": (-100, 100)} if quantized else None

a_np = np.random.uniform(high=5.0, size=a_shape).astype(dtype)
b_np = np.random.uniform(high=5.0, size=b_shape).astype(dtype)
compare_tflite_with_tvm([a_np, b_np], [a.name, b.name], [a, b], [result])
compare_tflite_with_tvm(
[a_np, b_np],
[a.name, b.name],
[a, b],
[result],
experimental_new_converter=True,
quantized=quantized,
input_range=input_range,
)


def test_forward_batch_matmul():
@pytest.mark.parametrize("config", [("int8", "int32", True), ("float32", "float32", False)])
def test_forward_batch_matmul(config):
"""BATCH_MAT_MUL"""
_test_batch_matmul((3, 5, 4), (3, 4, 5), "float32")
_test_batch_matmul((3, 5, 4), (3, 4, 5), "float32", True, True)
_test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", True, False)
_test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", False, True)
_test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), "float32")
_test_batch_matmul(
(3, 5, 4), (3, 4, 5), dtype=config[0], out_dtype=config[1], quantized=config[2]
)
_test_batch_matmul(
(3, 5, 4),
(3, 4, 5),
dtype=config[0],
out_dtype=config[1],
adjoint_a=True,
adjoint_b=True,
quantized=config[2],
)
_test_batch_matmul(
(3, 5, 4),
(3, 5, 4),
dtype=config[0],
out_dtype=config[1],
adjoint_a=True,
adjoint_b=False,
quantized=config[2],
)
_test_batch_matmul(
(3, 5, 4),
(3, 5, 4),
dtype=config[0],
out_dtype=config[1],
adjoint_a=False,
adjoint_b=True,
quantized=config[2],
)
_test_batch_matmul(
(3, 4, 5, 6), (3, 4, 6, 5), dtype=config[0], out_dtype=config[1], quantized=config[2]
)
# BatchMatMul doesn't support larger than 4D tensors
# _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), dtype=config[0], out_dtype=config[1], quantized=config[2])


#######################################################################
Expand Down