Skip to content

Commit

Permalink
tensorflow frontend ApproximateEqual (#12008)
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyjain26 authored Mar 12, 2023
1 parent 4e6ffb2 commit 1dee784
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .idea/ivy.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions ivy/functional/frontends/tensorflow/raw_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,29 @@ def Acosh(*, x, name="Acosh"):
AddV2 = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.add))


@with_unsupported_dtypes(
{
"2.10.0 and below": (
"float16",
"bool",
"bfloat16",
)
},
"tensorflow",
)
@to_ivy_arrays_and_back
def ApproximateEqual(
*,
x,
y,
tolerance=1e-05,
name="ApproximateEqual",
):
x, y = check_tensorflow_casting(x, y)
ret = ivy.abs(x - y)
return ret < tolerance


@to_ivy_arrays_and_back
def ArgMin(*, input, dimension, output_type=None, name=None):
output_type = to_ivy_dtype(output_type)
Expand Down
33 changes: 33 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,39 @@ def test_tensorflow_Acosh( # NOQA
)


# ApproximateEqual
@handle_frontend_test(
fn_tree="tensorflow.raw_ops.ApproximateEqual",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
num_arrays=2,
shared_dtype=True,
),
tol=st.floats(1e-05, 1e-03),
test_with_out=st.just(False),
)
def test_tensorflow_ApproximateEqual( # NOQA
*,
dtype_and_x,
tol,
frontend,
test_flags,
fn_tree,
on_device,
):
input_dtype, xs = dtype_and_x
helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
x=xs[0],
y=xs[1],
tolerance=tol,
)


# AddV2
@handle_frontend_test(
fn_tree="tensorflow.raw_ops.AddV2",
Expand Down

0 comments on commit 1dee784

Please sign in to comment.