Skip to content

Commit

Permalink
expand tests ffftn
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudiaComito committed Oct 11, 2023
1 parent e5e015c commit c76d61c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
11 changes: 11 additions & 0 deletions heat/fft/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray:
local_x = x.larray
except AttributeError:
raise TypeError("x must be a DNDarray, is {}".format(type(x)))

original_split = x.split

# sanitize kwargs
Expand All @@ -126,6 +127,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray:
if repeated_axes:
raise NotImplementedError("Multiple transforms over the same axis not implemented yet.")
s = kwargs.get("s", None)
s = sanitize_axis(x.gshape, s)
norm = kwargs.get("norm", None)

# non-distributed DNDarray
Expand All @@ -142,6 +144,7 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray:
for i, axis in enumerate(axes):
output_shape[axis] = s[i]
else:
axes = tuple(range(x.ndim))
s = tuple(output_shape[axis] for axis in axes)
output_shape = tuple(output_shape)

Expand Down Expand Up @@ -170,6 +173,14 @@ def __fftn_op(x: DNDarray, fftn_op: callable, **kwargs) -> DNDarray:
)
x = x.transpose(transpose_axes)

# original split is 0 and fft is along axis 0
if x.ndim == 1:
_ = x.resplit(axis=None)
result = __fftn_op(_, fftn_op, **kwargs)
del _
result.resplit_(axis=0)
return result

# redistribute x from axis 0 to 1
_ = x.resplit(axis=1)
# FFT along axis 0 (now non-split)
Expand Down
19 changes: 17 additions & 2 deletions heat/fft/tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_fft(self):

def test_ifft(self):
# 1D non-distributed
x = ht.random.randn(6)
x = ht.random.randn(6, dtype=ht.float64)
x_fft = ht.fft.fft(x)
y = ht.fft.ifft(x_fft)
self.assertIsInstance(y, ht.DNDarray)
Expand All @@ -88,7 +88,22 @@ def test_irfft2(self):
pass

def test_fftn(self):
pass
# 1D non-distributed
x = ht.random.randn(6)
y = ht.fft.fftn(x)
np_y = np.fft.fftn(x.numpy())
self.assertIsInstance(y, ht.DNDarray)
self.assertEqual(y.shape, x.shape)
self.assert_array_equal(y, np_y)

# 1D distributed
x = ht.random.randn(6, split=0)
y = ht.fft.fftn(x)
np_y = np.fft.fftn(x.numpy())
self.assertIsInstance(y, ht.DNDarray)
self.assertEqual(y.shape, x.shape)
self.assertTrue(y.split == 0)
self.assert_array_equal(y, np_y)

def test_ifftn(self):
pass
Expand Down

1 comment on commit c76d61c

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: c76d61c Previous: 724a80b Ratio
heat_benchmarks_N4_CPU - ENERGY 983.2209550923338 J (157.54194875762911) 1.3133995955104978 kJ (0.25536059514139076) 748.61

This comment was automatically generated by workflow using github-action-benchmark.

CC: @ClaudiaComito

Please sign in to comment.