Skip to content

Commit

Permalink
Fill missing data_vars during concat by reindexing (#7400)
Browse files Browse the repository at this point in the history
* Fill missing data variables during concat by reindexing

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* FIX: use `Any` for type of `fill_value` as this seems consistent with other places

* ENH: add tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* typing

Co-authored-by: Illviljan <[email protected]>

* typing

Co-authored-by: Illviljan <[email protected]>

* typing

Co-authored-by: Illviljan <[email protected]>

* use None instead of False

Co-authored-by: Illviljan <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* concatenate variable in any case if variable has concat_dim

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add tests from @scottcha #3545

* typing

* fix typing

* fix tests with, finalize typing

* add whats-new.rst entry

* Update xarray/tests/test_concat.py

Co-authored-by: Illviljan <[email protected]>

* Update xarray/tests/test_concat.py

Co-authored-by: Illviljan <[email protected]>

* add TODO, fix numpy.random.default_rng

* change np.random to use Generator

* move code for variable order into dedicated function, merge with _parse_datasets, provide fast lane for variable order estimation

* fix comment

* Use order from first dataset, append missing variables to the end

* ensure fill_value is dict

* ensure fill_value in align

* simplify combined_var, fix test

* revert fill_value for alignment.py

* derive variable order in order of appearance as suggested per review

* remove unneeded enumerate

* Use alignment.reindex_variables instead.

This also removes the need to handle fill_value

* small cleanup

* Update doc/whats-new.rst

Co-authored-by: Deepak Cherian <[email protected]>

* adapt tests as per review request, fix ensure_common_dims

* adapt tests as per review request

* fix whats-new.rst

* add whats-new.rst entry

* Add additional test with scalar data_var

* remove erroneous content from whats-new.rst

Co-authored-by: Scott Chamberlin <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Illviljan <[email protected]>
Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
5 people authored Jan 20, 2023
1 parent b21f62e commit b4e3cbc
Show file tree
Hide file tree
Showing 3 changed files with 466 additions and 24 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ Deprecations
Bug fixes
~~~~~~~~~

- :py:func:`xarray.concat` can now concatenate variables present in some datasets but
not others (:issue:`508`, :pull:`7400`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_ and `Scott Chamberlin <https://github.com/scottcha>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
71 changes: 55 additions & 16 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd

from xarray.core import dtypes, utils
from xarray.core.alignment import align
from xarray.core.alignment import align, reindex_variables
from xarray.core.duck_array_ops import lazy_array_equiv
from xarray.core.indexes import Index, PandasIndex
from xarray.core.merge import (
Expand Down Expand Up @@ -378,7 +378,9 @@ def process_subset_opt(opt, subset):

elif opt == "all":
concat_over.update(
set(getattr(datasets[0], subset)) - set(datasets[0].dims)
set().union(
*list(set(getattr(d, subset)) - set(d.dims) for d in datasets)
)
)
elif opt == "minimal":
pass
Expand Down Expand Up @@ -406,19 +408,26 @@ def process_subset_opt(opt, subset):

# determine dimensional coordinate names and a dict mapping name to DataArray
def _parse_datasets(
datasets: Iterable[T_Dataset],
) -> tuple[dict[Hashable, Variable], dict[Hashable, int], set[Hashable], set[Hashable]]:

datasets: list[T_Dataset],
) -> tuple[
dict[Hashable, Variable],
dict[Hashable, int],
set[Hashable],
set[Hashable],
list[Hashable],
]:
dims: set[Hashable] = set()
all_coord_names: set[Hashable] = set()
data_vars: set[Hashable] = set() # list of data_vars
dim_coords: dict[Hashable, Variable] = {} # maps dim name to variable
dims_sizes: dict[Hashable, int] = {} # shared dimension sizes to expand variables
variables_order: dict[Hashable, Variable] = {} # variables in order of appearance

for ds in datasets:
dims_sizes.update(ds.dims)
all_coord_names.update(ds.coords)
data_vars.update(ds.data_vars)
variables_order.update(ds.variables)

# preserves ordering of dimensions
for dim in ds.dims:
Expand All @@ -429,7 +438,7 @@ def _parse_datasets(
dim_coords[dim] = ds.coords[dim].variable
dims = dims | set(ds.dims)

return dim_coords, dims_sizes, all_coord_names, data_vars
return dim_coords, dims_sizes, all_coord_names, data_vars, list(variables_order)


def _dataset_concat(
Expand All @@ -439,7 +448,7 @@ def _dataset_concat(
coords: str | list[str],
compat: CompatOptions,
positions: Iterable[Iterable[int]] | None,
fill_value: object = dtypes.NA,
fill_value: Any = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> T_Dataset:
Expand Down Expand Up @@ -471,7 +480,9 @@ def _dataset_concat(
align(*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value)
)

dim_coords, dims_sizes, coord_names, data_names = _parse_datasets(datasets)
dim_coords, dims_sizes, coord_names, data_names, vars_order = _parse_datasets(
datasets
)
dim_names = set(dim_coords)
unlabeled_dims = dim_names - coord_names

Expand Down Expand Up @@ -525,7 +536,7 @@ def _dataset_concat(

# we've already verified everything is consistent; now, calculate
# shared dimension sizes so we can expand the necessary variables
def ensure_common_dims(vars):
def ensure_common_dims(vars, concat_dim_lengths):
# ensure each variable with the given name shares the same
# dimensions and the same shape for all of them except along the
# concat dimension
Expand Down Expand Up @@ -553,16 +564,35 @@ def get_indexes(name):
data = var.set_dims(dim).values
yield PandasIndex(data, dim, coord_dtype=var.dtype)

# create concatenation index, needed for later reindexing
concat_index = list(range(sum(concat_dim_lengths)))

# stack up each variable and/or index to fill-out the dataset (in order)
# n.b. this loop preserves variable order, needed for groupby.
for name in datasets[0].variables:
for name in vars_order:
if name in concat_over and name not in result_indexes:
try:
vars = ensure_common_dims([ds[name].variable for ds in datasets])
except KeyError:
raise ValueError(f"{name!r} is not present in all datasets.")

# Try concatenate the indexes, concatenate the variables when no index
variables = []
variable_index = []
var_concat_dim_length = []
for i, ds in enumerate(datasets):
if name in ds.variables:
variables.append(ds[name].variable)
# add to variable index, needed for reindexing
var_idx = [
sum(concat_dim_lengths[:i]) + k
for k in range(concat_dim_lengths[i])
]
variable_index.extend(var_idx)
var_concat_dim_length.append(len(var_idx))
else:
# raise if coordinate not in all datasets
if name in coord_names:
raise ValueError(
f"coordinate {name!r} not present in all datasets."
)
vars = ensure_common_dims(variables, var_concat_dim_length)

# Try to concatenate the indexes, concatenate the variables when no index
# is found on all datasets.
indexes: list[Index] = list(get_indexes(name))
if indexes:
Expand All @@ -589,6 +619,15 @@ def get_indexes(name):
combined_var = concat_vars(
vars, dim, positions, combine_attrs=combine_attrs
)
# reindex if variable is not present in all datasets
if len(variable_index) < len(concat_index):
combined_var = reindex_variables(
variables={name: combined_var},
dim_pos_indexers={
dim: pd.Index(variable_index).get_indexer(concat_index)
},
fill_value=fill_value,
)[name]
result_vars[name] = combined_var

elif name in result_vars:
Expand Down
Loading

0 comments on commit b4e3cbc

Please sign in to comment.