Skip to content

Commit

Permalink
Add steiner_tree function
Browse files Browse the repository at this point in the history
This commit adds a function to find an approximation of the minimum
Steiner tree using the algorithm described in "A fast algorithm for
Steiner tree" Kou, Markowsky & Berman (1981).
https://link.springer.com/article/10.1007/BF00288961
This algorithm produces a tree whose weight is within a (2 - (2 / t))
factor of the weight of the optimal Steiner tree where t is the number
of terminal nodes.

Closes Qiskit#389
  • Loading branch information
mtreinish committed Jul 20, 2021
1 parent 7b7a20d commit 77f3983
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/generators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use pyo3::Python;
use super::digraph;
use super::graph;

fn pairwise<I>(right: I) -> impl Iterator<Item = (Option<I::Item>, I::Item)>
pub fn pairwise<I>(right: I) -> impl Iterator<Item = (Option<I::Item>, I::Item)>
where
I: IntoIterator + Clone,
{
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4898,6 +4898,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(digraph_num_shortest_paths_unweighted))?;
m.add_wrapped(wrap_pyfunction!(graph_num_shortest_paths_unweighted))?;
m.add_wrapped(wrap_pyfunction!(metric_closure))?;
m.add_wrapped(wrap_pyfunction!(steiner_tree))?;
m.add_class::<digraph::PyDiGraph>()?;
m.add_class::<graph::PyGraph>()?;
m.add_class::<iterators::BFSSuccessors>()?;
Expand Down
98 changes: 98 additions & 0 deletions src/steiner_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,21 @@
// License for the specific language governing permissions and limitations
// under the License.

use std::cmp::Ordering;

use hashbrown::{HashMap, HashSet};
use rayon::prelude::*;

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::Python;

use petgraph::graph::NodeIndex;
use petgraph::unionfind::UnionFind;
use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeIndexable};

use crate::_all_pairs_dijkstra_shortest_paths;
use crate::generators::pairwise;
use crate::graph;

struct MetricClosureEdge {
Expand Down Expand Up @@ -104,3 +110,95 @@ fn _metric_closure_edges(
}
Ok(out_vec)
}

/// Return an approximation to the minimum Steiner tree of a graph.
///
/// :param PyGraph graph:
/// :param list terminal_nodes:
/// :param weight_fn:
///
/// :returns: An approximation to the minimal steiner tree of ``graph`` induced
/// by ``terminal_nodes``.
#[pyfunction]
#[pyo3(text_signature = "(graph, terminal_nodes, weight_fn, /)")]
pub fn steiner_tree(
py: Python,
graph: &graph::PyGraph,
terminal_nodes: Vec<usize>,
weight_fn: PyObject,
) -> PyResult<graph::PyGraph> {
let terminal_node_set: HashSet<usize> =
terminal_nodes.into_iter().collect();
let metric_edges = _metric_closure_edges(py, graph, weight_fn)?;
// Calculate mst edges from metric closure edge list:
let mut subgraphs = UnionFind::<usize>::new(graph.graph.node_bound());
let mut edge_list: Vec<MetricClosureEdge> =
Vec::with_capacity(metric_edges.len());
for edge in metric_edges {
if !terminal_node_set.contains(&edge.source)
|| !terminal_node_set.contains(&edge.target)
{
continue;
}
let weight = edge.distance;
if weight.is_nan() {
return Err(PyValueError::new_err("NaN found as an edge weight"));
}
edge_list.push(edge);
}
edge_list.par_sort_unstable_by(|a, b| {
let weight_a = a.distance;
let weight_b = b.distance;
weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less)
});
let mut mst_edges: Vec<MetricClosureEdge> = Vec::new();
for float_edge_pair in edge_list {
let u = float_edge_pair.source;
let v = float_edge_pair.target;
if subgraphs.union(u, v) {
let w = float_edge_pair.distance;
let path = float_edge_pair.path;
mst_edges.push(MetricClosureEdge {
source: u,
target: v,
distance: w,
path,
});
}
}
// Generate the output graph from the MST of the metric closure
let out_edge_list: Vec<[usize; 2]> = mst_edges
.iter()
.map(|edge| pairwise(edge.path.clone()))
.flatten()
.filter_map(|x| x.0.map(|a| [a, x.1]))
.collect();
let out_edges: HashSet<(usize, usize)> =
out_edge_list.iter().map(|x| (x[0], x[1])).collect();
let mut out_graph = graph.clone();
let out_nodes: HashSet<NodeIndex> = out_edge_list
.iter()
.map(|x| x.iter())
.flatten()
.copied()
.map(NodeIndex::new)
.collect();
for node in graph
.graph
.node_indices()
.filter(|node| !out_nodes.contains(node))
{
out_graph.graph.remove_node(node);
out_graph.node_removed = true;
}
for edge in graph.graph.edge_references().filter(|edge| {
let source = edge.source().index();
let target = edge.target().index();
!out_edges.contains(&(source, target))
&& !out_edges.contains(&(target, source))
}) {
out_graph.graph.remove_edge(edge.id());
}
// TODO: Deduplicate edges with min weight for multigraphs
Ok(out_graph)
}
15 changes: 15 additions & 0 deletions tests/graph/test_steiner_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,18 @@ def test_not_connected_metric_closure(self):
self.graph.add_node(None)
with self.assertRaises(ValueError):
retworkx.metric_closure(self.graph, weight_fn=float)

def test_steiner_graph(self):
steiner_tree = retworkx.steiner_tree(
self.graph, [1, 2, 3, 4, 5], weight_fn=float
)
expected_steiner_tree = [
(1, 2, 10),
(2, 3, 10),
(2, 7, 1),
(3, 4, 10),
(7, 5, 1),
]
steiner_tree_edge_list = steiner_tree.weighted_edge_list()
for edge in expected_steiner_tree:
self.assertIn(edge, steiner_tree_edge_list)

0 comments on commit 77f3983

Please sign in to comment.