forked from Qiskit/rustworkx
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This commit adds a metric closure function to retworkx. The metric closure of a graph is the complete graph in which each edge is weighted by the shortest path distance between the nodes in the graph. This is the first step towards implementing the minimum steiner graph function because the first step for that is to compute the metric closure for the graph. Additionally, this function is entirely self contained to a new rust module steiner_tree.rs (which we'll eventually add the steiner_tree() function too) in the interest of showing the eventual organizational structure we want for Qiskit#300 (although having an algorithms namespace might make sense for Qiskit#300). In a separate PR addressing Qiskit#300 we should start to reorganize the rust code in lib.rs like this to make things easier to find. Partially implements Qiskit#389
- Loading branch information
Showing
5 changed files
with
225 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
--- | ||
features: | ||
- | | ||
Added a new function, :func:`~retworkx.metric_closure`, which is used | ||
to generate the metric closure of a given graph. This function is used in | ||
the computation of for calculating the Steiner Tree using the algorithm | ||
from | ||
`Kou, Markowsky & Berman (1981). "A fast algorithm for Steiner trees". | ||
<https://link.springer.com/article/10.1007/BF00288961>`__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
// Licensed under the Apache License, Version 2.0 (the "License"); you may | ||
// not use this file except in compliance with the License. You may obtain | ||
// a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
// License for the specific language governing permissions and limitations | ||
// under the License. | ||
|
||
// This module was originally copied and forked from the upstream petgraph | ||
// repository, specifically: | ||
// https://github.com/petgraph/petgraph/blob/0.5.1/src/dijkstra.rs | ||
// this was necessary to modify the error handling to allow python callables | ||
// to be use for the input functions for edge_cost and return any exceptions | ||
// raised in Python instead of panicking | ||
|
||
use hashbrown::{HashMap, HashSet}; | ||
|
||
use pyo3::exceptions::PyValueError; | ||
use pyo3::prelude::*; | ||
use pyo3::Python; | ||
|
||
use petgraph::graph::NodeIndex; | ||
|
||
use crate::_all_pairs_dijkstra_shortest_paths; | ||
use crate::graph; | ||
|
||
#[derive(Debug)] | ||
struct MetricClosureEdge { | ||
source: usize, | ||
target: usize, | ||
distance: f64, | ||
path: Vec<usize>, | ||
} | ||
|
||
/// Return the metric closure of a graph | ||
/// | ||
/// The metric closure of a graph is the complete graph in which each edge is | ||
/// weighted by the shortest path distance between the nodes in the graph. | ||
/// | ||
/// :param PyGraph graph: The input graph to find the metric closure for | ||
/// :param weight_fn: A callable object that will be passed an edge's | ||
/// weight/data payload and expected to return a ``float``. For example, | ||
/// you can use ``weight_fn=float`` to cast every weight as a float | ||
/// | ||
/// :return: A metric closure graph from the input graph | ||
/// :rtype: PyGraph | ||
#[pyfunction] | ||
#[pyo3(text_signature = "(graph, weight_fn, /)")] | ||
pub fn metric_closure( | ||
py: Python, | ||
graph: &graph::PyGraph, | ||
weight_fn: PyObject, | ||
) -> PyResult<graph::PyGraph> { | ||
let mut out_graph = graph.clone(); | ||
out_graph.graph.clear_edges(); | ||
let edges = _metric_closure_edges(py, graph, weight_fn)?; | ||
for edge in edges { | ||
out_graph.graph.add_edge( | ||
NodeIndex::new(edge.source), | ||
NodeIndex::new(edge.target), | ||
(edge.distance, edge.path).to_object(py), | ||
); | ||
} | ||
Ok(out_graph) | ||
} | ||
|
||
fn _metric_closure_edges( | ||
py: Python, | ||
graph: &graph::PyGraph, | ||
weight_fn: PyObject, | ||
) -> PyResult<Vec<MetricClosureEdge>> { | ||
let mut out_vec = Vec::with_capacity(graph.graph.edge_count()); | ||
let mut distances = HashMap::with_capacity(graph.graph.node_count()); | ||
let paths = _all_pairs_dijkstra_shortest_paths( | ||
py, | ||
&graph.graph, | ||
weight_fn, | ||
Some(&mut distances), | ||
)? | ||
.paths; | ||
let path_keys: HashSet<usize> = paths | ||
.keys() | ||
.filter(|x| !paths[x].paths.is_empty()) | ||
.copied() | ||
.collect(); | ||
let mut nodes: HashSet<usize> = | ||
graph.graph.node_indices().map(|x| x.index()).collect(); | ||
if nodes.difference(&path_keys).count() > 0 { | ||
return Err(PyValueError::new_err("The input graph must be a connected graph. The metric closure is not defined for a graph with unconnected nodes")); | ||
} | ||
for node in graph.graph.node_indices().map(|x| x.index()) { | ||
let path_map = &paths[&node].paths; | ||
nodes.remove(&node); | ||
let distance = &distances[&node]; | ||
for v in &nodes { | ||
let v_index = NodeIndex::new(*v); | ||
out_vec.push(MetricClosureEdge { | ||
source: node, | ||
target: *v, | ||
distance: distance[&v_index], | ||
path: path_map[v].clone(), | ||
}); | ||
} | ||
} | ||
Ok(out_vec) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import pprint | ||
import unittest | ||
|
||
import retworkx | ||
|
||
|
||
class TestMetricClosure(unittest.TestCase): | ||
def setUp(self): | ||
self.graph = retworkx.PyGraph() | ||
self.graph.add_node(None) | ||
self.graph.extend_from_weighted_edge_list( | ||
[ | ||
(1, 2, 10), | ||
(2, 3, 10), | ||
(3, 4, 10), | ||
(4, 5, 10), | ||
(5, 6, 10), | ||
(2, 7, 1), | ||
(7, 5, 1), | ||
] | ||
) | ||
self.graph.remove_node(0) | ||
|
||
def test_metric_closure(self): | ||
closure_graph = retworkx.metric_closure(self.graph, weight_fn=float) | ||
expected_edges = [ | ||
(1, 2, (10.0, [1, 2])), | ||
(1, 3, (20.0, [1, 2, 3])), | ||
(1, 4, (22.0, [1, 2, 7, 5, 4])), | ||
(1, 5, (12.0, [1, 2, 7, 5])), | ||
(1, 6, (22.0, [1, 2, 7, 5, 6])), | ||
(1, 7, (11.0, [1, 2, 7])), | ||
(2, 3, (10.0, [2, 3])), | ||
(2, 4, (12.0, [2, 7, 5, 4])), | ||
(2, 5, (2.0, [2, 7, 5])), | ||
(2, 6, (12, [2, 7, 5, 6])), | ||
(2, 7, (1.0, [2, 7])), | ||
(3, 4, (10.0, [3, 4])), | ||
(3, 5, (12.0, [3, 2, 7, 5])), | ||
(3, 6, (22.0, [3, 2, 7, 5, 6])), | ||
(3, 7, (11.0, [3, 2, 7])), | ||
(4, 5, (10.0, [4, 5])), | ||
(4, 6, (20.0, [4, 5, 6])), | ||
(4, 7, (11.0, [4, 5, 7])), | ||
(5, 6, (10.0, [5, 6])), | ||
(5, 7, (1.0, [5, 7])), | ||
(6, 7, (11.0, [6, 5, 7])), | ||
] | ||
edges = list(closure_graph.weighted_edge_list()) | ||
for edge in expected_edges: | ||
found = False | ||
if edge in edges: | ||
found = True | ||
if not found: | ||
|
||
if ( | ||
edge[1], | ||
edge[0], | ||
(edge[2][0], list(reversed(edge[2][1]))), | ||
) in edges: | ||
found = True | ||
if not found: | ||
self.fail( | ||
f"edge: {edge} nor it's reverse not found in metric " | ||
f"closure output:\n{pprint.pformat(edges)}" | ||
) | ||
|
||
def test_not_connected_metric_closure(self): | ||
self.graph.add_node(None) | ||
with self.assertRaises(ValueError): | ||
retworkx.metric_closure(self.graph, weight_fn=float) |