From 3350e401861742f94e90ba990cb36064661b3e49 Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Wed, 15 Feb 2023 20:09:34 +0000 Subject: [PATCH 01/12] initial addition of relu6 in all backends --- ivy/functional/backends/jax/activations.py | 4 ++ ivy/functional/backends/numpy/activations.py | 8 +++ .../backends/tensorflow/activations.py | 5 ++ ivy/functional/backends/torch/activations.py | 5 ++ ivy/functional/ivy/activations.py | 56 +++++++++++++++++++ 5 files changed, 78 insertions(+) diff --git a/ivy/functional/backends/jax/activations.py b/ivy/functional/backends/jax/activations.py index 27b39aedff193..464d19ba31d65 100644 --- a/ivy/functional/backends/jax/activations.py +++ b/ivy/functional/backends/jax/activations.py @@ -35,6 +35,10 @@ def relu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jnp.maximum(x, 0) +def relu6(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: + return jax.nn.relu6(x) + + def sigmoid(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return 1 / (1 + jnp.exp(-x)) diff --git a/ivy/functional/backends/numpy/activations.py b/ivy/functional/backends/numpy/activations.py index 267315763f969..6b478a5e14a17 100644 --- a/ivy/functional/backends/numpy/activations.py +++ b/ivy/functional/backends/numpy/activations.py @@ -19,6 +19,14 @@ def relu(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: relu.support_native_out = True +@_scalar_output_to_0d_array +def relu6(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: + return np.minimum(np.maximum(x, 0, dtype=x.dtype), 6, out=out, dtype=x.dtype) + + +relu6.support_native_out = True + + def leaky_relu( x: np.ndarray, /, *, alpha: float = 0.2, out: Optional[np.ndarray] = None ) -> np.ndarray: diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index f53f4ab4e0725..878394173c4d0 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -33,6 +33,11 @@ def relu(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: return tf.nn.relu(x) +@with_unsupported_dtypes({"2.9.1 and below": ("complex",)}, backend_version) +def relu6(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: + return tf.nn.relu6(x) + + def sigmoid(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: if not ivy.is_array(x): x = float(x) diff --git a/ivy/functional/backends/torch/activations.py b/ivy/functional/backends/torch/activations.py index a32eddaede275..162e3c2b857fb 100644 --- a/ivy/functional/backends/torch/activations.py +++ b/ivy/functional/backends/torch/activations.py @@ -19,6 +19,11 @@ def relu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten return torch.relu(x) +@with_unsupported_dtypes({"1.11.0 and below": ("complex", "float16")}, backend_version) +def relu6(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: + return torch.nn.functional.relu6(x) + + @with_unsupported_dtypes({"1.11.0 and below": ("complex", "float16")}, backend_version) def leaky_relu( x: torch.Tensor, diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index 3f56b3c93a5a9..edd6e30bcfd2e 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -593,3 +593,59 @@ def mish( } """ return current_backend(x).mish(x, out=out) + + +@to_native_arrays_and_back +@handle_out_argument +@handle_nestable +@handle_exceptions +@handle_array_like_without_promotion +@handle_array_function +def relu6( + x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None +) -> ivy.Array: + """Applies the rectified linear unit 6 function element-wise. + + Parameters + ---------- + x + input array + out + optional output array, for writing the result to. It must have a shape that the + inputs broadcast to. + + Returns + ------- + ret + an array containing the rectified linear unit 6 activation of each element in + ``x``. + + Examples + -------- + With :class:`ivy.Array` input: + + >>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.]) + >>> y = ivy.relu6(x) + >>> print(y) + ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.]) + + >>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.]) + >>> y = ivy.zeros(9) + >>> ivy.relu6(x, out = y) + >>> print(y) + ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.]) + + With :class:`ivy.Container` input: + + >>> x = { + a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.]) + } + >>> x = ivy.relu6(x, out=x) + >>> print(x) + { + a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.]) + } + """ + return current_backend(x).relu6(x, out=out) From d253cb27527271b873bce837d0c1e90e8428376d Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Wed, 15 Feb 2023 20:35:28 +0000 Subject: [PATCH 02/12] switch old relu6 calls with new ivy.relu6 --- .../frontends/jax/nn/non_linear_activations.py | 2 +- ivy/functional/frontends/tensorflow/nn.py | 15 +++++++++++---- ivy/functional/frontends/tensorflow/raw_ops.py | 2 +- .../functional/non_linear_activation_functions.py | 4 ++-- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/ivy/functional/frontends/jax/nn/non_linear_activations.py b/ivy/functional/frontends/jax/nn/non_linear_activations.py index ef76ae286738b..e0f71ff7d6ee6 100644 --- a/ivy/functional/frontends/jax/nn/non_linear_activations.py +++ b/ivy/functional/frontends/jax/nn/non_linear_activations.py @@ -269,7 +269,7 @@ def relu(x): @to_ivy_arrays_and_back def relu6(x): - res = ivy.minimum(ivy.maximum(x, 0.0), 6.0) + res = ivy.relu6(x) return _type_conversion_64(res) diff --git a/ivy/functional/frontends/tensorflow/nn.py b/ivy/functional/frontends/tensorflow/nn.py index 2f96c984fc881..b9f0d5c23733f 100644 --- a/ivy/functional/frontends/tensorflow/nn.py +++ b/ivy/functional/frontends/tensorflow/nn.py @@ -19,13 +19,13 @@ def _reduce_strides_dilations(dim, stride, dilations): @to_ivy_arrays_and_back def atrous_conv2d(value, filters, rate, padding): - return ivy.conv2d(value, filters, 1, padding, dilations=[rate]*2) + return ivy.conv2d(value, filters, 1, padding, dilations=[rate] * 2) @to_ivy_arrays_and_back def atrous_conv2d_transpose(value, filters, output_shape, rate, padding): return ivy.conv2d_transpose( - value, filters, 1, padding, output_shape=output_shape, dilations=[rate]*2 + value, filters, 1, padding, output_shape=output_shape, dilations=[rate] * 2 ) @@ -145,13 +145,15 @@ def depthwise_conv2d( ): strides, dilations = _reduce_strides_dilations(2, strides, dilations) fc = filter.shape[-2] - filter = filter.reshape([*filter.shape[0:2], 1, filter.shape[-2]*filter.shape[-1]]) + filter = filter.reshape( + [*filter.shape[0:2], 1, filter.shape[-2] * filter.shape[-1]] + ) return ivy.conv_general_dilated( input, filter, strides, padding, - data_format='channel_last' if data_format[-1] == 'C' else 'channel_first', + data_format="channel_last" if data_format[-1] == "C" else "channel_first", dilations=dilations, feature_group_count=fc, ) @@ -421,6 +423,11 @@ def relu(features, name=None): return ivy.relu(features) +@to_ivy_arrays_and_back +def relu6(features, name=None): + return ivy.relu6(features) + + @to_ivy_arrays_and_back def softmax(logits, axis=None, name=None): return ivy.softmax(logits, axis=axis) diff --git a/ivy/functional/frontends/tensorflow/raw_ops.py b/ivy/functional/frontends/tensorflow/raw_ops.py index b691368a8e2a8..f4562de29fae4 100644 --- a/ivy/functional/frontends/tensorflow/raw_ops.py +++ b/ivy/functional/frontends/tensorflow/raw_ops.py @@ -487,7 +487,7 @@ def Pow(*, x, y, name="Pow"): @to_ivy_arrays_and_back def Relu6(features, name="Relu6"): - return ivy.clip(features, 0, 6) + return ivy.relu6(features) Sigmoid = to_ivy_arrays_and_back( diff --git a/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py b/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py index 6a863f845735a..bf38036a26482 100644 --- a/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/non_linear_activation_functions.py @@ -172,7 +172,7 @@ def threshold_(input, threshold, value): def relu6(input, inplace=False): - ret = ivy.minimum(ivy.maximum(input, 0), 6) + ret = ivy.relu6(input) if inplace: ivy.inplace_update(input, ret) return input @@ -307,7 +307,7 @@ def leaky_relu_(input, negative_slope=0.01): def hardswish(input, inplace=False): - relu6_val = ivy.minimum(ivy.maximum(ivy.add(input, 3), 0), 6) + relu6_val = ivy.relu6(ivy.add(input, 3)) ret = ivy.multiply(input, ivy.divide(relu6_val, 6)) if inplace: ivy.inplace_update(input, ret) From 64f332ed240834af3e43a7df00b8ea805de50bdb Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Wed, 15 Feb 2023 20:41:09 +0000 Subject: [PATCH 03/12] added test for tf.nn.relu6 --- .../test_frontends/test_tensorflow/test_nn.py | 42 +++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py index f09829901ff1e..781363267ce3d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py @@ -49,7 +49,9 @@ def _x_and_filters( dilations = draw( st.one_of( st.integers(dilation_min, dilation_max), - st.lists(st.integers(dilation_min, dilation_max), min_size=dim, max_size=dim), + st.lists( + st.integers(dilation_min, dilation_max), min_size=dim, max_size=dim + ), ) ) if atrous: @@ -58,7 +60,9 @@ def _x_and_filters( stride = draw( st.one_of( st.integers(stride_min, stride_max), - st.lists(st.integers(stride_min, stride_max), min_size=dim, max_size=dim), + st.lists( + st.integers(stride_min, stride_max), min_size=dim, max_size=dim + ), ) ) fstride = [stride] * dim if isinstance(stride, int) else stride @@ -104,7 +108,9 @@ def _x_and_filters( if transpose: output_shape = [ x_shape[0], - _deconv_length(x_w, fstride[0], filter_shape[0], padding, fdilations[0]), + _deconv_length( + x_w, fstride[0], filter_shape[0], padding, fdilations[0] + ), d_in, ] elif dim == 2: @@ -1079,6 +1085,36 @@ def test_tensorflow_relu( ) +# relu6 +@handle_frontend_test( + fn_tree="tensorflow.nn.relu6", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + num_arrays=1, + min_value=-20, + max_value=20, + ), + test_with_out=st.just(False), +) +def test_tensorflow_relu6( + *, + dtype_and_x, + test_flags, + frontend, + fn_tree, + on_device, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + features=x[0], + ) + + # softmax @handle_frontend_test( fn_tree="tensorflow.nn.softmax", From ecfb09559ba6832d75383c9854adb881989760ad Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Fri, 17 Feb 2023 08:44:30 +0000 Subject: [PATCH 04/12] moved relu6 in all backends to experimental --- ivy/functional/backends/jax/activations.py | 4 -- .../backends/jax/experimental/activations.py | 5 ++ ivy/functional/backends/numpy/activations.py | 8 --- .../numpy/experimental/activations.py | 8 +++ .../backends/tensorflow/activations.py | 5 -- .../tensorflow/experimental/activations.py | 5 ++ ivy/functional/backends/torch/activations.py | 5 -- .../torch/experimental/activations.py | 5 ++ ivy/functional/ivy/activations.py | 56 ------------------ .../ivy/experimental/activations.py | 57 +++++++++++++++++++ 10 files changed, 80 insertions(+), 78 deletions(-) diff --git a/ivy/functional/backends/jax/activations.py b/ivy/functional/backends/jax/activations.py index 464d19ba31d65..27b39aedff193 100644 --- a/ivy/functional/backends/jax/activations.py +++ b/ivy/functional/backends/jax/activations.py @@ -35,10 +35,6 @@ def relu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jnp.maximum(x, 0) -def relu6(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: - return jax.nn.relu6(x) - - def sigmoid(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return 1 / (1 + jnp.exp(-x)) diff --git a/ivy/functional/backends/jax/experimental/activations.py b/ivy/functional/backends/jax/experimental/activations.py index b181680031741..0af3e0dd3dff7 100644 --- a/ivy/functional/backends/jax/experimental/activations.py +++ b/ivy/functional/backends/jax/experimental/activations.py @@ -1,6 +1,7 @@ from typing import Optional, Union # global +import jax import jax.numpy as jnp from ivy.functional.backends.jax import JaxArray @@ -13,6 +14,10 @@ def logit(x: JaxArray, /, *, eps: Optional[float] = None, out=None): return jnp.log(x / (1 - x)) +def relu6(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: + return jax.nn.relu6(x) + + def thresholded_relu( x: JaxArray, /, diff --git a/ivy/functional/backends/numpy/activations.py b/ivy/functional/backends/numpy/activations.py index 6b478a5e14a17..267315763f969 100644 --- a/ivy/functional/backends/numpy/activations.py +++ b/ivy/functional/backends/numpy/activations.py @@ -19,14 +19,6 @@ def relu(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: relu.support_native_out = True -@_scalar_output_to_0d_array -def relu6(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: - return np.minimum(np.maximum(x, 0, dtype=x.dtype), 6, out=out, dtype=x.dtype) - - -relu6.support_native_out = True - - def leaky_relu( x: np.ndarray, /, *, alpha: float = 0.2, out: Optional[np.ndarray] = None ) -> np.ndarray: diff --git a/ivy/functional/backends/numpy/experimental/activations.py b/ivy/functional/backends/numpy/experimental/activations.py index 33f431c6e3a39..c5b9f51954a69 100644 --- a/ivy/functional/backends/numpy/experimental/activations.py +++ b/ivy/functional/backends/numpy/experimental/activations.py @@ -31,3 +31,11 @@ def thresholded_relu( thresholded_relu.support_native_out = True + + +@_scalar_output_to_0d_array +def relu6(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: + return np.minimum(np.maximum(x, 0, dtype=x.dtype), 6, out=out, dtype=x.dtype) + + +relu6.support_native_out = True diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index 878394173c4d0..f53f4ab4e0725 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -33,11 +33,6 @@ def relu(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: return tf.nn.relu(x) -@with_unsupported_dtypes({"2.9.1 and below": ("complex",)}, backend_version) -def relu6(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: - return tf.nn.relu6(x) - - def sigmoid(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: if not ivy.is_array(x): x = float(x) diff --git a/ivy/functional/backends/tensorflow/experimental/activations.py b/ivy/functional/backends/tensorflow/experimental/activations.py index ea35b0d2a959f..98fef2a5d3810 100644 --- a/ivy/functional/backends/tensorflow/experimental/activations.py +++ b/ivy/functional/backends/tensorflow/experimental/activations.py @@ -30,3 +30,8 @@ def thresholded_relu( out: Optional[Tensor] = None, ) -> Tensor: return tf.where(x > threshold, x, 0) + + +@with_unsupported_dtypes({"2.9.1 and below": ("complex",)}, backend_version) +def relu6(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: + return tf.nn.relu6(x) diff --git a/ivy/functional/backends/torch/activations.py b/ivy/functional/backends/torch/activations.py index 162e3c2b857fb..a32eddaede275 100644 --- a/ivy/functional/backends/torch/activations.py +++ b/ivy/functional/backends/torch/activations.py @@ -19,11 +19,6 @@ def relu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Ten return torch.relu(x) -@with_unsupported_dtypes({"1.11.0 and below": ("complex", "float16")}, backend_version) -def relu6(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: - return torch.nn.functional.relu6(x) - - @with_unsupported_dtypes({"1.11.0 and below": ("complex", "float16")}, backend_version) def leaky_relu( x: torch.Tensor, diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index ef3ff49ae66ca..5647e3f23dc6d 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -23,3 +23,8 @@ def thresholded_relu( out: Optional[torch.Tensor] = None, ) -> torch.Tensor: return torch.threshold(x, threshold=threshold, value=0) + + +@with_unsupported_dtypes({"1.11.0 and below": ("complex", "float16")}, backend_version) +def relu6(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: + return torch.nn.functional.relu6(x) diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index edd6e30bcfd2e..3f56b3c93a5a9 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -593,59 +593,3 @@ def mish( } """ return current_backend(x).mish(x, out=out) - - -@to_native_arrays_and_back -@handle_out_argument -@handle_nestable -@handle_exceptions -@handle_array_like_without_promotion -@handle_array_function -def relu6( - x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None -) -> ivy.Array: - """Applies the rectified linear unit 6 function element-wise. - - Parameters - ---------- - x - input array - out - optional output array, for writing the result to. It must have a shape that the - inputs broadcast to. - - Returns - ------- - ret - an array containing the rectified linear unit 6 activation of each element in - ``x``. - - Examples - -------- - With :class:`ivy.Array` input: - - >>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.]) - >>> y = ivy.relu6(x) - >>> print(y) - ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.]) - - >>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.]) - >>> y = ivy.zeros(9) - >>> ivy.relu6(x, out = y) - >>> print(y) - ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.]) - - With :class:`ivy.Container` input: - - >>> x = { - a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]), - b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.]) - } - >>> x = ivy.relu6(x, out=x) - >>> print(x) - { - a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]), - b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.]) - } - """ - return current_backend(x).relu6(x, out=out) diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index 45bfe95d94e9a..69b82d4be747e 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -6,6 +6,7 @@ from ivy.backend_handler import current_backend from ivy.exceptions import handle_exceptions from ivy.func_wrapper import ( + handle_array_function, handle_nestable, to_native_arrays_and_back, handle_array_like_without_promotion, @@ -173,3 +174,59 @@ def thresholded_relu( } """ return current_backend(x).thresholded_relu(x, threshold=threshold, out=out) + + +@to_native_arrays_and_back +@handle_out_argument +@handle_nestable +@handle_exceptions +@handle_array_like_without_promotion +@handle_array_function +def relu6( + x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None +) -> ivy.Array: + """Applies the rectified linear unit 6 function element-wise. + + Parameters + ---------- + x + input array + out + optional output array, for writing the result to. It must have a shape that the + inputs broadcast to. + + Returns + ------- + ret + an array containing the rectified linear unit 6 activation of each element in + ``x``. + + Examples + -------- + With :class:`ivy.Array` input: + + >>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.]) + >>> y = ivy.relu6(x) + >>> print(y) + ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.]) + + >>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.]) + >>> y = ivy.zeros(9) + >>> ivy.relu6(x, out = y) + >>> print(y) + ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.]) + + With :class:`ivy.Container` input: + + >>> x = { + a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.]) + } + >>> x = ivy.relu6(x, out=x) + >>> print(x) + { + a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.]) + } + """ + return current_backend(x).relu6(x, out=out) From aaae6ed8ed7794e84b590e6e9addfb0b442c1392 Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Fri, 17 Feb 2023 09:51:29 +0000 Subject: [PATCH 05/12] added backend test, relu6 method for arrays and containers, fixed supported torch dtypes --- ivy/array/experimental/activations.py | 47 +++++++ ivy/container/experimental/activations.py | 128 ++++++++++++++++++ .../torch/experimental/activations.py | 2 +- .../test_nn/test_activations.py | 31 +++++ 4 files changed, 207 insertions(+), 1 deletion(-) diff --git a/ivy/array/experimental/activations.py b/ivy/array/experimental/activations.py index 527fb0af9a1ab..61d8882062515 100644 --- a/ivy/array/experimental/activations.py +++ b/ivy/array/experimental/activations.py @@ -109,3 +109,50 @@ def prelu( ------- """ return ivy.prelu(self._data, slope, out=out) + + def relu6(self, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: + """Applies the rectified linear unit 6 function element-wise. + + Parameters + ---------- + x + input array + out + optional output array, for writing the result to. + It must have a shape that the inputs broadcast to. + + Returns + ------- + ret + an array containing the rectified linear unit 6 activation + of each element in ``x``. + + Examples + -------- + With :class:`ivy.Array` input: + + >>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.]) + >>> y = ivy.relu6(x) + >>> print(y) + ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.]) + + >>> x = ivy.array([-1., 0., 1., 2., 3., 4., 5., 6., 7.]) + >>> y = ivy.zeros(9) + >>> ivy.relu6(x, out = y) + >>> print(y) + ivy.array([0., 0., 1., 2., 3., 4., 5., 6., 6.]) + + With :class:`ivy.Container` input: + + >>> x = { + a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.]) + } + >>> x = ivy.relu6(x, out=x) + >>> print(x) + { + a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.]) + } + """ + return ivy.relu6(self._data, out=out) diff --git a/ivy/container/experimental/activations.py b/ivy/container/experimental/activations.py index d6f7ce2230431..c23a9e526430f 100644 --- a/ivy/container/experimental/activations.py +++ b/ivy/container/experimental/activations.py @@ -323,3 +323,131 @@ def prelu( map_sequences=map_sequences, out=out, ) + + @staticmethod + def static_relu6( + x: Union[ivy.Array, ivy.NativeArray, ivy.Container], + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: bool = True, + prune_unapplied: bool = False, + map_sequences: bool = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.relu6. + This method simply wraps the function, and so the docstring + for ivy.relu6 also applies to this method with minimal changes. + + Parameters + ---------- + x + input container. + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + a container with the rectified linear 6 activation unit function + applied element-wise. + + Examples + -------- + >>> x = { + a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.]) + } + >>> y = ivy.Container.static_relu6(x) + >>> print(y) + { + a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.]) + } + + """ + return ContainerBase.cont_multi_map_in_function( + "relu6", + x, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def relu6( + self: ivy.Container, + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: bool = True, + prune_unapplied: bool = False, + map_sequences: bool = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container instance method variant of ivy.relu6. + This method simply wraps the function, and so the docstring + for ivy.relu6 also applies to this method with minimal changes. + + Parameters + ---------- + self + input container. + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + a container with the rectified linear 6 activation unit function + applied element-wise. + + Examples + -------- + >>> x = { + a: ivy.array([-3., -2., -1., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 7., 8., 9.]) + } + >>> y = x.relu() + >>> print(y) + { + a: ivy.array([0., 0., 0., 0., 1., 2., 3., 4., 5.]), + b: ivy.array([1., 2., 3., 4., 5., 6., 6., 6., 6.]) + } + + """ + return self.static_relu6( + self, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 5647e3f23dc6d..f52ce2d14d2d4 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -25,6 +25,6 @@ def thresholded_relu( return torch.threshold(x, threshold=threshold, value=0) -@with_unsupported_dtypes({"1.11.0 and below": ("complex", "float16")}, backend_version) +@with_unsupported_dtypes({"1.11.0 and below": ("bfloat16", "float16")}, backend_version) def relu6(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: return torch.nn.functional.relu6(x) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index f49c7f49d40c4..8ce1ba0aee31f 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -106,3 +106,34 @@ def test_prelu( x=x[0], slope=slope, ) + + +# relu +@handle_test( + fn_tree="functional.ivy.experimental.relu6", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=8, + small_abs_safety_factor=8, + safety_factor_scale="log", + ), +) +def test_relu6( + *, + dtype_and_x, + test_flags, + backend_fw, + fn_name, + on_device, + ground_truth_backend, +): + dtype, x = dtype_and_x + helpers.test_function( + ground_truth_backend=ground_truth_backend, + input_dtypes=dtype, + fw=backend_fw, + test_flags=test_flags, + fn_name=fn_name, + on_device=on_device, + x=x[0], + ) From 5498daa1c730bd12fb6107f6e4f640f8c12bc184 Mon Sep 17 00:00:00 2001 From: Mahmoud Ashraf <32404268+MahmoudAshraf97@users.noreply.github.com> Date: Fri, 17 Feb 2023 11:57:33 +0200 Subject: [PATCH 06/12] revert unrelated formatting changes --- ivy/functional/frontends/tensorflow/nn.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ivy/functional/frontends/tensorflow/nn.py b/ivy/functional/frontends/tensorflow/nn.py index b9f0d5c23733f..089caccdb192c 100644 --- a/ivy/functional/frontends/tensorflow/nn.py +++ b/ivy/functional/frontends/tensorflow/nn.py @@ -19,13 +19,13 @@ def _reduce_strides_dilations(dim, stride, dilations): @to_ivy_arrays_and_back def atrous_conv2d(value, filters, rate, padding): - return ivy.conv2d(value, filters, 1, padding, dilations=[rate] * 2) + return ivy.conv2d(value, filters, 1, padding, dilations=[rate]*2) @to_ivy_arrays_and_back def atrous_conv2d_transpose(value, filters, output_shape, rate, padding): return ivy.conv2d_transpose( - value, filters, 1, padding, output_shape=output_shape, dilations=[rate] * 2 + value, filters, 1, padding, output_shape=output_shape, dilations=[rate]*2 ) @@ -145,15 +145,13 @@ def depthwise_conv2d( ): strides, dilations = _reduce_strides_dilations(2, strides, dilations) fc = filter.shape[-2] - filter = filter.reshape( - [*filter.shape[0:2], 1, filter.shape[-2] * filter.shape[-1]] - ) + filter = filter.reshape([*filter.shape[0:2], 1, filter.shape[-2]*filter.shape[-1]]) return ivy.conv_general_dilated( input, filter, strides, padding, - data_format="channel_last" if data_format[-1] == "C" else "channel_first", + data_format='channel_last' if data_format[-1] == 'C' else 'channel_first', dilations=dilations, feature_group_count=fc, ) From bacd76fa6df8a770b40e76e0419f4a56b0d21d27 Mon Sep 17 00:00:00 2001 From: Mahmoud Ashraf <32404268+MahmoudAshraf97@users.noreply.github.com> Date: Fri, 17 Feb 2023 11:59:48 +0200 Subject: [PATCH 07/12] revert unrelated formatting changes --- .../test_frontends/test_tensorflow/test_nn.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py index 781363267ce3d..3487bf7df59b3 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_nn.py @@ -49,9 +49,7 @@ def _x_and_filters( dilations = draw( st.one_of( st.integers(dilation_min, dilation_max), - st.lists( - st.integers(dilation_min, dilation_max), min_size=dim, max_size=dim - ), + st.lists(st.integers(dilation_min, dilation_max), min_size=dim, max_size=dim), ) ) if atrous: @@ -60,9 +58,7 @@ def _x_and_filters( stride = draw( st.one_of( st.integers(stride_min, stride_max), - st.lists( - st.integers(stride_min, stride_max), min_size=dim, max_size=dim - ), + st.lists(st.integers(stride_min, stride_max), min_size=dim, max_size=dim), ) ) fstride = [stride] * dim if isinstance(stride, int) else stride @@ -108,9 +104,7 @@ def _x_and_filters( if transpose: output_shape = [ x_shape[0], - _deconv_length( - x_w, fstride[0], filter_shape[0], padding, fdilations[0] - ), + _deconv_length(x_w, fstride[0], filter_shape[0], padding, fdilations[0]), d_in, ] elif dim == 2: From 82e0b796590a375aad77381965471832664acf57 Mon Sep 17 00:00:00 2001 From: Mahmoud Ashraf <32404268+MahmoudAshraf97@users.noreply.github.com> Date: Tue, 21 Feb 2023 00:05:19 +0200 Subject: [PATCH 08/12] switch Relu6 to tf.nn alias --- ivy/functional/frontends/tensorflow/raw_ops.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ivy/functional/frontends/tensorflow/raw_ops.py b/ivy/functional/frontends/tensorflow/raw_ops.py index f4562de29fae4..4fd24690b4dbb 100644 --- a/ivy/functional/frontends/tensorflow/raw_ops.py +++ b/ivy/functional/frontends/tensorflow/raw_ops.py @@ -485,9 +485,12 @@ def Pow(*, x, y, name="Pow"): return ivy.pow(x, y) -@to_ivy_arrays_and_back -def Relu6(features, name="Relu6"): - return ivy.relu6(features) +Relu6 = to_ivy_arrays_and_back( + map_raw_ops_alias( + tf_frontend.nn.relu6, + kwargs_to_update={"x": "features"}, + ) +) Sigmoid = to_ivy_arrays_and_back( From 6c2d2e76ecdb9281b8e314d21ed2acd111a64a69 Mon Sep 17 00:00:00 2001 From: Mahmoud Ashraf <32404268+MahmoudAshraf97@users.noreply.github.com> Date: Tue, 21 Feb 2023 11:53:02 +0200 Subject: [PATCH 09/12] add unsupported dtypes for relu6 --- ivy/functional/backends/torch/experimental/activations.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index f52ce2d14d2d4..f43301f45cee9 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -28,3 +28,5 @@ def thresholded_relu( @with_unsupported_dtypes({"1.11.0 and below": ("bfloat16", "float16")}, backend_version) def relu6(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: return torch.nn.functional.relu6(x) + +relu6.unsupported_dtypes = ("float16", "bfloat16",) From b58a951fd9657244509d2c22fcb508244af313e6 Mon Sep 17 00:00:00 2001 From: Mahmoud Ashraf <32404268+MahmoudAshraf97@users.noreply.github.com> Date: Tue, 21 Feb 2023 16:27:03 +0200 Subject: [PATCH 10/12] fix jax casting to adhere to superset behaviour --- ivy/functional/backends/jax/experimental/activations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ivy/functional/backends/jax/experimental/activations.py b/ivy/functional/backends/jax/experimental/activations.py index 0af3e0dd3dff7..e7a5cafa148c2 100644 --- a/ivy/functional/backends/jax/experimental/activations.py +++ b/ivy/functional/backends/jax/experimental/activations.py @@ -15,7 +15,8 @@ def logit(x: JaxArray, /, *, eps: Optional[float] = None, out=None): def relu6(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: - return jax.nn.relu6(x) + x_dtype = x.dtype + return jax.nn.relu6(x).astype(x_dtype) def thresholded_relu( From ea79d53acdfdfa086ef61c505ebcd9a09d6abc72 Mon Sep 17 00:00:00 2001 From: CatB1t Date: Thu, 23 Feb 2023 21:44:14 +0200 Subject: [PATCH 11/12] update safety factor, change `get_dtypes` to numeric. --- .../test_experimental/test_nn/test_activations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index 8ce1ba0aee31f..c787a80de8643 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -112,9 +112,9 @@ def test_prelu( @handle_test( fn_tree="functional.ivy.experimental.relu6", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - large_abs_safety_factor=8, - small_abs_safety_factor=8, + available_dtypes=helpers.get_dtypes("numeric"), + large_abs_safety_factor=2, + small_abs_safety_factor=2, safety_factor_scale="log", ), ) From 339d8cb504e6595b81fe8c16ad4f4997fbbb232f Mon Sep 17 00:00:00 2001 From: Mahmoud Ashraf <32404268+MahmoudAshraf97@users.noreply.github.com> Date: Sat, 25 Feb 2023 21:20:17 +0200 Subject: [PATCH 12/12] update jax.nn.relu6 to fix gradients at boundary conditions --- .../backends/jax/experimental/activations.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/jax/experimental/activations.py b/ivy/functional/backends/jax/experimental/activations.py index e7a5cafa148c2..5fce86fd64ff2 100644 --- a/ivy/functional/backends/jax/experimental/activations.py +++ b/ivy/functional/backends/jax/experimental/activations.py @@ -4,6 +4,8 @@ import jax import jax.numpy as jnp from ivy.functional.backends.jax import JaxArray +from jax import lax +import ivy def logit(x: JaxArray, /, *, eps: Optional[float] = None, out=None): @@ -15,8 +17,17 @@ def logit(x: JaxArray, /, *, eps: Optional[float] = None, out=None): def relu6(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: - x_dtype = x.dtype - return jax.nn.relu6(x).astype(x_dtype) + relu6_func = jax.nn.relu6 + + # sets gradient at 0 and 6 to 0 instead of 0.5 + # can refactor to jax.nn.relu6 when this PR is merged + # https://github.com/google/jax/pull/14682 + def custom_grad_func(x_and_grad, one): return lax.select( + (6 > x_and_grad[0]) & (x_and_grad[0] > 0), one, lax.full_like(one, 0)) + + new_func = ivy.bind_custom_gradient_function(relu6_func, custom_grad_func) + + return new_func(x).astype(x.dtype) def thresholded_relu(