forked from Qiskit/rustworkx
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add metric closure function (Qiskit#390)
* Add metric closure function This commit adds a metric closure function to retworkx. The metric closure of a graph is the complete graph in which each edge is weighted by the shortest path distance between the nodes in the graph. This is the first step towards implementing the minimum steiner graph function because the first step for that is to compute the metric closure for the graph. Additionally, this function is entirely self contained to a new rust module steiner_tree.rs (which we'll eventually add the steiner_tree() function too) in the interest of showing the eventual organizational structure we want for Qiskit#300 (although having an algorithms namespace might make sense for Qiskit#300). In a separate PR addressing Qiskit#300 we should start to reorganize the rust code in lib.rs like this to make things easier to find. Partially implements Qiskit#389 * Move docs to Other Algorithms section * Fix allocation size for output vector * Rework disconnected graph logic The previous check for a disconnected graph only caught graphs with completely disconnected nodes and didn't detect all disconnected graphs. This commit fixes the logic by check the first node in the graph has a path to all other nodes. * Fix handling of empty graph to not overflow * Update src/lib.rs Co-authored-by: Ivan Carvalho <[email protected]> * Rework final loop to be over hashmap instead of indices Co-authored-by: Ivan Carvalho <[email protected]>
- Loading branch information
1 parent
4b0aafa
commit 0919144
Showing
5 changed files
with
264 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
--- | ||
features: | ||
- | | ||
Added a new function, :func:`~retworkx.metric_closure`, which is used | ||
to generate the metric closure of a given graph. This function is used in | ||
the computation of for calculating the Steiner Tree using the algorithm | ||
from | ||
`Kou, Markowsky & Berman (1981). "A fast algorithm for Steiner trees". | ||
<https://link.springer.com/article/10.1007/BF00288961>`__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
// Licensed under the Apache License, Version 2.0 (the "License"); you may | ||
// not use this file except in compliance with the License. You may obtain | ||
// a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
// License for the specific language governing permissions and limitations | ||
// under the License. | ||
|
||
use hashbrown::{HashMap, HashSet}; | ||
|
||
use pyo3::exceptions::PyValueError; | ||
use pyo3::prelude::*; | ||
use pyo3::Python; | ||
|
||
use petgraph::graph::NodeIndex; | ||
|
||
use crate::_all_pairs_dijkstra_shortest_paths; | ||
use crate::graph; | ||
|
||
struct MetricClosureEdge { | ||
source: usize, | ||
target: usize, | ||
distance: f64, | ||
path: Vec<usize>, | ||
} | ||
|
||
/// Return the metric closure of a graph | ||
/// | ||
/// The metric closure of a graph is the complete graph in which each edge is | ||
/// weighted by the shortest path distance between the nodes in the graph. | ||
/// | ||
/// :param PyGraph graph: The input graph to find the metric closure for | ||
/// :param weight_fn: A callable object that will be passed an edge's | ||
/// weight/data payload and expected to return a ``float``. For example, | ||
/// you can use ``weight_fn=float`` to cast every weight as a float | ||
/// | ||
/// :return: A metric closure graph from the input graph | ||
/// :rtype: PyGraph | ||
#[pyfunction] | ||
#[pyo3(text_signature = "(graph, weight_fn, /)")] | ||
pub fn metric_closure( | ||
py: Python, | ||
graph: &graph::PyGraph, | ||
weight_fn: PyObject, | ||
) -> PyResult<graph::PyGraph> { | ||
let mut out_graph = graph.clone(); | ||
out_graph.graph.clear_edges(); | ||
let edges = _metric_closure_edges(py, graph, weight_fn)?; | ||
for edge in edges { | ||
out_graph.graph.add_edge( | ||
NodeIndex::new(edge.source), | ||
NodeIndex::new(edge.target), | ||
(edge.distance, edge.path).to_object(py), | ||
); | ||
} | ||
Ok(out_graph) | ||
} | ||
|
||
fn _metric_closure_edges( | ||
py: Python, | ||
graph: &graph::PyGraph, | ||
weight_fn: PyObject, | ||
) -> PyResult<Vec<MetricClosureEdge>> { | ||
let node_count = graph.graph.node_count(); | ||
if node_count == 0 { | ||
return Ok(Vec::new()); | ||
} | ||
let mut out_vec = Vec::with_capacity(node_count * (node_count - 1) / 2); | ||
let mut distances = HashMap::with_capacity(graph.graph.node_count()); | ||
let paths = _all_pairs_dijkstra_shortest_paths( | ||
py, | ||
&graph.graph, | ||
weight_fn, | ||
Some(&mut distances), | ||
)? | ||
.paths; | ||
let mut nodes: HashSet<usize> = | ||
graph.graph.node_indices().map(|x| x.index()).collect(); | ||
let first_node = graph | ||
.graph | ||
.node_indices() | ||
.map(|x| x.index()) | ||
.next() | ||
.unwrap(); | ||
let path_keys: HashSet<usize> = | ||
paths[&first_node].paths.keys().copied().collect(); | ||
// first_node will always be missing from path_keys so if the difference | ||
// is > 1 with nodes that means there is another node in the graph that | ||
// first_node doesn't have a path to. | ||
if nodes.difference(&path_keys).count() > 1 { | ||
return Err(PyValueError::new_err( | ||
"The input graph must be a connected graph. The metric closure is \ | ||
not defined for a graph with unconnected nodes", | ||
)); | ||
} | ||
for (node, path) in paths { | ||
let path_map = path.paths; | ||
nodes.remove(&node); | ||
let distance = &distances[&node]; | ||
for v in &nodes { | ||
let v_index = NodeIndex::new(*v); | ||
out_vec.push(MetricClosureEdge { | ||
source: node, | ||
target: *v, | ||
distance: distance[&v_index], | ||
path: path_map[v].clone(), | ||
}); | ||
} | ||
} | ||
Ok(out_vec) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); you may | ||
# not use this file except in compliance with the License. You may obtain | ||
# a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
# License for the specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import pprint | ||
import unittest | ||
|
||
import retworkx | ||
|
||
|
||
class TestMetricClosure(unittest.TestCase): | ||
def setUp(self): | ||
self.graph = retworkx.PyGraph() | ||
self.graph.add_node(None) | ||
self.graph.extend_from_weighted_edge_list( | ||
[ | ||
(1, 2, 10), | ||
(2, 3, 10), | ||
(3, 4, 10), | ||
(4, 5, 10), | ||
(5, 6, 10), | ||
(2, 7, 1), | ||
(7, 5, 1), | ||
] | ||
) | ||
self.graph.remove_node(0) | ||
|
||
def test_metric_closure(self): | ||
closure_graph = retworkx.metric_closure(self.graph, weight_fn=float) | ||
expected_edges = [ | ||
(1, 2, (10.0, [1, 2])), | ||
(1, 3, (20.0, [1, 2, 3])), | ||
(1, 4, (22.0, [1, 2, 7, 5, 4])), | ||
(1, 5, (12.0, [1, 2, 7, 5])), | ||
(1, 6, (22.0, [1, 2, 7, 5, 6])), | ||
(1, 7, (11.0, [1, 2, 7])), | ||
(2, 3, (10.0, [2, 3])), | ||
(2, 4, (12.0, [2, 7, 5, 4])), | ||
(2, 5, (2.0, [2, 7, 5])), | ||
(2, 6, (12, [2, 7, 5, 6])), | ||
(2, 7, (1.0, [2, 7])), | ||
(3, 4, (10.0, [3, 4])), | ||
(3, 5, (12.0, [3, 2, 7, 5])), | ||
(3, 6, (22.0, [3, 2, 7, 5, 6])), | ||
(3, 7, (11.0, [3, 2, 7])), | ||
(4, 5, (10.0, [4, 5])), | ||
(4, 6, (20.0, [4, 5, 6])), | ||
(4, 7, (11.0, [4, 5, 7])), | ||
(5, 6, (10.0, [5, 6])), | ||
(5, 7, (1.0, [5, 7])), | ||
(6, 7, (11.0, [6, 5, 7])), | ||
] | ||
edges = list(closure_graph.weighted_edge_list()) | ||
for edge in expected_edges: | ||
found = False | ||
if edge in edges: | ||
found = True | ||
if not found: | ||
|
||
if ( | ||
edge[1], | ||
edge[0], | ||
(edge[2][0], list(reversed(edge[2][1]))), | ||
) in edges: | ||
found = True | ||
if not found: | ||
self.fail( | ||
f"edge: {edge} nor it's reverse not found in metric " | ||
f"closure output:\n{pprint.pformat(edges)}" | ||
) | ||
|
||
def test_not_connected_metric_closure(self): | ||
self.graph.add_node(None) | ||
with self.assertRaises(ValueError): | ||
retworkx.metric_closure(self.graph, weight_fn=float) | ||
|
||
def test_partially_connected_metric_closure(self): | ||
graph = retworkx.PyGraph() | ||
graph.add_node(None) | ||
graph.extend_from_weighted_edge_list( | ||
[ | ||
(1, 2, 10), | ||
(2, 3, 10), | ||
(3, 4, 10), | ||
(4, 5, 10), | ||
(5, 6, 10), | ||
(2, 7, 1), | ||
(7, 5, 1), | ||
] | ||
) | ||
graph.extend_from_weighted_edge_list( | ||
[ | ||
(0, 8, 20), | ||
(0, 9, 20), | ||
(0, 10, 20), | ||
(8, 10, 10), | ||
(9, 10, 5), | ||
] | ||
) | ||
with self.assertRaises(ValueError): | ||
retworkx.metric_closure(graph, weight_fn=float) | ||
|
||
def test_metric_closure_empty_graph(self): | ||
graph = retworkx.PyGraph() | ||
closure = retworkx.metric_closure(graph, weight_fn=float) | ||
self.assertEqual([], closure.weighted_edge_list()) |