diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 91a1c23634..bf0c07be06 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -13,6 +13,7 @@ use crate::ops::{OpName, OpNameRef}; use crate::types::type_param::{check_type_args, TypeArg, TypeParam}; use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature}; use crate::Hugr; +mod serialize_signature_func; /// Trait necessary for binary computations of OpDef signature pub trait CustomSignatureFunc: Send + Sync { @@ -163,65 +164,6 @@ pub enum SignatureFunc { MissingComputeFunc, } -mod serialize_signature_func { - use serde::{Deserialize, Serialize}; - - use super::{PolyFuncTypeRV, SignatureFunc}; - #[derive(serde::Deserialize, serde::Serialize)] - struct SerSignatureFunc { - /// If the type scheme is available explicitly, store it. - signature: Option, - /// Whether an associated binary function is expected. - /// If `signature` is `None`, a true value here indicates a custom compute function. - /// If `signature` is not `None`, a true value here indicates a custom validation function. - binary: bool, - } - - pub(super) fn serialize( - value: &super::SignatureFunc, - serializer: S, - ) -> Result - where - S: serde::Serializer, - { - match value { - SignatureFunc::PolyFuncType(custom) => SerSignatureFunc { - signature: Some(custom.poly_func.clone()), - binary: custom.validate.is_some(), - }, - SignatureFunc::MissingValidateFunc(poly_func) => SerSignatureFunc { - signature: Some(poly_func.clone()), - binary: true, - }, - SignatureFunc::CustomFunc(_) => SerSignatureFunc { - signature: None, - binary: true, - }, - SignatureFunc::MissingComputeFunc => SerSignatureFunc { - signature: None, - binary: false, - }, - } - .serialize(serializer) - } - - pub(super) fn deserialize<'de, D>(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let SerSignatureFunc { signature, binary } = SerSignatureFunc::deserialize(deserializer)?; - - match (signature, binary) { - (Some(sig), false) => Ok(sig.into()), - (Some(sig), true) => Ok(SignatureFunc::MissingValidateFunc(sig)), - (None, true) => Ok(SignatureFunc::MissingComputeFunc), - (None, false) => Err(serde::de::Error::custom( - "No signature provided and custom computation not expected.", - )), - } - } -} - #[derive(PartialEq, Eq, Debug)] struct NoValidate; impl ValidateTypeArgs for NoValidate { diff --git a/hugr-core/src/extension/op_def/serialize_signature_func.rs b/hugr-core/src/extension/op_def/serialize_signature_func.rs new file mode 100644 index 0000000000..c8f69a930c --- /dev/null +++ b/hugr-core/src/extension/op_def/serialize_signature_func.rs @@ -0,0 +1,169 @@ +use serde::{Deserialize, Serialize}; + +use super::{PolyFuncTypeRV, SignatureFunc}; +#[derive(serde::Deserialize, serde::Serialize, PartialEq, Debug)] +struct SerSignatureFunc { + /// If the type scheme is available explicitly, store it. + signature: Option, + /// Whether an associated binary function is expected. + /// If `signature` is `None`, a true value here indicates a custom compute function. + /// If `signature` is not `None`, a true value here indicates a custom validation function. + binary: bool, +} + +pub(super) fn serialize(value: &super::SignatureFunc, serializer: S) -> Result +where + S: serde::Serializer, +{ + match value { + SignatureFunc::PolyFuncType(custom) => SerSignatureFunc { + signature: Some(custom.poly_func.clone()), + binary: custom.validate.is_some(), + }, + SignatureFunc::MissingValidateFunc(poly_func) => SerSignatureFunc { + signature: Some(poly_func.clone()), + binary: true, + }, + SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => SerSignatureFunc { + signature: None, + binary: true, + }, + } + .serialize(serializer) +} + +pub(super) fn deserialize<'de, D>(deserializer: D) -> Result +where + D: serde::Deserializer<'de>, +{ + let SerSignatureFunc { signature, binary } = SerSignatureFunc::deserialize(deserializer)?; + + match (signature, binary) { + (Some(sig), false) => Ok(sig.into()), + (Some(sig), true) => Ok(SignatureFunc::MissingValidateFunc(sig)), + (None, true) => Ok(SignatureFunc::MissingComputeFunc), + (None, false) => Err(serde::de::Error::custom( + "No signature provided and custom computation not expected.", + )), + } +} +#[derive(serde::Deserialize, serde::Serialize, Debug)] +/// Wrapper we can derive serde for, to allow round-trip serialization +struct Wrapper { + #[serde( + serialize_with = "serialize", + deserialize_with = "deserialize", + flatten + )] + inner: SignatureFunc, +} + +#[cfg(test)] +mod test { + use cool_asserts::assert_matches; + use serde::de::Error; + + use super::*; + use crate::{ + extension::{op_def::NoValidate, prelude::USIZE_T, CustomSignatureFunc, CustomValidator}, + types::{FuncValueType, Signature, TypeArg}, + }; + // Define test-only conversions via serialization roundtrip + impl TryFrom for SignatureFunc { + type Error = serde_json::Error; + fn try_from(value: SerSignatureFunc) -> Result { + let ser = serde_json::to_value(value).unwrap(); + let w: Wrapper = serde_json::from_value(ser)?; + Ok(w.inner) + } + } + + impl From for SerSignatureFunc { + fn from(value: SignatureFunc) -> Self { + let ser = serde_json::to_value(Wrapper { inner: value }).unwrap(); + serde_json::from_value(ser).unwrap() + } + } + struct CustomSig; + + impl CustomSignatureFunc for CustomSig { + fn compute_signature<'o, 'a: 'o>( + &'a self, + _arg_values: &[TypeArg], + _def: &'o crate::extension::op_def::OpDef, + _extension_registry: &crate::extension::ExtensionRegistry, + ) -> Result { + Ok(Default::default()) + } + + fn static_params(&self) -> &[crate::types::type_param::TypeParam] { + &[] + } + } + #[test] + fn test_serial_sig_func() { + // test round-trip + let sig: FuncValueType = Signature::new_endo(USIZE_T.clone()).into(); + let simple: SignatureFunc = sig.clone().into(); + let ser: SerSignatureFunc = simple.into(); + let expected_ser = SerSignatureFunc { + signature: Some(sig.clone().into()), + binary: false, + }; + + assert_eq!(ser, expected_ser); + let deser = SignatureFunc::try_from(ser).unwrap(); + assert_matches!( deser, + SignatureFunc::PolyFuncType(CustomValidator { + poly_func, + validate, + }) => { + assert_eq!(poly_func, sig.clone().into()); + assert!(validate.is_none()); + }); + + let with_custom: SignatureFunc = + CustomValidator::new_with_validator(sig.clone(), NoValidate).into(); + let ser: SerSignatureFunc = with_custom.into(); + let expected_ser = SerSignatureFunc { + signature: Some(sig.clone().into()), + binary: true, + }; + assert_eq!(ser, expected_ser); + let deser = SignatureFunc::try_from(ser).unwrap(); + assert_matches!(&deser, + SignatureFunc::MissingValidateFunc(poly_func) => { + assert_eq!(poly_func, &PolyFuncTypeRV::from(sig.clone())); + } + ); + + // re-serializing should give the same result + assert_eq!(SerSignatureFunc::from(deser), expected_ser); + + let custom: SignatureFunc = CustomSig.into(); + let ser: SerSignatureFunc = custom.into(); + let expected_ser = SerSignatureFunc { + signature: None, + binary: true, + }; + assert_eq!(ser, expected_ser); + + let deser = SignatureFunc::try_from(ser).unwrap(); + assert_matches!(&deser, &SignatureFunc::MissingComputeFunc); + + assert_eq!(SerSignatureFunc::from(deser), expected_ser); + + let bad_ser = SerSignatureFunc { + signature: None, + binary: false, + }; + + let err = SignatureFunc::try_from(bad_ser).unwrap_err(); + + assert_eq!( + err.to_string(), + serde_json::Error::custom("No signature provided and custom computation not expected.") + .to_string() + ); + } +} diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index e771679515..65b48ce460 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -517,6 +517,18 @@ fn roundtrip_optype(#[case] optype: impl Into + std::fmt::Debug) { }); } +#[test] +// test all standard extension serialisations are valid against scheme +fn std_extensions_valid() { + let std_reg = crate::std_extensions::std_reg(); + for (_, ext) in std_reg.into_iter() { + let val = serde_json::to_value(ext).unwrap(); + NamedSchema::check_schemas(&val, get_schemas(true)); + // check deserialises correctly, can't check equality because of custom binaries. + let _: crate::extension::Extension = serde_json::from_value(val).unwrap(); + } +} + mod proptest { use super::check_testing_roundtrip; use super::{NodeSer, SimpleOpDef}; diff --git a/hugr-core/src/std_extensions.rs b/hugr-core/src/std_extensions.rs index b93534df45..8ce6906745 100644 --- a/hugr-core/src/std_extensions.rs +++ b/hugr-core/src/std_extensions.rs @@ -2,7 +2,24 @@ //! //! These may be moved to other crates in the future, or dropped altogether. +use crate::extension::ExtensionRegistry; + pub mod arithmetic; pub mod collections; pub mod logic; pub mod ptr; + +/// Extension registry with all standard extensions and prelude. +pub fn std_reg() -> ExtensionRegistry { + ExtensionRegistry::try_new([ + crate::extension::prelude::PRELUDE.to_owned(), + arithmetic::int_ops::EXTENSION.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + arithmetic::conversions::EXTENSION.to_owned(), + arithmetic::float_ops::EXTENSION.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + logic::EXTENSION.to_owned(), + ptr::EXTENSION.to_owned(), + ]) + .unwrap() +}