From 4c64cf442e92bed9c8671b31d669121d1c6f75e2 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 16 May 2024 12:09:22 -0700 Subject: [PATCH] Adds a test for fixed behavior of `repeat` with 0-size arrays --- dpctl/tests/test_usm_ndarray_manipulation.py | 32 ++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index d75a3aa182..1e9c30d9cf 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -1485,3 +1485,35 @@ def test_tile_arg_validation(): x = dpt.empty(()) with pytest.raises(TypeError): dpt.tile(x, dict()) + + +def test_repeat_0_size(): + get_queue_or_skip() + + x = dpt.ones((0, 10, 0), dtype="i4") + repetitions = 2 + res = dpt.repeat(x, repetitions) + assert res.shape == (0,) + res = dpt.repeat(x, repetitions, axis=2) + assert res.shape == x.shape + res = dpt.repeat(x, repetitions, axis=1) + axis_sz = x.shape[1] * repetitions + assert res.shape == (0, 20, 0) + + repetitions = dpt.asarray(2, dtype="i4") + res = dpt.repeat(x, repetitions) + assert res.shape == (0,) + res = dpt.repeat(x, repetitions, axis=2) + assert res.shape == x.shape + res = dpt.repeat(x, repetitions, axis=1) + assert res.shape == (0, 20, 0) + + repetitions = dpt.arange(10, dtype="i4") + res = dpt.repeat(x, repetitions, axis=1) + axis_sz = dpt.sum(repetitions) + assert res.shape == (0, axis_sz, 0) + + repetitions = (2,) * 10 + res = dpt.repeat(x, repetitions, axis=1) + axis_sz = 2 * x.shape[1] + assert res.shape == (0, axis_sz, 0)