diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 4e58be1ad2f..ab049a0a4b4 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -4,7 +4,7 @@ from collections import defaultdict from contextlib import suppress from datetime import timedelta -from typing import Any, Callable, Sequence, Tuple, Union +from typing import Any, Callable, Iterable, Sequence, Tuple, Union import numpy as np import pandas as pd @@ -1314,6 +1314,24 @@ def __init__(self, array): self.array = array def __getitem__(self, key): + + if not isinstance(key, VectorizedIndexer): + # if possible, short-circuit when keys are effectively slice(None) + # This preserves dask name and passes lazy array equivalence checks + # (see duck_array_ops.lazy_array_equiv) + rewritten_indexer = False + new_indexer = [] + for idim, k in enumerate(key.tuple): + if isinstance(k, Iterable) and duck_array_ops.array_equiv( + k, np.arange(self.array.shape[idim]) + ): + new_indexer.append(slice(None)) + rewritten_indexer = True + else: + new_indexer.append(k) + if rewritten_indexer: + key = type(key)(tuple(new_indexer)) + if isinstance(key, BasicIndexer): return self.array[key.tuple] elif isinstance(key, VectorizedIndexer): diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 058b7bf52d4..daa8678157b 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1057,7 +1057,9 @@ def isel( invalid = indexers.keys() - set(self.dims) if invalid: - raise ValueError("dimensions %r do not exist" % invalid) + raise ValueError( + f"dimensions {invalid} do not exist. Expected one or more of {self.dims}" + ) key = tuple(indexers.get(dim, slice(None)) for dim in self.dims) return self[key] diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index cc554850839..8fb54c4ee84 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1390,3 +1390,58 @@ def test_lazy_array_equiv_merge(compat): xr.merge([da1, da3], compat=compat) with raise_if_dask_computes(max_computes=2): xr.merge([da1, da2 / 2], compat=compat) + + +@pytest.mark.filterwarnings("ignore::FutureWarning") # transpose_coords +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +@pytest.mark.parametrize( + "transform", + [ + lambda a: a.assign_attrs(new_attr="anew"), + lambda a: a.assign_coords(cxy=a.cxy), + lambda a: a.copy(), + lambda a: a.isel(x=np.arange(a.sizes["x"])), + lambda a: a.isel(x=slice(None)), + lambda a: a.loc[dict(x=slice(None))], + lambda a: a.loc[dict(x=np.arange(a.sizes["x"]))], + lambda a: a.loc[dict(x=a.x)], + lambda a: a.sel(x=a.x), + lambda a: a.sel(x=a.x.values), + lambda a: a.transpose(...), + lambda a: a.squeeze(), # no dimensions to squeeze + lambda a: a.sortby("x"), # "x" is already sorted + lambda a: a.reindex(x=a.x), + lambda a: a.reindex_like(a), + lambda a: a.rename({"cxy": "cnew"}).rename({"cnew": "cxy"}), + lambda a: a.pipe(lambda x: x), + lambda a: xr.align(a, xr.zeros_like(a))[0], + # assign + # swap_dims + # set_index / reset_index + ], +) +def test_transforms_pass_lazy_array_equiv(obj, transform): + with raise_if_dask_computes(): + assert_equal(obj, transform(obj)) + + +def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds): + with raise_if_dask_computes(): + assert_equal(map_ds.cxy.broadcast_like(map_ds.cxy), map_ds.cxy) + assert_equal(xr.broadcast(map_ds.cxy, map_ds.cxy)[0], map_ds.cxy) + assert_equal(map_ds.map(lambda x: x), map_ds) + assert_equal(map_ds.set_coords("a").reset_coords("a"), map_ds) + assert_equal(map_ds.update({"a": map_ds.a}), map_ds) + + # fails because of index error + # assert_equal( + # map_ds.rename_dims({"x": "xnew"}).rename_dims({"xnew": "x"}), map_ds + # ) + + assert_equal( + map_ds.rename_vars({"cxy": "cnew"}).rename_vars({"cnew": "cxy"}), map_ds + ) + + assert_equal(map_da._from_temp_dataset(map_da._to_temp_dataset()), map_da) + assert_equal(map_da.astype(map_da.dtype), map_da) + assert_equal(map_da.transpose("y", "x", transpose_coords=False).cxy, map_da.cxy)