Skip to content

Commit

Permalink
[BUG, ENH] Implement BFS m-separation algorithm (#48)
Browse files Browse the repository at this point in the history
* implement breadth first search first BayesBall algorithm for m-separation

Signed-off-by: Jaron Lee <[email protected]>
Co-authored-by: Adam Li <[email protected]>
  • Loading branch information
jaron-lee and adam2392 authored Feb 8, 2023
1 parent 35e860c commit ba73b32
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 96 deletions.
1 change: 1 addition & 0 deletions docs/whats_new/v0.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Version 0.1
Changelog
---------

- |Feature| Implement m-separation :func:`pywhy_graphs.networkx.m_separated` with the BallTree approach, by `Jaron Lee`_ (:pr:`48`)
- |Feature| Add support for undirected edges in m-separation :func:`pywhy_graphs.networkx.m_separated`, by `Jaron Lee`_ (:pr:`46`)
- |Feature| Implement uncovered circle path finding inside the :func:`pywhy_graphs.algorithms.uncovered_pd_path`, by `Jaron Lee`_ (:pr:`42`)
- |Feature| Implement and test the :class:`pywhy_graphs.CPDAG` for CPDAGs, by `Adam Li`_ (:pr:`6`)
Expand Down
193 changes: 99 additions & 94 deletions pywhy_graphs/networkx/algorithms/causal/m_separation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
from collections import deque
from copy import deepcopy

import networkx as nx
from networkx.utils import UnionFind

import pywhy_graphs.networkx as pywhy_nx

__all__ = ["m_separated"]

logger = logging.getLogger(__name__)


def m_separated(
G,
Expand All @@ -21,24 +22,13 @@ def m_separated(
"""Check m-separation among 'x' and 'y' given 'z' in mixed-edge causal graph G, which may
contain directed, bidirected, and undirected edges.
This algorithm adapts the linear time algorithm presented in [1]_ currently implemented in
`networkx.algorithms.d_separation` to work for mixed-edge causal graphs, using m-separation
logic detailed in [2]_.
This algorithm works by retaining select edges in each of the directed, bidirected, and
undirected edge subgraphs (if supplied). Then, an undirected graph is created from
the union of allsuch retained edges (without direction information), and then
m-separation of x and y givenz is determined if x is disconnected from y in this graph.
In the directed edge subgraph, nodes and associated edges are removed if they are childless
and not in x | y | z; this process is repeated until no such nodes remain. Then, outgoing
edges from z are removed. The remaining edges are retained.
This implements the m-separation algorithm TESTSEP presented in [1]_ for ancestral mixed
graphs, which is itself adapted from [2]_. Further checks have ensure that it works
for non-ancestral mixed graphs (e.g. ADMGs). The algorithm performs a breadth-first search
over m-connecting paths between 'x' and 'y' (i.e. a path on which every node that is a
collider is in 'z', and every node that is not a collider is not in 'z'). The algorithm
has runtime ``O(|E| + |V|)`` for number of edges ``|E|`` and number of vertices ``|V|``.
In the bidirected edge subgraph, nodes and associated edges are removed if they are not
in x | y | z. The remaining edges are retained.
In the undirected edge subgraph, all edges involving z are removed. The remaining edges are
retained.
Parameters
----------
Expand All @@ -64,12 +54,12 @@ def m_separated(
References
----------
.. [1] Darwiche, A. (2009). Modeling and reasoning with Bayesian networks.
Cambridge: Cambridge University Press.
.. [2] Spirtes, P. and Richardson, T.S.. (1997). A Polynomial Time Algorithm
for Determining DAG Equivalence in the Presence of Latent Variables and Selection
Bias. Proceedings of the Sixth International Workshop on Artificial Intelligence and
Statistics, in Proceedings of Machine Learning Research
.. [1] B. van der Zander, M. Liśkiewicz, and J. Textor, “Separators and Adjustment
Sets in Causal Graphs: Complete Criteria and an Algorithmic Framework,” Artificial
Intelligence, vol. 270, pp. 1–40, May 2019, doi: 10.1016/j.artint.2018.12.006.
.. [2]
See Also
--------
Expand Down Expand Up @@ -100,77 +90,92 @@ def m_separated(
if not nx.is_directed_acyclic_graph(G.get_graphs(directed_edge_name)):
raise nx.NetworkXError("directed edge graph should be directed acyclic")

union_xyz = x.union(y).union(z)
# contains -> and <-> edges from starting node T
forward_deque = deque([])
forward_visited = set()

# get directed edges
has_directed = False
if directed_edge_name in G.edge_types:
has_directed = True
G_directed = nx.DiGraph()
G_directed.add_nodes_from((n, deepcopy(d)) for n, d in G.nodes.items())
G_directed.add_edges_from(G.get_graphs(edge_type=directed_edge_name).edges)

# get bidirected edges subgraph
has_bidirected = False
if bidirected_edge_name in G.edge_types:
has_bidirected = True
G_bidirected = nx.Graph()
G_bidirected.add_nodes_from((n, deepcopy(d)) for n, d in G.nodes.items())
G_bidirected.add_edges_from(G.get_graphs(edge_type=bidirected_edge_name).edges)

# get undirected edges subgraph
has_undirected = False
if undirected_edge_name in G.edge_types:
has_undirected = True
G_undirected = nx.Graph()
G_undirected.add_nodes_from((n, deepcopy(d)) for n, d in G.nodes.items())
G_undirected.add_edges_from(G.get_graphs(edge_type=undirected_edge_name).edges)

# get ancestral subgraph of x | y | z by removing leaves in directed graph that are not
# in x | y | z until no more leaves can be removed.
if has_directed:
leaves = deque([n for n in G_directed.nodes if G_directed.out_degree[n] == 0])
while len(leaves) > 0:
leaf = leaves.popleft()
if leaf not in union_xyz:
for p in G_directed.predecessors(leaf):
if G_directed.out_degree[p] == 1:
leaves.append(p)
G_directed.remove_node(leaf)

# remove outgoing directed edges in z
edges_to_remove = list(G_directed.out_edges(z))
G_directed.remove_edges_from(edges_to_remove)

# remove nodes in bidirected graph that are not in x | y | z (since they will be
# independent due to colliders)
if has_bidirected:
nodes = [n for n in G_bidirected.nodes]
for node in nodes:
if node not in union_xyz:
G_bidirected.remove_node(node)

# remove nodes in undirected graph that are in z to block m-connecting paths
# contains <- and - edges from starting node T
backward_deque = deque(x)
backward_visited = set()
has_undirected = undirected_edge_name in G.edge_types
if has_undirected:
edges_to_remove = list(G_undirected.edges(z))
G_undirected.remove_edges_from(edges_to_remove)
G_undirected = G.get_graphs(edge_type=undirected_edge_name)
has_directed = directed_edge_name in G.edge_types

# make new undirected graph from remaining directed, bidirected, and undirected edges
G_final = nx.Graph()
an_z = z
if has_directed:
G_final.add_edges_from(G_directed.edges)
G_directed = G.get_graphs(edge_type=directed_edge_name)
an_z = set().union(*[nx.ancestors(G_directed, x) for x in z]).union(z)

has_bidirected = bidirected_edge_name in G.edge_types
if has_bidirected:
G_final.add_edges_from(G_bidirected.edges)
if has_undirected:
G_final.add_edges_from(G_undirected.edges)

disjoint_set = UnionFind(G_final.nodes())
for component in nx.connected_components(G_final):
disjoint_set.union(*component)
disjoint_set.union(*x)
disjoint_set.union(*y)

if x and y and disjoint_set[next(iter(x))] == disjoint_set[next(iter(y))]:
return False
else:
return True
G_bidirected = G.get_graphs(edge_type=bidirected_edge_name)

while forward_deque or backward_deque:

if backward_deque:
node = backward_deque.popleft()
backward_visited.add(node)
if node in y:
return False
if node in z:
continue

# add - edges to forward deque
if has_undirected:
for nbr in G_undirected.neighbors(node):
if nbr not in backward_visited:
backward_deque.append(nbr)

if has_directed:
# add <- edges to backward deque
for x, _ in G_directed.in_edges(nbunch=node):
if x not in backward_visited:
backward_deque.append(x)

# add -> edges to forward deque
for _, x in G_directed.out_edges(nbunch=node):
if x not in forward_visited:
forward_deque.append(x)

# add <-> edge to forward deque
if has_bidirected:
for nbr in G_bidirected.neighbors(node):
if nbr not in forward_visited:
forward_deque.append(nbr)

if forward_deque:
node = forward_deque.popleft()
forward_visited.add(node)
if node in y:
return False

# Consider if *-> node <-* is opened due to conditioning on collider,
# or descendant of collider
if node in an_z:

if has_directed:
# add <- edges to backward deque
for x, _ in G_directed.in_edges(nbunch=node):
if x not in backward_visited:
backward_deque.append(x)

# add <-> edge to backward deque
if has_bidirected:
for nbr in G_bidirected.neighbors(node):
if nbr not in forward_visited:
forward_deque.append(nbr)

if node not in z:
if has_undirected:
for nbr in G_undirected.neighbors(node):
if nbr not in backward_visited:
backward_deque.append(nbr)

if has_directed:
# add -> edges to forward deque
for _, x in G_directed.out_edges(nbunch=node):
if x not in forward_visited:
forward_deque.append(x)

return True
84 changes: 82 additions & 2 deletions pywhy_graphs/networkx/algorithms/causal/tests/test_m_separation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import networkx as nx
import pytest
from networkx.exception import NetworkXError
Expand All @@ -6,6 +8,8 @@


def test_m_separation():
logging.getLogger().setLevel(logging.DEBUG)

digraph = nx.path_graph(4, create_using=nx.DiGraph)
digraph.add_edge(2, 4)
bigraph = nx.Graph([(2, 3)])
Expand Down Expand Up @@ -51,7 +55,8 @@ def test_m_separation():
assert pywhy_nx.m_separated(G, {1}, {3}, set())
assert not pywhy_nx.m_separated(G, {1}, {3}, {2})

# check that 1 _|_ 5 in graph 1 - 2 -> 3 <- 4 <-> 5
# check that m-sep in graph with all kinds of edges
# e.g. 1 _|_ 5 in graph 1 - 2 -> 3 <- 4 <-> 5
digraph = nx.DiGraph()
digraph.add_nodes_from([1, 2, 3, 4, 5])
digraph.add_edge(2, 3)
Expand All @@ -66,7 +71,26 @@ def test_m_separation():

assert pywhy_nx.m_separated(G, {1}, {4}, set())

# check that 1 not _|_ 3 in graph 1 - 2 - 3
# e.g. 1 _|_ 5 | 7 in 1 - 2 -> 3 <-> 4 - 5, 3 -> 6, 2 - 7 <-> 5
digraph = nx.DiGraph()
digraph.add_nodes_from([1, 2, 3, 4, 5, 6, 7])
digraph.add_edges_from([(2, 3), (3, 6)])
bigraph = nx.Graph([(3, 4), (7, 5)])
bigraph.add_nodes_from(digraph)
ungraph = nx.Graph([(1, 2), (4, 5), (2, 7)])
ungraph.add_nodes_from(digraph)
G = pywhy_nx.MixedEdgeGraph(
[digraph, bigraph, ungraph], ["directed", "bidirected", "undirected"]
)

assert pywhy_nx.m_separated(G, {1}, {5}, {7})
assert not pywhy_nx.m_separated(G, {1}, {5}, set())
assert not pywhy_nx.m_separated(G, {1}, {5}, {6})
print(G.edges())
assert not pywhy_nx.m_separated(G, {1}, {5}, {6, 7})

# check m-sep works in undirected graphs:
# e.g. that 1 not _|_ 3 in graph 1 - 2 - 3
ungraph = nx.Graph([(1, 2), (2, 3)])
G = pywhy_nx.MixedEdgeGraph([ungraph], ["undirected"])

Expand All @@ -91,3 +115,59 @@ def test_m_separation():

assert pywhy_nx.m_separated(G, {1}, {3}, {2})
assert pywhy_nx.m_separated(G, {1}, {2}, set())

# check fig 6 of Zhang 2008

digraph = nx.DiGraph()
digraph.add_nodes_from(["A", "B", "C", "D"])
digraph.add_edge("A", "C")
digraph.add_edge("C", "D")
digraph.add_edge("B", "D")
bigraph = nx.Graph()
bigraph.add_edge("A", "B")
G = pywhy_nx.MixedEdgeGraph([digraph, bigraph], ["directed", "bidirected"])
assert not pywhy_nx.m_separated(G, {"A"}, {"D"}, {"C"})
assert pywhy_nx.m_separated(G, {"A"}, {"D"}, {"B", "C"})

assert pywhy_nx.m_separated(G, {"B"}, {"C"}, {"A"})
assert not pywhy_nx.m_separated(G, {"B"}, {"C"}, {"A", "D"})
assert not pywhy_nx.m_separated(G, {"B"}, {"C"}, set())

# check more complicated ADMGs

# check inducing paths behave correctly
digraph = nx.DiGraph()
digraph.add_nodes_from(["A", "B", "C", "D"])
digraph.add_edge("B", "C")
digraph.add_edge("C", "D")
bigraph = nx.Graph()
bigraph.add_edge("A", "B")
bigraph.add_edge("B", "C")
G = pywhy_nx.MixedEdgeGraph([digraph, bigraph], ["directed", "bidirected"])

assert not pywhy_nx.m_separated(G, {"A"}, {"C"}, {"B"})
assert not pywhy_nx.m_separated(G, {"A"}, {"C"}, set())
assert not pywhy_nx.m_separated(G, {"A"}, {"D"}, set())

# check conditioning on collider of descendant in bidirected graph works
digraph = nx.DiGraph()
digraph.add_nodes_from(["A", "B", "C", "D"])
digraph.add_edge("B", "D")
digraph.add_edge("A", "B")
digraph.add_edge("C", "B")

G = pywhy_nx.MixedEdgeGraph([digraph], ["directed"])

assert not pywhy_nx.m_separated(G, {"A"}, {"C"}, {"D"})
assert pywhy_nx.m_separated(G, {"A"}, {"C"}, set())

digraph = nx.DiGraph()
digraph.add_nodes_from(["A", "B", "C", "D"])
digraph.add_edge("B", "D")
digraph.add_edge("A", "B")
bigraph = nx.Graph()
bigraph.add_edge("B", "C")
G = pywhy_nx.MixedEdgeGraph([digraph, bigraph], ["directed", "bidirected"])

assert not pywhy_nx.m_separated(G, {"A"}, {"C"}, {"D"})
assert pywhy_nx.m_separated(G, {"A"}, {"C"}, set())

0 comments on commit ba73b32

Please sign in to comment.