Skip to content

Commit

Permalink
Feature: get_nodes_between method (#87)
Browse files Browse the repository at this point in the history
* Deprecation: Deprecate kwargs in CausalGraph construction

* Point to method

* Feature: get_nodes_between

* Merge main

* Remove cl-dev-tools

* Comments
  • Loading branch information
maxelliott-causalens authored Oct 10, 2024
1 parent d233702 commit 60eb3a0
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 2 deletions.
66 changes: 66 additions & 0 deletions cai_causal_graph/causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,6 +1726,72 @@ def get_common_descendants(self, node_1: NodeLike, node_2: NodeLike) -> Set[str]

return self.get_descendants(node_1).intersection(self.get_descendants(node_2))

def get_nodes_between(self, node_1: NodeLike, node_2: NodeLike) -> Set[Node]:
"""
Get the set of nodes that are on directed causal paths between two nodes.
The nodes themselves will be included in the final set.
If there are no causal paths between the nodes then an empty set will be returned.
This method will be faster and significantly more memory efficient than
`cai_causal_graph.causal_graph.CausalGraph.get_all_causal_paths` for large graphs, but only returns the set of
nodes of the causal paths, rather than each individual causal path. The worst case for this method is
`O(n + e)` time complexity and `O(n)` memory complexity, where `n` and `e` are the number of nodes and
edges in the graph respectively.
:param node_1: a single node-identifier coercible object, representing the source node.
:param node_2: a single node-identifier coercible object, representing the destination node.
:return: a set of nodes that are on causal paths between the two provided nodes.
"""
assert self.is_dag(), 'This method only works for DAGs but the current graph is not a DAG.'

# Set up cache: for each seen node, cache whether there is a causal path between it and the destination
seen_nodes: dict[str, bool] = {}

start = self.get_node(node_1)
end = self.get_node(node_2)

def _has_causal_path_inner(start_node_: Node, end_node_: Node) -> bool:
"""
Check if there is a causal path from the start node to the end node.
This is done recursively by checking:
1. if the start node is the end node, return True
2. if the start node is a sink node, return False
3. else recursively check if any of the children of that start have a causal path to the end node
The return value for each node is cached in the `seen_nodes` dictionary to avoid repeat calculations. This
means that for a full graph this will run in `O(e)` time, where e is the number of edges in the graph.
"""
if start_node_.identifier in seen_nodes:
return seen_nodes[start_node_.identifier]
if start_node_ == end_node_:
seen_nodes[start_node_.identifier] = True
return True
if start_node_.is_sink_node():
seen_nodes[start_node_.identifier] = False
return False

start_children = [e.destination for e in start_node_._outbound_edges]

# Run recursively for all children, which itself will cache the results of all descendant nodes.
children_have_causal_paths = [_has_causal_path_inner(child, end_node_) for child in start_children]

has_causal_path = any(children_have_causal_paths)
seen_nodes[start_node_.identifier] = has_causal_path
return has_causal_path

# Check recursively which of the start node of its descendants have a causal path to the end node
causal_path = _has_causal_path_inner(start, end)

# Can return an empty set if there is no causal path from the start node to the end node
if not causal_path:
return set()

# Otherwise return all nodes that have a causal path to the destination (including the source and destination)
return {self.get_node(node) for node, has_path in seen_nodes.items() if has_path}

def get_d_separation_set(self, node_1: NodeLike, node_2: NodeLike) -> Set[str]:
"""
Return a minimal d-separation set between two nodes.
Expand Down
9 changes: 7 additions & 2 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Changelog

## NEXT

- Added the `cai_causal_graph.causal_graph.CausalGraph.get_nodes_between` method, which returns the set of all nodes
that are on a directed causal path between two nodes in the graph.

## 0.5.5

- Fixed a bug in the `cai_causal_graph.causal_graph.CausalGraph.add_edge` method that would allow the addition of an
edge between `b` and `a` when the reversed edge, i.e. between `a` and `b` was already specified.
- Added the `CausalGraphError.ReverseEdgeExistsError` exception, which distinguishes errors arising from the
introduction of cycles or reverse edges.
- Added the `CausalGraphError.ReverseEdgeExistsError` exception, which distinguishes errors arising from the
introduction of cycles or reverse edges.

## 0.5.4

Expand Down
68 changes: 68 additions & 0 deletions tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,71 @@ def test_with_target_and_features(self):
g.add_edge('a', output_node)

self.assertTrue(g.directed_path_exists(input_node, output_node))


class TestGetNodesBetween(unittest.TestCase):
def test_get_nodes_between(self):
cg = CausalGraph()
cg.add_edge('a', 'b')
cg.add_edge('a', 'c')
cg.add_edge('c', 'b')
cg.add_edge('b', 'd')
cg.add_edge('c', 'd')
cg.add_edge('c', 'e')
cg.add_edge('e', 'f')

# Source to sink
nodes = cg.get_nodes_between('a', 'd')
self.assertSetEqual(nodes, set(cg.get_nodes(['a', 'b', 'c', 'd'])))

# Source to other sink
nodes = cg.get_nodes_between('a', 'f')
self.assertSetEqual(nodes, set(cg.get_nodes(['a', 'c', 'e', 'f'])))

# Non-source to sink
nodes = cg.get_nodes_between('c', 'd')
self.assertSetEqual(nodes, set(cg.get_nodes(['b', 'c', 'd'])))

# Non-source to other sink
nodes = cg.get_nodes_between('c', 'f')
self.assertSetEqual(nodes, set(cg.get_nodes(['c', 'e', 'f'])))

# source to non-sink
nodes = cg.get_nodes_between('a', 'b')
self.assertSetEqual(nodes, set(cg.get_nodes(['a', 'c', 'b'])))

nodes = cg.get_nodes_between('a', 'e')
self.assertSetEqual(nodes, set(cg.get_nodes(['a', 'c', 'e'])))

# non-source to non-sink
nodes = cg.get_nodes_between('c', 'b')
self.assertSetEqual(nodes, set(cg.get_nodes(['c', 'b'])))

def test_get_nodes_between_equal_source_and_destination(self):
"""Equal source and destination returns the source/destination node."""

cg = CausalGraph()
cg.add_edge('a', 'b')
cg.add_edge('a', 'c')
cg.add_edge('c', 'b')
cg.add_edge('b', 'd')
cg.add_edge('c', 'd')
cg.add_edge('c', 'e')
cg.add_edge('e', 'f')

nodes = cg.get_nodes_between('a', 'a')
self.assertSetEqual(nodes, set(cg.get_nodes(['a'])))

def test_get_nodes_between_no_nodes(self):
"""No nodes between the source and destination returns an empty set."""
cg = CausalGraph()
cg.add_edge('a', 'b')
cg.add_edge('a', 'c')
cg.add_edge('c', 'b')
cg.add_edge('b', 'd')
cg.add_edge('c', 'd')
cg.add_edge('c', 'e')
cg.add_edge('e', 'f')

nodes = cg.get_nodes_between('b', 'a')
self.assertSetEqual(nodes, set())

0 comments on commit 60eb3a0

Please sign in to comment.