Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add op's extension to signature check in resolve_opaque_op #1317

Merged
merged 17 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions hugr-core/src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ mod test {
use super::*;
use cool_asserts::assert_matches;

use crate::extension::{ExtensionId, ExtensionSet};
use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
use crate::utils::test_quantum_extension::{
self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
Expand Down Expand Up @@ -295,16 +296,20 @@ mod test {

#[test]
fn with_nonlinear_and_outputs() {
let missing_ext: ExtensionId = "MissingExt".try_into().unwrap();
let my_custom_op = CustomOp::new_opaque(OpaqueOp::new(
"MissingRsrc".try_into().unwrap(),
missing_ext.clone(),
"MyOp",
"unknown op".to_string(),
vec![],
FunctionType::new(vec![QB, NAT], vec![QB]),
));
let build_res = build_main(
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T])
.with_extension_delta(test_quantum_extension::EXTENSION_ID)
.with_extension_delta(ExtensionSet::from_iter([
test_quantum_extension::EXTENSION_ID,
missing_ext,
]))
.into(),
|mut f_build| {
let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();
Expand Down
68 changes: 43 additions & 25 deletions hugr-core/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -643,38 +643,56 @@ mod test {
}

#[test]
fn test_invalid() -> Result<(), Box<dyn std::error::Error>> {
fn test_invalid() {
let unknown_ext: ExtensionId = "unknown_ext".try_into().unwrap();
let utou = FunctionType::new_endo(vec![USIZE_T]);
let mk_op = |s| {
CustomOp::new_opaque(OpaqueOp::new(
ExtensionId::new("unknown_ext").unwrap(),
unknown_ext.clone(),
s,
String::new(),
vec![],
utou.clone(),
))
};
let mut h = DFGBuilder::new(FunctionType::new(
type_row![USIZE_T, BOOL_T],
type_row![USIZE_T],
))?;
let mut h = DFGBuilder::new(
FunctionType::new(type_row![USIZE_T, BOOL_T], type_row![USIZE_T])
.with_extension_delta(unknown_ext.clone()),
)
.unwrap();
let [i, b] = h.input_wires_arr();
let mut cond = h.conditional_builder(
(vec![type_row![]; 2], b),
[(USIZE_T, i)],
type_row![USIZE_T],
)?;
let mut case1 = cond.case_builder(0)?;
let foo = case1.add_dataflow_op(mk_op("foo"), case1.input_wires())?;
let case1 = case1.finish_with_outputs(foo.outputs())?.node();
let mut case2 = cond.case_builder(1)?;
let bar = case2.add_dataflow_op(mk_op("bar"), case2.input_wires())?;
let mut baz_dfg = case2.dfg_builder(utou.clone(), bar.outputs())?;
let baz = baz_dfg.add_dataflow_op(mk_op("baz"), baz_dfg.input_wires())?;
let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs())?;
let case2 = case2.finish_with_outputs(baz_dfg.outputs())?.node();
let cond = cond.finish_sub_container()?;
let h = h.finish_hugr_with_outputs(cond.outputs(), &PRELUDE_REGISTRY)?;
let mut cond = h
.conditional_builder_exts(
(vec![type_row![]; 2], b),
[(USIZE_T, i)],
type_row![USIZE_T],
unknown_ext.clone(),
)
.unwrap();
let mut case1 = cond.case_builder(0).unwrap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why ? -> unwrap everywhere? It does make things a lot longer and look like you are changing more than you are.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unwrap is much better than ? in tests because you get a src location on failure. I changed them all to be able to fix the test and thought I might as well leave it in a debuggable state.

let foo = case1
.add_dataflow_op(mk_op("foo"), case1.input_wires())
.unwrap();
let case1 = case1.finish_with_outputs(foo.outputs()).unwrap().node();
let mut case2 = cond.case_builder(1).unwrap();
let bar = case2
.add_dataflow_op(mk_op("bar"), case2.input_wires())
.unwrap();
let mut baz_dfg = case2
.dfg_builder(
utou.clone().with_extension_delta(unknown_ext.clone()),
bar.outputs(),
)
.unwrap();
let baz = baz_dfg
.add_dataflow_op(mk_op("baz"), baz_dfg.input_wires())
.unwrap();
let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs()).unwrap();
let case2 = case2.finish_with_outputs(baz_dfg.outputs()).unwrap().node();
let cond = cond.finish_sub_container().unwrap();
let h = h
.finish_hugr_with_outputs(cond.outputs(), &PRELUDE_REGISTRY)
.unwrap();

let mut r_hugr = Hugr::new(h.get_optype(cond.node()).clone());
let r1 = r_hugr.add_node_with_parent(
Expand All @@ -701,7 +719,7 @@ mod test {
rep.verify(&h).unwrap();
{
let mut target = h.clone();
let node_map = rep.clone().apply(&mut target)?;
let node_map = rep.clone().apply(&mut target).unwrap();
let new_case2 = *node_map.get(&r2).unwrap();
assert_eq!(target.get_parent(baz.node()), Some(new_case2));
}
Expand All @@ -716,7 +734,8 @@ mod test {
// Root node type needs to be that of common parent of the removed nodes:
let mut rep2 = rep.clone();
rep2.replacement
.replace_op(rep2.replacement.root(), h.root_type().clone())?;
.replace_op(rep2.replacement.root(), h.root_type().clone())
.unwrap();
assert_eq!(
check_same_errors(rep2),
ReplaceError::WrongRootNodeTag {
Expand Down Expand Up @@ -815,6 +834,5 @@ mod test {
}),
ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge)
);
Ok(())
}
}
34 changes: 26 additions & 8 deletions hugr-core/src/ops/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl DataflowOpTrait for CustomOp {
/// The signature of the operation.
fn signature(&self) -> FunctionType {
match self {
Self::Opaque(op) => op.signature.clone(),
Self::Opaque(op) => op.signature(),
Self::Extension(ext_op) => ext_op.signature(),
}
}
Expand Down Expand Up @@ -276,7 +276,15 @@ impl DataflowOpTrait for ExtensionOp {
}
}

/// An opaquely-serialized op that refers to an as-yet-unresolved [`OpDef`]
/// An opaquely-serialized op that refers to an as-yet-unresolved [`OpDef`].
///
/// All [CustomOp]s are serialised as `OpaqueOp`s.
///
/// The signature of a [CustomOp] always includes that op's extension. We do not
/// require that the `signature` field of [OpaqueOp] contains `extension`,
/// instead we are careful to add it whenever we look at the `signature` of an
/// `OpaqueOp`. This is a small efficiency in serialisation and allows us to
/// be more liberal in deserialisation.
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct OpaqueOp {
Expand All @@ -286,6 +294,9 @@ pub struct OpaqueOp {
#[cfg_attr(test, proptest(strategy = "any_nonempty_string()"))]
description: String, // cache in advance so description() can return &str
args: Vec<TypeArg>,
// note that `signature` may not include `extension`. Thus this field must
doug-q marked this conversation as resolved.
Show resolved Hide resolved
// remain private, and should be accessed through
// `DataflowOpTrait::signature`.
signature: FunctionType,
}

Expand Down Expand Up @@ -343,7 +354,9 @@ impl DataflowOpTrait for OpaqueOp {
}

fn signature(&self) -> FunctionType {
self.signature.clone()
self.signature
.clone()
.with_extension_delta(self.extension().clone())
}
}

Expand Down Expand Up @@ -392,9 +405,8 @@ pub fn resolve_opaque_op(
r.name().clone(),
));
};
let ext_op =
ExtensionOp::new(def.clone(), opaque.args.clone(), extension_registry).unwrap();
if opaque.signature != ext_op.signature {
let ext_op = ExtensionOp::new(def.clone(), opaque.args.clone(), extension_registry)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change from unwrap to ? here necessitates the addition of CustomOpError::SignatureError? Good move :)

if opaque.signature() != ext_op.signature() {
return Err(CustomOpError::SignatureMismatch {
extension: opaque.extension.clone(),
op: def.name().clone(),
Expand Down Expand Up @@ -425,10 +437,14 @@ pub enum CustomOpError {
stored: FunctionType,
computed: FunctionType,
},
/// An error in computing the signature of the ExtensionOp
#[error(transparent)]
SignatureError(#[from] SignatureError),
}

#[cfg(test)]
mod test {

use crate::{
extension::prelude::{BOOL_T, QB_T, USIZE_T},
std_extensions::arithmetic::{
Expand All @@ -453,13 +469,15 @@ mod test {
assert_eq!(op.name(), "res.op");
assert_eq!(DataflowOpTrait::description(&op), "desc");
assert_eq!(op.args(), &[TypeArg::Type { ty: USIZE_T }]);
assert_eq!(op.signature(), sig);
assert_eq!(
op.signature(),
sig.with_extension_delta(op.extension().clone())
);
assert!(op.is_opaque());
assert!(!op.is_extension_op());
}

#[test]
#[should_panic] // https://github.com/CQCL/hugr/issues/1315
fn resolve_opaque_op() {
let registry = &INT_OPS_REGISTRY;
let i0 = &INT_TYPES[0];
Expand Down
3 changes: 1 addition & 2 deletions hugr-py/src/hugr/std/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from hugr import tys, val

#: HUGR 64-bit IEEE 754-2019 floating point type.
FLOAT_EXT_ID = "arithmetic.float.types"
FLOAT_T = tys.Opaque(
extension=FLOAT_EXT_ID,
extension="arithmetic.float.types",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't really see any need to revert here

id="float64",
args=[],
bound=tys.TypeBound.Copyable,
Expand Down
4 changes: 1 addition & 3 deletions hugr-py/src/hugr/std/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def to_custom(self) -> Custom:
return Custom(
"idivmod_u",
tys.FunctionType(
input=[int_t(self.arg1)] * 2,
output=[int_t(self.arg2)] * 2,
extension_reqs=[OPS_EXTENSION],
input=[int_t(self.arg1)] * 2, output=[int_t(self.arg2)] * 2
),
extension=OPS_EXTENSION,
args=[tys.BoundedNatArg(n=self.arg1), tys.BoundedNatArg(n=self.arg2)],
Expand Down
3 changes: 1 addition & 2 deletions hugr-py/src/hugr/std/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ class _NotDef(AsCustomOp):
"""Not operation."""

def to_custom(self) -> Custom:
sig = tys.FunctionType.endo([tys.Bool], [EXTENSION_ID])
return Custom("Not", sig, extension=EXTENSION_ID)
return Custom("Not", tys.FunctionType.endo([tys.Bool]), extension=EXTENSION_ID)

def __call__(self, a: ComWire) -> Command:
return DataflowOp.__call__(self, a)
Expand Down
20 changes: 5 additions & 15 deletions hugr-py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from hugr.hugr import Hugr
from hugr.ops import AsCustomOp, Command, Custom, DataflowOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.std.float import FLOAT_EXT_ID, FLOAT_T
from hugr.std.float import FLOAT_T

if TYPE_CHECKING:
from hugr.ops import ComWire
Expand Down Expand Up @@ -48,7 +48,7 @@ def __call__(self, q: ComWire) -> Command:
def to_custom(self) -> Custom:
return Custom(
self._enum.value,
tys.FunctionType.endo([tys.Qubit], extension_reqs=[QUANTUM_EXTENSION_ID]),
tys.FunctionType.endo([tys.Qubit]),
extension=QUANTUM_EXTENSION_ID,
)

Expand All @@ -70,9 +70,7 @@ class _Enum(Enum):
def to_custom(self) -> Custom:
return Custom(
self._enum.value,
tys.FunctionType.endo(
[tys.Qubit] * 2, extension_reqs=[QUANTUM_EXTENSION_ID]
),
tys.FunctionType.endo([tys.Qubit] * 2),
extension=QUANTUM_EXTENSION_ID,
)

Expand All @@ -92,11 +90,7 @@ class MeasureDef(AsCustomOp):
def to_custom(self) -> Custom:
return Custom(
"Measure",
tys.FunctionType(
[tys.Qubit],
[tys.Qubit, tys.Bool],
extension_reqs=[QUANTUM_EXTENSION_ID],
),
tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]),
extension=QUANTUM_EXTENSION_ID,
)

Expand All @@ -112,11 +106,7 @@ class RzDef(AsCustomOp):
def to_custom(self) -> Custom:
return Custom(
"Rz",
tys.FunctionType(
[tys.Qubit, FLOAT_T],
[tys.Qubit],
extension_reqs=[QUANTUM_EXTENSION_ID, FLOAT_EXT_ID],
),
tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit]),
extension=QUANTUM_EXTENSION_ID,
)

Expand Down
Loading