diff --git a/src/digraph.rs b/src/digraph.rs index 2d8603b83..712da1988 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -41,8 +41,8 @@ use petgraph::graph::{EdgeIndex, NodeIndex}; use petgraph::prelude::*; use petgraph::visit::{ - GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, NodeIndexable, - Visitable, + EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, + NodeIndexable, Visitable, }; use super::dot_utils::build_dot; @@ -273,85 +273,205 @@ impl PyDiGraph { } fn __getstate__(&self, py: Python) -> PyResult { - let out_dict = PyDict::new(py); - let node_dict = PyDict::new(py); - let mut out_list: Vec = Vec::with_capacity(self.graph.edge_count()); - out_dict.set_item("nodes", node_dict)?; - out_dict.set_item("nodes_removed", self.node_removed)?; - out_dict.set_item("multigraph", self.multigraph)?; - let dir = petgraph::Direction::Incoming; - for node_index in self.graph.node_indices() { - let node_data = self.graph.node_weight(node_index).unwrap(); - node_dict.set_item(node_index.index(), node_data)?; - for edge in self.graph.edges_directed(node_index, dir) { - let edge_w = edge.weight(); - let triplet = (edge.source().index(), edge.target().index(), edge_w).to_object(py); - out_list.push(triplet); - } + let mut nodes: Vec = Vec::with_capacity(self.graph.node_count()); + let mut edges: Vec = Vec::with_capacity(self.graph.edge_bound()); + + // save nodes to a list along with its index + for node_idx in self.graph.node_indices() { + let node_data = self.graph.node_weight(node_idx).unwrap(); + nodes.push((node_idx.index(), node_data).to_object(py)); + } + + // edges are saved with none (deleted edges) instead of their index to save space + for i in 0..self.graph.edge_bound() { + let idx = EdgeIndex::new(i); + let edge = match self.graph.edge_weight(idx) { + Some(edge_w) => { + let endpoints = self.graph.edge_endpoints(idx).unwrap(); + (endpoints.0.index(), endpoints.1.index(), edge_w).to_object(py) + } + None => py.None(), + }; + edges.push(edge); } - let py_out_list: PyObject = PyList::new(py, out_list).into(); - out_dict.set_item("edges", py_out_list)?; - Ok(out_dict.into()) + + let outdict = PyDict::new(py); + let nodes_lst: PyObject = PyList::new(py, nodes).into(); + let edges_lst: PyObject = PyList::new(py, edges).into(); + outdict.set_item("nodes", nodes_lst)?; + outdict.set_item("edges", edges_lst)?; + outdict.set_item( + "node_removed", + (self.graph.node_bound() != self.graph.node_count()).to_object(py), + )?; + outdict.set_item("multigraph", self.multigraph)?; + Ok(outdict.into()) } fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - self.graph = StablePyGraph::::new(); let dict_state = state.cast_as::(py)?; + let nodes_lst = dict_state.get_item("nodes").unwrap().downcast::()?; + let edges_lst = dict_state.get_item("edges").unwrap().downcast::()?; - let nodes_dict = dict_state.get_item("nodes").unwrap().downcast::()?; - let edges_list = dict_state.get_item("edges").unwrap().downcast::()?; - let nodes_removed_raw = dict_state - .get_item("nodes_removed") - .unwrap() - .downcast::()?; - self.node_removed = nodes_removed_raw.extract()?; - let multigraph_raw = dict_state + self.graph = StablePyGraph::::new(); + self.multigraph = dict_state .get_item("multigraph") .unwrap() - .downcast::()?; - self.multigraph = multigraph_raw.extract()?; - let mut node_indices: Vec = Vec::new(); - for raw_index in nodes_dict.keys() { - let tmp_index = raw_index.downcast::()?; - node_indices.push(tmp_index.extract()?); - } - if node_indices.is_empty() { + .downcast::()? + .extract()?; + self.node_removed = dict_state + .get_item("node_removed") + .unwrap() + .downcast::()? + .extract()?; + + // graph is empty, stop early + if nodes_lst.is_empty() { return Ok(()); } - let max_index: usize = *node_indices.iter().max().unwrap(); - if max_index + 1 != node_indices.len() { - self.node_removed = true; - } - let mut tmp_nodes: Vec = Vec::new(); - let mut node_count: usize = 0; - while max_index >= self.graph.node_bound() { - match nodes_dict.get_item(node_count) { - Some(raw_data) => { - self.graph.add_node(raw_data.into()); + + if !self.node_removed { + for item in nodes_lst.iter() { + let node_w = item + .downcast::() + .unwrap() + .get_item(1) + .unwrap() + .extract() + .unwrap(); + self.graph.add_node(node_w); + } + } else { + if nodes_lst.len() == 1 { + // graph has only one node, handle logic here to save one if in the loop later + let item = nodes_lst + .get_item(0) + .unwrap() + .downcast::() + .unwrap(); + let node_idx: usize = item.get_item(0).unwrap().extract().unwrap(); + let node_w = item.get_item(1).unwrap().extract().unwrap(); + + for _i in 0..node_idx { + self.graph.add_node(py.None()); } - None => { + self.graph.add_node(node_w); + for i in 0..node_idx { + self.graph.remove_node(NodeIndex::new(i)); + } + } else { + let last_item = nodes_lst + .get_item(nodes_lst.len() - 1) + .unwrap() + .downcast::() + .unwrap(); + + // use a pointer to iter the node list + let mut pointer = 0; + let mut next_node_idx: usize = nodes_lst + .get_item(pointer) + .unwrap() + .downcast::() + .unwrap() + .get_item(0) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + + // list of temporary nodes that will be removed later to re-create holes + let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap(); + let mut tmp_nodes: Vec = + Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len()); + + let second_last_node_idx: usize = nodes_lst + .get_item(nodes_lst.len() - 2) + .unwrap() + .downcast::() + .unwrap() + .get_item(0) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + + for i in 0..(second_last_node_idx + 1) { + if i < next_node_idx { + // node does not exist + let tmp_node = self.graph.add_node(py.None()); + tmp_nodes.push(tmp_node); + } else { + // add node to the graph, and update the next available node index + let item = nodes_lst + .get_item(pointer) + .unwrap() + .downcast::() + .unwrap(); + + let node_w = item.get_item(1).unwrap().extract().unwrap(); + self.graph.add_node(node_w); + pointer += 1; + next_node_idx = nodes_lst + .get_item(pointer) + .unwrap() + .downcast::() + .unwrap() + .get_item(0) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + } + } + + for _i in (second_last_node_idx + 1)..next_node_idx { let tmp_node = self.graph.add_node(py.None()); tmp_nodes.push(tmp_node); } - }; - node_count += 1; - } - for tmp_node in tmp_nodes { - self.graph.remove_node(tmp_node); - } - for raw_edge in edges_list.iter() { - let edge = raw_edge.downcast::()?; - let raw_p_index = edge.get_item(0)?.downcast::()?; - let p_index: usize = raw_p_index.extract()?; - let raw_c_index = edge.get_item(1)?.downcast::()?; - let c_index: usize = raw_c_index.extract()?; - let edge_data = edge.get_item(2)?; - self.graph.add_edge( - NodeIndex::new(p_index), - NodeIndex::new(c_index), - edge_data.into(), - ); + let last_node_w = last_item.get_item(1).unwrap().extract().unwrap(); + self.graph.add_node(last_node_w); + for tmp_node in tmp_nodes { + self.graph.remove_node(tmp_node); + } + } } + + // to ensure O(1) on edge deletion, use a temporary node to store missing edges + let tmp_node = self.graph.add_node(py.None()); + + for item in edges_lst { + if item.is_none() { + // add a temporary edge that will be deleted later to re-create the hole + self.graph.add_edge(tmp_node, tmp_node, py.None()); + } else { + let triple = item.downcast::().unwrap(); + let edge_p: usize = triple + .get_item(0) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + let edge_c: usize = triple + .get_item(1) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + let edge_w = triple.get_item(2).unwrap().extract().unwrap(); + self.graph + .add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w); + } + } + + // remove the temporary node will remove all deleted edges in bulk, + // the cost is equal to the number of edges + self.graph.remove_node(tmp_node); + Ok(()) } diff --git a/src/graph.rs b/src/graph.rs index da6436e06..2a8cbe5ab 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -44,7 +44,8 @@ use petgraph::algo; use petgraph::graph::{EdgeIndex, NodeIndex}; use petgraph::prelude::*; use petgraph::visit::{ - GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, NodeIndexable, + EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, + NodeIndexable, }; /// A class for creating undirected graphs @@ -167,82 +168,205 @@ impl PyGraph { } fn __getstate__(&self, py: Python) -> PyResult { - let out_dict = PyDict::new(py); - let node_dict = PyDict::new(py); - let mut out_list: Vec = Vec::with_capacity(self.graph.edge_count()); - out_dict.set_item("nodes", node_dict)?; - out_dict.set_item("nodes_removed", self.node_removed)?; - out_dict.set_item("multigraph", self.multigraph)?; - for node_index in self.graph.node_indices() { - let node_data = self.graph.node_weight(node_index).unwrap(); - node_dict.set_item(node_index.index(), node_data)?; + let mut nodes: Vec = Vec::with_capacity(self.graph.node_count()); + let mut edges: Vec = Vec::with_capacity(self.graph.edge_bound()); + + // save nodes to a list along with its index + for node_idx in self.graph.node_indices() { + let node_data = self.graph.node_weight(node_idx).unwrap(); + nodes.push((node_idx.index(), node_data).to_object(py)); } - for edge in self.graph.edge_indices() { - let edge_w = self.graph.edge_weight(edge); - let endpoints = self.graph.edge_endpoints(edge).unwrap(); - let triplet = (endpoints.0.index(), endpoints.1.index(), edge_w).to_object(py); - out_list.push(triplet); + // edges are saved with none (deleted edges) instead of their index to save space + for i in 0..self.graph.edge_bound() { + let idx = EdgeIndex::new(i); + let edge = match self.graph.edge_weight(idx) { + Some(edge_w) => { + let endpoints = self.graph.edge_endpoints(idx).unwrap(); + (endpoints.0.index(), endpoints.1.index(), edge_w).to_object(py) + } + None => py.None(), + }; + edges.push(edge); } - let py_out_list: PyObject = PyList::new(py, out_list).into(); - out_dict.set_item("edges", py_out_list)?; - Ok(out_dict.into()) + + let outdict = PyDict::new(py); + let nodes_lst: PyObject = PyList::new(py, nodes).into(); + let edges_lst: PyObject = PyList::new(py, edges).into(); + outdict.set_item("nodes", nodes_lst)?; + outdict.set_item("edges", edges_lst)?; + outdict.set_item( + "node_removed", + (self.graph.node_bound() != self.graph.node_count()).to_object(py), + )?; + outdict.set_item("multigraph", self.multigraph)?; + Ok(outdict.into()) } fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - self.graph = StablePyGraph::::default(); let dict_state = state.cast_as::(py)?; - let nodes_dict = dict_state.get_item("nodes").unwrap().downcast::()?; - let edges_list = dict_state.get_item("edges").unwrap().downcast::()?; - let nodes_removed_raw = dict_state - .get_item("nodes_removed") - .unwrap() - .downcast::()?; - self.node_removed = nodes_removed_raw.extract()?; - let multigraph_raw = dict_state + let nodes_lst = dict_state.get_item("nodes").unwrap().downcast::()?; + let edges_lst = dict_state.get_item("edges").unwrap().downcast::()?; + + self.graph = StablePyGraph::::default(); + self.multigraph = dict_state .get_item("multigraph") .unwrap() - .downcast::()?; - self.multigraph = multigraph_raw.extract()?; + .downcast::()? + .extract()?; + self.node_removed = dict_state + .get_item("node_removed") + .unwrap() + .downcast::()? + .extract()?; - let mut node_indices: Vec = Vec::new(); - for raw_index in nodes_dict.keys() { - let tmp_index = raw_index.downcast::()?; - node_indices.push(tmp_index.extract()?); - } - if node_indices.is_empty() { + // graph is empty, stop early + if nodes_lst.is_empty() { return Ok(()); } - let max_index: usize = *node_indices.iter().max().unwrap(); - let mut tmp_nodes: Vec = Vec::new(); - let mut node_count: usize = 0; - while max_index >= self.graph.node_bound() { - match nodes_dict.get_item(node_count) { - Some(raw_data) => { - self.graph.add_node(raw_data.into()); + + if !self.node_removed { + for item in nodes_lst.iter() { + let node_w = item + .downcast::() + .unwrap() + .get_item(1) + .unwrap() + .extract() + .unwrap(); + self.graph.add_node(node_w); + } + } else { + if nodes_lst.len() == 1 { + // graph has only one node, handle logic here to save one if in the loop later + let item = nodes_lst + .get_item(0) + .unwrap() + .downcast::() + .unwrap(); + let node_idx: usize = item.get_item(0).unwrap().extract().unwrap(); + let node_w = item.get_item(1).unwrap().extract().unwrap(); + + for _i in 0..node_idx { + self.graph.add_node(py.None()); } - None => { + self.graph.add_node(node_w); + for i in 0..node_idx { + self.graph.remove_node(NodeIndex::new(i)); + } + } else { + let last_item = nodes_lst + .get_item(nodes_lst.len() - 1) + .unwrap() + .downcast::() + .unwrap(); + + // use a pointer to iter the node list + let mut pointer = 0; + let mut next_node_idx: usize = nodes_lst + .get_item(pointer) + .unwrap() + .downcast::() + .unwrap() + .get_item(0) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + + // list of temporary nodes that will be removed later to re-create holes + let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap(); + let mut tmp_nodes: Vec = + Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len()); + + let second_last_node_idx: usize = nodes_lst + .get_item(nodes_lst.len() - 2) + .unwrap() + .downcast::() + .unwrap() + .get_item(0) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + + for i in 0..(second_last_node_idx + 1) { + if i < next_node_idx { + // node does not exist + let tmp_node = self.graph.add_node(py.None()); + tmp_nodes.push(tmp_node); + } else { + // add node to the graph, and update the next available node index + let item = nodes_lst + .get_item(pointer) + .unwrap() + .downcast::() + .unwrap(); + + let node_w = item.get_item(1).unwrap().extract().unwrap(); + self.graph.add_node(node_w); + pointer += 1; + next_node_idx = nodes_lst + .get_item(pointer) + .unwrap() + .downcast::() + .unwrap() + .get_item(0) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + } + } + + for _i in (second_last_node_idx + 1)..next_node_idx { let tmp_node = self.graph.add_node(py.None()); tmp_nodes.push(tmp_node); } - }; - node_count += 1; - } - for tmp_node in tmp_nodes { - self.graph.remove_node(tmp_node); + let last_node_w = last_item.get_item(1).unwrap().extract().unwrap(); + self.graph.add_node(last_node_w); + for tmp_node in tmp_nodes { + self.graph.remove_node(tmp_node); + } + } } - for raw_edge in edges_list.iter() { - let edge = raw_edge.downcast::()?; - let raw_p_index = edge.get_item(0)?.downcast::()?; - let parent: usize = raw_p_index.extract()?; - let p_index = NodeIndex::new(parent); - let raw_c_index = edge.get_item(1)?.downcast::()?; - let child: usize = raw_c_index.extract()?; - let c_index = NodeIndex::new(child); - let edge_data = edge.get_item(2)?; - - self.graph.add_edge(p_index, c_index, edge_data.into()); + + // to ensure O(1) on edge deletion, use a temporary node to store missing edges + let tmp_node = self.graph.add_node(py.None()); + + for item in edges_lst { + if item.is_none() { + // add a temporary edge that will be deleted later to re-create the hole + self.graph.add_edge(tmp_node, tmp_node, py.None()); + } else { + let triple = item.downcast::().unwrap(); + let edge_p: usize = triple + .get_item(0) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + let edge_c: usize = triple + .get_item(1) + .unwrap() + .downcast::() + .unwrap() + .extract() + .unwrap(); + let edge_w = triple.get_item(2).unwrap().extract().unwrap(); + self.graph + .add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w); + } } + + // remove the temporary node will remove all deleted edges in bulk, + // the cost is equal to the number of edges + self.graph.remove_node(tmp_node); + Ok(()) } diff --git a/tests/digraph/test_pickle.py b/tests/digraph/test_pickle.py new file mode 100644 index 000000000..4c03cea72 --- /dev/null +++ b/tests/digraph/test_pickle.py @@ -0,0 +1,41 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import pickle +import unittest + +import retworkx + + +class TestPickleDiGraph(unittest.TestCase): + def test_noweight_graph(self): + g = retworkx.PyDAG() + for i in range(4): + g.add_node(None) + g.add_edges_from_no_data([(0, 1), (1, 2), (3, 0), (3, 1)]) + g.remove_node(0) + + gprime = pickle.loads(pickle.dumps(g)) + self.assertEqual([1, 2, 3], gprime.node_indices()) + self.assertEqual([None, None, None], gprime.nodes()) + self.assertEqual({1: (1, 2, None), 3: (3, 1, None)}, dict(gprime.edge_index_map())) + + def test_weight_graph(self): + g = retworkx.PyDAG() + g.add_nodes_from(["A", "B", "C", "D"]) + g.add_edges_from([(0, 1, "A -> B"), (1, 2, "B -> C"), (3, 0, "D -> A"), (3, 1, "D -> B")]) + g.remove_node(0) + + gprime = pickle.loads(pickle.dumps(g)) + self.assertEqual([1, 2, 3], gprime.node_indices()) + self.assertEqual(["B", "C", "D"], gprime.nodes()) + self.assertEqual({1: (1, 2, "B -> C"), 3: (3, 1, "D -> B")}, dict(gprime.edge_index_map())) diff --git a/tests/graph/test_pickle.py b/tests/graph/test_pickle.py new file mode 100644 index 000000000..25dc9ab56 --- /dev/null +++ b/tests/graph/test_pickle.py @@ -0,0 +1,41 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import pickle +import unittest + +import retworkx + + +class TestPickleGraph(unittest.TestCase): + def test_noweight_graph(self): + g = retworkx.PyGraph() + for i in range(4): + g.add_node(None) + g.add_edges_from_no_data([(0, 1), (1, 2), (3, 0), (3, 1)]) + g.remove_node(0) + + gprime = pickle.loads(pickle.dumps(g)) + self.assertEqual([1, 2, 3], gprime.node_indices()) + self.assertEqual([None, None, None], gprime.nodes()) + self.assertEqual({1: (1, 2, None), 3: (3, 1, None)}, dict(gprime.edge_index_map())) + + def test_weight_graph(self): + g = retworkx.PyGraph() + g.add_nodes_from(["A", "B", "C", "D"]) + g.add_edges_from([(0, 1, "A -> B"), (1, 2, "B -> C"), (3, 0, "D -> A"), (3, 1, "D -> B")]) + g.remove_node(0) + + gprime = pickle.loads(pickle.dumps(g)) + self.assertEqual([1, 2, 3], gprime.node_indices()) + self.assertEqual(["B", "C", "D"], gprime.nodes()) + self.assertEqual({1: (1, 2, "B -> C"), 3: (3, 1, "D -> B")}, dict(gprime.edge_index_map()))