Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactored gelu #841

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ivy/functional/backends/jax/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ def leaky_relu(x: JaxArray, alpha: Optional[float] = 0.2)\
-> JaxArray:
return jnp.where(x > 0, x, x * alpha)


gelu = jax.nn.gelu
def gelu(x: JaxArray, approximate: bool = True)\
-> JaxArray:
return jax.nn.gelu(x, approximate)

tanh = jnp.tanh
sigmoid = lambda x: 1 / (1 + jnp.exp(-x))

Expand Down
3 changes: 2 additions & 1 deletion ivy/functional/backends/mxnet/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def leaky_relu(x: _mx.nd.NDArray, alpha: Optional[float] = 0.2)\
return _mx.nd.LeakyReLU(x, slope=alpha)


def gelu(x, approximate=True):
def gelu(x: _mx.nd.NDArray, approximate:bool = True)\
-> _mx.nd.NDArray:
if approximate:
return 0.5 * x * (1 + _mx.nd.tanh(((2 / _np.pi) ** 0.5) * (x + 0.044715 * x ** 3)))
return _mx.nd.LeakyReLU(x, act_type='gelu')
Expand Down
4 changes: 2 additions & 2 deletions ivy/functional/backends/numpy/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def leaky_relu(x: np.ndarray, alpha: Optional[float] = 0.2)\
return np.where(x > 0, x, x * alpha)


def gelu(x, approximate=True):
def gelu(x: np.ndarray, approximate: bool=True)\
-> np.ndarray:
if _erf is None:
raise Exception('scipy must be installed in order to call ivy.gelu with a numpy backend.')
if approximate:
return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x ** 3)))
return 0.5 * x * (1 + _erf(x/np.sqrt(2)))


tanh = np.tanh
sigmoid = lambda x: 1 / (1 + np.exp(-x))

Expand Down
6 changes: 5 additions & 1 deletion ivy/functional/backends/tensorflow/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# global
import tensorflow as tf
import numpy as np
from tensorflow.python.types.core import Tensor


Expand All @@ -19,7 +20,10 @@ def leaky_relu(x: Tensor, alpha: Optional[float] = 0.2)\
return tf.nn.leaky_relu(x, alpha)


gelu = lambda x, approximate=True: tf.nn.gelu(x, approximate)
def gelu(x: Tensor, approximate: bool =True)\
-> Tensor:
return tf.nn.gelu(x, approximate)

tanh = tf.nn.tanh
sigmoid = tf.nn.sigmoid
softmax = tf.nn.softmax
Expand Down
3 changes: 2 additions & 1 deletion ivy/functional/backends/torch/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def leaky_relu(x: torch.Tensor, alpha: Optional[float] = 0.2)\
return torch.nn.functional.leaky_relu(x, alpha)


def gelu(x, approximate: bool = True):
def gelu(x: torch.Tensor, approximate: bool = True)\
-> torch.Tensor:
if approximate:
return 0.5 * x * (1 + torch.tanh(((2 / np.pi) ** 0.5) * (x + 0.044715 * x ** 3)))
return torch.nn.functional.gelu(x)
Expand Down
26 changes: 20 additions & 6 deletions ivy/functional/ivy/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,29 @@ def leaky_relu(x: Union[ivy.Array, ivy.NativeArray], alpha: Optional[float] = 0.
return _cur_framework(x).leaky_relu(x, alpha)


def gelu(x, approximate=True):
def gelu(x: Union[ivy.Array, ivy.NativeArray], approximate: bool = True)\
-> ivy.Array:
"""
Applies the Gaussian error linear unit (GELU) activation function.

:param x: Input array.
:type x: array
:param approximate: Whether to approximate, default is True.
:type approximate: bool, optional
:return: The input array with leaky relu applied element-wise.
Parameters
-----------
x:
Input array.
approximate:
Whether to approximate. Default: True.

Returns
-------
out:
The input array with gelu applied element-wise on ``x``.

Examples:
>>> x = ivy.array([-1. , 0. , 1. ])
>>> y = ivy.gelu(x, True)
>>> print(y)
[-0.5, 0. , 0.5]

"""
return _cur_framework(x).gelu(x, approximate)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def test_gelu(x, approx, dtype, tensor_fn, dev, call):
assert ret.shape == x.shape
# value test
assert np.allclose(call(ivy.gelu, x, approx), ivy.functional.backends.numpy.gelu(ivy.to_numpy(x), approx))
# docstrings test
helpers.assert_docstring_examples_run(ivy.gelu)


# tanh
Expand Down