Skip to content

Commit

Permalink
Fix handling of index holes in PyDiGraph pickling (#116)
Browse files Browse the repository at this point in the history
This commit fixes an issue with pickling (and by extension deepcopy)
PyDiGraph or PyDAG objects that have holes in their node id lists.
Previously the node ids would not be preserved across pickling leading
to a compacted list instead of the original node ids. For example, if
you had a PyDiGraph with node ids [0, 1, 3] after pickling/deepcopy the
node ids would be [0, 1, 2] but otherwise identical. This commit fixes
this issue by adding a check for holes to __setstate__ method and
incrementing the node id to reproduce a 1:1 mapping with the original
node ids prior to pickling.
  • Loading branch information
mtreinish authored Sep 21, 2020
1 parent 547c30b commit 4116d1a
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 39 deletions.
4 changes: 0 additions & 4 deletions retworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ class PyDAG(PyDiGraph):
ensure that no cycles are added, ensuring that the PyDAG class truly
represents a directed acyclic graph.
.. note::
When using ``copy.deepcopy()`` or pickling node indexes are not
guaranteed to be preserved.
PyDAG is a subclass of the PyDiGraph class and behaves identically to
the :class:`~retworkx.PyDiGraph` class.
"""
Expand Down
52 changes: 32 additions & 20 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,6 @@ use super::{
/// With check_cycle set to true any calls to :meth:`PyDiGraph.add_edge` will
/// ensure that no cycles are added, ensuring that the PyDiGraph class truly
/// represents a directed acyclic graph.
///
/// .. note::
/// When using ``copy.deepcopy()`` or pickling node indexes are not
/// guaranteed to be preserved.
///
#[pyclass(module = "retworkx", subclass)]
#[text_signature = "(/, check_cycle=False)"]
pub struct PyDiGraph {
Expand Down Expand Up @@ -334,35 +329,52 @@ impl PyDiGraph {
Ok(out_dict.into())
}

fn __setstate__(&mut self, state: PyObject) -> PyResult<()> {
let mut node_mapping: HashMap<usize, NodeIndex> = HashMap::new();
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
self.graph = StableDiGraph::<PyObject, PyObject>::new();
let gil = Python::acquire_gil();
let py = gil.python();
let dict_state = state.cast_as::<PyDict>(py)?;

let nodes_dict =
dict_state.get_item("nodes").unwrap().downcast::<PyDict>()?;
let edges_list =
dict_state.get_item("edges").unwrap().downcast::<PyList>()?;
for raw_index in nodes_dict.keys().iter() {
let mut node_indices: Vec<usize> = Vec::new();
for raw_index in nodes_dict.keys() {
let tmp_index = raw_index.downcast::<PyLong>()?;
let index: usize = tmp_index.extract()?;
let raw_data = nodes_dict.get_item(index).unwrap();
let node_index = self.graph.add_node(raw_data.into());
node_mapping.insert(index, node_index);
node_indices.push(tmp_index.extract()?);
}
let max_index: usize = *node_indices.iter().max().unwrap();
if max_index + 1 != node_indices.len() {
self.node_removed = true;
}
let mut tmp_nodes: Vec<NodeIndex> = Vec::new();
let mut node_count: usize = 0;
while max_index >= self.graph.node_bound() {
match nodes_dict.get_item(node_count) {
Some(raw_data) => {
self.graph.add_node(raw_data.into());
}
None => {
let tmp_node = self.graph.add_node(py.None());
tmp_nodes.push(tmp_node);
}
};
node_count += 1;
}
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}
for raw_edge in edges_list.iter() {
let edge = raw_edge.downcast::<PyTuple>()?;
let raw_p_index = edge.get_item(0).downcast::<PyLong>()?;
let tmp_p_index: usize = raw_p_index.extract()?;
let p_index: usize = raw_p_index.extract()?;
let raw_c_index = edge.get_item(1).downcast::<PyLong>()?;
let tmp_c_index: usize = raw_c_index.extract()?;
let c_index: usize = raw_c_index.extract()?;
let edge_data = edge.get_item(2);

let p_index = node_mapping.get(&tmp_p_index).unwrap();
let c_index = node_mapping.get(&tmp_c_index).unwrap();
self.graph.add_edge(*p_index, *c_index, edge_data.into());
self.graph.add_edge(
NodeIndex::new(p_index),
NodeIndex::new(c_index),
edge_data.into(),
);
}
Ok(())
}
Expand Down
33 changes: 18 additions & 15 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,26 +263,29 @@ impl PyGraph {
dict_state.get_item("nodes").unwrap().downcast::<PyDict>()?;
let edges_list =
dict_state.get_item("edges").unwrap().downcast::<PyList>()?;
let mut index_count = 0;
for raw_index in nodes_dict.keys().iter() {
let mut node_indices: Vec<usize> = Vec::new();
for raw_index in nodes_dict.keys() {
let tmp_index = raw_index.downcast::<PyLong>()?;
let index: usize = tmp_index.extract()?;
let mut tmp_nodes: Vec<NodeIndex> = Vec::new();
if index > index_count + 1 {
let diff = index - (index_count + 1);
for _ in 0..diff {
node_indices.push(tmp_index.extract()?);
}
let max_index: usize = *node_indices.iter().max().unwrap();
let mut tmp_nodes: Vec<NodeIndex> = Vec::new();
let mut node_count: usize = 0;
while max_index >= self.graph.node_bound() {
match nodes_dict.get_item(node_count) {
Some(raw_data) => {
self.graph.add_node(raw_data.into());
}
None => {
let tmp_node = self.graph.add_node(py.None());
tmp_nodes.push(tmp_node);
}
}
let raw_data = nodes_dict.get_item(index).unwrap();
let out_index = self.graph.add_node(raw_data.into());
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}
index_count = out_index.index();
};
node_count += 1;
}
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}

for raw_edge in edges_list.iter() {
let edge = raw_edge.downcast::<PyTuple>()?;
let raw_p_index = edge.get_item(0).downcast::<PyLong>()?;
Expand Down
12 changes: 12 additions & 0 deletions tests/test_deepcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,15 @@ def test_isomorphic_compare_nodes_identical(self):
self.assertTrue(
retworkx.is_isomorphic_node_match(
dag_a, dag_b, lambda x, y: x == y))

def test_deepcopy_with_holes(self):
dag_a = retworkx.PyDAG()
node_a = dag_a.add_node('a_1')
node_b = dag_a.add_node('a_2')
dag_a.add_edge(node_a, node_b, 'edge_1')
node_c = dag_a.add_node('a_3')
dag_a.add_edge(node_b, node_c, 'edge_2')
dag_a.remove_node(node_b)
dag_b = copy.deepcopy(dag_a)
self.assertIsInstance(dag_b, retworkx.PyDAG)
self.assertEqual([node_a, node_c], dag_b.node_indexes())

0 comments on commit 4116d1a

Please sign in to comment.