-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat!: Serialised extensions #1371
Merged
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
078da5d
refactor!: Use internal tagging for TypeDefBound for pydantic compati…
ss2165 066dcde
feat: pydantic models for extension definition
ss2165 b9ae43a
feat: generate top level schemas including extensions
ss2165 f175105
feat!: Use custom serialization for SignatureFunc.
ss2165 c5ab7d0
skip lower_funcs if empty
ss2165 0f3ec3e
error rather than panic on missing function
ss2165 6b923d8
add some explanation comments
ss2165 6b7c788
CustomValidator docstrings
ss2165 bf071dc
feat: be explicit about missing validation
ss2165 0e3b160
feat: extract serialisation in to module and add test
ss2165 e31839c
simplify `ignore_missing_validation`
ss2165 a47252a
separate explicit and custom validate variants
ss2165 4cd9af6
fix `CustomValidator` docstrings
ss2165 21642ea
minor review suggestions
ss2165 8bcb85c
improve tests
ss2165 b243b17
move `NoValidate` to test only
ss2165 7010d0e
remove unused method
ss2165 ac834de
convenience constructors for `TypeDefBound`
ss2165 9990e78
fix schema gen script
ss2165 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ use crate::ops::{OpName, OpNameRef}; | |
use crate::types::type_param::{check_type_args, TypeArg, TypeParam}; | ||
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature}; | ||
use crate::Hugr; | ||
mod serialize_signature_func; | ||
|
||
/// Trait necessary for binary computations of OpDef signature | ||
pub trait CustomSignatureFunc: Send + Sync { | ||
|
@@ -114,29 +115,19 @@ pub trait CustomLowerFunc: Send + Sync { | |
) -> Option<Hugr>; | ||
} | ||
|
||
/// Encode a signature as `PolyFuncTypeRV` but optionally allow validating type | ||
/// Encode a signature as [PolyFuncTypeRV] but with additional validation of type | ||
/// arguments via a custom binary. The binary cannot be serialized so will be | ||
/// lost over a serialization round-trip. | ||
#[derive(serde::Deserialize, serde::Serialize)] | ||
pub struct CustomValidator { | ||
#[serde(flatten)] | ||
poly_func: PolyFuncTypeRV, | ||
#[serde(skip)] | ||
/// Custom function for validating type arguments before returning the signature. | ||
pub(crate) validate: Box<dyn ValidateTypeArgs>, | ||
} | ||
|
||
impl CustomValidator { | ||
/// Encode a signature using a `PolyFuncTypeRV` | ||
pub fn from_polyfunc(poly_func: impl Into<PolyFuncTypeRV>) -> Self { | ||
Self { | ||
poly_func: poly_func.into(), | ||
validate: Default::default(), | ||
} | ||
} | ||
|
||
/// Encode a signature using a `PolyFuncTypeRV`, with a custom function for | ||
/// validating type arguments before returning the signature. | ||
pub fn new_with_validator( | ||
pub fn new( | ||
poly_func: impl Into<PolyFuncTypeRV>, | ||
validate: impl ValidateTypeArgs + 'static, | ||
) -> Self { | ||
|
@@ -147,37 +138,19 @@ impl CustomValidator { | |
} | ||
} | ||
|
||
/// The two ways in which an OpDef may compute the Signature of each operation node. | ||
#[derive(serde::Deserialize, serde::Serialize)] | ||
/// The ways in which an OpDef may compute the Signature of each operation node. | ||
pub enum SignatureFunc { | ||
// Note: except for serialization, we could have type schemes just implement the same | ||
// CustomSignatureFunc trait too, and replace this enum with Box<dyn CustomSignatureFunc>. | ||
// However instead we treat all CustomFunc's as non-serializable. | ||
/// A PolyFuncType (polymorphic function type), with optional custom | ||
/// validation for provided type arguments, | ||
#[serde(rename = "signature")] | ||
PolyFuncType(CustomValidator), | ||
#[serde(skip)] | ||
/// An explicit polymorphic function type. | ||
PolyFuncType(PolyFuncTypeRV), | ||
/// A polymorphic function type (like [Self::PolyFuncType] but also with a custom binary for validating type arguments. | ||
CustomValidator(CustomValidator), | ||
/// Serialized declaration specified a custom validate binary but it was not provided. | ||
MissingValidateFunc(PolyFuncTypeRV), | ||
/// A custom binary which computes a polymorphic function type given values | ||
/// for its static type parameters. | ||
CustomFunc(Box<dyn CustomSignatureFunc>), | ||
} | ||
struct NoValidate; | ||
impl ValidateTypeArgs for NoValidate { | ||
fn validate<'o, 'a: 'o>( | ||
&self, | ||
_arg_values: &[TypeArg], | ||
_def: &'o OpDef, | ||
_extension_registry: &ExtensionRegistry, | ||
) -> Result<(), SignatureError> { | ||
Ok(()) | ||
} | ||
} | ||
|
||
impl Default for Box<dyn ValidateTypeArgs> { | ||
fn default() -> Self { | ||
Box::new(NoValidate) | ||
} | ||
/// Serialized declaration specified a custom compute binary but it was not provided. | ||
MissingComputeFunc, | ||
} | ||
|
||
impl<T: CustomSignatureFunc + 'static> From<T> for SignatureFunc { | ||
|
@@ -188,39 +161,50 @@ impl<T: CustomSignatureFunc + 'static> From<T> for SignatureFunc { | |
|
||
impl From<PolyFuncType> for SignatureFunc { | ||
fn from(value: PolyFuncType) -> Self { | ||
Self::PolyFuncType(CustomValidator::from_polyfunc(value)) | ||
Self::PolyFuncType(value.into()) | ||
} | ||
} | ||
|
||
impl From<PolyFuncTypeRV> for SignatureFunc { | ||
fn from(v: PolyFuncTypeRV) -> Self { | ||
Self::PolyFuncType(CustomValidator::from_polyfunc(v)) | ||
Self::PolyFuncType(v) | ||
} | ||
} | ||
|
||
impl From<FuncValueType> for SignatureFunc { | ||
fn from(v: FuncValueType) -> Self { | ||
Self::PolyFuncType(CustomValidator::from_polyfunc(v)) | ||
Self::PolyFuncType(v.into()) | ||
} | ||
} | ||
|
||
impl From<Signature> for SignatureFunc { | ||
fn from(v: Signature) -> Self { | ||
Self::PolyFuncType(CustomValidator::from_polyfunc(FuncValueType::from(v))) | ||
Self::PolyFuncType(FuncValueType::from(v).into()) | ||
} | ||
} | ||
|
||
impl From<CustomValidator> for SignatureFunc { | ||
fn from(v: CustomValidator) -> Self { | ||
Self::PolyFuncType(v) | ||
Self::CustomValidator(v) | ||
} | ||
} | ||
|
||
impl SignatureFunc { | ||
fn static_params(&self) -> &[TypeParam] { | ||
match self { | ||
SignatureFunc::PolyFuncType(ts) => ts.poly_func.params(), | ||
fn static_params(&self) -> Result<&[TypeParam], SignatureError> { | ||
Ok(match self { | ||
SignatureFunc::PolyFuncType(ts) | ||
| SignatureFunc::CustomValidator(CustomValidator { poly_func: ts, .. }) | ||
| SignatureFunc::MissingValidateFunc(ts) => ts.params(), | ||
SignatureFunc::CustomFunc(func) => func.static_params(), | ||
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc), | ||
acl-cqc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}) | ||
} | ||
|
||
/// If the signature is missing a custom validation function, ignore and treat as | ||
/// self-contained type scheme (with no custom validation). | ||
pub fn ignore_missing_validation(&mut self) { | ||
if let SignatureFunc::MissingValidateFunc(ts) = self { | ||
*self = SignatureFunc::PolyFuncType(ts.clone()); | ||
} | ||
} | ||
|
||
|
@@ -243,10 +227,11 @@ impl SignatureFunc { | |
) -> Result<Signature, SignatureError> { | ||
let temp: PolyFuncTypeRV; // to keep alive | ||
let (pf, args) = match &self { | ||
SignatureFunc::PolyFuncType(custom) => { | ||
SignatureFunc::CustomValidator(custom) => { | ||
custom.validate.validate(args, def, exts)?; | ||
(&custom.poly_func, args) | ||
} | ||
SignatureFunc::PolyFuncType(ts) => (ts, args), | ||
SignatureFunc::CustomFunc(func) => { | ||
let static_params = func.static_params(); | ||
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len())); | ||
|
@@ -255,6 +240,10 @@ impl SignatureFunc { | |
temp = func.compute_signature(static_args, def, exts)?; | ||
(&temp, other_args) | ||
} | ||
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc), | ||
SignatureFunc::MissingValidateFunc(_) => { | ||
return Err(SignatureError::MissingValidateFunc) | ||
} | ||
}; | ||
|
||
let mut res = pf.instantiate(args, exts)?; | ||
|
@@ -268,8 +257,11 @@ impl SignatureFunc { | |
impl Debug for SignatureFunc { | ||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { | ||
match self { | ||
Self::PolyFuncType(ts) => ts.poly_func.fmt(f), | ||
Self::CustomValidator(ts) => ts.poly_func.fmt(f), | ||
Self::PolyFuncType(ts) => ts.fmt(f), | ||
Self::CustomFunc { .. } => f.write_str("<custom sig>"), | ||
Self::MissingComputeFunc => f.write_str("<missing custom sig>"), | ||
Self::MissingValidateFunc(_) => f.write_str("<missing custom validation>"), | ||
} | ||
} | ||
} | ||
|
@@ -321,10 +313,11 @@ pub struct OpDef { | |
#[serde(default, skip_serializing_if = "HashMap::is_empty")] | ||
misc: HashMap<String, serde_json::Value>, | ||
|
||
#[serde(flatten)] | ||
#[serde(with = "serialize_signature_func", flatten)] | ||
signature_func: SignatureFunc, | ||
// Some operations cannot lower themselves and tools that do not understand them | ||
// can only treat them as opaque/black-box ops. | ||
#[serde(default, skip_serializing_if = "Vec::is_empty")] | ||
pub(crate) lower_funcs: Vec<LowerFunc>, | ||
|
||
/// Operations can optionally implement [`ConstFold`] to implement constant folding. | ||
|
@@ -344,7 +337,8 @@ impl OpDef { | |
) -> Result<(), SignatureError> { | ||
let temp: PolyFuncTypeRV; // to keep alive | ||
let (pf, args) = match &self.signature_func { | ||
SignatureFunc::PolyFuncType(ts) => (&ts.poly_func, args), | ||
SignatureFunc::CustomValidator(ts) => (&ts.poly_func, args), | ||
SignatureFunc::PolyFuncType(ts) => (ts, args), | ||
SignatureFunc::CustomFunc(custom) => { | ||
let (static_args, other_args) = | ||
args.split_at(min(custom.static_params().len(), args.len())); | ||
|
@@ -355,6 +349,10 @@ impl OpDef { | |
temp = custom.compute_signature(static_args, self, exts)?; | ||
(&temp, other_args) | ||
} | ||
SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc), | ||
SignatureFunc::MissingValidateFunc(_) => { | ||
return Err(SignatureError::MissingValidateFunc) | ||
} | ||
}; | ||
args.iter() | ||
.try_for_each(|ta| ta.validate(exts, var_decls))?; | ||
|
@@ -409,14 +407,14 @@ impl OpDef { | |
} | ||
|
||
/// Returns a reference to the params of this [`OpDef`]. | ||
pub fn params(&self) -> &[TypeParam] { | ||
pub fn params(&self) -> Result<&[TypeParam], SignatureError> { | ||
self.signature_func.static_params() | ||
} | ||
|
||
pub(super) fn validate(&self, exts: &ExtensionRegistry) -> Result<(), SignatureError> { | ||
// TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams | ||
// for both type scheme and custom binary | ||
if let SignatureFunc::PolyFuncType(ts) = &self.signature_func { | ||
if let SignatureFunc::CustomValidator(ts) = &self.signature_func { | ||
// The type scheme may contain row variables so be of variable length; | ||
// these will have to be substituted to fixed-length concrete types when | ||
// the OpDef is instantiated into an actual OpType. | ||
|
@@ -557,12 +555,13 @@ pub(super) mod test { | |
// a compile error here. To fix: modify the fields matched on here, | ||
// maintaining the lack of `..` and, for each part that is | ||
// serializable, ensure we are checking it for equality below. | ||
SignatureFunc::PolyFuncType(CustomValidator { | ||
SignatureFunc::CustomValidator(CustomValidator { | ||
poly_func, | ||
validate: _, | ||
}) => Some(poly_func.clone()), | ||
// This is ruled out by `new()` but leave it here for later. | ||
SignatureFunc::CustomFunc(_) => None, | ||
}) | ||
| SignatureFunc::PolyFuncType(poly_func) | ||
| SignatureFunc::MissingValidateFunc(poly_func) => Some(poly_func.clone()), | ||
SignatureFunc::CustomFunc(_) | SignatureFunc::MissingComputeFunc => None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect the comment on the previous line can go (I suspect "later" meant "when we have implemented serialization", essentially!), please clarify if I'm wrong |
||
}; | ||
|
||
let get_lower_funcs = |lfs: &Vec<LowerFunc>| { | ||
|
@@ -787,9 +786,7 @@ pub(super) mod test { | |
|
||
use crate::{ | ||
builder::test::simple_dfg_hugr, | ||
extension::{ | ||
op_def::LowerFunc, CustomValidator, ExtensionId, ExtensionSet, OpDef, SignatureFunc, | ||
}, | ||
extension::{op_def::LowerFunc, ExtensionId, ExtensionSet, OpDef, SignatureFunc}, | ||
types::PolyFuncTypeRV, | ||
}; | ||
|
||
|
@@ -801,7 +798,7 @@ pub(super) mod test { | |
// this is not serialized. When it is, we should generate | ||
// examples here . | ||
any::<PolyFuncTypeRV>() | ||
.prop_map(|x| SignatureFunc::PolyFuncType(CustomValidator::from_polyfunc(x))) | ||
.prop_map(SignatureFunc::PolyFuncType) | ||
.boxed() | ||
} | ||
} | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about a missing custom validation func? In either case I think we just have to trust the cache, maybe with a warning, so I guess if the warning is the same then we don't have to distinguish, is that the plan?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure - in the case of missing validation function do you trust the cached signature or use the type scheme to generate the signature and check against cache as you would without custom validation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Given the validation function can only say one of two things - invalid, or use the type scheme - you can try the latter (which might reject if TypeArgs don't match the TypeParams, say) and see if that matches; that might get you an error, but even if the typescheme says ok, if there's a binary validation function that you haven't got, then that still has to be a warning