Skip to content

Commit

Permalink
Add metric closure function
Browse files Browse the repository at this point in the history
This commit adds a metric closure function to retworkx. The metric
closure of a graph is the complete graph in which each edge is weighted
by the shortest path distance between the nodes in the graph. This is
the first step towards implementing the minimum steiner graph function
because the first step for that is to compute the metric closure for the
graph.

Additionally, this function is entirely self contained to a new rust module
steiner_tree.rs (which we'll eventually add the steiner_tree() function
too) in the interest of showing the eventual organizational structure we
want for Qiskit#300 (although having an algorithms namespace might make sense
for Qiskit#300). In a separate PR addressing Qiskit#300 we should start to
reorganize the rust code in lib.rs like this to make things easier to
find.

Partially implements Qiskit#389
  • Loading branch information
mtreinish committed Jul 20, 2021
1 parent 9f24a93 commit 0eb166b
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 6 deletions.
10 changes: 10 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,16 @@ Connectivity and Cycles
retworkx.cycle_basis
retworkx.digraph_find_cycle

.. _approximations:

Approximations and Heuristics
-----------------------------

.. autosummary::
:toctree: stubs

retworkx.metric_closure

.. _other-algorithms:

Other Algorithm Functions
Expand Down
9 changes: 9 additions & 0 deletions releasenotes/notes/steiner_tree-3e5282c65095868a.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
features:
- |
Added a new function, :func:`~retworkx.metric_closure`, which is used
to generate the metric closure of a given graph. This function is used in
the computation of for calculating the Steiner Tree using the algorithm
from
`Kou, Markowsky & Berman (1981). "A fast algorithm for Steiner trees".
<https://link.springer.com/article/10.1007/BF00288961>`__
31 changes: 25 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ mod iterators;
mod k_shortest_path;
mod layout;
mod max_weight_matching;
mod steiner_tree;
mod union;

use std::cmp::{Ordering, Reverse};
use std::collections::{BTreeSet, BinaryHeap};
use std::sync::RwLock;

use hashbrown::{HashMap, HashSet};

Expand Down Expand Up @@ -64,6 +66,7 @@ use crate::iterators::{
NodesCountMapping, PathLengthMapping, PathMapping, Pos2DMapping,
WeightedEdgeList,
};
use crate::steiner_tree::*;

trait NodesRemoved {
fn nodes_removed(&self) -> bool;
Expand Down Expand Up @@ -2473,10 +2476,11 @@ fn _all_pairs_dijkstra_path_lengths<Ty: EdgeType + Sync>(
})
}

fn _all_pairs_dijkstra_shortest_paths<Ty: EdgeType + Sync>(
pub fn _all_pairs_dijkstra_shortest_paths<Ty: EdgeType + Sync>(
py: Python,
graph: &StableGraph<PyObject, PyObject, Ty>,
edge_cost_fn: PyObject,
distances: Option<&mut HashMap<usize, HashMap<NodeIndex, f64>>>,
) -> PyResult<AllPairsPathMapping> {
if graph.node_count() == 0 {
return Ok(AllPairsPathMapping {
Expand Down Expand Up @@ -2520,20 +2524,30 @@ fn _all_pairs_dijkstra_shortest_paths<Ty: EdgeType + Sync>(
}
};
let node_indices: Vec<NodeIndex> = graph.node_indices().collect();
Ok(AllPairsPathMapping {
let temp_distances: RwLock<HashMap<usize, HashMap<NodeIndex, f64>>> =
if distances.is_some() {
RwLock::new(HashMap::with_capacity(graph.node_count()))
} else {
// Avoid extra allocation if HashMap isn't used
RwLock::new(HashMap::new())
};
let out_map = AllPairsPathMapping {
paths: node_indices
.into_par_iter()
.map(|x| {
let mut paths: HashMap<NodeIndex, Vec<NodeIndex>> =
HashMap::with_capacity(graph.node_count());
dijkstra::dijkstra(
let distance = dijkstra::dijkstra(
graph,
x,
None,
|e| edge_cost(e.id()),
Some(&mut paths),
)
.unwrap();
if distances.is_some() {
temp_distances.write().unwrap().insert(x.index(), distance);
}
let index = x.index();
let out_paths = PathMapping {
paths: paths
Expand All @@ -2558,7 +2572,11 @@ fn _all_pairs_dijkstra_shortest_paths<Ty: EdgeType + Sync>(
(index, out_paths)
})
.collect(),
})
};
if let Some(x) = distances {
*x = temp_distances.read().unwrap().clone()
};
Ok(out_map)
}

/// Calculate the the shortest length from all nodes in a
Expand Down Expand Up @@ -2632,7 +2650,7 @@ pub fn digraph_all_pairs_dijkstra_shortest_paths(
graph: &digraph::PyDiGraph,
edge_cost_fn: PyObject,
) -> PyResult<AllPairsPathMapping> {
_all_pairs_dijkstra_shortest_paths(py, &graph.graph, edge_cost_fn)
_all_pairs_dijkstra_shortest_paths(py, &graph.graph, edge_cost_fn, None)
}

/// Calculate the the shortest length from all nodes in a
Expand Down Expand Up @@ -2698,7 +2716,7 @@ pub fn graph_all_pairs_dijkstra_shortest_paths(
graph: &graph::PyGraph,
edge_cost_fn: PyObject,
) -> PyResult<AllPairsPathMapping> {
_all_pairs_dijkstra_shortest_paths(py, &graph.graph, edge_cost_fn)
_all_pairs_dijkstra_shortest_paths(py, &graph.graph, edge_cost_fn, None)
}

/// Compute the A* shortest path for a PyGraph
Expand Down Expand Up @@ -4879,6 +4897,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(digraph_spring_layout))?;
m.add_wrapped(wrap_pyfunction!(digraph_num_shortest_paths_unweighted))?;
m.add_wrapped(wrap_pyfunction!(graph_num_shortest_paths_unweighted))?;
m.add_wrapped(wrap_pyfunction!(metric_closure))?;
m.add_class::<digraph::PyDiGraph>()?;
m.add_class::<graph::PyGraph>()?;
m.add_class::<iterators::BFSSuccessors>()?;
Expand Down
110 changes: 110 additions & 0 deletions src/steiner_tree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

// This module was originally copied and forked from the upstream petgraph
// repository, specifically:
// https://github.com/petgraph/petgraph/blob/0.5.1/src/dijkstra.rs
// this was necessary to modify the error handling to allow python callables
// to be use for the input functions for edge_cost and return any exceptions
// raised in Python instead of panicking

use hashbrown::{HashMap, HashSet};

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::Python;

use petgraph::graph::NodeIndex;

use crate::_all_pairs_dijkstra_shortest_paths;
use crate::graph;

#[derive(Debug)]
struct MetricClosureEdge {
source: usize,
target: usize,
distance: f64,
path: Vec<usize>,
}

/// Return the metric closure of a graph
///
/// The metric closure of a graph is the complete graph in which each edge is
/// weighted by the shortest path distance between the nodes in the graph.
///
/// :param PyGraph graph: The input graph to find the metric closure for
/// :param weight_fn: A callable object that will be passed an edge's
/// weight/data payload and expected to return a ``float``. For example,
/// you can use ``weight_fn=float`` to cast every weight as a float
///
/// :return: A metric closure graph from the input graph
/// :rtype: PyGraph
#[pyfunction]
#[pyo3(text_signature = "(graph, weight_fn, /)")]
pub fn metric_closure(
py: Python,
graph: &graph::PyGraph,
weight_fn: PyObject,
) -> PyResult<graph::PyGraph> {
let mut out_graph = graph.clone();
out_graph.graph.clear_edges();
let edges = _metric_closure_edges(py, graph, weight_fn)?;
for edge in edges {
out_graph.graph.add_edge(
NodeIndex::new(edge.source),
NodeIndex::new(edge.target),
(edge.distance, edge.path).to_object(py),
);
}
Ok(out_graph)
}

fn _metric_closure_edges(
py: Python,
graph: &graph::PyGraph,
weight_fn: PyObject,
) -> PyResult<Vec<MetricClosureEdge>> {
let mut out_vec = Vec::with_capacity(graph.graph.edge_count());
let mut distances = HashMap::with_capacity(graph.graph.node_count());
let paths = _all_pairs_dijkstra_shortest_paths(
py,
&graph.graph,
weight_fn,
Some(&mut distances),
)?
.paths;
let path_keys: HashSet<usize> = paths
.keys()
.filter(|x| !paths[x].paths.is_empty())
.copied()
.collect();
let mut nodes: HashSet<usize> =
graph.graph.node_indices().map(|x| x.index()).collect();
if nodes.difference(&path_keys).count() > 0 {
return Err(PyValueError::new_err("The input graph must be a connected graph. The metric closure is not defined for a graph with unconnected nodes"));
}
for node in graph.graph.node_indices().map(|x| x.index()) {
let path_map = &paths[&node].paths;
nodes.remove(&node);
let distance = &distances[&node];
for v in &nodes {
let v_index = NodeIndex::new(*v);
out_vec.push(MetricClosureEdge {
source: node,
target: *v,
distance: distance[&v_index],
path: path_map[v].clone(),
});
}
}
Ok(out_vec)
}
71 changes: 71 additions & 0 deletions tests/graph/test_steiner_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pprint
import unittest

import retworkx


class TestMetricClosure(unittest.TestCase):
def setUp(self):
self.graph = retworkx.PyGraph()
self.graph.add_node(None)
self.graph.extend_from_weighted_edge_list(
[
(1, 2, 10),
(2, 3, 10),
(3, 4, 10),
(4, 5, 10),
(5, 6, 10),
(2, 7, 1),
(7, 5, 1),
]
)
self.graph.remove_node(0)

def test_metric_closure(self):
closure_graph = retworkx.metric_closure(self.graph, weight_fn=float)
expected_edges = [
(1, 2, (10.0, [1, 2])),
(1, 3, (20.0, [1, 2, 3])),
(1, 4, (22.0, [1, 2, 7, 5, 4])),
(1, 5, (12.0, [1, 2, 7, 5])),
(1, 6, (22.0, [1, 2, 7, 5, 6])),
(1, 7, (11.0, [1, 2, 7])),
(2, 3, (10.0, [2, 3])),
(2, 4, (12.0, [2, 7, 5, 4])),
(2, 5, (2.0, [2, 7, 5])),
(2, 6, (12, [2, 7, 5, 6])),
(2, 7, (1.0, [2, 7])),
(3, 4, (10.0, [3, 4])),
(3, 5, (12.0, [3, 2, 7, 5])),
(3, 6, (22.0, [3, 2, 7, 5, 6])),
(3, 7, (11.0, [3, 2, 7])),
(4, 5, (10.0, [4, 5])),
(4, 6, (20.0, [4, 5, 6])),
(4, 7, (11.0, [4, 5, 7])),
(5, 6, (10.0, [5, 6])),
(5, 7, (1.0, [5, 7])),
(6, 7, (11.0, [6, 5, 7])),
]
edges = list(closure_graph.weighted_edge_list())
for edge in expected_edges:
found = False
if edge in edges:
found = True
if not found:

if (
edge[1],
edge[0],
(edge[2][0], list(reversed(edge[2][1]))),
) in edges:
found = True
if not found:
self.fail(
f"edge: {edge} nor it's reverse not found in metric "
f"closure output:\n{pprint.pformat(edges)}"
)

def test_not_connected_metric_closure(self):
self.graph.add_node(None)
with self.assertRaises(ValueError):
retworkx.metric_closure(self.graph, weight_fn=float)

0 comments on commit 0eb166b

Please sign in to comment.