diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 8d460e492c6..95d33d2ab65 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3592,7 +3592,7 @@ def interpolate_na( """ from xarray.core.missing import interp_na - return interp_na( + return interp_na( # type: ignore[return-value] self, dim=dim, method=method, @@ -3685,7 +3685,7 @@ def ffill(self, dim: Hashable, limit: int | None = None) -> Self: """ from xarray.core.missing import ffill - return ffill(self, dim, limit=limit) + return ffill(self, dim, limit=limit) # type: ignore[return-value] def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values backward @@ -3769,7 +3769,7 @@ def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """ from xarray.core.missing import bfill - return bfill(self, dim, limit=limit) + return bfill(self, dim, limit=limit) # type: ignore[return-value] def combine_first(self, other: Self) -> Self: """Combine two DataArray objects, with union of coordinates. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a7dedd2ed07..dd129c501ce 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4077,7 +4077,7 @@ def _validate_interp_indexer(x, new_x): # optimization: subset to coordinate range of the target index if method in ["linear", "nearest"]: for k, v in validated_indexers.items(): - obj, newidx = missing._localize(obj, {k: v}) + obj, newidx = missing._localize(obj, {k: v}) # type: ignore[assignment, arg-type] validated_indexers[k] = newidx[k] # optimization: create dask coordinate arrays once per Dataset @@ -6758,7 +6758,7 @@ def interpolate_na( max_gap=max_gap, **kwargs, ) - return new + return new # type: ignore[return-value] def ffill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values forward @@ -6822,7 +6822,7 @@ def ffill(self, dim: Hashable, limit: int | None = None) -> Self: from xarray.core.missing import _apply_over_vars_with_dim, ffill new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit) - return new + return new # type: ignore[return-value] def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values backward @@ -6887,7 +6887,7 @@ def bfill(self, dim: Hashable, limit: int | None = None) -> Self: from xarray.core.missing import _apply_over_vars_with_dim, bfill new = _apply_over_vars_with_dim(bfill, self, dim=dim, limit=limit) - return new + return new # type: ignore[return-value] def combine_first(self, other: Self) -> Self: """Combine two Datasets, default to data_vars of self. diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 4523e4f8232..30a95edf80e 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -2,13 +2,14 @@ import datetime as dt import warnings -from collections.abc import Callable, Hashable, Sequence +from collections.abc import Callable, Generator, Hashable, Mapping, Sequence from functools import partial from numbers import Number from typing import TYPE_CHECKING, Any, get_args import numpy as np import pandas as pd +from numpy.typing import ArrayLike from xarray.core import utils from xarray.core.common import _contains_datetime_like_objects, ones_like @@ -20,20 +21,31 @@ timedelta_to_numeric, ) from xarray.core.options import _get_keep_attrs -from xarray.core.types import Interp1dOptions, InterpOptions +from xarray.core.types import ( + Interp1dOptions, + InterpOptions, + ScalarOrArray, + T_DuckArray, + T_Xarray, +) from xarray.core.utils import OrderedSet, is_scalar from xarray.core.variable import Variable, broadcast_variables from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import is_chunked_array if TYPE_CHECKING: + from typing import TypeVar + from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset + # T_DuckArray can't be a function parameter because covariant + XarrayLike = TypeVar("XarrayLike", Variable, DataArray, Dataset) + def _get_nan_block_lengths( - obj: Dataset | DataArray | Variable, dim: Hashable, index: Variable -): + obj: Dataset | DataArray, dim: Hashable, index: T_DuckArray | ArrayLike +) -> Any: """ Return an object where each NaN element in 'obj' is replaced by the length of the gap the element is in. @@ -66,12 +78,12 @@ class BaseInterpolator: cons_kwargs: dict[str, Any] call_kwargs: dict[str, Any] f: Callable - method: str + method: str | int - def __call__(self, x): + def __call__(self, x: ScalarOrArray) -> ArrayLike: # dask array will return np return self.f(x, **self.call_kwargs) - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}: method={self.method}" @@ -83,7 +95,14 @@ class NumpyInterpolator(BaseInterpolator): numpy.interp """ - def __init__(self, xi, yi, method="linear", fill_value=None, period=None): + def __init__( + self, + xi: T_DuckArray, # passed to np.asarray + yi: T_DuckArray, # passed to np.asarray + requires dtype attribute + method: str | None = "linear", + fill_value: float | complex | None = None, + period: float | None = None, + ): if method != "linear": raise ValueError("only method `linear` is valid for the NumpyInterpolator") @@ -104,8 +123,8 @@ def __init__(self, xi, yi, method="linear", fill_value=None, period=None): self._left = fill_value[0] self._right = fill_value[1] elif is_scalar(fill_value): - self._left = fill_value - self._right = fill_value + self._left = fill_value # type: ignore[assignment] + self._right = fill_value # type: ignore[assignment] else: raise ValueError(f"{fill_value} is not a valid fill_value") @@ -130,15 +149,15 @@ class ScipyInterpolator(BaseInterpolator): def __init__( self, - xi, - yi, - method=None, - fill_value=None, - assume_sorted=True, - copy=False, - bounds_error=False, - order=None, - axis=-1, + xi: T_DuckArray, + yi: T_DuckArray, + method: str | int | None = None, + fill_value: float | complex | None = None, + assume_sorted: bool = True, + copy: bool = False, + bounds_error: bool = False, + order: int | None = None, + axis: int = -1, **kwargs, ): from scipy.interpolate import interp1d @@ -154,18 +173,11 @@ def __init__( raise ValueError("order is required when method=polynomial") method = order - self.method = method + self.method: str | int = method self.cons_kwargs = kwargs self.call_kwargs = {} - nan = np.nan if yi.dtype.kind != "c" else np.nan + np.nan * 1j - - if fill_value is None and method == "linear": - fill_value = nan, nan - elif fill_value is None: - fill_value = nan - self.f = interp1d( xi, yi, @@ -189,13 +201,13 @@ class SplineInterpolator(BaseInterpolator): def __init__( self, - xi, - yi, - method="spline", - fill_value=None, - order=3, - nu=0, - ext=None, + xi: T_DuckArray, + yi: T_DuckArray, + method: str | int | None = "spline", + fill_value: float | complex | None = None, + order: int = 3, + nu: float | None = 0, + ext: int | str | None = None, **kwargs, ): from scipy.interpolate import UnivariateSpline @@ -213,7 +225,9 @@ def __init__( self.f = UnivariateSpline(xi, yi, k=order, **self.cons_kwargs) -def _apply_over_vars_with_dim(func, self, dim=None, **kwargs): +def _apply_over_vars_with_dim( + func: Callable, self: Dataset, dim: Hashable | None = None, **kwargs +) -> Dataset: """Wrapper for datasets""" ds = type(self)(coords=self.coords, attrs=self.attrs) @@ -227,13 +241,16 @@ def _apply_over_vars_with_dim(func, self, dim=None, **kwargs): def get_clean_interp_index( - arr, dim: Hashable, use_coordinate: Hashable | bool = True, strict: bool = True -): + arr: T_Xarray, + dim: Hashable, + use_coordinate: Hashable | bool = True, + strict: bool = True, +) -> np.ndarray[Any, np.dtype[np.float64]]: """Return index to use for x values in interpolation or curve fitting. Parameters ---------- - arr : DataArray + arr : DataArray or Dataset Array to interpolate or fit to a curve. dim : str Name of dimension along which to fit. @@ -266,13 +283,13 @@ def get_clean_interp_index( index = arr.get_index(dim) else: # string - index = arr.coords[use_coordinate] + index = arr.coords[use_coordinate] # type: ignore[assignment] if index.ndim != 1: raise ValueError( f"Coordinates used for interpolation must be 1D, " f"{use_coordinate} is {index.ndim}D." ) - index = index.to_index() + index = index.to_index() # type: ignore[attr-defined] # TODO: index.name is None for multiindexes # set name for nice error messages below @@ -291,15 +308,16 @@ def get_clean_interp_index( if isinstance(index, CFTimeIndex | pd.DatetimeIndex): offset = type(index[0])(1970, 1, 1) if isinstance(index, CFTimeIndex): - index = index.values - index = Variable( + index = index.values # type: ignore[assignment] + index = Variable( # type: ignore[assignment] data=datetime_to_numeric(index, offset=offset, datetime_unit="ns"), dims=(dim,), ) # raise if index cannot be cast to a float (e.g. MultiIndex) try: - index = index.values.astype(np.float64) + # this step ensures output is ndarray + index = index.values.astype(np.float64) # type: ignore[assignment] except (TypeError, ValueError) as err: # pandas raises a TypeError # xarray/numpy raise a ValueError @@ -308,11 +326,11 @@ def get_clean_interp_index( f"interpolation or curve fitting, got {type(index).__name__}." ) from err - return index + return index # type: ignore[return-value] def interp_na( - self, + self: T_Xarray, dim: Hashable | None = None, use_coordinate: bool | str = True, method: InterpOptions = "linear", @@ -322,7 +340,7 @@ def interp_na( ) = None, keep_attrs: bool | None = None, **kwargs, -): +) -> T_Xarray: """Interpolate values according to different methods.""" from xarray.coding.cftimeindex import CFTimeIndex @@ -355,6 +373,7 @@ def interp_na( # method index = get_clean_interp_index(self, dim, use_coordinate=use_coordinate) + interp_class: type[BaseInterpolator] interp_class, kwargs = _get_interpolator(method, **kwargs) interpolator = partial(func_interpolate_na, interp_class, **kwargs) @@ -390,14 +409,16 @@ def interp_na( return arr -def func_interpolate_na(interpolator, y, x, **kwargs): +def func_interpolate_na( + interpolator: Callable, y: T_Xarray, x: np.ndarray, **kwargs +) -> T_Xarray: """helper function to apply interpolation along 1 dimension""" # reversed arguments are so that attrs are preserved from da, not index # it would be nice if this wasn't necessary, works around: # "ValueError: assignment destination is read-only" in assignment below out = y.copy() - nans = pd.isnull(y) + nans = pd.isnull(y) # type: ignore[call-overload] nonans = ~nans # fast track for no-nans, all nan but one, and all-nans cases @@ -421,7 +442,9 @@ def _bfill(arr, n=None, axis=-1): return np.flip(arr, axis=axis) -def ffill(arr, dim=None, limit=None): +def ffill( + arr: T_Xarray, dim: Hashable | None = None, limit: int | None = None +) -> T_Xarray: """forward fill missing values""" axis = arr.get_axis_num(dim) @@ -439,7 +462,9 @@ def ffill(arr, dim=None, limit=None): ).transpose(*arr.dims) -def bfill(arr, dim=None, limit=None): +def bfill( + arr: T_Xarray, dim: Hashable | None = None, limit: int | None = None +) -> T_Xarray: """backfill missing values""" axis = arr.get_axis_num(dim) @@ -457,7 +482,7 @@ def bfill(arr, dim=None, limit=None): ).transpose(*arr.dims) -def _import_interpolant(interpolant, method): +def _import_interpolant(interpolant: str, method: InterpOptions): """Import interpolant from scipy.interpolate.""" try: from scipy import interpolate @@ -469,18 +494,16 @@ def _import_interpolant(interpolant, method): def _get_interpolator( method: InterpOptions, vectorizeable_only: bool = False, **kwargs -): +) -> tuple[type[BaseInterpolator], dict[str, Any]]: """helper function to select the appropriate interpolator class returns interpolator class and keyword arguments for the class """ - interp_class: ( - type[NumpyInterpolator] | type[ScipyInterpolator] | type[SplineInterpolator] - ) - interp1d_methods = get_args(Interp1dOptions) valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v)) + interp_class: type[BaseInterpolator] + # prefer numpy.interp for 1d linear interpolation. This function cannot # take higher dimensional data but scipy.interp1d can. if ( @@ -531,7 +554,9 @@ def _get_interpolator( return interp_class, kwargs -def _get_interpolator_nd(method, **kwargs): +def _get_interpolator_nd( + method: InterpOptions, **kwargs +) -> tuple[type[BaseInterpolator], dict[str, Any]]: """helper function to select the appropriate interpolator class returns interpolator class and keyword arguments for the class @@ -551,7 +576,7 @@ def _get_interpolator_nd(method, **kwargs): return interp_class, kwargs -def _get_valid_fill_mask(arr, dim, limit): +def _get_valid_fill_mask(arr: T_Xarray, dim: Hashable, limit: int) -> T_Xarray: """helper function to determine values that can be filled when limit is not None""" kw = {dim: limit + 1} @@ -559,13 +584,21 @@ def _get_valid_fill_mask(arr, dim, limit): new_dim = utils.get_temp_dimname(arr.dims, "_window") return ( arr.isnull() - .rolling(min_periods=1, **kw) + .rolling( + kw, + min_periods=1, + ) .construct(new_dim, fill_value=False) - .sum(new_dim, skipna=False) + .sum(dim=new_dim, skipna=False) # type: ignore[arg-type] ) <= limit -def _localize(var, indexes_coords): +def _localize( + var: Variable, + indexes_coords: dict[ + Any, tuple[Variable, Variable] + ], # indexes_coords altered so can't use mapping type +) -> tuple[Variable | T_Xarray, dict[Any, tuple[Variable, Variable]]]: """Speed up for linear and nearest neighbor method. Only consider a subspace that is needed for the interpolation """ @@ -577,17 +610,20 @@ def _localize(var, indexes_coords): index = x.to_index() imin, imax = index.get_indexer([minval, maxval], method="nearest") indexes[dim] = slice(max(imin - 2, 0), imax + 2) - indexes_coords[dim] = (x[indexes[dim]], new_x) - return var.isel(**indexes), indexes_coords + indexes_coords[dim] = ( + x[indexes[dim]], + new_x, + ) # probably the in-place modification is unintentional here? + return var.isel(indexers=indexes), indexes_coords -def _floatize_x(x, new_x): +def _floatize_x(x: tuple[Variable, ...], new_x: tuple[Variable, ...]): """Make x and new_x float. This is particularly useful for datetime dtype. x, new_x: tuple of np.ndarray """ - x = list(x) - new_x = list(new_x) + x = list(x) # type: ignore[assignment] + new_x = list(new_x) # type: ignore[assignment] for i in range(len(x)): if _contains_datetime_like_objects(x[i]): # Scipy casts coordinates to np.float64, which is not accurate @@ -596,12 +632,17 @@ def _floatize_x(x, new_x): # offset (min(x)) and the variation (x - min(x)) can be # represented by float. xmin = x[i].values.min() - x[i] = x[i]._to_numeric(offset=xmin, dtype=np.float64) - new_x[i] = new_x[i]._to_numeric(offset=xmin, dtype=np.float64) + x[i] = x[i]._to_numeric(offset=xmin, dtype=np.float64) # type: ignore[index] + new_x[i] = new_x[i]._to_numeric(offset=xmin, dtype=np.float64) # type: ignore[index] return x, new_x -def interp(var, indexes_coords, method: InterpOptions, **kwargs): +def interp( + var: Variable, + indexes_coords: Mapping[Any, tuple[Any, Any]], + method: InterpOptions, + **kwargs, +) -> Variable: """Make an interpolation of Variable Parameters @@ -662,20 +703,23 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs): return result -def interp_func(var, x, new_x, method: InterpOptions, kwargs): +def interp_func( + var: type[T_DuckArray], + x: tuple[Variable, ...], + new_x: tuple[Variable, ...], + method: InterpOptions, + kwargs: dict[str, Any], +) -> ArrayLike: """ multi-dimensional interpolation for array-like. Interpolated axes should be located in the last position. Parameters ---------- - var : np.ndarray or dask.array.Array - Array to be interpolated. The final dimension is interpolated. - x : a list of 1d array. - Original coordinates. Should not contain NaN. - new_x : a list of 1d array - New coordinates. Should not contain NaN. - method : string + var : Array to be interpolated. The final dimension is interpolated. + x : Original coordinates. Should not contain NaN. + new_x : New coordinates. Should not contain NaN. + method : {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'pchip', 'akima', 'makima', 'barycentric', 'krogh'} for 1-dimensional interpolation. {'linear', 'nearest'} for multidimensional interpolation @@ -696,15 +740,16 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): scipy.interpolate.interp1d """ if not x: - return var.copy() + return var.data.copy() if len(x) == 1: func, kwargs = _get_interpolator(method, vectorizeable_only=True, **kwargs) else: func, kwargs = _get_interpolator_nd(method, **kwargs) - if is_chunked_array(var): - chunkmanager = get_chunked_array_type(var) + # is_chunked_array typed using named_array, unsure of relationship to duck arrays + if is_chunked_array(var): # type: ignore[arg-type] + chunkmanager = get_chunked_array_type(var) # duck compatible ndim = var.ndim nconst = ndim - len(x) @@ -713,11 +758,11 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): # blockwise args format x_arginds = [[_x, (nconst + index,)] for index, _x in enumerate(x)] - x_arginds = [item for pair in x_arginds for item in pair] + x_arginds = [item for pair in x_arginds for item in pair] # type: ignore[misc] new_x_arginds = [ [_x, [ndim + index for index in range(_x.ndim)]] for _x in new_x ] - new_x_arginds = [item for pair in new_x_arginds for item in pair] + new_x_arginds = [item for pair in new_x_arginds for item in pair] # type: ignore[misc] args = (var, range(ndim), *x_arginds, *new_x_arginds) @@ -727,13 +772,13 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): elem for pair in zip(rechunked, args[1::2], strict=True) for elem in pair ) - new_x = rechunked[1 + (len(rechunked) - 1) // 2 :] + new_x = rechunked[1 + (len(rechunked) - 1) // 2 :] # type: ignore[assignment] new_x0_chunks = new_x[0].chunks new_x0_shape = new_x[0].shape new_x0_chunks_is_not_none = new_x0_chunks is not None new_axes = { - ndim + i: new_x0_chunks[i] if new_x0_chunks_is_not_none else new_x0_shape[i] + ndim + i: new_x0_chunks[i] if new_x0_chunks_is_not_none else new_x0_shape[i] # type: ignore[index] for i in range(new_x[0].ndim) } @@ -743,10 +788,12 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): # scipy.interpolate.interp1d always forces to float. # Use the same check for blockwise as well: if not issubclass(var.dtype.type, np.inexact): - dtype = float + dtype = np.dtype(float) else: dtype = var.dtype + # mypy may flag this in the future since _meta is not a property + # of duck arrays--only inside this conditional if a dask array meta = var._meta return chunkmanager.blockwise( @@ -758,7 +805,7 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): localize=localize, concatenate=True, dtype=dtype, - new_axes=new_axes, + new_axes=new_axes, # type: ignore[arg-type] meta=meta, align_arrays=False, ) @@ -766,9 +813,14 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs): return _interpnd(var, x, new_x, func, kwargs) -def _interp1d(var, x, new_x, func, kwargs): +def _interp1d( + var: XarrayLike, + x: Variable, + new_x: Variable, + func: Callable, + kwargs: dict[str, Any], +) -> ArrayLike: # x, new_x are tuples of size 1. - x, new_x = x[0], new_x[0] rslt = func(x, var, **kwargs)(np.ravel(new_x)) if new_x.ndim > 1: return reshape(rslt, (var.shape[:-1] + new_x.shape)) @@ -777,11 +829,17 @@ def _interp1d(var, x, new_x, func, kwargs): return rslt -def _interpnd(var, x, new_x, func, kwargs): +def _interpnd( + var: XarrayLike, + x: tuple[Variable, ...], + new_x: tuple[Variable, ...], + func: Callable, + kwargs: dict[str, Any], +) -> ArrayLike: x, new_x = _floatize_x(x, new_x) if len(x) == 1: - return _interp1d(var, x, new_x, func, kwargs) + return _interp1d(var, x[0], new_x[0], func, kwargs) # move the interpolation axes to the start position var = var.transpose(range(-len(x), var.ndim - len(x))) @@ -793,7 +851,13 @@ def _interpnd(var, x, new_x, func, kwargs): return reshape(rslt, rslt.shape[:-1] + new_x[0].shape) -def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True): +def _chunked_aware_interpnd( + var: Variable, + *coords, + interp_func: Callable, + interp_kwargs: dict[str, Any], + localize: bool = True, +) -> ArrayLike: """Wrapper for `_interpnd` through `blockwise` for chunked arrays. The first half arrays in `coords` are original coordinates, @@ -803,11 +867,13 @@ def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=T nconst = len(var.shape) - n_x # _interpnd expect coords to be Variables - x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])] - new_x = [ + x = tuple( + Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x]) + ) + new_x = tuple( Variable([f"dim_{len(var.shape) + dim}" for dim in range(len(_x.shape))], _x) for _x in coords[n_x:] - ] + ) if localize: # _localize expect var to be a Variable @@ -818,7 +884,7 @@ def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=T } # simple speed up for the local interpolation - var, indexes_coords = _localize(var, indexes_coords) + var, indexes_coords = _localize(var, indexes_coords) # type: ignore[assignment] x, new_x = zip(*[indexes_coords[d] for d in indexes_coords], strict=True) # put var back as a ndarray @@ -827,7 +893,9 @@ def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=T return _interpnd(var, x, new_x, interp_func, interp_kwargs) -def decompose_interp(indexes_coords): +def decompose_interp( + indexes_coords: Mapping[Any, tuple[Any, Any]] +) -> Generator[Mapping[Any, tuple[Any, Any]]]: """Decompose the interpolation into a succession of independent interpolation keeping the order""" dest_dims = [ @@ -835,7 +903,7 @@ def decompose_interp(indexes_coords): for dim, dest in indexes_coords.items() ] partial_dest_dims = [] - partial_indexes_coords = {} + partial_indexes_coords: dict[Any, tuple[Any, Any]] = {} for i, index_coords in enumerate(indexes_coords.items()): partial_indexes_coords.update([index_coords])