Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix panic error at shortest paths #1134

Merged
merged 4 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions releasenotes/notes/fix-panic-3962ad36788cab00.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
fixes:
- |
Fixed an issue with the Dijkstra path functions:
* :func:`rustworkx.dijkstra_shortest_paths`
* :func:`rustworkx.dijkstra_shortest_path_lengths`
* :func:`rustworkx.bellman_ford_shortest_path_lengths`
* :func:`rustworkx.bellman_ford_shortest_paths`
* :func:`rustworkx.astar_shortest_path`
where a `Pyo3.PanicException`were raise with no much detail at the moment
of pass in the `source` argument the index of an out of bound node.
Fixed `#1117 <https://github.com/Qiskit/rustworkx/issues/1117>`__
64 changes: 63 additions & 1 deletion src/shortest_path/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ pub fn graph_dijkstra_shortest_paths(
default_weight: f64,
) -> PyResult<PathMapping> {
let start = NodeIndex::new(source);
if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{source}\" out of graph bound"
)));
}

let goal_index: Option<NodeIndex> = target.map(NodeIndex::new);
let mut paths: DictMap<NodeIndex, Vec<NodeIndex>> = DictMap::with_capacity(graph.node_count());

Expand Down Expand Up @@ -217,6 +223,12 @@ pub fn digraph_dijkstra_shortest_paths(
as_undirected: bool,
) -> PyResult<PathMapping> {
let start = NodeIndex::new(source);
if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{source}\" out of graph bound"
)));
}

let goal_index: Option<NodeIndex> = target.map(NodeIndex::new);
let mut paths: DictMap<NodeIndex, Vec<NodeIndex>> = DictMap::with_capacity(graph.node_count());
let cost_fn = CostFn::try_from((weight_fn, default_weight))?;
Expand Down Expand Up @@ -371,10 +383,16 @@ pub fn graph_dijkstra_shortest_path_lengths(
edge_cost_fn: PyObject,
goal: Option<usize>,
) -> PyResult<PathLengthMapping> {
let edge_cost_callable = CostFn::from(edge_cost_fn);
let start = NodeIndex::new(node);
let edge_cost_callable = CostFn::from(edge_cost_fn);
let goal_index: Option<NodeIndex> = goal.map(NodeIndex::new);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{node}\" out of graph bound"
)));
}

let res: Vec<Option<f64>> = dijkstra(
&graph.graph,
start,
Expand Down Expand Up @@ -445,6 +463,12 @@ pub fn digraph_dijkstra_shortest_path_lengths(
let start = NodeIndex::new(node);
let goal_index: Option<NodeIndex> = goal.map(NodeIndex::new);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{node}\" out of graph bound"
)));
}

let res: Vec<Option<f64>> = dijkstra(
&graph.graph,
start,
Expand Down Expand Up @@ -671,6 +695,12 @@ pub fn digraph_astar_shortest_path(
let estimate_cost_callable = CostFn::from(estimate_cost_fn);
let start = NodeIndex::new(node);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{node}\" out of graph bound"
)));
}

let astar_res = astar(
&graph.graph,
start,
Expand Down Expand Up @@ -729,6 +759,12 @@ pub fn graph_astar_shortest_path(
let estimate_cost_callable = CostFn::from(estimate_cost_fn);
let start = NodeIndex::new(node);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{node}\" out of graph bound"
)));
}

let astar_res = astar(
&graph.graph,
start,
Expand Down Expand Up @@ -1561,6 +1597,12 @@ pub fn digraph_bellman_ford_shortest_path_lengths(

let start = NodeIndex::new(node);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{node}\" out of graph bound"
)));
}

let res: Option<Vec<Option<f64>>> =
bellman_ford(&graph.graph, start, |e| edge_cost(e.id()), None)?;

Expand Down Expand Up @@ -1640,6 +1682,12 @@ pub fn graph_bellman_ford_shortest_path_lengths(

let start = NodeIndex::new(node);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{node}\" out of graph bound"
)));
}

let res: Option<Vec<Option<f64>>> =
bellman_ford(&graph.graph, start, |e| edge_cost(e.id()), None)?;

Expand Down Expand Up @@ -1715,6 +1763,13 @@ pub fn graph_bellman_ford_shortest_paths(
default_weight: f64,
) -> PyResult<PathMapping> {
let start = NodeIndex::new(source);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{source}\" out of graph bound"
)));
}

let mut paths: DictMap<NodeIndex, Vec<NodeIndex>> = DictMap::with_capacity(graph.node_count());

let edge_weights: Vec<Option<f64>> =
Expand Down Expand Up @@ -1801,6 +1856,13 @@ pub fn digraph_bellman_ford_shortest_paths(
}

let start = NodeIndex::new(source);

if !graph.graph.contains_node(start) {
return Err(PyIndexError::new_err(format!(
"Node source index \"{source}\" out of graph bound"
)));
}

let mut paths: DictMap<NodeIndex, Vec<NodeIndex>> = DictMap::with_capacity(graph.node_count());

let edge_weights: Vec<Option<f64>> =
Expand Down
14 changes: 14 additions & 0 deletions tests/digraph/test_astar.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,17 @@ def test_astar_with_invalid_weights(self):
edge_cost_fn=lambda _: invalid_weight,
estimate_cost_fn=lambda _: 0,
)

def test_astar_with_invalid_source_node(self):
g = rustworkx.PyDAG()
a = g.add_node("A")
b = g.add_node("B")
g.add_edge(a, b, 7)
with self.assertRaises(IndexError):
rustworkx.digraph_astar_shortest_path(
g,
len(g.node_indices()) + 1,
goal_fn=lambda goal: goal == "B",
edge_cost_fn=lambda x: float(x),
estimate_cost_fn=lambda _: 0,
)
12 changes: 12 additions & 0 deletions tests/digraph/test_bellman_ford.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,15 @@ def test_raises_negative_cycle_all_pairs_bellman_ford_path_lenghts(self):

with self.assertRaises(rustworkx.NegativeCycle):
rustworkx.all_pairs_bellman_ford_path_lengths(graph, float)

def test_raises_index_error_bellman_ford_paths(self):
with self.assertRaises(IndexError):
rustworkx.digraph_bellman_ford_shortest_paths(
self.graph, len(self.graph.node_indices()) + 1, weight_fn=lambda x: float(x)
)

def test_raises_index_error_bellman_ford_path_lenghts(self):
with self.assertRaises(IndexError):
rustworkx.digraph_bellman_ford_shortest_path_lengths(
self.graph, len(self.graph.node_indices()) + 1, edge_cost_fn=lambda x: float(x)
)
10 changes: 10 additions & 0 deletions tests/digraph/test_dijkstra.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,13 @@ def all_pairs_dijkstra_lenghts_with_invalid_weights(self):
rustworkx.digraph_all_pairs_dijkstra_path_lengths(
graph, edge_cost_fn=lambda _: invalid_weight
)

def test_dijkstra_path_digraph_with_invalid_source(self):
with self.assertRaises(IndexError):
rustworkx.dijkstra_shortest_paths(self.graph, len(self.graph.node_indices()) + 1)

def test_dijkstra_path_digraph_lengths_with_invalid_source(self):
with self.assertRaises(IndexError):
rustworkx.dijkstra_shortest_path_lengths(
self.graph, len(self.graph.node_indices()) + 1, edge_cost_fn=lambda x: x
)
14 changes: 14 additions & 0 deletions tests/graph/test_astar.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,17 @@ def test_astar_with_invalid_weights(self):
edge_cost_fn=lambda _: invalid_weight,
estimate_cost_fn=lambda _: 0,
)

def test_astar_with_invalid_source_node(self):
g = rustworkx.PyGraph()
a = g.add_node("A")
b = g.add_node("B")
g.add_edge(a, b, 7)
with self.assertRaises(IndexError):
rustworkx.graph_astar_shortest_path(
g,
len(g.node_indices()) + 1,
goal_fn=lambda goal: goal == "B",
edge_cost_fn=lambda x: float(x),
estimate_cost_fn=lambda _: 0,
)
12 changes: 12 additions & 0 deletions tests/graph/test_bellman_ford.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,15 @@ def test_raises_negative_cycle_all_pairs_bellman_ford_path_lenghts(self):

with self.assertRaises(rustworkx.NegativeCycle):
rustworkx.all_pairs_bellman_ford_path_lengths(graph, float)

def test_raises_index_error_bellman_ford_paths(self):
with self.assertRaises(IndexError):
rustworkx.graph_bellman_ford_shortest_paths(
self.graph, len(self.graph.node_indices()) + 1, weight_fn=lambda x: float(x)
)

def test_raises_index_error_bellman_ford_path_lenghts(self):
with self.assertRaises(IndexError):
rustworkx.graph_bellman_ford_shortest_path_lengths(
self.graph, len(self.graph.node_indices()) + 1, edge_cost_fn=lambda x: float(x)
)
10 changes: 10 additions & 0 deletions tests/graph/test_dijkstra.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,16 @@ def dijkstra_with_invalid_weights(self):
as_undirected=as_undirected,
)

def test_dijkstra_path_with_invalid_source(self):
with self.assertRaises(IndexError):
rustworkx.dijkstra_shortest_paths(self.graph, len(self.graph.node_indices()) + 1)

def test_dijkstra_path_lengths_with_invalid_source(self):
with self.assertRaises(IndexError):
rustworkx.dijkstra_shortest_path_lengths(
self.graph, len(self.graph.node_indices()) + 1, edge_cost_fn=float
)

def dijkstra_lengths_with_invalid_weights(self):
graph = rustworkx.generators.path_graph(2)
for invalid_weight in [float("nan"), -1]:
Expand Down
16 changes: 5 additions & 11 deletions tests/graph/test_max_weight_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,11 @@ def get_nx_weight(edge):
if (u, v) not in nx_matches:
if (v, u) not in nx_matches:
print(
"seed {} failed. Element {} and it's "
"reverse {} not found in networkx output.\nrustworkx"
" output: {}\nnetworkx output: {}\nedge list: {}\n"
"falling back to checking for a valid solution".format(
seed,
(u, v),
(v, u),
rx_matches,
nx_matches,
list(rx_graph.weighted_edge_list()),
)
f"seed {seed} failed. Element {(u, v)} and it's "
f"reverse {(v, u)} not found in networkx output.\nrustworkx"
f" output: {rx_matches}\nnetworkx output: {nx_matches}"
f"\nedge list: {list(rx_graph.weighted_edge_list())}\n"
"falling back to checking for a valid solution"
)
not_match = True
break
Expand Down
Loading