Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: patch NDArrayOperatorsMixin #2534

Merged
merged 8 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ mccabe.max-complexity = 100
"src/awkward/__init__.py" = ["E402", "F401", "F403", "I001"]
"src/awkward/operations/__init__.py" = ["F403"]
"src/awkward/_nplikes/*" = ["TID251"]
"src/awkward/_operators.py" = ["TID251"]
"tests*/*" = ["T20", "TID251"]

[tool.ruff.flake8-tidy-imports.banned-api]
Expand Down
6 changes: 3 additions & 3 deletions src/awkward/_connect/numexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import sys
import warnings

from packaging.version import parse as parse_version
jpivarski marked this conversation as resolved.
Show resolved Hide resolved

import awkward as ak
from awkward._behavior import behavior_of
from awkward._layout import wrap_layout
Expand All @@ -25,9 +27,7 @@ def _import_numexpr():
) from err
else:
if not _has_checked_version:
if ak._util.parse_version(numexpr.__version__) < ak._util.parse_version(
"2.7.1"
):
if parse_version(numexpr.__version__) < parse_version("2.7.1"):
warnings.warn(
"Awkward Array is only known to work with numexpr 2.7.1 or later"
"(you have version {})".format(numexpr.__version__),
Expand Down
5 changes: 3 additions & 2 deletions src/awkward/_connect/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import chain

import numpy
from packaging.version import parse as parse_version

import awkward as ak
from awkward._backends.backend import Backend
Expand All @@ -20,13 +21,13 @@
from awkward._nplikes import to_nplike
from awkward._regularize import is_non_string_like_iterable
from awkward._typing import Iterator
from awkward._util import Sentinel, numpy_at_least
from awkward._util import Sentinel
from awkward.contents.numpyarray import NumpyArray

# NumPy 1.13.1 introduced NEP13, without which Awkward ufuncs won't work, which
# would be worse than lacking a feature: it would cause unexpected output.
# NumPy 1.17.0 introduced NEP18, which is optional (use ak.* instead of np.*).
if not numpy_at_least("1.13.1"):
if parse_version(numpy.__version__) < parse_version("1.13.1"):
jpivarski marked this conversation as resolved.
Show resolved Hide resolved
raise ImportError("NumPy 1.13.1 or later required")


Expand Down
4 changes: 3 additions & 1 deletion src/awkward/_connect/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import json
from collections.abc import Iterable, Sized

from packaging.version import parse as parse_version
jpivarski marked this conversation as resolved.
Show resolved Hide resolved

import awkward as ak
from awkward._backends.numpy import NumpyBackend
from awkward._nplikes.numpy import Numpy
Expand All @@ -29,7 +31,7 @@
"""

else:
if ak._util.parse_version(pyarrow.__version__) < ak._util.parse_version("7.0.0"):
if parse_version(pyarrow.__version__) < parse_version("7.0.0"):
pyarrow = None
error_message = "pyarrow 7.0.0 or later required for {0}"

Expand Down
6 changes: 3 additions & 3 deletions src/awkward/_nplikes/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from __future__ import annotations

import numpy
from packaging.version import parse as parse_version
jpivarski marked this conversation as resolved.
Show resolved Hide resolved

import awkward as ak
from awkward._nplikes.array_module import ArrayModuleNumpyLike
from awkward._nplikes.dispatch import register_nplike
from awkward._nplikes.numpylike import ArrayLike, NumpyMetadata
Expand Down Expand Up @@ -57,7 +57,7 @@ def packbits(
bitorder: Literal["big", "little"] = "big",
):
assert not isinstance(x, PlaceholderArray)
if ak._util.numpy_at_least("1.17.0"):
if parse_version(numpy.__version__) >= parse_version("1.17.0"):
return numpy.packbits(x, axis=axis, bitorder=bitorder)
else:
assert axis is None, "unsupported argument value for axis given"
Expand Down Expand Up @@ -85,7 +85,7 @@ def unpackbits(
bitorder: Literal["big", "little"] = "big",
):
assert not isinstance(x, PlaceholderArray)
if ak._util.numpy_at_least("1.17.0"):
if parse_version(numpy.__version__) >= parse_version("1.17.0"):
return numpy.unpackbits(x, axis=axis, count=count, bitorder=bitorder)
else:
assert axis is None, "unsupported argument value for axis given"
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_nplikes/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from numbers import Number

import numpy
from numpy.lib.mixins import NDArrayOperatorsMixin

import awkward as ak
from awkward._nplikes.dispatch import register_nplike
Expand All @@ -17,6 +16,7 @@
)
from awkward._nplikes.placeholder import PlaceholderArray
from awkward._nplikes.shape import ShapeItem, unknown_length
from awkward._operators import NDArrayOperatorsMixin
from awkward._regularize import is_integer, is_non_string_like_sequence
from awkward._typing import (
Any,
Expand Down
214 changes: 214 additions & 0 deletions src/awkward/_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
"""
Copyright (c) 2005-2023, NumPy Developers.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:

* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.

* Neither the name of the NumPy Developers nor the names of any
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from numpy.core import umath as um


def _disables_array_ufunc(obj):
"""True when __array_ufunc__ is set to None."""
try:
return obj.__array_ufunc__ is None
except AttributeError:
return False


def _binary_method(ufunc, name):
"""Implement a forward binary method with a ufunc, e.g., __add__."""

def func(self, other):
if _disables_array_ufunc(other):
return NotImplemented
return ufunc(self, other)

func.__name__ = f"__{name}__"
return func


def _reflected_binary_method(ufunc, name):
"""Implement a reflected binary method with a ufunc, e.g., __radd__."""

def func(self, other):
if _disables_array_ufunc(other):
return NotImplemented
return ufunc(other, self)

func.__name__ = f"__r{name}__"
return func


def _inplace_binary_method(ufunc, name):
"""Implement an in-place binary method with a ufunc, e.g., __iadd__."""

def func(self, other):
return ufunc(self, other, out=(self,))

func.__name__ = f"__i{name}__"
return func


def _numeric_methods(ufunc, name):
"""Implement forward, reflected and inplace binary methods with a ufunc."""
return (
_binary_method(ufunc, name),
_reflected_binary_method(ufunc, name),
_inplace_binary_method(ufunc, name),
)


def _unary_method(ufunc, name):
"""Implement a unary special method with a ufunc."""

def func(self):
return ufunc(self)

func.__name__ = f"__{name}__"
return func


class NDArrayOperatorsMixin:
"""Mixin defining all operator special methods using __array_ufunc__.

This class implements the special methods for almost all of Python's
builtin operators defined in the `operator` module, including comparisons
(``==``, ``>``, etc.) and arithmetic (``+``, ``*``, ``-``, etc.), by
deferring to the ``__array_ufunc__`` method, which subclasses must
implement.

It is useful for writing classes that do not inherit from `numpy.ndarray`,
but that should support arithmetic and numpy universal functions like
arrays as described in `A Mechanism for Overriding Ufuncs
<https://numpy.org/neps/nep-0013-ufunc-overrides.html>`_.

As an trivial example, consider this implementation of an ``ArrayLike``
class that simply wraps a NumPy array and ensures that the result of any
arithmetic operation is also an ``ArrayLike`` object::

class ArrayLike(np.lib.mixins.NDArrayOperatorsMixin):
def __init__(self, value):
self.value = np.asarray(value)

# One might also consider adding the built-in list type to this
# list, to support operations like np.add(array_like, list)
_HANDLED_TYPES = (np.ndarray, numbers.Number)

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
out = kwargs.get('out', ())
for x in inputs + out:
# Only support operations with instances of _HANDLED_TYPES.
# Use ArrayLike instead of type(self) for isinstance to
# allow subclasses that don't override __array_ufunc__ to
# handle ArrayLike objects.
if not isinstance(x, self._HANDLED_TYPES + (ArrayLike,)):
return NotImplemented

# Defer to the implementation of the ufunc on unwrapped values.
inputs = tuple(x.value if isinstance(x, ArrayLike) else x
for x in inputs)
if out:
kwargs['out'] = tuple(
x.value if isinstance(x, ArrayLike) else x
for x in out)
result = getattr(ufunc, method)(*inputs, **kwargs)

if type(result) is tuple:
# multiple return values
return tuple(type(self)(x) for x in result)
elif method == 'at':
# no return value
return None
else:
# one return value
return type(self)(result)

def __repr__(self):
return '%s(%r)' % (type(self).__name__, self.value)

In interactions between ``ArrayLike`` objects and numbers or numpy arrays,
the result is always another ``ArrayLike``:

>>> x = ArrayLike([1, 2, 3])
>>> x - 1
ArrayLike(array([0, 1, 2]))
>>> 1 - x
ArrayLike(array([ 0, -1, -2]))
>>> np.arange(3) - x
ArrayLike(array([-1, -1, -1]))
>>> x - np.arange(3)
ArrayLike(array([1, 1, 1]))

Note that unlike ``numpy.ndarray``, ``ArrayLike`` does not allow operations
with arbitrary, unrecognized types. This ensures that interactions with
ArrayLike preserve a well-defined casting hierarchy.

.. versionadded:: 1.13
"""

# Like np.ndarray, this mixin class implements "Option 1" from the ufunc
# overrides NEP.

# comparisons don't have reflected and in-place versions
__lt__ = _binary_method(um.less, "lt")
__le__ = _binary_method(um.less_equal, "le")
__eq__ = _binary_method(um.equal, "eq")
__ne__ = _binary_method(um.not_equal, "ne")
__gt__ = _binary_method(um.greater, "gt")
__ge__ = _binary_method(um.greater_equal, "ge")

# numeric methods
__add__, __radd__, __iadd__ = _numeric_methods(um.add, "add")
__sub__, __rsub__, __isub__ = _numeric_methods(um.subtract, "sub")
__mul__, __rmul__, __imul__ = _numeric_methods(um.multiply, "mul")
__matmul__, __rmatmul__, __imatmul__ = _numeric_methods(um.matmul, "matmul")
# Python 3 does not use __div__, __rdiv__, or __idiv__
__truediv__, __rtruediv__, __itruediv__ = _numeric_methods(
um.true_divide, "truediv"
)
__floordiv__, __rfloordiv__, __ifloordiv__ = _numeric_methods(
um.floor_divide, "floordiv"
)
__mod__, __rmod__, __imod__ = _numeric_methods(um.remainder, "mod")
__divmod__ = _binary_method(um.divmod, "divmod")
__rdivmod__ = _reflected_binary_method(um.divmod, "divmod")
# __idivmod__ does not exist
# TODO: handle the optional third argument for __pow__?
__pow__, __rpow__, __ipow__ = _numeric_methods(um.power, "pow")
__lshift__, __rlshift__, __ilshift__ = _numeric_methods(um.left_shift, "lshift")
__rshift__, __rrshift__, __irshift__ = _numeric_methods(um.right_shift, "rshift")
__and__, __rand__, __iand__ = _numeric_methods(um.bitwise_and, "and")
__xor__, __rxor__, __ixor__ = _numeric_methods(um.bitwise_xor, "xor")
__or__, __ror__, __ior__ = _numeric_methods(um.bitwise_or, "or")

# unary methods
__neg__ = _unary_method(um.negative, "neg")
__pos__ = _unary_method(um.positive, "pos")
__abs__ = _unary_method(um.absolute, "abs")
__invert__ = _unary_method(um.invert, "invert")
12 changes: 0 additions & 12 deletions src/awkward/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import sys
from collections.abc import Collection

import packaging.version

from awkward._typing import TypeVar

win = os.name == "nt"
Expand All @@ -24,16 +22,6 @@
kMaxLevels = 48


def parse_version(version):
return packaging.version.parse(version)


def numpy_at_least(version):
import numpy # noqa: TID251

return parse_version(numpy.__version__) >= parse_version(version)
jpivarski marked this conversation as resolved.
Show resolved Hide resolved


def in_module(obj, modulename: str) -> bool:
m = type(obj).__module__
return m == modulename or m.startswith(modulename + ".")
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from collections.abc import Iterable, Mapping, Sized

from awkward_cpp.lib import _ext
from numpy.lib.mixins import NDArrayOperatorsMixin # noqa: TID251

import awkward as ak
import awkward._connect.hist
Expand All @@ -21,6 +20,7 @@
from awkward._layout import wrap_layout
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._operators import NDArrayOperatorsMixin
from awkward._regularize import is_non_string_like_iterable

np = NumpyMetadata.instance()
Expand Down
3 changes: 2 additions & 1 deletion src/awkward/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math

import numpy # noqa: TID251
from packaging.version import parse as parse_version

import awkward as ak

Expand All @@ -27,7 +28,7 @@ def register_and_check():
) from err

if not _has_checked_version:
if ak._util.parse_version(numba.__version__) < ak._util.parse_version("0.50"):
if parse_version(numba.__version__) < parse_version("0.50"):
jpivarski marked this conversation as resolved.
Show resolved Hide resolved
raise ImportError(
"Awkward Array can only work with numba 0.50 or later "
"(you have version {})".format(numba.__version__)
Expand Down
Loading