Skip to content

Commit

Permalink
Reformatting RELU function, added docstring examples (#7082)
Browse files Browse the repository at this point in the history
  • Loading branch information
dimikave authored Nov 17, 2022
1 parent 155ad43 commit 0cb8da8
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 35 deletions.
4 changes: 2 additions & 2 deletions .idea/ivy.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion ivy/array/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@


class ArrayWithActivations(abc.ABC):
def relu(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array:
def relu(
self: ivy.Array,
/,
*,
out: Optional[ivy.Array] = None
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.relu. This method simply wraps the
function, and so the docstring for ivy.relu also applies to this method
Expand Down
42 changes: 22 additions & 20 deletions ivy/container/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
class ContainerWithActivations(ContainerBase):
@staticmethod
def static_relu(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.relu.
Expand Down Expand Up @@ -52,11 +52,12 @@ def static_relu(
Examples
--------
>>> x = ivy.Container(a=ivy.array([1.0, 0, 1.0]))
>>> x = ivy.Container(a=ivy.array([1.0, -1.2]), b=ivy.array([0.4, -0.2]))
>>> y = ivy.Container.static_relu(x)
>>> print(y)
{
a: ivy.array([1., 0., 1.])
a: ivy.array([1., 0.]),
b: ivy.array([0.40000001, 0.])
}
"""
Expand All @@ -71,14 +72,14 @@ def static_relu(
)

def relu(
self: ivy.Container,
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
self: ivy.Container,
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.relu.
Expand Down Expand Up @@ -112,11 +113,12 @@ def relu(
Examples
--------
>>> x = ivy.Container(a=ivy.array([1.0, 0, 1.0]))
>>> x = ivy.Container(a=ivy.array([1.0, -1.2]), b=ivy.array([0.4, -0.2]))
>>> y = x.relu()
>>> print(y)
{
a: ivy.array([1., 0., 1.])
a: ivy.array([1., 0.]),
b: ivy.array([0.40000001, 0.])
}
"""
Expand Down
7 changes: 6 additions & 1 deletion ivy/functional/backends/jax/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ def leaky_relu(
return jnp.asarray(jnp.where(x > 0, x, jnp.multiply(x, alpha)), x.dtype)


def relu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
def relu(
x: JaxArray,
/,
*,
out: Optional[JaxArray] = None
) -> JaxArray:
return jnp.maximum(x, 0)


Expand Down
7 changes: 6 additions & 1 deletion ivy/functional/backends/numpy/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@


@_scalar_output_to_0d_array
def relu(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
def relu(
x: np.ndarray,
/,
*,
out: Optional[np.ndarray] = None
) -> np.ndarray:
return np.maximum(x, 0, out=out, dtype=x.dtype)


Expand Down
7 changes: 6 additions & 1 deletion ivy/functional/backends/tensorflow/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ def leaky_relu(
return tf.nn.leaky_relu(x, alpha)


def relu(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor:
def relu(
x: Tensor,
/,
*,
out: Optional[Tensor] = None
) -> Tensor:
return tf.nn.relu(x)


Expand Down
7 changes: 6 additions & 1 deletion ivy/functional/backends/torch/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@


@with_unsupported_dtypes({"1.11.0 and below": ("float16",)}, backend_version)
def relu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor:
def relu(
x: torch.Tensor,
/,
*,
out: Optional[torch.Tensor] = None
) -> torch.Tensor:
return torch.relu(x)


Expand Down
24 changes: 18 additions & 6 deletions ivy/functional/ivy/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
@handle_exceptions
@handle_array_like
def relu(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
out: Optional[ivy.Array] = None
) -> ivy.Array:
"""Applies the rectified linear unit function element-wise.
Expand All @@ -44,8 +47,8 @@ def relu(
an array containing the rectified linear unit activation of each element in
``x``.
Functional Examples
-------------------
Examples
--------
With :class:`ivy.Array` input:
>>> x = ivy.array([-1., 0., 1.])
Expand All @@ -65,6 +68,16 @@ def relu(
>>> y = ivy.relu(x)
>>> print(y)
ivy.array([0., 0., 2.])
With :class:`ivy.Container` input:
>>> x = ivy.Container(a=ivy.array([1.0, -1.2]), b=ivy.array([0.4, -0.2]))
>>> x = ivy.relu(x, out = x)
>>> print(x)
{
a: ivy.array([1., 0.]),
b: ivy.array([0.40000001, 0.])
}
"""
return current_backend(x).relu(x, out=out)

Expand Down Expand Up @@ -98,9 +111,8 @@ def leaky_relu(
ret
The input array with leaky relu applied element-wise.
Functional Examples
-------------------
Examples
--------
With :class:`ivy.Array` input:
>>> x = ivy.array([0.39, -0.85])
Expand Down

0 comments on commit 0cb8da8

Please sign in to comment.