Skip to content

Commit

Permalink
Add tests for PolyFuncType serialisation
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed Apr 25, 2024
1 parent df801eb commit 41187ae
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 21 deletions.
4 changes: 3 additions & 1 deletion hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,14 @@ class TupleParam(BaseModel):
tp: Literal["Tuple"] = "Tuple"
params: list["TypeParam"]

class ExtensionsParam(BaseModel):
tp: Literal["Extensions"] = "Extensions"

class TypeParam(RootModel):
"""A type parameter."""

root: Annotated[
TypeTypeParam | BoundedNatParam | OpaqueParam | ListParam | TupleParam,
TypeTypeParam | BoundedNatParam | OpaqueParam | ListParam | TupleParam | ExtensionsParam,
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="tp")

Expand Down
44 changes: 26 additions & 18 deletions hugr/src/hugr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,12 @@ pub mod test {
use crate::ops::{dataflow::IOTrait, Input, Module, Noop, Output, DFG};
use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::arithmetic::int_types::{ConstIntS, INT_TYPES};
use crate::std_extensions::arithmetic::int_types::{int_custom_type, ConstIntS, INT_TYPES};
use crate::std_extensions::logic::NotOp;
use crate::types::{FunctionType, SumType, Type, TypeBound};
use crate::types::{
type_param::{TypeArg, TypeParam},
FunctionType, PolyFuncType, SumType, Type, TypeBound,
};
use crate::{type_row, OutgoingPort};
use itertools::Itertools;
use jsonschema::{Draft, JSONSchema};
Expand Down Expand Up @@ -638,20 +641,25 @@ pub mod test {
check_testing_roundtrip(SerTesting { value })
}

// fn polyfunctype1() -> PolyFuncType {
// let mut extension_set = ExtensionSet::new();
// extension_set.insert_type_var(1);
// let function_type = FunctionType::new_endo(type_row![]).with_extension_delta(extension_set);
// PolyFuncType::new([TypeParam::max_nat(), TypeParam::Extensions], function_type)
// }
// #[rstest]
// #[case(FunctionType::new_endo(type_row![]).into())]
// #[case(polyfunctype1())]
// fn roundtrip_polyfunctype(#[case] poly_func_type: PolyFuncType) {
// #[derive(Serialize, Deserialize, PartialEq, Debug)]
// struct SerTesting {
// poly_func_type: PolyFuncType,
// }
// check_testing_roundtrip(SerTesting { poly_func_type })
// }
fn polyfunctype1() -> PolyFuncType {
let mut extension_set = ExtensionSet::new();
extension_set.insert_type_var(1);
let function_type = FunctionType::new_endo(type_row![]).with_extension_delta(extension_set);
PolyFuncType::new([TypeParam::max_nat(), TypeParam::Extensions], function_type)
}

#[rstest]
#[case(FunctionType::new_endo(type_row![]).into())]
#[case(polyfunctype1())]
#[case(PolyFuncType::new([TypeParam::Opaque { ty: int_custom_type(TypeArg::BoundedNat { n: 1 }) }], FunctionType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))]
#[case(PolyFuncType::new([TypeBound::Eq.into()], FunctionType::new_endo(type_row![Type::new_var_use(0, TypeBound::Eq)])))]
#[case(PolyFuncType::new([TypeParam::List { param: Box::new(TypeBound::Any.into()) }], FunctionType::new_endo(type_row![])))]
#[case(PolyFuncType::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], FunctionType::new_endo(type_row![])))]
fn roundtrip_polyfunctype(#[case] poly_func_type: PolyFuncType) {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct SerTesting {
poly_func_type: PolyFuncType,
}
check_testing_roundtrip(SerTesting { poly_func_type })
}
}
2 changes: 1 addition & 1 deletion hugr/src/std_extensions/arithmetic/int_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int
/// Identifier for the integer type.
pub const INT_TYPE_ID: SmolStr = SmolStr::new_inline("int");

fn int_custom_type(width_arg: TypeArg) -> CustomType {
pub(crate) fn int_custom_type(width_arg: TypeArg) -> CustomType {
CustomType::new(INT_TYPE_ID, [width_arg], EXTENSION_ID, TypeBound::Eq)
}

Expand Down
1 change: 0 additions & 1 deletion hugr/src/types/type_param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ impl From<Type> for TypeArg {
}
}


/// A statically-known argument value to an operation.
#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
#[non_exhaustive]
Expand Down
19 changes: 19 additions & 0 deletions specification/schema/hugr_schema_v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,21 @@
"title": "ExtensionsArg",
"type": "object"
},
"ExtensionsParam": {
"properties": {
"tp": {
"const": "Extensions",
"default": "Extensions",
"enum": [
"Extensions"
],
"title": "Tp",
"type": "string"
}
},
"title": "ExtensionsParam",
"type": "object"
},
"FuncDecl": {
"description": "External function declaration, linked at runtime.",
"properties": {
Expand Down Expand Up @@ -1543,6 +1558,7 @@
"discriminator": {
"mapping": {
"BoundedNat": "#/$defs/BoundedNatParam",
"Extensions": "#/$defs/ExtensionsParam",
"List": "#/$defs/ListParam",
"Opaque": "#/$defs/OpaqueParam",
"Tuple": "#/$defs/TupleParam",
Expand All @@ -1565,6 +1581,9 @@
},
{
"$ref": "#/$defs/TupleParam"
},
{
"$ref": "#/$defs/ExtensionsParam"
}
],
"title": "TypeParam"
Expand Down
19 changes: 19 additions & 0 deletions specification/schema/testing_hugr_schema_v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,21 @@
"title": "ExtensionsArg",
"type": "object"
},
"ExtensionsParam": {
"properties": {
"tp": {
"const": "Extensions",
"default": "Extensions",
"enum": [
"Extensions"
],
"title": "Tp",
"type": "string"
}
},
"title": "ExtensionsParam",
"type": "object"
},
"FunctionType": {
"description": "A graph encoded as a value. It contains a concrete signature and a set of required resources.",
"properties": {
Expand Down Expand Up @@ -652,6 +667,7 @@
"discriminator": {
"mapping": {
"BoundedNat": "#/$defs/BoundedNatParam",
"Extensions": "#/$defs/ExtensionsParam",
"List": "#/$defs/ListParam",
"Opaque": "#/$defs/OpaqueParam",
"Tuple": "#/$defs/TupleParam",
Expand All @@ -674,6 +690,9 @@
},
{
"$ref": "#/$defs/TupleParam"
},
{
"$ref": "#/$defs/ExtensionsParam"
}
],
"title": "TypeParam"
Expand Down

0 comments on commit 41187ae

Please sign in to comment.