Skip to content

Commit

Permalink
Declare Array API 2023.12 support (#651)
Browse files Browse the repository at this point in the history
* Declare Array API 2023.12 support

* Add (unused) device parameter to astype

* Fix unstack edge cases

* Bumpy array API tests commit to test against

* Change dtype=None behavior in sum/prod following data-apis/array-api#744

* Fix unstack edge cases

* Update array api test skips file

* Only support integral values for `repeats` in `repeat`
  • Loading branch information
tomwhite authored Dec 21, 2024
1 parent fc2201e commit 1bc5769
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 32 deletions.
14 changes: 9 additions & 5 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
uses: actions/checkout@v3
with:
repository: data-apis/array-api-tests
ref: 'db95e67b29235249e5776ca2b6bb4e77117e0690' # Latest commit as of 2024-08-08
ref: '1572b129c6682211abfe139e112592226c361a6c' # Latest commit as of 2024-12-04
path: array-api-tests
submodules: "true"
- name: Set up Python ${{ matrix.python-version }}
Expand Down Expand Up @@ -90,8 +90,7 @@ jobs:
array_api_tests/test_has_names.py
# signatures of items not implemented
array_api_tests/test_signatures.py::test_func_signature[std]
array_api_tests/test_signatures.py::test_func_signature[var]
array_api_tests/test_signatures.py::test_func_signature[cumulative_sum]
array_api_tests/test_signatures.py::test_func_signature[unique_all]
array_api_tests/test_signatures.py::test_func_signature[unique_counts]
array_api_tests/test_signatures.py::test_func_signature[unique_inverse]
Expand All @@ -110,13 +109,15 @@ jobs:
array_api_tests/test_linalg.py::test_vecdot
# (getitem with negative step size is not implemented)
array_api_tests/test_array_object.py::test_getitem
# test_searchsorted depends on sort which is not implemented
array_api_tests/test_searching_functions.py::test_searchsorted
# not implemented
array_api_tests/test_array_object.py::test_setitem
array_api_tests/test_array_object.py::test_setitem_masking
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_sorting_functions.py
array_api_tests/test_statistical_functions.py::test_std
array_api_tests/test_statistical_functions.py::test_var
array_api_tests/test_statistical_functions.py::test_cumulative_sum
# finfo(float32).eps returns float32 but should return float
array_api_tests/test_data_type_functions.py::test_finfo[float32]
Expand All @@ -126,6 +127,9 @@ jobs:
# https://github.com/numpy/numpy/issues/18881
array_api_tests/test_creation_functions.py::test_linspace
# https://github.com/numpy/numpy/issues/20870
#array_api_tests/test_data_type_functions.py::test_can_cast
EOF
pytest -v -rxXfEA --hypothesis-max-examples=2 --disable-data-dependent-shapes --disable-extension linalg --hypothesis-disable-deadline
6 changes: 2 additions & 4 deletions api_status.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
## Array API Coverage Implementation Status

Cubed supports version [2022.12](https://data-apis.org/array-api/2022.12/index.html) of the Python array API standard, with a few exceptions noted below. The [Fourier transform functions](https://data-apis.org/array-api/2022.12/extensions/fourier_transform_functions.html) are *not* supported.

Support for version [2023.12](https://data-apis.org/array-api/2023.12/index.html) is tracked in Cubed issue [#438](https://github.com/cubed-dev/cubed/issues/438).
Cubed supports version [2023.12](https://data-apis.org/array-api/2023.12/index.html) of the Python array API standard, with a few exceptions noted below. The [Fourier transform functions](https://data-apis.org/array-api/2023.12/extensions/fourier_transform_functions.html) are *not* supported.

This table shows which parts of the the [Array API](https://data-apis.org/array-api/latest/API_specification/index.html) have been implemented in Cubed, and which ones are missing. The version column shows the version when the feature was added to the standard, for version 2022.12 or later.

Expand Down Expand Up @@ -61,7 +59,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | `expand_dims` | :white_check_mark: | | |
| | `flip` | :white_check_mark: | | |
| | `permute_dims` | :white_check_mark: | | |
| | `repeat` | :white_check_mark: | | |
| | `repeat` | :white_check_mark: | 2023.12 | |
| | `reshape` | :white_check_mark: | | Partial implementation |
| | `roll` | :white_check_mark: | | |
| | `squeeze` | :white_check_mark: | | |
Expand Down
2 changes: 1 addition & 1 deletion cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

# Array API

__array_api_version__ = "2022.12"
__array_api_version__ = "2023.12"

from .array_api.inspection import __array_namespace_info__

Expand Down
2 changes: 1 addition & 1 deletion cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = []

__array_api_version__ = "2022.12"
__array_api_version__ = "2023.12"

from .inspection import __array_namespace_info__

Expand Down
6 changes: 5 additions & 1 deletion cubed/array_api/array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,11 @@ def __abs__(self, /):
return elemwise(nxp.abs, self, dtype=dtype)

def __array_namespace__(self, /, *, api_version=None):
if api_version is not None and api_version not in ("2021.12", "2022.12"):
if api_version is not None and api_version not in (
"2021.12",
"2022.12",
"2023.12",
):
raise ValueError(f"Unrecognized array API version: {api_version!r}")
import cubed.array_api as array_api

Expand Down
2 changes: 1 addition & 1 deletion cubed/array_api/data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from cubed.core import CoreArray, map_blocks


def astype(x, dtype, /, *, copy=True):
def astype(x, dtype, /, *, copy=True, device=None):
if not copy and dtype == x.dtype:
return x
return map_blocks(_astype, x, dtype=dtype, astype_dtype=dtype)
Expand Down
9 changes: 7 additions & 2 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ def permute_dims(x, /, axes):


def repeat(x, repeats, /, *, axis=0):
if not isinstance(repeats, int):
raise ValueError("repeat only supports integral values for `repeats`")

if axis is None:
x = flatten(x)
axis = 0
Expand Down Expand Up @@ -599,8 +602,10 @@ def unstack(x, /, *, axis=0):

n_arrays = x.shape[axis]

if n_arrays == 1:
return (x,)
if n_arrays == 0:
return ()
elif n_arrays == 1:
return (squeeze(x, axis=axis),)

shape = x.shape[:axis] + x.shape[axis + 1 :]
dtype = x.dtype
Expand Down
12 changes: 0 additions & 12 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
_real_numeric_dtypes,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
complex64,
complex128,
float32,
float64,
int64,
uint64,
)
Expand Down Expand Up @@ -128,10 +124,6 @@ def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
dtype = int64
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
elif x.dtype == float32:
dtype = float64
elif x.dtype == complex64:
dtype = complex128
else:
dtype = x.dtype
extra_func_kwargs = dict(dtype=dtype)
Expand Down Expand Up @@ -169,10 +161,6 @@ def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
dtype = int64
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
elif x.dtype == float32:
dtype = float64
elif x.dtype == complex64:
dtype = complex128
else:
dtype = x.dtype
extra_func_kwargs = dict(dtype=dtype)
Expand Down
9 changes: 7 additions & 2 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,10 +722,15 @@ def test_unstack(spec, executor, chunks):
assert_array_equal(cu, np.full((4, 6), 3))


def test_unstack_noop(spec):
def test_unstack_zero_arrays(spec):
a = xp.full((0, 4, 6), 1, chunks=(1, 2, 3), spec=spec)
assert xp.unstack(a) == ()


def test_unstack_single_array(spec):
a = xp.full((1, 4, 6), 1, chunks=(1, 2, 3), spec=spec)
(b,) = xp.unstack(a)
assert a is b
assert_array_equal(b.compute(), np.full((4, 6), 1))


# Searching functions
Expand Down
4 changes: 1 addition & 3 deletions docs/array-api.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Python Array API

Cubed implements version 2022.12 of the [Python Array API standard](https://data-apis.org/array-api/2022.12/index.html) in `cubed.array_api`, with a few exceptions listed on the [coverage status](https://github.com/cubed-dev/cubed/blob/main/api_status.md) page. The [Fourier transform functions](https://data-apis.org/array-api/2022.12/extensions/fourier_transform_functions.html) are *not* supported.

Support for version [2023.12](https://data-apis.org/array-api/2023.12/index.html) is tracked in Cubed issue [#438](https://github.com/cubed-dev/cubed/issues/438).
Cubed implements version 2023.12 of the [Python Array API standard](https://data-apis.org/array-api/2023.12/index.html) in `cubed.array_api`, with a few exceptions listed on the [coverage status](https://github.com/cubed-dev/cubed/blob/main/api_status.md) page. The [Fourier transform functions](https://data-apis.org/array-api/2023.12/extensions/fourier_transform_functions.html) are *not* supported.

## Differences between Cubed and the standard

Expand Down

0 comments on commit 1bc5769

Please sign in to comment.