diff --git a/src/cyclebane/graph.py b/src/cyclebane/graph.py index 682e9c6..1a6af40 100644 --- a/src/cyclebane/graph.py +++ b/src/cyclebane/graph.py @@ -430,7 +430,8 @@ def __setitem__(self, branch: Hashable | slice, other: Graph) -> None: intersection_nodes = set(graph.nodes) & set(new_branch.nodes) - {branch} for node in intersection_nodes: - if graph.pred[node] != new_branch.pred[node]: + new_pred = new_branch.pred[node] + if new_pred and graph.pred[node] != new_pred: raise ValueError( f"Node inputs differ for node '{node}':\n" f" {graph.pred[node]}\n" diff --git a/tests/graph_test.py b/tests/graph_test.py index d036679..c204666 100644 --- a/tests/graph_test.py +++ b/tests/graph_test.py @@ -639,6 +639,49 @@ def test_setitem_preserves_nodes_that_are_ancestors_of_unrelated_node() -> None: nx.utils.graphs_equal(graph.to_networkx(), g) +def test_setitem_inserts_graph_with_missing_parent() -> None: + g1 = nx.DiGraph() + g1.add_edge('a', 'b') + g1.add_edge('n', 'c') + g2 = nx.DiGraph() + g2.add_edge('b', 'n') + + graph = cb.Graph(g1) + graph['n'] = cb.Graph(g2) + expected = g1.copy() + expected.add_edge('b', 'n') + nx.utils.graphs_equal(graph.to_networkx(), expected) + + +def test_setitem_fails_with_partial_parent_mismatch() -> None: + g1 = nx.DiGraph() + g1.add_edge('a', 'b') # only in g1 + g1.add_edge('e', 'b') # in both graphs + g1.add_edge('n', 'c') + g2 = nx.DiGraph() + g2.add_edge('b', 'n') + g2.add_edge('e', 'b') # in both graphs + g2.add_edge('x', 'b') # only in g2 + + graph = cb.Graph(g1) + with pytest.raises(ValueError, match="Node inputs differ for node 'b'"): + graph['n'] = cb.Graph(g2) + + +def test_setitem_fails_when_grandparents_change() -> None: + g1 = nx.DiGraph() + g1.add_edge('a1', 'b') + g1.add_edge('a2', 'b') + g1.add_edge('b', 'c') + g2 = nx.DiGraph() # no a2 -> b edge + g2.add_edge('a1', 'b') + g2.add_edge('b', 'c') + + graph = cb.Graph(g1) + with pytest.raises(ValueError, match="Node inputs differ for node 'b'"): + graph['c'] = cb.Graph(g2) + + def test_getitem_returns_graph_containing_only_key_and_ancestors() -> None: g = nx.DiGraph() g.add_edge('a', 'b')