Skip to content

Commit

Permalink
Improvement: generalize CausalGraph.from_adjacency_matrix (#53)
Browse files Browse the repository at this point in the history
Description and Motivation
The cai_causal_graph.causal_graph.CausalGraph.from_adjacency_matrix is now a class method (rather than being a static method) which has been generalized to return an instance of the class on which it has been called (e.g. enabling returning instances of classes inheriting from cai_causal_graph.causal_graph.CausalGraph.

Additional Context
This is to simplify extending CausalGraph class and to bring from_adjancency_matrix in line with other class construction methods such as from_dict.

This change is completely backwards compatible since the API of the method is unchanged.

How Has This Been Tested?
Existing tests.
  • Loading branch information
ilya-causalens authored Jan 26, 2024
1 parent 81ea49a commit eb8963f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
6 changes: 3 additions & 3 deletions cai_causal_graph/causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1884,9 +1884,9 @@ def from_gml_string(gml: str) -> CausalGraph:
g = networkx.parse_gml(gml)
return CausalGraph.from_networkx(g)

@staticmethod
@classmethod
def from_adjacency_matrix(
adjacency: numpy.ndarray, node_names: Optional[List[Union[NodeLike, int]]] = None
cls, adjacency: numpy.ndarray, node_names: Optional[List[Union[NodeLike, int]]] = None
) -> CausalGraph:
"""
Construct a `cai_causal_graph.causal_graph.CausalGraph` instance from an adjacency matrix and optionally a list
Expand Down Expand Up @@ -1927,7 +1927,7 @@ def from_adjacency_matrix(
nodes = [CausalGraph.coerce_to_nodelike(node) for node in node_names] # type: ignore

# Add edges. Any conversion from BasicFeature or BasicTarget is handled by the add_edge method.
graph = CausalGraph()
graph = cls()
graph.add_nodes_from(nodes)
for i, j in itertools.combinations(range(len(nodes)), 2):
if adjacency[i, j] != 0 and adjacency[j, i] == 0:
Expand Down
3 changes: 3 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
- Fixed a bug where `cai_causal_graph.causal_graph.CausalGraph.delete_edge` would not support passing source and destination
as `cai_causal_graph.graph_components.Node`.
- Added support for passing source and destination as `cai_causal_graph.graph_components.Node` to `cai_causal_graph.causal_graph.CausalGraph.remove_edge`.
- The `cai_causal_graph.causal_graph.CausalGraph.from_adjacency_matrix` is now a class method (rather than being a static method)
which has been generalized to return an instance of the class on which it has been called (e.g. enabling returning instances of
classes inheriting from `cai_causal_graph.causal_graph.CausalGraph`.

## 0.3.14

Expand Down

0 comments on commit eb8963f

Please sign in to comment.