Skip to content

Commit

Permalink
Add metric closure function (Qiskit#390)
Browse files Browse the repository at this point in the history
* Add metric closure function

This commit adds a metric closure function to retworkx. The metric
closure of a graph is the complete graph in which each edge is weighted
by the shortest path distance between the nodes in the graph. This is
the first step towards implementing the minimum steiner graph function
because the first step for that is to compute the metric closure for the
graph.

Additionally, this function is entirely self contained to a new rust module
steiner_tree.rs (which we'll eventually add the steiner_tree() function
too) in the interest of showing the eventual organizational structure we
want for Qiskit#300 (although having an algorithms namespace might make sense
for Qiskit#300). In a separate PR addressing Qiskit#300 we should start to
reorganize the rust code in lib.rs like this to make things easier to
find.

Partially implements Qiskit#389

* Move docs to Other Algorithms section

* Fix allocation size for output vector

* Rework disconnected graph logic

The previous check for a disconnected graph only caught graphs with
completely disconnected nodes and didn't detect all disconnected
graphs. This commit fixes the logic by check the first node in the graph
has a path to all other nodes.

* Fix handling of empty graph to not overflow

* Update src/lib.rs

Co-authored-by: Ivan Carvalho <[email protected]>

* Rework final loop to be over hashmap instead of indices

Co-authored-by: Ivan Carvalho <[email protected]>
  • Loading branch information
mtreinish and IvanIsCoding authored Aug 4, 2021
1 parent 4b0aafa commit 0919144
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ Other Algorithm Functions
retworkx.core_number
retworkx.graph_greedy_color
retworkx.digraph_union
retworkx.metric_closure

Generators
==========
Expand Down
9 changes: 9 additions & 0 deletions releasenotes/notes/steiner_tree-3e5282c65095868a.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
features:
- |
Added a new function, :func:`~retworkx.metric_closure`, which is used
to generate the metric closure of a given graph. This function is used in
the computation of for calculating the Steiner Tree using the algorithm
from
`Kou, Markowsky & Berman (1981). "A fast algorithm for Steiner trees".
<https://link.springer.com/article/10.1007/BF00288961>`__
31 changes: 25 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ mod iterators;
mod k_shortest_path;
mod layout;
mod max_weight_matching;
mod steiner_tree;
mod union;

use std::cmp::{Ordering, Reverse};
use std::collections::{BTreeSet, BinaryHeap};
use std::sync::RwLock;

use ahash::RandomState;
use hashbrown::{HashMap, HashSet};
Expand Down Expand Up @@ -67,6 +69,7 @@ use crate::iterators::{
NodesCountMapping, PathLengthMapping, PathMapping, Pos2DMapping,
WeightedEdgeList,
};
use steiner_tree::*;

trait NodesRemoved {
fn nodes_removed(&self) -> bool;
Expand Down Expand Up @@ -2713,10 +2716,11 @@ fn _all_pairs_dijkstra_path_lengths<Ty: EdgeType + Sync>(
})
}

fn _all_pairs_dijkstra_shortest_paths<Ty: EdgeType + Sync>(
pub fn _all_pairs_dijkstra_shortest_paths<Ty: EdgeType + Sync>(
py: Python,
graph: &StableGraph<PyObject, PyObject, Ty>,
edge_cost_fn: PyObject,
distances: Option<&mut HashMap<usize, HashMap<NodeIndex, f64>>>,
) -> PyResult<AllPairsPathMapping> {
if graph.node_count() == 0 {
return Ok(AllPairsPathMapping {
Expand Down Expand Up @@ -2760,20 +2764,30 @@ fn _all_pairs_dijkstra_shortest_paths<Ty: EdgeType + Sync>(
}
};
let node_indices: Vec<NodeIndex> = graph.node_indices().collect();
Ok(AllPairsPathMapping {
let temp_distances: RwLock<HashMap<usize, HashMap<NodeIndex, f64>>> =
if distances.is_some() {
RwLock::new(HashMap::with_capacity(graph.node_count()))
} else {
// Avoid extra allocation if HashMap isn't used
RwLock::new(HashMap::new())
};
let out_map = AllPairsPathMapping {
paths: node_indices
.into_par_iter()
.map(|x| {
let mut paths: HashMap<NodeIndex, Vec<NodeIndex>> =
HashMap::with_capacity(graph.node_count());
dijkstra::dijkstra(
let distance = dijkstra::dijkstra(
graph,
x,
None,
|e| edge_cost(e.id()),
Some(&mut paths),
)
.unwrap();
if distances.is_some() {
temp_distances.write().unwrap().insert(x.index(), distance);
}
let index = x.index();
let out_paths = PathMapping {
paths: paths
Expand All @@ -2798,7 +2812,11 @@ fn _all_pairs_dijkstra_shortest_paths<Ty: EdgeType + Sync>(
(index, out_paths)
})
.collect(),
})
};
if let Some(x) = distances {
*x = temp_distances.read().unwrap().clone()
};
Ok(out_map)
}

/// Calculate the the shortest length from all nodes in a
Expand Down Expand Up @@ -2872,7 +2890,7 @@ pub fn digraph_all_pairs_dijkstra_shortest_paths(
graph: &digraph::PyDiGraph,
edge_cost_fn: PyObject,
) -> PyResult<AllPairsPathMapping> {
_all_pairs_dijkstra_shortest_paths(py, &graph.graph, edge_cost_fn)
_all_pairs_dijkstra_shortest_paths(py, &graph.graph, edge_cost_fn, None)
}

/// Calculate the the shortest length from all nodes in a
Expand Down Expand Up @@ -2938,7 +2956,7 @@ pub fn graph_all_pairs_dijkstra_shortest_paths(
graph: &graph::PyGraph,
edge_cost_fn: PyObject,
) -> PyResult<AllPairsPathMapping> {
_all_pairs_dijkstra_shortest_paths(py, &graph.graph, edge_cost_fn)
_all_pairs_dijkstra_shortest_paths(py, &graph.graph, edge_cost_fn, None)
}

/// Compute the A* shortest path for a PyGraph
Expand Down Expand Up @@ -5122,6 +5140,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(digraph_spring_layout))?;
m.add_wrapped(wrap_pyfunction!(digraph_num_shortest_paths_unweighted))?;
m.add_wrapped(wrap_pyfunction!(graph_num_shortest_paths_unweighted))?;
m.add_wrapped(wrap_pyfunction!(metric_closure))?;
m.add_class::<digraph::PyDiGraph>()?;
m.add_class::<graph::PyGraph>()?;
m.add_class::<iterators::BFSSuccessors>()?;
Expand Down
115 changes: 115 additions & 0 deletions src/steiner_tree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

use hashbrown::{HashMap, HashSet};

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::Python;

use petgraph::graph::NodeIndex;

use crate::_all_pairs_dijkstra_shortest_paths;
use crate::graph;

struct MetricClosureEdge {
source: usize,
target: usize,
distance: f64,
path: Vec<usize>,
}

/// Return the metric closure of a graph
///
/// The metric closure of a graph is the complete graph in which each edge is
/// weighted by the shortest path distance between the nodes in the graph.
///
/// :param PyGraph graph: The input graph to find the metric closure for
/// :param weight_fn: A callable object that will be passed an edge's
/// weight/data payload and expected to return a ``float``. For example,
/// you can use ``weight_fn=float`` to cast every weight as a float
///
/// :return: A metric closure graph from the input graph
/// :rtype: PyGraph
#[pyfunction]
#[pyo3(text_signature = "(graph, weight_fn, /)")]
pub fn metric_closure(
py: Python,
graph: &graph::PyGraph,
weight_fn: PyObject,
) -> PyResult<graph::PyGraph> {
let mut out_graph = graph.clone();
out_graph.graph.clear_edges();
let edges = _metric_closure_edges(py, graph, weight_fn)?;
for edge in edges {
out_graph.graph.add_edge(
NodeIndex::new(edge.source),
NodeIndex::new(edge.target),
(edge.distance, edge.path).to_object(py),
);
}
Ok(out_graph)
}

fn _metric_closure_edges(
py: Python,
graph: &graph::PyGraph,
weight_fn: PyObject,
) -> PyResult<Vec<MetricClosureEdge>> {
let node_count = graph.graph.node_count();
if node_count == 0 {
return Ok(Vec::new());
}
let mut out_vec = Vec::with_capacity(node_count * (node_count - 1) / 2);
let mut distances = HashMap::with_capacity(graph.graph.node_count());
let paths = _all_pairs_dijkstra_shortest_paths(
py,
&graph.graph,
weight_fn,
Some(&mut distances),
)?
.paths;
let mut nodes: HashSet<usize> =
graph.graph.node_indices().map(|x| x.index()).collect();
let first_node = graph
.graph
.node_indices()
.map(|x| x.index())
.next()
.unwrap();
let path_keys: HashSet<usize> =
paths[&first_node].paths.keys().copied().collect();
// first_node will always be missing from path_keys so if the difference
// is > 1 with nodes that means there is another node in the graph that
// first_node doesn't have a path to.
if nodes.difference(&path_keys).count() > 1 {
return Err(PyValueError::new_err(
"The input graph must be a connected graph. The metric closure is \
not defined for a graph with unconnected nodes",
));
}
for (node, path) in paths {
let path_map = path.paths;
nodes.remove(&node);
let distance = &distances[&node];
for v in &nodes {
let v_index = NodeIndex::new(*v);
out_vec.push(MetricClosureEdge {
source: node,
target: *v,
distance: distance[&v_index],
path: path_map[v].clone(),
});
}
}
Ok(out_vec)
}
114 changes: 114 additions & 0 deletions tests/graph/test_steiner_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import pprint
import unittest

import retworkx


class TestMetricClosure(unittest.TestCase):
def setUp(self):
self.graph = retworkx.PyGraph()
self.graph.add_node(None)
self.graph.extend_from_weighted_edge_list(
[
(1, 2, 10),
(2, 3, 10),
(3, 4, 10),
(4, 5, 10),
(5, 6, 10),
(2, 7, 1),
(7, 5, 1),
]
)
self.graph.remove_node(0)

def test_metric_closure(self):
closure_graph = retworkx.metric_closure(self.graph, weight_fn=float)
expected_edges = [
(1, 2, (10.0, [1, 2])),
(1, 3, (20.0, [1, 2, 3])),
(1, 4, (22.0, [1, 2, 7, 5, 4])),
(1, 5, (12.0, [1, 2, 7, 5])),
(1, 6, (22.0, [1, 2, 7, 5, 6])),
(1, 7, (11.0, [1, 2, 7])),
(2, 3, (10.0, [2, 3])),
(2, 4, (12.0, [2, 7, 5, 4])),
(2, 5, (2.0, [2, 7, 5])),
(2, 6, (12, [2, 7, 5, 6])),
(2, 7, (1.0, [2, 7])),
(3, 4, (10.0, [3, 4])),
(3, 5, (12.0, [3, 2, 7, 5])),
(3, 6, (22.0, [3, 2, 7, 5, 6])),
(3, 7, (11.0, [3, 2, 7])),
(4, 5, (10.0, [4, 5])),
(4, 6, (20.0, [4, 5, 6])),
(4, 7, (11.0, [4, 5, 7])),
(5, 6, (10.0, [5, 6])),
(5, 7, (1.0, [5, 7])),
(6, 7, (11.0, [6, 5, 7])),
]
edges = list(closure_graph.weighted_edge_list())
for edge in expected_edges:
found = False
if edge in edges:
found = True
if not found:

if (
edge[1],
edge[0],
(edge[2][0], list(reversed(edge[2][1]))),
) in edges:
found = True
if not found:
self.fail(
f"edge: {edge} nor it's reverse not found in metric "
f"closure output:\n{pprint.pformat(edges)}"
)

def test_not_connected_metric_closure(self):
self.graph.add_node(None)
with self.assertRaises(ValueError):
retworkx.metric_closure(self.graph, weight_fn=float)

def test_partially_connected_metric_closure(self):
graph = retworkx.PyGraph()
graph.add_node(None)
graph.extend_from_weighted_edge_list(
[
(1, 2, 10),
(2, 3, 10),
(3, 4, 10),
(4, 5, 10),
(5, 6, 10),
(2, 7, 1),
(7, 5, 1),
]
)
graph.extend_from_weighted_edge_list(
[
(0, 8, 20),
(0, 9, 20),
(0, 10, 20),
(8, 10, 10),
(9, 10, 5),
]
)
with self.assertRaises(ValueError):
retworkx.metric_closure(graph, weight_fn=float)

def test_metric_closure_empty_graph(self):
graph = retworkx.PyGraph()
closure = retworkx.metric_closure(graph, weight_fn=float)
self.assertEqual([], closure.weighted_edge_list())

0 comments on commit 0919144

Please sign in to comment.