Skip to content

Commit

Permalink
Another groupby.reduce bugfix. (#3403)
Browse files Browse the repository at this point in the history
* Another groupby.reduce bugfix.

Fixes #3402

* Add whats-new.

* Use is_scalar instead

* bugfix

* fix whats-new

* Update xarray/core/groupby.py

Co-Authored-By: Maximilian Roos <[email protected]>
  • Loading branch information
dcherian and max-sixty authored Oct 25, 2019
1 parent 63cc857 commit fb0cf7b
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 35 deletions.
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ Bug fixes
- Sync with cftime by removing `dayofwk=-1` for cftime>=1.0.4.
By `Anderson Banihirwe <https://github.com/andersy005>`_.

- Fix :py:meth:`xarray.core.groupby.DataArrayGroupBy.reduce` and
:py:meth:`xarray.core.groupby.DatasetGroupBy.reduce` when reducing over multiple dimensions.
(:issue:`3402`). By `Deepak Cherian <https://github.com/dcherian/>`_


Documentation
~~~~~~~~~~~~~
Expand Down
27 changes: 16 additions & 11 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,26 @@
from .utils import (
either_dict_or_kwargs,
hashable,
is_scalar,
maybe_wrap_array,
peek_at,
safe_cast_to_index,
)
from .variable import IndexVariable, Variable, as_variable


def check_reduce_dims(reduce_dims, dimensions):

if reduce_dims is not ...:
if is_scalar(reduce_dims):
reduce_dims = [reduce_dims]
if any([dim not in dimensions for dim in reduce_dims]):
raise ValueError(
"cannot reduce over dimensions %r. expected either '...' to reduce over all dimensions or one or more of %r."
% (reduce_dims, dimensions)
)


def unique_value_groups(ar, sort=True):
"""Group an array by its unique values.
Expand Down Expand Up @@ -794,15 +807,11 @@ def reduce(
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)

if dim is not ... and dim not in self.dims:
raise ValueError(
"cannot reduce over dimension %r. expected either '...' to reduce over all dimensions or one or more of %r."
% (dim, self.dims)
)

def reduce_array(ar):
return ar.reduce(func, dim, axis, keep_attrs=keep_attrs, **kwargs)

check_reduce_dims(dim, self.dims)

return self.apply(reduce_array, shortcut=shortcut)


Expand Down Expand Up @@ -895,11 +904,7 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs):
def reduce_dataset(ds):
return ds.reduce(func, dim, keep_attrs, **kwargs)

if dim is not ... and dim not in self.dims:
raise ValueError(
"cannot reduce over dimension %r. expected either '...' to reduce over all dimensions or one or more of %r."
% (dim, self.dims)
)
check_reduce_dims(dim, self.dims)

return self.apply(reduce_dataset)

Expand Down
9 changes: 0 additions & 9 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2560,15 +2560,6 @@ def change_metadata(x):
expected = change_metadata(expected)
assert_equal(expected, actual)

def test_groupby_reduce_dimension_error(self):
array = self.make_groupby_example_array()
grouped = array.groupby("y")
with raises_regex(ValueError, "cannot reduce over dimension 'y'"):
grouped.mean()

grouped = array.groupby("y", squeeze=False)
assert_identical(array, grouped.mean())

def test_groupby_math(self):
array = self.make_groupby_example_array()
for squeeze in [True, False]:
Expand Down
56 changes: 41 additions & 15 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,23 @@
import xarray as xr
from xarray.core.groupby import _consolidate_slices

from . import assert_identical, raises_regex
from . import assert_allclose, assert_identical, raises_regex


@pytest.fixture
def dataset():
ds = xr.Dataset(
{"foo": (("x", "y", "z"), np.random.randn(3, 4, 2))},
{"x": ["a", "b", "c"], "y": [1, 2, 3, 4], "z": [1, 2]},
)
ds["boo"] = (("z", "y"), [["f", "g", "h", "j"]] * 2)

return ds


@pytest.fixture
def array(dataset):
return dataset["foo"]


def test_consolidate_slices():
Expand All @@ -21,25 +37,17 @@ def test_consolidate_slices():
_consolidate_slices([slice(3), 4])


def test_groupby_dims_property():
ds = xr.Dataset(
{"foo": (("x", "y", "z"), np.random.randn(3, 4, 2))},
{"x": ["a", "bcd", "c"], "y": [1, 2, 3, 4], "z": [1, 2]},
)
def test_groupby_dims_property(dataset):
assert dataset.groupby("x").dims == dataset.isel(x=1).dims
assert dataset.groupby("y").dims == dataset.isel(y=1).dims

assert ds.groupby("x").dims == ds.isel(x=1).dims
assert ds.groupby("y").dims == ds.isel(y=1).dims

stacked = ds.stack({"xy": ("x", "y")})
stacked = dataset.stack({"xy": ("x", "y")})
assert stacked.groupby("xy").dims == stacked.isel(xy=0).dims


def test_multi_index_groupby_apply():
def test_multi_index_groupby_apply(dataset):
# regression test for GH873
ds = xr.Dataset(
{"foo": (("x", "y"), np.random.randn(3, 4))},
{"x": ["a", "b", "c"], "y": [1, 2, 3, 4]},
)
ds = dataset.isel(z=1, drop=True)[["foo"]]
doubled = 2 * ds
group_doubled = (
ds.stack(space=["x", "y"])
Expand Down Expand Up @@ -276,6 +284,24 @@ def test_groupby_grouping_errors():
dataset.to_array().groupby(dataset.foo * np.nan)


def test_groupby_reduce_dimension_error(array):
grouped = array.groupby("y")
with raises_regex(ValueError, "cannot reduce over dimensions"):
grouped.mean()

with raises_regex(ValueError, "cannot reduce over dimensions"):
grouped.mean("huh")

with raises_regex(ValueError, "cannot reduce over dimensions"):
grouped.mean(("x", "y", "asd"))

grouped = array.groupby("y", squeeze=False)
assert_identical(array, grouped.mean())

assert_identical(array.mean("x"), grouped.reduce(np.mean, "x"))
assert_allclose(array.mean(["x", "z"]), grouped.reduce(np.mean, ["x", "z"]))


def test_groupby_bins_timeseries():
ds = xr.Dataset()
ds["time"] = xr.DataArray(
Expand Down

0 comments on commit fb0cf7b

Please sign in to comment.