Skip to content

Commit

Permalink
fix issue Qiskit#585 that pickling graph & digraph do not preserve or…
Browse files Browse the repository at this point in the history
…iginal edge index
  • Loading branch information
Binh Vu committed Apr 18, 2022
1 parent add8e5b commit f3bc9f3
Show file tree
Hide file tree
Showing 4 changed files with 451 additions and 125 deletions.
252 changes: 186 additions & 66 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ use petgraph::graph::{EdgeIndex, NodeIndex};
use petgraph::prelude::*;

use petgraph::visit::{
GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, NodeIndexable,
Visitable,
EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered,
NodeIndexable, Visitable,
};

use super::dot_utils::build_dot;
Expand Down Expand Up @@ -273,85 +273,205 @@ impl PyDiGraph {
}

fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let out_dict = PyDict::new(py);
let node_dict = PyDict::new(py);
let mut out_list: Vec<PyObject> = Vec::with_capacity(self.graph.edge_count());
out_dict.set_item("nodes", node_dict)?;
out_dict.set_item("nodes_removed", self.node_removed)?;
out_dict.set_item("multigraph", self.multigraph)?;
let dir = petgraph::Direction::Incoming;
for node_index in self.graph.node_indices() {
let node_data = self.graph.node_weight(node_index).unwrap();
node_dict.set_item(node_index.index(), node_data)?;
for edge in self.graph.edges_directed(node_index, dir) {
let edge_w = edge.weight();
let triplet = (edge.source().index(), edge.target().index(), edge_w).to_object(py);
out_list.push(triplet);
}
let mut nodes: Vec<PyObject> = Vec::with_capacity(self.graph.node_count());
let mut edges: Vec<PyObject> = Vec::with_capacity(self.graph.edge_bound());

// save nodes to a list along with its index
for node_idx in self.graph.node_indices() {
let node_data = self.graph.node_weight(node_idx).unwrap();
nodes.push((node_idx.index(), node_data).to_object(py));
}

// edges are saved with none (deleted edges) instead of their index to save space
for i in 0..self.graph.edge_bound() {
let idx = EdgeIndex::new(i);
let edge = match self.graph.edge_weight(idx) {
Some(edge_w) => {
let endpoints = self.graph.edge_endpoints(idx).unwrap();
(endpoints.0.index(), endpoints.1.index(), edge_w).to_object(py)
}
None => py.None(),
};
edges.push(edge);
}
let py_out_list: PyObject = PyList::new(py, out_list).into();
out_dict.set_item("edges", py_out_list)?;
Ok(out_dict.into())

let outdict = PyDict::new(py);
let nodes_lst: PyObject = PyList::new(py, nodes).into();
let edges_lst: PyObject = PyList::new(py, edges).into();
outdict.set_item("nodes", nodes_lst)?;
outdict.set_item("edges", edges_lst)?;
outdict.set_item(
"node_removed",
(self.graph.node_bound() != self.graph.node_count()).to_object(py),
)?;
outdict.set_item("multigraph", self.multigraph)?;
Ok(outdict.into())
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
self.graph = StablePyGraph::<Directed>::new();
let dict_state = state.cast_as::<PyDict>(py)?;
let nodes_lst = dict_state.get_item("nodes").unwrap().downcast::<PyList>()?;
let edges_lst = dict_state.get_item("edges").unwrap().downcast::<PyList>()?;

let nodes_dict = dict_state.get_item("nodes").unwrap().downcast::<PyDict>()?;
let edges_list = dict_state.get_item("edges").unwrap().downcast::<PyList>()?;
let nodes_removed_raw = dict_state
.get_item("nodes_removed")
.unwrap()
.downcast::<PyBool>()?;
self.node_removed = nodes_removed_raw.extract()?;
let multigraph_raw = dict_state
self.graph = StablePyGraph::<Directed>::new();
self.multigraph = dict_state
.get_item("multigraph")
.unwrap()
.downcast::<PyBool>()?;
self.multigraph = multigraph_raw.extract()?;
let mut node_indices: Vec<usize> = Vec::new();
for raw_index in nodes_dict.keys() {
let tmp_index = raw_index.downcast::<PyLong>()?;
node_indices.push(tmp_index.extract()?);
}
if node_indices.is_empty() {
.downcast::<PyBool>()?
.extract()?;
self.node_removed = dict_state
.get_item("node_removed")
.unwrap()
.downcast::<PyBool>()?
.extract()?;

// graph is empty, stop early
if nodes_lst.is_empty() {
return Ok(());
}
let max_index: usize = *node_indices.iter().max().unwrap();
if max_index + 1 != node_indices.len() {
self.node_removed = true;
}
let mut tmp_nodes: Vec<NodeIndex> = Vec::new();
let mut node_count: usize = 0;
while max_index >= self.graph.node_bound() {
match nodes_dict.get_item(node_count) {
Some(raw_data) => {
self.graph.add_node(raw_data.into());

if !self.node_removed {
for item in nodes_lst.iter() {
let node_w = item
.downcast::<PyTuple>()
.unwrap()
.get_item(1)
.unwrap()
.extract()
.unwrap();
self.graph.add_node(node_w);
}
} else {
if nodes_lst.len() == 1 {
// graph has only one node, handle logic here to save one if in the loop later
let item = nodes_lst
.get_item(0)
.unwrap()
.downcast::<PyTuple>()
.unwrap();
let node_idx: usize = item.get_item(0).unwrap().extract().unwrap();
let node_w = item.get_item(1).unwrap().extract().unwrap();

for _i in 0..node_idx {
self.graph.add_node(py.None());
}
None => {
self.graph.add_node(node_w);
for i in 0..node_idx {
self.graph.remove_node(NodeIndex::new(i));
}
} else {
let last_item = nodes_lst
.get_item(nodes_lst.len() - 1)
.unwrap()
.downcast::<PyTuple>()
.unwrap();

// use a pointer to iter the node list
let mut pointer = 0;
let mut next_node_idx: usize = nodes_lst
.get_item(pointer)
.unwrap()
.downcast::<PyTuple>()
.unwrap()
.get_item(0)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();

// list of temporary nodes that will be removed later to re-create holes
let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap();
let mut tmp_nodes: Vec<NodeIndex> =
Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len());

let second_last_node_idx: usize = nodes_lst
.get_item(nodes_lst.len() - 2)
.unwrap()
.downcast::<PyTuple>()
.unwrap()
.get_item(0)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();

for i in 0..(second_last_node_idx + 1) {
if i < next_node_idx {
// node does not exist
let tmp_node = self.graph.add_node(py.None());
tmp_nodes.push(tmp_node);
} else {
// add node to the graph, and update the next available node index
let item = nodes_lst
.get_item(pointer)
.unwrap()
.downcast::<PyTuple>()
.unwrap();

let node_w = item.get_item(1).unwrap().extract().unwrap();
self.graph.add_node(node_w);
pointer += 1;
next_node_idx = nodes_lst
.get_item(pointer)
.unwrap()
.downcast::<PyTuple>()
.unwrap()
.get_item(0)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();
}
}

for _i in (second_last_node_idx + 1)..next_node_idx {
let tmp_node = self.graph.add_node(py.None());
tmp_nodes.push(tmp_node);
}
};
node_count += 1;
}
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}
for raw_edge in edges_list.iter() {
let edge = raw_edge.downcast::<PyTuple>()?;
let raw_p_index = edge.get_item(0)?.downcast::<PyLong>()?;
let p_index: usize = raw_p_index.extract()?;
let raw_c_index = edge.get_item(1)?.downcast::<PyLong>()?;
let c_index: usize = raw_c_index.extract()?;
let edge_data = edge.get_item(2)?;
self.graph.add_edge(
NodeIndex::new(p_index),
NodeIndex::new(c_index),
edge_data.into(),
);
let last_node_w = last_item.get_item(1).unwrap().extract().unwrap();
self.graph.add_node(last_node_w);
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}
}
}

// to ensure O(1) on edge deletion, use a temporary node to store missing edges
let tmp_node = self.graph.add_node(py.None());

for item in edges_lst {
if item.is_none() {
// add a temporary edge that will be deleted later to re-create the hole
self.graph.add_edge(tmp_node, tmp_node, py.None());
} else {
let triple = item.downcast::<PyTuple>().unwrap();
let edge_p: usize = triple
.get_item(0)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();
let edge_c: usize = triple
.get_item(1)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();
let edge_w = triple.get_item(2).unwrap().extract().unwrap();
self.graph
.add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w);
}
}

// remove the temporary node will remove all deleted edges in bulk,
// the cost is equal to the number of edges
self.graph.remove_node(tmp_node);

Ok(())
}

Expand Down
Loading

0 comments on commit f3bc9f3

Please sign in to comment.