Skip to content

Commit

Permalink
flox: Properly propagate multiindex (#9649)
Browse files Browse the repository at this point in the history
* flox: Properly propagate multiindex

Closes #9648

* skip test on old pandas

* small optimization

* fix
  • Loading branch information
dcherian authored Oct 21, 2024
1 parent cfaa72f commit 01831a4
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 37 deletions.
2 changes: 1 addition & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Bug fixes
the non-missing times could in theory be encoded with integers
(:issue:`9488`, :pull:`9497`). By `Spencer Clark
<https://github.com/spencerkclark>`_.
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`).
- Fix a few bugs affecting groupby reductions with `flox`. (:issue:`8090`, :issue:`9398`, :issue:`9648`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Fix the safe_chunks validation option on the to_zarr method
(:issue:`5511`, :pull:`9559`). By `Joseph Nowak
Expand Down
11 changes: 11 additions & 0 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,3 +1116,14 @@ def create_coords_with_default_indexes(
new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes)

return new_coords


def _coordinates_from_variable(variable: Variable) -> Coordinates:
from xarray.core.indexes import create_default_index_implicit

(name,) = variable.dims
new_index, index_vars = create_default_index_implicit(variable)
indexes = {k: new_index for k in index_vars}
new_vars = new_index.create_variables()
new_vars[name].attrs = variable.attrs
return Coordinates(new_vars, indexes)
41 changes: 18 additions & 23 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.concat import concat
from xarray.core.coordinates import Coordinates
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
from xarray.core.formatting import format_array_flat
from xarray.core.indexes import (
PandasIndex,
PandasMultiIndex,
filter_indexes_from_coords,
)
from xarray.core.merge import merge_coords
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import (
Dims,
Expand Down Expand Up @@ -851,7 +851,6 @@ def _flox_reduce(
from flox.xarray import xarray_reduce

from xarray.core.dataset import Dataset
from xarray.groupers import BinGrouper

obj = self._original_obj
variables = (
Expand Down Expand Up @@ -901,13 +900,6 @@ def _flox_reduce(
# set explicitly to avoid unnecessarily accumulating count
kwargs["min_count"] = 0

unindexed_dims: tuple[Hashable, ...] = tuple(
grouper.name
for grouper in self.groupers
if isinstance(grouper.group, _DummyGroup)
and not isinstance(grouper.grouper, BinGrouper)
)

parsed_dim: tuple[Hashable, ...]
if isinstance(dim, str):
parsed_dim = (dim,)
Expand Down Expand Up @@ -963,26 +955,29 @@ def _flox_reduce(
# we did end up reducing over dimension(s) that are
# in the grouped variable
group_dims = set(grouper.group.dims)
new_coords = {}
new_coords = []
to_drop = []
if group_dims.issubset(set(parsed_dim)):
new_indexes = {}
for grouper in self.groupers:
output_index = grouper.full_index
if isinstance(output_index, pd.RangeIndex):
# flox always assigns an index so we must drop it here if we don't need it.
to_drop.append(grouper.name)
continue
name = grouper.name
new_coords[name] = IndexVariable(
dims=name, data=np.array(output_index), attrs=grouper.codes.attrs
)
index_cls = (
PandasIndex
if not isinstance(output_index, pd.MultiIndex)
else PandasMultiIndex
new_coords.append(
# Using IndexVariable here ensures we reconstruct PandasMultiIndex with
# all associated levels properly.
_coordinates_from_variable(
IndexVariable(
dims=grouper.name,
data=output_index,
attrs=grouper.codes.attrs,
)
)
)
new_indexes[name] = index_cls(output_index, dim=name)
result = result.assign_coords(
Coordinates(new_coords, new_indexes)
).drop_vars(unindexed_dims)
Coordinates._construct_direct(*merge_coords(new_coords))
).drop_vars(to_drop)

# broadcast any non-dim coord variables that don't
# share all dimensions with the grouper
Expand Down
13 changes: 1 addition & 12 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
from xarray.core import duck_array_ops
from xarray.core.coordinates import Coordinates
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
from xarray.core.dataarray import DataArray
from xarray.core.groupby import T_Group, _DummyGroup
from xarray.core.indexes import safe_cast_to_index
Expand All @@ -42,17 +42,6 @@
RESAMPLE_DIM = "__resample_dim__"


def _coordinates_from_variable(variable: Variable) -> Coordinates:
from xarray.core.indexes import create_default_index_implicit

(name,) = variable.dims
new_index, index_vars = create_default_index_implicit(variable)
indexes = {k: new_index for k in index_vars}
new_vars = new_index.create_variables()
new_vars[name].attrs = variable.attrs
return Coordinates(new_vars, indexes)


@dataclass(init=False)
class EncodedGroups:
"""
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _importorskip(
has_pint, requires_pint = _importorskip("pint")
has_numexpr, requires_numexpr = _importorskip("numexpr")
has_flox, requires_flox = _importorskip("flox")
has_pandas_ge_2_2, __ = _importorskip("pandas", "2.2")
has_pandas_ge_2_2, requires_pandas_ge_2_2 = _importorskip("pandas", "2.2")
has_pandas_3, requires_pandas_3 = _importorskip("pandas", "3.0.0.dev0")


Expand Down
21 changes: 21 additions & 0 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
requires_dask,
requires_flox,
requires_flox_0_9_12,
requires_pandas_ge_2_2,
requires_scipy,
)

Expand Down Expand Up @@ -145,6 +146,26 @@ def test_multi_index_groupby_sum() -> None:
assert_equal(expected, actual)


@requires_pandas_ge_2_2
def test_multi_index_propagation():
# regression test for GH9648
times = pd.date_range("2023-01-01", periods=4)
locations = ["A", "B"]
data = [[0.5, 0.7], [0.6, 0.5], [0.4, 0.6], [0.4, 0.9]]

da = xr.DataArray(
data, dims=["time", "location"], coords={"time": times, "location": locations}
)
da = da.stack(multiindex=["time", "location"])
grouped = da.groupby("multiindex")

with xr.set_options(use_flox=True):
actual = grouped.sum()
with xr.set_options(use_flox=False):
expected = grouped.first()
assert_identical(actual, expected)


def test_groupby_da_datetime() -> None:
# test groupby with a DataArray of dtype datetime for GH1132
# create test data
Expand Down

0 comments on commit 01831a4

Please sign in to comment.