Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add .oindex and .vindex to BackendArray #8885

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,24 @@ def get_duck_array(self, dtype: np.typing.DTypeLike = None):
key = indexing.BasicIndexer((slice(None),) * self.ndim)
return self[key] # type: ignore [index]

def _oindex_get(self, key: indexing.OuterIndexer):
raise NotImplementedError(
f"{self.__class__.__name__}._oindex_get method should be overridden"
)

def _vindex_get(self, key: indexing.VectorizedIndexer):
raise NotImplementedError(
f"{self.__class__.__name__}._vindex_get method should be overridden"
)

@property
def oindex(self) -> indexing.IndexCallable:
return indexing.IndexCallable(self._oindex_get)

@property
def vindex(self) -> indexing.IndexCallable:
return indexing.IndexCallable(self._vindex_get)


class AbstractDataStore:
__slots__ = ()
Expand Down
12 changes: 11 additions & 1 deletion xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,17 @@ def get_array(self, needs_lock=True):
ds = self.datastore._acquire(needs_lock)
return ds.variables[self.variable_name]

def __getitem__(self, key):
def _oindex_get(self, key: indexing.OuterIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)

def _vindex_get(self, key: indexing.VectorizedIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)

def __getitem__(self, key: indexing.BasicIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)
Expand Down
12 changes: 11 additions & 1 deletion xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,17 @@ def get_array(self, needs_lock=True):
variable.set_auto_chartostring(False)
return variable

def __getitem__(self, key):
def _oindex_get(self, key: indexing.OuterIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
)

def _vindex_get(self, key: indexing.VectorizedIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
)

def __getitem__(self, key: indexing.BasicIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
)
Expand Down
12 changes: 11 additions & 1 deletion xarray/backends/pydap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,17 @@ def shape(self) -> tuple[int, ...]:
def dtype(self):
return self.array.dtype

def __getitem__(self, key):
def _oindex_get(self, key: indexing.OuterIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
)

def _vindex_get(self, key: indexing.VectorizedIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
)

def __getitem__(self, key: indexing.BasicIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
)
Expand Down
12 changes: 11 additions & 1 deletion xarray/backends/pynio_.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,17 @@ def get_array(self, needs_lock=True):
ds = self.datastore._manager.acquire(needs_lock)
return ds.variables[self.variable_name]

def __getitem__(self, key):
def _oindex_get(self, key: indexing.OuterIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
)

def _vindex_get(self, key: indexing.VectorizedIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
)

def __getitem__(self, key: indexing.BasicIndexer):
return indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
)
Expand Down
33 changes: 24 additions & 9 deletions xarray/backends/scipy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,7 @@ def get_variable(self, needs_lock=True):
ds = self.datastore._manager.acquire(needs_lock)
return ds.variables[self.variable_name]

def _getitem(self, key):
with self.datastore.lock:
data = self.get_variable(needs_lock=False).data
return data[key]

def __getitem__(self, key):
data = indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)
def _finalize_result(self, data):
# Copy data if the source file is mmapped. This makes things consistent
# with the netCDF4 library by ensuring we can safely read arrays even
# after closing associated files.
Expand All @@ -88,6 +80,29 @@ def __getitem__(self, key):

return np.array(data, dtype=self.dtype, copy=copy)

def _getitem(self, key):
with self.datastore.lock:
data = self.get_variable(needs_lock=False).data
return data[key]

def _vindex_get(self, key: indexing.VectorizedIndexer):
data = indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)
return self._finalize_result(data)

def _oindex_get(self, key: indexing.OuterIndexer):
data = indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)
return self._finalize_result(data)

def __getitem__(self, key):
data = indexing.explicit_indexing_adapter(
key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
)
return self._finalize_result(data)

def __setitem__(self, key, value):
with self.datastore.lock:
data = self.get_variable(needs_lock=False)
Expand Down
49 changes: 31 additions & 18 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,38 @@ def __init__(self, zarr_array):
def get_array(self):
return self._array

def _oindex(self, key):
return self._array.oindex[key]

def _vindex(self, key):
return self._array.vindex[key]

def _getitem(self, key):
return self._array[key]

def __getitem__(self, key):
array = self._array
if isinstance(key, indexing.BasicIndexer):
method = self._getitem
elif isinstance(key, indexing.VectorizedIndexer):
method = self._vindex
elif isinstance(key, indexing.OuterIndexer):
method = self._oindex
def _oindex_get(self, key: indexing.OuterIndexer):
def raw_indexing_method(key):
andersy005 marked this conversation as resolved.
Show resolved Hide resolved
return self._array.oindex[key]

return indexing.explicit_indexing_adapter(
key,
self._array.shape,
indexing.IndexingSupport.VECTORIZED,
raw_indexing_method,
)

def _vindex_get(self, key: indexing.VectorizedIndexer):

def raw_indexing_method(key):
return self._array.vindex[key]

return indexing.explicit_indexing_adapter(
key,
self._array.shape,
indexing.IndexingSupport.VECTORIZED,
raw_indexing_method,
)

def __getitem__(self, key: indexing.BasicIndexer):
def raw_indexing_method(key):
return self._array[key]

return indexing.explicit_indexing_adapter(
key, array.shape, indexing.IndexingSupport.VECTORIZED, method
key,
self._array.shape,
indexing.IndexingSupport.VECTORIZED,
raw_indexing_method,
)

# if self.ndim == 0:
Expand Down
36 changes: 30 additions & 6 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import enum
import functools
import operator
import warnings
from collections import Counter, defaultdict
from collections.abc import Hashable, Iterable, Mapping
from contextlib import suppress
Expand Down Expand Up @@ -564,6 +565,14 @@ def __getitem__(self, key: Any):
return result


BackendArray_fallback_warning_message = (
"The array `{0}` does not support indexing using the .vindex and .oindex properties. "
"The __getitem__ method is being used instead. This fallback behavior will be "
"removed in a future version. Please ensure that the backend array `{1}` implements "
"support for the .vindex and .oindex properties to avoid potential issues."
)


class LazilyIndexedArray(ExplicitlyIndexedNDArrayMixin):
"""Wrap an array to make basic and outer indexing lazy."""

Expand Down Expand Up @@ -615,11 +624,18 @@ def shape(self) -> _Shape:
return tuple(shape)

def get_duck_array(self):
if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
try:
array = apply_indexer(self.array, self.key)
else:
except NotImplementedError as _:
# If the array is not an ExplicitlyIndexedNDArrayMixin,
# it may wrap a BackendArray so use its __getitem__
# it may wrap a BackendArray subclass that doesn't implement .oindex and .vindex. so use its __getitem__
warnings.warn(
BackendArray_fallback_warning_message.format(
self.array.__class__.__name__, self.array.__class__.__name__
),
category=DeprecationWarning,
stacklevel=2,
)
array = self.array[self.key]

# self.array[self.key] is now a numpy array when
Expand Down Expand Up @@ -691,12 +707,20 @@ def shape(self) -> _Shape:
return np.broadcast(*self.key.tuple).shape

def get_duck_array(self):
if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
try:
array = apply_indexer(self.array, self.key)
else:
except NotImplementedError as _:
# If the array is not an ExplicitlyIndexedNDArrayMixin,
# it may wrap a BackendArray so use its __getitem__
# it may wrap a BackendArray subclass that doesn't implement .oindex and .vindex. so use its __getitem__
warnings.warn(
BackendArray_fallback_warning_message.format(
self.array.__class__.__name__, self.array.__class__.__name__
),
category=PendingDeprecationWarning,
stacklevel=2,
)
array = self.array[self.key]

# self.array[self.key] is now a numpy array when
# self.array is a BackendArray subclass
# and self.key is BasicIndexer((slice(None, None, None),))
Expand Down
46 changes: 46 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5828,3 +5828,49 @@ def test_zarr_region_chunk_partial_offset(tmp_path):
# This write is unsafe, and should raise an error, but does not.
# with pytest.raises(ValueError):
# da.isel(x=slice(5, 25)).chunk(x=(10, 10)).to_zarr(store, region="auto")


def test_backend_array_deprecation_warning(capsys):
class CustomBackendArray(xr.backends.common.BackendArray):
def __init__(self):
array = self.get_array()
self.shape = array.shape
self.dtype = array.dtype

def get_array(self):
return np.arange(10)

def __getitem__(self, key):
return xr.core.indexing.explicit_indexing_adapter(
key, self.shape, xr.core.indexing.IndexingSupport.BASIC, self._getitem
)

def _getitem(self, key):
array = self.get_array()
return array[key]

cba = CustomBackendArray()
indexer = xr.core.indexing.VectorizedIndexer(key=(np.array([0]),))

la = xr.core.indexing.LazilyIndexedArray(cba, indexer)

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
la.vindex[indexer].get_duck_array()

captured = capsys.readouterr()
assert len(w) == 1
assert issubclass(w[-1].category, PendingDeprecationWarning)
assert (
"The array `CustomBackendArray` does not support indexing using the .vindex and .oindex properties."
in str(w[-1].message)
)
assert "The __getitem__ method is being used instead." in str(w[-1].message)
assert "This fallback behavior will be removed in a future version." in str(
w[-1].message
)
assert (
"Please ensure that the backend array `CustomBackendArray` implements support for the .vindex and .oindex properties to avoid potential issues."
in str(w[-1].message)
)
assert captured.out == ""
Loading