Skip to content

Commit

Permalink
Add function to approximate minimum Steiner tree (#392)
Browse files Browse the repository at this point in the history
* Add steiner_tree function

This commit adds a function to find an approximation of the minimum
Steiner tree using the algorithm described in "A fast algorithm for
Steiner tree" Kou, Markowsky & Berman (1981).
https://link.springer.com/article/10.1007/BF00288961
This algorithm produces a tree whose weight is within a (2 - (2 / t))
factor of the weight of the optimal Steiner tree where t is the number
of terminal nodes.

Closes #389

* Add more tests

* Switch to iterating over node indices in metric_closure

In #390 one of the last changes made to the PR before merging was:
3353b4b
which changed the final loop constructing the metric closure edges from
iterating over node indices and pulling the path from the hashmap for
that index to just iterating over the hashmap. That however had the
unintended side effect of introducing non-determinism to the output as
the iteration order over a hashmap isn't guaranteed. This was causing
failures in the steiner tree function as the metric closure edge order
can affect the computed tree. This commit fixes this by switching back
to using the node id order for the final output generation and adds a
comment as to why.

* Fix release notes

* Apply suggestions from code review

Co-authored-by: georgios-ts <[email protected]>

* Split out edge deduplication to separate function

* Attempt to avoid cycles by using src and target in sort

Co-authored-by: georgios-ts <[email protected]>

* Remove input graph usage in edge deduplication

Co-authored-by: georgios-ts <[email protected]>

* Move steiner_tree() to Tree docs section

* Add test case with equal distances

This adds a test case with an input graph that has equal distances. It
is a test based on a review comment [1] that was previously producing an
incorrect output (that wasn't even a tree). After fixing the underlying
issue it would be good to have this tested to ensure we don't regress on
it in the future.

[1] #392 (comment)

Co-authored-by: georgios-ts <[email protected]>
  • Loading branch information
mtreinish and georgios-ts authored Aug 13, 2021
1 parent adaeb31 commit 05ae8b3
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 9 deletions.
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Tree

retworkx.minimum_spanning_edges
retworkx.minimum_spanning_tree
retworkx.steiner_tree

.. _isomorphism:

Expand Down
12 changes: 9 additions & 3 deletions releasenotes/notes/steiner_tree-3e5282c65095868a.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ features:
- |
Added a new function, :func:`~retworkx.metric_closure`, which is used
to generate the metric closure of a given graph. This function is used in
the computation of for calculating the Steiner Tree using the algorithm
the computation for calculating the Steiner Tree using the algorithm
from
`Kou, Markowsky & Berman (1981). "A fast algorithm for Steiner trees".
<https://link.springer.com/article/10.1007/BF00288961>`__
Kou, Markowsky & Berman (1981). "A fast algorithm for Steiner trees".
https://link.springer.com/article/10.1007/BF00288961
- |
Added a new function, :func:`~retworkx.steiner_tree`, which is used to
generate an approximation of the minimum Steiner tree using the
algorithm from:
Kou, Markowsky & Berman (1981). "A fast algorithm for Steiner trees".
https://link.springer.com/article/10.1007/BF00288961
2 changes: 1 addition & 1 deletion src/generators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use pyo3::Python;
use super::digraph;
use super::graph;

fn pairwise<I>(right: I) -> impl Iterator<Item = (Option<I::Item>, I::Item)>
pub fn pairwise<I>(right: I) -> impl Iterator<Item = (Option<I::Item>, I::Item)>
where
I: IntoIterator + Clone,
{
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(digraph_num_shortest_paths_unweighted))?;
m.add_wrapped(wrap_pyfunction!(graph_num_shortest_paths_unweighted))?;
m.add_wrapped(wrap_pyfunction!(metric_closure))?;
m.add_wrapped(wrap_pyfunction!(steiner_tree))?;
m.add_class::<digraph::PyDiGraph>()?;
m.add_class::<graph::PyGraph>()?;
m.add_class::<iterators::BFSSuccessors>()?;
Expand Down
177 changes: 174 additions & 3 deletions src/steiner_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@
// License for the specific language governing permissions and limitations
// under the License.

use std::cmp::Ordering;

use hashbrown::{HashMap, HashSet};
use rayon::prelude::*;

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::Python;

use petgraph::graph::NodeIndex;
use petgraph::graph::{EdgeIndex, NodeIndex};
use petgraph::unionfind::UnionFind;
use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeIndexable};

use crate::generators::pairwise;
use crate::graph;
use crate::shortest_path::all_pairs_dijkstra::all_pairs_dijkstra_shortest_paths;

Expand Down Expand Up @@ -97,8 +103,9 @@ fn _metric_closure_edges(
not defined for a graph with unconnected nodes",
));
}
for (node, path) in paths {
let path_map = path.paths;
// Iterate over node indices for a deterministic order
for node in graph.graph.node_indices().map(|x| x.index()) {
let path_map = &paths[&node].paths;
nodes.remove(&node);
let distance = &distances[&node];
for v in &nodes {
Expand All @@ -113,3 +120,167 @@ fn _metric_closure_edges(
}
Ok(out_vec)
}

/// Return an approximation to the minimum Steiner tree of a graph.
///
/// The minimum tree of ``graph`` with regard to a set of ``terminal_nodes``
/// is a tree within ``graph`` that spans those nodes and has a minimum size
/// (measured as the sum of edge weights) amoung all such trees.
///
/// The minimum steiner tree can be approximated by computing the minimum
/// spanning tree of the subgraph of the metric closure of ``graph`` induced
/// by the terminal nodes, where the metric closure of ``graph`` is the
/// complete graph in which each edge is weighted by the shortest path distance
/// between nodes in ``graph``.
///
/// This algorithm [1]_ produces a tree whose weight is within a
/// :math:`(2 - (2 / t))` factor of the weight of the optimal Steiner tree
/// where :math:`t` is the number of terminal nodes.
///
/// :param PyGraph graph: The graph to compute the minimum Steiner tree for
/// :param list terminal_nodes: The list of node indices for which the Steiner
/// tree is to be computed for.
/// :param weight_fn: A callable object that will be passed an edge's
/// weight/data payload and expected to return a ``float``. For example,
/// you can use ``weight_fn=float`` to cast every weight as a float.
///
/// :returns: An approximation to the minimal steiner tree of ``graph`` induced
/// by ``terminal_nodes``.
/// :rtype: PyGraph
///
/// .. [1] Kou, Markowsky & Berman,
/// "A fast algorithm for Steiner trees"
/// Acta Informatica 15, 141–145 (1981).
/// https://link.springer.com/article/10.1007/BF00288961
#[pyfunction]
#[pyo3(text_signature = "(graph, terminal_nodes, weight_fn, /)")]
pub fn steiner_tree(
py: Python,
graph: &graph::PyGraph,
terminal_nodes: Vec<usize>,
weight_fn: PyObject,
) -> PyResult<graph::PyGraph> {
let terminal_node_set: HashSet<usize> =
terminal_nodes.into_iter().collect();
let metric_edges =
_metric_closure_edges(py, graph, weight_fn.clone_ref(py))?;
// Calculate mst edges from metric closure edge list:
let mut subgraphs = UnionFind::<usize>::new(graph.graph.node_bound());
let mut edge_list: Vec<MetricClosureEdge> =
Vec::with_capacity(metric_edges.len());
for edge in metric_edges {
if !terminal_node_set.contains(&edge.source)
|| !terminal_node_set.contains(&edge.target)
{
continue;
}
let weight = edge.distance;
if weight.is_nan() {
return Err(PyValueError::new_err("NaN found as an edge weight"));
}
edge_list.push(edge);
}
edge_list.par_sort_unstable_by(|a, b| {
let weight_a = (a.distance, a.source, a.target);
let weight_b = (b.distance, b.source, b.target);
weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less)
});
let mut mst_edges: Vec<MetricClosureEdge> = Vec::new();
for float_edge_pair in edge_list {
let u = float_edge_pair.source;
let v = float_edge_pair.target;
if subgraphs.union(u, v) {
mst_edges.push(float_edge_pair);
}
}
// Generate the output graph from the MST of the metric closure
let out_edge_list: Vec<[usize; 2]> = mst_edges
.iter()
.map(|edge| pairwise(edge.path.clone()))
.flatten()
.filter_map(|x| x.0.map(|a| [a, x.1]))
.collect();
let out_edges: HashSet<(usize, usize)> =
out_edge_list.iter().map(|x| (x[0], x[1])).collect();
let mut out_graph = graph.clone();
let out_nodes: HashSet<NodeIndex> = out_edge_list
.iter()
.map(|x| x.iter())
.flatten()
.copied()
.map(NodeIndex::new)
.collect();
for node in graph
.graph
.node_indices()
.filter(|node| !out_nodes.contains(node))
{
out_graph.graph.remove_node(node);
out_graph.node_removed = true;
}
for edge in graph.graph.edge_references().filter(|edge| {
let source = edge.source().index();
let target = edge.target().index();
!out_edges.contains(&(source, target))
&& !out_edges.contains(&(target, source))
}) {
out_graph.graph.remove_edge(edge.id());
}
// Deduplicate potential duplicate edges
deduplicate_edges(py, &mut out_graph, &weight_fn)?;
Ok(out_graph)
}

fn deduplicate_edges(
py: Python,
out_graph: &mut graph::PyGraph,
weight_fn: &PyObject,
) -> PyResult<()> {
if out_graph.multigraph {
// Find all edges between nodes
let mut duplicate_map: HashMap<
[NodeIndex; 2],
Vec<(EdgeIndex, PyObject)>,
> = HashMap::new();
for edge in out_graph.graph.edge_references() {
if duplicate_map.contains_key(&[edge.source(), edge.target()]) {
duplicate_map
.get_mut(&[edge.source(), edge.target()])
.unwrap()
.push((edge.id(), edge.weight().clone_ref(py)));
} else if duplicate_map
.contains_key(&[edge.target(), edge.source()])
{
duplicate_map
.get_mut(&[edge.target(), edge.source()])
.unwrap()
.push((edge.id(), edge.weight().clone_ref(py)));
} else {
duplicate_map.insert(
[edge.source(), edge.target()],
vec![(edge.id(), edge.weight().clone_ref(py))],
);
}
}
// For a node pair with > 1 edge find minimum edge and remove others
for edges_raw in duplicate_map.values().filter(|x| x.len() > 1) {
let mut edges: Vec<(EdgeIndex, f64)> =
Vec::with_capacity(edges_raw.len());
for edge in edges_raw {
let res = weight_fn.call1(py, (&edge.1,))?;
let raw = res.to_object(py);
let weight = raw.extract(py)?;
edges.push((edge.0, weight));
}
edges.par_sort_unstable_by(|a, b| {
let weight_a = a.1;
let weight_b = b.1;
weight_a.partial_cmp(&weight_b).unwrap_or(Ordering::Less)
});
edges[1..].iter().for_each(|x| {
out_graph.graph.remove_edge(x.0);
});
}
}
Ok(())
}
80 changes: 78 additions & 2 deletions tests/graph/test_steiner_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import retworkx


class TestMetricClosure(unittest.TestCase):
class TestSteinerTree(unittest.TestCase):
def setUp(self):
self.graph = retworkx.PyGraph()
self.graph = retworkx.PyGraph(multigraph=False)
self.graph.add_node(None)
self.graph.extend_from_weighted_edge_list(
[
Expand Down Expand Up @@ -112,3 +112,79 @@ def test_metric_closure_empty_graph(self):
graph = retworkx.PyGraph()
closure = retworkx.metric_closure(graph, weight_fn=float)
self.assertEqual([], closure.weighted_edge_list())

def test_steiner_graph(self):
steiner_tree = retworkx.steiner_tree(
self.graph, [1, 2, 3, 4, 5], weight_fn=float
)
expected_steiner_tree = [
(1, 2, 10),
(2, 3, 10),
(2, 7, 1),
(3, 4, 10),
(7, 5, 1),
]
steiner_tree_edge_list = steiner_tree.weighted_edge_list()
for edge in expected_steiner_tree:
self.assertIn(edge, steiner_tree_edge_list)

def test_steiner_graph_multigraph(self):
edge_list = [
(1, 2, 1),
(2, 3, 999),
(2, 3, 1),
(3, 4, 1),
(3, 5, 1),
]
graph = retworkx.PyGraph()
graph.extend_from_weighted_edge_list(edge_list)
graph.remove_node(0)
terminal_nodes = [2, 4, 5]
tree = retworkx.steiner_tree(graph, terminal_nodes, weight_fn=float)
expected_edges = [
(2, 3, 1),
(3, 4, 1),
(3, 5, 1),
]
steiner_tree_edge_list = tree.weighted_edge_list()
for edge in expected_edges:
self.assertIn(edge, steiner_tree_edge_list)

def test_not_connected_steiner_tree(self):
self.graph.add_node(None)
with self.assertRaises(ValueError):
retworkx.steiner_tree(self.graph, [0, 1, 2], weight_fn=float)

def test_steiner_tree_empty_graph(self):
graph = retworkx.PyGraph()
tree = retworkx.steiner_tree(graph, [], weight_fn=float)
self.assertEqual([], tree.weighted_edge_list())

def test_equal_distance_graph(self):
n = 3
graph = retworkx.PyGraph()
graph.add_nodes_from(range(n + 5))
graph.add_edges_from(
[
(n, n + 1, 0.5),
(n, n + 2, 0.5),
(n + 1, n + 2, 0.5),
(n, n + 3, 0.5),
(n + 1, n + 4, 0.5),
]
)
graph.add_edges_from([(i, n + 2, 2) for i in range(n)])
terminals = list(range(5)) + [n + 3, n + 4]
tree = retworkx.steiner_tree(graph, terminals, weight_fn=float)
# Assert no cycle
self.assertEqual(retworkx.cycle_basis(tree), [])
expected_edges = [
(3, 4, 0.5),
(3, 5, 0.5),
(3, 6, 0.5),
(4, 7, 0.5),
(0, 5, 2),
(1, 5, 2),
(2, 5, 2),
]
self.assertEqual(tree.weighted_edge_list(), expected_edges)

0 comments on commit 05ae8b3

Please sign in to comment.