From 05ae8b31ac1229f3edf89620bc25e2f2991f7739 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Fri, 13 Aug 2021 09:40:30 -0400 Subject: [PATCH] Add function to approximate minimum Steiner tree (#392) * Add steiner_tree function This commit adds a function to find an approximation of the minimum Steiner tree using the algorithm described in "A fast algorithm for Steiner tree" Kou, Markowsky & Berman (1981). https://link.springer.com/article/10.1007/BF00288961 This algorithm produces a tree whose weight is within a (2 - (2 / t)) factor of the weight of the optimal Steiner tree where t is the number of terminal nodes. Closes #389 * Add more tests * Switch to iterating over node indices in metric_closure In #390 one of the last changes made to the PR before merging was: https://github.com/Qiskit/retworkx/pull/390/commits/3353b4b4cc3a7f859812145e25eb6290dc446ec4 which changed the final loop constructing the metric closure edges from iterating over node indices and pulling the path from the hashmap for that index to just iterating over the hashmap. That however had the unintended side effect of introducing non-determinism to the output as the iteration order over a hashmap isn't guaranteed. This was causing failures in the steiner tree function as the metric closure edge order can affect the computed tree. This commit fixes this by switching back to using the node id order for the final output generation and adds a comment as to why. * Fix release notes * Apply suggestions from code review Co-authored-by: georgios-ts <45130028+georgios-ts@users.noreply.github.com> * Split out edge deduplication to separate function * Attempt to avoid cycles by using src and target in sort Co-authored-by: georgios-ts <45130028+georgios-ts@users.noreply.github.com> * Remove input graph usage in edge deduplication Co-authored-by: georgios-ts <45130028+georgios-ts@users.noreply.github.com> * Move steiner_tree() to Tree docs section * Add test case with equal distances This adds a test case with an input graph that has equal distances. It is a test based on a review comment [1] that was previously producing an incorrect output (that wasn't even a tree). After fixing the underlying issue it would be good to have this tested to ensure we don't regress on it in the future. [1] https://github.com/Qiskit/retworkx/pull/392#discussion_r687646066 Co-authored-by: georgios-ts <45130028+georgios-ts@users.noreply.github.com> --- docs/source/api.rst | 1 + .../notes/steiner_tree-3e5282c65095868a.yaml | 12 +- src/generators.rs | 2 +- src/lib.rs | 1 + src/steiner_tree.rs | 177 +++++++++++++++++- tests/graph/test_steiner_tree.py | 80 +++++++- 6 files changed, 264 insertions(+), 9 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 15de84e17d..a25c5d5e0c 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -78,6 +78,7 @@ Tree retworkx.minimum_spanning_edges retworkx.minimum_spanning_tree + retworkx.steiner_tree .. _isomorphism: diff --git a/releasenotes/notes/steiner_tree-3e5282c65095868a.yaml b/releasenotes/notes/steiner_tree-3e5282c65095868a.yaml index 614c75f428..26bf4d6a0c 100644 --- a/releasenotes/notes/steiner_tree-3e5282c65095868a.yaml +++ b/releasenotes/notes/steiner_tree-3e5282c65095868a.yaml @@ -3,7 +3,13 @@ features: - | Added a new function, :func:`~retworkx.metric_closure`, which is used to generate the metric closure of a given graph. This function is used in - the computation of for calculating the Steiner Tree using the algorithm + the computation for calculating the Steiner Tree using the algorithm from - `Kou, Markowsky & Berman (1981). "A fast algorithm for Steiner trees". - `__ + Kou, Markowsky & Berman (1981). "A fast algorithm for Steiner trees". + https://link.springer.com/article/10.1007/BF00288961 + - | + Added a new function, :func:`~retworkx.steiner_tree`, which is used to + generate an approximation of the minimum Steiner tree using the + algorithm from: + Kou, Markowsky & Berman (1981). "A fast algorithm for Steiner trees". + https://link.springer.com/article/10.1007/BF00288961 diff --git a/src/generators.rs b/src/generators.rs index 9fa7edad3a..811bf4b7d3 100644 --- a/src/generators.rs +++ b/src/generators.rs @@ -25,7 +25,7 @@ use pyo3::Python; use super::digraph; use super::graph; -fn pairwise(right: I) -> impl Iterator, I::Item)> +pub fn pairwise(right: I) -> impl Iterator, I::Item)> where I: IntoIterator + Clone, { diff --git a/src/lib.rs b/src/lib.rs index 431df6705f..9e3051edb4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -233,6 +233,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(digraph_num_shortest_paths_unweighted))?; m.add_wrapped(wrap_pyfunction!(graph_num_shortest_paths_unweighted))?; m.add_wrapped(wrap_pyfunction!(metric_closure))?; + m.add_wrapped(wrap_pyfunction!(steiner_tree))?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/steiner_tree.rs b/src/steiner_tree.rs index d47eeb5820..51b2104e97 100644 --- a/src/steiner_tree.rs +++ b/src/steiner_tree.rs @@ -10,14 +10,20 @@ // License for the specific language governing permissions and limitations // under the License. +use std::cmp::Ordering; + use hashbrown::{HashMap, HashSet}; +use rayon::prelude::*; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::Python; -use petgraph::graph::NodeIndex; +use petgraph::graph::{EdgeIndex, NodeIndex}; +use petgraph::unionfind::UnionFind; +use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeIndexable}; +use crate::generators::pairwise; use crate::graph; use crate::shortest_path::all_pairs_dijkstra::all_pairs_dijkstra_shortest_paths; @@ -97,8 +103,9 @@ fn _metric_closure_edges( not defined for a graph with unconnected nodes", )); } - for (node, path) in paths { - let path_map = path.paths; + // Iterate over node indices for a deterministic order + for node in graph.graph.node_indices().map(|x| x.index()) { + let path_map = &paths[&node].paths; nodes.remove(&node); let distance = &distances[&node]; for v in &nodes { @@ -113,3 +120,167 @@ fn _metric_closure_edges( } Ok(out_vec) } + +/// Return an approximation to the minimum Steiner tree of a graph. +/// +/// The minimum tree of ``graph`` with regard to a set of ``terminal_nodes`` +/// is a tree within ``graph`` that spans those nodes and has a minimum size +/// (measured as the sum of edge weights) amoung all such trees. +/// +/// The minimum steiner tree can be approximated by computing the minimum +/// spanning tree of the subgraph of the metric closure of ``graph`` induced +/// by the terminal nodes, where the metric closure of ``graph`` is the +/// complete graph in which each edge is weighted by the shortest path distance +/// between nodes in ``graph``. +/// +/// This algorithm [1]_ produces a tree whose weight is within a +/// :math:`(2 - (2 / t))` factor of the weight of the optimal Steiner tree +/// where :math:`t` is the number of terminal nodes. +/// +/// :param PyGraph graph: The graph to compute the minimum Steiner tree for +/// :param list terminal_nodes: The list of node indices for which the Steiner +/// tree is to be computed for. +/// :param weight_fn: A callable object that will be passed an edge's +/// weight/data payload and expected to return a ``float``. For example, +/// you can use ``weight_fn=float`` to cast every weight as a float. +/// +/// :returns: An approximation to the minimal steiner tree of ``graph`` induced +/// by ``terminal_nodes``. +/// :rtype: PyGraph +/// +/// .. [1] Kou, Markowsky & Berman, +/// "A fast algorithm for Steiner trees" +/// Acta Informatica 15, 141–145 (1981). +/// https://link.springer.com/article/10.1007/BF00288961 +#[pyfunction] +#[pyo3(text_signature = "(graph, terminal_nodes, weight_fn, /)")] +pub fn steiner_tree( + py: Python, + graph: &graph::PyGraph, + terminal_nodes: Vec, + weight_fn: PyObject, +) -> PyResult { + let terminal_node_set: HashSet = + terminal_nodes.into_iter().collect(); + let metric_edges = + _metric_closure_edges(py, graph, weight_fn.clone_ref(py))?; + // Calculate mst edges from metric closure edge list: + let mut subgraphs = UnionFind::::new(graph.graph.node_bound()); + let mut edge_list: Vec = + Vec::with_capacity(metric_edges.len()); + for edge in metric_edges { + if !terminal_node_set.contains(&edge.source) + || !terminal_node_set.contains(&edge.target) + { + continue; + } + let weight = edge.distance; + if weight.is_nan() { + return Err(PyValueError::new_err("NaN found as an edge weight")); + } + edge_list.push(edge); + } + edge_list.par_sort_unstable_by(|a, b| { + let weight_a = (a.distance, a.source, a.target); + let weight_b = (b.distance, b.source, b.target); + weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less) + }); + let mut mst_edges: Vec = Vec::new(); + for float_edge_pair in edge_list { + let u = float_edge_pair.source; + let v = float_edge_pair.target; + if subgraphs.union(u, v) { + mst_edges.push(float_edge_pair); + } + } + // Generate the output graph from the MST of the metric closure + let out_edge_list: Vec<[usize; 2]> = mst_edges + .iter() + .map(|edge| pairwise(edge.path.clone())) + .flatten() + .filter_map(|x| x.0.map(|a| [a, x.1])) + .collect(); + let out_edges: HashSet<(usize, usize)> = + out_edge_list.iter().map(|x| (x[0], x[1])).collect(); + let mut out_graph = graph.clone(); + let out_nodes: HashSet = out_edge_list + .iter() + .map(|x| x.iter()) + .flatten() + .copied() + .map(NodeIndex::new) + .collect(); + for node in graph + .graph + .node_indices() + .filter(|node| !out_nodes.contains(node)) + { + out_graph.graph.remove_node(node); + out_graph.node_removed = true; + } + for edge in graph.graph.edge_references().filter(|edge| { + let source = edge.source().index(); + let target = edge.target().index(); + !out_edges.contains(&(source, target)) + && !out_edges.contains(&(target, source)) + }) { + out_graph.graph.remove_edge(edge.id()); + } + // Deduplicate potential duplicate edges + deduplicate_edges(py, &mut out_graph, &weight_fn)?; + Ok(out_graph) +} + +fn deduplicate_edges( + py: Python, + out_graph: &mut graph::PyGraph, + weight_fn: &PyObject, +) -> PyResult<()> { + if out_graph.multigraph { + // Find all edges between nodes + let mut duplicate_map: HashMap< + [NodeIndex; 2], + Vec<(EdgeIndex, PyObject)>, + > = HashMap::new(); + for edge in out_graph.graph.edge_references() { + if duplicate_map.contains_key(&[edge.source(), edge.target()]) { + duplicate_map + .get_mut(&[edge.source(), edge.target()]) + .unwrap() + .push((edge.id(), edge.weight().clone_ref(py))); + } else if duplicate_map + .contains_key(&[edge.target(), edge.source()]) + { + duplicate_map + .get_mut(&[edge.target(), edge.source()]) + .unwrap() + .push((edge.id(), edge.weight().clone_ref(py))); + } else { + duplicate_map.insert( + [edge.source(), edge.target()], + vec![(edge.id(), edge.weight().clone_ref(py))], + ); + } + } + // For a node pair with > 1 edge find minimum edge and remove others + for edges_raw in duplicate_map.values().filter(|x| x.len() > 1) { + let mut edges: Vec<(EdgeIndex, f64)> = + Vec::with_capacity(edges_raw.len()); + for edge in edges_raw { + let res = weight_fn.call1(py, (&edge.1,))?; + let raw = res.to_object(py); + let weight = raw.extract(py)?; + edges.push((edge.0, weight)); + } + edges.par_sort_unstable_by(|a, b| { + let weight_a = a.1; + let weight_b = b.1; + weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less) + }); + edges[1..].iter().for_each(|x| { + out_graph.graph.remove_edge(x.0); + }); + } + } + Ok(()) +} diff --git a/tests/graph/test_steiner_tree.py b/tests/graph/test_steiner_tree.py index 2ac7bdef0e..04012ddcf0 100644 --- a/tests/graph/test_steiner_tree.py +++ b/tests/graph/test_steiner_tree.py @@ -16,9 +16,9 @@ import retworkx -class TestMetricClosure(unittest.TestCase): +class TestSteinerTree(unittest.TestCase): def setUp(self): - self.graph = retworkx.PyGraph() + self.graph = retworkx.PyGraph(multigraph=False) self.graph.add_node(None) self.graph.extend_from_weighted_edge_list( [ @@ -112,3 +112,79 @@ def test_metric_closure_empty_graph(self): graph = retworkx.PyGraph() closure = retworkx.metric_closure(graph, weight_fn=float) self.assertEqual([], closure.weighted_edge_list()) + + def test_steiner_graph(self): + steiner_tree = retworkx.steiner_tree( + self.graph, [1, 2, 3, 4, 5], weight_fn=float + ) + expected_steiner_tree = [ + (1, 2, 10), + (2, 3, 10), + (2, 7, 1), + (3, 4, 10), + (7, 5, 1), + ] + steiner_tree_edge_list = steiner_tree.weighted_edge_list() + for edge in expected_steiner_tree: + self.assertIn(edge, steiner_tree_edge_list) + + def test_steiner_graph_multigraph(self): + edge_list = [ + (1, 2, 1), + (2, 3, 999), + (2, 3, 1), + (3, 4, 1), + (3, 5, 1), + ] + graph = retworkx.PyGraph() + graph.extend_from_weighted_edge_list(edge_list) + graph.remove_node(0) + terminal_nodes = [2, 4, 5] + tree = retworkx.steiner_tree(graph, terminal_nodes, weight_fn=float) + expected_edges = [ + (2, 3, 1), + (3, 4, 1), + (3, 5, 1), + ] + steiner_tree_edge_list = tree.weighted_edge_list() + for edge in expected_edges: + self.assertIn(edge, steiner_tree_edge_list) + + def test_not_connected_steiner_tree(self): + self.graph.add_node(None) + with self.assertRaises(ValueError): + retworkx.steiner_tree(self.graph, [0, 1, 2], weight_fn=float) + + def test_steiner_tree_empty_graph(self): + graph = retworkx.PyGraph() + tree = retworkx.steiner_tree(graph, [], weight_fn=float) + self.assertEqual([], tree.weighted_edge_list()) + + def test_equal_distance_graph(self): + n = 3 + graph = retworkx.PyGraph() + graph.add_nodes_from(range(n + 5)) + graph.add_edges_from( + [ + (n, n + 1, 0.5), + (n, n + 2, 0.5), + (n + 1, n + 2, 0.5), + (n, n + 3, 0.5), + (n + 1, n + 4, 0.5), + ] + ) + graph.add_edges_from([(i, n + 2, 2) for i in range(n)]) + terminals = list(range(5)) + [n + 3, n + 4] + tree = retworkx.steiner_tree(graph, terminals, weight_fn=float) + # Assert no cycle + self.assertEqual(retworkx.cycle_basis(tree), []) + expected_edges = [ + (3, 4, 0.5), + (3, 5, 0.5), + (3, 6, 0.5), + (4, 7, 0.5), + (0, 5, 2), + (1, 5, 2), + (2, 5, 2), + ] + self.assertEqual(tree.weighted_edge_list(), expected_edges)