-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat!: Serialised extensions #1371
Changes from 12 commits
078da5d
066dcde
b9ae43a
f175105
c5ab7d0
0f3ec3e
6b923d8
6b7c788
bf071dc
0e3b160
e31839c
a47252a
4cd9af6
21642ea
8bcb85c
b243b17
7010d0e
ac834de
9990e78
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ use crate::ops::{OpName, OpNameRef}; | |
use crate::types::type_param::{check_type_args, TypeArg, TypeParam}; | ||
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature}; | ||
use crate::Hugr; | ||
mod serialize_signature_func; | ||
|
||
/// Trait necessary for binary computations of OpDef signature | ||
pub trait CustomSignatureFunc: Send + Sync { | ||
|
@@ -117,26 +118,17 @@ pub trait CustomLowerFunc: Send + Sync { | |
/// Encode a signature as `PolyFuncTypeRV` but optionally allow validating type | ||
/// arguments via a custom binary. The binary cannot be serialized so will be | ||
/// lost over a serialization round-trip. | ||
#[derive(serde::Deserialize, serde::Serialize)] | ||
pub struct CustomValidator { | ||
#[serde(flatten)] | ||
poly_func: PolyFuncTypeRV, | ||
#[serde(skip)] | ||
//Optional custom function for validating type arguments before returning the signature. | ||
//If None, no custom validation is performed. | ||
pub(crate) validate: Box<dyn ValidateTypeArgs>, | ||
} | ||
|
||
impl CustomValidator { | ||
/// Encode a signature using a `PolyFuncTypeRV` | ||
pub fn from_polyfunc(poly_func: impl Into<PolyFuncTypeRV>) -> Self { | ||
Self { | ||
poly_func: poly_func.into(), | ||
validate: Default::default(), | ||
} | ||
} | ||
|
||
/// Encode a signature using a `PolyFuncTypeRV`, with a custom function for | ||
/// validating type arguments before returning the signature. | ||
pub fn new_with_validator( | ||
pub fn new( | ||
poly_func: impl Into<PolyFuncTypeRV>, | ||
validate: impl ValidateTypeArgs + 'static, | ||
) -> Self { | ||
|
@@ -147,21 +139,22 @@ impl CustomValidator { | |
} | ||
} | ||
|
||
/// The two ways in which an OpDef may compute the Signature of each operation node. | ||
#[derive(serde::Deserialize, serde::Serialize)] | ||
/// The ways in which an OpDef may compute the Signature of each operation node. | ||
pub enum SignatureFunc { | ||
// Note: except for serialization, we could have type schemes just implement the same | ||
// CustomSignatureFunc trait too, and replace this enum with Box<dyn CustomSignatureFunc>. | ||
// However instead we treat all CustomFunc's as non-serializable. | ||
/// A PolyFuncType (polymorphic function type), with optional custom | ||
/// validation for provided type arguments, | ||
#[serde(rename = "signature")] | ||
PolyFuncType(CustomValidator), | ||
#[serde(skip)] | ||
/// An explicit polymorphic function type. | ||
PolyFuncType(PolyFuncTypeRV), | ||
/// A polymorphic function type with a custom binary for validating type arguments. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we could slightly emphasize that this is mostly the same as the previous case but PLUS a custom binary |
||
CustomValidator(CustomValidator), | ||
/// Serialized declaration specified a custom validate binary but it was not provided. | ||
MissingValidateFunc(PolyFuncTypeRV), | ||
/// A custom binary which computes a polymorphic function type given values | ||
/// for its static type parameters. | ||
CustomFunc(Box<dyn CustomSignatureFunc>), | ||
/// Serialized declaration specified a custom compute binary but it was not provided. | ||
MissingComputeFunc, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about a missing custom validation func? In either case I think we just have to trust the cache, maybe with a warning, so I guess if the warning is the same then we don't have to distinguish, is that the plan? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure - in the case of missing validation function do you trust the cached signature or use the type scheme to generate the signature and check against cache as you would without custom validation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. Given the validation function can only say one of two things - invalid, or use the type scheme - you can try the latter (which might reject if TypeArgs don't match the TypeParams, say) and see if that matches; that might get you an error, but even if the typescheme says ok, if there's a binary validation function that you haven't got, then that still has to be a warning |
||
} | ||
|
||
#[derive(PartialEq, Eq, Debug)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Methinks we should not need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, moved it just to the test that used it |
||
struct NoValidate; | ||
impl ValidateTypeArgs for NoValidate { | ||
fn validate<'o, 'a: 'o>( | ||
|
@@ -188,39 +181,50 @@ impl<T: CustomSignatureFunc + 'static> From<T> for SignatureFunc { | |
|
||
impl From<PolyFuncType> for SignatureFunc { | ||
fn from(value: PolyFuncType) -> Self { | ||
Self::PolyFuncType(CustomValidator::from_polyfunc(value)) | ||
Self::PolyFuncType(value.into()) | ||
} | ||
} | ||
|
||
impl From<PolyFuncTypeRV> for SignatureFunc { | ||
fn from(v: PolyFuncTypeRV) -> Self { | ||
Self::PolyFuncType(CustomValidator::from_polyfunc(v)) | ||
Self::PolyFuncType(v) | ||
} | ||
} | ||
|
||
impl From<FuncValueType> for SignatureFunc { | ||
fn from(v: FuncValueType) -> Self { | ||
Self::PolyFuncType(CustomValidator::from_polyfunc(v)) | ||
Self::PolyFuncType(v.into()) | ||
} | ||
} | ||
|
||
impl From<Signature> for SignatureFunc { | ||
fn from(v: Signature) -> Self { | ||
Self::PolyFuncType(CustomValidator::from_polyfunc(FuncValueType::from(v))) | ||
Self::PolyFuncType(FuncValueType::from(v).into()) | ||
} | ||
} | ||
|
||
impl From<CustomValidator> for SignatureFunc { | ||
fn from(v: CustomValidator) -> Self { | ||
Self::PolyFuncType(v) | ||
Self::CustomValidator(v) | ||
} | ||
} | ||
|
||
impl SignatureFunc { | ||
fn static_params(&self) -> &[TypeParam] { | ||
match self { | ||
SignatureFunc::PolyFuncType(ts) => ts.poly_func.params(), | ||
fn static_params(&self) -> Result<&[TypeParam], SignatureError> { | ||
Ok(match self { | ||
SignatureFunc::PolyFuncType(ts) | ||
| SignatureFunc::CustomValidator(CustomValidator { poly_func: ts, .. }) | ||
| SignatureFunc::MissingValidateFunc(ts) => ts.params(), | ||
SignatureFunc::CustomFunc(func) => func.static_params(), | ||
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc), | ||
acl-cqc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}) | ||
} | ||
|
||
/// If the signature is missing a custom validation function, ignore and treat as | ||
/// self-contained type scheme (with no custom validation). | ||
pub fn ignore_missing_validation(&mut self) { | ||
if let SignatureFunc::MissingValidateFunc(ts) = self { | ||
*self = SignatureFunc::PolyFuncType(ts.clone()); | ||
} | ||
} | ||
|
||
|
@@ -243,10 +247,11 @@ impl SignatureFunc { | |
) -> Result<Signature, SignatureError> { | ||
let temp: PolyFuncTypeRV; // to keep alive | ||
let (pf, args) = match &self { | ||
SignatureFunc::PolyFuncType(custom) => { | ||
custom.validate.validate(args, def, exts)?; | ||
SignatureFunc::CustomValidator(custom) => { | ||
custom.validate.as_ref().validate(args, def, exts)?; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. presumably this extra There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wasn't needed, removed |
||
(&custom.poly_func, args) | ||
} | ||
SignatureFunc::PolyFuncType(ts) => (ts, args), | ||
SignatureFunc::CustomFunc(func) => { | ||
let static_params = func.static_params(); | ||
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len())); | ||
|
@@ -255,6 +260,10 @@ impl SignatureFunc { | |
temp = func.compute_signature(static_args, def, exts)?; | ||
(&temp, other_args) | ||
} | ||
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc), | ||
SignatureFunc::MissingValidateFunc(_) => { | ||
return Err(SignatureError::MissingValidateFunc) | ||
} | ||
}; | ||
|
||
let mut res = pf.instantiate(args, exts)?; | ||
|
@@ -268,8 +277,11 @@ impl SignatureFunc { | |
impl Debug for SignatureFunc { | ||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { | ||
match self { | ||
Self::PolyFuncType(ts) => ts.poly_func.fmt(f), | ||
Self::CustomValidator(ts) => ts.poly_func.fmt(f), | ||
Self::PolyFuncType(ts) => ts.fmt(f), | ||
Self::CustomFunc { .. } => f.write_str("<custom sig>"), | ||
Self::MissingComputeFunc => f.write_str("<missing custom sig>"), | ||
Self::MissingValidateFunc(_) => f.write_str("<missing custom validation>"), | ||
} | ||
} | ||
} | ||
|
@@ -321,10 +333,11 @@ pub struct OpDef { | |
#[serde(default, skip_serializing_if = "HashMap::is_empty")] | ||
misc: HashMap<String, serde_json::Value>, | ||
|
||
#[serde(flatten)] | ||
#[serde(with = "serialize_signature_func", flatten)] | ||
signature_func: SignatureFunc, | ||
// Some operations cannot lower themselves and tools that do not understand them | ||
// can only treat them as opaque/black-box ops. | ||
#[serde(default, skip_serializing_if = "Vec::is_empty")] | ||
pub(crate) lower_funcs: Vec<LowerFunc>, | ||
|
||
/// Operations can optionally implement [`ConstFold`] to implement constant folding. | ||
|
@@ -344,7 +357,8 @@ impl OpDef { | |
) -> Result<(), SignatureError> { | ||
let temp: PolyFuncTypeRV; // to keep alive | ||
let (pf, args) = match &self.signature_func { | ||
SignatureFunc::PolyFuncType(ts) => (&ts.poly_func, args), | ||
SignatureFunc::CustomValidator(ts) => (&ts.poly_func, args), | ||
SignatureFunc::PolyFuncType(ts) => (ts, args), | ||
SignatureFunc::CustomFunc(custom) => { | ||
let (static_args, other_args) = | ||
args.split_at(min(custom.static_params().len(), args.len())); | ||
|
@@ -355,6 +369,10 @@ impl OpDef { | |
temp = custom.compute_signature(static_args, self, exts)?; | ||
(&temp, other_args) | ||
} | ||
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc), | ||
SignatureFunc::MissingValidateFunc(_) => { | ||
return Err(SignatureError::MissingValidateFunc) | ||
} | ||
}; | ||
args.iter() | ||
.try_for_each(|ta| ta.validate(exts, var_decls))?; | ||
|
@@ -409,14 +427,14 @@ impl OpDef { | |
} | ||
|
||
/// Returns a reference to the params of this [`OpDef`]. | ||
pub fn params(&self) -> &[TypeParam] { | ||
pub fn params(&self) -> Result<&[TypeParam], SignatureError> { | ||
self.signature_func.static_params() | ||
} | ||
|
||
pub(super) fn validate(&self, exts: &ExtensionRegistry) -> Result<(), SignatureError> { | ||
// TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams | ||
// for both type scheme and custom binary | ||
if let SignatureFunc::PolyFuncType(ts) = &self.signature_func { | ||
if let SignatureFunc::CustomValidator(ts) = &self.signature_func { | ||
// The type scheme may contain row variables so be of variable length; | ||
// these will have to be substituted to fixed-length concrete types when | ||
// the OpDef is instantiated into an actual OpType. | ||
|
@@ -557,12 +575,14 @@ pub(super) mod test { | |
// a compile error here. To fix: modify the fields matched on here, | ||
// maintaining the lack of `..` and, for each part that is | ||
// serializable, ensure we are checking it for equality below. | ||
SignatureFunc::PolyFuncType(CustomValidator { | ||
SignatureFunc::CustomValidator(CustomValidator { | ||
poly_func, | ||
validate: _, | ||
}) => Some(poly_func.clone()), | ||
}) | ||
| SignatureFunc::PolyFuncType(poly_func) | ||
| SignatureFunc::MissingValidateFunc(poly_func) => Some(poly_func.clone()), | ||
// This is ruled out by `new()` but leave it here for later. | ||
SignatureFunc::CustomFunc(_) => None, | ||
SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect the comment on the previous line can go (I suspect "later" meant "when we have implemented serialization", essentially!), please clarify if I'm wrong |
||
}; | ||
|
||
let get_lower_funcs = |lfs: &Vec<LowerFunc>| { | ||
|
@@ -787,9 +807,7 @@ pub(super) mod test { | |
|
||
use crate::{ | ||
builder::test::simple_dfg_hugr, | ||
extension::{ | ||
op_def::LowerFunc, CustomValidator, ExtensionId, ExtensionSet, OpDef, SignatureFunc, | ||
}, | ||
extension::{op_def::LowerFunc, ExtensionId, ExtensionSet, OpDef, SignatureFunc}, | ||
types::PolyFuncTypeRV, | ||
}; | ||
|
||
|
@@ -801,7 +819,7 @@ pub(super) mod test { | |
// this is not serialized. When it is, we should generate | ||
// examples here . | ||
any::<PolyFuncTypeRV>() | ||
.prop_map(|x| SignatureFunc::PolyFuncType(CustomValidator::from_polyfunc(x))) | ||
.prop_map(SignatureFunc::PolyFuncType) | ||
.boxed() | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't look very optional ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I'd started commenting in a different direction, but making this non-optional is a different way of doing what I was going to suggest, that works for me. Just correct the comment tho :)