Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add metric closure function #390

Merged
merged 10 commits into from
Aug 4, 2021
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ Other Algorithm Functions
retworkx.core_number
retworkx.graph_greedy_color
retworkx.digraph_union
retworkx.metric_closure

Generators
==========
Expand Down
9 changes: 9 additions & 0 deletions releasenotes/notes/steiner_tree-3e5282c65095868a.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
features:
- |
Added a new function, :func:`~retworkx.metric_closure`, which is used
to generate the metric closure of a given graph. This function is used in
the computation of for calculating the Steiner Tree using the algorithm
from
`Kou, Markowsky & Berman (1981). "A fast algorithm for Steiner trees".
<https://link.springer.com/article/10.1007/BF00288961>`__
31 changes: 25 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ mod iterators;
mod k_shortest_path;
mod layout;
mod max_weight_matching;
mod steiner_tree;
mod union;

use std::cmp::{Ordering, Reverse};
use std::collections::{BTreeSet, BinaryHeap};
use std::sync::RwLock;

use hashbrown::{HashMap, HashSet};

Expand Down Expand Up @@ -65,6 +67,7 @@ use crate::iterators::{
NodesCountMapping, PathLengthMapping, PathMapping, Pos2DMapping,
WeightedEdgeList,
};
use crate::steiner_tree::*;

trait NodesRemoved {
fn nodes_removed(&self) -> bool;
Expand Down Expand Up @@ -2575,10 +2578,11 @@ fn _all_pairs_dijkstra_path_lengths<Ty: EdgeType + Sync>(
})
}

fn _all_pairs_dijkstra_shortest_paths<Ty: EdgeType + Sync>(
pub fn _all_pairs_dijkstra_shortest_paths<Ty: EdgeType + Sync>(
py: Python,
graph: &StableGraph<PyObject, PyObject, Ty>,
edge_cost_fn: PyObject,
distances: Option<&mut HashMap<usize, HashMap<NodeIndex, f64>>>,
) -> PyResult<AllPairsPathMapping> {
if graph.node_count() == 0 {
return Ok(AllPairsPathMapping {
Expand Down Expand Up @@ -2622,20 +2626,30 @@ fn _all_pairs_dijkstra_shortest_paths<Ty: EdgeType + Sync>(
}
};
let node_indices: Vec<NodeIndex> = graph.node_indices().collect();
Ok(AllPairsPathMapping {
let temp_distances: RwLock<HashMap<usize, HashMap<NodeIndex, f64>>> =
if distances.is_some() {
RwLock::new(HashMap::with_capacity(graph.node_count()))
} else {
// Avoid extra allocation if HashMap isn't used
RwLock::new(HashMap::new())
};
let out_map = AllPairsPathMapping {
paths: node_indices
.into_par_iter()
.map(|x| {
let mut paths: HashMap<NodeIndex, Vec<NodeIndex>> =
HashMap::with_capacity(graph.node_count());
dijkstra::dijkstra(
let distance = dijkstra::dijkstra(
graph,
x,
None,
|e| edge_cost(e.id()),
Some(&mut paths),
)
.unwrap();
if distances.is_some() {
temp_distances.write().unwrap().insert(x.index(), distance);
}
let index = x.index();
let out_paths = PathMapping {
paths: paths
Expand All @@ -2660,7 +2674,11 @@ fn _all_pairs_dijkstra_shortest_paths<Ty: EdgeType + Sync>(
(index, out_paths)
})
.collect(),
})
};
if let Some(x) = distances {
*x = temp_distances.read().unwrap().clone()
};
Ok(out_map)
}

/// Calculate the the shortest length from all nodes in a
Expand Down Expand Up @@ -2734,7 +2752,7 @@ pub fn digraph_all_pairs_dijkstra_shortest_paths(
graph: &digraph::PyDiGraph,
edge_cost_fn: PyObject,
) -> PyResult<AllPairsPathMapping> {
_all_pairs_dijkstra_shortest_paths(py, &graph.graph, edge_cost_fn)
_all_pairs_dijkstra_shortest_paths(py, &graph.graph, edge_cost_fn, None)
}

/// Calculate the the shortest length from all nodes in a
Expand Down Expand Up @@ -2800,7 +2818,7 @@ pub fn graph_all_pairs_dijkstra_shortest_paths(
graph: &graph::PyGraph,
edge_cost_fn: PyObject,
) -> PyResult<AllPairsPathMapping> {
_all_pairs_dijkstra_shortest_paths(py, &graph.graph, edge_cost_fn)
_all_pairs_dijkstra_shortest_paths(py, &graph.graph, edge_cost_fn, None)
}

/// Compute the A* shortest path for a PyGraph
Expand Down Expand Up @@ -4983,6 +5001,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(digraph_spring_layout))?;
m.add_wrapped(wrap_pyfunction!(digraph_num_shortest_paths_unweighted))?;
m.add_wrapped(wrap_pyfunction!(graph_num_shortest_paths_unweighted))?;
m.add_wrapped(wrap_pyfunction!(metric_closure))?;
m.add_class::<digraph::PyDiGraph>()?;
m.add_class::<graph::PyGraph>()?;
m.add_class::<iterators::BFSSuccessors>()?;
Expand Down
111 changes: 111 additions & 0 deletions src/steiner_tree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use hashbrown::{HashMap, HashSet};

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::Python;

use petgraph::graph::NodeIndex;

use crate::_all_pairs_dijkstra_shortest_paths;
use crate::graph;

struct MetricClosureEdge {
source: usize,
target: usize,
distance: f64,
path: Vec<usize>,
}

/// Return the metric closure of a graph
///
/// The metric closure of a graph is the complete graph in which each edge is
/// weighted by the shortest path distance between the nodes in the graph.
///
/// :param PyGraph graph: The input graph to find the metric closure for
/// :param weight_fn: A callable object that will be passed an edge's
/// weight/data payload and expected to return a ``float``. For example,
/// you can use ``weight_fn=float`` to cast every weight as a float
///
/// :return: A metric closure graph from the input graph
/// :rtype: PyGraph
#[pyfunction]
#[pyo3(text_signature = "(graph, weight_fn, /)")]
pub fn metric_closure(
py: Python,
graph: &graph::PyGraph,
weight_fn: PyObject,
) -> PyResult<graph::PyGraph> {
let mut out_graph = graph.clone();
out_graph.graph.clear_edges();
let edges = _metric_closure_edges(py, graph, weight_fn)?;
for edge in edges {
out_graph.graph.add_edge(
NodeIndex::new(edge.source),
NodeIndex::new(edge.target),
(edge.distance, edge.path).to_object(py),
);
}
Ok(out_graph)
}

fn _metric_closure_edges(
py: Python,
graph: &graph::PyGraph,
weight_fn: PyObject,
) -> PyResult<Vec<MetricClosureEdge>> {
let node_count = graph.graph.node_count();
let mut out_vec = Vec::with_capacity(node_count * (node_count - 1) / 2);
let mut distances = HashMap::with_capacity(graph.graph.node_count());
let paths = _all_pairs_dijkstra_shortest_paths(
py,
&graph.graph,
weight_fn,
Some(&mut distances),
)?
.paths;
let mut nodes: HashSet<usize> =
graph.graph.node_indices().map(|x| x.index()).collect();
let first_node = match graph.graph.node_indices().map(|x| x.index()).next()
{
Some(node) => node,
None => return Ok(Vec::new()),
};
let path_keys: HashSet<usize> =
paths[&first_node].paths.keys().copied().collect();
// first_node will always be missing from path_keys so if the difference
// is > 1 with nodes that means there is another node in the graph that
// first_node doesn't have a path to.
if nodes.difference(&path_keys).count() > 1 {
return Err(PyValueError::new_err(
"The input graph must be a connected graph. The metric closure is \
not defined for a graph with unconnected nodes",
));
}
for node in graph.graph.node_indices().map(|x| x.index()) {
let path_map = &paths[&node].paths;
nodes.remove(&node);
let distance = &distances[&node];
for v in &nodes {
let v_index = NodeIndex::new(*v);
out_vec.push(MetricClosureEdge {
source: node,
target: *v,
distance: distance[&v_index],
path: path_map[v].clone(),
});
}
}
Ok(out_vec)
}
114 changes: 114 additions & 0 deletions tests/graph/test_steiner_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import pprint
import unittest

import retworkx


class TestMetricClosure(unittest.TestCase):
def setUp(self):
self.graph = retworkx.PyGraph()
self.graph.add_node(None)
self.graph.extend_from_weighted_edge_list(
[
(1, 2, 10),
(2, 3, 10),
(3, 4, 10),
(4, 5, 10),
(5, 6, 10),
(2, 7, 1),
(7, 5, 1),
]
)
self.graph.remove_node(0)

def test_metric_closure(self):
closure_graph = retworkx.metric_closure(self.graph, weight_fn=float)
expected_edges = [
(1, 2, (10.0, [1, 2])),
(1, 3, (20.0, [1, 2, 3])),
(1, 4, (22.0, [1, 2, 7, 5, 4])),
(1, 5, (12.0, [1, 2, 7, 5])),
(1, 6, (22.0, [1, 2, 7, 5, 6])),
(1, 7, (11.0, [1, 2, 7])),
(2, 3, (10.0, [2, 3])),
(2, 4, (12.0, [2, 7, 5, 4])),
(2, 5, (2.0, [2, 7, 5])),
(2, 6, (12, [2, 7, 5, 6])),
(2, 7, (1.0, [2, 7])),
(3, 4, (10.0, [3, 4])),
(3, 5, (12.0, [3, 2, 7, 5])),
(3, 6, (22.0, [3, 2, 7, 5, 6])),
(3, 7, (11.0, [3, 2, 7])),
(4, 5, (10.0, [4, 5])),
(4, 6, (20.0, [4, 5, 6])),
(4, 7, (11.0, [4, 5, 7])),
(5, 6, (10.0, [5, 6])),
(5, 7, (1.0, [5, 7])),
(6, 7, (11.0, [6, 5, 7])),
]
edges = list(closure_graph.weighted_edge_list())
for edge in expected_edges:
found = False
if edge in edges:
found = True
if not found:

if (
edge[1],
edge[0],
(edge[2][0], list(reversed(edge[2][1]))),
) in edges:
found = True
if not found:
self.fail(
f"edge: {edge} nor it's reverse not found in metric "
f"closure output:\n{pprint.pformat(edges)}"
)

def test_not_connected_metric_closure(self):
self.graph.add_node(None)
with self.assertRaises(ValueError):
retworkx.metric_closure(self.graph, weight_fn=float)

def test_partially_connected_metric_closure(self):
graph = retworkx.PyGraph()
graph.add_node(None)
graph.extend_from_weighted_edge_list(
[
(1, 2, 10),
(2, 3, 10),
(3, 4, 10),
(4, 5, 10),
(5, 6, 10),
(2, 7, 1),
(7, 5, 1),
]
)
graph.extend_from_weighted_edge_list(
[
(0, 8, 20),
(0, 9, 20),
(0, 10, 20),
(8, 10, 10),
(9, 10, 5),
]
)
with self.assertRaises(ValueError):
retworkx.metric_closure(graph, weight_fn=float)

def test_metric_closure_empty_graph(self):
graph = retworkx.PyGraph()
closure = retworkx.metric_closure(graph, weight_fn=float)
self.assertEqual([], closure.weighted_edge_list())