Skip to content

Commit

Permalink
Make xr.corr and xr.map_blocks work without dask (#5731)
Browse files Browse the repository at this point in the history
Co-authored-by: Mathias Hauser <[email protected]>
Co-authored-by: dcherian <[email protected]>
Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
4 people authored Nov 24, 2021
1 parent 5db4046 commit 21e8484
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 7 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ Deprecations

Bug fixes
~~~~~~~~~
- :py:func:`xr.map_blocks` and :py:func:`xr.corr` now work when dask is not installed (:issue:`3391`, :issue:`5715`, :pull:`5731`).
By `Gijom <https://github.com/Gijom>`_.
- Fix plot.line crash for data of shape ``(1, N)`` in _title_for_slice on format_item (:pull:`5948`).
By `Sebastian Weigand <https://github.com/s-weigand>`_.
- Fix a regression in the removal of duplicate backend entrypoints (:issue:`5944`, :pull:`5959`)
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,8 @@ def _get_valid_values(da, other):
da = da.where(~missing_vals)
return da
else:
return da
# ensure consistent return dtype
return da.astype(float)

da_a = da_a.map_blocks(_get_valid_values, args=[da_b])
da_b = da_b.map_blocks(_get_valid_values, args=[da_a])
Expand Down
5 changes: 3 additions & 2 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .alignment import align
from .dataarray import DataArray
from .dataset import Dataset
from .pycompat import is_dask_collection

try:
import dask
Expand Down Expand Up @@ -328,13 +329,13 @@ def _wrapper(
raise TypeError("kwargs must be a mapping (for example, a dict)")

for value in kwargs.values():
if dask.is_dask_collection(value):
if is_dask_collection(value):
raise TypeError(
"Cannot pass dask collections in kwargs yet. Please compute or "
"load values before passing to map_blocks."
)

if not dask.is_dask_collection(obj):
if not is_dask_collection(obj):
return func(obj, *args, **kwargs)

all_args = [obj] + list(args)
Expand Down
8 changes: 6 additions & 2 deletions xarray/core/pycompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,19 @@ def __init__(self, mod):
self.available = duck_array_module is not None


def is_duck_dask_array(x):
def is_dask_collection(x):
if DuckArrayModule("dask").available:
from dask.base import is_dask_collection

return is_duck_array(x) and is_dask_collection(x)
return is_dask_collection(x)
else:
return False


def is_duck_dask_array(x):
return is_duck_array(x) and is_dask_collection(x)


dsk = DuckArrayModule("dask")
dask_version = dsk.version
dask_array_type = dsk.type
Expand Down
25 changes: 23 additions & 2 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

from . import has_dask, raise_if_dask_computes, requires_dask

dask = pytest.importorskip("dask")


def assert_identical(a, b):
"""A version of this function which accepts numpy arrays"""
Expand Down Expand Up @@ -1420,6 +1418,7 @@ def arrays_w_tuples():
],
)
@pytest.mark.parametrize("dim", [None, "x", "time"])
@requires_dask
def test_lazy_corrcov(da_a, da_b, dim, ddof) -> None:
# GH 5284
from dask import is_dask_collection
Expand Down Expand Up @@ -1554,6 +1553,28 @@ def test_covcorr_consistency(da_a, da_b, dim) -> None:
assert_allclose(actual, expected)


@requires_dask
@pytest.mark.parametrize("da_a, da_b", arrays_w_tuples()[1])
@pytest.mark.parametrize("dim", [None, "time", "x"])
def test_corr_lazycorr_consistency(da_a, da_b, dim) -> None:
da_al = da_a.chunk()
da_bl = da_b.chunk()
c_abl = xr.corr(da_al, da_bl, dim=dim)
c_ab = xr.corr(da_a, da_b, dim=dim)
c_ab_mixed = xr.corr(da_a, da_bl, dim=dim)
assert_allclose(c_ab, c_abl)
assert_allclose(c_ab, c_ab_mixed)


@requires_dask
def test_corr_dtype_error():
da_a = xr.DataArray([[1, 2], [2, 1]], dims=["x", "time"])
da_b = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"])

xr.testing.assert_equal(xr.corr(da_a, da_b), xr.corr(da_a.chunk(), da_b.chunk()))
xr.testing.assert_equal(xr.corr(da_a, da_b), xr.corr(da_a, da_b.chunk()))


@pytest.mark.parametrize(
"da_a",
arrays_w_tuples()[0],
Expand Down

0 comments on commit 21e8484

Please sign in to comment.