From c66705e6441f584b1a62c91ce42be416c9087c8a Mon Sep 17 00:00:00 2001 From: Wooseok Date: Mon, 26 Jun 2023 17:52:20 -0500 Subject: [PATCH] [FRONTEND][TFLITE][BugFix] Fix int16 transpose conv loading Loading int16 conv transpose op in tflite model currently fails because output type is not int64. This patch adjusts output type to int64 for int16 quantized transpose convolution operation. In addition, one typo in QnnConv2DTransposeRel is fixed. Test script is also included to evaluate the loading of int16 quantized transpose convolution op. --- python/tvm/relay/frontend/tflite.py | 3 +- src/relay/qnn/op/convolution_transpose.cc | 2 +- tests/python/frontend/tflite/test_forward.py | 80 ++++++++++++++++++++ 3 files changed, 83 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 9e2e244cb146..9e88a85e035d 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -3300,6 +3300,7 @@ def convert_transpose_conv(self, op): kernel_zero_point = weights_tensor.qnn_params["zero_point"] input_scale = input_tensor.qnn_params["scale"] kernel_scale = weights_tensor.qnn_params["scale"] + out_dtype = "int64" if output_tensor_type_str == "int16" else "int32" out = _qnn.op.conv2d_transpose( in_expr, weight_expr_iohw, @@ -3313,7 +3314,7 @@ def convert_transpose_conv(self, op): kernel_size=(int(kernel_h), int(kernel_w)), data_layout="NHWC", kernel_layout="IOHW", - out_dtype="int32", + out_dtype=out_dtype, ) else: out = _op.nn.conv2d_transpose( diff --git a/src/relay/qnn/op/convolution_transpose.cc b/src/relay/qnn/op/convolution_transpose.cc index 951c1bdfb051..0b24ae71ca8c 100644 --- a/src/relay/qnn/op/convolution_transpose.cc +++ b/src/relay/qnn/op/convolution_transpose.cc @@ -99,7 +99,7 @@ bool QnnConv2DTransposeRel(const Array& types, int num_inputs, const Attrs ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == DataType::UInt(8)) << "Expected qnn conv2d type(int8, uint8) for weight but was " << weight->dtype; ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == DataType::Int(32) || - data->dtype == DataType::Int(64)) + param->out_dtype == DataType::Int(64)) << "Expected qnn conv2d type(int16, int32, int64) for output but was " << param->out_dtype; ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater than 0."; diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 3b3dcc59f057..c65e48b40288 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1648,6 +1648,86 @@ def test_forward_transpose_conv(): ) +def _test_tflite2_quantized_transpose_conv( + input_shape, + kernel_shape, + filters, + padding="valid", + strides=(1, 1), + data_format=None, + int_quant_dtype=tf.int8, +): + """One iteration of TFLite2 quantized tranpose conv with given shapes and attributes""" + data_format = "channels_last" if data_format == "NHWC" else "channels_first" + data = np.random.uniform(0, 1, input_shape).astype("float32") + _ = np.random.uniform(0, 1, kernel_shape).astype("float32") + + data_in = tf.keras.layers.Input(shape=data.shape[1:], batch_size=1) + transpose_conv = tf.keras.layers.Conv2DTranspose( + filters=filters, + kernel_size=(kernel_shape[0], kernel_shape[1]), + padding=padding, + strides=strides, + use_bias=True, + )(data_in) + keras_model = tf.keras.models.Model(data_in, transpose_conv) + + # To create quantized values with dynamic range of activations, needs representative dataset + def representative_data_gen(): + for _ in range(1): + yield [data] + + tflite_model_quant = _quantize_keras_model( + keras_model, + representative_data_gen, + is_float_input=True, + is_float_output=True, + int_quant_dtype=int_quant_dtype, + ) + + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 + try: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_quant, 0) + except AttributeError: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_model_quant, 0) + except ImportError as exc: + raise ImportError("The tflite package must be installed") from exc + + subgraph = tflite_model.Subgraphs(0) + model_input = subgraph.InputsAsNumpy() + input_node = subgraph.Tensors(model_input).Name().decode("utf-8") + + tflite_output = run_tflite_graph(tflite_model_quant, data) + + if tf.__version__ < LooseVersion("2.9"): + input_node = data_in.name.replace(":0", "") + else: + input_node = "serving_default_" + data_in.name + ":0" + + tvm_output = run_tvm_graph(tflite_model_quant, data, input_node) + tvm.testing.assert_allclose( + np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2 + ) + + +def test_forward_quantized_transpose_conv(): + """Quantized convolution""" + for int_quant_dtype in [tf.int8, tf.int16]: + _test_tflite2_quantized_transpose_conv( + (1, 1, 5, 64), + (3, 3), + 64, + padding="same", + strides=(1, 2), + data_format="NHWC", + int_quant_dtype=int_quant_dtype, + ) + + ####################################################################### # Reshape # -------