From 0ce32e0066dd743e92b4a45f651fcb7eaeb0df76 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Thu, 3 Dec 2020 15:52:01 -0500 Subject: [PATCH] Add pickle support to custom return types (#213) * Add pickle support to custom return types In recent releases we have started to add custom return types instead of returning lists of of objects to avoid the overhead of converting large lists from rust to python. However, these custom return types were missing support for pickling, which was causing issue for places where they were being used with Python's multiprocessing or deepcopy. This has caused issues in places where a list was returned before. This commit fixes this by adding the necessary methods to the custom return classes to enable python to pickle them so they can avoid these issues. * Bump version string to prepare for bugfix release * Fix test lint --- Cargo.toml | 2 +- docs/source/conf.py | 4 +- setup.py | 2 +- src/iterators.rs | 67 ++++++++++++++++++++++++++++++- src/lib.rs | 1 - tests/test_custom_return_types.py | 46 +++++++++++++++++++++ 6 files changed, 116 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ee4346d5ce..615f5360d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "retworkx" description = "A python graph library implemented in Rust" -version = "0.7.0" +version = "0.7.1" authors = ["Matthew Treinish "] license = "Apache-2.0" readme = "README.md" diff --git a/docs/source/conf.py b/docs/source/conf.py index ff3e228d7d..40ccdf9eca 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -79,9 +79,9 @@ # built documents. # # The short X.Y version. -version = '0.7.0' +version = '0.7.1' # The full version, including alpha/beta/rc tags. -release = '0.7.0' +release = '0.7.1' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/setup.py b/setup.py index 7f4ce6c3a3..e14b14811f 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ def readme(): setup( name="retworkx", - version="0.7.0", + version="0.7.1", description="A python graph library implemented in Rust", long_description=readme(), long_description_content_type='text/markdown', diff --git a/src/iterators.rs b/src/iterators.rs index 0dc8dc074c..2ed4b8f8b1 100644 --- a/src/iterators.rs +++ b/src/iterators.rs @@ -42,7 +42,24 @@ use pyo3::types::PySequence; #[pyclass(module = "retworkx")] pub struct BFSSuccessors { pub bfs_successors: Vec<(PyObject, Vec)>, - pub index: usize, +} + +#[pymethods] +impl BFSSuccessors { + #[new] + fn new() -> Self { + BFSSuccessors { + bfs_successors: Vec::new(), + } + } + + fn __getstate__(&self) -> Vec<(PyObject, Vec)> { + self.bfs_successors.clone() + } + + fn __setstate__(&mut self, state: Vec<(PyObject, Vec)>) { + self.bfs_successors = state; + } } #[pyproto] @@ -136,6 +153,22 @@ pub struct NodeIndices { pub nodes: Vec, } +#[pymethods] +impl NodeIndices { + #[new] + fn new() -> NodeIndices { + NodeIndices { nodes: Vec::new() } + } + + fn __getstate__(&self) -> Vec { + self.nodes.clone() + } + + fn __setstate__(&mut self, state: Vec) { + self.nodes = state; + } +} + #[pyproto] impl<'p> PyObjectProtocol<'p> for NodeIndices { fn __richcmp__( @@ -211,6 +244,22 @@ pub struct EdgeList { pub edges: Vec<(usize, usize)>, } +#[pymethods] +impl EdgeList { + #[new] + fn new() -> EdgeList { + EdgeList { edges: Vec::new() } + } + + fn __getstate__(&self) -> Vec<(usize, usize)> { + self.edges.clone() + } + + fn __setstate__(&mut self, state: Vec<(usize, usize)>) { + self.edges = state; + } +} + #[pyproto] impl<'p> PyObjectProtocol<'p> for EdgeList { fn __richcmp__( @@ -286,6 +335,22 @@ pub struct WeightedEdgeList { pub edges: Vec<(usize, usize, PyObject)>, } +#[pymethods] +impl WeightedEdgeList { + #[new] + fn new() -> WeightedEdgeList { + WeightedEdgeList { edges: Vec::new() } + } + + fn __getstate__(&self) -> Vec<(usize, usize, PyObject)> { + self.edges.clone() + } + + fn __setstate__(&mut self, state: Vec<(usize, usize, PyObject)>) { + self.edges = state; + } +} + #[pyproto] impl<'p> PyObjectProtocol<'p> for WeightedEdgeList { fn __richcmp__( diff --git a/src/lib.rs b/src/lib.rs index 5304bdfd7d..4c8e654543 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -514,7 +514,6 @@ fn bfs_successors( } Ok(iterators::BFSSuccessors { bfs_successors: out_list, - index: 0, }) } diff --git a/tests/test_custom_return_types.py b/tests/test_custom_return_types.py index 8f00a9125d..ef47009160 100644 --- a/tests/test_custom_return_types.py +++ b/tests/test_custom_return_types.py @@ -10,6 +10,8 @@ # License for the specific language governing permissions and limitations # under the License. +import copy +import pickle import unittest import retworkx @@ -63,6 +65,17 @@ def test__gt__not_implemented(self): with self.assertRaises(NotImplementedError): retworkx.bfs_successors(self.dag, 0) > [('b', ['c'])] + def test_deepcopy(self): + bfs = retworkx.bfs_successors(self.dag, 0) + bfs_copy = copy.deepcopy(bfs) + self.assertEqual(bfs, bfs_copy) + + def test_pickle(self): + bfs = retworkx.bfs_successors(self.dag, 0) + bfs_pickle = pickle.dumps(bfs) + bfs_copy = pickle.loads(bfs_pickle) + self.assertEqual(bfs, bfs_copy) + class TestNodeIndicesComparisons(unittest.TestCase): @@ -101,6 +114,17 @@ def test__gt__not_implemented(self): with self.assertRaises(NotImplementedError): self.dag.node_indexes() > [2, 1] + def test_deepcopy(self): + nodes = self.dag.node_indexes() + nodes_copy = copy.deepcopy(nodes) + self.assertEqual(nodes, nodes_copy) + + def test_pickle(self): + nodes = self.dag.node_indexes() + nodes_pickle = pickle.dumps(nodes) + nodes_copy = pickle.loads(nodes_pickle) + self.assertEqual(nodes, nodes_copy) + class TestEdgeListComparisons(unittest.TestCase): @@ -137,6 +161,17 @@ def test__gt__not_implemented(self): with self.assertRaises(NotImplementedError): self.dag.edge_list() > [(2, 1)] + def test_deepcopy(self): + edges = self.dag.edge_list() + edges_copy = copy.deepcopy(edges) + self.assertEqual(edges, edges_copy) + + def test_pickle(self): + edges = self.dag.edge_list() + edges_pickle = pickle.dumps(edges) + edges_copy = pickle.loads(edges_pickle) + self.assertEqual(edges, edges_copy) + class TestWeightedEdgeListComparisons(unittest.TestCase): @@ -174,3 +209,14 @@ def test__ne__invalid_type(self): def test__gt__not_implemented(self): with self.assertRaises(NotImplementedError): self.dag.weighted_edge_list() > [(2, 1, 'Not Edgy')] + + def test_deepcopy(self): + edges = self.dag.weighted_edge_list() + edges_copy = copy.deepcopy(edges) + self.assertEqual(edges, edges_copy) + + def test_pickle(self): + edges = self.dag.weighted_edge_list() + edges_pickle = pickle.dumps(edges) + edges_copy = pickle.loads(edges_pickle) + self.assertEqual(edges, edges_copy)