diff --git a/ivy/functional/frontends/tensorflow/nn.py b/ivy/functional/frontends/tensorflow/nn.py index e71beac65b4bd..a66d249c40378 100644 --- a/ivy/functional/frontends/tensorflow/nn.py +++ b/ivy/functional/frontends/tensorflow/nn.py @@ -489,6 +489,12 @@ def softmax(logits, axis=None, name=None): return ivy.softmax(logits, axis=axis) +@with_unsupported_dtypes({"2.9.0 and below": "float16"}, "tensorflow") +@to_ivy_arrays_and_back +def leaky_relu(features, alpha, name=None): + return ivy.leaky_relu(features, alpha=alpha) + + @to_ivy_arrays_and_back def crelu(features, axis=-1, name=None): c = ivy.concat([features, -features], axis=axis) diff --git a/ivy/functional/frontends/tensorflow/raw_ops.py b/ivy/functional/frontends/tensorflow/raw_ops.py index 242ca1a621075..925304999619b 100644 --- a/ivy/functional/frontends/tensorflow/raw_ops.py +++ b/ivy/functional/frontends/tensorflow/raw_ops.py @@ -132,11 +132,13 @@ def Concat(*, concat_dim, values, name="Concat"): Cos = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.math.cos)) + @to_ivy_arrays_and_back def Cross(*, a, b, name='Cross'): a, b = check_tensorflow_casting(a, b) return ivy.cross(a, b) + @to_ivy_arrays_and_back def Cosh(*, x, name="Cosh"): return ivy.cosh(x) @@ -750,6 +752,34 @@ def BatchMatMulV3(x, y, Tout=ivy.Dtype, adj_x=False, adj_y=False, name="BatchMat Slice = to_ivy_arrays_and_back(map_raw_ops_alias(tf_frontend.slice)) +LeakyRelu = to_ivy_arrays_and_back( + map_raw_ops_alias( + tf_frontend.nn.leaky_relu, + )) + +LeakyRelu.supported_dtypes = { + "numpy": ( + "float32", + "float64", + ), + "tensorflow": ( + "bfloat16", + "float16", + "float32", + "float64", + ), + "torch": ( + "float32", + "float64", + ), + "jax": ( + "bfloat16", + "float16", + "float32", + "float64", + ), +} + @to_ivy_arrays_and_back def Prod(*, input, axis, keep_dims=False, name="Prod"): diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py index b045389be3483..bcc7d5171d03a 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py @@ -15,6 +15,36 @@ ) +@handle_frontend_test( + fn_tree="tensorflow.nn.leaky_relu", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + ), + test_with_out=st.just(False), + alpha=helpers.floats(min_value=0, max_value=1) +) +def test_tensorflow_leaky_relu( + *, + dtype_and_x, + alpha, + frontend, + test_flags, + fn_tree, + on_device, +): + dtype, x = dtype_and_x + return helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + features=x[0], + alpha=alpha + ) + + @st.composite def _x_and_filters( draw, diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py index a7009ce50b9e2..4142191801af6 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py @@ -314,8 +314,8 @@ def test_tensorflow_Cos( # NOQA # Cross @handle_frontend_test( - fn_tree='tensorflow.raw_ops.Cross', - dtype_and_x=helpers.dtype_and_values( + fn_tree='tensorflow.raw_ops.Cross', + dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_num_dims=1, max_num_dims=5, @@ -324,11 +324,11 @@ def test_tensorflow_Cos( # NOQA safety_factor_scale="log", num_arrays=2, shared_dtype=True, - ), - test_with_out=st.just(False), + ), + test_with_out=st.just(False), ) def test_tensorflow_Cross( # NOQA - *, + *, dtype_and_x, frontend, test_flags, @@ -346,6 +346,7 @@ def test_tensorflow_Cross( # NOQA b=xs[1], ) + # Rsqrt @handle_frontend_test( fn_tree="tensorflow.raw_ops.Rsqrt", @@ -381,7 +382,7 @@ def test_tensorflow_Rsqrt( ), test_with_out=st.just(False), ) -def test_tensorflow_Cosh( # NOQA +def test_tensorflow_Cosh( *, dtype_and_x, frontend, @@ -3691,3 +3692,33 @@ def test_tensorflow_Prod( # NOQA axis=axis, keep_dims=keep_dims, ) + + +@handle_frontend_test( + fn_tree="tensorflow.raw_ops.LeakyRelu", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_num_dims=1, + ), + test_with_out=st.just(False), + alpha=helpers.floats(min_value=0, max_value=1) +) +def test_tensorflow_LeakyReLU( + *, + dtype_and_x, + alpha, + frontend, + test_flags, + fn_tree, + on_device, +): + dtype, x = dtype_and_x + return helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + features=x[0], + alpha=alpha + )