diff --git a/src/dag_isomorphism.rs b/src/dag_isomorphism.rs index 9173f40cc1..b7d460a984 100644 --- a/src/dag_isomorphism.rs +++ b/src/dag_isomorphism.rs @@ -16,12 +16,16 @@ use fixedbitset::FixedBitSet; use std::marker; +use hashbrown::HashMap; + use super::digraph::PyDiGraph; use pyo3::prelude::*; +use petgraph::algo; use petgraph::stable_graph::NodeIndex; -use petgraph::visit::GetAdjacencyMatrix; +use petgraph::stable_graph::StableDiGraph; +use petgraph::visit::{EdgeRef, GetAdjacencyMatrix, IntoEdgeReferences}; use petgraph::{Directed, Incoming}; #[derive(Debug)] @@ -190,6 +194,36 @@ pub fn is_isomorphic(dag0: &PyDiGraph, dag1: &PyDiGraph) -> PyResult { Ok(res.unwrap_or(false)) } +fn clone_graph(py: Python, dag: &PyDiGraph) -> PyDiGraph { + // NOTE: this is a hacky workaround to handle non-contiguous node ids in + // VF2. The code which was forked from petgraph was written assuming the + // Graph type and not StableGraph so it makes an implicit assumption on + // node_bound() == node_count() which isn't true with removals on + // StableGraph. This compacts the node ids as a workaround until VF2State + // and try_match can be rewitten to handle this (and likely contributed + // upstream to petgraph too). + let mut new_graph = StableDiGraph::::new(); + let mut id_map: HashMap = HashMap::new(); + for node_index in dag.graph.node_indices() { + let node_data = dag.graph.node_weight(node_index).unwrap(); + let new_index = new_graph.add_node(node_data.clone_ref(py)); + id_map.insert(node_index, new_index); + } + for edge in dag.graph.edge_references() { + let edge_w = edge.weight(); + let p_index = id_map.get(&edge.source()).unwrap(); + let c_index = id_map.get(&edge.target()).unwrap(); + new_graph.add_edge(*p_index, *c_index, edge_w.clone_ref(py)); + } + + PyDiGraph { + graph: new_graph, + cycle_state: algo::DfsSpace::default(), + check_cycle: dag.check_cycle, + node_removed: false, + } +} + /// [Graph] Return `true` if the graphs `g0` and `g1` are isomorphic. /// /// Using the VF2 algorithm, examining both syntactic and semantic @@ -197,6 +231,7 @@ pub fn is_isomorphic(dag0: &PyDiGraph, dag1: &PyDiGraph) -> PyResult { /// /// The graphs should not be multigraphs. pub fn is_isomorphic_matching( + py: Python, dag0: &PyDiGraph, dag1: &PyDiGraph, mut node_match: F, @@ -206,15 +241,36 @@ where F: FnMut(&PyObject, &PyObject) -> PyResult, G: FnMut(&PyObject, &PyObject) -> PyResult, { - let g0 = &dag0.graph; - let g1 = &dag1.graph; + let inner_temp_dag0: PyDiGraph; + let inner_temp_dag1: PyDiGraph; + let dag0_out = if dag0.node_removed { + inner_temp_dag0 = clone_graph(py, dag0); + &inner_temp_dag0 + } else { + dag0 + }; + let dag1_out = if dag1.node_removed { + inner_temp_dag1 = clone_graph(py, dag1); + &inner_temp_dag1 + } else { + dag1 + }; + let g0 = &dag0_out.graph; + let g1 = &dag1_out.graph; + if g0.node_count() != g1.node_count() || g0.edge_count() != g1.edge_count() { return Ok(false); } - let mut st = [Vf2State::new(dag0), Vf2State::new(dag1)]; - let res = try_match(&mut st, dag0, dag1, &mut node_match, &mut edge_match)?; + let mut st = [Vf2State::new(&dag0_out), Vf2State::new(&dag1_out)]; + let res = try_match( + &mut st, + &dag0_out, + &dag1_out, + &mut node_match, + &mut edge_match, + )?; Ok(res.unwrap_or(false)) } diff --git a/src/lib.rs b/src/lib.rs index 8f8d9e4271..149cd05972 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -221,6 +221,7 @@ fn is_isomorphic_node_match( Ok(true) } let res = dag_isomorphism::is_isomorphic_matching( + py, first, second, compare_nodes, diff --git a/tests/test_isomorphic.py b/tests/test_isomorphic.py index c53757412d..a40e72af58 100644 --- a/tests/test_isomorphic.py +++ b/tests/test_isomorphic.py @@ -10,6 +10,7 @@ # License for the specific language governing permissions and limitations # under the License. +import copy import unittest import retworkx @@ -92,3 +93,86 @@ def test_isomorphic_compare_nodes_identical(self): self.assertTrue( retworkx.is_isomorphic_node_match( dag_a, dag_b, lambda x, y: x == y)) + + def test_isomorphic_compare_nodes_with_removals(self): + dag_a = retworkx.PyDAG() + dag_b = retworkx.PyDAG() + + qr_0_in = dag_a.add_node('qr[0]') + qr_1_in = dag_a.add_node('qr[1]') + cr_0_in = dag_a.add_node('cr[0]') + qr_0_out = dag_a.add_node('qr[0]') + qr_1_out = dag_a.add_node('qr[1]') + cr_0_out = dag_a.add_node('qr[0]') + cu1 = dag_a.add_child(qr_0_in, 'cu1', 'qr[0]') + dag_a.add_edge(qr_1_in, cu1, 'qr[1]') + measure_0 = dag_a.add_child(cr_0_in, 'measure', 'cr[0]') + dag_a.add_edge(cu1, measure_0, 'qr[0]') + measure_1 = dag_a.add_child(cu1, 'measure', 'qr[1]') + dag_a.add_edge(measure_0, measure_1, 'cr[0]') + dag_a.add_edge(measure_1, qr_1_out, 'qr[1]') + dag_a.add_edge(measure_1, cr_0_out, 'cr[0]') + dag_a.add_edge(measure_0, qr_0_out, 'qr[0]') + dag_a.remove_node(cu1) + dag_a.add_edge(qr_0_in, measure_0, 'qr[0]') + dag_a.add_edge(qr_1_in, measure_1, 'qr[1]') + + qr_0_in = dag_b.add_node('qr[0]') + qr_1_in = dag_b.add_node('qr[1]') + cr_0_in = dag_b.add_node('cr[0]') + qr_0_out = dag_b.add_node('qr[0]') + qr_1_out = dag_b.add_node('qr[1]') + cr_0_out = dag_b.add_node('qr[0]') + measure_0 = dag_b.add_child(cr_0_in, 'measure', 'cr[0]') + dag_b.add_edge(qr_0_in, measure_0, 'qr[0]') + measure_1 = dag_b.add_child(qr_1_in, 'measure', 'qr[1]') + dag_b.add_edge(measure_1, qr_1_out, 'qr[1]') + dag_b.add_edge(measure_1, cr_0_out, 'cr[0]') + dag_b.add_edge(measure_0, measure_1, 'cr[0]') + dag_b.add_edge(measure_0, qr_0_out, 'qr[0]') + + self.assertTrue( + retworkx.is_isomorphic_node_match( + dag_a, dag_b, lambda x, y: x == y)) + + def test_isomorphic_compare_nodes_with_removals_deepcopy(self): + dag_a = retworkx.PyDAG() + dag_b = retworkx.PyDAG() + + qr_0_in = dag_a.add_node('qr[0]') + qr_1_in = dag_a.add_node('qr[1]') + cr_0_in = dag_a.add_node('cr[0]') + qr_0_out = dag_a.add_node('qr[0]') + qr_1_out = dag_a.add_node('qr[1]') + cr_0_out = dag_a.add_node('qr[0]') + cu1 = dag_a.add_child(qr_0_in, 'cu1', 'qr[0]') + dag_a.add_edge(qr_1_in, cu1, 'qr[1]') + measure_0 = dag_a.add_child(cr_0_in, 'measure', 'cr[0]') + dag_a.add_edge(cu1, measure_0, 'qr[0]') + measure_1 = dag_a.add_child(cu1, 'measure', 'qr[1]') + dag_a.add_edge(measure_0, measure_1, 'cr[0]') + dag_a.add_edge(measure_1, qr_1_out, 'qr[1]') + dag_a.add_edge(measure_1, cr_0_out, 'cr[0]') + dag_a.add_edge(measure_0, qr_0_out, 'qr[0]') + dag_a.remove_node(cu1) + dag_a.add_edge(qr_0_in, measure_0, 'qr[0]') + dag_a.add_edge(qr_1_in, measure_1, 'qr[1]') + + qr_0_in = dag_b.add_node('qr[0]') + qr_1_in = dag_b.add_node('qr[1]') + cr_0_in = dag_b.add_node('cr[0]') + qr_0_out = dag_b.add_node('qr[0]') + qr_1_out = dag_b.add_node('qr[1]') + cr_0_out = dag_b.add_node('qr[0]') + measure_0 = dag_b.add_child(cr_0_in, 'measure', 'cr[0]') + dag_b.add_edge(qr_0_in, measure_0, 'qr[0]') + measure_1 = dag_b.add_child(qr_1_in, 'measure', 'qr[1]') + dag_b.add_edge(measure_1, qr_1_out, 'qr[1]') + dag_b.add_edge(measure_1, cr_0_out, 'cr[0]') + dag_b.add_edge(measure_0, measure_1, 'cr[0]') + dag_b.add_edge(measure_0, qr_0_out, 'qr[0]') + + self.assertTrue( + retworkx.is_isomorphic_node_match( + copy.deepcopy(dag_a), copy.deepcopy(dag_b), + lambda x, y: x == y))