diff --git a/docs/source/api.rst b/docs/source/api.rst index f0b3dc29da..3851cb577c 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -4,23 +4,18 @@ retworkx API .. py:class:: PyDAG A class for creating direct acyclic graphs. - The PyDAG class is constructed using the Rust library `daggy`_ which is - itself built on the Rust library `petgraph`_. The limitations and quirks - with both libraries dictate how this operates. The biggest thing to be - aware of when using the PyDAG class is that while node and edge indexes - are used for accessing elements on the DAG, node removal can change the - index of a node `petgraph`_. The limitations and quirks - with both libraries dictate how this operates. The biggest thing to be - aware of when using the PyDAG class is that while node and edge indexes - are used for accessing elements on the DAG, node removal can change the - indexes of nodes. Basically when a node in the middle of the dag is - removed the last index is moved to fill that spot. This means either - you have to track that event, or on node removal update the indexes for - the nodes you care about. + The PyDAG class is constructed using the Rust library `petgraph`_ around + the ``StableGraph`` type. The limitations and quirks with this library and + type dictate how this operates. The biggest thing to be aware of when using + the PyDAG class is that an integer node and edge index is used for accessing + elements on the DAG, not the data/weight of nodes and edges. .. py:method:: __init__(self): Initialize an empty DAG. + .. py:method:: __len__(self): + Return the number of nodes in the graph. Use via ``len()`` function + .. py:method:: edges(self): Return a list of all edge data. @@ -49,6 +44,14 @@ retworkx API :returns: A list of the node data for all the parent neighbor nodes :rtype: list + .. py:method:: get_node_data(self, node): + Return the node data for a given node index + + :param int node: The index for the node + + :returns: The data object set for that node + :raises IndexError: when an invalid node index is provided + .. py:method:: get_edge_data(self, node_a, node_b): Return the edge data for the edge between 2 nodes. @@ -158,6 +161,34 @@ retworkx API :raises NoEdgeBetweenNodes if the DAG is broken and an edge can't be found to a neighbor node + .. py:method:: in_edges(self, node): + Get the index and edge data for all parents of a node. + + This will return a list of tuples with the parent index the node index + and the edge data. This can be used to recreate add_edge() calls. + + :param int node: The index of the node to get the edges for + + :returns in_edges: A list of tuples of the form: + (parent_index, node_index, edge_data) + :rtype: list + :raises NoEdgeBetweenNodes if the DAG is broken and an edge can't be + found to a neighbor node + + .. py:method:: out_edges(self, node): + Get the index and edge data for all children of a node. + + This will return a list of tuples with the child index the node index + and the edge data. This can be used to recreate add_edge() calls. + + :param int node: The index of the node to get the edges for + + :returns out_edges: A list of tuples of the form: + (node_index, child_index, edge_data) + :rtype: list + :raises NoEdgeBetweenNodes if the DAG is broken and an edge can't be + found to a neighbor node + .. py:method:: in_degree(self, node): Get the degree of a node for inbound edges. @@ -166,6 +197,14 @@ retworkx API :returns degree: The inbound degree for the specified node :rtype: int + .. py:method:: out_degree(self, node): + Get the degree of a node for outbound edges. + + :param int node: The index of the node to find the outbound degree of + + :returns degree: The outbound degree for the specified node + :rtype: int + .. py:method:: remove_edge(self, parent, child): Remove an edge between 2 nodes. @@ -183,7 +222,6 @@ retworkx API :param int edge: The index of the edge to remove -.. _daggy: https://github.com/mitchmindtree/daggy .. _petgraph: https://github.com/bluss/petgraph .. py:function:: dag_longest_path_length(graph): diff --git a/src/lib.rs b/src/lib.rs index 035c86a956..6b25bd9f4b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,8 +20,9 @@ use std::collections::HashMap; use std::iter; use std::ops::{Index, IndexMut}; +use pyo3::class::PyMappingProtocol; use pyo3::create_exception; -use pyo3::exceptions::Exception; +use pyo3::exceptions::{Exception, IndexError}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; use pyo3::wrap_pyfunction; @@ -294,6 +295,15 @@ impl PyDAG { Ok(data) } + pub fn get_node_data(&self, node: usize) -> PyResult<&PyObject> { + let index = NodeIndex::new(node); + let node = match self.graph.node_weight(index) { + Some(node) => node, + None => return Err(IndexError::py_err("No node found for index")), + }; + Ok(node) + } + pub fn get_all_edge_data( &self, py: Python, @@ -451,6 +461,51 @@ impl PyDAG { } Ok(out_dict.into()) } + + pub fn in_edges(&mut self, py: Python, node: usize) -> PyResult { + let index = NodeIndex::new(node); + let dir = petgraph::Direction::Incoming; + let neighbors = self.graph.neighbors_directed(index, dir); + let mut out_list: Vec = Vec::new(); + for neighbor in neighbors { + let edge = match self.graph.find_edge(neighbor, index) { + Some(edge) => edge, + None => { + return Err(NoEdgeBetweenNodes::py_err( + "No edge found between nodes", + )) + } + }; + let edge_w = self.graph.edge_weight(edge); + let triplet = + (neighbor.index(), node, edge_w.unwrap()).to_object(py); + out_list.push(triplet) + } + Ok(PyList::new(py, out_list).into()) + } + + pub fn out_edges(&mut self, py: Python, node: usize) -> PyResult { + let index = NodeIndex::new(node); + let dir = petgraph::Direction::Outgoing; + let neighbors = self.graph.neighbors_directed(index, dir); + let mut out_list: Vec = Vec::new(); + for neighbor in neighbors { + let edge = match self.graph.find_edge(index, neighbor) { + Some(edge) => edge, + None => { + return Err(NoEdgeBetweenNodes::py_err( + "No edge found between nodes", + )) + } + }; + let edge_w = self.graph.edge_weight(edge); + let triplet = + (node, neighbor.index(), edge_w.unwrap()).to_object(py); + out_list.push(triplet) + } + Ok(PyList::new(py, out_list).into()) + } + // pub fn add_nodes_from(&self) -> PyResult<()> { // // } @@ -466,6 +521,20 @@ impl PyDAG { let neighbors = self.graph.neighbors_directed(index, dir); neighbors.count() } + + pub fn out_degree(&self, node: usize) -> usize { + let index = NodeIndex::new(node); + let dir = petgraph::Direction::Outgoing; + let neighbors = self.graph.neighbors_directed(index, dir); + neighbors.count() + } +} + +#[pyproto] +impl PyMappingProtocol for PyDAG { + fn __len__(&self) -> PyResult { + Ok(self.graph.node_count()) + } } fn must_check_for_cycle(dag: &PyDAG, a: NodeIndex, b: NodeIndex) -> bool { diff --git a/tests/test_adj.py b/tests/test_adj.py index 4208d8dcd5..7d152aac53 100644 --- a/tests/test_adj.py +++ b/tests/test_adj.py @@ -44,6 +44,25 @@ def test_neighbor_dir_surrounded(self): res = dag.adj_direction(node_b, True) self.assertEqual({node_a: {'a': 1}}, res) + def test_single_neighbor_dir_out_edges(self): + dag = retworkx.PyDAG() + node_a = dag.add_node('a') + node_b = dag.add_child(node_a, 'b', {'a': 1}) + node_c = dag.add_child(node_a, 'c', {'a': 2}) + res = dag.out_edges(node_a) + self.assertEqual([(node_a, node_c, {'a': 2}), + (node_a, node_b, {'a': 1})], res) + + def test_neighbor_dir_surrounded_in_out_edges(self): + dag = retworkx.PyDAG() + node_a = dag.add_node('a') + node_b = dag.add_child(node_a, 'b', {'a': 1}) + node_c = dag.add_child(node_b, 'c', {'a': 2}) + res = dag.out_edges(node_b) + self.assertEqual([(node_b, node_c, {'a': 2})], res) + res = dag.in_edges(node_b) + self.assertEqual([(node_a, node_b, {'a': 1})], res) + def test_no_neighbor(self): dag = retworkx.PyDAG() node_a = dag.add_node('a') @@ -62,3 +81,17 @@ def test_in_direction_none(self): for i in range(5): dag.add_child(node_a, i, None) self.assertEqual(0, dag.in_degree(node_a)) + + def test_out_direction(self): + dag = retworkx.PyDAG() + node_a = dag.add_node('a') + for i in range(5): + dag.add_parent(node_a, i, None) + self.assertEqual(0, dag.out_degree(node_a)) + + def test_out_direction_none(self): + dag = retworkx.PyDAG() + node_a = dag.add_node('a') + for i in range(5): + dag.add_child(node_a, i, None) + self.assertEqual(5, dag.out_degree(node_a)) diff --git a/tests/test_nodes.py b/tests/test_nodes.py index c8b56f7eee..56a4940e60 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -40,3 +40,25 @@ def test_topo_sort(self): dag.add_parent(3, 'A parent', None) res = retworkx.topological_sort(dag) self.assertEqual([6, 0, 5, 4, 3, 2, 1], res) + + def test_get_node_data(self): + dag = retworkx.PyDAG() + node_a = dag.add_node('a') + node_b = dag.add_child(node_a, 'b', "Edgy") + self.assertEqual('b', dag.get_node_data(node_b)) + + def test_get_node_data_bad_index(self): + dag = retworkx.PyDAG() + node_a = dag.add_node('a') + node_b = dag.add_child(node_a, 'b', "Edgy") + self.assertRaises(IndexError, dag.get_node_data, 42) + + def test_pydag_length(self): + dag = retworkx.PyDAG() + node_a = dag.add_node('a') + node_b = dag.add_child(node_a, 'b', "Edgy") + self.assertEqual(2, len(dag)) + + def test_pydag_length_empty(self): + dag = retworkx.PyDAG() + self.assertEqual(0, len(dag))