diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index 98038fc18f..37ad705bc2 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -136,6 +136,11 @@ pub const ERROR_TYPE: Type = Type::new_extension(CustomType::new_simple( /// The string name of the error type. pub const ERROR_TYPE_NAME: SmolStr = SmolStr::new_inline("error"); +/// Return a Sum type with the first variant as the given type and the second an Error. +pub fn sum_with_error(ty: Type) -> Type { + Type::new_sum(vec![ty, ERROR_TYPE]) +} + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] /// Structure for holding constant usize values. pub struct ConstUsize(u64); diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 4ce3e7e199..cbac81a85a 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -1,7 +1,10 @@ //! Conversions between integer and floating-point values. use crate::{ - extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, PRELUDE}, + extension::{ + prelude::sum_with_error, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, + PRELUDE, + }, type_row, types::{FunctionType, PolyFuncType, Type}, Extension, @@ -17,13 +20,7 @@ fn ftoi_sig( int_type_var: Type, temp_reg: &ExtensionRegistry, ) -> Result { - let body = FunctionType::new( - type_row![FLOAT64_TYPE], - vec![Type::new_sum(vec![ - int_type_var, - crate::extension::prelude::ERROR_TYPE, - ])], - ); + let body = FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_type_var)]); PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg) } diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index cac087d6af..18afe26551 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -1,7 +1,7 @@ //! Basic integer operations. use super::int_types::{get_log_width, int_type, int_type_var, INT_TYPE_ID, LOG_WIDTH_TYPE_PARAM}; -use crate::extension::prelude::{BOOL_T, ERROR_TYPE}; +use crate::extension::prelude::{sum_with_error, BOOL_T}; use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::type_row; use crate::types::{FunctionType, PolyFuncType}; @@ -37,7 +37,7 @@ fn inarrow_sig(arg_values: &[TypeArg]) -> Result { } Ok(FunctionType::new( vec![int_type(arg0.clone())], - vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])], + vec![sum_with_error(int_type(arg1.clone()))], )) } @@ -103,7 +103,7 @@ fn idivmod_checked_sig( int_polytype( 2, intpair.clone(), - vec![Type::new_sum(vec![Type::new_tuple(intpair), ERROR_TYPE])], + vec![sum_with_error(Type::new_tuple(intpair))], temp_reg, ) } @@ -125,7 +125,7 @@ fn idiv_checked_sig( int_polytype( 2, vec![int_type_var_0.clone(), int_type_var_1], - vec![Type::new_sum(vec![int_type_var_0, ERROR_TYPE])], + vec![sum_with_error(int_type_var_0)], temp_reg, ) } @@ -151,7 +151,7 @@ fn imod_checked_sig( int_polytype( 2, vec![int_type_var_0, int_type_var_1.clone()], - vec![Type::new_sum(vec![int_type_var_1, ERROR_TYPE])], + vec![sum_with_error(int_type_var_1)], temp_reg, ) } @@ -566,10 +566,7 @@ mod test { let sig = inarrow_sig(&[ta(2), ta(1)]).unwrap(); assert_eq!( sig, - FunctionType::new( - vec![int_type(ta(2))], - vec![Type::new_sum(vec![int_type(ta(1)), ERROR_TYPE])], - ) + FunctionType::new(vec![int_type(ta(2))], vec![sum_with_error(int_type(ta(1)))],) ); inarrow_sig(&[ta(1), ta(2)]).unwrap_err(); diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index 7db1225e22..cc830f9303 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -75,7 +75,7 @@ fn extension() -> Extension { let temp_reg: ExtensionRegistry = [extension.clone()].into(); let list_type_def = extension.get_type(&LIST_TYPENAME).unwrap(); - let (l, e) = list_types(list_type_def); + let (l, e) = list_and_elem_type(list_type_def); extension .add_op_type_scheme_simple( POP_NAME, @@ -103,16 +103,6 @@ fn extension() -> Extension { extension } -fn list_types(list_type_def: &TypeDef) -> (Type, Type) { - let elem_type = Type::new_var_use(0, TypeBound::Any); - let list_type = Type::new_extension( - list_type_def - .instantiate(vec![TypeArg::new_var_use(0, TP)]) - .unwrap(), - ); - (list_type, elem_type) -} - lazy_static! { /// Collections extension definition. pub static ref EXTENSION: Extension = extension(); @@ -122,6 +112,15 @@ fn get_type(name: &str) -> &TypeDef { EXTENSION.get_type(name).unwrap() } +fn list_and_elem_type(list_type_def: &TypeDef) -> (Type, Type) { + let elem_type = Type::new_var_use(0, TypeBound::Any); + let list_type = Type::new_extension( + list_type_def + .instantiate(vec![TypeArg::new_var_use(0, TP)]) + .unwrap(), + ); + (list_type, elem_type) +} #[cfg(test)] mod test { use crate::{ diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index d432fadadb..794b98c9de 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -47,7 +47,7 @@ fn extension() -> Extension { "logical 'and'".into(), vec![H_INT], |arg_values: &[TypeArg]| { - let TypeArg::BoundedNat { n } = arg_values.iter().exactly_one().unwrap() else { + let Ok(TypeArg::BoundedNat { n }) = arg_values.iter().exactly_one() else { panic!("should be covered by validation.") }; @@ -65,7 +65,7 @@ fn extension() -> Extension { "logical 'or'".into(), vec![H_INT], |arg_values: &[TypeArg]| { - let TypeArg::BoundedNat { n } = arg_values.iter().exactly_one().unwrap() else { + let Ok(TypeArg::BoundedNat { n }) = arg_values.iter().exactly_one() else { panic!("should be covered by validation.") };