diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index c6079b4535c4..642e680f47f7 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2439,6 +2439,7 @@ def _impl(inputs, attr, params, mod): "ResizeNearestNeighbor": _resize("nearest_neighbor"), "ReverseV2": _reverse_v2(), "RightShift": AttrCvt("right_shift"), + "Rint": AttrCvt("round"), "Round": AttrCvt("round"), "Rsqrt": _rsqrt(), "Select": _where(), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 5f849ac9ac93..93bfd0cbaf83 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -3602,6 +3602,23 @@ def _test_forward_softsign(shape): _test_forward_softsign([2, 5, 2, 5]) +def test_forward_rint(): + """test operator rint """ + + def _test_forward_rint(shape): + tf.disable_eager_execution() + np_data = np.random.uniform(-100, 100, size=shape).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, shape, name="in_data") + tf.math.rint(in_data, name="rint") + compare_tf_with_tvm([np_data], ["in_data:0"], "rint:0") + + _test_forward_rint([100]) + _test_forward_rint([1, 100]) + _test_forward_rint([1, 10, 10]) + _test_forward_rint([2, 5, 2, 5]) + + def test_forward_negative(): """test tf operator Neg """ np_data = np.random.uniform(-100, 255, size=(224, 224, 3)).astype(np.float32)