Skip to content

Commit

Permalink
Improvement: Add is dag to repr for CausalGraph (#75)
Browse files Browse the repository at this point in the history
Description and Motivation
Add is dag to graph repr. This is to help with the package analytics.

How Has This Been Tested?
Extended unit tests.
  • Loading branch information
andrew-causalens authored Jun 10, 2024
1 parent 72e22d3 commit 672c481
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 3 deletions.
3 changes: 2 additions & 1 deletion cai_causal_graph/causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2114,7 +2114,8 @@ def copy(self, include_meta: bool = True) -> CausalGraph:
def __repr__(self) -> str:
"""Return a string description of the `cai_causal_graph.causal_graph.CausalGraph` instance."""
return (
f'{self.__class__.__name__}(num_nodes={len(self.nodes)}, num_edges={len(self.edges)}, id={self.__hash__()})'
f'{self.__class__.__name__}(num_nodes={len(self.nodes)}, num_edges={len(self.edges)}, '
f'id={self.__hash__()}, is_dag={self.is_dag()})'
f'\n'
f'Nodes: {self.nodes}\nEdges: {self.edges}'
)
Expand Down
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
- Generalized `cai_causal_graph.causal_graph.CausalGraph.__eq__` to check for the class of the instance itself, enabling
to reuse this method by extending classes.
- Added `cai_causal_graph.causal_graph.CausalGraph.has_non_serializable_metadata` method, which returns `False` by default.
- Extended the string representation (`repr`) of `cai_causal_graph.causal_graph.CausalGraph` to include whether the
graph instance is a directed acyclic graph (DAG).
- Dropped support for `python` `3.8` as it is approaching end of life.

## 0.4.10
Expand Down
62 changes: 60 additions & 2 deletions tests/test_causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import numpy
import pandas

from cai_causal_graph import CausalGraph, EdgeType, NodeVariableType
from cai_causal_graph import CausalGraph, EdgeType, NodeVariableType, TimeSeriesCausalGraph
from cai_causal_graph import __version__ as VERSION
from cai_causal_graph.exceptions import CausalGraphErrors
from cai_causal_graph.graph_components import Edge, Node
from cai_causal_graph.graph_components import Edge, Node, TimeSeriesEdge, TimeSeriesNode


class TestCausalGraph(unittest.TestCase):
Expand Down Expand Up @@ -1163,13 +1163,39 @@ def test_default_nodes_and_edges(self):
self.assertTrue(n.__repr__().startswith('Node'))
self.assertTrue(e.__repr__().startswith('Edge'))
self.assertTrue(cg.__repr__().startswith('CausalGraph'))
self.assertIn('is_dag=True', cg.__repr__())

self.assertIsInstance(n.details(), str)
self.assertIsInstance(e.details(), str)
self.assertIsInstance(cg.details(), str)
self.assertTrue(n.details().startswith('Node'))
self.assertTrue(e.details().startswith('Edge'))
self.assertTrue(cg.details().startswith('CausalGraph'))
self.assertIn('is_dag=True', cg.details())

tscg = TimeSeriesCausalGraph.from_causal_graph(cg)
n = tscg.get_node('a')
e = tscg.get_edge('a', 'b')

self.assertIsInstance(n.__hash__(), int)
self.assertIsInstance(e.__hash__(), int)
self.assertIsInstance(tscg.__hash__(), int)

self.assertIsInstance(n.__repr__(), str)
self.assertIsInstance(e.__repr__(), str)
self.assertIsInstance(tscg.__repr__(), str)
self.assertTrue(n.__repr__().startswith('TimeSeriesNode'))
self.assertTrue(e.__repr__().startswith('TimeSeriesEdge'))
self.assertTrue(tscg.__repr__().startswith('TimeSeriesCausalGraph'))
self.assertIn('is_dag=True', tscg.__repr__())

self.assertIsInstance(n.details(), str)
self.assertIsInstance(e.details(), str)
self.assertIsInstance(cg.details(), str)
self.assertTrue(n.details().startswith('TimeSeriesNode'))
self.assertTrue(e.details().startswith('TimeSeriesEdge'))
self.assertTrue(tscg.details().startswith('TimeSeriesCausalGraph'))
self.assertIn('is_dag=True', tscg.details())

def test_complex_nodes_and_edges(self):
cg = CausalGraph()
Expand All @@ -1187,13 +1213,39 @@ def test_complex_nodes_and_edges(self):
self.assertTrue(n.__repr__().startswith('Node'))
self.assertTrue(e.__repr__().startswith('Edge'))
self.assertTrue(cg.__repr__().startswith('CausalGraph'))
self.assertIn('is_dag=False', cg.__repr__())

self.assertIsInstance(n.details(), str)
self.assertIsInstance(e.details(), str)
self.assertIsInstance(cg.details(), str)
self.assertTrue(n.details().startswith('Node'))
self.assertTrue(e.details().startswith('Edge'))
self.assertTrue(cg.details().startswith('CausalGraph'))
self.assertIn('is_dag=False', cg.details())

tscg = TimeSeriesCausalGraph.from_causal_graph(cg)
n = tscg.get_node('a')
e = tscg.get_edge('a', 'b')

self.assertIsInstance(n.__hash__(), int)
self.assertIsInstance(e.__hash__(), int)
self.assertIsInstance(tscg.__hash__(), int)

self.assertIsInstance(n.__repr__(), str)
self.assertIsInstance(e.__repr__(), str)
self.assertIsInstance(tscg.__repr__(), str)
self.assertTrue(n.__repr__().startswith('TimeSeriesNode'))
self.assertTrue(e.__repr__().startswith('TimeSeriesEdge'))
self.assertTrue(tscg.__repr__().startswith('TimeSeriesCausalGraph'))
self.assertIn('is_dag=False', tscg.__repr__())

self.assertIsInstance(n.details(), str)
self.assertIsInstance(e.details(), str)
self.assertIsInstance(cg.details(), str)
self.assertTrue(n.details().startswith('TimeSeriesNode'))
self.assertTrue(e.details().startswith('TimeSeriesEdge'))
self.assertTrue(tscg.details().startswith('TimeSeriesCausalGraph'))
self.assertIn('is_dag=False', tscg.details())

def test_add_node_from_node(self):
causal_graph = CausalGraph()
Expand Down Expand Up @@ -1276,3 +1328,9 @@ def test_node_repr(self):
self.assertEqual(cg['b'], cg.get_node('b'))
self.assertEqual(repr(cg['a']), 'Node("a")')
self.assertEqual(repr(cg['b']), 'Node("b", type="continuous")')

tscg = TimeSeriesCausalGraph.from_causal_graph(cg)
self.assertEqual(tscg['a'], tscg.get_node('a'))
self.assertEqual(tscg['b'], tscg.get_node('b'))
self.assertEqual(repr(tscg['a']), 'TimeSeriesNode("a")')
self.assertEqual(repr(tscg['b']), 'TimeSeriesNode("b", type="continuous")')

0 comments on commit 672c481

Please sign in to comment.