Skip to content

Commit

Permalink
Add support for the quantized TANH operator to relay TFLite frontend (#…
Browse files Browse the repository at this point in the history
…8024)

Change-Id: I70df765e1562fa586ed0ffd0e07b8858f7fbb831
  • Loading branch information
NicolaLancellotti authored May 20, 2021
1 parent 1203d73 commit ec3b160
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
11 changes: 9 additions & 2 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,11 +769,18 @@ def convert_tanh(self, op):
"""Convert TFLite TANH"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
out = _op.tanh(in_expr)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]

if input_tensor.qnn_params:
in_expr = self.dequantize(in_expr, input_tensor)
out = _op.tanh(in_expr)
if output_tensor.qnn_params:
out = self.quantize(out, output_tensor)
return out

def convert_range(self, op):
Expand Down
25 changes: 19 additions & 6 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3255,17 +3255,30 @@ def test_forward_log_softmax():
# ----


def _test_tanh(data):
def _test_tanh(data, quantized=False):
""" One iteration of TANH """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
out = math_ops.tanh(in_data)
compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
in_data = array_ops.placeholder(shape=data.shape, dtype="float32", name="in_0")

if quantized:
inq_data = tf.quantization.fake_quant_with_min_max_args(
in_data, min=-3, max=3, name="inq_0"
)
input_range = {"inq_0": (-3, 3)}
out = math_ops.tanh(inq_data)
out = tf.quantization.fake_quant_with_min_max_args(out, min=-1, max=1, name="out")
compare_tflite_with_tvm(
data, "inq_0:0", [inq_data], [out], quantized=True, input_range=input_range
)
else:
out = math_ops.tanh(in_data)
compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])


def test_forward_tanh():
""" TANH """
_test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
"""TANH"""
_test_tanh(np.arange(6.0, dtype=np.float32).reshape((1, 6)), quantized=False)
_test_tanh(np.arange(0, 256, 30, dtype=np.uint8), quantized=True)


#######################################################################
Expand Down

0 comments on commit ec3b160

Please sign in to comment.