Skip to content

Commit

Permalink
refactor!: use dedicated structs for varargs
Browse files Browse the repository at this point in the history
and remove the automatic closure conversion
  • Loading branch information
ss2165 committed Nov 21, 2023
1 parent cb9edde commit 0748200
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 108 deletions.
64 changes: 14 additions & 50 deletions src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use crate::types::type_param::{check_type_args, TypeArg, TypeParam};
use crate::types::{FunctionType, PolyFuncType};
use crate::Hugr;

pub trait ComputeSignature: Send + Sync {
/// Trait necessary for binary computations of OpDef signature
pub trait CustomSignatureFunc: Send + Sync {
/// Compute signature of node given the operation name,
/// values for the type parameters,
/// and 'misc' data from the extension definition YAML
Expand All @@ -27,23 +28,9 @@ pub trait ComputeSignature: Send + Sync {
misc: &HashMap<String, serde_yaml::Value>,
extension_registry: &ExtensionRegistry,
) -> Result<PolyFuncType, SignatureError>;
}

// Note this is very much a utility, rather than definitive;
// one can only do so much without the ExtensionRegistry!
impl<F, R: Into<PolyFuncType>> ComputeSignature for F
where
F: Fn(&[TypeArg]) -> Result<R, SignatureError> + Send + Sync,
{
fn compute_signature(
&self,
_name: &SmolStr,
arg_values: &[TypeArg],
_misc: &HashMap<String, serde_yaml::Value>,
_extension_registry: &ExtensionRegistry,
) -> Result<PolyFuncType, SignatureError> {
Ok(self(arg_values)?.into())
}
/// The declared type parameters which require values in order for signature to
/// be computed.
fn static_params(&self) -> &[TypeParam];
}

pub trait ValidateTypeArgs: Send + Sync {
Expand Down Expand Up @@ -99,30 +86,6 @@ pub trait CustomLowerFunc: Send + Sync {
) -> Option<Hugr>;
}

/// Compute a signature by recording the type parameters and a custom function
/// for computing the signature.
pub struct CustomSignatureFunc {
/// Type parameters passed to [func]. (The returned [PolyFuncType]
/// may require further type parameters, not declared here.)
static_params: Vec<TypeParam>,
func: Box<dyn ComputeSignature>,
}

impl CustomSignatureFunc {
/// Build custom computation from a function that takes in type arguments
/// and returns a signature.
pub fn from_function<F, R>(static_params: impl Into<Vec<TypeParam>>, func: F) -> Self
where
R: Into<PolyFuncType>,
F: Fn(&[TypeArg]) -> Result<R, SignatureError> + Send + Sync + 'static,
{
Self {
static_params: static_params.into(),
func: Box::new(func),
}
}
}

/// Encode a signature as `PolyFuncType` but optionally allow validating type
/// arguments via a custom binary.
#[derive(serde::Deserialize, serde::Serialize)]
Expand Down Expand Up @@ -164,7 +127,7 @@ pub enum SignatureFunc {
#[serde(rename = "signature")]
TypeScheme(CustomValidator),
#[serde(skip)]
CustomFunc(CustomSignatureFunc),
CustomFunc(Box<dyn CustomSignatureFunc>),
}

impl Default for Box<dyn ValidateTypeArgs> {
Expand All @@ -173,9 +136,9 @@ impl Default for Box<dyn ValidateTypeArgs> {
}
}

impl<T: Into<CustomSignatureFunc>> From<T> for SignatureFunc {
impl<T: CustomSignatureFunc + 'static> From<T> for SignatureFunc {
fn from(v: T) -> Self {
Self::CustomFunc(v.into())
Self::CustomFunc(Box::new(v))
}
}

Expand Down Expand Up @@ -216,10 +179,11 @@ impl SignatureFunc {
let static_params = self.static_params();
let (static_args, other_args) =
arg_values.split_at(min(static_params.len(), arg_values.len()));

dbg!(&static_args, &other_args);
check_type_args(static_args, static_params)?;
let pf =
func.func
.compute_signature(name, static_args, misc, extension_registry)?;
let pf = func.compute_signature(name, static_args, misc, extension_registry)?;
dbg!(pf.params());
(pf, other_args)
}
})
Expand All @@ -228,7 +192,7 @@ impl SignatureFunc {
fn static_params(&self) -> &[TypeParam] {
match self {
SignatureFunc::TypeScheme(ts) => ts.poly_func.params(),
SignatureFunc::CustomFunc(func) => &func.static_params,
SignatureFunc::CustomFunc(func) => func.static_params(),
}
}
}
Expand Down Expand Up @@ -414,7 +378,7 @@ impl Extension {
}
}

/// Create an OpDef with `PolyFuncType`, `CustomSignatureFunc` or `CustomValidator`
/// Create an OpDef with `PolyFuncType`, `impl CustomSignatureFunc` or `CustomValidator`
/// ; and no "misc" or "lowering functions" defined.
pub fn add_op_simple(
&mut self,
Expand Down
67 changes: 43 additions & 24 deletions src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,46 @@ use lazy_static::lazy_static;
use smol_str::SmolStr;

use crate::{
extension::{op_def::CustomSignatureFunc, ExtensionId, TypeDefBound},
extension::{ExtensionId, TypeDefBound},
ops::LeafOp,
types::{
type_param::{TypeArg, TypeParam},
CustomCheckFailure, CustomType, FunctionType, Type, TypeBound,
CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeBound,
},
values::{CustomConst, KnownTypeConst},
Extension,
};

use super::ExtensionRegistry;
use super::{CustomSignatureFunc, ExtensionRegistry, SignatureError};
struct ArrayOpCustom;

const MAX: &[TypeParam; 1] = &[TypeParam::max_nat()];
impl CustomSignatureFunc for ArrayOpCustom {
fn compute_signature(
&self,
_name: &SmolStr,
arg_values: &[TypeArg],
_misc: &std::collections::HashMap<String, serde_yaml::Value>,
_extension_registry: &ExtensionRegistry,
) -> Result<PolyFuncType, SignatureError> {
let [TypeArg::BoundedNat { n }] = *arg_values else {
panic!("Should have been checked already.")
};
let elem_ty_var = Type::new_var_use(0, TypeBound::Any);

let var_arg_row = vec![elem_ty_var.clone(); n as usize];
let other_row = vec![array_type(elem_ty_var.clone(), TypeArg::BoundedNat { n })];

Ok(PolyFuncType::new(
vec![TypeParam::Type(TypeBound::Any)],
FunctionType::new(var_arg_row, other_row),
))
}

fn static_params(&self) -> &[TypeParam] {
MAX
}
}

/// Name of prelude extension.
pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude");
Expand All @@ -34,26 +63,16 @@ lazy_static! {
prelude
.add_type(
SmolStr::new_inline("array"),
vec![TypeParam::Type(TypeBound::Any), TypeParam::max_nat()],
vec![ TypeParam::max_nat(),TypeParam::Type(TypeBound::Any)],
"array".into(),
TypeDefBound::FromParams(vec![0]),
TypeDefBound::FromParams(vec![1]),
)
.unwrap();

prelude
.add_op_simple(
SmolStr::new_inline(NEW_ARRAY_OP_ID),
"Create a new array from elements".to_string(),
CustomSignatureFunc::from_function(vec![TypeParam::Type(TypeBound::Any), TypeParam::max_nat()],
|args: &[TypeArg]| {
let [TypeArg::Type { ty }, TypeArg::BoundedNat { n }] = args else {
panic!("should have been checked already.")
};
Ok(FunctionType::new(
vec![ty.clone(); *n as usize],
vec![array_type(ty.clone(), *n)],
))
}),
ArrayOpCustom,
)
.unwrap();

Expand Down Expand Up @@ -98,13 +117,10 @@ pub const USIZE_T: Type = Type::new_extension(USIZE_CUSTOM_T);
pub const BOOL_T: Type = Type::new_unit_sum(2);

/// Initialize a new array of element type `element_ty` of length `size`
pub fn array_type(element_ty: Type, size: u64) -> Type {
pub fn array_type(element_ty: Type, size: TypeArg) -> Type {
let array_def = PRELUDE.get_type("array").unwrap();
let custom_t = array_def
.instantiate(vec![
TypeArg::Type { ty: element_ty },
TypeArg::BoundedNat { n: size },
])
.instantiate(vec![size, TypeArg::Type { ty: element_ty }])
.unwrap();
Type::new_extension(custom_t)
}
Expand All @@ -118,8 +134,8 @@ pub fn new_array_op(element_ty: Type, size: u64) -> LeafOp {
.instantiate_extension_op(
NEW_ARRAY_OP_ID,
vec![
TypeArg::Type { ty: element_ty },
TypeArg::BoundedNat { n: size },
TypeArg::Type { ty: element_ty },
],
&PRELUDE_REGISTRY,
)
Expand Down Expand Up @@ -179,7 +195,10 @@ impl KnownTypeConst for ConstUsize {

#[cfg(test)]
mod test {
use crate::builder::{DFGBuilder, Dataflow, DataflowHugr};
use crate::{
builder::{DFGBuilder, Dataflow, DataflowHugr},
types::FunctionType,
};

use super::*;

Expand All @@ -188,7 +207,7 @@ mod test {
fn test_new_array() {
let mut b = DFGBuilder::new(FunctionType::new(
vec![QB_T, QB_T],
vec![array_type(QB_T, 2)],
vec![array_type(QB_T, TypeArg::BoundedNat { n: 2 })],
))
.unwrap();

Expand Down
49 changes: 27 additions & 22 deletions src/std_extensions/logic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! Basic logical operations.
use itertools::Itertools;
use smol_str::SmolStr;

use crate::{
Expand Down Expand Up @@ -28,9 +27,33 @@ pub const OR_NAME: &str = "Or";
/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("logic");

fn logic_op_sig() -> impl CustomSignatureFunc {
struct LogicOpCustom;

const MAX: &[TypeParam; 1] = &[TypeParam::max_nat()];
impl CustomSignatureFunc for LogicOpCustom {
fn compute_signature(
&self,
_name: &SmolStr,
arg_values: &[TypeArg],
_misc: &std::collections::HashMap<String, serde_yaml::Value>,
_extension_registry: &crate::extension::ExtensionRegistry,
) -> Result<crate::types::PolyFuncType, crate::extension::SignatureError> {
let [TypeArg::BoundedNat { n }] = *arg_values else {
panic!("Should have been checked already.")
};
let var_arg_row = vec![BOOL_T; n as usize];
Ok(FunctionType::new(var_arg_row, vec![BOOL_T]).into())
}

fn static_params(&self) -> &[TypeParam] {
MAX
}
}
LogicOpCustom
}
/// Extension for basic logical operations.
fn extension() -> Extension {
const H_INT: TypeParam = TypeParam::max_nat();
let mut extension = Extension::new(EXTENSION_ID);

extension
Expand All @@ -45,33 +68,15 @@ fn extension() -> Extension {
.add_op_simple(
SmolStr::new_inline(AND_NAME),
"logical 'and'".into(),
CustomSignatureFunc::from_function(vec![H_INT], |arg_values: &[TypeArg]| {
let Ok(TypeArg::BoundedNat { n }) = arg_values.iter().exactly_one() else {
panic!("should be covered by validation.")
};

Ok(FunctionType::new(
vec![BOOL_T; *n as usize],
type_row![BOOL_T],
))
}),
logic_op_sig(),
)
.unwrap();

extension
.add_op_simple(
SmolStr::new_inline(OR_NAME),
"logical 'or'".into(),
CustomSignatureFunc::from_function(vec![H_INT], |arg_values: &[TypeArg]| {
let Ok(TypeArg::BoundedNat { n }) = arg_values.iter().exactly_one() else {
panic!("should be covered by validation.")
};

Ok(FunctionType::new(
vec![BOOL_T; *n as usize],
type_row![BOOL_T],
))
}),
logic_op_sig(),
)
.unwrap();

Expand Down
15 changes: 5 additions & 10 deletions src/types/poly_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ pub(crate) mod test {
#[test]
fn test_mismatched_args() -> Result<(), SignatureError> {
let ar_def = PRELUDE.get_type("array").unwrap();
let typarams = [TypeParam::Type(TypeBound::Any), TypeParam::max_nat()];
let typarams = [TypeParam::max_nat(), TypeParam::Type(TypeBound::Any)];
let [tyvar, szvar] =
[0, 1].map(|i| TypeArg::new_var_use(i, typarams.get(i).unwrap().clone()));

Expand All @@ -295,20 +295,20 @@ pub(crate) mod test {

// Sanity check (good args)
good_ts.instantiate(
&[TypeArg::Type { ty: USIZE_T }, TypeArg::BoundedNat { n: 5 }],
&[TypeArg::BoundedNat { n: 5 }, TypeArg::Type { ty: USIZE_T }],
&PRELUDE_REGISTRY,
)?;

let wrong_args = good_ts.instantiate(
&[TypeArg::BoundedNat { n: 5 }, TypeArg::Type { ty: USIZE_T }],
&[TypeArg::Type { ty: USIZE_T }, TypeArg::BoundedNat { n: 5 }],
&PRELUDE_REGISTRY,
);
assert_eq!(
wrong_args,
Err(SignatureError::TypeArgMismatch(
TypeArgError::TypeMismatch {
param: typarams[0].clone(),
arg: TypeArg::BoundedNat { n: 5 }
arg: TypeArg::Type { ty: USIZE_T }
}
))
);
Expand Down Expand Up @@ -459,12 +459,7 @@ pub(crate) mod test {

// The standard library new_array does not allow passing in a variable for size.
fn new_array(ty: Type, s: TypeArg) -> Type {
let array_def = PRELUDE.get_type("array").unwrap();
Type::new_extension(
array_def
.instantiate(vec![TypeArg::Type { ty }, s])
.unwrap(),
)
crate::extension::prelude::array_type(ty, s)
}

const USIZE_TA: TypeArg = TypeArg::Type { ty: USIZE_T };
Expand Down
6 changes: 4 additions & 2 deletions src/types/serialize.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{PolyFuncType, SumType, Type, TypeBound, TypeEnum, TypeRow};
use super::{PolyFuncType, SumType, Type, TypeArg, TypeBound, TypeEnum, TypeRow};

use super::custom::CustomType;

Expand Down Expand Up @@ -48,7 +48,9 @@ impl From<SerSimpleType> for Type {
SerSimpleType::G(sig) => Type::new_function(*sig),
SerSimpleType::Tuple { inner } => Type::new_tuple(inner),
SerSimpleType::Sum(sum) => sum.into(),
SerSimpleType::Array { inner, len } => array_type((*inner).into(), len),
SerSimpleType::Array { inner, len } => {
array_type((*inner).into(), TypeArg::BoundedNat { n: len })
}
SerSimpleType::Opaque(custom) => Type::new_extension(custom),
SerSimpleType::Alias(a) => Type::new_alias(a),
SerSimpleType::V { i, b } => Type::new_var_use(i, b),
Expand Down

0 comments on commit 0748200

Please sign in to comment.