diff --git a/docs/source/api.rst b/docs/source/api.rst index 0874e52867..18d3b198b6 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -59,6 +59,7 @@ Traversal retworkx.dfs_search retworkx.bfs_successors retworkx.bfs_search + retworkx.dijkstra_search retworkx.topological_sort retworkx.lexicographical_topological_sort retworkx.descendants @@ -67,6 +68,7 @@ Traversal retworkx.collect_bicolor_runs retworkx.visit.DFSVisitor retworkx.visit.BFSVisitor + retworkx.visit.DijkstraVisitor retworkx.TopologicalSorter .. _dag-algorithms: @@ -270,6 +272,7 @@ the functions from the explicitly typed based on the data type. retworkx.digraph_betweenness_centrality retworkx.digraph_unweighted_average_shortest_path_length retworkx.digraph_bfs_search + retworkx.digraph_dijkstra_search .. _api-functions-pygraph: @@ -315,6 +318,7 @@ typed API based on the data type. retworkx.graph_betweenness_centrality retworkx.graph_unweighted_average_shortest_path_length retworkx.graph_bfs_search + retworkx.graph_dijkstra_search Exceptions ========== diff --git a/releasenotes/notes/dijkstra-search-2d1899241f5166ea.yaml b/releasenotes/notes/dijkstra-search-2d1899241f5166ea.yaml new file mode 100644 index 0000000000..ae84e55236 --- /dev/null +++ b/releasenotes/notes/dijkstra-search-2d1899241f5166ea.yaml @@ -0,0 +1,36 @@ +--- +features: + - | + Added a new :func:`~retworkx.dijkstra_search` (and it's per type variants + :func:`~retworkx.graph_dijkstra_search` and :func:`~retworkx.digraph_dijkstra_search`) + that traverses the graph using dijkstra algorithm and emits events at specified + points. The events are handled by a visitor object that subclasses + :class:`~retworkx.visit.DijkstraVisitor` through the appropriate callback functions. + For example: + + .. jupyter-execute:: + + import retworkx + from retworkx.visit import DijkstraVisitor + + + class DijkstraTreeEdgesRecorder(retworkx.visit.DijkstraVisitor): + + def __init__(self): + self.edges = [] + self.parents = dict() + + def discover_vertex(self, v, _): + u = self.parents.get(v, None) + if u is not None: + self.edges.append((u, v)) + + def edge_relaxed(self, edge): + u, v, _ = edge + self.parents[v] = u + + graph = retworkx.PyGraph() + graph.extend_from_weighted_edge_list([(1, 3, 1), (0, 1, 10), (2, 1, 1), (0, 2, 1)]) + vis = DijkstraTreeEdgesRecorder() + retworkx.graph_dijkstra_search(graph, [0], float, vis) + print('Tree edges:', vis.edges) diff --git a/retworkx-core/src/traversal/dijkstra_visit.rs b/retworkx-core/src/traversal/dijkstra_visit.rs new file mode 100644 index 0000000000..0893e8542a --- /dev/null +++ b/retworkx-core/src/traversal/dijkstra_visit.rs @@ -0,0 +1,304 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +// This module was originally copied and forked from the upstream petgraph +// repository, specifically: +// https://github.com/petgraph/petgraph/blob/0.5.1/src/dijkstra.rs +// this was necessary to modify the error handling to allow python callables +// to be use for the input functions for edge_cost and return any exceptions +// raised in Python instead of panicking + +use std::collections::BinaryHeap; +use std::hash::Hash; + +use hashbrown::hash_map::Entry::{Occupied, Vacant}; +use hashbrown::HashMap; + +use petgraph::algo::Measure; +use petgraph::visit::{ControlFlow, EdgeRef, IntoEdges, VisitMap, Visitable}; + +use crate::min_scored::MinScored; + +use super::try_control; + +macro_rules! try_control_with_result { + ($e:expr, $p:stmt) => { + try_control_with_result!($e, $p, ()); + }; + ($e:expr, $p:stmt, $q:stmt) => { + match $e { + x => { + if x.should_break() { + return Ok(x); + } else if x.should_prune() { + $p + } else { + $q + } + } + } + }; +} + +/// A dijkstra search visitor event. +#[derive(Copy, Clone, Debug)] +pub enum DijkstraEvent { + /// This is invoked when a vertex is encountered for the first time and + /// it's popped from the queue. Together with the node, we report the optimal + /// distance of the node. + Discover(N, K), + /// This is invoked on every out-edge of each vertex after it is discovered. + ExamineEdge(N, N, E), + /// Upon examination, if the distance of the target of the edge is reduced, this event is emitted. + EdgeRelaxed(N, N, E), + /// Upon examination, if the edge is not relaxed, this event is emitted. + EdgeNotRelaxed(N, N, E), + /// All edges from a node have been reported. + Finish(N), +} + +/// Dijkstra traversal of a graph. +/// +/// Starting points are the nodes in the iterator `starts` (specify just one +/// start vertex *x* by using `Some(x)`). +/// +/// The traversal emits discovery and finish events for each reachable vertex, +/// and edge classification of each reachable edge. `visitor` is called for each +/// event, see [`DijkstraEvent`] for possible values. +/// +/// The return value should implement the trait [`ControlFlow`], and can be used to change +/// the control flow of the search. +/// +/// [`Control`](petgraph::visit::Control) Implements [`ControlFlow`] such that `Control::Continue` resumes the search. +/// `Control::Break` will stop the visit early, returning the contained value. +/// `Control::Prune` will stop traversing any additional edges from the current +/// node and proceed immediately to the `Finish` event. +/// +/// There are implementations of [`ControlFlow`] for `()`, and [`Result`] where +/// `C: ControlFlow`. The implementation for `()` will continue until finished. +/// For [`Result`], upon encountering an `E` it will break, otherwise acting the same as `C`. +/// +/// ***Panics** if you attempt to prune a node from its `Finish` event. +/// +/// The pseudo-code for the Dijkstra algorithm is listed below, with the annotated +/// event points, for which the given visitor object will be called with the +/// appropriate method. +/// +/// ```norust +/// DIJKSTRA(G, source, weight) +/// for each vertex u in V +/// d[u] := infinity +/// p[u] := u +/// end for +/// d[source] := 0 +/// INSERT(Q, source) +/// while (Q != Ø) +/// u := EXTRACT-MIN(Q) discover vertex u +/// for each vertex v in Adj[u] examine edge (u,v) +/// if (weight[(u,v)] + d[u] < d[v]) edge (u,v) relaxed +/// d[v] := weight[(u,v)] + d[u] +/// p[v] := u +/// DECREASE-KEY(Q, v) +/// else edge (u,v) not relaxed +/// ... +/// if (d[v] was originally infinity) +/// INSERT(Q, v) +/// end for finish vertex u +/// end while +/// ``` +/// +/// # Example returning [`Control`](petgraph::visit::Control). +/// +/// Find the shortest path from vertex 0 to 5, and exit the visit as soon as +/// we reach the goal vertex. +/// +/// ``` +/// use retworkx_core::petgraph::prelude::*; +/// use retworkx_core::petgraph::graph::node_index as n; +/// use retworkx_core::petgraph::visit::Control; +/// +/// use retworkx_core::traversal::{DijkstraEvent, dijkstra_search}; +/// +/// let gr: Graph<(), ()> = Graph::from_edges(&[ +/// (0, 1), (0, 2), (0, 3), (0, 4), +/// (1, 3), +/// (2, 3), (2, 4), +/// (4, 5), +/// ]); +/// +/// // record each predecessor, mapping node → node +/// let mut predecessor = vec![NodeIndex::end(); gr.node_count()]; +/// let start = n(0); +/// let goal = n(5); +/// dijkstra_search( +/// &gr, +/// Some(start), +/// |edge| -> Result { +/// Ok(1) +/// }, +/// |event| { +/// match event { +/// DijkstraEvent::Discover(v, _) => { +/// if v == goal { +/// return Control::Break(v); +/// } +/// }, +/// DijkstraEvent::EdgeRelaxed(u, v, _) => { +/// predecessor[v.index()] = u; +/// }, +/// _ => {} +/// }; +/// +/// Control::Continue +/// }, +/// ).unwrap(); +/// +/// let mut next = goal; +/// let mut path = vec![next]; +/// while next != start { +/// let pred = predecessor[next.index()]; +/// path.push(pred); +/// next = pred; +/// } +/// path.reverse(); +/// assert_eq!(&path, &[n(0), n(4), n(5)]); +/// ``` +pub fn dijkstra_search( + graph: G, + starts: I, + mut edge_cost: F, + mut visitor: H, +) -> Result +where + G: IntoEdges + Visitable, + G::NodeId: Eq + Hash, + I: IntoIterator, + F: FnMut(G::EdgeRef) -> Result, + K: Measure + Copy, + H: FnMut(DijkstraEvent) -> C, + C: ControlFlow, +{ + let visited = &mut graph.visit_map(); + + for start in starts { + // `dijkstra_visitor` returns a "signal" to either continue or exit early + // but it never "prunes", so we use `unreachable`. + try_control!( + dijkstra_visitor( + graph, + start, + &mut edge_cost, + &mut visitor, + visited + ), + unreachable!() + ); + } + + Ok(C::continuing()) +} + +pub fn dijkstra_visitor( + graph: G, + start: G::NodeId, + mut edge_cost: F, + mut visitor: V, + visited: &mut G::Map, +) -> Result +where + G: IntoEdges + Visitable, + G::NodeId: Eq + Hash, + F: FnMut(G::EdgeRef) -> Result, + K: Measure + Copy, + V: FnMut(DijkstraEvent) -> C, + C: ControlFlow, +{ + if visited.is_visited(&start) { + return Ok(C::continuing()); + } + + let mut scores = HashMap::new(); + let mut visit_next = BinaryHeap::new(); + let zero_score = K::default(); + scores.insert(start, zero_score); + visit_next.push(MinScored(zero_score, start)); + + while let Some(MinScored(node_score, node)) = visit_next.pop() { + if !visited.visit(node) { + continue; + } + + try_control_with_result!( + visitor(DijkstraEvent::Discover(node, node_score)), + continue + ); + + for edge in graph.edges(node) { + let next = edge.target(); + try_control_with_result!( + visitor(DijkstraEvent::ExamineEdge(node, next, edge.weight())), + continue + ); + + if visited.is_visited(&next) { + continue; + } + + let cost = edge_cost(edge)?; + let next_score = node_score + cost; + match scores.entry(next) { + Occupied(ent) => { + if next_score < *ent.get() { + try_control_with_result!( + visitor(DijkstraEvent::EdgeRelaxed( + node, + next, + edge.weight() + )), + continue + ); + *ent.into_mut() = next_score; + visit_next.push(MinScored(next_score, next)); + } else { + try_control_with_result!( + visitor(DijkstraEvent::EdgeNotRelaxed( + node, + next, + edge.weight() + )), + continue + ); + } + } + Vacant(ent) => { + try_control_with_result!( + visitor(DijkstraEvent::EdgeRelaxed( + node, + next, + edge.weight() + )), + continue + ); + ent.insert(next_score); + visit_next.push(MinScored(next_score, next)); + } + } + } + + try_control_with_result!( + visitor(DijkstraEvent::Finish(node)), + panic!("Pruning on the `DijkstraEvent::Finish` is not supported!") + ); + } + + Ok(C::continuing()) +} diff --git a/retworkx-core/src/traversal/mod.rs b/retworkx-core/src/traversal/mod.rs index 2a06c30070..2a62544810 100644 --- a/retworkx-core/src/traversal/mod.rs +++ b/retworkx-core/src/traversal/mod.rs @@ -15,10 +15,12 @@ mod bfs_visit; mod dfs_edges; mod dfs_visit; +mod dijkstra_visit; pub use bfs_visit::{breadth_first_search, BfsEvent}; pub use dfs_edges::dfs_edges; pub use dfs_visit::{depth_first_search, DfsEvent}; +pub use dijkstra_visit::{dijkstra_search, DijkstraEvent}; /// Return if the expression is a break value, execute the provided statement /// if it is a prune value. diff --git a/retworkx/__init__.py b/retworkx/__init__.py index 1bebcd0785..ae9c3bdd2e 100644 --- a/retworkx/__init__.py +++ b/retworkx/__init__.py @@ -1993,3 +1993,70 @@ def _digraph_dfs_search(graph, source, visitor): @dfs_search.register(PyGraph) def _graph_dfs_search(graph, source, visitor): return graph_dfs_search(graph, source, visitor) + + +@functools.singledispatch +def dijkstra_search(graph, source, weight_fn, visitor): + """Dijkstra traversal of a graph. + + The pseudo-code for the Dijkstra algorithm is listed below, with the annotated + event points, for which the given visitor object will be called with the + appropriate method. + + :: + + DIJKSTRA(G, source, weight) + for each vertex u in V + d[u] := infinity + p[u] := u + end for + d[source] := 0 + INSERT(Q, source) + while (Q != Ø) + u := EXTRACT-MIN(Q) discover vertex u + for each vertex v in Adj[u] examine edge (u,v) + if (weight[(u,v)] + d[u] < d[v]) edge (u,v) relaxed + d[v] := weight[(u,v)] + d[u] + p[v] := u + DECREASE-KEY(Q, v) + else edge (u,v) not relaxed + ... + if (d[v] was originally infinity) + INSERT(Q, v) + end for finish vertex u + end while + + If an exception is raised inside the callback function, the graph traversal + will be stopped immediately. You can exploit this to exit early by raising a + :class:`~retworkx.visit.StopSearch` exception, in which case the search function + will return but without raising back the exception. You can also prune part of the + search tree by raising :class:`~retworkx.visit.PruneSearch`. + + .. note:: + + Graph can **not** be mutated while traversing. + + :param graph: The graph to be used. This can be a :class:`~retworkx.PyGraph` + or a :class:`~retworkx.PyDiGraph`. + :param List[int] source: An optional list of node indices to use as the starting nodes + for the dijkstra search. If this is not specified then a source + will be chosen arbitrarly and repeated until all components of the + graph are searched. + :param weight_fn: An optional weight function for an edge. It will accept + a single argument, the edge's weight object and will return a float which + will be used to represent the weight/cost of the edge. If not specified, + a default value of cost ``1.0`` will be used for each edge. + :param visitor: A visitor object that is invoked at the event points inside the + algorithm. This should be a subclass of :class:`~retworkx.visit.DijkstraVisitor`. + """ + raise TypeError("Invalid Input Type %s for graph" % type(graph)) + + +@dijkstra_search.register(PyDiGraph) +def _digraph_dijkstra_search(graph, source, weight_fn, visitor): + return digraph_dijkstra_search(graph, source, weight_fn, visitor) + + +@dijkstra_search.register(PyGraph) +def _graph_dijkstra_search(graph, source, weight_fn, visitor): + return graph_dijkstra_search(graph, source, weight_fn, visitor) diff --git a/retworkx/visit.py b/retworkx/visit.py index d0f21fb682..7cce5d7aed 100644 --- a/retworkx/visit.py +++ b/retworkx/visit.py @@ -33,10 +33,7 @@ def discover_vertex(self, v): def finish_vertex(self, v): """ - This is invoked on vertex `v` after all of its out edges have been - added to the search tree and all of the adjacent vertices have been - discovered, but before the out-edges of the adjacent vertices have - been examined. + This is invoked on vertex `v` after all of its out edges have been examined. """ return @@ -120,3 +117,43 @@ def forward_or_cross_edge(self, e): In an undirected graph this method is never called. """ return + + +class DijkstraVisitor: + """A visitor object that is invoked at the event-points inside the + :func:`~retworkx.dijkstra_search` algorithm. By default, it performs no + action, and should be used as a base class in order to be useful. + """ + + def discover_vertex(self, v, score): + """ + This is invoked when a vertex is encountered for the first time and + it's popped from the queue. Together with the node, we report the optimal + distance of the node. + """ + return + + def finish_vertex(self, v): + """ + This is invoked on vertex `v` after all of its out edges have been examined. + """ + return + + def examine_edge(self, edge): + """ + This is invoked on every out-edge of each vertex after it is discovered. + """ + return + + def edge_relaxed(self, edge): + """ + Upon examination, if the distance of the target of the edge is reduced, + this event is emitted. + """ + return + + def edge_not_relaxed(self, edge): + """ + Upon examination, if the edge is not relaxed, this event is emitted. + """ + return diff --git a/src/lib.rs b/src/lib.rs index 4dd68ece2e..e5c9285ecf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -309,6 +309,8 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(bfs_successors))?; m.add_wrapped(wrap_pyfunction!(graph_bfs_search))?; m.add_wrapped(wrap_pyfunction!(digraph_bfs_search))?; + m.add_wrapped(wrap_pyfunction!(graph_dijkstra_search))?; + m.add_wrapped(wrap_pyfunction!(digraph_dijkstra_search))?; m.add_wrapped(wrap_pyfunction!(dag_longest_path))?; m.add_wrapped(wrap_pyfunction!(dag_longest_path_length))?; m.add_wrapped(wrap_pyfunction!(dag_weighted_longest_path))?; diff --git a/src/traversal/dijkstra_visit.rs b/src/traversal/dijkstra_visit.rs new file mode 100644 index 0000000000..40a0b2df74 --- /dev/null +++ b/src/traversal/dijkstra_visit.rs @@ -0,0 +1,66 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use pyo3::prelude::*; + +use petgraph::stable_graph::NodeIndex; +use petgraph::visit::Control; + +use crate::{PruneSearch, StopSearch}; +use retworkx_core::traversal::DijkstraEvent; + +#[derive(FromPyObject)] +pub struct PyDijkstraVisitor { + discover_vertex: PyObject, + finish_vertex: PyObject, + examine_edge: PyObject, + edge_relaxed: PyObject, + edge_not_relaxed: PyObject, +} + +pub fn dijkstra_handler( + py: Python, + vis: &PyDijkstraVisitor, + event: DijkstraEvent, +) -> PyResult> { + let res = match event { + DijkstraEvent::Discover(u, score) => { + vis.discover_vertex.call1(py, (u.index(), score)) + } + DijkstraEvent::ExamineEdge(u, v, weight) => { + let edge = (u.index(), v.index(), weight); + vis.examine_edge.call1(py, (edge,)) + } + DijkstraEvent::EdgeRelaxed(u, v, weight) => { + let edge = (u.index(), v.index(), weight); + vis.edge_relaxed.call1(py, (edge,)) + } + DijkstraEvent::EdgeNotRelaxed(u, v, weight) => { + let edge = (u.index(), v.index(), weight); + vis.edge_not_relaxed.call1(py, (edge,)) + } + DijkstraEvent::Finish(u) => vis.finish_vertex.call1(py, (u.index(),)), + }; + + match res { + Err(e) => { + if e.is_instance::(py) { + Ok(Control::Prune) + } else if e.is_instance::(py) { + Ok(Control::Break(())) + } else { + Err(e) + } + } + Ok(_) => Ok(Control::Continue), + } +} diff --git a/src/traversal/mod.rs b/src/traversal/mod.rs index f28ba6f69e..21f170fe0d 100644 --- a/src/traversal/mod.rs +++ b/src/traversal/mod.rs @@ -12,15 +12,19 @@ mod bfs_visit; pub mod dfs_visit; +mod dijkstra_visit; use bfs_visit::{bfs_handler, PyBfsVisitor}; use dfs_visit::{dfs_handler, PyDfsVisitor}; +use dijkstra_visit::{dijkstra_handler, PyDijkstraVisitor}; use retworkx_core::traversal::{ - breadth_first_search, depth_first_search, dfs_edges, + breadth_first_search, depth_first_search, dfs_edges, dijkstra_search, }; -use super::{digraph, graph, iterators}; +use super::{digraph, graph, iterators, CostFn}; + +use std::convert::TryFrom; use hashbrown::HashSet; @@ -558,3 +562,153 @@ pub fn graph_dfs_search( Ok(()) } + +/// Dijkstra traversal of a directed graph. +/// +/// The pseudo-code for the Dijkstra algorithm is listed below, with the annotated +/// event points, for which the given visitor object will be called with the +/// appropriate method. +/// +/// :: +/// +/// DIJKSTRA(G, source, weight) +/// for each vertex u in V +/// d[u] := infinity +/// p[u] := u +/// end for +/// d[source] := 0 +/// INSERT(Q, source) +/// while (Q != Ø) +/// u := EXTRACT-MIN(Q) discover vertex u +/// for each vertex v in Adj[u] examine edge (u,v) +/// if (weight[(u,v)] + d[u] < d[v]) edge (u,v) relaxed +/// d[v] := weight[(u,v)] + d[u] +/// p[v] := u +/// DECREASE-KEY(Q, v) +/// else edge (u,v) not relaxed +/// ... +/// if (d[v] was originally infinity) +/// INSERT(Q, v) +/// end for finish vertex u +/// end while +/// +/// If an exception is raised inside the callback function, the graph traversal +/// will be stopped immediately. You can exploit this to exit early by raising a +/// :class:`~retworkx.visit.StopSearch` exception, in which case the search function +/// will return but without raising back the exception. You can also prune part of the +/// search tree by raising :class:`~retworkx.visit.PruneSearch`. +/// +/// .. note:: +/// +/// Graph can **not** be mutated while traversing. +/// +/// :param PyDiGraph graph: The graph to be used. +/// :param List[int] source: An optional list of node indices to use as the starting nodes +/// for the dijkstra search. If this is not specified then a source +/// will be chosen arbitrarly and repeated until all components of the +/// graph are searched. +/// :param weight_fn: An optional weight function for an edge. It will accept +/// a single argument, the edge's weight object and will return a float which +/// will be used to represent the weight/cost of the edge. If not specified, +/// a default value of cost ``1.0`` will be used for each edge. +/// :param visitor: A visitor object that is invoked at the event points inside the +/// algorithm. This should be a subclass of :class:`~retworkx.visit.DijkstraVisitor`. +#[pyfunction] +#[pyo3(text_signature = "(graph, source, weight_fn, visitor)")] +pub fn digraph_dijkstra_search( + py: Python, + graph: &digraph::PyDiGraph, + source: Option>, + weight_fn: Option, + visitor: PyDijkstraVisitor, +) -> PyResult<()> { + let starts: Vec<_> = match source { + Some(nx) => nx.into_iter().map(NodeIndex::new).collect(), + None => graph.graph.node_indices().collect(), + }; + + let edge_cost_fn = CostFn::try_from((weight_fn, 1.0))?; + dijkstra_search( + &graph.graph, + starts, + |e| edge_cost_fn.call(py, e.weight()), + |event| dijkstra_handler(py, &visitor, event), + )??; + + Ok(()) +} + +/// Dijkstra traversal of an undirected graph. +/// +/// The pseudo-code for the Dijkstra algorithm is listed below, with the annotated +/// event points, for which the given visitor object will be called with the +/// appropriate method. +/// +/// :: +/// +/// DIJKSTRA(G, source, weight) +/// for each vertex u in V +/// d[u] := infinity +/// p[u] := u +/// end for +/// d[source] := 0 +/// INSERT(Q, source) +/// while (Q != Ø) +/// u := EXTRACT-MIN(Q) discover vertex u +/// for each vertex v in Adj[u] examine edge (u,v) +/// if (weight[(u,v)] + d[u] < d[v]) edge (u,v) relaxed +/// d[v] := weight[(u,v)] + d[u] +/// p[v] := u +/// DECREASE-KEY(Q, v) +/// else edge (u,v) not relaxed +/// ... +/// if (d[v] was originally infinity) +/// INSERT(Q, v) +/// end for finish vertex u +/// end while +/// +/// If an exception is raised inside the callback function, the graph traversal +/// will be stopped immediately. You can exploit this to exit early by raising a +/// :class:`~retworkx.visit.StopSearch` exception, in which case the search function +/// will return but without raising back the exception. You can also prune part of the +/// search tree by raising :class:`~retworkx.visit.PruneSearch`. +/// +/// .. note:: +/// +/// Graph can **not** be mutated while traversing. +/// +/// :param PyGraph graph: The graph to be used. +/// :param List[int] source: An optional list of node indices to use as the starting nodes +/// for the dijkstra search. If this is not specified then a source +/// will be chosen arbitrarly and repeated until all components of the +/// graph are searched. +/// :param weight_fn: An optional weight function for an edge. It will accept +/// a single argument, the edge's weight object and will return a float which +/// will be used to represent the weight/cost of the edge. If not specified, +/// a default value of cost ``1.0`` will be used for each edge. +/// :param visitor: A visitor object that is invoked at the event points inside the +/// algorithm. This should be a subclass of :class:`~retworkx.visit.DijkstraVisitor`. +#[pyfunction] +#[pyo3(text_signature = "(graph, source, weight_fn, visitor)")] +pub fn graph_dijkstra_search( + py: Python, + graph: &graph::PyGraph, + source: Option>, + weight_fn: Option, + visitor: PyDijkstraVisitor, +) -> PyResult<()> { + let starts: Vec<_> = match source { + Some(nx) => nx.into_iter().map(NodeIndex::new).collect(), + None => graph.graph.node_indices().collect(), + }; + + let edge_cost_fn = CostFn::try_from((weight_fn, 1.0))?; + dijkstra_search( + &graph.graph, + starts, + |e| edge_cost_fn.call(py, e.weight()), + |event| dijkstra_handler(py, &visitor, event), + )??; + + Ok(()) +} diff --git a/tests/digraph/test_dijkstra_search.py b/tests/digraph/test_dijkstra_search.py new file mode 100644 index 0000000000..3bb6950783 --- /dev/null +++ b/tests/digraph/test_dijkstra_search.py @@ -0,0 +1,191 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import retworkx + + +class TestDijkstraSearch(unittest.TestCase): + def setUp(self): + self.graph = retworkx.PyDiGraph() + self.graph.extend_from_weighted_edge_list( + [ + (0, 1, 1), + (0, 2, 2), + (1, 3, 10), + (2, 1, 1), + (2, 5, 1), + (2, 6, 1), + (5, 3, 1), + (4, 7, 1), + ] + ) + + def test_digraph_dijkstra_tree_edges(self): + class DijkstraTreeEdgesRecorder(retworkx.visit.DijkstraVisitor): + def __init__(self): + self.edges = [] + self.parents = dict() + + def discover_vertex(self, v, _): + u = self.parents.get(v, None) + if u is not None: + self.edges.append((u, v)) + + def edge_relaxed(self, edge): + u, v, _ = edge + self.parents[v] = u + + vis = DijkstraTreeEdgesRecorder() + retworkx.digraph_dijkstra_search(self.graph, [0], float, vis) + self.assertEqual(vis.edges, [(0, 1), (0, 2), (2, 6), (2, 5), (5, 3)]) + + def test_digraph_dijkstra_tree_edges_no_starting_point(self): + class DijkstraTreeEdgesRecorder(retworkx.visit.DijkstraVisitor): + def __init__(self): + self.edges = [] + self.parents = dict() + + def discover_vertex(self, v, _): + u = self.parents.get(v, None) + if u is not None: + self.edges.append((u, v)) + + def edge_relaxed(self, edge): + u, v, _ = edge + self.parents[v] = u + + vis = DijkstraTreeEdgesRecorder() + retworkx.digraph_dijkstra_search(self.graph, None, float, vis) + self.assertEqual( + vis.edges, [(0, 1), (0, 2), (2, 6), (2, 5), (5, 3), (4, 7)] + ) + + def test_digraph_dijkstra_goal_search_with_stop_search_exception(self): + class GoalSearch(retworkx.visit.DijkstraVisitor): + + goal = 3 + + def __init__(self): + self.parents = {} + self.opt_goal_cost = None + + def discover_vertex(self, v, score): + if v == self.goal: + self.opt_goal_cost = score + raise retworkx.visit.StopSearch + + def edge_relaxed(self, edge): + u, v, _ = edge + self.parents[v] = u + + def reconstruct_path(self): + v = self.goal + path = [v] + while v in self.parents: + v = self.parents[v] + path.append(v) + + path.reverse() + return path + + vis = GoalSearch() + retworkx.digraph_dijkstra_search(self.graph, [0], float, vis) + self.assertEqual(vis.reconstruct_path(), [0, 2, 5, 3]) + self.assertEqual(vis.opt_goal_cost, 4.0) + + def test_digraph_dijkstra_goal_search_with_custom_exception(self): + class StopIfGoalFound(Exception): + pass + + class GoalSearch(retworkx.visit.DijkstraVisitor): + + goal = 3 + + def __init__(self): + self.parents = {} + self.opt_goal_cost = None + + def discover_vertex(self, v, score): + if v == self.goal: + self.opt_goal_cost = score + raise StopIfGoalFound + + def edge_relaxed(self, edge): + u, v, _ = edge + self.parents[v] = u + + def reconstruct_path(self): + v = self.goal + path = [v] + while v in self.parents: + v = self.parents[v] + path.append(v) + + path.reverse() + return path + + vis = GoalSearch() + try: + retworkx.digraph_dijkstra_search(self.graph, [0], float, vis) + except StopIfGoalFound: + pass + self.assertEqual(vis.reconstruct_path(), [0, 2, 5, 3]) + self.assertEqual(vis.opt_goal_cost, 4.0) + + def test_digraph_dijkstra_goal_search_with_prohibited_edges(self): + class GoalSearch(retworkx.visit.DijkstraVisitor): + + goal = 3 + prohibited = [(5, 3)] + + def __init__(self): + self.parents = {} + self.opt_goal_cost = None + + def discover_vertex(self, v, score): + if v == self.goal: + self.opt_goal_cost = score + raise retworkx.visit.StopSearch + + def examine_edge(self, edge): + u, v, _ = edge + if (u, v) in self.prohibited: + raise retworkx.visit.PruneSearch + + def edge_relaxed(self, edge): + u, v, _ = edge + self.parents[v] = u + + def reconstruct_path(self): + v = self.goal + path = [v] + while v in self.parents: + v = self.parents[v] + path.append(v) + + path.reverse() + return path + + vis = GoalSearch() + retworkx.digraph_dijkstra_search(self.graph, [0], float, vis) + self.assertEqual(vis.reconstruct_path(), [0, 1, 3]) + self.assertEqual(vis.opt_goal_cost, 11.0) + + def test_digraph_prune_edge_not_relaxed(self): + class PruneEdgeNotRelaxed(retworkx.visit.DijkstraVisitor): + def edge_not_relaxed(self, _): + raise retworkx.visit.PruneSearch + + vis = PruneEdgeNotRelaxed() + retworkx.digraph_dijkstra_search(self.graph, [0], float, vis) diff --git a/tests/graph/test_dijkstra_search.py b/tests/graph/test_dijkstra_search.py new file mode 100644 index 0000000000..18e713c340 --- /dev/null +++ b/tests/graph/test_dijkstra_search.py @@ -0,0 +1,191 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +import retworkx + + +class TestDijkstraSearch(unittest.TestCase): + def setUp(self): + self.graph = retworkx.PyGraph() + self.graph.extend_from_weighted_edge_list( + [ + (0, 1, 1), + (0, 2, 2), + (1, 3, 10), + (2, 1, 1), + (2, 5, 1), + (2, 6, 1), + (5, 3, 1), + (4, 7, 1), + ] + ) + + def test_graph_dijkstra_tree_edges(self): + class DijkstraTreeEdgesRecorder(retworkx.visit.DijkstraVisitor): + def __init__(self): + self.edges = [] + self.parents = dict() + + def discover_vertex(self, v, _): + u = self.parents.get(v, None) + if u is not None: + self.edges.append((u, v)) + + def edge_relaxed(self, edge): + u, v, _ = edge + self.parents[v] = u + + vis = DijkstraTreeEdgesRecorder() + retworkx.graph_dijkstra_search(self.graph, [0], float, vis) + self.assertEqual(vis.edges, [(0, 1), (0, 2), (2, 6), (2, 5), (5, 3)]) + + def test_graph_dijkstra_tree_edges_no_starting_point(self): + class DijkstraTreeEdgesRecorder(retworkx.visit.DijkstraVisitor): + def __init__(self): + self.edges = [] + self.parents = dict() + + def discover_vertex(self, v, _): + u = self.parents.get(v, None) + if u is not None: + self.edges.append((u, v)) + + def edge_relaxed(self, edge): + u, v, _ = edge + self.parents[v] = u + + vis = DijkstraTreeEdgesRecorder() + retworkx.graph_dijkstra_search(self.graph, None, float, vis) + self.assertEqual( + vis.edges, [(0, 1), (0, 2), (2, 6), (2, 5), (5, 3), (4, 7)] + ) + + def test_graph_dijkstra_goal_search_with_stop_search_exception(self): + class GoalSearch(retworkx.visit.DijkstraVisitor): + + goal = 3 + + def __init__(self): + self.parents = {} + self.opt_goal_cost = None + + def discover_vertex(self, v, score): + if v == self.goal: + self.opt_goal_cost = score + raise retworkx.visit.StopSearch + + def edge_relaxed(self, edge): + u, v, _ = edge + self.parents[v] = u + + def reconstruct_path(self): + v = self.goal + path = [v] + while v in self.parents: + v = self.parents[v] + path.append(v) + + path.reverse() + return path + + vis = GoalSearch() + retworkx.graph_dijkstra_search(self.graph, [0], float, vis) + self.assertEqual(vis.reconstruct_path(), [0, 2, 5, 3]) + self.assertEqual(vis.opt_goal_cost, 4.0) + + def test_graph_dijkstra_goal_search_with_custom_exception(self): + class StopIfGoalFound(Exception): + pass + + class GoalSearch(retworkx.visit.DijkstraVisitor): + + goal = 3 + + def __init__(self): + self.parents = {} + self.opt_goal_cost = None + + def discover_vertex(self, v, score): + if v == self.goal: + self.opt_goal_cost = score + raise StopIfGoalFound + + def edge_relaxed(self, edge): + u, v, _ = edge + self.parents[v] = u + + def reconstruct_path(self): + v = self.goal + path = [v] + while v in self.parents: + v = self.parents[v] + path.append(v) + + path.reverse() + return path + + vis = GoalSearch() + try: + retworkx.graph_dijkstra_search(self.graph, [0], float, vis) + except StopIfGoalFound: + pass + self.assertEqual(vis.reconstruct_path(), [0, 2, 5, 3]) + self.assertEqual(vis.opt_goal_cost, 4.0) + + def test_graph_dijkstra_goal_search_with_prohibited_edges(self): + class GoalSearch(retworkx.visit.DijkstraVisitor): + + goal = 3 + prohibited = [(5, 3)] + + def __init__(self): + self.parents = {} + self.opt_goal_cost = None + + def discover_vertex(self, v, score): + if v == self.goal: + self.opt_goal_cost = score + raise retworkx.visit.StopSearch + + def examine_edge(self, edge): + u, v, _ = edge + if (u, v) in self.prohibited: + raise retworkx.visit.PruneSearch + + def edge_relaxed(self, edge): + u, v, _ = edge + self.parents[v] = u + + def reconstruct_path(self): + v = self.goal + path = [v] + while v in self.parents: + v = self.parents[v] + path.append(v) + + path.reverse() + return path + + vis = GoalSearch() + retworkx.graph_dijkstra_search(self.graph, [0], float, vis) + self.assertEqual(vis.reconstruct_path(), [0, 1, 3]) + self.assertEqual(vis.opt_goal_cost, 11.0) + + def test_graph_prune_edge_not_relaxed(self): + class PruneEdgeNotRelaxed(retworkx.visit.DijkstraVisitor): + def edge_not_relaxed(self, _): + raise retworkx.visit.PruneSearch + + vis = PruneEdgeNotRelaxed() + retworkx.graph_dijkstra_search(self.graph, [0], float, vis)