Skip to content

Commit

Permalink
Merge pull request #87 from asmeurer/require-numpy-2
Browse files Browse the repository at this point in the history
Require NumPy >= 2.1
  • Loading branch information
asmeurer authored Nov 8, 2024
2 parents f8e7c84 + 54ce945 commit a711897
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 98 deletions.
9 changes: 3 additions & 6 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
numpy-version: ['1.26', 'dev']
exclude:
- python-version: '3.8'
numpy-version: 'dev'
python-version: ['3.10', '3.11', '3.12']
numpy-version: ['2.1', 'dev']

steps:
- name: Checkout array-api-strict
Expand All @@ -38,7 +35,7 @@ jobs:
if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then
python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy;
else
python -m pip install 'numpy>=1.26,<2.0';
python -m pip install 'numpy>=${{ matrix.numpy-version }},<${{ matrix.numpy-version }}.99';
fi
python -m pip install ${GITHUB_WORKSPACE}/array-api-strict
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
Expand Down
9 changes: 3 additions & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
numpy-version: ['1.26', 'dev']
exclude:
- python-version: '3.8'
numpy-version: 'dev'
python-version: ['3.10', '3.11', '3.12']
numpy-version: ['2.1', 'dev']
fail-fast: true
steps:
- uses: actions/checkout@v4
Expand All @@ -22,7 +19,7 @@ jobs:
if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then
python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy;
else
python -m pip install 'numpy>=1.26,<2.0';
python -m pip install 'numpy>=${{ matrix.numpy-version }},<${{ matrix.numpy-version }}.99';
fi
python -m pip install -r requirements-dev.txt
- name: Run Tests
Expand Down
6 changes: 6 additions & 0 deletions array_api_strict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
"""

import numpy as np
from numpy.lib import NumpyVersion

if NumpyVersion(np.__version__) < NumpyVersion('2.1.0'):
raise ImportError("array-api-strict requires NumPy >= 2.1.0")

__all__ = []

# Warning: __array_api_version__ could change globally with
Expand Down
40 changes: 9 additions & 31 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,7 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
if _allow_array:
if self._device != CPU_DEVICE:
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
# copy keyword is new in 2.0.0; for older versions don't use it
# retry without that keyword.
if np.__version__[0] < '2':
return np.asarray(self._array, dtype=dtype)
elif np.__version__.startswith('2.0.0-dev0'):
# Handle dev version for which we can't know based on version
# number whether or not the copy keyword is supported.
try:
return np.asarray(self._array, dtype=dtype, copy=copy)
except TypeError:
return np.asarray(self._array, dtype=dtype)
else:
return np.asarray(self._array, dtype=dtype, copy=copy)
return np.asarray(self._array, dtype=dtype, copy=copy)
raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported")

# These are various helper functions to make the array behavior match the
Expand Down Expand Up @@ -586,24 +574,14 @@ def __dlpack__(
if copy is not _default:
raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API")

if np.__version__[0] < '2.1':
if max_version not in [_default, None]:
raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented")
if dl_device not in [_default, None]:
raise NotImplementedError("The device argument to __dlpack__ is not yet implemented")
if copy not in [_default, None]:
raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented")

return self._array.__dlpack__(stream=stream)
else:
kwargs = {'stream': stream}
if max_version is not _default:
kwargs['max_version'] = max_version
if dl_device is not _default:
kwargs['dl_device'] = dl_device
if copy is not _default:
kwargs['copy'] = copy
return self._array.__dlpack__(**kwargs)
kwargs = {'stream': stream}
if max_version is not _default:
kwargs['max_version'] = max_version
if dl_device is not _default:
kwargs['dl_device'] = dl_device
if copy is not _default:
kwargs['copy'] = copy
return self._array.__dlpack__(**kwargs)

def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
"""
Expand Down
23 changes: 0 additions & 23 deletions array_api_strict/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,29 +83,6 @@ def asarray(
if isinstance(obj, Array) and device is None:
device = obj.device

if np.__version__[0] < '2':
if copy is False:
# Note: copy=False is not yet implemented in np.asarray for
# NumPy 1

# Work around it by creating the new array and seeing if NumPy
# copies it.
if isinstance(obj, Array):
new_array = np.array(obj._array, copy=copy, dtype=_np_dtype)
if new_array is not obj._array:
raise ValueError("Unable to avoid copy while creating an array from given array.")
return Array._new(new_array, device=device)
elif _supports_buffer_protocol(obj):
# Buffer protocol will always support no-copy
return Array._new(np.array(obj, copy=copy, dtype=_np_dtype), device=device)
else:
# No-copy is unsupported for Python built-in types.
raise ValueError("Unable to avoid copy while creating an array from given object.")

if copy is None:
# NumPy 1 treats copy=False the same as the standard copy=None
copy = False

if isinstance(obj, Array):
return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype), device=device)
if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
Expand Down
42 changes: 15 additions & 27 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,31 +456,19 @@ def dlpack_2023_12(api_version):
set_array_api_strict_flags(api_version=api_version)

a = asarray([1, 2, 3], dtype=int8)
# Never an error
a.__dlpack__()


if np.__version__ < '2.1':
exception = NotImplementedError if api_version >= '2023.12' else ValueError
pytest.raises(exception, lambda:
a.__dlpack__(dl_device=CPU_DEVICE))
pytest.raises(exception, lambda:
a.__dlpack__(dl_device=None))
pytest.raises(exception, lambda:
a.__dlpack__(max_version=(1, 0)))
pytest.raises(exception, lambda:
a.__dlpack__(max_version=None))
pytest.raises(exception, lambda:
a.__dlpack__(copy=False))
pytest.raises(exception, lambda:
a.__dlpack__(copy=True))
pytest.raises(exception, lambda:
a.__dlpack__(copy=None))
else:
a.__dlpack__(dl_device=CPU_DEVICE)
a.__dlpack__(dl_device=None)
a.__dlpack__(max_version=(1, 0))
a.__dlpack__(max_version=None)
a.__dlpack__(copy=False)
a.__dlpack__(copy=True)
a.__dlpack__(copy=None)
# Do not error
a.__dlpack__()
a.__dlpack__(dl_device=CPU_DEVICE)
a.__dlpack__(dl_device=None)
a.__dlpack__(max_version=(1, 0))
a.__dlpack__(max_version=None)
a.__dlpack__(copy=False)
a.__dlpack__(copy=True)
a.__dlpack__(copy=None)

x = np.from_dlpack(a)
assert isinstance(x, np.ndarray)
assert x.dtype == np.int8
assert x.shape == (3,)
assert np.all(x == np.asarray([1, 2, 3]))
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pytest
hypothesis
numpy
numpy>=2.1
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
numpy
numpy>=2.1
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
long_description_content_type="text/markdown",
url="https://data-apis.org/array-api-strict/",
license="MIT",
python_requires=">=3.9",
install_requires=["numpy"],
python_requires=">=3.10",
install_requires=["numpy>=2.1"],
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
],
Expand Down

0 comments on commit a711897

Please sign in to comment.