Skip to content

Commit

Permalink
Add support for PolyFuncType in testing_schema, add a few tests
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed Apr 25, 2024
1 parent 7e24fb1 commit df801eb
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 583 deletions.
3 changes: 2 additions & 1 deletion hugr-py/src/hugr/serialization/testing_hugr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Literal, Optional
from pydantic import BaseModel
from .tys import Type, SumType
from .tys import Type, SumType, PolyFuncType
from .ops import Value


Expand All @@ -10,6 +10,7 @@ class TestingHugr(BaseModel):
version: Literal["v1"] = "v1"
typ: Optional[Type] = None
sum_type: Optional[SumType] = None
poly_func_type: Optional[PolyFuncType] = None
value: Optional[Value] = None

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class Array(MultiContainer):


class UnitSum(BaseModel):
"""Simple predicate where all variants are empty tuples."""
"""Simple sum type where all variants are empty tuples."""

s: Literal["Unit"] = "Unit"
size: int
Expand Down Expand Up @@ -255,7 +255,7 @@ def join(*bs: "TypeBound") -> "TypeBound":


class Opaque(BaseModel):
"""An opaque operation that can be downcasted by the extensions that define it."""
"""An opaque Type that can be downcasted by the extensions that define it."""

t: Literal["Opaque"] = "Opaque"
extension: ExtensionId
Expand All @@ -265,7 +265,7 @@ class Opaque(BaseModel):


class Alias(BaseModel):
"""TODO"""
"""An Alias Type"""

t: Literal["Alias"] = "Alias"
bound: TypeBound
Expand Down
25 changes: 23 additions & 2 deletions hugr/src/hugr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,13 +588,14 @@ pub mod test {
}

#[rstest]
#[case(Type::new_unit_sum(1))]
#[case(BOOL_T)]
#[case(USIZE_T)]
#[case(INT_TYPES[2].clone())]
#[case(Type::new_alias(crate::ops::AliasDecl::new("t", TypeBound::Any)))]
#[case(Type::new_var_use(2, TypeBound::Copyable))]
#[case(Type::new_tuple(type_row![BOOL_T,QB_T]))]
#[case(Type::new_sum([type_row![BOOL_T,QB_T], type_row![Type::new_unit_sum(4)]]))]
#[case(Type::new_function(FunctionType::new_endo(type_row![QB_T,BOOL_T,USIZE_T])))]
fn roundtrip_type(#[case] typ: Type) {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct SerTesting {
Expand All @@ -604,7 +605,7 @@ pub mod test {
}

#[rstest]
#[case(SumType::new([type_row![],type_row![]]))]
#[case(SumType::new_unary(2))]
#[case(SumType::new([type_row![USIZE_T, QB_T], type_row![]]))]
fn roundtrip_sumtype(#[case] sum_type: SumType) {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
Expand All @@ -617,6 +618,7 @@ pub mod test {
#[rstest]
#[case(Value::unit())]
#[case(Value::true_val())]
#[case(Value::unit_sum(3,5).unwrap())]
#[case(Value::extension(ConstF64::new(-1.5)))]
#[case(Value::extension(ConstF64::new(0.0)))]
#[case(Value::extension(ConstF64::new(-0.0)))]
Expand All @@ -626,11 +628,30 @@ pub mod test {
// #[case(Value::extension(ConstF64::new(std::f64::NEG_INFINITY)))]
#[case(Value::extension(ConstF64::new(std::f64::MIN_POSITIVE)))]
#[case(Value::sum(1,[Value::extension(ConstIntS::new(2,1).unwrap())], SumType::new([vec![], vec![INT_TYPES[2].clone()]])).unwrap())]
#[case(Value::tuple([Value::false_val(), Value::extension(ConstIntS::new(2,1).unwrap())]))]
#[case(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap())]
fn roundtrip_value(#[case] value: Value) {
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct SerTesting {
value: Value,
}
check_testing_roundtrip(SerTesting { value })
}

// fn polyfunctype1() -> PolyFuncType {
// let mut extension_set = ExtensionSet::new();
// extension_set.insert_type_var(1);
// let function_type = FunctionType::new_endo(type_row![]).with_extension_delta(extension_set);
// PolyFuncType::new([TypeParam::max_nat(), TypeParam::Extensions], function_type)
// }
// #[rstest]
// #[case(FunctionType::new_endo(type_row![]).into())]
// #[case(polyfunctype1())]
// fn roundtrip_polyfunctype(#[case] poly_func_type: PolyFuncType) {
// #[derive(Serialize, Deserialize, PartialEq, Debug)]
// struct SerTesting {
// poly_func_type: PolyFuncType,
// }
// check_testing_roundtrip(SerTesting { poly_func_type })
// }
}
11 changes: 8 additions & 3 deletions hugr/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,17 @@ impl SumType {

let len: usize = rows.len();
if len <= (u8::MAX as usize) && rows.iter().all(TypeRow::is_empty) {
Self::Unit { size: len as u8 }
Self::new_unary(len as u8)
} else {
Self::General { rows }
}
}

/// New UnitSum with empty Tuple variants
pub const fn new_unary(size: u8) -> Self {
Self::Unit { size }
}

/// Report the tag'th variant, if it exists.
pub fn get_variant(&self, tag: usize) -> Option<&TypeRow> {
match self {
Expand Down Expand Up @@ -312,13 +317,13 @@ impl Type {
/// New UnitSum with empty Tuple variants
pub const fn new_unit_sum(size: u8) -> Self {
// should be the only way to avoid going through SumType::new
Self(TypeEnum::Sum(SumType::Unit { size }), TypeBound::Eq)
Self(TypeEnum::Sum(SumType::new_unary(size)), TypeBound::Eq)
}

/// New use (occurrence) of the type variable with specified index.
/// For use in type schemes only: `bound` must match that with which the
/// variable was declared (i.e. as a [TypeParam::Type]`(bound)`).
pub fn new_var_use(idx: usize, bound: TypeBound) -> Self {
pub const fn new_var_use(idx: usize, bound: TypeBound) -> Self {
Self(TypeEnum::Variable(idx, bound), bound)
}

Expand Down
7 changes: 7 additions & 0 deletions hugr/src/types/type_param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,19 @@ impl From<TypeBound> for TypeParam {
}
}

impl From<UpperBound> for TypeParam {
fn from(bound: UpperBound) -> Self {
Self::BoundedNat { bound }
}
}

impl From<Type> for TypeArg {
fn from(ty: Type) -> Self {
Self::Type { ty }
}
}


/// A statically-known argument value to an operation.
#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
#[non_exhaustive]
Expand Down
56 changes: 28 additions & 28 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions specification/schema/hugr_schema_v1.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"$defs": {
"Alias": {
"description": "TODO",
"description": "An Alias Type",
"properties": {
"t": {
"const": "Alias",
Expand Down Expand Up @@ -1063,7 +1063,7 @@
"title": "OpType"
},
"Opaque": {
"description": "An opaque operation that can be downcasted by the extensions that define it.",
"description": "An opaque Type that can be downcasted by the extensions that define it.",
"properties": {
"t": {
"const": "Opaque",
Expand Down Expand Up @@ -1628,7 +1628,7 @@
"type": "object"
},
"UnitSum": {
"description": "Simple predicate where all variants are empty tuples.",
"description": "Simple sum type where all variants are empty tuples.",
"properties": {
"s": {
"const": "Unit",
Expand Down
Loading

0 comments on commit df801eb

Please sign in to comment.