Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require NumPy >= 2.1 #87

Merged
merged 3 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
numpy-version: ['1.26', 'dev']
exclude:
- python-version: '3.8'
numpy-version: 'dev'
python-version: ['3.10', '3.11', '3.12']
numpy-version: ['2.1', 'dev']

steps:
- name: Checkout array-api-strict
Expand All @@ -38,7 +35,7 @@ jobs:
if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then
python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy;
else
python -m pip install 'numpy>=1.26,<2.0';
python -m pip install 'numpy>=${{ matrix.numpy-version }},<${{ matrix.numpy-version }}.99';
fi
python -m pip install ${GITHUB_WORKSPACE}/array-api-strict
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
Expand Down
9 changes: 3 additions & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
numpy-version: ['1.26', 'dev']
exclude:
- python-version: '3.8'
numpy-version: 'dev'
python-version: ['3.10', '3.11', '3.12']
numpy-version: ['2.1', 'dev']
fail-fast: true
steps:
- uses: actions/checkout@v4
Expand All @@ -22,7 +19,7 @@ jobs:
if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then
python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy;
else
python -m pip install 'numpy>=1.26,<2.0';
python -m pip install 'numpy>=${{ matrix.numpy-version }},<${{ matrix.numpy-version }}.99';
fi
python -m pip install -r requirements-dev.txt
- name: Run Tests
Expand Down
6 changes: 6 additions & 0 deletions array_api_strict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@

"""

import numpy as np
from numpy.lib import NumpyVersion

if NumpyVersion(np.__version__) < NumpyVersion('2.1.0'):
raise ImportError("array-api-strict requires NumPy >= 2.1.0")

__all__ = []

# Warning: __array_api_version__ could change globally with
Expand Down
40 changes: 9 additions & 31 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,7 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
if _allow_array:
if self._device != CPU_DEVICE:
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
# copy keyword is new in 2.0.0; for older versions don't use it
# retry without that keyword.
if np.__version__[0] < '2':
return np.asarray(self._array, dtype=dtype)
elif np.__version__.startswith('2.0.0-dev0'):
# Handle dev version for which we can't know based on version
# number whether or not the copy keyword is supported.
try:
return np.asarray(self._array, dtype=dtype, copy=copy)
except TypeError:
return np.asarray(self._array, dtype=dtype)
else:
return np.asarray(self._array, dtype=dtype, copy=copy)
return np.asarray(self._array, dtype=dtype, copy=copy)
raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported")

# These are various helper functions to make the array behavior match the
Expand Down Expand Up @@ -586,24 +574,14 @@ def __dlpack__(
if copy is not _default:
raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API")

if np.__version__[0] < '2.1':
if max_version not in [_default, None]:
raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented")
if dl_device not in [_default, None]:
raise NotImplementedError("The device argument to __dlpack__ is not yet implemented")
if copy not in [_default, None]:
raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented")

return self._array.__dlpack__(stream=stream)
else:
kwargs = {'stream': stream}
if max_version is not _default:
kwargs['max_version'] = max_version
if dl_device is not _default:
kwargs['dl_device'] = dl_device
if copy is not _default:
kwargs['copy'] = copy
return self._array.__dlpack__(**kwargs)
kwargs = {'stream': stream}
if max_version is not _default:
kwargs['max_version'] = max_version
if dl_device is not _default:
kwargs['dl_device'] = dl_device
if copy is not _default:
kwargs['copy'] = copy
return self._array.__dlpack__(**kwargs)

def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
"""
Expand Down
23 changes: 0 additions & 23 deletions array_api_strict/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,29 +83,6 @@ def asarray(
if isinstance(obj, Array) and device is None:
device = obj.device

if np.__version__[0] < '2':
if copy is False:
# Note: copy=False is not yet implemented in np.asarray for
# NumPy 1

# Work around it by creating the new array and seeing if NumPy
# copies it.
if isinstance(obj, Array):
new_array = np.array(obj._array, copy=copy, dtype=_np_dtype)
if new_array is not obj._array:
raise ValueError("Unable to avoid copy while creating an array from given array.")
return Array._new(new_array, device=device)
elif _supports_buffer_protocol(obj):
# Buffer protocol will always support no-copy
return Array._new(np.array(obj, copy=copy, dtype=_np_dtype), device=device)
else:
# No-copy is unsupported for Python built-in types.
raise ValueError("Unable to avoid copy while creating an array from given object.")

if copy is None:
# NumPy 1 treats copy=False the same as the standard copy=None
copy = False

if isinstance(obj, Array):
return Array._new(np.array(obj._array, copy=copy, dtype=_np_dtype), device=device)
if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
Expand Down
42 changes: 15 additions & 27 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,31 +456,19 @@ def dlpack_2023_12(api_version):
set_array_api_strict_flags(api_version=api_version)

a = asarray([1, 2, 3], dtype=int8)
# Never an error
a.__dlpack__()


if np.__version__ < '2.1':
exception = NotImplementedError if api_version >= '2023.12' else ValueError
pytest.raises(exception, lambda:
a.__dlpack__(dl_device=CPU_DEVICE))
pytest.raises(exception, lambda:
a.__dlpack__(dl_device=None))
pytest.raises(exception, lambda:
a.__dlpack__(max_version=(1, 0)))
pytest.raises(exception, lambda:
a.__dlpack__(max_version=None))
pytest.raises(exception, lambda:
a.__dlpack__(copy=False))
pytest.raises(exception, lambda:
a.__dlpack__(copy=True))
pytest.raises(exception, lambda:
a.__dlpack__(copy=None))
else:
a.__dlpack__(dl_device=CPU_DEVICE)
a.__dlpack__(dl_device=None)
a.__dlpack__(max_version=(1, 0))
a.__dlpack__(max_version=None)
a.__dlpack__(copy=False)
a.__dlpack__(copy=True)
a.__dlpack__(copy=None)
# Do not error
a.__dlpack__()
a.__dlpack__(dl_device=CPU_DEVICE)
a.__dlpack__(dl_device=None)
a.__dlpack__(max_version=(1, 0))
a.__dlpack__(max_version=None)
a.__dlpack__(copy=False)
a.__dlpack__(copy=True)
a.__dlpack__(copy=None)

x = np.from_dlpack(a)
assert isinstance(x, np.ndarray)
assert x.dtype == np.int8
assert x.shape == (3,)
assert np.all(x == np.asarray([1, 2, 3]))
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pytest
hypothesis
numpy
numpy>=2.1
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
numpy
numpy>=2.1
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
long_description_content_type="text/markdown",
url="https://data-apis.org/array-api-strict/",
license="MIT",
python_requires=">=3.9",
install_requires=["numpy"],
python_requires=">=3.10",
install_requires=["numpy>=2.1"],
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
],
Expand Down
Loading