Skip to content

Commit

Permalink
Implement roll (#465)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored May 19, 2024
1 parent fb12b09 commit 447421e
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 6 deletions.
1 change: 0 additions & 1 deletion .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ jobs:
array_api_tests/test_array_object.py::test_setitem
array_api_tests/test_array_object.py::test_setitem_masking
array_api_tests/test_manipulation_functions.py::test_flip
array_api_tests/test_manipulation_functions.py::test_roll
array_api_tests/test_sorting_functions.py
array_api_tests/test_statistical_functions.py::test_std
array_api_tests/test_statistical_functions.py::test_var
Expand Down
10 changes: 5 additions & 5 deletions api_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | `zeros` | :white_check_mark: | | |
| | `zeros_like` | :white_check_mark: | | |
| Data Type Functions | `astype` | :white_check_mark: | | |
| | `can_cast` | :white_check_mark: | | Same as `numpy.array_api` |
| | `finfo` | :white_check_mark: | | Same as `numpy.array_api` |
| | `iinfo` | :white_check_mark: | | Same as `numpy.array_api` |
| | `result_type` | :white_check_mark: | | Same as `numpy.array_api` |
| | `can_cast` | :white_check_mark: | | |
| | `finfo` | :white_check_mark: | | |
| | `iinfo` | :white_check_mark: | | |
| | `result_type` | :white_check_mark: | | |
| Data Types | `bool`, `int8`, ... | :white_check_mark: | | |
| Elementwise Functions | `add` | :white_check_mark: | | Example of a binary function |
| | `negative` | :white_check_mark: | | Example of a unary function |
Expand All @@ -52,7 +52,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | `flip` | :x: | 2 | Needs indexing with step=-1, [#114](https://github.com/cubed-dev/cubed/issues/114) |
| | `permute_dims` | :white_check_mark: | | |
| | `reshape` | :white_check_mark: | | Partial implementation |
| | `roll` | :x: | 2 | Use `concat` and `reshape`, [#115](https://github.com/cubed-dev/cubed/issues/115) |
| | `roll` | :white_check_mark: | | |
| | `squeeze` | :white_check_mark: | | |
| | `stack` | :white_check_mark: | | |
| Searching Functions | `argmax` | :white_check_mark: | | |
Expand Down
2 changes: 2 additions & 0 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@
moveaxis,
permute_dims,
reshape,
roll,
squeeze,
stack,
)
Expand All @@ -236,6 +237,7 @@
"moveaxis",
"permute_dims",
"reshape",
"roll",
"squeeze",
"stack",
]
Expand Down
41 changes: 41 additions & 0 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,47 @@ def _reshape_chunk(x, template):
return nxp.reshape(x, template.shape)


def roll(x, /, shift, *, axis=None):
# based on dask roll
result = x

if axis is None:
result = flatten(result)

if not isinstance(shift, int):
raise TypeError("Expect `shift` to be an int when `axis` is None.")

shift = (shift,)
axis = (0,)
else:
if not isinstance(shift, tuple):
shift = (shift,)
if not isinstance(axis, tuple):
axis = (axis,)

if len(shift) != len(axis):
raise ValueError("Must have the same number of shifts as axes.")

for i, s in zip(axis, shift):
shape = result.shape[i]
s = 0 if shape == 0 else -s % shape

sl1 = result.ndim * [slice(None)]
sl2 = result.ndim * [slice(None)]

sl1[i] = slice(s, None)
sl2[i] = slice(None, s)

sl1 = tuple(sl1)
sl2 = tuple(sl2)

# note we want the concatenated array to have the same chunking as input,
# not the chunking of result[sl1], which may be different
result = concat([result[sl1], result[sl2]], axis=i, chunks=result.chunks)

return reshape(result, x.shape)


def stack(arrays, /, *, axis=0):
if not arrays:
raise ValueError("Need array(s) to stack")
Expand Down
31 changes: 31 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,37 @@ def test_reshape_chunks_with_smaller_end_chunk(spec, executor):
)


def _maybe_len(a):
try:
return len(a)
except TypeError:
return 0


@pytest.mark.parametrize(
"chunks, shift, axis",
[
((2, 6), 3, None),
((2, 6), 3, 0),
((2, 6), (3, 9), (0, 1)),
((2, 6), (3, 9), None),
((2, 6), (3, 9), 1),
],
)
def test_roll(spec, executor, chunks, shift, axis):
x = np.arange(4 * 6).reshape((4, 6))
a = cubed.from_array(x, chunks=chunks, spec=spec)

if _maybe_len(shift) != _maybe_len(axis):
with pytest.raises(TypeError if axis is None else ValueError):
xp.roll(a, shift, axis=axis)
else:
assert_array_equal(
xp.roll(a, shift, axis=axis).compute(executor=executor),
np.roll(x, shift, axis),
)


def test_squeeze_1d(spec, executor):
a = xp.asarray([[1, 2, 3]], chunks=(1, 2), spec=spec)
b = xp.squeeze(a, 0)
Expand Down

0 comments on commit 447421e

Please sign in to comment.