Skip to content

Commit

Permalink
try adding FunctionTypeVarArgs
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed May 16, 2024
1 parent d728714 commit 5ec995f
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 35 deletions.
4 changes: 2 additions & 2 deletions hugr/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
types::EdgeKind,
};

use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE_REGISTRY};
use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError, PRELUDE_REGISTRY};
use crate::types::{FunctionType, PolyFuncType, Type, TypeArg, TypeRow};

use itertools::Itertools;
Expand Down Expand Up @@ -87,7 +87,7 @@ pub trait Container {
name: impl Into<String>,
signature: PolyFuncType,
) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
let body = signature.body().clone();
let body = signature.body_norowvars()?.clone();
let f_node = self.add_child_node(NodeType::new_pure(ops::FuncDefn {
name: name.into(),
signature,
Expand Down
2 changes: 1 addition & 1 deletion hugr/src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl FunctionBuilder<Hugr> {
///
/// Error in adding DFG child nodes.
pub fn new(name: impl Into<String>, signature: PolyFuncType) -> Result<Self, BuildError> {
let body = signature.body().clone();
let body = signature.body_norowvars()?.clone();
let op = ops::FuncDefn {
signature,
name: name.into(),
Expand Down
2 changes: 1 addition & 1 deletion hugr/src/builder/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
op_desc: "crate::ops::OpType::FuncDecl",
})?
.clone();
let body = signature.body().clone();
let body = signature.body_norowvars()?.clone();
self.hugr_mut()
.replace_op(
f_node,
Expand Down
2 changes: 1 addition & 1 deletion hugr/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use super::{

use crate::ops::{OpName, OpNameRef};
use crate::types::type_param::{check_type_args, TypeArg, TypeParam};
use crate::types::{FunctionType, PolyFuncType};
use crate::types::{FunctionType, FunctionTypeVarArgs, PolyFuncType, Substitution};
use crate::Hugr;

/// Trait necessary for binary computations of OpDef signature
Expand Down
11 changes: 5 additions & 6 deletions hugr/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ fn inner_row_variables() -> Result<(), Box<dyn std::error::Error>> {
#[rstest]
#[case(false)]
#[case(true)]
fn no_outer_row_variables(#[case] connect: bool) -> Result<(), Box<dyn std::error::Error>> {
fn no_outer_row_variables(#[case] connect: bool) {
let e = extension_with_eval_parallel();
let tv = Type::new_row_var_use(0, TypeBound::Copyable);
let mut fb = FunctionBuilder::new(
Expand All @@ -669,15 +669,15 @@ fn no_outer_row_variables(#[case] connect: bool) -> Result<(), Box<dyn std::erro
if connect { vec![tv.clone()] } else { vec![] },
),
),
)?;
).unwrap();
let [func_arg] = fb.input_wires_arr();
let i = fb.add_load_value(crate::extension::prelude::ConstUsize::new(5));
let ev =
e.instantiate_extension_op("eval", [seq1ty(USIZE_T), seq1ty(tv)], &PRELUDE_REGISTRY)?;
let ev = fb.add_dataflow_op(ev, [func_arg, i])?;
e.instantiate_extension_op("eval", [seq1ty(USIZE_T), seq1ty(tv)], &PRELUDE_REGISTRY).unwrap();
let ev = fb.add_dataflow_op(ev, [func_arg, i]).unwrap();
let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), e]).unwrap();
if connect {
fb.set_outputs(ev.outputs())?;
fb.set_outputs(ev.outputs()).unwrap();
}
assert_eq!(
fb.finish_hugr(&reg).unwrap_err(),
Expand All @@ -686,7 +686,6 @@ fn no_outer_row_variables(#[case] connect: bool) -> Result<(), Box<dyn std::erro
cause: SignatureError::RowTypeVarOutsideRow { idx: 0 }
}
);
Ok(())
}

#[test]
Expand Down
3 changes: 2 additions & 1 deletion hugr/src/ops/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ impl StaticTag for FuncDefn {

impl DataflowParent for FuncDefn {
fn inner_signature(&self) -> FunctionType {
self.signature.body().clone()
// ok by validation
self.signature.body_norowvars().unwrap().clone()
}
}

Expand Down
13 changes: 9 additions & 4 deletions hugr/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::utils::display_list_with_separator;
pub use check::SumTypeError;
pub use custom::CustomType;
pub use poly_func::PolyFuncType;
pub use signature::FunctionType;
pub use signature::{FunctionType, FunctionTypeVarArgs};
use smol_str::SmolStr;
pub use type_param::TypeArg;
pub use type_row::TypeRow;
Expand Down Expand Up @@ -205,7 +205,7 @@ pub enum TypeEnum {
Alias(AliasDecl),
#[allow(missing_docs)]
#[display(fmt = "Function({})", "_0")]
Function(Box<FunctionType>),
Function(Box<FunctionTypeVarArgs>),
// Index into TypeParams, and cache of TypeBound (checked in validation)
#[allow(missing_docs)]
#[display(fmt = "Variable({})", _0)]
Expand Down Expand Up @@ -272,7 +272,7 @@ impl Type {
const EMPTY_TYPEROW_REF: &'static TypeRow = &Self::EMPTY_TYPEROW;

/// Initialize a new function type.
pub fn new_function(fun_ty: impl Into<FunctionType>) -> Self {
pub fn new_function(fun_ty: impl Into<FunctionTypeVarArgs>) -> Self {
Self::new(TypeEnum::Function(Box::new(fun_ty.into())))
}

Expand Down Expand Up @@ -359,6 +359,11 @@ impl Type {
}
}

/// TODO docs
pub fn is_row_var(&self) -> bool {
matches!(self.0, TypeEnum::RowVariable(_, _))
}

/// Checks that this [Type] represents a single Type, not a row variable,
/// that all variables used within are in the provided list of bound variables,
/// and that for each [CustomType], the corresponding
Expand Down Expand Up @@ -434,7 +439,7 @@ impl Type {

/// Details a replacement of type variables with a finite list of known values.
/// (Variables out of the range of the list will result in a panic)
pub(crate) struct Substitution<'a>(&'a [TypeArg], &'a ExtensionRegistry);
pub(crate) struct Substitution<'a>(pub(crate) &'a [TypeArg], pub(crate) &'a ExtensionRegistry);

impl<'a> Substitution<'a> {
pub(crate) fn apply_var(&self, idx: usize, decl: &TypeParam) -> TypeArg {
Expand Down
49 changes: 41 additions & 8 deletions hugr/src/types/poly_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::extension::{ExtensionRegistry, SignatureError};
use itertools::Itertools;

use super::type_param::{check_type_args, TypeArg, TypeParam};
use super::{FunctionType, Substitution};
use super::{FunctionType, FunctionTypeVarArgs, Substitution};

/// A polymorphic type scheme, i.e. of a [FuncDecl], [FuncDefn] or [OpDef].
/// (Nodes/operations in the Hugr are not polymorphic.)
Expand All @@ -20,25 +20,32 @@ use super::{FunctionType, Substitution};
"params.iter().map(ToString::to_string).join(\" \")",
"body"
)]

pub struct PolyFuncType {
/// The declared type parameters, i.e., these must be instantiated with
/// the same number of [TypeArg]s before the function can be called. This
/// defines the indices used by variables inside the body.
params: Vec<TypeParam>,
/// Template for the function. May contain variables up to length of [Self::params]
body: FunctionType,
body: FunctionTypeVarArgs,
}

impl From<FunctionType> for PolyFuncType {
fn from(body: FunctionType) -> Self {
impl From<FunctionTypeVarArgs> for PolyFuncType {
fn from(body: FunctionTypeVarArgs) -> Self {
Self {
params: vec![],
body,
}
}
}

impl TryFrom<PolyFuncType> for FunctionType {
impl From<FunctionType> for PolyFuncType {
fn from(body: FunctionType) -> Self {
Into::<FunctionTypeVarArgs>::into(body).into()
}
}

impl TryFrom<PolyFuncType> for FunctionTypeVarArgs {
/// If the PolyFuncType is not a monomorphic FunctionType, fail with the binders
type Error = Vec<TypeParam>;

Expand All @@ -51,23 +58,41 @@ impl TryFrom<PolyFuncType> for FunctionType {
}
}

impl TryFrom<PolyFuncType> for FunctionType {
/// If the PolyFuncType is not a monomorphic FunctionType, fail with the binders
type Error = PolyFuncType;

fn try_from(value: PolyFuncType) -> Result<Self, Self::Error> {
if let Ok(ftva) = TryInto::<FunctionTypeVarArgs>::try_into(value.clone()){
ftva.try_into().map_err(|_| value)
} else {
Err(value)
}
}
}

impl PolyFuncType {
/// The type parameters, aka binders, over which this type is polymorphic
pub fn params(&self) -> &[TypeParam] {
&self.params
}

/// The body of the type, a function type.
pub fn body(&self) -> &FunctionType {
pub fn body(&self) -> &FunctionTypeVarArgs {
&self.body
}

/// The body of the type, a function type.
pub fn body_norowvars(&self) -> Result<&FunctionType,SignatureError> {
self.body.try_as_ref()
}

/// Create a new PolyFuncType given the kinds of the variables it declares
/// and the underlying [FunctionType].
pub fn new(params: impl Into<Vec<TypeParam>>, body: FunctionType) -> Self {
pub fn new(params: impl Into<Vec<TypeParam>>, body: impl Into<FunctionTypeVarArgs>) -> Self {
Self {
params: params.into(),
body,
body: body.into(),
}
}

Expand All @@ -92,6 +117,14 @@ impl PolyFuncType {
args: &[TypeArg],
ext_reg: &ExtensionRegistry,
) -> Result<FunctionType, SignatureError> {
self.instantiate_varargs(args,ext_reg).and_then(TryInto::try_into)
}

pub(crate) fn instantiate_varargs(
&self,
args: &[TypeArg],
ext_reg: &ExtensionRegistry,
) -> Result<FunctionTypeVarArgs, SignatureError> {
// Check that args are applicable, and that we have a value for each binder,
// i.e. each possible free variable within the body.
check_type_args(args, &self.params)?;
Expand Down
4 changes: 2 additions & 2 deletions hugr/src/types/serialize.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{FunctionType, SumType, Type, TypeArg, TypeBound, TypeEnum};
use super::{FunctionType, FunctionTypeVarArgs, SumType, Type, TypeArg, TypeBound, TypeEnum};

use super::custom::CustomType;

Expand All @@ -10,7 +10,7 @@ use crate::ops::AliasDecl;
pub(super) enum SerSimpleType {
Q,
I,
G(Box<FunctionType>),
G(Box<FunctionTypeVarArgs>),
Sum(SumType),
Array { inner: Box<SerSimpleType>, len: u64 },
Opaque(CustomType),
Expand Down
88 changes: 79 additions & 9 deletions hugr/src/types/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use itertools::Either;
use std::fmt::{self, Display, Write};

use super::type_param::TypeParam;
use super::{Substitution, Type, TypeRow};
use super::{Substitution, Type, TypeEnum, TypeRow};

use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError};
use crate::{Direction, IncomingPort, OutgoingPort, Port};
Expand All @@ -25,25 +25,95 @@ pub struct FunctionType {
pub extension_reqs: ExtensionSet,
}

impl FunctionType {
/// Builder method, add extension_reqs to an FunctionType
pub fn with_extension_delta(mut self, rs: impl Into<ExtensionSet>) -> Self {
self.extension_reqs = self.extension_reqs.union(rs.into());
self
#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct FunctionTypeVarArgs(FunctionType);

impl FunctionTypeVarArgs {
pub fn any_rowvars(&self) -> bool {
self.rowvar().is_some()
}

pub fn rowvar(&self) -> Option<usize> {
self.0.input.iter().chain(self.0.output.iter()).filter_map(|x| match x.as_type_enum() {
TypeEnum::RowVariable(i,_) => Some(*i),
_ => None
}).next()
}

pub fn try_as_ref(&self) -> Result<&FunctionType,SignatureError> {
if let Some(idx) = self.rowvar() {
Err(SignatureError::RowTypeVarOutsideRow { idx })
} else {
Ok(&self.0)
}
}

pub(super) fn validate_varargs(
&self,
extension_registry: &ExtensionRegistry,
var_decls: &[TypeParam],
) -> Result<(), SignatureError> {
self.input.validate_var_len(extension_registry, var_decls)?;
self.output
self.0.input.validate_var_len(extension_registry, var_decls)?;
self.0.output
.validate_var_len(extension_registry, var_decls)?;
self.extension_reqs.validate(var_decls)
self.0.extension_reqs.validate(var_decls)
}

pub(crate) fn substitute(&self, tr: &Substitution) -> Self {
Self(FunctionType {
input: self.0.input.substitute(tr),
output: self.0.output.substitute(tr),
extension_reqs: self.0.extension_reqs.substitute(tr),
})
}
}

impl PartialEq<FunctionType> for FunctionTypeVarArgs {
fn eq(&self, other: &FunctionType) -> bool {
&self.0 == other
}
}

impl PartialEq<FunctionTypeVarArgs> for FunctionType {
fn eq(&self, other: &FunctionTypeVarArgs) -> bool {
self == &other.0
}
}

impl Display for FunctionTypeVarArgs {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}

impl TryFrom<FunctionTypeVarArgs> for FunctionType {
type Error = SignatureError;
fn try_from(value: FunctionTypeVarArgs) -> Result<Self, Self::Error> {
if let Some(idx) = value.rowvar() {
Err(SignatureError::RowTypeVarOutsideRow { idx })
} else {
Ok(value.0)
}
}
}

impl From<FunctionType> for FunctionTypeVarArgs {
fn from(value: FunctionType) -> Self {
Self(value)
}
}


impl FunctionType {
/// Builder method, add extension_reqs to an FunctionType
pub fn with_extension_delta(mut self, rs: impl Into<ExtensionSet>) -> Self {
self.extension_reqs = self.extension_reqs.union(rs.into());
self
}


pub(crate) fn substitute(&self, tr: &Substitution) -> Self {
// TODO assert no row vars in result
FunctionType {
input: self.input.substitute(tr),
output: self.output.substitute(tr),
Expand Down

0 comments on commit 5ec995f

Please sign in to comment.