diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4b03e8d676e..b0f6a07841b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -35,6 +35,9 @@ Deprecations Bug fixes ~~~~~~~~~ +- :py:func:`xarray.concat` can now concatenate variables present in some datasets but + not others (:issue:`508`, :pull:`7400`). + By `Kai Mühlbauer `_ and `Scott Chamberlin `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 2eea2ecb3ee..95a2199f9c5 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -5,7 +5,7 @@ import pandas as pd from xarray.core import dtypes, utils -from xarray.core.alignment import align +from xarray.core.alignment import align, reindex_variables from xarray.core.duck_array_ops import lazy_array_equiv from xarray.core.indexes import Index, PandasIndex from xarray.core.merge import ( @@ -378,7 +378,9 @@ def process_subset_opt(opt, subset): elif opt == "all": concat_over.update( - set(getattr(datasets[0], subset)) - set(datasets[0].dims) + set().union( + *list(set(getattr(d, subset)) - set(d.dims) for d in datasets) + ) ) elif opt == "minimal": pass @@ -406,19 +408,26 @@ def process_subset_opt(opt, subset): # determine dimensional coordinate names and a dict mapping name to DataArray def _parse_datasets( - datasets: Iterable[T_Dataset], -) -> tuple[dict[Hashable, Variable], dict[Hashable, int], set[Hashable], set[Hashable]]: - + datasets: list[T_Dataset], +) -> tuple[ + dict[Hashable, Variable], + dict[Hashable, int], + set[Hashable], + set[Hashable], + list[Hashable], +]: dims: set[Hashable] = set() all_coord_names: set[Hashable] = set() data_vars: set[Hashable] = set() # list of data_vars dim_coords: dict[Hashable, Variable] = {} # maps dim name to variable dims_sizes: dict[Hashable, int] = {} # shared dimension sizes to expand variables + variables_order: dict[Hashable, Variable] = {} # variables in order of appearance for ds in datasets: dims_sizes.update(ds.dims) all_coord_names.update(ds.coords) data_vars.update(ds.data_vars) + variables_order.update(ds.variables) # preserves ordering of dimensions for dim in ds.dims: @@ -429,7 +438,7 @@ def _parse_datasets( dim_coords[dim] = ds.coords[dim].variable dims = dims | set(ds.dims) - return dim_coords, dims_sizes, all_coord_names, data_vars + return dim_coords, dims_sizes, all_coord_names, data_vars, list(variables_order) def _dataset_concat( @@ -439,7 +448,7 @@ def _dataset_concat( coords: str | list[str], compat: CompatOptions, positions: Iterable[Iterable[int]] | None, - fill_value: object = dtypes.NA, + fill_value: Any = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", ) -> T_Dataset: @@ -471,7 +480,9 @@ def _dataset_concat( align(*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value) ) - dim_coords, dims_sizes, coord_names, data_names = _parse_datasets(datasets) + dim_coords, dims_sizes, coord_names, data_names, vars_order = _parse_datasets( + datasets + ) dim_names = set(dim_coords) unlabeled_dims = dim_names - coord_names @@ -525,7 +536,7 @@ def _dataset_concat( # we've already verified everything is consistent; now, calculate # shared dimension sizes so we can expand the necessary variables - def ensure_common_dims(vars): + def ensure_common_dims(vars, concat_dim_lengths): # ensure each variable with the given name shares the same # dimensions and the same shape for all of them except along the # concat dimension @@ -553,16 +564,35 @@ def get_indexes(name): data = var.set_dims(dim).values yield PandasIndex(data, dim, coord_dtype=var.dtype) + # create concatenation index, needed for later reindexing + concat_index = list(range(sum(concat_dim_lengths))) + # stack up each variable and/or index to fill-out the dataset (in order) # n.b. this loop preserves variable order, needed for groupby. - for name in datasets[0].variables: + for name in vars_order: if name in concat_over and name not in result_indexes: - try: - vars = ensure_common_dims([ds[name].variable for ds in datasets]) - except KeyError: - raise ValueError(f"{name!r} is not present in all datasets.") - - # Try concatenate the indexes, concatenate the variables when no index + variables = [] + variable_index = [] + var_concat_dim_length = [] + for i, ds in enumerate(datasets): + if name in ds.variables: + variables.append(ds[name].variable) + # add to variable index, needed for reindexing + var_idx = [ + sum(concat_dim_lengths[:i]) + k + for k in range(concat_dim_lengths[i]) + ] + variable_index.extend(var_idx) + var_concat_dim_length.append(len(var_idx)) + else: + # raise if coordinate not in all datasets + if name in coord_names: + raise ValueError( + f"coordinate {name!r} not present in all datasets." + ) + vars = ensure_common_dims(variables, var_concat_dim_length) + + # Try to concatenate the indexes, concatenate the variables when no index # is found on all datasets. indexes: list[Index] = list(get_indexes(name)) if indexes: @@ -589,6 +619,15 @@ def get_indexes(name): combined_var = concat_vars( vars, dim, positions, combine_attrs=combine_attrs ) + # reindex if variable is not present in all datasets + if len(variable_index) < len(concat_index): + combined_var = reindex_variables( + variables={name: combined_var}, + dim_pos_indexers={ + dim: pd.Index(variable_index).get_indexer(concat_index) + }, + fill_value=fill_value, + )[name] result_vars[name] = combined_var elif name in result_vars: diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index e0e0038cd89..eb3cd7fccc8 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -1,7 +1,7 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable import numpy as np import pandas as pd @@ -23,6 +23,89 @@ from xarray.core.types import CombineAttrsOptions, JoinOptions +# helper method to create multiple tests datasets to concat +def create_concat_datasets( + num_datasets: int = 2, seed: int | None = None, include_day: bool = True +) -> list[Dataset]: + rng = np.random.default_rng(seed) + lat = rng.standard_normal(size=(1, 4)) + lon = rng.standard_normal(size=(1, 4)) + result = [] + variables = ["temperature", "pressure", "humidity", "precipitation", "cloud_cover"] + for i in range(num_datasets): + if include_day: + data_tuple = ( + ["x", "y", "day"], + rng.standard_normal(size=(1, 4, 2)), + ) + data_vars = {v: data_tuple for v in variables} + result.append( + Dataset( + data_vars=data_vars, + coords={ + "lat": (["x", "y"], lat), + "lon": (["x", "y"], lon), + "day": ["day" + str(i * 2 + 1), "day" + str(i * 2 + 2)], + }, + ) + ) + else: + data_tuple = ( + ["x", "y"], + rng.standard_normal(size=(1, 4)), + ) + data_vars = {v: data_tuple for v in variables} + result.append( + Dataset( + data_vars=data_vars, + coords={"lat": (["x", "y"], lat), "lon": (["x", "y"], lon)}, + ) + ) + + return result + + +# helper method to create multiple tests datasets to concat with specific types +def create_typed_datasets( + num_datasets: int = 2, seed: int | None = None +) -> list[Dataset]: + var_strings = ["a", "b", "c", "d", "e", "f", "g", "h"] + result = [] + rng = np.random.default_rng(seed) + lat = rng.standard_normal(size=(1, 4)) + lon = rng.standard_normal(size=(1, 4)) + for i in range(num_datasets): + result.append( + Dataset( + data_vars={ + "float": (["x", "y", "day"], rng.standard_normal(size=(1, 4, 2))), + "float2": (["x", "y", "day"], rng.standard_normal(size=(1, 4, 2))), + "string": ( + ["x", "y", "day"], + rng.choice(var_strings, size=(1, 4, 2)), + ), + "int": (["x", "y", "day"], rng.integers(0, 10, size=(1, 4, 2))), + "datetime64": ( + ["x", "y", "day"], + np.arange( + np.datetime64("2017-01-01"), np.datetime64("2017-01-09") + ).reshape(1, 4, 2), + ), + "timedelta64": ( + ["x", "y", "day"], + np.reshape([pd.Timedelta(days=i) for i in range(8)], [1, 4, 2]), + ), + }, + coords={ + "lat": (["x", "y"], lat), + "lon": (["x", "y"], lon), + "day": ["day" + str(i * 2 + 1), "day" + str(i * 2 + 2)], + }, + ) + ) + return result + + def test_concat_compat() -> None: ds1 = Dataset( { @@ -46,14 +129,324 @@ def test_concat_compat() -> None: for var in ["has_x", "no_x_y"]: assert "y" not in result[var].dims and "y" not in result[var].coords - with pytest.raises( - ValueError, match=r"coordinates in some datasets but not others" - ): + with pytest.raises(ValueError, match=r"'q' not present in all datasets"): concat([ds1, ds2], dim="q") - with pytest.raises(ValueError, match=r"'q' is not present in all datasets"): + with pytest.raises(ValueError, match=r"'q' not present in all datasets"): concat([ds2, ds1], dim="q") +def test_concat_missing_var() -> None: + datasets = create_concat_datasets(2, seed=123) + expected = concat(datasets, dim="day") + vars_to_drop = ["humidity", "precipitation", "cloud_cover"] + + expected = expected.drop_vars(vars_to_drop) + expected["pressure"][..., 2:] = np.nan + + datasets[0] = datasets[0].drop_vars(vars_to_drop) + datasets[1] = datasets[1].drop_vars(vars_to_drop + ["pressure"]) + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == ["temperature", "pressure"] + assert_identical(actual, expected) + + +def test_concat_missing_multiple_consecutive_var() -> None: + datasets = create_concat_datasets(3, seed=123) + expected = concat(datasets, dim="day") + vars_to_drop = ["humidity", "pressure"] + + expected["pressure"][..., :4] = np.nan + expected["humidity"][..., :4] = np.nan + + datasets[0] = datasets[0].drop_vars(vars_to_drop) + datasets[1] = datasets[1].drop_vars(vars_to_drop) + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == [ + "temperature", + "precipitation", + "cloud_cover", + "pressure", + "humidity", + ] + assert_identical(actual, expected) + + +def test_concat_all_empty() -> None: + ds1 = Dataset() + ds2 = Dataset() + expected = Dataset() + actual = concat([ds1, ds2], dim="new_dim") + + assert_identical(actual, expected) + + +def test_concat_second_empty() -> None: + ds1 = Dataset(data_vars={"a": ("y", [0.1])}, coords={"x": 0.1}) + ds2 = Dataset(coords={"x": 0.1}) + + expected = Dataset(data_vars={"a": ("y", [0.1, np.nan])}, coords={"x": 0.1}) + actual = concat([ds1, ds2], dim="y") + assert_identical(actual, expected) + + expected = Dataset( + data_vars={"a": ("y", [0.1, np.nan])}, coords={"x": ("y", [0.1, 0.1])} + ) + actual = concat([ds1, ds2], dim="y", coords="all") + assert_identical(actual, expected) + + # Check concatenating scalar data_var only present in ds1 + ds1["b"] = 0.1 + expected = Dataset( + data_vars={"a": ("y", [0.1, np.nan]), "b": ("y", [0.1, np.nan])}, + coords={"x": ("y", [0.1, 0.1])}, + ) + actual = concat([ds1, ds2], dim="y", coords="all", data_vars="all") + assert_identical(actual, expected) + + expected = Dataset( + data_vars={"a": ("y", [0.1, np.nan]), "b": 0.1}, coords={"x": 0.1} + ) + actual = concat([ds1, ds2], dim="y", coords="different", data_vars="different") + assert_identical(actual, expected) + + +def test_concat_multiple_missing_variables() -> None: + datasets = create_concat_datasets(2, seed=123) + expected = concat(datasets, dim="day") + vars_to_drop = ["pressure", "cloud_cover"] + + expected["pressure"][..., 2:] = np.nan + expected["cloud_cover"][..., 2:] = np.nan + + datasets[1] = datasets[1].drop_vars(vars_to_drop) + actual = concat(datasets, dim="day") + + # check the variables orders are the same + assert list(actual.data_vars.keys()) == [ + "temperature", + "pressure", + "humidity", + "precipitation", + "cloud_cover", + ] + + assert_identical(actual, expected) + + +@pytest.mark.parametrize("include_day", [True, False]) +def test_concat_multiple_datasets_missing_vars(include_day: bool) -> None: + vars_to_drop = [ + "temperature", + "pressure", + "humidity", + "precipitation", + "cloud_cover", + ] + + datasets = create_concat_datasets( + len(vars_to_drop), seed=123, include_day=include_day + ) + expected = concat(datasets, dim="day") + + for i, name in enumerate(vars_to_drop): + if include_day: + expected[name][..., i * 2 : (i + 1) * 2] = np.nan + else: + expected[name][i : i + 1, ...] = np.nan + + # set up the test data + datasets = [ds.drop_vars(varname) for ds, varname in zip(datasets, vars_to_drop)] + + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == [ + "pressure", + "humidity", + "precipitation", + "cloud_cover", + "temperature", + ] + assert_identical(actual, expected) + + +def test_concat_multiple_datasets_with_multiple_missing_variables() -> None: + vars_to_drop_in_first = ["temperature", "pressure"] + vars_to_drop_in_second = ["humidity", "precipitation", "cloud_cover"] + datasets = create_concat_datasets(2, seed=123) + expected = concat(datasets, dim="day") + for name in vars_to_drop_in_first: + expected[name][..., :2] = np.nan + for name in vars_to_drop_in_second: + expected[name][..., 2:] = np.nan + + # set up the test data + datasets[0] = datasets[0].drop_vars(vars_to_drop_in_first) + datasets[1] = datasets[1].drop_vars(vars_to_drop_in_second) + + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == [ + "humidity", + "precipitation", + "cloud_cover", + "temperature", + "pressure", + ] + assert_identical(actual, expected) + + +def test_concat_type_of_missing_fill() -> None: + datasets = create_typed_datasets(2, seed=123) + expected1 = concat(datasets, dim="day", fill_value=dtypes.NA) + expected2 = concat(datasets[::-1], dim="day", fill_value=dtypes.NA) + vars = ["float", "float2", "string", "int", "datetime64", "timedelta64"] + expected = [expected2, expected1] + for i, exp in enumerate(expected): + sl = slice(i * 2, (i + 1) * 2) + exp["float2"][..., sl] = np.nan + exp["datetime64"][..., sl] = np.nan + exp["timedelta64"][..., sl] = np.nan + var = exp["int"] * 1.0 + var[..., sl] = np.nan + exp["int"] = var + var = exp["string"].astype(object) + var[..., sl] = np.nan + exp["string"] = var + + # set up the test data + datasets[1] = datasets[1].drop_vars(vars[1:]) + + actual = concat(datasets, dim="day", fill_value=dtypes.NA) + + assert_identical(actual, expected[1]) + + # reversed + actual = concat(datasets[::-1], dim="day", fill_value=dtypes.NA) + + assert_identical(actual, expected[0]) + + +def test_concat_order_when_filling_missing() -> None: + vars_to_drop_in_first: list[str] = [] + # drop middle + vars_to_drop_in_second = ["humidity"] + datasets = create_concat_datasets(2, seed=123) + expected1 = concat(datasets, dim="day") + for name in vars_to_drop_in_second: + expected1[name][..., 2:] = np.nan + expected2 = concat(datasets[::-1], dim="day") + for name in vars_to_drop_in_second: + expected2[name][..., :2] = np.nan + + # set up the test data + datasets[0] = datasets[0].drop_vars(vars_to_drop_in_first) + datasets[1] = datasets[1].drop_vars(vars_to_drop_in_second) + + actual = concat(datasets, dim="day") + + assert list(actual.data_vars.keys()) == [ + "temperature", + "pressure", + "humidity", + "precipitation", + "cloud_cover", + ] + assert_identical(actual, expected1) + + actual = concat(datasets[::-1], dim="day") + + assert list(actual.data_vars.keys()) == [ + "temperature", + "pressure", + "precipitation", + "cloud_cover", + "humidity", + ] + assert_identical(actual, expected2) + + +@pytest.fixture +def concat_var_names() -> Callable: + # create var names list with one missing value + def get_varnames(var_cnt: int = 10, list_cnt: int = 10) -> list[list[str]]: + orig = [f"d{i:02d}" for i in range(var_cnt)] + var_names = [] + for i in range(0, list_cnt): + l1 = orig.copy() + var_names.append(l1) + return var_names + + return get_varnames + + +@pytest.fixture +def create_concat_ds() -> Callable: + def create_ds( + var_names: list[list[str]], + dim: bool = False, + coord: bool = False, + drop_idx: list[int] | None = None, + ) -> list[Dataset]: + out_ds = [] + ds = Dataset() + ds = ds.assign_coords({"x": np.arange(2)}) + ds = ds.assign_coords({"y": np.arange(3)}) + ds = ds.assign_coords({"z": np.arange(4)}) + for i, dsl in enumerate(var_names): + vlist = dsl.copy() + if drop_idx is not None: + vlist.pop(drop_idx[i]) + foo_data = np.arange(48, dtype=float).reshape(2, 2, 3, 4) + dsi = ds.copy() + if coord: + dsi = ds.assign({"time": (["time"], [i * 2, i * 2 + 1])}) + for k in vlist: + dsi = dsi.assign({k: (["time", "x", "y", "z"], foo_data.copy())}) + if not dim: + dsi = dsi.isel(time=0) + out_ds.append(dsi) + return out_ds + + return create_ds + + +@pytest.mark.parametrize("dim", [True, False]) +@pytest.mark.parametrize("coord", [True, False]) +def test_concat_fill_missing_variables( + concat_var_names, create_concat_ds, dim: bool, coord: bool +) -> None: + var_names = concat_var_names() + drop_idx = [0, 7, 6, 4, 4, 8, 0, 6, 2, 0] + + expected = concat( + create_concat_ds(var_names, dim=dim, coord=coord), dim="time", data_vars="all" + ) + for i, idx in enumerate(drop_idx): + if dim: + expected[var_names[0][idx]][i * 2 : i * 2 + 2] = np.nan + else: + expected[var_names[0][idx]][i] = np.nan + + concat_ds = create_concat_ds(var_names, dim=dim, coord=coord, drop_idx=drop_idx) + actual = concat(concat_ds, dim="time", data_vars="all") + + assert list(actual.data_vars.keys()) == [ + "d01", + "d02", + "d03", + "d04", + "d05", + "d06", + "d07", + "d08", + "d09", + "d00", + ] + assert_identical(actual, expected) + + class TestConcatDataset: @pytest.fixture def data(self) -> Dataset: @@ -86,10 +479,17 @@ def test_concat_merge_variables_present_in_some_datasets(self, data) -> None: split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))] data0, data1 = deepcopy(split_data) data1["foo"] = ("bar", np.random.randn(10)) - actual = concat([data0, data1], "dim1") + actual = concat([data0, data1], "dim1", data_vars="minimal") expected = data.copy().assign(foo=data1.foo) assert_identical(expected, actual) + # expand foo + actual = concat([data0, data1], "dim1") + foo = np.ones((8, 10), dtype=data1.foo.dtype) * np.nan + foo[3:] = data1.foo.values[None, ...] + expected = data.copy().assign(foo=(["dim1", "bar"], foo)) + assert_identical(expected, actual) + def test_concat_2(self, data) -> None: dim = "dim2" datasets = [g for _, g in data.groupby(dim, squeeze=True)] @@ -776,7 +1176,7 @@ def test_concat_merge_single_non_dim_coord(): actual = concat([da1, da2], "x", coords=coords) assert_identical(actual, expected) - with pytest.raises(ValueError, match=r"'y' is not present in all datasets."): + with pytest.raises(ValueError, match=r"'y' not present in all datasets."): concat([da1, da2], dim="x", coords="all") da1 = DataArray([1, 2, 3], dims="x", coords={"x": [1, 2, 3], "y": 1}) @@ -784,7 +1184,7 @@ def test_concat_merge_single_non_dim_coord(): da3 = DataArray([7, 8, 9], dims="x", coords={"x": [7, 8, 9], "y": 1}) for coords in ["different", "all"]: with pytest.raises(ValueError, match=r"'y' not present in all datasets"): - concat([da1, da2, da3], dim="x") + concat([da1, da2, da3], dim="x", coords=coords) def test_concat_preserve_coordinate_order() -> None: