Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: serialisation schema #968

Merged
merged 16 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci-rs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ env:
CI: true # insta snapshots behave differently on ci
SCCACHE_GHA_ENABLED: "true"
RUSTC_WRAPPER: "sccache"
HUGR_TEST_SCHEMA: "1"

jobs:
# Check if changes were made to the relevant files.
Expand Down
11 changes: 10 additions & 1 deletion hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
PolyFuncType,
Type,
TypeRow,
SumType,
TypeBound,
)

NodeID = int
Expand Down Expand Up @@ -126,7 +128,7 @@ class SumValue(BaseModel):

c: Literal["Sum"] = Field("Sum", title="ValueTag")
tag: int
typ: Type
typ: SumType
vs: list["Value"]

class Config:
Expand Down Expand Up @@ -475,6 +477,12 @@ class Lift(DataflowOp):
new_extension: ExtensionId


class AliasDecl(BaseOp):
op: Literal["AliasDecl"] = "AliasDecl"
name: str
bound: TypeBound


class OpType(RootModel):
"""A constant operation."""

Expand All @@ -501,6 +509,7 @@ class OpType(RootModel):
| Tag
| Lift
| DFG
| AliasDecl
) = Field(discriminator="op")


Expand Down
23 changes: 23 additions & 0 deletions hugr-py/src/hugr/serialization/testing_hugr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Literal, Optional
from pydantic import BaseModel
from .tys import Type, SumType, PolyFuncType
from .ops import Value


class TestingHugr(BaseModel):
"""A serializable representation of a Hugr Type, SumType, PolyFuncType, or
Value. Intended for testing only."""

version: Literal["v1"] = "v1"
typ: Optional[Type] = None
sum_type: Optional[SumType] = None
poly_func_type: Optional[PolyFuncType] = None
value: Optional[Value] = None

@classmethod
def get_version(cls) -> str:
"""Return the version of the schema."""
return cls().version

class Config:
title = "HugrTesting"
41 changes: 33 additions & 8 deletions hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,20 @@ class TupleParam(BaseModel):
params: list["TypeParam"]


class ExtensionsParam(BaseModel):
tp: Literal["Extensions"] = "Extensions"


class TypeParam(RootModel):
"""A type parameter."""

root: Annotated[
TypeTypeParam | BoundedNatParam | OpaqueParam | ListParam | TupleParam,
TypeTypeParam
| BoundedNatParam
| OpaqueParam
| ListParam
| TupleParam
| ExtensionsParam,
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="tp")

Expand Down Expand Up @@ -145,9 +154,7 @@ class Array(MultiContainer):


class UnitSum(BaseModel):
"""Simple predicate where all variants are empty tuples."""

t: Literal["Sum"] = "Sum"
"""Simple sum type where all variants are empty tuples."""

s: Literal["Unit"] = "Unit"
size: int
Expand All @@ -156,8 +163,6 @@ class UnitSum(BaseModel):
class GeneralSum(BaseModel):
"""General sum type that explicitly stores the types of the variants."""

t: Literal["Sum"] = "Sum"

s: Literal["General"] = "General"
rows: list["TypeRow"]

Expand All @@ -166,6 +171,11 @@ class SumType(RootModel):
root: Union[UnitSum, GeneralSum] = Field(discriminator="s")


class TaggedSumType(BaseModel):
t: Literal["Sum"] = "Sum"
st: SumType


# ----------------------------------------------
# --------------- ClassicType ------------------
# ----------------------------------------------
Expand Down Expand Up @@ -254,7 +264,7 @@ def join(*bs: "TypeBound") -> "TypeBound":


class Opaque(BaseModel):
"""An opaque operation that can be downcasted by the extensions that define it."""
"""An opaque Type that can be downcasted by the extensions that define it."""

t: Literal["Opaque"] = "Opaque"
extension: ExtensionId
Expand All @@ -263,6 +273,14 @@ class Opaque(BaseModel):
bound: TypeBound


class Alias(BaseModel):
"""An Alias Type"""

t: Literal["Alias"] = "Alias"
bound: TypeBound
name: str


# ----------------------------------------------
# --------------- LinearType -------------------
# ----------------------------------------------
Expand All @@ -278,7 +296,14 @@ class Type(RootModel):
"""A HUGR type."""

root: Annotated[
Qubit | Variable | USize | FunctionType | Array | SumType | Opaque,
Qubit
| Variable
| USize
| FunctionType
| Array
| TaggedSumType
| Opaque
| Alias,
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="t")

Expand Down
47 changes: 42 additions & 5 deletions hugr/src/hugr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,29 @@ use super::{HugrMut, HugrView};
/// recent version of the format. We keep the `Deserialize` implementations for
/// older versions to allow for backwards compatibility.
///
/// The Generic `SerHugr` is always instantiated to the most recent version of
/// the format outside this module.
///
/// Make sure to order the variants from newest to oldest, as the deserializer
/// will try to deserialize them in order.
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "version", rename_all = "lowercase")]
enum Versioned {
enum Versioned<SerHugr> {
/// Version 0 of the HUGR serialization format.
V0,
/// Version 1 of the HUGR serialization format.
V1(SerHugrV1),
V1(SerHugr),

#[serde(other)]
Unsupported,
}

impl<T> Versioned<T> {
pub fn new(t: T) -> Self {
Self::V1(t)
}
}

#[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
struct NodeSer {
parent: Node,
Expand All @@ -62,6 +71,34 @@ struct SerHugrV1 {
encoder: Option<String>,
}

/// Version 1 of the Testing HUGR serialisation format, see `testing_hugr.py`.
#[cfg(test)]
#[derive(Serialize, Deserialize, PartialEq, Debug, Default)]
struct SerTestingV1 {
typ: Option<crate::types::Type>,
sum_type: Option<crate::types::SumType>,
poly_func_type: Option<crate::types::PolyFuncType>,
value: Option<crate::ops::Value>,
}

macro_rules! impl_sertesting_from {
($typ:ty, $field:ident) => {
#[cfg(test)]
impl From<$typ> for SerTestingV1 {
fn from(v: $typ) -> Self {
let mut r: Self = Default::default();
r.$field = Some(v);
r
}
}
};
}

impl_sertesting_from!(crate::types::Type, typ);
impl_sertesting_from!(crate::types::SumType, sum_type);
impl_sertesting_from!(crate::types::PolyFuncType, poly_func_type);
impl_sertesting_from!(crate::ops::Value, value);

/// Errors that can occur while serializing a HUGR.
#[derive(Debug, Clone, PartialEq, Error)]
#[non_exhaustive]
Expand Down Expand Up @@ -99,7 +136,7 @@ impl Serialize for Hugr {
S: serde::Serializer,
{
let shg: SerHugrV1 = self.try_into().map_err(serde::ser::Error::custom)?;
let versioned = Versioned::V1(shg);
let versioned = Versioned::new(shg);
versioned.serialize(serializer)
}
}
Expand All @@ -109,7 +146,7 @@ impl<'de> Deserialize<'de> for Hugr {
where
D: Deserializer<'de>,
{
let shg = Versioned::deserialize(deserializer)?;
let shg: Versioned<SerHugrV1> = Versioned::deserialize(deserializer)?;
match shg {
Versioned::V0 => Err(serde::de::Error::custom(
"Version 0 HUGR serialization format is not supported.",
Expand Down
83 changes: 81 additions & 2 deletions hugr/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@ use crate::builder::{
test::closed_dfg_root_hugr, Container, DFGBuilder, Dataflow, DataflowHugr,
DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::{BOOL_T, USIZE_T};
use crate::extension::prelude::{BOOL_T, QB_T, USIZE_T};
use crate::extension::simple_op::MakeRegisteredOp;
use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::NodeType;
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::Value;
use crate::ops::{dataflow::IOTrait, Input, Module, Noop, Output, DFG};
use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::arithmetic::int_types::{int_custom_type, ConstInt, INT_TYPES};
use crate::std_extensions::logic::NotOp;
use crate::types::{FunctionType, Type};
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{FunctionType, PolyFuncType, SumType, Type, TypeBound};
use crate::{type_row, OutgoingPort};
use itertools::Itertools;
use jsonschema::{Draft, JSONSchema};
Expand All @@ -22,10 +25,13 @@ use portgraph::LinkView;
use portgraph::{
multiportgraph::MultiPortGraph, Hierarchy, LinkMut, PortMut, PortView, UnmanagedDenseMap,
};
use rstest::rstest;

const NAT: Type = crate::extension::prelude::USIZE_T;
const QB: Type = crate::extension::prelude::QB_T;

type TestingModel = SerTestingV1;

lazy_static! {
static ref SCHEMA: JSONSchema = {
let schema_val: serde_json::Value = serde_json::from_str(include_str!(
Expand All @@ -37,6 +43,16 @@ lazy_static! {
.compile(&schema_val)
.expect("Schema is invalid.")
};
static ref TESTING_SCHEMA: JSONSchema = {
let schema_val: serde_json::Value = serde_json::from_str(include_str!(
"../../../../specification/schema/testing_hugr_schema_v1.json"
))
.unwrap();
JSONSchema::options()
.with_draft(Draft::Draft7)
.compile(&schema_val)
.expect("Schema is invalid.")
};
}

#[test]
Expand Down Expand Up @@ -124,6 +140,12 @@ pub fn check_hugr_roundtrip(hugr: &Hugr, check_schema: bool) -> Hugr {
new_hugr
}

fn check_testing_roundtrip(t: TestingModel) {
let before = Versioned::new(t);
let after = ser_roundtrip_validate(&before, Some(&TESTING_SCHEMA));
assert_eq!(before, after);
}

/// Generate an optype for a node with a matching amount of inputs and outputs.
fn gen_optype(g: &MultiPortGraph, node: portgraph::NodeIndex) -> OpType {
let inputs = g.num_inputs(node);
Expand Down Expand Up @@ -312,3 +334,60 @@ fn serialize_types_roundtrip() {
let t = Type::new_unit_sum(4);
assert_eq!(ser_roundtrip(&t), t);
}

#[rstest]
#[case(BOOL_T)]
#[case(USIZE_T)]
#[case(INT_TYPES[2].clone())]
#[case(Type::new_alias(crate::ops::AliasDecl::new("t", TypeBound::Any)))]
#[case(Type::new_var_use(2, TypeBound::Copyable))]
#[case(Type::new_tuple(type_row![BOOL_T,QB_T]))]
#[case(Type::new_sum([type_row![BOOL_T,QB_T], type_row![Type::new_unit_sum(4)]]))]
#[case(Type::new_function(FunctionType::new_endo(type_row![QB_T,BOOL_T,USIZE_T])))]
Comment on lines +339 to +346
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use hypothesis/proptest for at least some of these?

Copy link
Collaborator Author

@doug-q doug-q Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it would be better. Hugr-Valued Values would not be possible, and anything with type variables may be tricky (if it turns out they have to make sense).

I would prefer to leave this as-is for now and convert to proptest in a second pass. Do lmk if you disagree.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On further investigation, it seems like it will be fine to construct and round trip types with typevars that don't make sense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have agreed to try proptest, I hope you are ok merging this PR first?

fn roundtrip_type(#[case] typ: Type) {
check_testing_roundtrip(typ.into())
}

#[rstest]
#[case(SumType::new_unary(2))]
#[case(SumType::new([type_row![USIZE_T, QB_T], type_row![]]))]
fn roundtrip_sumtype(#[case] sum_type: SumType) {
check_testing_roundtrip(sum_type.into())
}

#[rstest]
#[case(Value::unit())]
#[case(Value::true_val())]
#[case(Value::unit_sum(3,5).unwrap())]
#[case(Value::extension(ConstF64::new(-1.5)))]
#[case(Value::extension(ConstF64::new(0.0)))]
#[case(Value::extension(ConstF64::new(-0.0)))]
// These cases fail
// #[case(Value::extension(ConstF64::new(std::f64::NAN)))]
// #[case(Value::extension(ConstF64::new(std::f64::INFINITY)))]
// #[case(Value::extension(ConstF64::new(std::f64::NEG_INFINITY)))]
#[case(Value::extension(ConstF64::new(std::f64::MIN_POSITIVE)))]
#[case(Value::sum(1,[Value::extension(ConstInt::new_u(2,1).unwrap())], SumType::new([vec![], vec![INT_TYPES[2].clone()]])).unwrap())]
#[case(Value::tuple([Value::false_val(), Value::extension(ConstInt::new_s(2,1).unwrap())]))]
#[case(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap())]
fn roundtrip_value(#[case] value: Value) {
check_testing_roundtrip(value.into())
}

fn polyfunctype1() -> PolyFuncType {
let mut extension_set = ExtensionSet::new();
extension_set.insert_type_var(1);
let function_type = FunctionType::new_endo(type_row![]).with_extension_delta(extension_set);
PolyFuncType::new([TypeParam::max_nat(), TypeParam::Extensions], function_type)
}

#[rstest]
#[case(FunctionType::new_endo(type_row![]).into())]
#[case(polyfunctype1())]
#[case(PolyFuncType::new([TypeParam::Opaque { ty: int_custom_type(TypeArg::BoundedNat { n: 1 }) }], FunctionType::new_endo(type_row![Type::new_var_use(0, TypeBound::Copyable)])))]
#[case(PolyFuncType::new([TypeBound::Eq.into()], FunctionType::new_endo(type_row![Type::new_var_use(0, TypeBound::Eq)])))]
#[case(PolyFuncType::new([TypeParam::List { param: Box::new(TypeBound::Any.into()) }], FunctionType::new_endo(type_row![])))]
#[case(PolyFuncType::new([TypeParam::Tuple { params: [TypeBound::Any.into(), TypeParam::bounded_nat(2.try_into().unwrap())].into() }], FunctionType::new_endo(type_row![])))]
fn roundtrip_polyfunctype(#[case] poly_func_type: PolyFuncType) {
check_testing_roundtrip(poly_func_type.into())
}
Loading
Loading