diff --git a/docs/source/api.rst b/docs/source/api.rst index 01b9ac88c2..b8532017e0 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -174,6 +174,7 @@ Custom Return Types retworkx.BFSSuccessors retworkx.NodeIndices + retworkx.EdgeIndices retworkx.EdgeList retworkx.WeightedEdgeList retworkx.PathMapping diff --git a/releasenotes/notes/add-edge-indices-method-c1868ab1dab61b18.yaml b/releasenotes/notes/add-edge-indices-method-c1868ab1dab61b18.yaml new file mode 100644 index 0000000000..c39cc7eb2e --- /dev/null +++ b/releasenotes/notes/add-edge-indices-method-c1868ab1dab61b18.yaml @@ -0,0 +1,12 @@ +--- +features: + - | + Added a new method, :meth:`~retworkx.PyDiGraph.edge_indices`, to the + :class:`~retworkx.PyDiGraph` and :class:`~retworkx.PyGraph` + (:meth:`~retworkx.PyGraph.edge_indices`) that will return a list of + edge indices for every edge in the graph object. + - | + Added a new custom return type :class:`~retworkx.EdgeIndices` which is + returned by :class:`retworkx.PyDiGraph.edge_indices` and + :class:`retworkx.PyGraph.edge_indices`. It is equivalent to a read-only + list of integers that represent a list of edge indices. diff --git a/src/digraph.rs b/src/digraph.rs index 536dfb5d3e..3083f7a077 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -46,7 +46,7 @@ use petgraph::visit::{ }; use super::dot_utils::build_dot; -use super::iterators::{EdgeList, NodeIndices, WeightedEdgeList}; +use super::iterators::{EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList}; use super::{ is_directed_acyclic_graph, DAGHasCycle, DAGWouldCycle, NoEdgeBetweenNodes, NoSuitableNeighbors, NodesRemoved, @@ -534,6 +534,17 @@ impl PyDiGraph { .collect() } + /// Return a list of all edge indices. + /// + /// :returns: A list of all the edge indices in the graph + /// :rtype: EdgeIndices + #[text_signature = "(self)"] + pub fn edge_indices(&self) -> EdgeIndices { + EdgeIndices { + edges: self.graph.edge_indices().map(|edge| edge.index()).collect(), + } + } + /// Return a list of all node data. /// /// :returns: A list of all the node data objects in the graph diff --git a/src/graph.rs b/src/graph.rs index 32a9b7c538..914771a0be 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -32,7 +32,7 @@ use ndarray::prelude::*; use numpy::PyReadonlyArray2; use super::dot_utils::build_dot; -use super::iterators::{EdgeList, NodeIndices, WeightedEdgeList}; +use super::iterators::{EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList}; use super::{NoEdgeBetweenNodes, NodesRemoved}; use petgraph::graph::{EdgeIndex, NodeIndex}; @@ -382,6 +382,17 @@ impl PyGraph { .collect() } + /// Return a list of all edge indices. + /// + /// :returns: A list of all the edge indices in the graph + /// :rtype: EdgeIndices + #[text_signature = "(self)"] + pub fn edge_indices(&self) -> EdgeIndices { + EdgeIndices { + edges: self.graph.edge_indices().map(|edge| edge.index()).collect(), + } + } + /// Return a list of all node data. /// /// :returns: A list of all the node data objects in the graph diff --git a/src/iterators.rs b/src/iterators.rs index 58f5f10832..d8ffb9041f 100644 --- a/src/iterators.rs +++ b/src/iterators.rs @@ -1326,3 +1326,118 @@ impl PyIterProtocol for PathLengthMappingItems { } } } + +/// A custom class for the return of edge indices +/// +/// This class is a container class for the results of functions that +/// return a list of edge indices. It implements the Python sequence +/// protocol. So you can treat the return as a read-only sequence/list +/// that is integer indexed. If you want to use it as an iterator you +/// can by wrapping it in an ``iter()`` that will yield the results in +/// order. +/// +/// For example:: +/// +/// import retworkx +/// +/// graph = retworkx.generators.directed_path_graph(5) +/// edges = retworkx.edge_indexes(0) +/// # Index based access +/// third_element = edges[2] +/// # Use as iterator +/// edges_iter = iter(edges) +/// first_element = next(edges_iter) +/// second_element = next(edges_iter) +/// +#[pyclass(module = "retworkx", gc)] +#[derive(Clone)] +pub struct EdgeIndices { + pub edges: Vec, +} + +#[pymethods] +impl EdgeIndices { + #[new] + fn new() -> EdgeIndices { + EdgeIndices { edges: Vec::new() } + } + + fn __getstate__(&self) -> Vec { + self.edges.clone() + } + + fn __setstate__(&mut self, state: Vec) { + self.edges = state; + } +} + +#[pyproto] +impl<'p> PyObjectProtocol<'p> for EdgeIndices { + fn __richcmp__( + &self, + other: &'p PySequence, + op: pyo3::basic::CompareOp, + ) -> PyResult { + let compare = |other: &PySequence| -> PyResult { + if other.len()? as usize != self.edges.len() { + return Ok(false); + } + for i in 0..self.edges.len() { + let other_raw = other.get_item(i.try_into().unwrap())?; + let other_value: usize = other_raw.extract()?; + if other_value != self.edges[i] { + return Ok(false); + } + } + Ok(true) + }; + match op { + pyo3::basic::CompareOp::Eq => compare(other), + pyo3::basic::CompareOp::Ne => match compare(other) { + Ok(res) => Ok(!res), + Err(err) => Err(err), + }, + _ => Err(PyNotImplementedError::new_err( + "Comparison not implemented", + )), + } + } + + fn __str__(&self) -> PyResult { + let str_vec: Vec = + self.edges.iter().map(|n| format!("{}", n)).collect(); + Ok(format!("EdgeIndices[{}]", str_vec.join(", "))) + } + + fn __hash__(&self) -> PyResult { + let mut hasher = DefaultHasher::new(); + for index in &self.edges { + hasher.write_usize(*index); + } + Ok(hasher.finish()) + } +} + +#[pyproto] +impl PySequenceProtocol for EdgeIndices { + fn __len__(&self) -> PyResult { + Ok(self.edges.len()) + } + + fn __getitem__(&'p self, idx: isize) -> PyResult { + if idx >= self.edges.len().try_into().unwrap() { + Err(PyIndexError::new_err(format!("Invalid index, {}", idx))) + } else { + Ok(self.edges[idx as usize]) + } + } +} + +#[pyproto] +impl PyGCProtocol for EdgeIndices { + fn __traverse__(&self, _visit: PyVisit) -> Result<(), PyTraverseError> { + Ok(()) + } + + fn __clear__(&mut self) {} +} diff --git a/src/lib.rs b/src/lib.rs index 4e36f1b5c2..88e629d26c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3619,6 +3619,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/tests/digraph/test_edges.py b/tests/digraph/test_edges.py index b1991e95e2..42293551e3 100644 --- a/tests/digraph/test_edges.py +++ b/tests/digraph/test_edges.py @@ -104,6 +104,18 @@ def test_edges_empty(self): dag.add_node("a") self.assertEqual([], dag.edges()) + def test_edge_indices(self): + dag = retworkx.PyDAG() + node_a = dag.add_node("a") + node_b = dag.add_child(node_a, "b", "Edgy") + dag.add_child(node_b, "c", "Super edgy") + self.assertEqual([0, 1], dag.edge_indices()) + + def test_edge_indices_empty(self): + dag = retworkx.PyDAG() + dag.add_node("a") + self.assertEqual([], dag.edge_indices()) + def test_add_duplicates(self): dag = retworkx.PyDAG() node_a = dag.add_node("a") diff --git a/tests/graph/test_edges.py b/tests/graph/test_edges.py index 4573197d79..3532fb2919 100644 --- a/tests/graph/test_edges.py +++ b/tests/graph/test_edges.py @@ -120,6 +120,20 @@ def test_edges_empty(self): graph.add_node("a") self.assertEqual([], graph.edges()) + def test_edge_indices(self): + graph = retworkx.PyGraph() + node_a = graph.add_node("a") + node_b = graph.add_node("b") + graph.add_edge(node_a, node_b, "Edgy") + node_c = graph.add_node("c") + graph.add_edge(node_b, node_c, "Super edgy") + self.assertEqual([0, 1], graph.edge_indices()) + + def test_get_edge_indices_empty(self): + graph = retworkx.PyGraph() + graph.add_node("a") + self.assertEqual([], graph.edge_indices()) + def test_add_duplicates(self): graph = retworkx.PyGraph() node_a = graph.add_node("a") diff --git a/tests/test_custom_return_types.py b/tests/test_custom_return_types.py index 138a6aee16..834a554a8f 100644 --- a/tests/test_custom_return_types.py +++ b/tests/test_custom_return_types.py @@ -151,6 +151,66 @@ def test_hash(self): self.assertEqual(hash_res, hash(res)) +class TestEdgeIndicesComparisons(unittest.TestCase): + def setUp(self): + self.dag = retworkx.PyDiGraph() + node_a = self.dag.add_node("a") + node_b = self.dag.add_child(node_a, "b", "Edgy") + self.dag.add_child(node_b, "c", "Super Edgy") + + def test__eq__match(self): + self.assertTrue(self.dag.edge_indices() == [0, 1]) + + def test__eq__not_match(self): + self.assertFalse(self.dag.edge_indices() == [1, 2]) + + def test__eq__different_length(self): + self.assertFalse(self.dag.edge_indices() == [0, 1, 2, 3]) + + def test__eq__invalid_type(self): + with self.assertRaises(TypeError): + self.dag.edge_indices() == ["a", None] + + def test__ne__match(self): + self.assertFalse(self.dag.edge_indices() != [0, 1]) + + def test__ne__not_match(self): + self.assertTrue(self.dag.edge_indices() != [1, 2]) + + def test__ne__different_length(self): + self.assertTrue(self.dag.edge_indices() != [0, 1, 2, 3]) + + def test__ne__invalid_type(self): + with self.assertRaises(TypeError): + self.dag.edge_indices() != ["a", None] + + def test__gt__not_implemented(self): + with self.assertRaises(NotImplementedError): + self.dag.edge_indices() > [2, 1] + + def test_deepcopy(self): + edges = self.dag.edge_indices() + edges_copy = copy.deepcopy(edges) + self.assertEqual(edges, edges_copy) + + def test_pickle(self): + edges = self.dag.edge_indices() + edges_pickle = pickle.dumps(edges) + edges_copy = pickle.loads(edges_pickle) + self.assertEqual(edges, edges_copy) + + def test_str(self): + res = self.dag.edge_indices() + self.assertEqual("EdgeIndices[0, 1]", str(res)) + + def test_hash(self): + res = self.dag.edge_indices() + hash_res = hash(res) + self.assertIsInstance(hash_res, int) + # Assert hash is stable + self.assertEqual(hash_res, hash(res)) + + class TestEdgeListComparisons(unittest.TestCase): def setUp(self): self.dag = retworkx.PyDAG()