Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zip_subtrees for paired iteration over DataTrees #9623

Merged
merged 5 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def check_isomorphic(
Also optionally raised if their structure is isomorphic, but the names of any two
respective nodes are not equal.
"""
# TODO: remove require_names_equal and check_from_root. Instead, check that
# all child nodes match, in any order, which will suffice once
# map_over_datasets switches to use zip_subtrees.

if not isinstance(a, TreeNode):
raise TypeError(f"Argument `a` is not a tree, it is of type {type(a)}")
Expand All @@ -68,7 +71,7 @@ def check_isomorphic(

diff = diff_treestructure(a, b, require_names_equal=require_names_equal)

if diff:
if diff is not None:
raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff)


Expand Down
31 changes: 21 additions & 10 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from xarray.core.datatree_render import RenderDataTree
from xarray.core.duck_array_ops import array_equiv, astype
from xarray.core.indexing import MemoryCachedArray
from xarray.core.iterators import LevelOrderIter
from xarray.core.options import OPTIONS, _get_boolean_with_default
from xarray.core.utils import is_duck_array
from xarray.namedarray.pycompat import array_type, to_duck_array, to_numpy
Expand Down Expand Up @@ -981,16 +980,28 @@ def diff_array_repr(a, b, compat):
return "\n".join(summary)


def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str:
def diff_treestructure(
a: DataTree, b: DataTree, require_names_equal: bool
) -> str | None:
"""
Return a summary of why two trees are not isomorphic.
If they are isomorphic return an empty string.
If they are isomorphic return None.
"""
# .subtrees walks nodes in breadth-first-order, in order to produce as
# shallow of a diff as possible

# TODO: switch zip(a.subtree, b.subtree) to zip_subtrees(a, b), and only
# check that child node names match, e.g.,
# for node_a, node_b in zip_subtrees(a, b):
# if node_a.children.keys() != node_b.children.keys():
# diff = dedent(
# f"""\
# Node {node_a.path!r} in the left object has children {list(node_a.children.keys())}
# Node {node_b.path!r} in the right object has children {list(node_b.children.keys())}"""
# )
# return diff

# Walking nodes in "level-order" fashion means walking down from the root breadth-first.
# Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree
# (which it is so long as children are stored in a tuple or list rather than in a set).
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b), strict=True):
for node_a, node_b in zip(a.subtree, b.subtree, strict=True):
path_a, path_b = node_a.path, node_b.path

if require_names_equal and node_a.name != node_b.name:
Expand All @@ -1009,7 +1020,7 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s
)
return diff

return ""
return None


def diff_dataset_repr(a, b, compat):
Expand Down Expand Up @@ -1063,9 +1074,9 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat):

# If the trees structures are different there is no point comparing each node
# TODO we could show any differences in nodes up to the first place that structure differs?
if treestructure_diff or compat == "isomorphic":
if treestructure_diff is not None:
summary.append("\n" + treestructure_diff)
else:
elif compat != "isomorphic":
nodewise_diff = diff_nodewise_summary(a, b, compat)
summary.append("\n" + nodewise_diff)

Expand Down
124 changes: 0 additions & 124 deletions xarray/core/iterators.py

This file was deleted.

53 changes: 49 additions & 4 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import collections
import sys
from collections.abc import Iterator, Mapping
from pathlib import PurePosixPath
Expand Down Expand Up @@ -395,15 +396,18 @@ def subtree(self: Tree) -> Iterator[Tree]:
"""
An iterator over all nodes in this tree, including both self and all descendants.

Iterates depth-first.
Iterates breadth-first.

See Also
--------
DataTree.descendants
"""
from xarray.core.iterators import LevelOrderIter

return LevelOrderIter(self)
# https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode
queue = collections.deque([self])
while queue:
node = queue.popleft()
yield node
queue.extend(node.children.values())
Comment on lines +405 to +410
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing the entire iterators.py file with 6 lines is so clever it's almost rude 🤣


@property
def descendants(self: Tree) -> tuple[Tree, ...]:
Expand Down Expand Up @@ -765,3 +769,44 @@ def _path_to_ancestor(self, ancestor: NamedNode) -> NodePath:
generation_gap = list(parents_paths).index(ancestor.path)
path_upwards = "../" * generation_gap if generation_gap > 0 else "."
return NodePath(path_upwards)


def zip_subtrees(*trees: AnyNamedNode) -> Iterator[tuple[AnyNamedNode, ...]]:
"""Iterate over aligned subtrees in breadth-first order.

Parameters:
-----------
*trees : Tree
Trees to iterate over.

Yields
------
Tuples of matching subtrees.
"""
if not trees:
raise TypeError("must pass at least one tree object")

# https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode
queue = collections.deque([trees])

while queue:
active_nodes = queue.popleft()

# yield before raising an error, in case the caller chooses to exit
# iteration early
yield active_nodes

first_node = active_nodes[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, that theoretically it is possible to pass no arguments to this function. Then trees and here active_nodes is an empty tuple.

Maybe add something like this at the start:

if len(trees) < 2:
    yield trees
    return

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I added an error to catch this case.

if any(
sibling.children.keys() != first_node.children.keys()
for sibling in active_nodes[1:]
):
child_summary = " vs ".join(
str(list(node.children)) for node in active_nodes
)
raise ValueError(
f"children at {first_node.path!r} do not match: {child_summary}"
)

for name in first_node.children:
queue.append(tuple(node.children[name] for node in active_nodes))
Loading
Loading