Skip to content

Commit

Permalink
Add edge_indices methods to graph objects
Browse files Browse the repository at this point in the history
This commit adds a new method to the PyDiGraph (and PyDAG) and PyGraph
classes, node_indices, which returns a list of a edge indices for all
edges in the graph.
  • Loading branch information
mtreinish committed May 18, 2021
1 parent d052f77 commit 32a5aa8
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ Custom Return Types

retworkx.BFSSuccessors
retworkx.NodeIndices
retworkx.EdgeIndices
retworkx.EdgeList
retworkx.WeightedEdgeList
retworkx.PathMapping
Expand Down
12 changes: 12 additions & 0 deletions releasenotes/notes/add-edge-indices-method-c1868ab1dab61b18.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
features:
- |
Added a new method, :meth:`~retworkx.PyDiGraph.edge_indices`, to the
:class:`~retworkx.PyDiGraph` and :class:`~retworkx.PyGraph`
(:meth:`~retworkx.PyGraph.edge_indices`) that will return a list of
edge indices for every edge in the graph object.
- |
Added a new custom return type :class:`~retworkx.EdgeIndices` which is
returned by :class:`retworkx.PyDiGraph.edge_indices` and
:class:`retworkx.PyGraph.edge_indices`. It is equivalent to a read-only
list of integers that represent a list of edge indices.
13 changes: 12 additions & 1 deletion src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use petgraph::visit::{
};

use super::dot_utils::build_dot;
use super::iterators::{EdgeList, NodeIndices, WeightedEdgeList};
use super::iterators::{EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList};
use super::{
is_directed_acyclic_graph, DAGHasCycle, DAGWouldCycle, NoEdgeBetweenNodes,
NoSuitableNeighbors, NodesRemoved,
Expand Down Expand Up @@ -534,6 +534,17 @@ impl PyDiGraph {
.collect()
}

/// Return a list of all edge indices.
///
/// :returns: A list of all the edge indices in the graph
/// :rtype: EdgeIndices
#[text_signature = "(self)"]
pub fn edge_indices(&self) -> EdgeIndices {
EdgeIndices {
edges: self.graph.edge_indices().map(|edge| edge.index()).collect(),
}
}

/// Return a list of all node data.
///
/// :returns: A list of all the node data objects in the graph
Expand Down
13 changes: 12 additions & 1 deletion src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use ndarray::prelude::*;
use numpy::PyReadonlyArray2;

use super::dot_utils::build_dot;
use super::iterators::{EdgeList, NodeIndices, WeightedEdgeList};
use super::iterators::{EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList};
use super::{NoEdgeBetweenNodes, NodesRemoved};

use petgraph::graph::{EdgeIndex, NodeIndex};
Expand Down Expand Up @@ -382,6 +382,17 @@ impl PyGraph {
.collect()
}

/// Return a list of all edge indices.
///
/// :returns: A list of all the edge indices in the graph
/// :rtype: EdgeIndices
#[text_signature = "(self)"]
pub fn edge_indices(&self) -> EdgeIndices {
EdgeIndices {
edges: self.graph.edge_indices().map(|edge| edge.index()).collect(),
}
}

/// Return a list of all node data.
///
/// :returns: A list of all the node data objects in the graph
Expand Down
115 changes: 115 additions & 0 deletions src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1326,3 +1326,118 @@ impl PyIterProtocol for PathLengthMappingItems {
}
}
}

/// A custom class for the return of edge indices
///
/// This class is a container class for the results of functions that
/// return a list of edge indices. It implements the Python sequence
/// protocol. So you can treat the return as a read-only sequence/list
/// that is integer indexed. If you want to use it as an iterator you
/// can by wrapping it in an ``iter()`` that will yield the results in
/// order.
///
/// For example::
///
/// import retworkx
///
/// graph = retworkx.generators.directed_path_graph(5)
/// edges = retworkx.edge_indexes(0)
/// # Index based access
/// third_element = edges[2]
/// # Use as iterator
/// edges_iter = iter(edges)
/// first_element = next(edges_iter)
/// second_element = next(edges_iter)
///
#[pyclass(module = "retworkx", gc)]
#[derive(Clone)]
pub struct EdgeIndices {
pub edges: Vec<usize>,
}

#[pymethods]
impl EdgeIndices {
#[new]
fn new() -> EdgeIndices {
EdgeIndices { edges: Vec::new() }
}

fn __getstate__(&self) -> Vec<usize> {
self.edges.clone()
}

fn __setstate__(&mut self, state: Vec<usize>) {
self.edges = state;
}
}

#[pyproto]
impl<'p> PyObjectProtocol<'p> for EdgeIndices {
fn __richcmp__(
&self,
other: &'p PySequence,
op: pyo3::basic::CompareOp,
) -> PyResult<bool> {
let compare = |other: &PySequence| -> PyResult<bool> {
if other.len()? as usize != self.edges.len() {
return Ok(false);
}
for i in 0..self.edges.len() {
let other_raw = other.get_item(i.try_into().unwrap())?;
let other_value: usize = other_raw.extract()?;
if other_value != self.edges[i] {
return Ok(false);
}
}
Ok(true)
};
match op {
pyo3::basic::CompareOp::Eq => compare(other),
pyo3::basic::CompareOp::Ne => match compare(other) {
Ok(res) => Ok(!res),
Err(err) => Err(err),
},
_ => Err(PyNotImplementedError::new_err(
"Comparison not implemented",
)),
}
}

fn __str__(&self) -> PyResult<String> {
let str_vec: Vec<String> =
self.edges.iter().map(|n| format!("{}", n)).collect();
Ok(format!("EdgeIndices[{}]", str_vec.join(", ")))
}

fn __hash__(&self) -> PyResult<u64> {
let mut hasher = DefaultHasher::new();
for index in &self.edges {
hasher.write_usize(*index);
}
Ok(hasher.finish())
}
}

#[pyproto]
impl PySequenceProtocol for EdgeIndices {
fn __len__(&self) -> PyResult<usize> {
Ok(self.edges.len())
}

fn __getitem__(&'p self, idx: isize) -> PyResult<usize> {
if idx >= self.edges.len().try_into().unwrap() {
Err(PyIndexError::new_err(format!("Invalid index, {}", idx)))
} else {
Ok(self.edges[idx as usize])
}
}
}

#[pyproto]
impl PyGCProtocol for EdgeIndices {
fn __traverse__(&self, _visit: PyVisit) -> Result<(), PyTraverseError> {
Ok(())
}

fn __clear__(&mut self) {}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3619,6 +3619,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<graph::PyGraph>()?;
m.add_class::<iterators::BFSSuccessors>()?;
m.add_class::<iterators::NodeIndices>()?;
m.add_class::<iterators::EdgeIndices>()?;
m.add_class::<iterators::EdgeList>()?;
m.add_class::<iterators::WeightedEdgeList>()?;
m.add_class::<iterators::PathMapping>()?;
Expand Down
12 changes: 12 additions & 0 deletions tests/digraph/test_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ def test_edges_empty(self):
dag.add_node("a")
self.assertEqual([], dag.edges())

def test_edge_indices(self):
dag = retworkx.PyDAG()
node_a = dag.add_node("a")
node_b = dag.add_child(node_a, "b", "Edgy")
dag.add_child(node_b, "c", "Super edgy")
self.assertEqual([0, 1], dag.edge_indices())

def test_edge_indices_empty(self):
dag = retworkx.PyDAG()
dag.add_node("a")
self.assertEqual([], dag.edge_indices())

def test_add_duplicates(self):
dag = retworkx.PyDAG()
node_a = dag.add_node("a")
Expand Down
14 changes: 14 additions & 0 deletions tests/graph/test_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ def test_edges_empty(self):
graph.add_node("a")
self.assertEqual([], graph.edges())

def test_edge_indices(self):
graph = retworkx.PyGraph()
node_a = graph.add_node("a")
node_b = graph.add_node("b")
graph.add_edge(node_a, node_b, "Edgy")
node_c = graph.add_node("c")
graph.add_edge(node_b, node_c, "Super edgy")
self.assertEqual([0, 1], graph.edge_indices())

def test_get_edge_indices_empty(self):
graph = retworkx.PyGraph()
graph.add_node("a")
self.assertEqual([], graph.edge_indices())

def test_add_duplicates(self):
graph = retworkx.PyGraph()
node_a = graph.add_node("a")
Expand Down
60 changes: 60 additions & 0 deletions tests/test_custom_return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,66 @@ def test_hash(self):
self.assertEqual(hash_res, hash(res))


class TestEdgeIndicesComparisons(unittest.TestCase):
def setUp(self):
self.dag = retworkx.PyDiGraph()
node_a = self.dag.add_node("a")
node_b = self.dag.add_child(node_a, "b", "Edgy")
self.dag.add_child(node_b, "c", "Super Edgy")

def test__eq__match(self):
self.assertTrue(self.dag.edge_indices() == [0, 1])

def test__eq__not_match(self):
self.assertFalse(self.dag.edge_indices() == [1, 2])

def test__eq__different_length(self):
self.assertFalse(self.dag.edge_indices() == [0, 1, 2, 3])

def test__eq__invalid_type(self):
with self.assertRaises(TypeError):
self.dag.edge_indices() == ["a", None]

def test__ne__match(self):
self.assertFalse(self.dag.edge_indices() != [0, 1])

def test__ne__not_match(self):
self.assertTrue(self.dag.edge_indices() != [1, 2])

def test__ne__different_length(self):
self.assertTrue(self.dag.edge_indices() != [0, 1, 2, 3])

def test__ne__invalid_type(self):
with self.assertRaises(TypeError):
self.dag.edge_indices() != ["a", None]

def test__gt__not_implemented(self):
with self.assertRaises(NotImplementedError):
self.dag.edge_indices() > [2, 1]

def test_deepcopy(self):
edges = self.dag.edge_indices()
edges_copy = copy.deepcopy(edges)
self.assertEqual(edges, edges_copy)

def test_pickle(self):
edges = self.dag.edge_indices()
edges_pickle = pickle.dumps(edges)
edges_copy = pickle.loads(edges_pickle)
self.assertEqual(edges, edges_copy)

def test_str(self):
res = self.dag.edge_indices()
self.assertEqual("EdgeIndices[0, 1]", str(res))

def test_hash(self):
res = self.dag.edge_indices()
hash_res = hash(res)
self.assertIsInstance(hash_res, int)
# Assert hash is stable
self.assertEqual(hash_res, hash(res))


class TestEdgeListComparisons(unittest.TestCase):
def setUp(self):
self.dag = retworkx.PyDAG()
Expand Down

0 comments on commit 32a5aa8

Please sign in to comment.