Skip to content

Commit

Permalink
Adds a test for fixed behavior of repeat with 0-size arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
ndgrigorian committed May 16, 2024
1 parent 57033b0 commit 4c64cf4
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,3 +1485,35 @@ def test_tile_arg_validation():
x = dpt.empty(())
with pytest.raises(TypeError):
dpt.tile(x, dict())


def test_repeat_0_size():
get_queue_or_skip()

x = dpt.ones((0, 10, 0), dtype="i4")
repetitions = 2
res = dpt.repeat(x, repetitions)
assert res.shape == (0,)
res = dpt.repeat(x, repetitions, axis=2)
assert res.shape == x.shape
res = dpt.repeat(x, repetitions, axis=1)
axis_sz = x.shape[1] * repetitions
assert res.shape == (0, 20, 0)

repetitions = dpt.asarray(2, dtype="i4")
res = dpt.repeat(x, repetitions)
assert res.shape == (0,)
res = dpt.repeat(x, repetitions, axis=2)
assert res.shape == x.shape
res = dpt.repeat(x, repetitions, axis=1)
assert res.shape == (0, 20, 0)

repetitions = dpt.arange(10, dtype="i4")
res = dpt.repeat(x, repetitions, axis=1)
axis_sz = dpt.sum(repetitions)
assert res.shape == (0, axis_sz, 0)

repetitions = (2,) * 10
res = dpt.repeat(x, repetitions, axis=1)
axis_sz = 2 * x.shape[1]
assert res.shape == (0, axis_sz, 0)

0 comments on commit 4c64cf4

Please sign in to comment.