diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 56ed487d5c6..d9c89649358 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -351,33 +351,35 @@ IndexVariable.values - namedarray.core.NamedArray.all - namedarray.core.NamedArray.any - namedarray.core.NamedArray.attrs - namedarray.core.NamedArray.chunks - namedarray.core.NamedArray.chunksizes - namedarray.core.NamedArray.copy - namedarray.core.NamedArray.count - namedarray.core.NamedArray.cumprod - namedarray.core.NamedArray.cumsum - namedarray.core.NamedArray.data - namedarray.core.NamedArray.dims - namedarray.core.NamedArray.dtype - namedarray.core.NamedArray.get_axis_num - namedarray.core.NamedArray.max - namedarray.core.NamedArray.mean - namedarray.core.NamedArray.median - namedarray.core.NamedArray.min - namedarray.core.NamedArray.nbytes - namedarray.core.NamedArray.ndim - namedarray.core.NamedArray.prod - namedarray.core.NamedArray.reduce - namedarray.core.NamedArray.shape - namedarray.core.NamedArray.size - namedarray.core.NamedArray.sizes - namedarray.core.NamedArray.std - namedarray.core.NamedArray.sum - namedarray.core.NamedArray.var + NamedArray.all + NamedArray.any + NamedArray.attrs + NamedArray.broadcast_to + NamedArray.chunks + NamedArray.chunksizes + NamedArray.copy + NamedArray.count + NamedArray.cumprod + NamedArray.cumsum + NamedArray.data + NamedArray.dims + NamedArray.dtype + NamedArray.expand_dims + NamedArray.get_axis_num + NamedArray.max + NamedArray.mean + NamedArray.median + NamedArray.min + NamedArray.nbytes + NamedArray.ndim + NamedArray.prod + NamedArray.reduce + NamedArray.shape + NamedArray.size + NamedArray.sizes + NamedArray.std + NamedArray.sum + NamedArray.var plot.plot diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8865eb98481..a45e528875d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,6 +23,9 @@ v2024.02.0 (unreleased) New Features ~~~~~~~~~~~~ +- Add :py:meth:`NamedArray.expand_dims`, :py:meth:`NamedArray.permute_dims` and :py:meth:`NamedArray.broadcast_to` + (:pull:`8380`) By `Anderson Banihirwe `_. + - Xarray now defers to flox's `heuristics `_ to set default `method` for groupby problems. This only applies to ``flox>=0.9``. By `Deepak Cherian `_. diff --git a/xarray/__init__.py b/xarray/__init__.py index 1fd3b0c4336..91613e8cbbc 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -41,6 +41,7 @@ from xarray.core.options import get_options, set_options from xarray.core.parallel import map_blocks from xarray.core.variable import IndexVariable, Variable, as_variable +from xarray.namedarray.core import NamedArray from xarray.util.print_versions import show_versions try: @@ -104,6 +105,7 @@ "IndexSelResult", "IndexVariable", "Variable", + "NamedArray", # Exceptions "MergeError", "SerializationWarning", diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ddf2b3e1891..099a94592fa 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -71,6 +71,7 @@ as_compatible_data, as_variable, ) +from xarray.namedarray.utils import infix_dims from xarray.plot.accessor import DataArrayPlotAccessor from xarray.plot.utils import _get_units_from_attrs from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims @@ -3017,7 +3018,7 @@ def transpose( Dataset.transpose """ if dims: - dims = tuple(utils.infix_dims(dims, self.dims, missing_dims)) + dims = tuple(infix_dims(dims, self.dims, missing_dims)) variable = self.variable.transpose(*dims) if transpose_coords: coords: dict[Hashable, Variable] = {} diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ad3260569b1..7fc689eef7b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -114,8 +114,6 @@ drop_dims_from_indexers, either_dict_or_kwargs, emit_user_level_warning, - infix_dims, - is_dict_like, is_scalar, maybe_wrap_array, ) @@ -126,6 +124,7 @@ broadcast_variables, calculate_dimensions, ) +from xarray.namedarray.utils import infix_dims, is_dict_like from xarray.plot.accessor import DatasetPlotAccessor from xarray.util.deprecation_helpers import _deprecate_positional_args diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 3a9c3ee6470..93f648277ca 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -5,14 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - Literal, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union import numpy as np import pandas as pd diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 9da937b996e..b514e455230 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -846,36 +846,6 @@ def __len__(self) -> int: return len(self._data) - num_hidden -def infix_dims( - dims_supplied: Collection, - dims_all: Collection, - missing_dims: ErrorOptionsWithWarn = "raise", -) -> Iterator: - """ - Resolves a supplied list containing an ellipsis representing other items, to - a generator with the 'realized' list of all items - """ - if ... in dims_supplied: - if len(set(dims_all)) != len(dims_all): - raise ValueError("Cannot use ellipsis with repeated dims") - if list(dims_supplied).count(...) > 1: - raise ValueError("More than one ellipsis supplied") - other_dims = [d for d in dims_all if d not in dims_supplied] - existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims) - for d in existing_dims: - if d is ...: - yield from other_dims - else: - yield d - else: - existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims) - if set(existing_dims) ^ set(dims_all): - raise ValueError( - f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included" - ) - yield from existing_dims - - def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: """Get an new dimension name based on new_dim, that is not used in dims. If the same name exists, we add an underscore(s) in the head. diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 119495a486a..32323bb2e64 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -42,11 +42,11 @@ drop_dims_from_indexers, either_dict_or_kwargs, ensure_us_time_resolution, - infix_dims, is_duck_array, maybe_coerce_to_str, ) from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions +from xarray.namedarray.utils import infix_dims NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( indexing.ExplicitlyIndexed, diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py index b5c320e0b96..2ad539bad18 100644 --- a/xarray/namedarray/_array_api.py +++ b/xarray/namedarray/_array_api.py @@ -9,6 +9,7 @@ from xarray.namedarray._typing import ( Default, _arrayapi, + _Axes, _Axis, _default, _Dim, @@ -196,3 +197,32 @@ def expand_dims( d.insert(axis, dim) out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis)) return out + + +def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DType]: + """ + Permutes the dimensions of an array. + + Parameters + ---------- + x : + Array to permute. + axes : + Permutation of the dimensions of x. + + Returns + ------- + out : + An array with permuted dimensions. The returned array must have the same + data type as x. + + """ + + dims = x.dims + new_dims = tuple(dims[i] for i in axes) + if isinstance(x._data, _arrayapi): + xp = _get_data_namespace(x) + out = x._new(dims=new_dims, data=xp.permute_dims(x._data, axes)) + else: + out = x._new(dims=new_dims, data=x._data.transpose(axes)) # type: ignore[attr-defined] + return out diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py index 37832daca58..835f60e682d 100644 --- a/xarray/namedarray/_typing.py +++ b/xarray/namedarray/_typing.py @@ -7,6 +7,7 @@ Any, Callable, Final, + Literal, Protocol, SupportsIndex, TypeVar, @@ -311,8 +312,10 @@ def todense(self) -> np.ndarray[Any, _DType_co]: # NamedArray can most likely use both __array_function__ and __array_namespace__: _sparsearrayfunction_or_api = (_sparsearrayfunction, _sparsearrayapi) - sparseduckarray = Union[ _sparsearrayfunction[_ShapeType_co, _DType_co], _sparsearrayapi[_ShapeType_co, _DType_co], ] + +ErrorOptions = Literal["raise", "ignore"] +ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"] diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index b9ad27b6679..ea653c01498 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -22,6 +22,7 @@ from xarray.core import dtypes, formatting, formatting_html from xarray.namedarray._aggregations import NamedArrayAggregations from xarray.namedarray._typing import ( + ErrorOptionsWithWarn, _arrayapi, _arrayfunction_or_api, _chunkedarray, @@ -34,7 +35,12 @@ _SupportsImag, _SupportsReal, ) -from xarray.namedarray.utils import is_duck_dask_array, to_0d_object_array +from xarray.namedarray.utils import ( + either_dict_or_kwargs, + infix_dims, + is_duck_dask_array, + to_0d_object_array, +) if TYPE_CHECKING: from numpy.typing import ArrayLike, NDArray @@ -483,7 +489,7 @@ def _parse_dimensions(self, dims: _DimsLike) -> _Dims: f"number of data dimensions, ndim={self.ndim}" ) if len(set(dims)) < len(dims): - repeated_dims = set([d for d in dims if dims.count(d) > 1]) + repeated_dims = {d for d in dims if dims.count(d) > 1} warnings.warn( f"Duplicate dimension names present: dimensions {repeated_dims} appear more than once in dims={dims}. " "We do not yet support duplicate dimension names, but we do allow initial construction of the object. " @@ -855,6 +861,162 @@ def _to_dense(self) -> NamedArray[Any, _DType_co]: else: raise TypeError("self.data is not a sparse array") + def permute_dims( + self, + *dim: Iterable[_Dim] | ellipsis, + missing_dims: ErrorOptionsWithWarn = "raise", + ) -> NamedArray[Any, _DType_co]: + """Return a new object with transposed dimensions. + + Parameters + ---------- + *dim : Hashable, optional + By default, reverse the order of the dimensions. Otherwise, reorder the + dimensions to this order. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + NamedArray: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions + + Returns + ------- + NamedArray + The returned NamedArray has permuted dimensions and data with the + same attributes as the original. + + + See Also + -------- + numpy.transpose + """ + + from xarray.namedarray._array_api import permute_dims + + if not dim: + dims = self.dims[::-1] + else: + dims = tuple(infix_dims(dim, self.dims, missing_dims)) # type: ignore + + if len(dims) < 2 or dims == self.dims: + # no need to transpose if only one dimension + # or dims are in same order + return self.copy(deep=False) + + axes_result = self.get_axis_num(dims) + axes = (axes_result,) if isinstance(axes_result, int) else axes_result + + return permute_dims(self, axes) + + @property + def T(self) -> NamedArray[Any, _DType_co]: + """Return a new object with transposed dimensions.""" + if self.ndim != 2: + raise ValueError( + f"x.T requires x to have 2 dimensions, got {self.ndim}. Use x.permute_dims() to permute dimensions." + ) + + return self.permute_dims() + + def broadcast_to( + self, dim: Mapping[_Dim, int] | None = None, **dim_kwargs: Any + ) -> NamedArray[Any, _DType_co]: + """ + Broadcast the NamedArray to a new shape. New dimensions are not allowed. + + This method allows for the expansion of the array's dimensions to a specified shape. + It handles both positional and keyword arguments for specifying the dimensions to broadcast. + An error is raised if new dimensions are attempted to be added. + + Parameters + ---------- + dim : dict, str, sequence of str, optional + Dimensions to broadcast the array to. If a dict, keys are dimension names and values are the new sizes. + If a string or sequence of strings, existing dimensions are matched with a size of 1. + + **dim_kwargs : Any + Additional dimensions specified as keyword arguments. Each keyword argument specifies the name of an existing dimension and its size. + + Returns + ------- + NamedArray + A new NamedArray with the broadcasted dimensions. + + Examples + -------- + >>> data = np.asarray([[1.0, 2.0], [3.0, 4.0]]) + >>> array = xr.NamedArray(("x", "y"), data) + >>> array.sizes + {'x': 2, 'y': 2} + + >>> broadcasted = array.broadcast_to(x=2, y=2) + >>> broadcasted.sizes + {'x': 2, 'y': 2} + """ + + from xarray.core import duck_array_ops + + combined_dims = either_dict_or_kwargs(dim, dim_kwargs, "broadcast_to") + + # Check that no new dimensions are added + if new_dims := set(combined_dims) - set(self.dims): + raise ValueError( + f"Cannot add new dimensions: {new_dims}. Only existing dimensions are allowed. " + "Use `expand_dims` method to add new dimensions." + ) + + # Create a dictionary of the current dimensions and their sizes + current_shape = self.sizes + + # Update the current shape with the new dimensions, keeping the order of the original dimensions + broadcast_shape = {d: current_shape.get(d, 1) for d in self.dims} + broadcast_shape |= combined_dims + + # Ensure the dimensions are in the correct order + ordered_dims = list(broadcast_shape.keys()) + ordered_shape = tuple(broadcast_shape[d] for d in ordered_dims) + data = duck_array_ops.broadcast_to(self._data, ordered_shape) # type: ignore # TODO: use array-api-compat function + return self._new(data=data, dims=ordered_dims) + + def expand_dims( + self, + dim: _Dim | Default = _default, + ) -> NamedArray[Any, _DType_co]: + """ + Expand the dimensions of the NamedArray. + + This method adds new dimensions to the object. The new dimensions are added at the beginning of the array. + + Parameters + ---------- + dim : Hashable, optional + Dimension name to expand the array to. This dimension will be added at the beginning of the array. + + Returns + ------- + NamedArray + A new NamedArray with expanded dimensions. + + + Examples + -------- + + >>> data = np.asarray([[1.0, 2.0], [3.0, 4.0]]) + >>> array = xr.NamedArray(("x", "y"), data) + + + # expand dimensions by specifying a new dimension name + >>> expanded = array.expand_dims(dim="z") + >>> expanded.dims + ('z', 'x', 'y') + + """ + + from xarray.namedarray._array_api import expand_dims + + return expand_dims(self, dim=dim) + _NamedArray = NamedArray[Any, np.dtype[_ScalarType_co]] @@ -863,7 +1025,7 @@ def _raise_if_any_duplicate_dimensions( dims: _Dims, err_context: str = "This function" ) -> None: if len(set(dims)) < len(dims): - repeated_dims = set([d for d in dims if dims.count(d) > 1]) + repeated_dims = {d for d in dims if dims.count(d) > 1} raise ValueError( f"{err_context} cannot handle duplicate dimensions, but dimensions {repeated_dims} appear more than once on this object's dims: {dims}" ) diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py index 4bd20931189..3097cd820a0 100644 --- a/xarray/namedarray/utils.py +++ b/xarray/namedarray/utils.py @@ -1,11 +1,14 @@ from __future__ import annotations import sys -from collections.abc import Hashable -from typing import TYPE_CHECKING, Any +import warnings +from collections.abc import Hashable, Iterable, Iterator, Mapping +from typing import TYPE_CHECKING, Any, TypeVar, cast import numpy as np +from xarray.namedarray._typing import ErrorOptionsWithWarn, _DimsLike + if TYPE_CHECKING: if sys.version_info >= (3, 10): from typing import TypeGuard @@ -14,9 +17,7 @@ from numpy.typing import NDArray - from xarray.namedarray._typing import ( - duckarray, - ) + from xarray.namedarray._typing import _Dim, duckarray try: from dask.array.core import Array as DaskArray @@ -26,6 +27,11 @@ DaskCollection: Any = NDArray # type: ignore +K = TypeVar("K") +V = TypeVar("V") +T = TypeVar("T") + + def module_available(module: str) -> bool: """Checks whether a module is installed without importing it. @@ -67,6 +73,101 @@ def to_0d_object_array( return result +def is_dict_like(value: Any) -> TypeGuard[Mapping[Any, Any]]: + return hasattr(value, "keys") and hasattr(value, "__getitem__") + + +def drop_missing_dims( + supplied_dims: Iterable[_Dim], + dims: Iterable[_Dim], + missing_dims: ErrorOptionsWithWarn, +) -> _DimsLike: + """Depending on the setting of missing_dims, drop any dimensions from supplied_dims that + are not present in dims. + + Parameters + ---------- + supplied_dims : Iterable of Hashable + dims : Iterable of Hashable + missing_dims : {"raise", "warn", "ignore"} + """ + + if missing_dims == "raise": + supplied_dims_set = {val for val in supplied_dims if val is not ...} + if invalid := supplied_dims_set - set(dims): + raise ValueError( + f"Dimensions {invalid} do not exist. Expected one or more of {dims}" + ) + + return supplied_dims + + elif missing_dims == "warn": + if invalid := set(supplied_dims) - set(dims): + warnings.warn( + f"Dimensions {invalid} do not exist. Expected one or more of {dims}" + ) + + return [val for val in supplied_dims if val in dims or val is ...] + + elif missing_dims == "ignore": + return [val for val in supplied_dims if val in dims or val is ...] + + else: + raise ValueError( + f"Unrecognised option {missing_dims} for missing_dims argument" + ) + + +def infix_dims( + dims_supplied: Iterable[_Dim], + dims_all: Iterable[_Dim], + missing_dims: ErrorOptionsWithWarn = "raise", +) -> Iterator[_Dim]: + """ + Resolves a supplied list containing an ellipsis representing other items, to + a generator with the 'realized' list of all items + """ + if ... in dims_supplied: + dims_all_list = list(dims_all) + if len(set(dims_all)) != len(dims_all_list): + raise ValueError("Cannot use ellipsis with repeated dims") + if list(dims_supplied).count(...) > 1: + raise ValueError("More than one ellipsis supplied") + other_dims = [d for d in dims_all if d not in dims_supplied] + existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims) + for d in existing_dims: + if d is ...: + yield from other_dims + else: + yield d + else: + existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims) + if set(existing_dims) ^ set(dims_all): + raise ValueError( + f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included" + ) + yield from existing_dims + + +def either_dict_or_kwargs( + pos_kwargs: Mapping[Any, T] | None, + kw_kwargs: Mapping[str, T], + func_name: str, +) -> Mapping[Hashable, T]: + if pos_kwargs is None or pos_kwargs == {}: + # Need an explicit cast to appease mypy due to invariance; see + # https://github.com/python/mypy/issues/6228 + return cast(Mapping[Hashable, T], kw_kwargs) + + if not is_dict_like(pos_kwargs): + raise ValueError(f"the first argument to .{func_name} must be a dictionary") + if kw_kwargs: + raise ValueError( + f"cannot specify both keyword and positional arguments to .{func_name}" + ) + return pos_kwargs + + class ReprObject: """Object that prints as the given value, for use with sentinel values.""" diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 9aedcbc80d4..20652f4cc3b 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -26,10 +26,13 @@ from xarray.namedarray._typing import ( Default, _AttrsLike, + _Dim, _DimsLike, _DType, _IndexKeyLike, + _IntOrUnknown, _Shape, + _ShapeLike, duckarray, ) @@ -329,7 +332,7 @@ def test_duck_array_class( self, ) -> None: def test_duck_array_typevar( - a: duckarray[Any, _DType] + a: duckarray[Any, _DType], ) -> duckarray[Any, _DType]: # Mypy checks a is valid: b: duckarray[Any, _DType] = a @@ -379,7 +382,6 @@ def test_new_namedarray(self) -> None: narr_int = narr_float._new(("x",), np.array([1, 3], dtype=dtype_int)) assert narr_int.dtype == dtype_int - # Test with a subclass: class Variable( NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] ): @@ -417,9 +419,8 @@ def _new( if data is _default: return type(self)(dims_, copy.copy(self._data), attrs_) - else: - cls_ = cast("type[Variable[Any, _DType]]", type(self)) - return cls_(dims_, data, attrs_) + cls_ = cast("type[Variable[Any, _DType]]", type(self)) + return cls_(dims_, data, attrs_) var_float: Variable[Any, np.dtype[np.float32]] var_float = Variable(("x",), np.array([1.5, 3.2], dtype=dtype_float)) @@ -444,7 +445,6 @@ def test_replace_namedarray(self) -> None: narr_float2 = NamedArray(("x",), np_val2) assert narr_float2.dtype == dtype_float - # Test with a subclass: class Variable( NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] ): @@ -482,9 +482,8 @@ def _new( if data is _default: return type(self)(dims_, copy.copy(self._data), attrs_) - else: - cls_ = cast("type[Variable[Any, _DType]]", type(self)) - return cls_(dims_, data, attrs_) + cls_ = cast("type[Variable[Any, _DType]]", type(self)) + return cls_(dims_, data, attrs_) var_float: Variable[Any, np.dtype[np.float32]] var_float = Variable(("x",), np_val) @@ -494,6 +493,86 @@ def _new( var_float2 = var_float._replace(("x",), np_val2) assert var_float2.dtype == dtype_float + @pytest.mark.parametrize( + "dim,expected_ndim,expected_shape,expected_dims", + [ + (None, 3, (1, 2, 5), (None, "x", "y")), + (_default, 3, (1, 2, 5), ("dim_2", "x", "y")), + ("z", 3, (1, 2, 5), ("z", "x", "y")), + ], + ) + def test_expand_dims( + self, + target: NamedArray[Any, np.dtype[np.float32]], + dim: _Dim | Default, + expected_ndim: int, + expected_shape: _ShapeLike, + expected_dims: _DimsLike, + ) -> None: + result = target.expand_dims(dim=dim) + assert result.ndim == expected_ndim + assert result.shape == expected_shape + assert result.dims == expected_dims + + @pytest.mark.parametrize( + "dims, expected_sizes", + [ + ((), {"y": 5, "x": 2}), + (["y", "x"], {"y": 5, "x": 2}), + (["y", ...], {"y": 5, "x": 2}), + ], + ) + def test_permute_dims( + self, + target: NamedArray[Any, np.dtype[np.float32]], + dims: _DimsLike, + expected_sizes: dict[_Dim, _IntOrUnknown], + ) -> None: + actual = target.permute_dims(*dims) + assert actual.sizes == expected_sizes + + def test_permute_dims_errors( + self, + target: NamedArray[Any, np.dtype[np.float32]], + ) -> None: + with pytest.raises(ValueError, match=r"'y'.*permuted list"): + dims = ["y"] + target.permute_dims(*dims) + + @pytest.mark.parametrize( + "broadcast_dims,expected_ndim", + [ + ({"x": 2, "y": 5}, 2), + ({"x": 2, "y": 5, "z": 2}, 3), + ({"w": 1, "x": 2, "y": 5}, 3), + ], + ) + def test_broadcast_to( + self, + target: NamedArray[Any, np.dtype[np.float32]], + broadcast_dims: Mapping[_Dim, int], + expected_ndim: int, + ) -> None: + expand_dims = set(broadcast_dims.keys()) - set(target.dims) + # loop over expand_dims and call .expand_dims(dim=dim) in a loop + for dim in expand_dims: + target = target.expand_dims(dim=dim) + result = target.broadcast_to(broadcast_dims) + assert result.ndim == expected_ndim + assert result.sizes == broadcast_dims + + def test_broadcast_to_errors( + self, target: NamedArray[Any, np.dtype[np.float32]] + ) -> None: + with pytest.raises( + ValueError, + match=r"operands could not be broadcast together with remapped shapes", + ): + target.broadcast_to({"x": 2, "y": 2}) + + with pytest.raises(ValueError, match=r"Cannot add new dimensions"): + target.broadcast_to({"x": 2, "y": 2, "z": 2}) + def test_warn_on_repeated_dimension_names(self) -> None: with pytest.warns(UserWarning, match="Duplicate dimension names"): NamedArray(("x", "x"), np.arange(4).reshape(2, 2)) diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index b5bd58a78c4..7bdc2c92e44 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -8,6 +8,7 @@ from xarray.core import duck_array_ops, utils from xarray.core.utils import either_dict_or_kwargs, iterate_nested +from xarray.namedarray.utils import infix_dims from xarray.tests import assert_array_equal, requires_dask @@ -239,7 +240,7 @@ def test_either_dict_or_kwargs(): ], ) def test_infix_dims(supplied, all_, expected): - result = list(utils.infix_dims(supplied, all_)) + result = list(infix_dims(supplied, all_)) assert result == expected @@ -248,7 +249,7 @@ def test_infix_dims(supplied, all_, expected): ) def test_infix_dims_errors(supplied, all_): with pytest.raises(ValueError): - list(utils.infix_dims(supplied, all_)) + list(infix_dims(supplied, all_)) @pytest.mark.parametrize(