From 41187ae6aa08768b24ef46b39d0e870db7e75f93 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Thu, 25 Apr 2024 14:19:54 +0100 Subject: [PATCH] Add tests for PolyFuncType serialisation --- hugr-py/src/hugr/serialization/tys.py | 4 +- hugr/src/hugr/serialize.rs | 44 +++++++++++-------- .../std_extensions/arithmetic/int_types.rs | 2 +- hugr/src/types/type_param.rs | 1 - specification/schema/hugr_schema_v1.json | 19 ++++++++ .../schema/testing_hugr_schema_v1.json | 19 ++++++++ 6 files changed, 68 insertions(+), 21 deletions(-) diff --git a/hugr-py/src/hugr/serialization/tys.py b/hugr-py/src/hugr/serialization/tys.py index dc13076141..416ef51670 100644 --- a/hugr-py/src/hugr/serialization/tys.py +++ b/hugr-py/src/hugr/serialization/tys.py @@ -74,12 +74,14 @@ class TupleParam(BaseModel): tp: Literal["Tuple"] = "Tuple" params: list["TypeParam"] +class ExtensionsParam(BaseModel): + tp: Literal["Extensions"] = "Extensions" class TypeParam(RootModel): """A type parameter.""" root: Annotated[ - TypeTypeParam | BoundedNatParam | OpaqueParam | ListParam | TupleParam, + TypeTypeParam | BoundedNatParam | OpaqueParam | ListParam | TupleParam | ExtensionsParam, WrapValidator(_json_custom_error_validator), ] = Field(discriminator="tp") diff --git a/hugr/src/hugr/serialize.rs b/hugr/src/hugr/serialize.rs index c67947e95e..59deb15d3e 100644 --- a/hugr/src/hugr/serialize.rs +++ b/hugr/src/hugr/serialize.rs @@ -283,9 +283,12 @@ pub mod test { use crate::ops::{dataflow::IOTrait, Input, Module, Noop, Output, DFG}; use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY; use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; - use crate::std_extensions::arithmetic::int_types::{ConstIntS, INT_TYPES}; + use crate::std_extensions::arithmetic::int_types::{int_custom_type, ConstIntS, INT_TYPES}; use crate::std_extensions::logic::NotOp; - use crate::types::{FunctionType, SumType, Type, TypeBound}; + use crate::types::{ + type_param::{TypeArg, TypeParam}, + FunctionType, PolyFuncType, SumType, Type, TypeBound, + }; use crate::{type_row, OutgoingPort}; use itertools::Itertools; use jsonschema::{Draft, JSONSchema}; @@ -638,20 +641,25 @@ pub mod test { check_testing_roundtrip(SerTesting { value }) } - // fn polyfunctype1() -> PolyFuncType { - // let mut extension_set = ExtensionSet::new(); - // extension_set.insert_type_var(1); - // let function_type = FunctionType::new_endo(type_row![]).with_extension_delta(extension_set); - // PolyFuncType::new([TypeParam::max_nat(), TypeParam::Extensions], function_type) - // } - // #[rstest] - // #[case(FunctionType::new_endo(type_row![]).into())] - // #[case(polyfunctype1())] - // fn roundtrip_polyfunctype(#[case] poly_func_type: PolyFuncType) { - // #[derive(Serialize, Deserialize, PartialEq, Debug)] - // struct SerTesting { - // poly_func_type: PolyFuncType, - // } - // check_testing_roundtrip(SerTesting { poly_func_type }) - // } + fn polyfunctype1() -> PolyFuncType { + let mut extension_set = ExtensionSet::new(); + extension_set.insert_type_var(1); + let function_type = FunctionType::new_endo(type_row![]).with_extension_delta(extension_set); + PolyFuncType::new([TypeParam::max_nat(), TypeParam::Extensions], function_type) + } + + #[rstest] + #[case(FunctionType::new_endo(type_row![]).into())] + #[case(polyfunctype1())] + #[case(PolyFuncType::new([TypeParam::Opaque { ty: int_custom_type(TypeArg::BoundedNat { n: 1 }) }], FunctionType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))] + #[case(PolyFuncType::new([TypeBound::Eq.into()], FunctionType::new_endo(type_row![Type::new_var_use(0, TypeBound::Eq)])))] + #[case(PolyFuncType::new([TypeParam::List { param: Box::new(TypeBound::Any.into()) }], FunctionType::new_endo(type_row![])))] + #[case(PolyFuncType::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], FunctionType::new_endo(type_row![])))] + fn roundtrip_polyfunctype(#[case] poly_func_type: PolyFuncType) { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct SerTesting { + poly_func_type: PolyFuncType, + } + check_testing_roundtrip(SerTesting { poly_func_type }) + } } diff --git a/hugr/src/std_extensions/arithmetic/int_types.rs b/hugr/src/std_extensions/arithmetic/int_types.rs index 7c7724cc89..058a88c02a 100644 --- a/hugr/src/std_extensions/arithmetic/int_types.rs +++ b/hugr/src/std_extensions/arithmetic/int_types.rs @@ -20,7 +20,7 @@ pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int /// Identifier for the integer type. pub const INT_TYPE_ID: SmolStr = SmolStr::new_inline("int"); -fn int_custom_type(width_arg: TypeArg) -> CustomType { +pub(crate) fn int_custom_type(width_arg: TypeArg) -> CustomType { CustomType::new(INT_TYPE_ID, [width_arg], EXTENSION_ID, TypeBound::Eq) } diff --git a/hugr/src/types/type_param.rs b/hugr/src/types/type_param.rs index 467e222eef..ffe0f6166f 100644 --- a/hugr/src/types/type_param.rs +++ b/hugr/src/types/type_param.rs @@ -136,7 +136,6 @@ impl From for TypeArg { } } - /// A statically-known argument value to an operation. #[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[non_exhaustive] diff --git a/specification/schema/hugr_schema_v1.json b/specification/schema/hugr_schema_v1.json index ba3b0f4930..3fb8c8961b 100644 --- a/specification/schema/hugr_schema_v1.json +++ b/specification/schema/hugr_schema_v1.json @@ -590,6 +590,21 @@ "title": "ExtensionsArg", "type": "object" }, + "ExtensionsParam": { + "properties": { + "tp": { + "const": "Extensions", + "default": "Extensions", + "enum": [ + "Extensions" + ], + "title": "Tp", + "type": "string" + } + }, + "title": "ExtensionsParam", + "type": "object" + }, "FuncDecl": { "description": "External function declaration, linked at runtime.", "properties": { @@ -1543,6 +1558,7 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", + "Extensions": "#/$defs/ExtensionsParam", "List": "#/$defs/ListParam", "Opaque": "#/$defs/OpaqueParam", "Tuple": "#/$defs/TupleParam", @@ -1565,6 +1581,9 @@ }, { "$ref": "#/$defs/TupleParam" + }, + { + "$ref": "#/$defs/ExtensionsParam" } ], "title": "TypeParam" diff --git a/specification/schema/testing_hugr_schema_v1.json b/specification/schema/testing_hugr_schema_v1.json index dfaec239b1..bfaf6b5d12 100644 --- a/specification/schema/testing_hugr_schema_v1.json +++ b/specification/schema/testing_hugr_schema_v1.json @@ -183,6 +183,21 @@ "title": "ExtensionsArg", "type": "object" }, + "ExtensionsParam": { + "properties": { + "tp": { + "const": "Extensions", + "default": "Extensions", + "enum": [ + "Extensions" + ], + "title": "Tp", + "type": "string" + } + }, + "title": "ExtensionsParam", + "type": "object" + }, "FunctionType": { "description": "A graph encoded as a value. It contains a concrete signature and a set of required resources.", "properties": { @@ -652,6 +667,7 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", + "Extensions": "#/$defs/ExtensionsParam", "List": "#/$defs/ListParam", "Opaque": "#/$defs/OpaqueParam", "Tuple": "#/$defs/TupleParam", @@ -674,6 +690,9 @@ }, { "$ref": "#/$defs/TupleParam" + }, + { + "$ref": "#/$defs/ExtensionsParam" } ], "title": "TypeParam"