Skip to content

Commit

Permalink
Update _utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Sep 1, 2024
1 parent 9221751 commit 60aa969
Showing 1 changed file with 50 additions and 34 deletions.
84 changes: 50 additions & 34 deletions xarray/namedarray/_array_api/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from collections.abc import Iterable
from itertools import zip_longest
import math
from types import ModuleType
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -68,8 +70,21 @@ def _infer_dims(
shape: _Shape,
dims: _DimsLike | Default = _default,
) -> _DimsLike:
"""
Create default dim names if no dims were supplied.
Examples
--------
>>> _infer_dims(())
()
>>> _infer_dims((1,))
('dim_0',)
>>> _infer_dims((3, 1))
('dim_1', 'dim_0')
"""
if dims is _default:
return tuple(f"dim_{n}" for n in range(len(shape)))
ndim = len(shape)
return tuple(f"dim_{ndim - 1 - n}" for n in range(ndim))
else:
return dims

Expand Down Expand Up @@ -199,6 +214,11 @@ def _raise_if_any_duplicate_dimensions(
)


def _isnone(shape: _Shape) -> tuple[bool, ...]:
# TODO: math.isnan should not be needed for array api, but dask still uses np.nan:
return tuple(v is None and math.isnan(v) for v in shape)


def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]:
"""
Get the expected broadcasted dims.
Expand All @@ -209,6 +229,8 @@ def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]:
>>> b = NamedArray(("y", "z"), np.zeros((3, 4)))
>>> _get_broadcasted_dims(a, b)
(('x', 'y', 'z'), (5, 3, 4))
>>> _get_broadcasted_dims(b, a)
(('x', 'y', 'z'), (5, 3, 4))
>>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4)))
>>> b = NamedArray(("x", "y", "z"), np.zeros((0, 3, 4)))
Expand All @@ -230,37 +252,31 @@ def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]:
>>> _get_broadcasted_dims(a, b)
Traceback (most recent call last):
...
ValueError: operands cannot be broadcast together with mismatched lengths for dimension 'x': (5, 2)
ValueError: operands could not be broadcast together with dims = (('x', 'y', 'z'), ('x', 'y', 'z')) and shapes = ((5, 3, 4), (2, 3, 4))
"""

def broadcastable(e1: int, e2: int) -> bool:
# out = e1 > 1 and e2 <= 1
# out |= e2 > 1 and e1 <= 1

# out = e1 >= 0 and e2 <= 1
# out |= e2 >= 0 and e1 <= 1

out = e1 <= 1 or e2 <= 1

return out

# validate dimensions
all_dims = {}
for x in arrays:
_dims = x.dims
_raise_if_any_duplicate_dimensions(_dims, err_context="Broadcasting")

for d, s in zip(_dims, x.shape):
if d not in all_dims:
all_dims[d] = s
elif all_dims[d] != s:
if broadcastable(all_dims[d], s):
max(all_dims[d], s)
else:
raise ValueError(
"operands cannot be broadcast together "
f"with mismatched lengths for dimension {d!r}: {(all_dims[d], s)}"
)

# TODO: Return flag whether broadcasting is needed?
return tuple(all_dims.keys()), tuple(all_dims.values())
dims = tuple(a.dims for a in arrays)
shapes = tuple(a.shape for a in arrays)

if len(shapes) == 1:
return shapes[0]

out_dims = []
out_shape = []
for d, sizes in zip(
zip_longest(*map(reversed, dims), fillvalue=_default),
zip_longest(*map(reversed, shapes), fillvalue=-1),
):
_d = dict.fromkeys(d)
_d.pop(_default, None)
_d = list(_d)

dim = None if any(_isnone(sizes)) else max(sizes)

if any(i not in [-1, 0, 1, dim] for i in sizes) or len(_d) != 1:
raise ValueError(
f"operands could not be broadcast together with {dims = } and {shapes = }"
)

out_dims.append(_d[0])
out_shape.append(dim)
return tuple(reversed(out_dims)), tuple(reversed(out_shape))

0 comments on commit 60aa969

Please sign in to comment.