From 2c3a56ff1053131efdee74cafb3289cce0f8dd81 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Fri, 26 Apr 2024 15:20:59 +0000 Subject: [PATCH] [TFLite] Add support for GELU conversion This commit adds support for converting a TFLite fp32 GELU operation to Relay. Also includes some neighbouring cleanup of version checks to silence warnings. Change-Id: Ic43b1525c4b80cf7f47281c52bb9a8f2643c4073 --- python/tvm/relay/frontend/tflite.py | 21 ++++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 19 +++++++++++++++--- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 364886423928..e939895adeae 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -109,6 +109,7 @@ def __init__(self, model, subgraph, exp_tab): "GATHER_ND": self.convert_gather_nd, "GREATER_EQUAL": self.convert_greater_equal, "GREATER": self.convert_greater, + "GELU": self.convert_gelu, "HARD_SWISH": self.convert_hard_swish, "L2_NORMALIZATION": self.convert_l2_normalization, "L2_POOL_2D": self.convert_l2_pool2d, @@ -1287,6 +1288,26 @@ def convert_elu(self, op): return out + def convert_gelu(self, op): + """Convert TFLite GELU""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + "The TFLite to Relay converter does not support quantized GELU operator yet." + ) + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + in_type = self.get_tensor_type_str(input_tensor.tensor.Type()) + + return in_expr * ( + _expr.const(0.5, dtype=in_type) + + _op.erf(in_expr * _expr.const(0.5**0.5, dtype=in_type)) + * _expr.const(0.5, dtype=in_type) + ) + def convert_square(self, op): """Convert TFLite SQUARE""" input_tensors = self.get_input_tensors(op) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 7f65cfbc8556..ebf7bce250b1 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2150,7 +2150,9 @@ def _test_unary_elemwise(math_op, data, quantized, quant_range=(-6, 6), int_quan with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in") out = math_op(in_data) - compare_tflite_with_tvm(data, ["in:0"], [in_data], [out]) + compare_tflite_with_tvm( + data, ["in:0"], [in_data], [out], experimental_new_converter=True + ) def _unary_elewise_create_model(math_op, data, offset=0, int_quant_dtype=tf.int8): @@ -2400,6 +2402,16 @@ def _test_elu(data, quantized, int_quant_dtype=tf.int8): return _test_unary_elemwise(nn_ops.elu, data, quantized, int_quant_dtype=int_quant_dtype) +####################################################################### +# Gelu +# --- + + +def _test_gelu(data, quantized, int_quant_dtype=tf.int8): + """One iteration of elu""" + return _test_unary_elemwise(nn_ops.gelu, data, quantized, int_quant_dtype=int_quant_dtype) + + def _test_forward_unary_elemwise(test_op, int_quant_dtype=None, quantized=True, negative=True): # input data in_data, inq_data = [], [] @@ -2439,15 +2451,16 @@ def test_all_unary_elemwise(): _test_forward_unary_elemwise(_test_sin) _test_forward_unary_elemwise(_test_neg) _test_forward_unary_elemwise(_test_sqrt, negative=False) + _test_forward_unary_elemwise(_test_gelu, quantized=False) # tensorflow version upgrade support - if tf.__version__ < LooseVersion("2.6.1"): + if package_version.parse(tf.VERSION) < package_version.parse("2.6.1"): _test_forward_unary_elemwise(_test_rsqrt, negative=False, int_quant_dtype=tf.uint8) else: _test_forward_unary_elemwise(_test_rsqrt, negative=False, int_quant_dtype=tf.int8) # ceil and cos come with TFLite 1.14.0.post1 fbs schema if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"): _test_forward_unary_elemwise(_test_ceil) - if tf.__version__ < LooseVersion("2.6.1"): + if package_version.parse(tf.VERSION) < package_version.parse("2.6.1"): _test_forward_unary_elemwise(_test_cos, quantized=False) else: _test_forward_unary_elemwise(_test_cos, int_quant_dtype=tf.int8)