From ba73b32a987e51fe01d6478d649938c2b72834ba Mon Sep 17 00:00:00 2001 From: Jaron Lee Date: Wed, 8 Feb 2023 11:31:13 -0500 Subject: [PATCH] [BUG, ENH] Implement BFS m-separation algorithm (#48) * implement breadth first search first BayesBall algorithm for m-separation Signed-off-by: Jaron Lee Co-authored-by: Adam Li --- docs/whats_new/v0.1.rst | 1 + .../algorithms/causal/m_separation.py | 193 +++++++++--------- .../causal/tests/test_m_separation.py | 84 +++++++- 3 files changed, 182 insertions(+), 96 deletions(-) diff --git a/docs/whats_new/v0.1.rst b/docs/whats_new/v0.1.rst index 554da5c59..9661c99e6 100644 --- a/docs/whats_new/v0.1.rst +++ b/docs/whats_new/v0.1.rst @@ -26,6 +26,7 @@ Version 0.1 Changelog --------- +- |Feature| Implement m-separation :func:`pywhy_graphs.networkx.m_separated` with the BallTree approach, by `Jaron Lee`_ (:pr:`48`) - |Feature| Add support for undirected edges in m-separation :func:`pywhy_graphs.networkx.m_separated`, by `Jaron Lee`_ (:pr:`46`) - |Feature| Implement uncovered circle path finding inside the :func:`pywhy_graphs.algorithms.uncovered_pd_path`, by `Jaron Lee`_ (:pr:`42`) - |Feature| Implement and test the :class:`pywhy_graphs.CPDAG` for CPDAGs, by `Adam Li`_ (:pr:`6`) diff --git a/pywhy_graphs/networkx/algorithms/causal/m_separation.py b/pywhy_graphs/networkx/algorithms/causal/m_separation.py index bfcd6b95a..12958e2b7 100644 --- a/pywhy_graphs/networkx/algorithms/causal/m_separation.py +++ b/pywhy_graphs/networkx/algorithms/causal/m_separation.py @@ -1,13 +1,14 @@ +import logging from collections import deque -from copy import deepcopy import networkx as nx -from networkx.utils import UnionFind import pywhy_graphs.networkx as pywhy_nx __all__ = ["m_separated"] +logger = logging.getLogger(__name__) + def m_separated( G, @@ -21,24 +22,13 @@ def m_separated( """Check m-separation among 'x' and 'y' given 'z' in mixed-edge causal graph G, which may contain directed, bidirected, and undirected edges. - This algorithm adapts the linear time algorithm presented in [1]_ currently implemented in - `networkx.algorithms.d_separation` to work for mixed-edge causal graphs, using m-separation - logic detailed in [2]_. - - This algorithm works by retaining select edges in each of the directed, bidirected, and - undirected edge subgraphs (if supplied). Then, an undirected graph is created from - the union of allsuch retained edges (without direction information), and then - m-separation of x and y givenz is determined if x is disconnected from y in this graph. - - In the directed edge subgraph, nodes and associated edges are removed if they are childless - and not in x | y | z; this process is repeated until no such nodes remain. Then, outgoing - edges from z are removed. The remaining edges are retained. + This implements the m-separation algorithm TESTSEP presented in [1]_ for ancestral mixed + graphs, which is itself adapted from [2]_. Further checks have ensure that it works + for non-ancestral mixed graphs (e.g. ADMGs). The algorithm performs a breadth-first search + over m-connecting paths between 'x' and 'y' (i.e. a path on which every node that is a + collider is in 'z', and every node that is not a collider is not in 'z'). The algorithm + has runtime ``O(|E| + |V|)`` for number of edges ``|E|`` and number of vertices ``|V|``. - In the bidirected edge subgraph, nodes and associated edges are removed if they are not - in x | y | z. The remaining edges are retained. - - In the undirected edge subgraph, all edges involving z are removed. The remaining edges are - retained. Parameters ---------- @@ -64,12 +54,12 @@ def m_separated( References ---------- - .. [1] Darwiche, A. (2009). Modeling and reasoning with Bayesian networks. - Cambridge: Cambridge University Press. - .. [2] Spirtes, P. and Richardson, T.S.. (1997). A Polynomial Time Algorithm - for Determining DAG Equivalence in the Presence of Latent Variables and Selection - Bias. Proceedings of the Sixth International Workshop on Artificial Intelligence and - Statistics, in Proceedings of Machine Learning Research + .. [1] B. van der Zander, M. Liśkiewicz, and J. Textor, “Separators and Adjustment + Sets in Causal Graphs: Complete Criteria and an Algorithmic Framework,” Artificial + Intelligence, vol. 270, pp. 1–40, May 2019, doi: 10.1016/j.artint.2018.12.006. + + .. [2] + See Also -------- @@ -100,77 +90,92 @@ def m_separated( if not nx.is_directed_acyclic_graph(G.get_graphs(directed_edge_name)): raise nx.NetworkXError("directed edge graph should be directed acyclic") - union_xyz = x.union(y).union(z) + # contains -> and <-> edges from starting node T + forward_deque = deque([]) + forward_visited = set() - # get directed edges - has_directed = False - if directed_edge_name in G.edge_types: - has_directed = True - G_directed = nx.DiGraph() - G_directed.add_nodes_from((n, deepcopy(d)) for n, d in G.nodes.items()) - G_directed.add_edges_from(G.get_graphs(edge_type=directed_edge_name).edges) - - # get bidirected edges subgraph - has_bidirected = False - if bidirected_edge_name in G.edge_types: - has_bidirected = True - G_bidirected = nx.Graph() - G_bidirected.add_nodes_from((n, deepcopy(d)) for n, d in G.nodes.items()) - G_bidirected.add_edges_from(G.get_graphs(edge_type=bidirected_edge_name).edges) - - # get undirected edges subgraph - has_undirected = False - if undirected_edge_name in G.edge_types: - has_undirected = True - G_undirected = nx.Graph() - G_undirected.add_nodes_from((n, deepcopy(d)) for n, d in G.nodes.items()) - G_undirected.add_edges_from(G.get_graphs(edge_type=undirected_edge_name).edges) - - # get ancestral subgraph of x | y | z by removing leaves in directed graph that are not - # in x | y | z until no more leaves can be removed. - if has_directed: - leaves = deque([n for n in G_directed.nodes if G_directed.out_degree[n] == 0]) - while len(leaves) > 0: - leaf = leaves.popleft() - if leaf not in union_xyz: - for p in G_directed.predecessors(leaf): - if G_directed.out_degree[p] == 1: - leaves.append(p) - G_directed.remove_node(leaf) - - # remove outgoing directed edges in z - edges_to_remove = list(G_directed.out_edges(z)) - G_directed.remove_edges_from(edges_to_remove) - - # remove nodes in bidirected graph that are not in x | y | z (since they will be - # independent due to colliders) - if has_bidirected: - nodes = [n for n in G_bidirected.nodes] - for node in nodes: - if node not in union_xyz: - G_bidirected.remove_node(node) - - # remove nodes in undirected graph that are in z to block m-connecting paths + # contains <- and - edges from starting node T + backward_deque = deque(x) + backward_visited = set() + has_undirected = undirected_edge_name in G.edge_types if has_undirected: - edges_to_remove = list(G_undirected.edges(z)) - G_undirected.remove_edges_from(edges_to_remove) + G_undirected = G.get_graphs(edge_type=undirected_edge_name) + has_directed = directed_edge_name in G.edge_types - # make new undirected graph from remaining directed, bidirected, and undirected edges - G_final = nx.Graph() + an_z = z if has_directed: - G_final.add_edges_from(G_directed.edges) + G_directed = G.get_graphs(edge_type=directed_edge_name) + an_z = set().union(*[nx.ancestors(G_directed, x) for x in z]).union(z) + + has_bidirected = bidirected_edge_name in G.edge_types if has_bidirected: - G_final.add_edges_from(G_bidirected.edges) - if has_undirected: - G_final.add_edges_from(G_undirected.edges) - - disjoint_set = UnionFind(G_final.nodes()) - for component in nx.connected_components(G_final): - disjoint_set.union(*component) - disjoint_set.union(*x) - disjoint_set.union(*y) - - if x and y and disjoint_set[next(iter(x))] == disjoint_set[next(iter(y))]: - return False - else: - return True + G_bidirected = G.get_graphs(edge_type=bidirected_edge_name) + + while forward_deque or backward_deque: + + if backward_deque: + node = backward_deque.popleft() + backward_visited.add(node) + if node in y: + return False + if node in z: + continue + + # add - edges to forward deque + if has_undirected: + for nbr in G_undirected.neighbors(node): + if nbr not in backward_visited: + backward_deque.append(nbr) + + if has_directed: + # add <- edges to backward deque + for x, _ in G_directed.in_edges(nbunch=node): + if x not in backward_visited: + backward_deque.append(x) + + # add -> edges to forward deque + for _, x in G_directed.out_edges(nbunch=node): + if x not in forward_visited: + forward_deque.append(x) + + # add <-> edge to forward deque + if has_bidirected: + for nbr in G_bidirected.neighbors(node): + if nbr not in forward_visited: + forward_deque.append(nbr) + + if forward_deque: + node = forward_deque.popleft() + forward_visited.add(node) + if node in y: + return False + + # Consider if *-> node <-* is opened due to conditioning on collider, + # or descendant of collider + if node in an_z: + + if has_directed: + # add <- edges to backward deque + for x, _ in G_directed.in_edges(nbunch=node): + if x not in backward_visited: + backward_deque.append(x) + + # add <-> edge to backward deque + if has_bidirected: + for nbr in G_bidirected.neighbors(node): + if nbr not in forward_visited: + forward_deque.append(nbr) + + if node not in z: + if has_undirected: + for nbr in G_undirected.neighbors(node): + if nbr not in backward_visited: + backward_deque.append(nbr) + + if has_directed: + # add -> edges to forward deque + for _, x in G_directed.out_edges(nbunch=node): + if x not in forward_visited: + forward_deque.append(x) + + return True diff --git a/pywhy_graphs/networkx/algorithms/causal/tests/test_m_separation.py b/pywhy_graphs/networkx/algorithms/causal/tests/test_m_separation.py index 2f2c0397a..5a5b9bc5a 100644 --- a/pywhy_graphs/networkx/algorithms/causal/tests/test_m_separation.py +++ b/pywhy_graphs/networkx/algorithms/causal/tests/test_m_separation.py @@ -1,3 +1,5 @@ +import logging + import networkx as nx import pytest from networkx.exception import NetworkXError @@ -6,6 +8,8 @@ def test_m_separation(): + logging.getLogger().setLevel(logging.DEBUG) + digraph = nx.path_graph(4, create_using=nx.DiGraph) digraph.add_edge(2, 4) bigraph = nx.Graph([(2, 3)]) @@ -51,7 +55,8 @@ def test_m_separation(): assert pywhy_nx.m_separated(G, {1}, {3}, set()) assert not pywhy_nx.m_separated(G, {1}, {3}, {2}) - # check that 1 _|_ 5 in graph 1 - 2 -> 3 <- 4 <-> 5 + # check that m-sep in graph with all kinds of edges + # e.g. 1 _|_ 5 in graph 1 - 2 -> 3 <- 4 <-> 5 digraph = nx.DiGraph() digraph.add_nodes_from([1, 2, 3, 4, 5]) digraph.add_edge(2, 3) @@ -66,7 +71,26 @@ def test_m_separation(): assert pywhy_nx.m_separated(G, {1}, {4}, set()) - # check that 1 not _|_ 3 in graph 1 - 2 - 3 + # e.g. 1 _|_ 5 | 7 in 1 - 2 -> 3 <-> 4 - 5, 3 -> 6, 2 - 7 <-> 5 + digraph = nx.DiGraph() + digraph.add_nodes_from([1, 2, 3, 4, 5, 6, 7]) + digraph.add_edges_from([(2, 3), (3, 6)]) + bigraph = nx.Graph([(3, 4), (7, 5)]) + bigraph.add_nodes_from(digraph) + ungraph = nx.Graph([(1, 2), (4, 5), (2, 7)]) + ungraph.add_nodes_from(digraph) + G = pywhy_nx.MixedEdgeGraph( + [digraph, bigraph, ungraph], ["directed", "bidirected", "undirected"] + ) + + assert pywhy_nx.m_separated(G, {1}, {5}, {7}) + assert not pywhy_nx.m_separated(G, {1}, {5}, set()) + assert not pywhy_nx.m_separated(G, {1}, {5}, {6}) + print(G.edges()) + assert not pywhy_nx.m_separated(G, {1}, {5}, {6, 7}) + + # check m-sep works in undirected graphs: + # e.g. that 1 not _|_ 3 in graph 1 - 2 - 3 ungraph = nx.Graph([(1, 2), (2, 3)]) G = pywhy_nx.MixedEdgeGraph([ungraph], ["undirected"]) @@ -91,3 +115,59 @@ def test_m_separation(): assert pywhy_nx.m_separated(G, {1}, {3}, {2}) assert pywhy_nx.m_separated(G, {1}, {2}, set()) + + # check fig 6 of Zhang 2008 + + digraph = nx.DiGraph() + digraph.add_nodes_from(["A", "B", "C", "D"]) + digraph.add_edge("A", "C") + digraph.add_edge("C", "D") + digraph.add_edge("B", "D") + bigraph = nx.Graph() + bigraph.add_edge("A", "B") + G = pywhy_nx.MixedEdgeGraph([digraph, bigraph], ["directed", "bidirected"]) + assert not pywhy_nx.m_separated(G, {"A"}, {"D"}, {"C"}) + assert pywhy_nx.m_separated(G, {"A"}, {"D"}, {"B", "C"}) + + assert pywhy_nx.m_separated(G, {"B"}, {"C"}, {"A"}) + assert not pywhy_nx.m_separated(G, {"B"}, {"C"}, {"A", "D"}) + assert not pywhy_nx.m_separated(G, {"B"}, {"C"}, set()) + + # check more complicated ADMGs + + # check inducing paths behave correctly + digraph = nx.DiGraph() + digraph.add_nodes_from(["A", "B", "C", "D"]) + digraph.add_edge("B", "C") + digraph.add_edge("C", "D") + bigraph = nx.Graph() + bigraph.add_edge("A", "B") + bigraph.add_edge("B", "C") + G = pywhy_nx.MixedEdgeGraph([digraph, bigraph], ["directed", "bidirected"]) + + assert not pywhy_nx.m_separated(G, {"A"}, {"C"}, {"B"}) + assert not pywhy_nx.m_separated(G, {"A"}, {"C"}, set()) + assert not pywhy_nx.m_separated(G, {"A"}, {"D"}, set()) + + # check conditioning on collider of descendant in bidirected graph works + digraph = nx.DiGraph() + digraph.add_nodes_from(["A", "B", "C", "D"]) + digraph.add_edge("B", "D") + digraph.add_edge("A", "B") + digraph.add_edge("C", "B") + + G = pywhy_nx.MixedEdgeGraph([digraph], ["directed"]) + + assert not pywhy_nx.m_separated(G, {"A"}, {"C"}, {"D"}) + assert pywhy_nx.m_separated(G, {"A"}, {"C"}, set()) + + digraph = nx.DiGraph() + digraph.add_nodes_from(["A", "B", "C", "D"]) + digraph.add_edge("B", "D") + digraph.add_edge("A", "B") + bigraph = nx.Graph() + bigraph.add_edge("B", "C") + G = pywhy_nx.MixedEdgeGraph([digraph, bigraph], ["directed", "bidirected"]) + + assert not pywhy_nx.m_separated(G, {"A"}, {"C"}, {"D"}) + assert pywhy_nx.m_separated(G, {"A"}, {"C"}, set())