Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reverse inplace function for digraph #853

Merged
1 change: 1 addition & 0 deletions rustworkx/digraph.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class PyDiGraph(Generic[S, T]):
deliminator: Optional[str] = ...,
weight_fn: Optional[Callable[[T], str]] = ...,
) -> None: ...
def reverse(self) -> None: ...
def __delitem__(self, idx: int, /) -> None: ...
def __getitem__(self, idx: int, /) -> S: ...
def __getstate__(self) -> Any: ...
Expand Down
50 changes: 50 additions & 0 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2720,6 +2720,56 @@ impl PyDiGraph {
self.clone()
}

/// Reverse the direction of all edges in the graph, in place.
///
/// This method modifies the graph instance to reverse the direction of all edges.
/// It does so by iterating over all edges in the graph and removing each edge,
/// then adding a new edge in the opposite direction with the same weight.
///
/// # Arguments
///
/// * `py`: A reference to the Python interpreter.
///
/// # Returns
///
/// None.
///
/// # Panics
///
/// This method will panic if the edge indices or weights are not valid.
///
/// # Examples
///
/// ```
/// use rustworkx::{DiGraph, WeightedGraph};
///
/// let mut graph = DiGraph::<(), i32>::new();
/// graph.add_edge(0, 1, 3).unwrap();
/// graph.add_edge(1, 2, 5).unwrap();
/// graph.add_edge(2, 3, 2).unwrap();
///
/// graph.reverse(py);
///
/// assert_eq!(graph.edges().collect::<Vec<_>>(), vec![(3, 2), (2, 1), (1, 0)]);
/// ```
#[pyo3(text_signature = "(self)")]
pub fn reverse(&mut self, py: Python) {
let indices = self.graph.edge_indices().collect::<Vec<EdgeIndex>>();
for idx in indices {
let (source_node, dest_node) = self
.graph
.edge_endpoints(idx)
.unwrap();
let weight = self
.graph
.edge_weight(idx)
.unwrap()
.clone_ref(py);
self.graph.remove_edge(idx);
self.graph.add_edge(dest_node, source_node, weight);
}
}

/// Return the number of nodes in the graph
fn __len__(&self) -> PyResult<usize> {
Ok(self.graph.node_count())
Expand Down
41 changes: 41 additions & 0 deletions tests/rustworkx_tests/digraph/test_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,3 +962,44 @@ def test_extend_from_weighted_edge_list(self):
graph.extend_from_weighted_edge_list(edge_list)
self.assertEqual(len(graph), 4)
self.assertEqual(["a", "b", "c", "d", "e"], graph.edges())

def test_reverse_graph(self):
graph = rustworkx.PyDiGraph()
graph.add_nodes_from([i for i in range(4)])
edge_list = [
(0, 1, "a"),
(1, 2, "b"),
(0, 2, "c"),
(2, 3, "d"),
(0, 3, "e"),
]
graph.add_edges_from(edge_list)
graph.reverse()
self.assertEqual([(1, 0), (2, 1), (2, 0), (3, 2), (3, 0)], graph.edge_list())

def test_reverse_large_graph(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two small test cases that I think would be good to add is one with an empty graph. The other is for a graph where there is an edge removed in the middle (which will leave a hole in the edge indices). I don't expect there are any problems with the method's behavior with these, but they're common edge cases that are good to verify (because a lot of other methods have had issues in the past with both).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Sorry for the late response and update, I've added the tests

LARGE_AMOUNT_OF_NODES = 10000000

graph = rustworkx.PyDiGraph()
graph.add_nodes_from(range(LARGE_AMOUNT_OF_NODES))
edge_list = list(zip(range(LARGE_AMOUNT_OF_NODES), range(1, LARGE_AMOUNT_OF_NODES)))
weighted_edge_list = [(s, d, "a") for s, d in edge_list]
graph.add_edges_from(weighted_edge_list)
graph.reverse()
reversed_edge_list = [(d, s) for s, d in edge_list]
self.assertEqual(reversed_edge_list, graph.edge_list())

def test_reverse_empty_graph(self):
graph = rustworkx.PyDiGraph()
edges_before = graph.edge_list()
graph.reverse()
self.assertEqual(graph.edge_list(), edges_before)

def test_removed_middle_node_reverse(self):
graph = rustworkx.PyDiGraph()
graph.add_nodes_from(list(range(5)))
edge_list = [(0, 1), (2, 1), (1, 3), (3, 4), (4, 0)]
graph.extend_from_edge_list(edge_list)
graph.remove_node(1)
graph.reverse()
self.assertEqual(graph.edge_list(), [(4, 3), (0, 4)])