Skip to content

Commit

Permalink
add fftshift to jax frontend (#20857)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnhparitosh authored Aug 3, 2023
1 parent a4d86c4 commit 775ed27
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
19 changes: 19 additions & 0 deletions ivy/functional/frontends/jax/numpy/fft.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
# local
import ivy
from ivy.functional.frontends.jax.func_wrapper import to_ivy_arrays_and_back
from ivy.func_wrapper import with_unsupported_dtypes


@to_ivy_arrays_and_back
@with_unsupported_dtypes({"2.4.2 and below": ("float16", "bfloat16")}, "paddle")
def fftshift(x, axes=None, name=None):
shape = x.shape

if axes is None:
axes = tuple(range(x.ndim))
shifts = [(dim // 2) for dim in shape]
elif isinstance(axes, int):
shifts = shape[axes] // 2
else:
shifts = [shape[ax] // 2 for ax in axes]

roll = ivy.roll(x, shifts, axis=axes)

return roll


@to_ivy_arrays_and_back
Expand Down
24 changes: 24 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,27 @@ def test_jax_numpy_fft(
atol=1e-02,
rtol=1e-02,
)


# fftshift
@handle_frontend_test(
fn_tree="jax.numpy.fft.fftshift",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"), shape=(4,), array_api_dtypes=True
),
)
def test_jax_numpy_fftshift(
dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device
):
input_dtype, arr = dtype_and_x
helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
test_values=True,
x=arr[0],
axes=None,
)

0 comments on commit 775ed27

Please sign in to comment.