From 0b608a5a9c47f2ef1d8c99e4c3d31fa353166d40 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Sat, 25 May 2024 20:32:28 +0100 Subject: [PATCH] WIP try to parametrize by bool --- hugr/src/extension/op_def.rs | 2 +- hugr/src/types.rs | 159 ++++++++++++++++++++++++----------- hugr/src/types/poly_func.rs | 48 ++++++----- hugr/src/types/serialize.rs | 39 +++++---- hugr/src/types/signature.rs | 73 +++++++++++----- hugr/src/types/type_param.rs | 2 +- hugr/src/types/type_row.rs | 90 ++++++++++++-------- 7 files changed, 262 insertions(+), 151 deletions(-) diff --git a/hugr/src/extension/op_def.rs b/hugr/src/extension/op_def.rs index d91aa68900..5efeccbc7c 100644 --- a/hugr/src/extension/op_def.rs +++ b/hugr/src/extension/op_def.rs @@ -408,7 +408,7 @@ impl OpDef { // The type scheme may contain row variables so be of variable length; // these will have to be substituted to fixed-length concrete types when // the OpDef is instantiated into an actual OpType. - ts.poly_func.validate_var_len(exts)?; + ts.poly_func.validate(exts)?; } Ok(()) } diff --git a/hugr/src/types.rs b/hugr/src/types.rs index cb8cd07068..a1bdc66f41 100644 --- a/hugr/src/types.rs +++ b/hugr/src/types.rs @@ -38,7 +38,7 @@ pub type TypeName = SmolStr; pub type TypeNameRef = str; /// The kinds of edges in a HUGR, excluding Hierarchy. -#[derive(Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)] +#[derive(Clone, PartialEq, Debug, serde::Serialize, serde::Deserialize)] #[non_exhaustive] pub enum EdgeKind { /// Control edges of a CFG region. @@ -130,7 +130,7 @@ pub enum SumType { Unit { size: u8 }, /// General case of a Sum type. #[allow(missing_docs)] - General { rows: Vec }, + General { rows: Vec> }, } impl std::fmt::Display for SumType { @@ -152,7 +152,7 @@ impl SumType { /// Initialize a new sum type. pub fn new(variants: impl IntoIterator) -> Self where - V: Into, + V: Into>, { let rows = variants.into_iter().map(Into::into).collect_vec(); @@ -170,7 +170,7 @@ impl SumType { } /// Report the tag'th variant, if it exists. - pub fn get_variant(&self, tag: usize) -> Option<&TypeRow> { + pub fn get_variant(&self, tag: usize) -> Option<&TypeRow> { match self { SumType::Unit { size } if tag < (*size as usize) => Some(Type::EMPTY_TYPEROW_REF), SumType::General { rows } => rows.get(tag), @@ -187,8 +187,8 @@ impl SumType { } } -impl From for Type { - fn from(sum: SumType) -> Type { +impl From for Type { + fn from(sum: SumType) -> Self { match sum { SumType::Unit { size } => Type::new_unit_sum(size), SumType::General { rows } => Type::new_sum(rows), @@ -199,7 +199,7 @@ impl From for Type { #[derive(Clone, PartialEq, Debug, Eq, derive_more::Display)] #[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))] /// Core types -pub enum TypeEnum { +pub enum TypeEnum { // TODO optimise with Box ? // or some static version of this? #[allow(missing_docs)] @@ -223,14 +223,29 @@ pub enum TypeEnum { #[display(fmt = "Variable({})", _0)] Variable(usize, TypeBound), /// Variable index, and cache of inner TypeBound - matches a [TypeParam::List] of [TypeParam::Type] - /// of this bound (checked in validation) + /// of this bound (checked in validation). Should only exist for `Type` and `TypeEnum`. #[display(fmt = "RowVar({})", _0)] RowVariable(usize, TypeBound), #[allow(missing_docs)] #[display(fmt = "{}", "_0")] Sum(#[cfg_attr(test, proptest(strategy = "any_with::(params)"))] SumType), } -impl TypeEnum { + +/*impl PartialEq for TypeEnum { + fn eq(&self, other: &TypeEnum) -> bool { + match (self, other) { + (TypeEnum::Extension(e1), TypeEnum::Extension(e2)) => e1 == e2, + (TypeEnum::Alias(a1), TypeEnum::Alias(a2)) => a1 == a2, + (TypeEnum::Function(f1), TypeEnum::Function(f2)) => f1==f2, + (TypeEnum::Variable(i1, b1), TypeEnum::Variable(i2, b2)) => i1==i2 && b1==b2, + (TypeEnum::RowVariable(i1, b1), TypeEnum::RowVariable(i2, b2)) => i1==i2 && b1==b2, + (TypeEnum::Sum(s1), TypeEnum::Sum(s2)) => s1 == s2, + _ => false + } + } +}*/ + +impl TypeEnum { /// The smallest type bound that covers the whole type. fn least_upper_bound(&self) -> TypeBound { match self { @@ -249,10 +264,10 @@ impl TypeEnum { } #[derive( - Clone, PartialEq, Debug, Eq, derive_more::Display, serde::Serialize, serde::Deserialize, + Clone, Debug, PartialEq, Eq, derive_more::Display, serde::Serialize, serde::Deserialize, )] #[display(fmt = "{}", "_0")] -#[serde(into = "serialize::SerSimpleType", from = "serialize::SerSimpleType")] +#[serde(into = "serialize::SerSimpleType", try_from = "serialize::SerSimpleType")] /// A HUGR type - the valid types of [EdgeKind::Value] and [EdgeKind::Const] edges. /// Such an edge is valid if the ports on either end agree on the [Type]. /// Types have an optional [TypeBound] which places limits on the valid @@ -273,24 +288,30 @@ impl TypeEnum { /// let func_type = Type::new_function(FunctionType::new_endo(vec![])); /// assert_eq!(func_type.least_upper_bound(), TypeBound::Copyable); /// ``` -pub struct Type(TypeEnum, TypeBound); +pub struct Type(TypeEnum, TypeBound); -impl Type { +/*impl PartialEq for Type { + fn eq(&self, other: &Type) -> bool { + self.0 == other.0 && self.1 == other.1 + } +}*/ + +impl Type { /// An empty `TypeRow`. Provided here for convenience - pub const EMPTY_TYPEROW: TypeRow = type_row![]; + pub const EMPTY_TYPEROW: TypeRow = type_row![]; /// Unit type (empty tuple). pub const UNIT: Self = Self(TypeEnum::Sum(SumType::Unit { size: 1 }), TypeBound::Eq); - const EMPTY_TYPEROW_REF: &'static TypeRow = &Self::EMPTY_TYPEROW; + const EMPTY_TYPEROW_REF: &'static TypeRow = &Self::EMPTY_TYPEROW; /// Initialize a new function type. - pub fn new_function(fun_ty: impl Into) -> Self { + pub fn new_function(fun_ty: impl Into>) -> Self { Self::new(TypeEnum::Function(Box::new(fun_ty.into()))) } /// Initialize a new tuple type by providing the elements. #[inline(always)] - pub fn new_tuple(types: impl Into) -> Self { + pub fn new_tuple(types: impl Into>) -> Self { let row = types.into(); match row.len() { 0 => Self::UNIT, @@ -300,7 +321,7 @@ impl Type { /// Initialize a new sum type by providing the possible variant types. #[inline(always)] - pub fn new_sum(variants: impl IntoIterator) -> Self where { + pub fn new_sum(variants: impl IntoIterator>) -> Self where { Self::new(TypeEnum::Sum(SumType::new(variants))) } @@ -316,7 +337,7 @@ impl Type { Self::new(TypeEnum::Alias(alias)) } - fn new(type_e: TypeEnum) -> Self { + fn new(type_e: TypeEnum) -> Self { let bound = type_e.least_upper_bound(); Self(type_e, bound) } @@ -335,19 +356,6 @@ impl Type { Self(TypeEnum::Variable(idx, bound), bound) } - /// New use (occurrence) of the row variable with specified index. - /// `bound` must be exactly that with which the variable was declared - /// (i.e. as a [TypeParam::List]` of a `[TypeParam::Type]` of that bound), - /// which may be narrower than required for the use. - /// For use in [OpDef] type schemes, or function types, only, - /// not [FuncDefn] type schemes or as a Hugr port type. - /// - /// [OpDef]: crate::extension::OpDef - /// [FuncDefn]: crate::ops::FuncDefn - pub const fn new_row_var_use(idx: usize, bound: TypeBound) -> Self { - Self(TypeEnum::RowVariable(idx, bound), bound) - } - /// Report the least upper [TypeBound] #[inline(always)] pub const fn least_upper_bound(&self) -> TypeBound { @@ -356,7 +364,7 @@ impl Type { /// Report the component TypeEnum. #[inline(always)] - pub const fn as_type_enum(&self) -> &TypeEnum { + pub const fn as_type_enum(&self) -> &TypeEnum { &self.0 } @@ -382,7 +390,6 @@ impl Type { /// [TypeDef]: crate::extension::TypeDef pub(crate) fn validate( &self, - allow_row_vars: bool, extension_registry: &ExtensionRegistry, var_decls: &[TypeParam], ) -> Result<(), SignatureError> { @@ -391,16 +398,16 @@ impl Type { match &self.0 { TypeEnum::Sum(SumType::General { rows }) => rows .iter() - .try_for_each(|row| row.validate_var_len(extension_registry, var_decls)), + .try_for_each(|row| row.validate(extension_registry, var_decls)), TypeEnum::Sum(SumType::Unit { .. }) => Ok(()), // No leaves there TypeEnum::Alias(_) => Ok(()), TypeEnum::Extension(custy) => custy.validate(extension_registry, var_decls), // Function values may be passed around without knowing their arity // (i.e. with row vars) as long as they are not called: - TypeEnum::Function(ft) => ft.validate_var_len(extension_registry, var_decls), + TypeEnum::Function(ft) => ft.validate(extension_registry, var_decls), TypeEnum::Variable(idx, bound) => check_typevar_decl(var_decls, *idx, &(*bound).into()), TypeEnum::RowVariable(idx, bound) => { - if allow_row_vars { + if RV { check_typevar_decl(var_decls, *idx, &TypeParam::new_list(*bound)) } else { Err(SignatureError::RowVarWhereTypeExpected { idx: *idx }) @@ -411,29 +418,68 @@ impl Type { /// Applies a substitution to a type. /// This may result in a row of types, if this [Type] is not really a single type but actually a row variable - /// Invariants may be confirmed by validation: - /// * If [Type::validate]`(false)` returns successfully, this method will return a Vec containing exactly one type - /// * If [Type::validate]`(false)` fails, but `(true)` succeeds, this method may (depending on structure of self) - /// return a Vec containing any number of [Type]s. These may (or not) pass [Type::validate] - fn substitute(&self, t: &Substitution) -> Vec { + /// (of course this can only occur for a `Type`, will always return exactly one element. + fn subst_vec(&self, s: &Substitution) -> Vec { match &self.0 { - TypeEnum::RowVariable(idx, bound) => t.apply_rowvar(*idx, *bound), - TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone()], + TypeEnum::RowVariable(idx, bound) => + if RV {s.apply_rowvar(idx, bound)} // ALAN Argh, type error here as Type != Type even inside "if RV" + else {panic!("Row Variable outside Row - should not have validated?")}, + TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone().into()], TypeEnum::Variable(idx, bound) => { - let TypeArg::Type { ty } = t.apply_var(*idx, &((*bound).into())) else { + let TypeArg::Type { ty } = s.apply_var(*idx, &((*bound).into())) else { panic!("Variable was not a type - try validate() first") }; - vec![ty] + vec![ty] // ALAN argh, can't convert parametrically from Type to Type } - TypeEnum::Extension(cty) => vec![Type::new_extension(cty.substitute(t))], - TypeEnum::Function(bf) => vec![Type::new_function(bf.substitute(t))], + TypeEnum::Extension(cty) => vec![Type::new_extension(cty.substitute(s))], + TypeEnum::Function(bf) => vec![Type::new_function(bf.substitute(s))], TypeEnum::Sum(SumType::General { rows }) => { - vec![Type::new_sum(rows.iter().map(|r| r.substitute(t)))] + vec![Type::new_sum(rows.iter().map(|r| r.substitute(s)))] } } } } +impl TryFrom> for Type { + type Error = ConvertError; + fn try_from(value: Type) -> Result { + Ok(Self(match value.0 { + TypeEnum::Extension(e) => TypeEnum::Extension(e), + TypeEnum::Alias(a) => TypeEnum::Alias(a), + TypeEnum::Function(ft) => TypeEnum::Function(ft), + TypeEnum::Variable(i, b) => TypeEnum::Variable(i,b), + TypeEnum::RowVariable(_, _) => return Err(ConvertError), + TypeEnum::Sum(st) => TypeEnum::Sum(st) + }, value.1)) + } +} + +impl Type { + fn substitute(&self, s: &Substitution) -> Self { + let v = self.subst_vec(s); + let [r] = v.try_into().unwrap(); // No row vars, so every Type produces exactly one + r + } +} + +impl Type { + /// New use (occurrence) of the row variable with specified index. + /// `bound` must match that with which the variable was declared + /// (i.e. as a [TypeParam::List]` of a `[TypeParam::Type]` of that bound). + /// For use in [OpDef], not [FuncDefn], type schemes only. + /// + /// [OpDef]: crate::extension::OpDef + /// [FuncDefn]: crate::ops::FuncDefn + pub const fn new_row_var(idx: usize, bound: TypeBound) -> Self { + Self(TypeEnum::RowVariable(idx, bound), bound) + } + + fn substitute(&self, s: &Substitution) -> Vec { + self.subst_vec(s) + } + +} + /// Details a replacement of type variables with a finite list of known values. /// (Variables out of the range of the list will result in a panic) pub(crate) struct Substitution<'a>(&'a [TypeArg], &'a ExtensionRegistry); @@ -448,7 +494,7 @@ impl<'a> Substitution<'a> { arg.clone() } - fn apply_rowvar(&self, idx: usize, bound: TypeBound) -> Vec { + fn apply_rowvar(&self, idx: usize, bound: TypeBound) -> Vec> { let arg = self .0 .get(idx) @@ -476,6 +522,19 @@ impl<'a> Substitution<'a> { } } +impl From> for Type { + fn from(value: Type) -> Self { + Self(match value.0 { + TypeEnum::Alias(a) => TypeEnum::Alias(a), + TypeEnum::Extension(e) => TypeEnum::Extension(e), + TypeEnum::Function(ft) => TypeEnum::Function(ft), + TypeEnum::Variable(idx, bound) => TypeEnum::Variable(idx, bound), + TypeEnum::RowVariable(_, _) => panic!("Type should not contain row variables"), + TypeEnum::Sum(st) => TypeEnum::Sum(st), + }, value.1) + } +} + pub(crate) fn check_typevar_decl( decls: &[TypeParam], idx: usize, diff --git a/hugr/src/types/poly_func.rs b/hugr/src/types/poly_func.rs index e152a3c607..fb66fa2de3 100644 --- a/hugr/src/types/poly_func.rs +++ b/hugr/src/types/poly_func.rs @@ -27,7 +27,7 @@ use super::{FunctionType, Substitution}; "params.iter().map(ToString::to_string).join(\" \")", "body" )] -pub struct PolyFuncType { +pub struct PolyFuncType { /// The declared type parameters, i.e., these must be instantiated with /// the same number of [TypeArg]s before the function can be called. This /// defines the indices used by variables inside the body. @@ -35,11 +35,11 @@ pub struct PolyFuncType { params: Vec, /// Template for the function. May contain variables up to length of [Self::params] #[cfg_attr(test, proptest(strategy = "any_with::(params)"))] - body: FunctionType, + body: FunctionType, } -impl From for PolyFuncType { - fn from(body: FunctionType) -> Self { +impl From> for PolyFuncType { + fn from(body: FunctionType) -> Self { Self { params: vec![], body, @@ -47,7 +47,7 @@ impl From for PolyFuncType { } } -impl TryFrom for FunctionType { +impl TryFrom for FunctionType { /// If the PolyFuncType is not a monomorphic FunctionType, fail with the binders type Error = Vec; @@ -60,36 +60,26 @@ impl TryFrom for FunctionType { } } -impl PolyFuncType { +impl PolyFuncType { /// The type parameters, aka binders, over which this type is polymorphic pub fn params(&self) -> &[TypeParam] { &self.params } /// The body of the type, a function type. - pub fn body(&self) -> &FunctionType { + pub fn body(&self) -> &FunctionType { &self.body } /// Create a new PolyFuncType given the kinds of the variables it declares /// and the underlying [FunctionType]. - pub fn new(params: impl Into>, body: FunctionType) -> Self { + pub fn new(params: impl Into>, body: FunctionType) -> Self { Self { params: params.into(), body, } } - /// Validates this instance, checking that the types in the body are - /// wellformed with respect to the registry, and the type variables declared. - /// Allows both inputs and outputs to contain [RowVariable]s - /// - /// [RowVariable]: [crate::types::TypeEnum::RowVariable] - pub fn validate_var_len(&self, reg: &ExtensionRegistry) -> Result<(), SignatureError> { - // TODO https://github.com/CQCL/hugr/issues/624 validate TypeParams declared here, too - self.body.validate_var_len(reg, &self.params) - } - /// Instantiates an outer [PolyFuncType], i.e. with no free variables /// (as ensured by [Self::validate]), into a monomorphic type. /// @@ -100,7 +90,7 @@ impl PolyFuncType { &self, args: &[TypeArg], ext_reg: &ExtensionRegistry, - ) -> Result { + ) -> Result, SignatureError> { // Check that args are applicable, and that we have a value for each binder, // i.e. each possible free variable within the body. check_type_args(args, &self.params)?; @@ -108,6 +98,18 @@ impl PolyFuncType { } } +impl PolyFuncType { + /// Validates this instance, checking that the types in the body are + /// wellformed with respect to the registry, and the type variables declared. + /// Allows both inputs and outputs to contain [RowVariable]s if-and-only-if RV is true + /// + /// [RowVariable]: [crate::types::TypeEnum::RowVariable] + pub fn validate(&self, reg: &ExtensionRegistry) -> Result<(), SignatureError> { + // TODO https://github.com/CQCL/hugr/issues/624 validate TypeParams declared here, too + self.body.validate(reg, &self.params) + } +} + #[cfg(test)] pub(crate) mod test { use std::num::NonZeroU64; @@ -132,14 +134,14 @@ pub(crate) mod test { ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned()]).unwrap(); } - impl PolyFuncType { + impl PolyFuncType { fn new_validated( params: impl Into>, - body: FunctionType, + body: impl Into>, extension_registry: &ExtensionRegistry, ) -> Result { - let res = Self::new(params, body); - res.validate_var_len(extension_registry)?; + let res = Self::new(params, body.into()); + res.validate(extension_registry)?; Ok(res) } } diff --git a/hugr/src/types/serialize.rs b/hugr/src/types/serialize.rs index e4585d2f54..d086a91e8e 100644 --- a/hugr/src/types/serialize.rs +++ b/hugr/src/types/serialize.rs @@ -10,7 +10,7 @@ use crate::ops::AliasDecl; pub(super) enum SerSimpleType { Q, I, - G(Box), + G(Box>), Sum(SumType), Array { inner: Box, len: u64 }, Opaque(CustomType), @@ -19,17 +19,14 @@ pub(super) enum SerSimpleType { R { i: usize, b: TypeBound }, } -impl From for SerSimpleType { - fn from(value: Type) -> Self { - if value == QB_T { - return SerSimpleType::Q; - } - if value == USIZE_T { - return SerSimpleType::I; - } - // TODO short circuiting for array. - let Type(value, _) = value; - match value { +impl From> for SerSimpleType { + fn from(value: Type) -> Self { + // ALAN argh these comparisons fail. If we define Type as implementing PartialEq, + // they succeed, but that leads to other problems. + // Similarly, we can't compare value==QB_T.into() because we cannot `impl From for Type` + if value == QB_T { return SerSimpleType::Q }; + if value == USIZE_T { return SerSimpleType::I }; + match value.0 { TypeEnum::Extension(o) => SerSimpleType::Opaque(o), TypeEnum::Alias(a) => SerSimpleType::Alias(a), TypeEnum::Function(sig) => SerSimpleType::G(sig), @@ -40,20 +37,22 @@ impl From for SerSimpleType { } } -impl From for Type { - fn from(value: SerSimpleType) -> Type { - match value { - SerSimpleType::Q => QB_T, - SerSimpleType::I => USIZE_T, +impl TryFrom for Type { + type Error; + fn try_from(value: SerSimpleType) -> Result { + Ok(match value { + SerSimpleType::Q => QB_T.into(), + SerSimpleType::I => USIZE_T.into(), SerSimpleType::G(sig) => Type::new_function(*sig), SerSimpleType::Sum(st) => st.into(), SerSimpleType::Array { inner, len } => { - array_type(TypeArg::BoundedNat { n: len }, (*inner).into()) + array_type(TypeArg::BoundedNat { n: len }, (*inner).try_into().unwrap()).into() } SerSimpleType::Opaque(o) => Type::new_extension(o), SerSimpleType::Alias(a) => Type::new_alias(a), SerSimpleType::V { i, b } => Type::new_var_use(i, b), - SerSimpleType::R { i, b } => Type::new_row_var_use(i, b), - } + // ALAN ugh, can't use new_row_var because that returns Type not Type. + value@SerSimpleType::R { i, b } => if RV {Type(TypeEnum::RowVariable(i, b), b)} else {return Err(format!("Row Variable {:?} serialized where no row vars allowed", value))} + }) } } diff --git a/hugr/src/types/signature.rs b/hugr/src/types/signature.rs index 1cf731eda0..dee097978c 100644 --- a/hugr/src/types/signature.rs +++ b/hugr/src/types/signature.rs @@ -21,35 +21,41 @@ use {crate::proptest::RecursionDepth, ::proptest::prelude::*, proptest_derive::A /// and also the target (value) of a call (static). /// /// [Graph]: crate::ops::constant::Value::Function -pub struct FunctionType { +pub struct FunctionType { /// Value inputs of the function. #[cfg_attr(test, proptest(strategy = "any_with::(params)"))] - pub input: TypeRow, + pub input: TypeRow, /// Value outputs of the function. #[cfg_attr(test, proptest(strategy = "any_with::(params)"))] - pub output: TypeRow, + pub output: TypeRow, /// The extension requirements which are added by the operation pub extension_reqs: ExtensionSet, } -impl FunctionType { +pub type Signature = FunctionType; + +impl FunctionType { /// Builder method, add extension_reqs to an FunctionType pub fn with_extension_delta(mut self, rs: impl Into) -> Self { self.extension_reqs = self.extension_reqs.union(rs.into()); self } +} - pub(super) fn validate_var_len( +impl FunctionType { + pub(super) fn validate( &self, extension_registry: &ExtensionRegistry, var_decls: &[TypeParam], ) -> Result<(), SignatureError> { - self.input.validate_var_len(extension_registry, var_decls)?; + self.input.validate(extension_registry, var_decls)?; self.output - .validate_var_len(extension_registry, var_decls)?; + .validate(extension_registry, var_decls)?; self.extension_reqs.validate(var_decls) } +} +impl FunctionType { pub(crate) fn substitute(&self, tr: &Substitution) -> Self { FunctionType { input: self.input.substitute(tr), @@ -57,19 +63,15 @@ impl FunctionType { extension_reqs: self.extension_reqs.substitute(tr), } } -} -impl FunctionType { /// The number of wires in the signature. #[inline(always)] pub fn is_empty(&self) -> bool { self.input.is_empty() && self.output.is_empty() } -} -impl FunctionType { /// Create a new signature with specified inputs and outputs. - pub fn new(input: impl Into, output: impl Into) -> Self { + pub fn new(input: impl Into>, output: impl Into>) -> Self { Self { input: input.into(), output: output.into(), @@ -78,11 +80,13 @@ impl FunctionType { } /// Create a new signature with the same input and output types (signature of an endomorphic /// function). - pub fn new_endo(linear: impl Into) -> Self { + pub fn new_endo(linear: impl Into>) -> Self { let linear = linear.into(); Self::new(linear.clone(), linear) } +} +impl FunctionType { /// Returns the type of a value [`Port`]. Returns `None` if the port is out /// of bounds. #[inline] @@ -157,10 +161,12 @@ impl FunctionType { pub fn output_count(&self) -> usize { self.port_count(Direction::Outgoing) } +} +impl FunctionType { /// Returns a slice of the types for the given direction. #[inline] - pub fn types(&self, dir: Direction) -> &[Type] { + pub fn types(&self, dir: Direction) -> &[Type] { match dir { Direction::Incoming => &self.input, Direction::Outgoing => &self.output, @@ -169,31 +175,33 @@ impl FunctionType { /// Returns a slice of the input types. #[inline] - pub fn input_types(&self) -> &[Type] { + pub fn input_types(&self) -> &[Type] { self.types(Direction::Incoming) } /// Returns a slice of the output types. #[inline] - pub fn output_types(&self) -> &[Type] { + pub fn output_types(&self) -> &[Type] { self.types(Direction::Outgoing) } #[inline] /// Returns the input row - pub fn input(&self) -> &TypeRow { + pub fn input(&self) -> &TypeRow { &self.input } #[inline] /// Returns the output row - pub fn output(&self) -> &TypeRow { + pub fn output(&self) -> &TypeRow { &self.output } } -impl FunctionType { +impl FunctionType { /// If this FunctionType contains any row variables, return one. + /// (Note, we could define for FunctionType; obviously will only + /// ever return Some if RV==true) pub fn find_rowvar(&self) -> Option<(usize, TypeBound)> { self.input .iter() @@ -203,7 +211,8 @@ impl FunctionType { _ => None, }) } - +} +impl FunctionType { /// Returns the `Port`s in the signature for a given direction. #[inline] pub fn ports(&self, dir: Direction) -> impl Iterator { @@ -225,7 +234,29 @@ impl FunctionType { } } -impl Display for FunctionType { +impl From> for FunctionType { + fn from(value: FunctionType) -> Self { + Self { + input: value.input.into(), + output: value.input.into(), + extension_reqs: value.extension_reqs + } + } +} + +impl TryFrom> for FunctionType { + type Error; + + fn try_from(value: FunctionType) -> Result { + Ok(Self { + input: value.input.try_into()?, + output: value.output.try_into()?, + extension_reqs: value.extension_reqs + }) + } +} + +impl Display for FunctionType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if !self.input.is_empty() { self.input.fmt(f)?; diff --git a/hugr/src/types/type_param.rs b/hugr/src/types/type_param.rs index e7019fe7c4..567c41e489 100644 --- a/hugr/src/types/type_param.rs +++ b/hugr/src/types/type_param.rs @@ -142,7 +142,7 @@ impl From for TypeParam { } /// A statically-known argument value to an operation. -#[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)] #[non_exhaustive] #[serde(tag = "tya")] pub enum TypeArg { diff --git a/hugr/src/types/type_row.rs b/hugr/src/types/type_row.rs index 5be3ddc2f1..bdfaee7cab 100644 --- a/hugr/src/types/type_row.rs +++ b/hugr/src/types/type_row.rs @@ -19,12 +19,12 @@ use itertools::Itertools; #[derive(Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)] #[non_exhaustive] #[serde(transparent)] -pub struct TypeRow { +pub struct TypeRow { /// The datatypes in the row. - types: Cow<'static, [Type]>, + types: Cow<'static, [Type]>, } -impl Display for TypeRow { +impl Display for TypeRow { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_char('[')?; display_list(self.types.as_ref(), f)?; @@ -32,7 +32,7 @@ impl Display for TypeRow { } } -impl TypeRow { +impl TypeRow { /// Create a new empty row. pub const fn new() -> Self { Self { @@ -41,31 +41,48 @@ impl TypeRow { } /// Returns a new `TypeRow` with `xs` concatenated onto `self`. - pub fn extend<'a>(&'a self, rest: impl IntoIterator) -> Self { + pub fn extend<'a>(&'a self, rest: impl IntoIterator>) -> Self { self.iter().chain(rest).cloned().collect_vec().into() } /// Returns a reference to the types in the row. - pub fn as_slice(&self) -> &[Type] { + pub fn as_slice(&self) -> &[Type] { &self.types } + /// Applies a substitution to the row. + /// For `TypeRow`, note this may change the length of the row. + /// For `TypeRow`, guaranteed not to change the length of the row. + pub(super) fn substitute(&self, s: &Substitution) -> Self { + self + .iter() + .flat_map(|ty| ty.subst_vec(s)) + .collect::>() + .into() + } + delegate! { to self.types { /// Iterator over the types in the row. - pub fn iter(&self) -> impl Iterator; - - /// Returns the number of types in the row. - pub fn len(&self) -> usize; + pub fn iter(&self) -> impl Iterator>; /// Mutable vector of the types in the row. - pub fn to_mut(&mut self) -> &mut Vec; + pub fn to_mut(&mut self) -> &mut Vec>; /// Allow access (consumption) of the contained elements - pub fn into_owned(self) -> Vec; + pub fn into_owned(self) -> Vec>; /// Returns `true` if the row contains no types. pub fn is_empty(&self) -> bool ; + } + } +} + +impl TypeRow { + delegate! { + to self.types { + /// Returns the number of types in the row. + pub fn len(&self) -> usize; #[inline(always)] /// Returns the type at the specified index. Returns `None` if out of bounds. @@ -78,39 +95,42 @@ impl TypeRow { pub fn get_mut(&mut self, offset: usize) -> Option<&mut Type>; } } +} - /// Applies a substitution to the row. Note this may change the length - /// if-and-only-if the row contains any [RowVariable]s. - /// - /// [RowVariable]: [crate::types::TypeEnum::RowVariable] - pub(super) fn substitute(&self, tr: &Substitution) -> TypeRow { - let res = self - .iter() - .flat_map(|ty| ty.substitute(tr)) - .collect::>() - .into(); - res - } - - pub(super) fn validate_var_len( +impl TypeRow { + pub(super) fn validate( &self, exts: &ExtensionRegistry, var_decls: &[TypeParam], ) -> Result<(), SignatureError> { self.iter() - .try_for_each(|t| t.validate(true, exts, var_decls)) + .try_for_each(|t| t.validate(exts, var_decls)) + } +} + +impl From for TypeRow { + fn from(value: TypeRow) -> Self { + Self::from(value.into_owned().into_iter().map_into().collect::>>()) + } +} + +impl TryFrom> for TypeRow { + type Error; + + fn try_from(value: TypeRow) -> Result { + Ok(Self::from(value.into_owned().into_iter().map(|t| t.try_into()).collect::,_>>()?)) } } -impl Default for TypeRow { +impl Default for TypeRow { fn default() -> Self { Self::new() } } -impl From for TypeRow +impl From for TypeRow where - F: Into>, + F: Into]>>, { fn from(types: F) -> Self { Self { @@ -119,23 +139,23 @@ where } } -impl From for TypeRow { - fn from(t: Type) -> Self { +impl From> for TypeRow { + fn from(t: Type) -> Self { Self { types: vec![t].into(), } } } -impl Deref for TypeRow { - type Target = [Type]; +impl Deref for TypeRow { + type Target = [Type]; fn deref(&self) -> &Self::Target { self.as_slice() } } -impl DerefMut for TypeRow { +impl DerefMut for TypeRow { fn deref_mut(&mut self) -> &mut Self::Target { self.types.to_mut() }