From 1017e75154663cf64a1fdf965b84d3b961d114e2 Mon Sep 17 00:00:00 2001 From: Ayomide <120118911+Ayo-folashade@users.noreply.github.com> Date: Sat, 7 Oct 2023 07:04:56 +0000 Subject: [PATCH] feat: add FFT3D function to tensorflow.raw_ops --- .../frontends/tensorflow/raw_ops.py | 8 ++++ .../test_tensorflow/test_raw_ops.py | 39 +++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/ivy/functional/frontends/tensorflow/raw_ops.py b/ivy/functional/frontends/tensorflow/raw_ops.py index 6eb795c355b33..2b8e507fe1078 100644 --- a/ivy/functional/frontends/tensorflow/raw_ops.py +++ b/ivy/functional/frontends/tensorflow/raw_ops.py @@ -569,6 +569,14 @@ def FFT2D(*, input, name="FFT2D"): return ivy.astype(ivy.fft2(input, dim=(-2, -1)), input.dtype) +@to_ivy_arrays_and_back +def FFT3D(*, input, name="FFT3D"): + fft_result = ivy.fft(input, -1) + fft_result = ivy.fft(fft_result, -2) + fft_result = ivy.fft(fft_result, -3) + return ivy.astype(fft_result, input.dtype) + + @to_ivy_arrays_and_back def Fill(*, dims, value, name="Full"): return ivy.full(dims, value) diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py index 6bbbf9eb8b009..d2222f6487b73 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops.py @@ -1885,6 +1885,45 @@ def test_tensorflow_FFT2D( ) +# FFT3D +@handle_frontend_test( + fn_tree="tensorflow.raw_ops.FFT3D", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("complex"), + min_value=-1e5, + max_value=1e5, + min_num_dims=3, + max_num_dims=5, + min_dim_size=2, + max_dim_size=5, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + ), +) +def test_tensorflow_FFT3D( + *, + dtype_and_x, + frontend, + test_flags, + fn_tree, + backend_fw, + on_device, +): + dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + rtol=1e-02, + atol=1e-02, + ) + + # fill @handle_frontend_test( fn_tree="tensorflow.raw_ops.Fill",