Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support alternative names for the root node in DataTree.from_dict #9638

Merged
merged 2 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,10 +1104,12 @@ def from_dict(
d : dict-like
A mapping from path names to xarray.Dataset or DataTree objects.

Path names are to be given as unix-like path. If path names containing more than one
part are given, new tree nodes will be constructed as necessary.
Path names are to be given as unix-like path. If path names
containing more than one part are given, new tree nodes will be
constructed as necessary.

To assign data to the root node of the tree use "/" as the path.
To assign data to the root node of the tree use "", ".", "/" or "./"
as the path.
name : Hashable | None, optional
Name for the root node of the tree. Default is None.

Expand All @@ -1119,17 +1121,27 @@ def from_dict(
-----
If your dictionary is nested you will need to flatten it before using this method.
"""

# First create the root node
# Find any values corresponding to the root
d_cast = dict(d)
root_data = d_cast.pop("/", None)
root_data = None
for key in ("", ".", "/", "./"):
if key in d_cast:
if root_data is not None:
raise ValueError(
"multiple entries found corresponding to the root node"
)
root_data = d_cast.pop(key)

# Create the root node
if isinstance(root_data, DataTree):
obj = root_data.copy()
obj.name = name
elif root_data is None or isinstance(root_data, Dataset):
obj = cls(name=name, dataset=root_data, children=None)
else:
raise TypeError(
f'root node data (at "/") must be a Dataset or DataTree, got {type(root_data)}'
f'root node data (at "", ".", "/" or "./") must be a Dataset '
f"or DataTree, got {type(root_data)}"
)

def depth(item) -> int:
Expand All @@ -1141,11 +1153,10 @@ def depth(item) -> int:
# Sort keys by depth so as to insert nodes from root first (see GH issue #9276)
for path, data in sorted(d_cast.items(), key=depth):
# Create and set new node
node_name = NodePath(path).name
if isinstance(data, DataTree):
new_node = data.copy()
elif isinstance(data, Dataset) or data is None:
new_node = cls(name=node_name, dataset=data)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to set name explicitly because it is already guaranteed to be consistent (via NamedNode._post_attach).

new_node = cls(dataset=data)
else:
raise TypeError(f"invalid values: {data}")
obj._set_item(
Expand Down Expand Up @@ -1683,7 +1694,7 @@ def reduce(
numeric_only=numeric_only,
**kwargs,
)
path = "/" if node is self else node.relative_to(self)
path = node.relative_to(self)
result[path] = node_result
return type(self).from_dict(result, name=self.name)

Expand Down Expand Up @@ -1718,7 +1729,7 @@ def _selective_indexing(
# with a scalar) can also create scalar coordinates, which
# need to be explicitly removed.
del node_result.coords[k]
path = "/" if node is self else node.relative_to(self)
path = node.relative_to(self)
result[path] = node_result
return type(self).from_dict(result, name=self.name)

Expand Down
42 changes: 42 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,48 @@ def test_array_values(self) -> None:
with pytest.raises(TypeError):
DataTree.from_dict(data) # type: ignore[arg-type]

def test_relative_paths(self) -> None:
tree = DataTree.from_dict({".": None, "foo": None, "./bar": None, "x/y": None})
paths = [node.path for node in tree.subtree]
assert paths == [
"/",
"/foo",
"/bar",
"/x",
"/x/y",
]

def test_root_keys(self):
ds = Dataset({"x": 1})
expected = DataTree(dataset=ds)

actual = DataTree.from_dict({"": ds})
assert_identical(actual, expected)

actual = DataTree.from_dict({".": ds})
assert_identical(actual, expected)

actual = DataTree.from_dict({"/": ds})
assert_identical(actual, expected)

actual = DataTree.from_dict({"./": ds})
assert_identical(actual, expected)

with pytest.raises(
ValueError, match="multiple entries found corresponding to the root node"
):
DataTree.from_dict({"": ds, "/": ds})

def test_name(self):
tree = DataTree.from_dict({"/": None}, name="foo")
assert tree.name == "foo"

tree = DataTree.from_dict({"/": DataTree()}, name="foo")
assert tree.name == "foo"

tree = DataTree.from_dict({"/": DataTree(name="bar")}, name="foo")
assert tree.name == "foo"


class TestDatasetView:
def test_view_contents(self) -> None:
Expand Down
Loading