Skip to content

Commit

Permalink
Merge branch 'main' into workflows/backport-merged-pull-request
Browse files Browse the repository at this point in the history
  • Loading branch information
mtar authored Feb 1, 2024
2 parents bc7eb09 + 31d0752 commit 7c7bcb9
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions heat/core/tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,25 +1181,37 @@ def test_percentile(self):
# test list q and writing to output buffer
q = [0.1, 2.3, 15.9, 50.0, 84.1, 97.7, 99.9]
axis = 2
p_np = np.percentile(x_np, q, axis=axis, method="lower", keepdims=True)
try:
p_np = np.percentile(x_np, q, axis=axis, method="lower", keepdims=True)
except TypeError:
p_np = np.percentile(x_np, q, axis=axis, interpolation="lower", keepdims=True)
p_ht = ht.percentile(x_ht, q, axis=axis, interpolation="lower", keepdims=True)
out = ht.empty(p_np.shape, dtype=ht.float64, split=None, device=x_ht.device)
ht.percentile(x_ht, q, axis=axis, out=out, interpolation="lower", keepdims=True)
self.assertEqual(p_ht.numpy()[5].all(), p_np[5].all())
self.assertEqual(out.numpy()[2].all(), p_np[2].all())
self.assertTrue(p_ht.shape == p_np.shape)
axis = None
p_np = np.percentile(x_np, q, axis=axis, method="higher")
try:
p_np = np.percentile(x_np, q, axis=axis, method="higher")
except TypeError:
p_np = np.percentile(x_np, q, axis=axis, interpolation="higher")
p_ht = ht.percentile(x_ht, q, axis=axis, interpolation="higher")
self.assertEqual(p_ht.numpy()[6], p_np[6])
self.assertTrue(p_ht.shape == p_np.shape)
p_np = np.percentile(x_np, q, axis=axis, method="nearest")
try:
p_np = np.percentile(x_np, q, axis=axis, method="nearest")
except TypeError:
p_np = np.percentile(x_np, q, axis=axis, interpolation="nearest")
p_ht = ht.percentile(x_ht, q, axis=axis, interpolation="nearest")
self.assertEqual(p_ht.numpy()[2], p_np[2])

# test split q
q_ht = ht.array(q, split=0, comm=x_ht.comm)
p_np = np.percentile(x_np, q, axis=axis, method="midpoint")
try:
p_np = np.percentile(x_np, q, axis=axis, method="midpoint")
except TypeError:
p_np = np.percentile(x_np, q, axis=axis, interpolation="midpoint")
p_ht = ht.percentile(x_ht, q_ht, axis=axis, interpolation="midpoint")
self.assertEqual(p_ht.numpy()[4], p_np[4])

Expand Down

0 comments on commit 7c7bcb9

Please sign in to comment.