From e8a3c358bb9e607f7de98f7e7b47a1dda0be3c21 Mon Sep 17 00:00:00 2001
From: Stephan Hoyer <shoyer@google.com>
Date: Sat, 28 Sep 2024 15:12:30 -0700
Subject: [PATCH] Stop inheriting non-indexed coordinates for DataTree

This is option (4) from https://github.com/pydata/xarray/issues/9475#issuecomment-2357004264
---
 xarray/core/datatree.py       |  8 +++++++-
 xarray/tests/test_datatree.py | 36 ++++++++++++++++++-----------------
 2 files changed, 26 insertions(+), 18 deletions(-)

diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py
index 52d44bec96f..6d9a5a6d8d0 100644
--- a/xarray/core/datatree.py
+++ b/xarray/core/datatree.py
@@ -438,6 +438,7 @@ class DataTree(
     _cache: dict[str, Any]  # used by _CachedAccessor
     _data_variables: dict[Hashable, Variable]
     _node_coord_variables: dict[Hashable, Variable]
+    _node_indexed_coord_variables: dict[Hashable, Variable]
     _node_dims: dict[Hashable, int]
     _node_indexes: dict[Hashable, Index]
     _attrs: dict[Hashable, Any] | None
@@ -451,6 +452,7 @@ class DataTree(
         "_cache",  # used by _CachedAccessor
         "_data_variables",
         "_node_coord_variables",
+        "_node_indexed_coord_variables",
         "_node_dims",
         "_node_indexes",
         "_attrs",
@@ -498,6 +500,9 @@ def _set_node_data(self, dataset: Dataset):
         data_vars, coord_vars = _collect_data_and_coord_variables(dataset)
         self._data_variables = data_vars
         self._node_coord_variables = coord_vars
+        self._node_indexed_coord_variables = {
+            k: coord_vars[k] for k in dataset._indexes
+        }
         self._node_dims = dataset._dims
         self._node_indexes = dataset._indexes
         self._encoding = dataset._encoding
@@ -519,7 +524,8 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
     @property
     def _coord_variables(self) -> ChainMap[Hashable, Variable]:
         return ChainMap(
-            self._node_coord_variables, *(p._node_coord_variables for p in self.parents)
+            self._node_coord_variables,
+            *(p._node_indexed_coord_variables for p in self.parents),
         )
 
     @property
diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py
index 30934f83c63..3365e493090 100644
--- a/xarray/tests/test_datatree.py
+++ b/xarray/tests/test_datatree.py
@@ -216,16 +216,16 @@ def test_is_hollow(self):
 
 
 class TestToDataset:
-    def test_to_dataset(self):
-        base = xr.Dataset(coords={"a": 1})
-        sub = xr.Dataset(coords={"b": 2})
+    def test_to_dataset_inherited(self):
+        base = xr.Dataset(coords={"a": [1], "b": 2})
+        sub = xr.Dataset(coords={"c": [3]})
         tree = DataTree.from_dict({"/": base, "/sub": sub})
         subtree = typing.cast(DataTree, tree["sub"])
 
         assert_identical(tree.to_dataset(inherited=False), base)
         assert_identical(subtree.to_dataset(inherited=False), sub)
 
-        sub_and_base = xr.Dataset(coords={"a": 1, "b": 2})
+        sub_and_base = xr.Dataset(coords={"a": [1], "c": [3]})  # no "b"
         assert_identical(tree.to_dataset(inherited=True), base)
         assert_identical(subtree.to_dataset(inherited=True), sub_and_base)
 
@@ -714,7 +714,8 @@ def test_inherited(self):
         dt["child"] = DataTree()
         child = dt["child"]
 
-        assert set(child.coords) == {"x", "y", "a", "b"}
+        assert set(dt.coords) == {"x", "y", "a", "b"}
+        assert set(child.coords) == {"x", "y"}
 
         actual = child.copy(deep=True)
         actual.coords["x"] = ("x", ["a", "b"])
@@ -729,7 +730,7 @@ def test_inherited(self):
 
         with pytest.raises(KeyError):
             # cannot delete inherited coordinate from child node
-            del child["b"]
+            del child["x"]
 
         # TODO requires a fix for #9472
         # actual = child.copy(deep=True)
@@ -1278,22 +1279,23 @@ def test_inherited_coords_index(self):
         assert "x" in dt["/b"].coords
         xr.testing.assert_identical(dt["/x"], dt["/b/x"])
 
-    def test_inherited_coords_override(self):
+    def test_inherit_only_index_coords(self):
         dt = DataTree.from_dict(
             {
-                "/": xr.Dataset(coords={"x": 1, "y": 2}),
-                "/b": xr.Dataset(coords={"x": 4, "z": 3}),
+                "/": xr.Dataset(coords={"x": [1], "y": 2}),
+                "/b": xr.Dataset(coords={"z": 3}),
             }
         )
         assert dt.coords.keys() == {"x", "y"}
-        root_coords = {"x": 1, "y": 2}
-        sub_coords = {"x": 4, "y": 2, "z": 3}
-        xr.testing.assert_equal(dt["/x"], xr.DataArray(1, coords=root_coords))
-        xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords=root_coords))
-        assert dt["/b"].coords.keys() == {"x", "y", "z"}
-        xr.testing.assert_equal(dt["/b/x"], xr.DataArray(4, coords=sub_coords))
-        xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords))
-        xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords))
+        xr.testing.assert_equal(
+            dt["/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "y": 2})
+        )
+        xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords={"y": 2}))
+        assert dt["/b"].coords.keys() == {"x", "z"}
+        xr.testing.assert_equal(
+            dt["/b/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "z": 3})
+        )
+        xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords={"z": 3}))
 
     def test_inherited_coords_with_index_are_deduplicated(self):
         dt = DataTree.from_dict(