Skip to content

Commit

Permalink
Add pickle support to custom return types (Qiskit#213)
Browse files Browse the repository at this point in the history
* Add pickle support to custom return types

In recent releases we have started to add custom return types instead of
returning lists of of objects to avoid the overhead of converting large
lists from rust to python. However, these custom return types were
missing support for pickling, which was causing issue for places where
they were being used with Python's multiprocessing or deepcopy. This has
caused issues in places where a list was returned before. This commit
fixes this by adding the necessary methods to the custom return classes
to enable python to pickle them so they can avoid these issues.

* Bump version string to prepare for bugfix release

* Fix test lint
  • Loading branch information
mtreinish authored Dec 3, 2020
1 parent d289260 commit 0ce32e0
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "retworkx"
description = "A python graph library implemented in Rust"
version = "0.7.0"
version = "0.7.1"
authors = ["Matthew Treinish <[email protected]>"]
license = "Apache-2.0"
readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@
# built documents.
#
# The short X.Y version.
version = '0.7.0'
version = '0.7.1'
# The full version, including alpha/beta/rc tags.
release = '0.7.0'
release = '0.7.1'

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def readme():

setup(
name="retworkx",
version="0.7.0",
version="0.7.1",
description="A python graph library implemented in Rust",
long_description=readme(),
long_description_content_type='text/markdown',
Expand Down
67 changes: 66 additions & 1 deletion src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,24 @@ use pyo3::types::PySequence;
#[pyclass(module = "retworkx")]
pub struct BFSSuccessors {
pub bfs_successors: Vec<(PyObject, Vec<PyObject>)>,
pub index: usize,
}

#[pymethods]
impl BFSSuccessors {
#[new]
fn new() -> Self {
BFSSuccessors {
bfs_successors: Vec::new(),
}
}

fn __getstate__(&self) -> Vec<(PyObject, Vec<PyObject>)> {
self.bfs_successors.clone()
}

fn __setstate__(&mut self, state: Vec<(PyObject, Vec<PyObject>)>) {
self.bfs_successors = state;
}
}

#[pyproto]
Expand Down Expand Up @@ -136,6 +153,22 @@ pub struct NodeIndices {
pub nodes: Vec<usize>,
}

#[pymethods]
impl NodeIndices {
#[new]
fn new() -> NodeIndices {
NodeIndices { nodes: Vec::new() }
}

fn __getstate__(&self) -> Vec<usize> {
self.nodes.clone()
}

fn __setstate__(&mut self, state: Vec<usize>) {
self.nodes = state;
}
}

#[pyproto]
impl<'p> PyObjectProtocol<'p> for NodeIndices {
fn __richcmp__(
Expand Down Expand Up @@ -211,6 +244,22 @@ pub struct EdgeList {
pub edges: Vec<(usize, usize)>,
}

#[pymethods]
impl EdgeList {
#[new]
fn new() -> EdgeList {
EdgeList { edges: Vec::new() }
}

fn __getstate__(&self) -> Vec<(usize, usize)> {
self.edges.clone()
}

fn __setstate__(&mut self, state: Vec<(usize, usize)>) {
self.edges = state;
}
}

#[pyproto]
impl<'p> PyObjectProtocol<'p> for EdgeList {
fn __richcmp__(
Expand Down Expand Up @@ -286,6 +335,22 @@ pub struct WeightedEdgeList {
pub edges: Vec<(usize, usize, PyObject)>,
}

#[pymethods]
impl WeightedEdgeList {
#[new]
fn new() -> WeightedEdgeList {
WeightedEdgeList { edges: Vec::new() }
}

fn __getstate__(&self) -> Vec<(usize, usize, PyObject)> {
self.edges.clone()
}

fn __setstate__(&mut self, state: Vec<(usize, usize, PyObject)>) {
self.edges = state;
}
}

#[pyproto]
impl<'p> PyObjectProtocol<'p> for WeightedEdgeList {
fn __richcmp__(
Expand Down
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,6 @@ fn bfs_successors(
}
Ok(iterators::BFSSuccessors {
bfs_successors: out_list,
index: 0,
})
}

Expand Down
46 changes: 46 additions & 0 deletions tests/test_custom_return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# License for the specific language governing permissions and limitations
# under the License.

import copy
import pickle
import unittest

import retworkx
Expand Down Expand Up @@ -63,6 +65,17 @@ def test__gt__not_implemented(self):
with self.assertRaises(NotImplementedError):
retworkx.bfs_successors(self.dag, 0) > [('b', ['c'])]

def test_deepcopy(self):
bfs = retworkx.bfs_successors(self.dag, 0)
bfs_copy = copy.deepcopy(bfs)
self.assertEqual(bfs, bfs_copy)

def test_pickle(self):
bfs = retworkx.bfs_successors(self.dag, 0)
bfs_pickle = pickle.dumps(bfs)
bfs_copy = pickle.loads(bfs_pickle)
self.assertEqual(bfs, bfs_copy)


class TestNodeIndicesComparisons(unittest.TestCase):

Expand Down Expand Up @@ -101,6 +114,17 @@ def test__gt__not_implemented(self):
with self.assertRaises(NotImplementedError):
self.dag.node_indexes() > [2, 1]

def test_deepcopy(self):
nodes = self.dag.node_indexes()
nodes_copy = copy.deepcopy(nodes)
self.assertEqual(nodes, nodes_copy)

def test_pickle(self):
nodes = self.dag.node_indexes()
nodes_pickle = pickle.dumps(nodes)
nodes_copy = pickle.loads(nodes_pickle)
self.assertEqual(nodes, nodes_copy)


class TestEdgeListComparisons(unittest.TestCase):

Expand Down Expand Up @@ -137,6 +161,17 @@ def test__gt__not_implemented(self):
with self.assertRaises(NotImplementedError):
self.dag.edge_list() > [(2, 1)]

def test_deepcopy(self):
edges = self.dag.edge_list()
edges_copy = copy.deepcopy(edges)
self.assertEqual(edges, edges_copy)

def test_pickle(self):
edges = self.dag.edge_list()
edges_pickle = pickle.dumps(edges)
edges_copy = pickle.loads(edges_pickle)
self.assertEqual(edges, edges_copy)


class TestWeightedEdgeListComparisons(unittest.TestCase):

Expand Down Expand Up @@ -174,3 +209,14 @@ def test__ne__invalid_type(self):
def test__gt__not_implemented(self):
with self.assertRaises(NotImplementedError):
self.dag.weighted_edge_list() > [(2, 1, 'Not Edgy')]

def test_deepcopy(self):
edges = self.dag.weighted_edge_list()
edges_copy = copy.deepcopy(edges)
self.assertEqual(edges, edges_copy)

def test_pickle(self):
edges = self.dag.weighted_edge_list()
edges_pickle = pickle.dumps(edges)
edges_copy = pickle.loads(edges_pickle)
self.assertEqual(edges, edges_copy)

0 comments on commit 0ce32e0

Please sign in to comment.