From 0063b46c5ab7107108733789ae8f25352fff8354 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 12:32:38 +0000 Subject: [PATCH] refactor: use `int_type_var(usize)` to simplify int_ops and conversions --- src/std_extensions/arithmetic/conversions.rs | 35 ++-- src/std_extensions/arithmetic/int_ops.rs | 204 +++++++------------ src/std_extensions/arithmetic/int_types.rs | 19 +- 3 files changed, 105 insertions(+), 153 deletions(-) diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index cbac81a85..63207c219 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -6,30 +6,27 @@ use crate::{ PRELUDE, }, type_row, - types::{FunctionType, PolyFuncType, Type}, + types::{FunctionType, PolyFuncType}, Extension, }; -use super::int_types::{int_type_var, INT_TYPE_ID}; +use super::int_types::int_type_var; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); -fn ftoi_sig( - int_type_var: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - let body = FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_type_var)]); +fn ftoi_sig(temp_reg: &ExtensionRegistry) -> Result { + let body = FunctionType::new( + type_row![FLOAT64_TYPE], + vec![sum_with_error(int_type_var(0))], + ); PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg) } -fn itof_sig( - int_type_var: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - let body = FunctionType::new(vec![int_type_var], type_row![FLOAT64_TYPE]); +fn itof_sig(temp_reg: &ExtensionRegistry) -> Result { + let body = FunctionType::new(vec![int_type_var(0)], type_row![FLOAT64_TYPE]); PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg) } @@ -43,12 +40,8 @@ pub fn extension() -> Extension { super::float_types::EXTENSION_ID, ]), ); - let int_types_extension = super::int_types::extension(); - let int_type_def = int_types_extension.get_type(&INT_TYPE_ID).unwrap(); - let int_type_var = int_type_var(0, int_type_def).unwrap(); let temp_reg: ExtensionRegistry = [ - extension.clone(), - int_types_extension, + super::int_types::EXTENSION.to_owned(), super::float_types::extension(), PRELUDE.to_owned(), ] @@ -57,28 +50,28 @@ pub fn extension() -> Extension { .add_op_type_scheme_simple( "trunc_u".into(), "float to unsigned int".to_owned(), - ftoi_sig(int_type_var.clone(), &temp_reg).unwrap(), + ftoi_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "trunc_s".into(), "float to signed int".to_owned(), - ftoi_sig(int_type_var.clone(), &temp_reg).unwrap(), + ftoi_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "convert_u".into(), "unsigned int to float".to_owned(), - itof_sig(int_type_var.clone(), &temp_reg).unwrap(), + itof_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "convert_s".into(), "signed int to float".to_owned(), - itof_sig(int_type_var, &temp_reg).unwrap(), + itof_sig(&temp_reg).unwrap(), ) .unwrap(); diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 18afe2655..6fdb8a4a3 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -1,6 +1,6 @@ //! Basic integer operations. -use super::int_types::{get_log_width, int_type, int_type_var, INT_TYPE_ID, LOG_WIDTH_TYPE_PARAM}; +use super::int_types::{get_log_width, int_type, int_type_var, LOG_WIDTH_TYPE_PARAM}; use crate::extension::prelude::{sum_with_error, BOOL_T}; use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::type_row; @@ -54,52 +54,36 @@ fn int_polytype( ) } -fn itob_sig( - int_type_var: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - int_polytype(1, vec![int_type_var], type_row![BOOL_T], temp_reg) +fn itob_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(1, vec![int_type_var(0)], type_row![BOOL_T], temp_reg) } -fn btoi_sig( - int_type_var: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - int_polytype(1, type_row![BOOL_T], vec![int_type_var], temp_reg) +fn btoi_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(1, type_row![BOOL_T], vec![int_type_var(0)], temp_reg) } -fn icmp_sig( - int_type_var: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - int_polytype(1, vec![int_type_var; 2], type_row![BOOL_T], temp_reg) +fn icmp_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(1, vec![int_type_var(0); 2], type_row![BOOL_T], temp_reg) } -fn ibinop_sig( - int_type_var: Type, - temp_reg: &ExtensionRegistry, -) -> Result { +fn ibinop_sig(temp_reg: &ExtensionRegistry) -> Result { + let int_type_var = int_type_var(0); + int_polytype( 1, vec![int_type_var.clone(); 2], - vec![int_type_var.clone()], + vec![int_type_var], temp_reg, ) } -fn iunop_sig( - int_type_var: Type, - temp_reg: &ExtensionRegistry, -) -> Result { +fn iunop_sig(temp_reg: &ExtensionRegistry) -> Result { + let int_type_var = int_type_var(0); int_polytype(1, vec![int_type_var.clone()], vec![int_type_var], temp_reg) } -fn idivmod_checked_sig( - int_type_var_0: Type, - int_type_var_1: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - let intpair: TypeRow = vec![int_type_var_0, int_type_var_1].into(); +fn idivmod_checked_sig(temp_reg: &ExtensionRegistry) -> Result { + let intpair: TypeRow = vec![int_type_var(0), int_type_var(1)].into(); int_polytype( 2, intpair.clone(), @@ -108,78 +92,44 @@ fn idivmod_checked_sig( ) } -fn idivmod_sig( - int_type_var_0: Type, - int_type_var_1: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - let intpair: TypeRow = vec![int_type_var_0, int_type_var_1].into(); +fn idivmod_sig(temp_reg: &ExtensionRegistry) -> Result { + let intpair: TypeRow = vec![int_type_var(0), int_type_var(1)].into(); int_polytype(2, intpair.clone(), vec![Type::new_tuple(intpair)], temp_reg) } -fn idiv_checked_sig( - int_type_var_0: Type, - int_type_var_1: Type, - temp_reg: &ExtensionRegistry, -) -> Result { +fn idiv_checked_sig(temp_reg: &ExtensionRegistry) -> Result { int_polytype( 2, - vec![int_type_var_0.clone(), int_type_var_1], - vec![sum_with_error(int_type_var_0)], + vec![int_type_var(1)], + vec![sum_with_error(int_type_var(0))], temp_reg, ) } -fn idiv_sig( - int_type_var_0: Type, - int_type_var_1: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - int_polytype( - 2, - vec![int_type_var_0.clone(), int_type_var_1], - vec![int_type_var_0], - temp_reg, - ) +fn idiv_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(2, vec![int_type_var(1)], vec![int_type_var(0)], temp_reg) } -fn imod_checked_sig( - int_type_var_0: Type, - int_type_var_1: Type, - temp_reg: &ExtensionRegistry, -) -> Result { +fn imod_checked_sig(temp_reg: &ExtensionRegistry) -> Result { int_polytype( 2, - vec![int_type_var_0, int_type_var_1.clone()], - vec![sum_with_error(int_type_var_1)], + vec![int_type_var(0), int_type_var(1).clone()], + vec![sum_with_error(int_type_var(1))], temp_reg, ) } -fn imod_sig( - int_type_var_0: Type, - int_type_var_1: Type, - temp_reg: &ExtensionRegistry, -) -> Result { +fn imod_sig(temp_reg: &ExtensionRegistry) -> Result { int_polytype( 2, - vec![int_type_var_0, int_type_var_1.clone()], - vec![int_type_var_1], + vec![int_type_var(0), int_type_var(1).clone()], + vec![int_type_var(1)], temp_reg, ) } -fn ish_sig( - int_type_var_0: Type, - int_type_var_1: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - int_polytype( - 2, - vec![int_type_var_0.clone(), int_type_var_1], - vec![int_type_var_0], - temp_reg, - ) +fn ish_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(2, vec![int_type_var(1)], vec![int_type_var(0)], temp_reg) } /// Extension for basic integer operations. @@ -188,13 +138,13 @@ pub fn extension() -> Extension { EXTENSION_ID, ExtensionSet::singleton(&super::int_types::EXTENSION_ID), ); - let int_types_extension = super::int_types::extension(); - let int_type_def = int_types_extension.get_type(&INT_TYPE_ID).unwrap(); - let int_type_var_0 = int_type_var(0, int_type_def).unwrap(); - let int_type_var_1 = int_type_var(1, int_type_def).unwrap(); - let temp_reg: ExtensionRegistry = - [extension.clone(), int_types_extension, PRELUDE.to_owned()].into(); + let temp_reg: ExtensionRegistry = [ + extension.clone(), + super::int_types::EXTENSION.to_owned(), + PRELUDE.to_owned(), + ] + .into(); extension .add_op_custom_sig_simple( @@ -233,140 +183,140 @@ pub fn extension() -> Extension { .add_op_type_scheme_simple( "itobool".into(), "convert to bool (1 is true, 0 is false)".to_owned(), - itob_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + itob_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ifrombool".into(), "convert from bool (1 is true, 0 is false)".to_owned(), - btoi_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + btoi_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ieq".into(), "equality test".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ine".into(), "inequality test".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ilt_u".into(), "\"less than\" as unsigned integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ilt_s".into(), "\"less than\" as signed integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "igt_u".into(), "\"greater than\" as unsigned integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "igt_s".into(), "\"greater than\" as signed integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ile_u".into(), "\"less than or equal\" as unsigned integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ile_s".into(), "\"less than or equal\" as signed integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ige_u".into(), "\"greater than or equal\" as unsigned integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ige_s".into(), "\"greater than or equal\" as signed integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "imax_u".into(), "maximum of unsigned integers".to_owned(), - ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "imax_s".into(), "maximum of signed integers".to_owned(), - ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "imin_u".into(), "minimum of unsigned integers".to_owned(), - ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "imin_s".into(), "minimum of signed integers".to_owned(), - ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "iadd".into(), "addition modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "isub".into(), "subtraction modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ineg".into(), "negation modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - iunop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + iunop_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "imul".into(), "multiplication modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension @@ -375,7 +325,7 @@ pub fn extension() -> Extension { "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ q*m+r=n, 0<=r Extension { "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ q*m+r=n, 0<=r Extension { "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^M, generates \ signed q and unsigned r where q*m+r=n, 0<=r Extension { "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^M, generates \ signed q and unsigned r where q*m+r=n, 0<=r Extension { "shift first input left by k bits where k is unsigned interpretation of second input \ (leftmost bits dropped, rightmost bits set to zero" .to_owned(), - ish_sig(int_type_var_0.clone(), int_type_var_1.clone(), &temp_reg).unwrap(), + ish_sig(&temp_reg).unwrap(), ) .unwrap(); extension @@ -511,7 +461,7 @@ pub fn extension() -> Extension { "shift first input right by k bits where k is unsigned interpretation of second input \ (rightmost bits dropped, leftmost bits set to zero)" .to_owned(), - ish_sig(int_type_var_0.clone(), int_type_var_1.clone(), &temp_reg).unwrap(), + ish_sig(&temp_reg).unwrap(), ) .unwrap(); extension @@ -520,7 +470,7 @@ pub fn extension() -> Extension { "rotate first input left by k bits where k is unsigned interpretation of second input \ (leftmost bits replace rightmost bits)" .to_owned(), - ish_sig(int_type_var_0.clone(), int_type_var_1.clone(), &temp_reg).unwrap(), + ish_sig(&temp_reg).unwrap(), ) .unwrap(); extension @@ -529,7 +479,7 @@ pub fn extension() -> Extension { "rotate first input right by k bits where k is unsigned interpretation of second input \ (rightmost bits replace leftmost bits)" .to_owned(), - ish_sig(int_type_var_0.clone(), int_type_var_1.clone(), &temp_reg).unwrap(), + ish_sig( &temp_reg).unwrap(), ) .unwrap(); diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index 526ea9d67..7a67de28a 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -5,7 +5,7 @@ use std::num::NonZeroU64; use smol_str::SmolStr; use crate::{ - extension::{ExtensionId, SignatureError, TypeDef}, + extension::ExtensionId, types::{ type_param::{TypeArg, TypeArgError, TypeParam}, ConstTypeError, CustomCheckFailure, CustomType, Type, TypeBound, @@ -198,11 +198,20 @@ pub fn extension() -> Extension { extension } +lazy_static! { + /// Lazy reference to int types extension. + pub static ref EXTENSION: Extension = extension(); +} + /// get an integer type variable, given the integer type definition -pub(super) fn int_type_var(var_id: usize, int_type_def: &TypeDef) -> Result { - Ok(Type::new_extension(int_type_def.instantiate(vec![ - TypeArg::new_var_use(var_id, LOG_WIDTH_TYPE_PARAM), - ])?)) +pub(super) fn int_type_var(var_id: usize) -> Type { + Type::new_extension( + EXTENSION + .get_type(&INT_TYPE_ID) + .unwrap() + .instantiate(vec![TypeArg::new_var_use(var_id, LOG_WIDTH_TYPE_PARAM)]) + .unwrap(), + ) } #[cfg(test)] mod test {