diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 77a591de8c..4ce3e7e199 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -1,24 +1,18 @@ //! Conversions between integer and floating-point values. use crate::{ - extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, PRELUDE}, + extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, PRELUDE}, type_row, - types::{type_param::TypeArg, FunctionType, PolyFuncType, Type}, + types::{FunctionType, PolyFuncType, Type}, Extension, }; -use super::int_types::INT_TYPE_ID; +use super::int_types::{int_type_var, INT_TYPE_ID}; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); -fn int_type_var(var_id: usize, int_type_def: &TypeDef) -> Result { - Ok(Type::new_extension(int_type_def.instantiate(vec![ - TypeArg::new_var_use(var_id, LOG_WIDTH_TYPE_PARAM), - ])?)) -} - fn ftoi_sig( int_type_var: Type, temp_reg: &ExtensionRegistry, diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 4a6aae1983..cac087d6af 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -1,9 +1,10 @@ //! Basic integer operations. -use super::int_types::{get_log_width, int_type, type_arg, LOG_WIDTH_TYPE_PARAM}; +use super::int_types::{get_log_width, int_type, int_type_var, INT_TYPE_ID, LOG_WIDTH_TYPE_PARAM}; use crate::extension::prelude::{BOOL_T, ERROR_TYPE}; +use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::type_row; -use crate::types::FunctionType; +use crate::types::{FunctionType, PolyFuncType}; use crate::utils::collect_array; use crate::{ extension::{ExtensionId, ExtensionSet, SignatureError}, @@ -40,100 +41,145 @@ fn inarrow_sig(arg_values: &[TypeArg]) -> Result { )) } -fn itob_sig(_arg_values: &[TypeArg]) -> Result { - Ok(FunctionType::new( - vec![int_type(type_arg(0))], - type_row![BOOL_T], - )) +fn int_polytype( + n_vars: usize, + input: impl Into, + output: impl Into, + temp_reg: &ExtensionRegistry, +) -> Result { + PolyFuncType::new_validated( + vec![LOG_WIDTH_TYPE_PARAM; n_vars], + FunctionType::new(input, output), + temp_reg, + ) } -fn btoi_sig(_arg_values: &[TypeArg]) -> Result { - Ok(FunctionType::new( - type_row![BOOL_T], - vec![int_type(type_arg(0))], - )) +fn itob_sig( + int_type_var: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype(1, vec![int_type_var], type_row![BOOL_T], temp_reg) } -fn icmp_sig(arg_values: &[TypeArg]) -> Result { - let [arg] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg.clone()); 2], - type_row![BOOL_T], - )) +fn btoi_sig( + int_type_var: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype(1, type_row![BOOL_T], vec![int_type_var], temp_reg) } -fn ibinop_sig(arg_values: &[TypeArg]) -> Result { - let [arg] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg.clone()); 2], - vec![int_type(arg.clone())], - )) +fn icmp_sig( + int_type_var: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype(1, vec![int_type_var; 2], type_row![BOOL_T], temp_reg) } -fn iunop_sig(arg_values: &[TypeArg]) -> Result { - let [arg] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg.clone())], - vec![int_type(arg.clone())], - )) +fn ibinop_sig( + int_type_var: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype( + 1, + vec![int_type_var.clone(); 2], + vec![int_type_var.clone()], + temp_reg, + ) } -fn idivmod_checked_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - let intpair: TypeRow = vec![int_type(arg0.clone()), int_type(arg1.clone())].into(); - Ok(FunctionType::new( +fn iunop_sig( + int_type_var: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype(1, vec![int_type_var.clone()], vec![int_type_var], temp_reg) +} + +fn idivmod_checked_sig( + int_type_var_0: Type, + int_type_var_1: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + let intpair: TypeRow = vec![int_type_var_0, int_type_var_1].into(); + int_polytype( + 2, intpair.clone(), vec![Type::new_sum(vec![Type::new_tuple(intpair), ERROR_TYPE])], - )) + temp_reg, + ) } -fn idivmod_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - let intpair: TypeRow = vec![int_type(arg0.clone()), int_type(arg1.clone())].into(); - Ok(FunctionType::new( - intpair.clone(), - vec![Type::new_tuple(intpair)], - )) +fn idivmod_sig( + int_type_var_0: Type, + int_type_var_1: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + let intpair: TypeRow = vec![int_type_var_0, int_type_var_1].into(); + int_polytype(2, intpair.clone(), vec![Type::new_tuple(intpair)], temp_reg) } -fn idiv_checked_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![Type::new_sum(vec![int_type(arg0.clone()), ERROR_TYPE])], - )) +fn idiv_checked_sig( + int_type_var_0: Type, + int_type_var_1: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype( + 2, + vec![int_type_var_0.clone(), int_type_var_1], + vec![Type::new_sum(vec![int_type_var_0, ERROR_TYPE])], + temp_reg, + ) } -fn idiv_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![int_type(arg0.clone())], - )) +fn idiv_sig( + int_type_var_0: Type, + int_type_var_1: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype( + 2, + vec![int_type_var_0.clone(), int_type_var_1], + vec![int_type_var_0], + temp_reg, + ) } -fn imod_checked_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])], - )) +fn imod_checked_sig( + int_type_var_0: Type, + int_type_var_1: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype( + 2, + vec![int_type_var_0, int_type_var_1.clone()], + vec![Type::new_sum(vec![int_type_var_1, ERROR_TYPE])], + temp_reg, + ) } -fn imod_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![int_type(arg1.clone())], - )) +fn imod_sig( + int_type_var_0: Type, + int_type_var_1: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype( + 2, + vec![int_type_var_0, int_type_var_1.clone()], + vec![int_type_var_1], + temp_reg, + ) } -fn ish_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![int_type(arg0.clone())], - )) +fn ish_sig( + int_type_var_0: Type, + int_type_var_1: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype( + 2, + vec![int_type_var_0.clone(), int_type_var_1], + vec![int_type_var_0], + temp_reg, + ) } /// Extension for basic integer operations. @@ -142,6 +188,13 @@ pub fn extension() -> Extension { EXTENSION_ID, ExtensionSet::singleton(&super::int_types::EXTENSION_ID), ); + let int_types_extension = super::int_types::extension(); + let int_type_def = int_types_extension.get_type(&INT_TYPE_ID).unwrap(); + let int_type_var_0 = int_type_var(0, int_type_def).unwrap(); + let int_type_var_1 = int_type_var(1, int_type_def).unwrap(); + + let temp_reg: ExtensionRegistry = + [extension.clone(), int_types_extension, PRELUDE.to_owned()].into(); extension .add_op_custom_sig_simple( @@ -177,347 +230,306 @@ pub fn extension() -> Extension { ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "itobool".into(), "convert to bool (1 is true, 0 is false)".to_owned(), - vec![], - itob_sig, + itob_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ifrombool".into(), "convert from bool (1 is true, 0 is false)".to_owned(), - vec![], - btoi_sig, + btoi_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ieq".into(), "equality test".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ine".into(), "inequality test".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ilt_u".into(), "\"less than\" as unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ilt_s".into(), "\"less than\" as signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "igt_u".into(), "\"greater than\" as unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "igt_s".into(), "\"greater than\" as signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ile_u".into(), "\"less than or equal\" as unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ile_s".into(), "\"less than or equal\" as signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ige_u".into(), "\"greater than or equal\" as unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ige_s".into(), "\"greater than or equal\" as signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imax_u".into(), "maximum of unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imax_s".into(), "maximum of signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imin_u".into(), "minimum of unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imin_s".into(), "minimum of signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "iadd".into(), "addition modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "isub".into(), "subtraction modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ineg".into(), "negation modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - iunop_sig, + iunop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imul".into(), "multiplication modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "idivmod_checked_u".into(), "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ q*m+r=n, 0<=r TypeArg { + TypeArg::BoundedNat { n } + } + #[test] + fn test_binary_signatures() { + let sig = iwiden_sig(&[ta(3), ta(4)]).unwrap(); + assert_eq!( + sig, + FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(4))],) + ); + + iwiden_sig(&[ta(4), ta(3)]).unwrap_err(); + + let sig = inarrow_sig(&[ta(2), ta(1)]).unwrap(); + assert_eq!( + sig, + FunctionType::new( + vec![int_type(ta(2))], + vec![Type::new_sum(vec![int_type(ta(1)), ERROR_TYPE])], + ) + ); + + inarrow_sig(&[ta(1), ta(2)]).unwrap_err(); + } } diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index 0ad7c7ad65..526ea9d679 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -5,7 +5,7 @@ use std::num::NonZeroU64; use smol_str::SmolStr; use crate::{ - extension::ExtensionId, + extension::{ExtensionId, SignatureError, TypeDef}, types::{ type_param::{TypeArg, TypeArgError, TypeParam}, ConstTypeError, CustomCheckFailure, CustomType, Type, TypeBound, @@ -198,6 +198,12 @@ pub fn extension() -> Extension { extension } +/// get an integer type variable, given the integer type definition +pub(super) fn int_type_var(var_id: usize, int_type_def: &TypeDef) -> Result { + Ok(Type::new_extension(int_type_def.instantiate(vec![ + TypeArg::new_var_use(var_id, LOG_WIDTH_TYPE_PARAM), + ])?)) +} #[cfg(test)] mod test { use cool_asserts::assert_matches;