Skip to content

Commit

Permalink
BUG: reject ndarrays in binary operators
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Nov 27, 2024
1 parent d086c61 commit 6d67d46
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
11 changes: 11 additions & 0 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def _check_device(self, other):
elif isinstance(other, Array):
if self.device != other.device:
raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.")
else:
raise TypeError(f"Cannot combine an Array with {type(other)}.")

# Helper function to match the type promotion rules in the spec
def _promote_scalar(self, scalar):
Expand Down Expand Up @@ -1066,6 +1068,7 @@ def __imod__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __imod__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__imod__")
if other is NotImplemented:
return other
Expand All @@ -1088,6 +1091,7 @@ def __imul__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __imul__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__imul__")
if other is NotImplemented:
return other
Expand All @@ -1110,6 +1114,7 @@ def __ior__(self: Array, other: Union[int, bool, Array], /) -> Array:
"""
Performs the operation __ior__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__ior__")
if other is NotImplemented:
return other
Expand All @@ -1132,6 +1137,7 @@ def __ipow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __ipow__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__ipow__")
if other is NotImplemented:
return other
Expand All @@ -1144,6 +1150,7 @@ def __rpow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
from ._elementwise_functions import pow

self._check_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__rpow__")
if other is NotImplemented:
return other
Expand All @@ -1155,6 +1162,7 @@ def __irshift__(self: Array, other: Union[int, Array], /) -> Array:
"""
Performs the operation __irshift__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "integer", "__irshift__")
if other is NotImplemented:
return other
Expand All @@ -1177,6 +1185,7 @@ def __isub__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __isub__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__isub__")
if other is NotImplemented:
return other
Expand All @@ -1199,6 +1208,7 @@ def __itruediv__(self: Array, other: Union[float, Array], /) -> Array:
"""
Performs the operation __itruediv__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "floating-point", "__itruediv__")
if other is NotImplemented:
return other
Expand All @@ -1221,6 +1231,7 @@ def __ixor__(self: Array, other: Union[int, bool, Array], /) -> Array:
"""
Performs the operation __ixor__.
"""
self._check_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__ixor__")
if other is NotImplemented:
return other
Expand Down
8 changes: 8 additions & 0 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,14 @@ def _array_vals():
else:
assert_raises(TypeError, lambda: getattr(x, _op)(y))

# finally, test that array op ndarray raises
# XXX: as long as there is __array__, __rop__s still
# return ndarrays
if not _op.startswith("__r"):
with assert_raises(TypeError):
getattr(x, _op)(y._array)


unary_op_dtypes = {
"__abs__": "numeric",
"__invert__": "integer_or_boolean",
Expand Down

0 comments on commit 6d67d46

Please sign in to comment.