Skip to content

Commit

Permalink
Add edge_index_map() method to graph classes (#324)
Browse files Browse the repository at this point in the history
* Add edge_index_map() method to graph classes

This commit adds a new function edge_index_map() which will return a
mapping from edge indices to a tuple of the node indices and weight for
the edge.

* Fix lint

* Add release note

* Add custom return type tests

* Add tests for method

* Apply suggestions from code review

Co-authored-by: Ali Javadi-Abhari <[email protected]>
  • Loading branch information
mtreinish and ajavadia authored May 27, 2021
1 parent 252adee commit 885f2de
Show file tree
Hide file tree
Showing 9 changed files with 492 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ Custom Return Types
retworkx.EdgeIndices
retworkx.EdgeList
retworkx.WeightedEdgeList
retworkx.EdgeIndexMap
retworkx.PathMapping
retworkx.PathLengthMapping
retworkx.Pos2DMapping
Expand Down
14 changes: 14 additions & 0 deletions releasenotes/notes/edge-index-map-cf07a035d02481a1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
features:
- |
Added a new method, :meth:`~retworkx.PyDiGraph.edge_index_map`, to the
:class:`~retworkx.PyDiGraph` and :class:`~retworkx.PyGraph`
(:meth:`~retworkx.PyGraph.edge_index_map`) that will return a read-only
mapping of edge indices to a tuple of the form
``(source_node_index, target_node_index, weight/data payload)`` for every edge in the
graph object.
- |
Added a new custom return type :class:`~retworkx.EdgeIndexMap` which is
returned by :meth:`retworkx.PyDiGraph.edge_index_map` and
:meth:`retworkx.PyGraph.edge_index_map`. It is equivalent to a read-only
dict/mapping that represent a mapping of edge indices to the edge.
31 changes: 30 additions & 1 deletion src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ use petgraph::visit::{
};

use super::dot_utils::build_dot;
use super::iterators::{EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList};
use super::iterators::{
EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList,
};
use super::{
is_directed_acyclic_graph, DAGHasCycle, DAGWouldCycle, NoEdgeBetweenNodes,
NoSuitableNeighbors, NodesRemoved,
Expand Down Expand Up @@ -902,6 +904,33 @@ impl PyDiGraph {
}
}

/// Get an edge index map
///
/// Returns a read only mapping from edge indices to the weighted edge
/// tuple. The return is a mapping of the form:
/// ``{0: (0, 1, "weight"), 1: (2, 3, 2.3)}``
///
/// :returns: An edge index map
/// :rtype: EdgeIndexMap
#[text_signature = "(self)"]
pub fn edge_index_map(&self, py: Python) -> EdgeIndexMap {
EdgeIndexMap {
edge_map: self
.edge_references()
.map(|edge| {
(
edge.id().index(),
(
edge.source().index(),
edge.target().index(),
edge.weight().clone_ref(py),
),
)
})
.collect(),
}
}

/// Remove a node from the graph.
///
/// :param int node: The index of the node to remove. If the index is not
Expand Down
31 changes: 30 additions & 1 deletion src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ use ndarray::prelude::*;
use numpy::PyReadonlyArray2;

use super::dot_utils::build_dot;
use super::iterators::{EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList};
use super::iterators::{
EdgeIndexMap, EdgeIndices, EdgeList, NodeIndices, WeightedEdgeList,
};
use super::{NoEdgeBetweenNodes, NodesRemoved};

use petgraph::graph::{EdgeIndex, NodeIndex};
Expand Down Expand Up @@ -617,6 +619,33 @@ impl PyGraph {
}
}

/// Get an edge index map
///
/// Returns a read only mapping from edge indices to the weighted edge
/// tuple. The return is a mapping of the form:
/// ``{0: (0, 1, "weight"), 1: (2, 3, 2.3)}``
///
/// :returns: An edge index map
/// :rtype: EdgeIndexMap
#[text_signature = "(self)"]
pub fn edge_index_map(&self, py: Python) -> EdgeIndexMap {
EdgeIndexMap {
edge_map: self
.edge_references()
.map(|edge| {
(
edge.id().index(),
(
edge.source().index(),
edge.target().index(),
edge.weight().clone_ref(py),
),
)
})
.collect(),
}
}

/// Remove a node from the graph.
///
/// :param int node: The index of the node to remove. If the index is not
Expand Down
263 changes: 263 additions & 0 deletions src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,269 @@ impl PyIterProtocol for PathLengthMappingItems {
}
}

/// A class representing a mapping of edge indices to a tuple of node indices
/// and weight/data payload
///
/// This class is equivalent to having a read only dict of the form::
///
/// {1: (0, 1, "weight'), 3: (2, 3, 1.2)}
///
/// It is used to efficiently represent an edge index map for a retworkx
/// graph. It behaves as a drop in replacement for a readonly ``dict``.
#[pyclass(module = "retworkx", gc)]
pub struct EdgeIndexMap {
pub edge_map: HashMap<usize, (usize, usize, PyObject)>,
}

#[pymethods]
impl EdgeIndexMap {
#[new]
fn new() -> EdgeIndexMap {
EdgeIndexMap {
edge_map: HashMap::new(),
}
}

fn __getstate__(&self) -> HashMap<usize, (usize, usize, PyObject)> {
self.edge_map.clone()
}

fn __setstate__(
&mut self,
state: HashMap<usize, (usize, usize, PyObject)>,
) {
self.edge_map = state;
}

fn keys(&self) -> EdgeIndexMapKeys {
EdgeIndexMapKeys {
edge_map_keys: self.edge_map.keys().copied().collect(),
iter_pos: 0,
}
}

fn values(&self) -> EdgeIndexMapValues {
EdgeIndexMapValues {
edge_map_values: self.edge_map.values().cloned().collect(),
iter_pos: 0,
}
}

fn items(&self) -> EdgeIndexMapItems {
let items: Vec<(usize, (usize, usize, PyObject))> =
self.edge_map.iter().map(|(k, v)| (*k, v.clone())).collect();
EdgeIndexMapItems {
edge_map_items: items,
iter_pos: 0,
}
}
}

#[pyproto]
impl<'p> PyObjectProtocol<'p> for EdgeIndexMap {
fn __richcmp__(
&self,
other: PyObject,
op: pyo3::basic::CompareOp,
) -> PyResult<bool> {
let compare = |other: PyObject| -> PyResult<bool> {
let gil = Python::acquire_gil();
let py = gil.python();
let other_ref = other.as_ref(py);
if other_ref.len()? != self.edge_map.len() {
return Ok(false);
}
for (key, value) in &self.edge_map {
match other_ref.get_item(key) {
Ok(other_raw) => {
let other_value: (usize, usize, PyObject) =
other_raw.extract()?;
if other_value.0 != value.0
|| other_value.1 != value.1
|| value.2.as_ref(py).compare(other_value.2)?
!= std::cmp::Ordering::Equal
{
return Ok(false);
}
}
Err(ref err)
if Python::with_gil(|py| {
err.is_instance::<PyKeyError>(py)
}) =>
{
return Ok(false);
}
Err(err) => return Err(err),
}
}
Ok(true)
};
match op {
pyo3::basic::CompareOp::Eq => compare(other),
pyo3::basic::CompareOp::Ne => match compare(other) {
Ok(res) => Ok(!res),
Err(err) => Err(err),
},
_ => Err(PyNotImplementedError::new_err(
"Comparison not implemented",
)),
}
}

fn __str__(&self) -> PyResult<String> {
let mut str_vec: Vec<String> = Vec::with_capacity(self.edge_map.len());
let gil = Python::acquire_gil();
let py = gil.python();
for path in &self.edge_map {
str_vec.push(format!(
"{}: ({}, {}, {})",
path.0,
path.1 .0,
path.1 .1,
path.1 .2.as_ref(py).str()?
));
}
Ok(format!("EdgeIndexMap{{{}}}", str_vec.join(", ")))
}

fn __hash__(&self) -> PyResult<u64> {
let mut hasher = DefaultHasher::new();
let gil = Python::acquire_gil();
let py = gil.python();
for index in &self.edge_map {
hasher.write_usize(*index.0);
hasher.write(&index.1 .0.to_be_bytes());
hasher.write(&index.1 .1.to_be_bytes());
hasher.write_isize(index.1 .2.as_ref(py).hash()?);
}
Ok(hasher.finish())
}
}

#[pyproto]
impl PySequenceProtocol for EdgeIndexMap {
fn __len__(&self) -> PyResult<usize> {
Ok(self.edge_map.len())
}

fn __contains__(&self, index: usize) -> PyResult<bool> {
Ok(self.edge_map.contains_key(&index))
}
}

#[pyproto]
impl PyMappingProtocol for EdgeIndexMap {
/// Return the number of nodes in the graph
fn __len__(&self) -> PyResult<usize> {
Ok(self.edge_map.len())
}
fn __getitem__(&'p self, idx: usize) -> PyResult<(usize, usize, PyObject)> {
match self.edge_map.get(&idx) {
Some(data) => Ok(data.clone()),
None => Err(PyIndexError::new_err("No node found for index")),
}
}
}

#[pyproto]
impl PyIterProtocol for EdgeIndexMap {
fn __iter__(slf: PyRef<Self>) -> EdgeIndexMapKeys {
EdgeIndexMapKeys {
edge_map_keys: slf.edge_map.keys().copied().collect(),
iter_pos: 0,
}
}
}

#[pyproto]
impl PyGCProtocol for EdgeIndexMap {
fn __traverse__(&self, visit: PyVisit) -> Result<(), PyTraverseError> {
for edge in &self.edge_map {
visit.call(&edge.1 .2)?;
}
Ok(())
}

fn __clear__(&mut self) {
self.edge_map = HashMap::new();
}
}

#[pyclass(module = "retworkx")]
pub struct EdgeIndexMapKeys {
pub edge_map_keys: Vec<usize>,
iter_pos: usize,
}

#[pyproto]
impl PyIterProtocol for EdgeIndexMapKeys {
fn __iter__(slf: PyRef<Self>) -> Py<EdgeIndexMapKeys> {
slf.into()
}
fn __next__(
mut slf: PyRefMut<Self>,
) -> IterNextOutput<usize, &'static str> {
if slf.iter_pos < slf.edge_map_keys.len() {
let res = IterNextOutput::Yield(slf.edge_map_keys[slf.iter_pos]);
slf.iter_pos += 1;
res
} else {
IterNextOutput::Return("Ended")
}
}
}

#[pyclass(module = "retworkx")]
pub struct EdgeIndexMapValues {
pub edge_map_values: Vec<(usize, usize, PyObject)>,
iter_pos: usize,
}

#[pyproto]
impl PyIterProtocol for EdgeIndexMapValues {
fn __iter__(slf: PyRef<Self>) -> Py<EdgeIndexMapValues> {
slf.into()
}
fn __next__(
mut slf: PyRefMut<Self>,
) -> IterNextOutput<(usize, usize, PyObject), &'static str> {
if slf.iter_pos < slf.edge_map_values.len() {
let res = IterNextOutput::Yield(
slf.edge_map_values[slf.iter_pos].clone(),
);
slf.iter_pos += 1;
res
} else {
IterNextOutput::Return("Ended")
}
}
}

#[pyclass(module = "retworkx")]
pub struct EdgeIndexMapItems {
pub edge_map_items: Vec<(usize, (usize, usize, PyObject))>,
iter_pos: usize,
}

#[pyproto]
impl PyIterProtocol for EdgeIndexMapItems {
fn __iter__(slf: PyRef<Self>) -> Py<EdgeIndexMapItems> {
slf.into()
}
fn __next__(
mut slf: PyRefMut<Self>,
) -> IterNextOutput<(usize, (usize, usize, PyObject)), &'static str> {
if slf.iter_pos < slf.edge_map_items.len() {
let res =
IterNextOutput::Yield(slf.edge_map_items[slf.iter_pos].clone());
slf.iter_pos += 1;
res
} else {
IterNextOutput::Return("Ended")
}
}
}

/// A custom class for the return of edge indices
///
/// This class is a container class for the results of functions that
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4185,6 +4185,7 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<iterators::NodeIndices>()?;
m.add_class::<iterators::EdgeIndices>()?;
m.add_class::<iterators::EdgeList>()?;
m.add_class::<iterators::EdgeIndexMap>()?;
m.add_class::<iterators::WeightedEdgeList>()?;
m.add_class::<iterators::PathMapping>()?;
m.add_class::<iterators::PathLengthMapping>()?;
Expand Down
Loading

0 comments on commit 885f2de

Please sign in to comment.