-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add zip_subtrees for paired iteration over DataTrees #9623
Changes from all commits
bde843c
23da8ca
4480e11
c22ff76
373886a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
import collections | ||
import sys | ||
from collections.abc import Iterator, Mapping | ||
from pathlib import PurePosixPath | ||
|
@@ -395,15 +396,18 @@ def subtree(self: Tree) -> Iterator[Tree]: | |
""" | ||
An iterator over all nodes in this tree, including both self and all descendants. | ||
|
||
Iterates depth-first. | ||
Iterates breadth-first. | ||
|
||
See Also | ||
-------- | ||
DataTree.descendants | ||
""" | ||
from xarray.core.iterators import LevelOrderIter | ||
|
||
return LevelOrderIter(self) | ||
# https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode | ||
queue = collections.deque([self]) | ||
while queue: | ||
node = queue.popleft() | ||
yield node | ||
queue.extend(node.children.values()) | ||
|
||
@property | ||
def descendants(self: Tree) -> tuple[Tree, ...]: | ||
|
@@ -765,3 +769,44 @@ def _path_to_ancestor(self, ancestor: NamedNode) -> NodePath: | |
generation_gap = list(parents_paths).index(ancestor.path) | ||
path_upwards = "../" * generation_gap if generation_gap > 0 else "." | ||
return NodePath(path_upwards) | ||
|
||
|
||
def zip_subtrees(*trees: AnyNamedNode) -> Iterator[tuple[AnyNamedNode, ...]]: | ||
"""Iterate over aligned subtrees in breadth-first order. | ||
|
||
Parameters: | ||
----------- | ||
*trees : Tree | ||
Trees to iterate over. | ||
|
||
Yields | ||
------ | ||
Tuples of matching subtrees. | ||
""" | ||
if not trees: | ||
raise TypeError("must pass at least one tree object") | ||
|
||
# https://en.wikipedia.org/wiki/Breadth-first_search#Pseudocode | ||
queue = collections.deque([trees]) | ||
|
||
while queue: | ||
active_nodes = queue.popleft() | ||
|
||
# yield before raising an error, in case the caller chooses to exit | ||
# iteration early | ||
yield active_nodes | ||
|
||
first_node = active_nodes[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note, that theoretically it is possible to pass no arguments to this function. Then Maybe add something like this at the start: if len(trees) < 2:
yield trees
return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I added an error to catch this case. |
||
if any( | ||
sibling.children.keys() != first_node.children.keys() | ||
for sibling in active_nodes[1:] | ||
): | ||
child_summary = " vs ".join( | ||
str(list(node.children)) for node in active_nodes | ||
) | ||
raise ValueError( | ||
f"children at {first_node.path!r} do not match: {child_summary}" | ||
) | ||
|
||
for name in first_node.children: | ||
queue.append(tuple(node.children[name] for node in active_nodes)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replacing the entire
iterators.py
file with 6 lines is so clever it's almost rude 🤣