Skip to content

Commit

Permalink
Bug: is_minimal_graph with floating nodes (#68)
Browse files Browse the repository at this point in the history
Description and Motivation
Fixed a bug where metadata on floating nodes would not be correctly carried over to the minimal graph when calling
cai_causal_graph.time_series_causal_graph.TimeSeriesCausalGraph.get_minimal_graph. This also fixes an issue where
cai_causal_graph.time_series_causal_graph.TimeSeriesCausalGraph.is_minimal_graph could return False on minimal
graphs that contain floating nodes.

---------

Co-authored-by: Andrew Lawrence <[email protected]>
  • Loading branch information
maxelliott-causalens and andrew-causalens authored May 3, 2024
1 parent b03b099 commit 67e2da4
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 1 deletion.
11 changes: 10 additions & 1 deletion cai_causal_graph/time_series_causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,16 @@ def get_minimal_graph(self) -> TimeSeriesCausalGraph:
if (
minimal_cg.variables is None or variable not in minimal_cg.variables
) and not minimal_cg.node_exists(variable):
minimal_cg.add_node(variable)
# Guaranteed there is at least one node, take the first for simplicity.
# This aligns with how edges are chosen for minimal graph.
# Only issue is if they have different meta. TODO: CAUSALAI-4784
original_floating_node = self.get_nodes_for_variable_name(variable)[0]
new_floating_node = self._NodeCls(
identifier=variable,
meta=original_floating_node.meta,
variable_type=original_floating_node.variable_type,
)
minimal_cg.add_node(node=new_floating_node)

return minimal_cg

Expand Down
7 changes: 7 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## NEXT

- Fixed a bug where metadata on floating nodes would not be correctly carried over to the minimal graph when calling
`cai_causal_graph.time_series_causal_graph.TimeSeriesCausalGraph.get_minimal_graph`. This also fixes an issue where
`cai_causal_graph.time_series_causal_graph.TimeSeriesCausalGraph.is_minimal_graph` could return `False` on minimal
graphs that contain floating nodes.

## 0.4.8

- Upgraded `docs-builder` dependency to `"~0.2.1"` in the Makefile and updated syntax to support newer `poetry`.
Expand Down
70 changes: 70 additions & 0 deletions tests/test_time_series_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1931,3 +1931,73 @@ def test_remove_non_existent_node_raises(self):
ts_graph._remove_node_from_cache(ts_graph.get_node(ts_graph._NodeCls.identifier_from('B')))
with self.assertRaises(ValueError):
ts_graph.delete_node('B')

def test_floating_nodes_correctly_added_to_minimal_graph(self):
cg = TimeSeriesCausalGraph()
cg.add_time_edge('a', -1, 'a', 0)
cg.add_time_edge('a', 0, 'b', 0)
cg.add_time_edge('a', -1, 'b', -1)
cg.add_node(
variable_name='floating', time_lag=0, meta={'some': 'metadata'}, variable_type=NodeVariableType.BINARY
)

self.assertFalse(cg.is_minimal_graph())

minimal_graph = cg.get_minimal_graph()

self.assertTrue(minimal_graph.is_minimal_graph())

minimal_node = minimal_graph.get_node('floating')
self.assertEqual(minimal_node.meta['some'], 'metadata')
self.assertEqual(minimal_node.variable_type, NodeVariableType.BINARY)
self.assertEqual(minimal_node.time_lag, 0)

def test_lagged_floating_nodes_correctly_added_to_minimal_graph(self):
"""Same as previous test, but floating node is lagged. It should still be at lag=0 in the minimal graph."""
cg = TimeSeriesCausalGraph()
cg.add_time_edge('a', -1, 'a', 0)
cg.add_time_edge('a', 0, 'b', 0)
cg.add_time_edge('a', -1, 'b', -1)
cg.add_node(
variable_name='floating', time_lag=-1, meta={'some': 'metadata'}, variable_type=NodeVariableType.BINARY
)

self.assertFalse(cg.is_minimal_graph())

minimal_graph = cg.get_minimal_graph()

self.assertTrue(minimal_graph.is_minimal_graph())

minimal_node = minimal_graph.get_node('floating')
self.assertEqual(minimal_node.meta['some'], 'metadata')
self.assertEqual(minimal_node.variable_type, NodeVariableType.BINARY)
self.assertEqual(minimal_node.time_lag, 0)

# Add td=0 floating as well to ensure we end up with 1.
cg.add_node(
variable_name='floating', time_lag=0, meta={'some more': 'metadata'}, variable_type=NodeVariableType.BINARY
)

# Add another floating node so minimal graph should have 2 floating nodes at td=0.
cg.add_node(
variable_name='floating_cont',
time_lag=-1,
meta={'some cont': 'metadata'},
variable_type=NodeVariableType.CONTINUOUS,
)

self.assertFalse(cg.is_minimal_graph())

minimal_graph = cg.get_minimal_graph()

self.assertTrue(minimal_graph.is_minimal_graph())

floating_node = minimal_graph.get_node('floating')
self.assertEqual(floating_node.meta['some'], 'metadata') # it takes meta from td=-1 as it is first
self.assertEqual(floating_node.variable_type, NodeVariableType.BINARY)
self.assertEqual(floating_node.time_lag, 0)

floating_cont_node = minimal_graph.get_node('floating_cont')
self.assertEqual(floating_cont_node.meta['some cont'], 'metadata')
self.assertEqual(floating_cont_node.variable_type, NodeVariableType.CONTINUOUS)
self.assertEqual(floating_cont_node.time_lag, 0)

0 comments on commit 67e2da4

Please sign in to comment.