Skip to content

Commit

Permalink
ENH: allow python scalars in binary elementwise functions
Browse files Browse the repository at this point in the history
Allow func(array, scalar) and func(scalar, array), raise on
func(scalar, scalar)

cross-ref data-apis/array-api#807
  • Loading branch information
ev-br committed Nov 22, 2024
1 parent d086c61 commit 866cedb
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 0 deletions.
57 changes: 57 additions & 0 deletions array_api_strict/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ._flags import requires_api_version
from ._creation_functions import asarray
from ._data_type_functions import broadcast_to, iinfo
from ._helpers import _maybe_normalize_py_scalars

from typing import Optional, Union

Expand Down Expand Up @@ -62,6 +63,8 @@ def add(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")

Expand Down Expand Up @@ -116,6 +119,8 @@ def atan2(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
Expand Down Expand Up @@ -144,6 +149,8 @@ def bitwise_and(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")

Expand All @@ -165,6 +172,8 @@ def bitwise_left_shift(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")

Expand Down Expand Up @@ -197,6 +206,8 @@ def bitwise_or(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")

Expand All @@ -218,6 +229,8 @@ def bitwise_right_shift(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")

Expand All @@ -238,6 +251,8 @@ def bitwise_xor(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")

Expand Down Expand Up @@ -389,6 +404,8 @@ def copysign(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")

Expand Down Expand Up @@ -427,6 +444,8 @@ def divide(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
Expand All @@ -443,6 +462,8 @@ def equal(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
# Call result type here just to raise on disallowed type combinations
Expand Down Expand Up @@ -493,6 +514,8 @@ def floor_divide(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
Expand All @@ -509,6 +532,8 @@ def greater(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
Expand All @@ -525,6 +550,8 @@ def greater_equal(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
Expand All @@ -541,6 +568,8 @@ def hypot(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
Expand Down Expand Up @@ -600,6 +629,8 @@ def less(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
Expand All @@ -616,6 +647,8 @@ def less_equal(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
Expand Down Expand Up @@ -676,6 +709,8 @@ def logaddexp(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
Expand All @@ -692,6 +727,8 @@ def logical_and(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
Expand Down Expand Up @@ -719,6 +756,8 @@ def logical_or(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
Expand All @@ -735,6 +774,8 @@ def logical_xor(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
Expand All @@ -751,6 +792,8 @@ def maximum(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
Expand All @@ -769,6 +812,8 @@ def minimum(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
Expand All @@ -784,6 +829,8 @@ def multiply(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
Expand Down Expand Up @@ -812,6 +859,8 @@ def nextafter(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
Expand All @@ -825,6 +874,8 @@ def not_equal(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
# Call result type here just to raise on disallowed type combinations
Expand All @@ -851,6 +902,8 @@ def pow(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
Expand Down Expand Up @@ -889,6 +942,8 @@ def remainder(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes:
Expand Down Expand Up @@ -985,6 +1040,8 @@ def subtract(x1: Array, x2: Array, /) -> Array:
See its docstring for more information.
"""
x1, x2 = _maybe_normalize_py_scalars(x1, x2)

if x1.device != x2.device:
raise ValueError(f"Arrays from two different devices ({x1.device} and {x2.device}) can not be combined.")
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
Expand Down
19 changes: 19 additions & 0 deletions array_api_strict/_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Private helper routines.
"""
import numpy as np

_py_scalars = (bool, int, float, complex)

def _maybe_normalize_py_scalars(x1, x2):
from ._array_object import Array

if isinstance(x1, _py_scalars):
if isinstance(x2, _py_scalars):
raise TypeError(f"Two scalars not allowed, {type(x1) = } and {type(x2) =}")
x1 = Array._new(np.asarray(x1, dtype=x2.dtype._np_dtype), device=x2.device)
elif isinstance(x2, _py_scalars):
x2 = Array._new(np.asarray(x2, dtype=x1.dtype._np_dtype), device=x1.device)
else:
# nothing to do
pass
return x1, x2
30 changes: 30 additions & 0 deletions array_api_strict/tests/test_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,33 @@ def test_bitwise_shift_error():
assert_raises(
ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))
)


def test_scalars():
# Test that binary functions accept (array, scalar) and (scalar, array) arguments
# and reject (scalar, scalar) arguments

def _sample_scalar(category):
if 'boolean' in category:
return True
elif 'floating-point' in category:
return 1.0
elif 'numeric' in category or 'integer' in category or 'all' in category:
return 1
else:
raise ValueError(f'Unknown {category = }')

for func_name, types in elementwise_function_input_types.items():
dtypes = _dtype_categories[types]
func = getattr(_elementwise_functions, func_name)
if nargs(func) == 2:
scalar = _sample_scalar(types)
for dt in dtypes:
array = asarray(scalar, dtype=dt)
conv_scalar = asarray(scalar, dtype=array.dtype)
assert func(scalar, array) == func(conv_scalar, array)
assert func(array, scalar) == func(array, conv_scalar)

with assert_raises(TypeError):
func(scalar, scalar)

0 comments on commit 866cedb

Please sign in to comment.