Skip to content

Commit

Permalink
separate explicit and custom validate variants
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jul 30, 2024
1 parent e31839c commit a47252a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 61 deletions.
75 changes: 32 additions & 43 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,47 +120,37 @@ pub trait CustomLowerFunc: Send + Sync {
/// lost over a serialization round-trip.
pub struct CustomValidator {
poly_func: PolyFuncTypeRV,
/// Optional custom function for validating type arguments before returning the signature.
/// If None, no custom validation is performed.
pub(crate) validate: Option<Box<dyn ValidateTypeArgs>>,
//Optional custom function for validating type arguments before returning the signature.
//If None, no custom validation is performed.
pub(crate) validate: Box<dyn ValidateTypeArgs>,
}

impl CustomValidator {
/// Encode a signature using a `PolyFuncTypeRV`
pub fn from_polyfunc(poly_func: impl Into<PolyFuncTypeRV>) -> Self {
Self {
poly_func: poly_func.into(),
validate: Default::default(),
}
}

/// Encode a signature using a `PolyFuncTypeRV`, with a custom function for
/// validating type arguments before returning the signature.
pub fn new_with_validator(
pub fn new(
poly_func: impl Into<PolyFuncTypeRV>,
validate: impl ValidateTypeArgs + 'static,
) -> Self {
Self {
poly_func: poly_func.into(),
validate: Some(Box::new(validate)),
validate: Box::new(validate),
}
}
}

/// The two ways in which an OpDef may compute the Signature of each operation node.
/// The ways in which an OpDef may compute the Signature of each operation node.
pub enum SignatureFunc {
// Note: except for serialization, we could have type schemes just implement the same
// CustomSignatureFunc trait too, and replace this enum with Box<dyn CustomSignatureFunc>.
// However instead we treat all CustomFunc's as non-serializable.
/// A PolyFuncType (polymorphic function type), with optional custom
/// validation for provided type arguments,
PolyFuncType(CustomValidator),
/// Declaration specified a custom validate binary but it was not provided.
/// An explicit polymorphic function type.
PolyFuncType(PolyFuncTypeRV),
/// A polymorphic function type with a custom binary for validating type arguments.
CustomValidator(CustomValidator),
/// Serialized declaration specified a custom validate binary but it was not provided.
MissingValidateFunc(PolyFuncTypeRV),
/// A custom binary which computes a polymorphic function type given values
/// for its static type parameters.
CustomFunc(Box<dyn CustomSignatureFunc>),
/// Declaration specified a custom compute binary but it was not provided.
/// Serialized declaration specified a custom compute binary but it was not provided.
MissingComputeFunc,
}

Expand Down Expand Up @@ -191,38 +181,39 @@ impl<T: CustomSignatureFunc + 'static> From<T> for SignatureFunc {

impl From<PolyFuncType> for SignatureFunc {
fn from(value: PolyFuncType) -> Self {
Self::PolyFuncType(CustomValidator::from_polyfunc(value))
Self::PolyFuncType(value.into())
}
}

impl From<PolyFuncTypeRV> for SignatureFunc {
fn from(v: PolyFuncTypeRV) -> Self {
Self::PolyFuncType(CustomValidator::from_polyfunc(v))
Self::PolyFuncType(v)
}
}

impl From<FuncValueType> for SignatureFunc {
fn from(v: FuncValueType) -> Self {
Self::PolyFuncType(CustomValidator::from_polyfunc(v))
Self::PolyFuncType(v.into())
}
}

impl From<Signature> for SignatureFunc {
fn from(v: Signature) -> Self {
Self::PolyFuncType(CustomValidator::from_polyfunc(FuncValueType::from(v)))
Self::PolyFuncType(FuncValueType::from(v).into())
}
}

impl From<CustomValidator> for SignatureFunc {
fn from(v: CustomValidator) -> Self {
Self::PolyFuncType(v)
Self::CustomValidator(v)
}
}

impl SignatureFunc {
fn static_params(&self) -> Result<&[TypeParam], SignatureError> {
Ok(match self {
SignatureFunc::PolyFuncType(CustomValidator { poly_func: ts, .. })
SignatureFunc::PolyFuncType(ts)
| SignatureFunc::CustomValidator(CustomValidator { poly_func: ts, .. })
| SignatureFunc::MissingValidateFunc(ts) => ts.params(),
SignatureFunc::CustomFunc(func) => func.static_params(),
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc),
Expand All @@ -233,7 +224,7 @@ impl SignatureFunc {
/// self-contained type scheme (with no custom validation).
pub fn ignore_missing_validation(&mut self) {
if let SignatureFunc::MissingValidateFunc(ts) = self {
*self = SignatureFunc::PolyFuncType(CustomValidator::from_polyfunc(ts.clone()));
*self = SignatureFunc::PolyFuncType(ts.clone());
}
}

Expand All @@ -256,14 +247,11 @@ impl SignatureFunc {
) -> Result<Signature, SignatureError> {
let temp: PolyFuncTypeRV; // to keep alive
let (pf, args) = match &self {
SignatureFunc::PolyFuncType(custom) => {
custom
.validate
.as_ref()
.unwrap_or(&Default::default())
.validate(args, def, exts)?;
SignatureFunc::CustomValidator(custom) => {
custom.validate.as_ref().validate(args, def, exts)?;
(&custom.poly_func, args)
}
SignatureFunc::PolyFuncType(ts) => (ts, args),
SignatureFunc::CustomFunc(func) => {
let static_params = func.static_params();
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));
Expand All @@ -289,7 +277,8 @@ impl SignatureFunc {
impl Debug for SignatureFunc {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::PolyFuncType(ts) => ts.poly_func.fmt(f),
Self::CustomValidator(ts) => ts.poly_func.fmt(f),
Self::PolyFuncType(ts) => ts.fmt(f),
Self::CustomFunc { .. } => f.write_str("<custom sig>"),
Self::MissingComputeFunc => f.write_str("<missing custom sig>"),
Self::MissingValidateFunc(_) => f.write_str("<missing custom validation>"),
Expand Down Expand Up @@ -368,7 +357,8 @@ impl OpDef {
) -> Result<(), SignatureError> {
let temp: PolyFuncTypeRV; // to keep alive
let (pf, args) = match &self.signature_func {
SignatureFunc::PolyFuncType(ts) => (&ts.poly_func, args),
SignatureFunc::CustomValidator(ts) => (&ts.poly_func, args),
SignatureFunc::PolyFuncType(ts) => (ts, args),
SignatureFunc::CustomFunc(custom) => {
let (static_args, other_args) =
args.split_at(min(custom.static_params().len(), args.len()));
Expand Down Expand Up @@ -444,7 +434,7 @@ impl OpDef {
pub(super) fn validate(&self, exts: &ExtensionRegistry) -> Result<(), SignatureError> {
// TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams
// for both type scheme and custom binary
if let SignatureFunc::PolyFuncType(ts) = &self.signature_func {
if let SignatureFunc::CustomValidator(ts) = &self.signature_func {
// The type scheme may contain row variables so be of variable length;
// these will have to be substituted to fixed-length concrete types when
// the OpDef is instantiated into an actual OpType.
Expand Down Expand Up @@ -585,10 +575,11 @@ pub(super) mod test {
// a compile error here. To fix: modify the fields matched on here,
// maintaining the lack of `..` and, for each part that is
// serializable, ensure we are checking it for equality below.
SignatureFunc::PolyFuncType(CustomValidator {
SignatureFunc::CustomValidator(CustomValidator {
poly_func,
validate: _,
})
| SignatureFunc::PolyFuncType(poly_func)
| SignatureFunc::MissingValidateFunc(poly_func) => Some(poly_func.clone()),
// This is ruled out by `new()` but leave it here for later.
SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None,
Expand Down Expand Up @@ -816,9 +807,7 @@ pub(super) mod test {

use crate::{
builder::test::simple_dfg_hugr,
extension::{
op_def::LowerFunc, CustomValidator, ExtensionId, ExtensionSet, OpDef, SignatureFunc,
},
extension::{op_def::LowerFunc, ExtensionId, ExtensionSet, OpDef, SignatureFunc},
types::PolyFuncTypeRV,
};

Expand All @@ -830,7 +819,7 @@ pub(super) mod test {
// this is not serialized. When it is, we should generate
// examples here .
any::<PolyFuncTypeRV>()
.prop_map(|x| SignatureFunc::PolyFuncType(CustomValidator::from_polyfunc(x)))
.prop_map(SignatureFunc::PolyFuncType)
.boxed()
}
}
Expand Down
25 changes: 9 additions & 16 deletions hugr-core/src/extension/op_def/serialize_signature_func.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize};

use super::{PolyFuncTypeRV, SignatureFunc};
use super::{CustomValidator, PolyFuncTypeRV, SignatureFunc};
#[derive(serde::Deserialize, serde::Serialize, PartialEq, Debug, Clone)]
struct SerSignatureFunc {
/// If the type scheme is available explicitly, store it.
Expand All @@ -16,11 +16,12 @@ where
S: serde::Serializer,
{
match value {
SignatureFunc::PolyFuncType(custom) => SerSignatureFunc {
signature: Some(custom.poly_func.clone()),
binary: custom.validate.is_some(),
SignatureFunc::PolyFuncType(poly) => SerSignatureFunc {
signature: Some(poly.clone()),
binary: false,
},
SignatureFunc::MissingValidateFunc(poly_func) => SerSignatureFunc {
SignatureFunc::CustomValidator(CustomValidator { poly_func, .. })
| SignatureFunc::MissingValidateFunc(poly_func) => SerSignatureFunc {
signature: Some(poly_func.clone()),
binary: true,
},
Expand Down Expand Up @@ -114,16 +115,11 @@ mod test {
assert_eq!(ser, expected_ser);
let deser = SignatureFunc::try_from(ser).unwrap();
assert_matches!( deser,
SignatureFunc::PolyFuncType(CustomValidator {
poly_func,
validate,
}) => {
SignatureFunc::PolyFuncType(poly_func) => {
assert_eq!(poly_func, sig.clone().into());
assert!(validate.is_none());
});

let with_custom: SignatureFunc =
CustomValidator::new_with_validator(sig.clone(), NoValidate).into();
let with_custom: SignatureFunc = CustomValidator::new(sig.clone(), NoValidate).into();
let ser: SerSignatureFunc = with_custom.into();
let expected_ser = SerSignatureFunc {
signature: Some(sig.clone().into()),
Expand All @@ -144,10 +140,7 @@ mod test {
);

deser.ignore_missing_validation();
assert_matches!(
&deser,
&SignatureFunc::PolyFuncType(CustomValidator { validate: None, .. })
);
assert_matches!(&deser, &SignatureFunc::PolyFuncType(_));

let custom: SignatureFunc = CustomSig.into();
let ser: SerSignatureFunc = custom.into();
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ impl MakeOpDef for IntOpDef {
fn signature(&self) -> SignatureFunc {
use IntOpDef::*;
match self {
iwiden_s | iwiden_u => CustomValidator::new_with_validator(
iwiden_s | iwiden_u => CustomValidator::new(
int_polytype(2, vec![int_tv(0)], vec![int_tv(1)]),
IOValidator { f_ge_s: false },
)
.into(),
inarrow_s | inarrow_u => CustomValidator::new_with_validator(
inarrow_s | inarrow_u => CustomValidator::new(
int_polytype(2, int_tv(0), sum_ty_with_err(int_tv(1))),
IOValidator { f_ge_s: true },
)
Expand Down

0 comments on commit a47252a

Please sign in to comment.