diff --git a/rustworkx/__init__.pyi b/rustworkx/__init__.pyi index 5952e177e..d1698e969 100644 --- a/rustworkx/__init__.pyi +++ b/rustworkx/__init__.pyi @@ -71,6 +71,7 @@ from .rustworkx import number_connected_components as number_connected_component from .rustworkx import number_weakly_connected_components as number_weakly_connected_components from .rustworkx import node_connected_component as node_connected_component from .rustworkx import strongly_connected_components as strongly_connected_components +from .rustworkx import condensation as condensation from .rustworkx import weakly_connected_components as weakly_connected_components from .rustworkx import digraph_adjacency_matrix as digraph_adjacency_matrix from .rustworkx import graph_adjacency_matrix as graph_adjacency_matrix diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 4a18edd61..ebbfc223a 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -192,6 +192,7 @@ def number_connected_components(graph: PyGraph, /) -> int: ... def number_weakly_connected_components(graph: PyDiGraph, /) -> bool: ... def node_connected_component(graph: PyGraph, node: int, /) -> set[int]: ... def strongly_connected_components(graph: PyDiGraph, /) -> list[list[int]]: ... +def condensation(graph: PyDiGraph, /, sccs=None) -> PyDiGraph: ... def weakly_connected_components(graph: PyDiGraph, /) -> list[set[int]]: ... def digraph_adjacency_matrix( graph: PyDiGraph[_S, _T], diff --git a/src/connectivity/mod.rs b/src/connectivity/mod.rs index ecd2858be..059d66dfc 100644 --- a/src/connectivity/mod.rs +++ b/src/connectivity/mod.rs @@ -21,12 +21,11 @@ use super::{ }; use hashbrown::{HashMap, HashSet}; -use petgraph::algo; -use petgraph::algo::condensation; -use petgraph::graph::DiGraph; +use petgraph::graph::{DiGraph, IndexType}; use petgraph::stable_graph::NodeIndex; use petgraph::unionfind::UnionFind; use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeCount, NodeIndexable, Visitable}; +use petgraph::{algo, Graph}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -35,6 +34,7 @@ use rayon::prelude::*; use ndarray::prelude::*; use numpy::IntoPyArray; +use petgraph::prelude::StableGraph; use crate::iterators::{ AllPairsMultiplePathMapping, BiconnectedComponents, Chains, EdgeList, NodeIndices, @@ -114,6 +114,79 @@ pub fn strongly_connected_components(graph: &digraph::PyDiGraph) -> Vec( + py: &Python, + g: Graph, + make_acyclic: bool, + sccs: Option>>, +) -> StableGraph +where + Ty: EdgeType, + Ix: IndexType, + N: ToPyObject, + E: ToPyObject, +{ + // Don't use into_iter to avoid extra allocations + let sccs = if let Some(sccs) = sccs { + sccs.iter() + .map(|row| row.iter().map(|x| NodeIndex::new(*x)).collect()) + .collect() + } else { + algo::kosaraju_scc(&g) + }; + + let mut condensed: StableGraph, E, Ty, Ix> = + StableGraph::with_capacity(sccs.len(), g.edge_count()); + + // Build a map from old indices to new ones. + let mut node_map = vec![NodeIndex::end(); g.node_count()]; + for comp in sccs { + let new_nix = condensed.add_node(Vec::new()); + for nix in comp { + node_map[nix.index()] = new_nix; + } + } + + // Consume nodes and edges of the old graph and insert them into the new one. + let (nodes, edges) = g.into_nodes_edges(); + for (nix, node) in nodes.into_iter().enumerate() { + condensed[node_map[nix]].push(node.weight); + } + for edge in edges { + let source = node_map[edge.source().index()]; + let target = node_map[edge.target().index()]; + if make_acyclic { + if source != target { + condensed.update_edge(source, target, edge.weight); + } + } else { + condensed.add_edge(source, target, edge.weight); + } + } + condensed.map(|_, w| w.to_object(*py), |_, w| w.to_object(*py)) +} + +#[pyfunction] +#[pyo3(text_signature = "(graph, /, sccs=None)", signature=(graph, sccs=None))] +pub fn condensation( + py: Python, + graph: &digraph::PyDiGraph, + sccs: Option>>, +) -> digraph::PyDiGraph { + let g = graph.graph.clone(); + + let condensed = condensation_inner(&py, g.into(), true, sccs); + + digraph::PyDiGraph { + graph: condensed, + cycle_state: algo::DfsSpace::default(), + check_cycle: false, + node_removed: false, + multigraph: true, + attrs: py.None(), + } +} + /// Return the first cycle encountered during DFS of a given PyDiGraph, /// empty list is returned if no cycle is found /// @@ -295,7 +368,7 @@ pub fn is_semi_connected(graph: &digraph::PyDiGraph) -> PyResult { temp_graph.add_edge(node_map[source.index()], node_map[target.index()], ()); } - let condensed = condensation(temp_graph, true); + let condensed = algo::condensation(temp_graph, true); let n = condensed.node_count(); let weight_fn = |_: petgraph::graph::EdgeReference<()>| Ok::(1usize); diff --git a/src/lib.rs b/src/lib.rs index 4ee4189a7..21661b4e8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -569,6 +569,7 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(cycle_basis))?; m.add_wrapped(wrap_pyfunction!(simple_cycles))?; m.add_wrapped(wrap_pyfunction!(strongly_connected_components))?; + m.add_wrapped(wrap_pyfunction!(condensation))?; m.add_wrapped(wrap_pyfunction!(digraph_dfs_edges))?; m.add_wrapped(wrap_pyfunction!(graph_dfs_edges))?; m.add_wrapped(wrap_pyfunction!(digraph_find_cycle))?; diff --git a/tests/digraph/test_strongly_connected.py b/tests/digraph/test_strongly_connected.py index 4a49d76dd..22c352d3b 100644 --- a/tests/digraph/test_strongly_connected.py +++ b/tests/digraph/test_strongly_connected.py @@ -65,3 +65,55 @@ def test_number_strongly_connected_big(self): node = G.add_node(i) G.add_child(node, str(i), {}) self.assertEqual(len(rustworkx.strongly_connected_components(G)), 200000) + + +class TestCondensation(unittest.TestCase): + def setUp(self): + # グラフをセットアップ + self.graph = rustworkx.PyDiGraph() + self.node_a = self.graph.add_node("a") + self.node_b = self.graph.add_node("b") + self.node_c = self.graph.add_node("c") + self.node_d = self.graph.add_node("d") + self.node_e = self.graph.add_node("e") + self.node_f = self.graph.add_node("f") + self.node_g = self.graph.add_node("g") + self.node_h = self.graph.add_node("h") + + # エッジを追加 + self.graph.add_edge(self.node_a, self.node_b, "a->b") + self.graph.add_edge(self.node_b, self.node_c, "b->c") + self.graph.add_edge(self.node_c, self.node_d, "c->d") + self.graph.add_edge(self.node_d, self.node_a, "d->a") # サイクル: a -> b -> c -> d -> a + + self.graph.add_edge(self.node_b, self.node_e, "b->e") + + self.graph.add_edge(self.node_e, self.node_f, "e->f") + self.graph.add_edge(self.node_f, self.node_g, "f->g") + self.graph.add_edge(self.node_g, self.node_h, "g->h") + self.graph.add_edge(self.node_h, self.node_e, "h->e") # サイクル: e -> f -> g -> h -> e + + def test_condensation(self): + # condensation関数を呼び出し + condensed_graph = rustworkx.condensation(self.graph) + + # ノード数を確認(2つのサイクルが1つずつのノードに縮約される) + self.assertEqual( + len(condensed_graph.node_indices()), 2 + ) # [SCC(a, b, c, d), SCC(e, f, g, h)] + + # エッジ数を確認 + self.assertEqual( + len(condensed_graph.edge_indices()), 1 + ) # Edge: [SCC(a, b, c, d)] -> [SCC(e, f, g, h)] + + # 縮約されたノードの内容を確認 + nodes = list(condensed_graph.nodes()) + scc1 = nodes[0] + scc2 = nodes[1] + self.assertTrue(set(scc1) == {"a", "b", "c", "d"} or set(scc2) == {"a", "b", "c", "d"}) + self.assertTrue(set(scc1) == {"e", "f", "g", "h"} or set(scc2) == {"e", "f", "g", "h"}) + + # エッジの内容を確認 + weight = condensed_graph.edges()[0] + self.assertIn("b->e", weight) # 縮約後のグラフにおいて、正しいエッジが残っていることを確認