Skip to content

Commit

Permalink
WIP try to parametrize by bool
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed May 25, 2024
1 parent d6d9a4e commit 0b608a5
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 151 deletions.
2 changes: 1 addition & 1 deletion hugr/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ impl OpDef {
// The type scheme may contain row variables so be of variable length;
// these will have to be substituted to fixed-length concrete types when
// the OpDef is instantiated into an actual OpType.
ts.poly_func.validate_var_len(exts)?;
ts.poly_func.validate(exts)?;
}
Ok(())
}
Expand Down
159 changes: 109 additions & 50 deletions hugr/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub type TypeName = SmolStr;
pub type TypeNameRef = str;

/// The kinds of edges in a HUGR, excluding Hierarchy.
#[derive(Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)]
#[derive(Clone, PartialEq, Debug, serde::Serialize, serde::Deserialize)]
#[non_exhaustive]
pub enum EdgeKind {
/// Control edges of a CFG region.
Expand Down Expand Up @@ -130,7 +130,7 @@ pub enum SumType {
Unit { size: u8 },
/// General case of a Sum type.
#[allow(missing_docs)]
General { rows: Vec<TypeRow> },
General { rows: Vec<TypeRow<true>> },
}

impl std::fmt::Display for SumType {
Expand All @@ -152,7 +152,7 @@ impl SumType {
/// Initialize a new sum type.
pub fn new<V>(variants: impl IntoIterator<Item = V>) -> Self
where
V: Into<TypeRow>,
V: Into<TypeRow<true>>,
{
let rows = variants.into_iter().map(Into::into).collect_vec();

Expand All @@ -170,7 +170,7 @@ impl SumType {
}

/// Report the tag'th variant, if it exists.
pub fn get_variant(&self, tag: usize) -> Option<&TypeRow> {
pub fn get_variant(&self, tag: usize) -> Option<&TypeRow<true>> {
match self {
SumType::Unit { size } if tag < (*size as usize) => Some(Type::EMPTY_TYPEROW_REF),
SumType::General { rows } => rows.get(tag),
Expand All @@ -187,8 +187,8 @@ impl SumType {
}
}

impl From<SumType> for Type {
fn from(sum: SumType) -> Type {
impl <const RV:bool> From<SumType> for Type<RV> {
fn from(sum: SumType) -> Self {
match sum {
SumType::Unit { size } => Type::new_unit_sum(size),
SumType::General { rows } => Type::new_sum(rows),
Expand All @@ -199,7 +199,7 @@ impl From<SumType> for Type {
#[derive(Clone, PartialEq, Debug, Eq, derive_more::Display)]
#[cfg_attr(test, derive(Arbitrary), proptest(params = "RecursionDepth"))]
/// Core types
pub enum TypeEnum {
pub enum TypeEnum<const ROWVARS:bool=false> {
// TODO optimise with Box<CustomType> ?
// or some static version of this?
#[allow(missing_docs)]
Expand All @@ -223,14 +223,29 @@ pub enum TypeEnum {
#[display(fmt = "Variable({})", _0)]
Variable(usize, TypeBound),
/// Variable index, and cache of inner TypeBound - matches a [TypeParam::List] of [TypeParam::Type]
/// of this bound (checked in validation)
/// of this bound (checked in validation). Should only exist for `Type<true>` and `TypeEnum<true>`.
#[display(fmt = "RowVar({})", _0)]
RowVariable(usize, TypeBound),
#[allow(missing_docs)]
#[display(fmt = "{}", "_0")]
Sum(#[cfg_attr(test, proptest(strategy = "any_with::<SumType>(params)"))] SumType),
}
impl TypeEnum {

/*impl <const RV:bool> PartialEq<TypeEnum> for TypeEnum<RV> {
fn eq(&self, other: &TypeEnum) -> bool {
match (self, other) {
(TypeEnum::Extension(e1), TypeEnum::Extension(e2)) => e1 == e2,
(TypeEnum::Alias(a1), TypeEnum::Alias(a2)) => a1 == a2,
(TypeEnum::Function(f1), TypeEnum::Function(f2)) => f1==f2,
(TypeEnum::Variable(i1, b1), TypeEnum::Variable(i2, b2)) => i1==i2 && b1==b2,
(TypeEnum::RowVariable(i1, b1), TypeEnum::RowVariable(i2, b2)) => i1==i2 && b1==b2,
(TypeEnum::Sum(s1), TypeEnum::Sum(s2)) => s1 == s2,
_ => false
}
}
}*/

impl <const RV:bool> TypeEnum<RV> {
/// The smallest type bound that covers the whole type.
fn least_upper_bound(&self) -> TypeBound {
match self {
Expand All @@ -249,10 +264,10 @@ impl TypeEnum {
}

#[derive(
Clone, PartialEq, Debug, Eq, derive_more::Display, serde::Serialize, serde::Deserialize,
Clone, Debug, PartialEq, Eq, derive_more::Display, serde::Serialize, serde::Deserialize,
)]
#[display(fmt = "{}", "_0")]
#[serde(into = "serialize::SerSimpleType", from = "serialize::SerSimpleType")]
#[serde(into = "serialize::SerSimpleType", try_from = "serialize::SerSimpleType")]
/// A HUGR type - the valid types of [EdgeKind::Value] and [EdgeKind::Const] edges.
/// Such an edge is valid if the ports on either end agree on the [Type].
/// Types have an optional [TypeBound] which places limits on the valid
Expand All @@ -273,24 +288,30 @@ impl TypeEnum {
/// let func_type = Type::new_function(FunctionType::new_endo(vec![]));
/// assert_eq!(func_type.least_upper_bound(), TypeBound::Copyable);
/// ```
pub struct Type(TypeEnum, TypeBound);
pub struct Type<const ROWVARS:bool=false>(TypeEnum<ROWVARS>, TypeBound);

impl Type {
/*impl<const RV:bool> PartialEq<Type> for Type<RV> {
fn eq(&self, other: &Type) -> bool {
self.0 == other.0 && self.1 == other.1
}
}*/

impl<const RV:bool> Type<RV> {
/// An empty `TypeRow`. Provided here for convenience
pub const EMPTY_TYPEROW: TypeRow = type_row![];
pub const EMPTY_TYPEROW: TypeRow<RV> = type_row![];
/// Unit type (empty tuple).
pub const UNIT: Self = Self(TypeEnum::Sum(SumType::Unit { size: 1 }), TypeBound::Eq);

const EMPTY_TYPEROW_REF: &'static TypeRow = &Self::EMPTY_TYPEROW;
const EMPTY_TYPEROW_REF: &'static TypeRow<RV> = &Self::EMPTY_TYPEROW;

/// Initialize a new function type.
pub fn new_function(fun_ty: impl Into<FunctionType>) -> Self {
pub fn new_function(fun_ty: impl Into<FunctionType<true>>) -> Self {
Self::new(TypeEnum::Function(Box::new(fun_ty.into())))
}

/// Initialize a new tuple type by providing the elements.
#[inline(always)]
pub fn new_tuple(types: impl Into<TypeRow>) -> Self {
pub fn new_tuple(types: impl Into<TypeRow<true>>) -> Self {
let row = types.into();
match row.len() {
0 => Self::UNIT,
Expand All @@ -300,7 +321,7 @@ impl Type {

/// Initialize a new sum type by providing the possible variant types.
#[inline(always)]
pub fn new_sum(variants: impl IntoIterator<Item = TypeRow>) -> Self where {
pub fn new_sum(variants: impl IntoIterator<Item = TypeRow<true>>) -> Self where {
Self::new(TypeEnum::Sum(SumType::new(variants)))
}

Expand All @@ -316,7 +337,7 @@ impl Type {
Self::new(TypeEnum::Alias(alias))
}

fn new(type_e: TypeEnum) -> Self {
fn new(type_e: TypeEnum<RV>) -> Self {
let bound = type_e.least_upper_bound();
Self(type_e, bound)
}
Expand All @@ -335,19 +356,6 @@ impl Type {
Self(TypeEnum::Variable(idx, bound), bound)
}

/// New use (occurrence) of the row variable with specified index.
/// `bound` must be exactly that with which the variable was declared
/// (i.e. as a [TypeParam::List]` of a `[TypeParam::Type]` of that bound),
/// which may be narrower than required for the use.
/// For use in [OpDef] type schemes, or function types, only,
/// not [FuncDefn] type schemes or as a Hugr port type.
///
/// [OpDef]: crate::extension::OpDef
/// [FuncDefn]: crate::ops::FuncDefn
pub const fn new_row_var_use(idx: usize, bound: TypeBound) -> Self {
Self(TypeEnum::RowVariable(idx, bound), bound)
}

/// Report the least upper [TypeBound]
#[inline(always)]
pub const fn least_upper_bound(&self) -> TypeBound {
Expand All @@ -356,7 +364,7 @@ impl Type {

/// Report the component TypeEnum.
#[inline(always)]
pub const fn as_type_enum(&self) -> &TypeEnum {
pub const fn as_type_enum(&self) -> &TypeEnum<RV> {
&self.0
}

Expand All @@ -382,7 +390,6 @@ impl Type {
/// [TypeDef]: crate::extension::TypeDef
pub(crate) fn validate(
&self,
allow_row_vars: bool,
extension_registry: &ExtensionRegistry,
var_decls: &[TypeParam],
) -> Result<(), SignatureError> {
Expand All @@ -391,16 +398,16 @@ impl Type {
match &self.0 {
TypeEnum::Sum(SumType::General { rows }) => rows
.iter()
.try_for_each(|row| row.validate_var_len(extension_registry, var_decls)),
.try_for_each(|row| row.validate(extension_registry, var_decls)),
TypeEnum::Sum(SumType::Unit { .. }) => Ok(()), // No leaves there
TypeEnum::Alias(_) => Ok(()),
TypeEnum::Extension(custy) => custy.validate(extension_registry, var_decls),
// Function values may be passed around without knowing their arity
// (i.e. with row vars) as long as they are not called:
TypeEnum::Function(ft) => ft.validate_var_len(extension_registry, var_decls),
TypeEnum::Function(ft) => ft.validate(extension_registry, var_decls),
TypeEnum::Variable(idx, bound) => check_typevar_decl(var_decls, *idx, &(*bound).into()),
TypeEnum::RowVariable(idx, bound) => {
if allow_row_vars {
if RV {
check_typevar_decl(var_decls, *idx, &TypeParam::new_list(*bound))
} else {
Err(SignatureError::RowVarWhereTypeExpected { idx: *idx })
Expand All @@ -411,29 +418,68 @@ impl Type {

/// Applies a substitution to a type.
/// This may result in a row of types, if this [Type] is not really a single type but actually a row variable
/// Invariants may be confirmed by validation:
/// * If [Type::validate]`(false)` returns successfully, this method will return a Vec containing exactly one type
/// * If [Type::validate]`(false)` fails, but `(true)` succeeds, this method may (depending on structure of self)
/// return a Vec containing any number of [Type]s. These may (or not) pass [Type::validate]
fn substitute(&self, t: &Substitution) -> Vec<Self> {
/// (of course this can only occur for a `Type<true``). For a `Type<false>`, will always return exactly one element.
fn subst_vec(&self, s: &Substitution) -> Vec<Self> {
match &self.0 {
TypeEnum::RowVariable(idx, bound) => t.apply_rowvar(*idx, *bound),
TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone()],
TypeEnum::RowVariable(idx, bound) =>
if RV {s.apply_rowvar(idx, bound)} // ALAN Argh, type error here as Type<true> != Type<RV> even inside "if RV"
else {panic!("Row Variable outside Row - should not have validated?")},
TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone().into()],
TypeEnum::Variable(idx, bound) => {
let TypeArg::Type { ty } = t.apply_var(*idx, &((*bound).into())) else {
let TypeArg::Type { ty } = s.apply_var(*idx, &((*bound).into())) else {
panic!("Variable was not a type - try validate() first")
};
vec![ty]
vec![ty] // ALAN argh, can't convert parametrically from Type<false> to Type<RV>
}
TypeEnum::Extension(cty) => vec![Type::new_extension(cty.substitute(t))],
TypeEnum::Function(bf) => vec![Type::new_function(bf.substitute(t))],
TypeEnum::Extension(cty) => vec![Type::new_extension(cty.substitute(s))],
TypeEnum::Function(bf) => vec![Type::new_function(bf.substitute(s))],
TypeEnum::Sum(SumType::General { rows }) => {
vec![Type::new_sum(rows.iter().map(|r| r.substitute(t)))]
vec![Type::new_sum(rows.iter().map(|r| r.substitute(s)))]
}
}
}
}

impl TryFrom<Type<true>> for Type<false> {
type Error = ConvertError;
fn try_from(value: Type<true>) -> Result<Self, Self::Error> {
Ok(Self(match value.0 {
TypeEnum::Extension(e) => TypeEnum::Extension(e),
TypeEnum::Alias(a) => TypeEnum::Alias(a),
TypeEnum::Function(ft) => TypeEnum::Function(ft),
TypeEnum::Variable(i, b) => TypeEnum::Variable(i,b),
TypeEnum::RowVariable(_, _) => return Err(ConvertError),
TypeEnum::Sum(st) => TypeEnum::Sum(st)
}, value.1))
}
}

impl Type<false> {
fn substitute(&self, s: &Substitution) -> Self {
let v = self.subst_vec(s);
let [r] = v.try_into().unwrap(); // No row vars, so every Type<false> produces exactly one
r
}
}

impl Type<true> {
/// New use (occurrence) of the row variable with specified index.
/// `bound` must match that with which the variable was declared
/// (i.e. as a [TypeParam::List]` of a `[TypeParam::Type]` of that bound).
/// For use in [OpDef], not [FuncDefn], type schemes only.
///
/// [OpDef]: crate::extension::OpDef
/// [FuncDefn]: crate::ops::FuncDefn
pub const fn new_row_var(idx: usize, bound: TypeBound) -> Self {
Self(TypeEnum::RowVariable(idx, bound), bound)
}

fn substitute(&self, s: &Substitution) -> Vec<Self> {
self.subst_vec(s)
}

}

/// Details a replacement of type variables with a finite list of known values.
/// (Variables out of the range of the list will result in a panic)
pub(crate) struct Substitution<'a>(&'a [TypeArg], &'a ExtensionRegistry);
Expand All @@ -448,7 +494,7 @@ impl<'a> Substitution<'a> {
arg.clone()
}

fn apply_rowvar(&self, idx: usize, bound: TypeBound) -> Vec<Type> {
fn apply_rowvar(&self, idx: usize, bound: TypeBound) -> Vec<Type<true>> {
let arg = self
.0
.get(idx)
Expand Down Expand Up @@ -476,6 +522,19 @@ impl<'a> Substitution<'a> {
}
}

impl From<Type<false>> for Type<true> {
fn from(value: Type<false>) -> Self {
Self(match value.0 {
TypeEnum::Alias(a) => TypeEnum::Alias(a),
TypeEnum::Extension(e) => TypeEnum::Extension(e),
TypeEnum::Function(ft) => TypeEnum::Function(ft),
TypeEnum::Variable(idx, bound) => TypeEnum::Variable(idx, bound),
TypeEnum::RowVariable(_, _) => panic!("Type<false> should not contain row variables"),
TypeEnum::Sum(st) => TypeEnum::Sum(st),
}, value.1)
}
}

pub(crate) fn check_typevar_decl(
decls: &[TypeParam],
idx: usize,
Expand Down
Loading

0 comments on commit 0b608a5

Please sign in to comment.