diff --git a/cai_causal_graph/time_series_causal_graph.py b/cai_causal_graph/time_series_causal_graph.py index d7989b0..76d33ae 100644 --- a/cai_causal_graph/time_series_causal_graph.py +++ b/cai_causal_graph/time_series_causal_graph.py @@ -306,7 +306,16 @@ def get_minimal_graph(self) -> TimeSeriesCausalGraph: if ( minimal_cg.variables is None or variable not in minimal_cg.variables ) and not minimal_cg.node_exists(variable): - minimal_cg.add_node(variable) + # Guaranteed there is at least one node, take the first for simplicity. + # This aligns with how edges are chosen for minimal graph. + # Only issue is if they have different meta. TODO: CAUSALAI-4784 + original_floating_node = self.get_nodes_for_variable_name(variable)[0] + new_floating_node = self._NodeCls( + identifier=variable, + meta=original_floating_node.meta, + variable_type=original_floating_node.variable_type, + ) + minimal_cg.add_node(node=new_floating_node) return minimal_cg diff --git a/changelog.md b/changelog.md index 2580732..f8339e8 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,12 @@ # Changelog +## NEXT + +- Fixed a bug where metadata on floating nodes would not be correctly carried over to the minimal graph when calling + `cai_causal_graph.time_series_causal_graph.TimeSeriesCausalGraph.get_minimal_graph`. This also fixes an issue where + `cai_causal_graph.time_series_causal_graph.TimeSeriesCausalGraph.is_minimal_graph` could return `False` on minimal + graphs that contain floating nodes. + ## 0.4.8 - Upgraded `docs-builder` dependency to `"~0.2.1"` in the Makefile and updated syntax to support newer `poetry`. diff --git a/tests/test_time_series_graph.py b/tests/test_time_series_graph.py index a8141bd..52af9fc 100644 --- a/tests/test_time_series_graph.py +++ b/tests/test_time_series_graph.py @@ -1931,3 +1931,73 @@ def test_remove_non_existent_node_raises(self): ts_graph._remove_node_from_cache(ts_graph.get_node(ts_graph._NodeCls.identifier_from('B'))) with self.assertRaises(ValueError): ts_graph.delete_node('B') + + def test_floating_nodes_correctly_added_to_minimal_graph(self): + cg = TimeSeriesCausalGraph() + cg.add_time_edge('a', -1, 'a', 0) + cg.add_time_edge('a', 0, 'b', 0) + cg.add_time_edge('a', -1, 'b', -1) + cg.add_node( + variable_name='floating', time_lag=0, meta={'some': 'metadata'}, variable_type=NodeVariableType.BINARY + ) + + self.assertFalse(cg.is_minimal_graph()) + + minimal_graph = cg.get_minimal_graph() + + self.assertTrue(minimal_graph.is_minimal_graph()) + + minimal_node = minimal_graph.get_node('floating') + self.assertEqual(minimal_node.meta['some'], 'metadata') + self.assertEqual(minimal_node.variable_type, NodeVariableType.BINARY) + self.assertEqual(minimal_node.time_lag, 0) + + def test_lagged_floating_nodes_correctly_added_to_minimal_graph(self): + """Same as previous test, but floating node is lagged. It should still be at lag=0 in the minimal graph.""" + cg = TimeSeriesCausalGraph() + cg.add_time_edge('a', -1, 'a', 0) + cg.add_time_edge('a', 0, 'b', 0) + cg.add_time_edge('a', -1, 'b', -1) + cg.add_node( + variable_name='floating', time_lag=-1, meta={'some': 'metadata'}, variable_type=NodeVariableType.BINARY + ) + + self.assertFalse(cg.is_minimal_graph()) + + minimal_graph = cg.get_minimal_graph() + + self.assertTrue(minimal_graph.is_minimal_graph()) + + minimal_node = minimal_graph.get_node('floating') + self.assertEqual(minimal_node.meta['some'], 'metadata') + self.assertEqual(minimal_node.variable_type, NodeVariableType.BINARY) + self.assertEqual(minimal_node.time_lag, 0) + + # Add td=0 floating as well to ensure we end up with 1. + cg.add_node( + variable_name='floating', time_lag=0, meta={'some more': 'metadata'}, variable_type=NodeVariableType.BINARY + ) + + # Add another floating node so minimal graph should have 2 floating nodes at td=0. + cg.add_node( + variable_name='floating_cont', + time_lag=-1, + meta={'some cont': 'metadata'}, + variable_type=NodeVariableType.CONTINUOUS, + ) + + self.assertFalse(cg.is_minimal_graph()) + + minimal_graph = cg.get_minimal_graph() + + self.assertTrue(minimal_graph.is_minimal_graph()) + + floating_node = minimal_graph.get_node('floating') + self.assertEqual(floating_node.meta['some'], 'metadata') # it takes meta from td=-1 as it is first + self.assertEqual(floating_node.variable_type, NodeVariableType.BINARY) + self.assertEqual(floating_node.time_lag, 0) + + floating_cont_node = minimal_graph.get_node('floating_cont') + self.assertEqual(floating_cont_node.meta['some cont'], 'metadata') + self.assertEqual(floating_cont_node.variable_type, NodeVariableType.CONTINUOUS) + self.assertEqual(floating_cont_node.time_lag, 0)