From fffb696defe6ebb86161410172938051b9451c67 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 9 Nov 2023 13:57:25 +0000 Subject: [PATCH 01/15] feat: add "simple" version of `add_op_type_scheme` --- src/extension/op_def.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 6a6ba0f97..60916914a 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -331,6 +331,23 @@ impl Extension { SignatureFunc::TypeScheme(type_scheme), ) } + + /// Create an OpDef with a signature (inputs+outputs) read from e.g. + /// declarative YAML; and no "misc" or "lowering functions" defined. + pub fn add_op_type_scheme_simple( + &mut self, + name: SmolStr, + description: String, + type_scheme: PolyFuncType, + ) -> Result<&OpDef, ExtensionBuildError> { + self.add_op( + name, + description, + Default::default(), + vec![], + SignatureFunc::TypeScheme(type_scheme), + ) + } } #[cfg(test)] From 4ec9c756067c66dda4bb67ddab0fe196e08a1b92 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 10 Nov 2023 13:34:16 +0000 Subject: [PATCH 02/15] refactor: use PolyFuncType for list definition --- src/std_extensions/collections.rs | 68 +++++++++++++++---------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index 40b10f6a2..7db1225e2 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -5,10 +5,10 @@ use serde::{Deserialize, Serialize}; use smol_str::SmolStr; use crate::{ - extension::{ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound}, + extension::{ExtensionId, ExtensionRegistry, TypeDef, TypeDefBound}, types::{ type_param::{TypeArg, TypeParam}, - CustomCheckFailure, CustomType, FunctionType, Type, TypeBound, TypeRow, + CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeBound, }, values::{CustomConst, Value}, Extension, @@ -59,6 +59,7 @@ impl CustomConst for ListValue { crate::values::downcast_equal_consts(self, other) } } +const TP: TypeParam = TypeParam::Type(TypeBound::Any); fn extension() -> Extension { let mut extension = Extension::new(EXTENSION_NAME); @@ -66,43 +67,52 @@ fn extension() -> Extension { extension .add_type( LIST_TYPENAME, - vec![TypeParam::Type(TypeBound::Any)], + vec![TP], "Generic dynamically sized list of type T.".into(), TypeDefBound::FromParams(vec![0]), ) .unwrap(); + let temp_reg: ExtensionRegistry = [extension.clone()].into(); + let list_type_def = extension.get_type(&LIST_TYPENAME).unwrap(); + + let (l, e) = list_types(list_type_def); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( POP_NAME, "Pop from back of list".into(), - vec![TypeParam::Type(TypeBound::Any)], - move |args: &[TypeArg]| { - let (list_type, element_type) = list_types(args)?; - Ok(FunctionType { - input: TypeRow::from(vec![list_type.clone()]), - output: TypeRow::from(vec![list_type, element_type]), - extension_reqs: ExtensionSet::singleton(&EXTENSION_NAME), - }) - }, + PolyFuncType::new_validated( + vec![TP], + FunctionType::new(vec![l.clone()], vec![l.clone(), e.clone()]), + &temp_reg, + ) + .unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( PUSH_NAME, "Push to back of list".into(), - vec![TypeParam::Type(TypeBound::Any)], - move |args: &[TypeArg]| { - let (list_type, element_type) = list_types(args)?; - Ok(FunctionType { - output: TypeRow::from(vec![list_type.clone()]), - input: TypeRow::from(vec![list_type, element_type]), - extension_reqs: ExtensionSet::singleton(&EXTENSION_NAME), - }) - }, + PolyFuncType::new_validated( + vec![TP], + FunctionType::new(vec![l.clone(), e], vec![l]), + &temp_reg, + ) + .unwrap(), ) .unwrap(); extension } + +fn list_types(list_type_def: &TypeDef) -> (Type, Type) { + let elem_type = Type::new_var_use(0, TypeBound::Any); + let list_type = Type::new_extension( + list_type_def + .instantiate(vec![TypeArg::new_var_use(0, TP)]) + .unwrap(), + ); + (list_type, elem_type) +} + lazy_static! { /// Collections extension definition. pub static ref EXTENSION: Extension = extension(); @@ -112,16 +122,6 @@ fn get_type(name: &str) -> &TypeDef { EXTENSION.get_type(name).unwrap() } -fn list_types(args: &[TypeArg]) -> Result<(Type, Type), SignatureError> { - let list_custom_type = get_type(&LIST_TYPENAME).instantiate(args)?; - let [TypeArg::Type { ty: element_type }] = args else { - panic!("should be checked by def.") - }; - - let list_type: Type = Type::new_extension(list_custom_type); - Ok((list_type, element_type.clone())) -} - #[cfg(test)] mod test { use crate::{ @@ -130,7 +130,7 @@ mod test { OpDef, PRELUDE, }, std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE}, - types::{type_param::TypeArg, Type}, + types::{type_param::TypeArg, Type, TypeRow}, Extension, }; From 7eb7bc8568e0d93405741dcd42a74f2211c0ce18 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 10 Nov 2023 13:46:56 +0000 Subject: [PATCH 03/15] refactor: use type scheme for logic not --- src/std_extensions/logic.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index 978c520c8..30fb317f1 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -34,11 +34,10 @@ fn extension() -> Extension { let mut extension = Extension::new(EXTENSION_ID); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( SmolStr::new_inline(NOT_NAME), "logical 'not'".into(), - vec![], - |_arg_values: &[TypeArg]| Ok(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])), + FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]).into(), ) .unwrap(); From 3f3e5ea98e77b1bab04164219fcd183ab5f2ac50 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 10 Nov 2023 14:01:13 +0000 Subject: [PATCH 04/15] test: improve logic extension test coverage --- src/hugr/validate/test.rs | 10 ++++----- src/std_extensions/logic.rs | 42 ++++++++++++++++--------------------- 2 files changed, 23 insertions(+), 29 deletions(-) diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index a4a4dd1a2..2cf5d1812 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -13,7 +13,7 @@ use crate::macros::const_extension_ids; use crate::ops::dataflow::IOTrait; use crate::ops::{self, LeafOp, OpType}; use crate::std_extensions::logic; -use crate::std_extensions::logic::test::{and_op, not_op}; +use crate::std_extensions::logic::test::{and_op, not_op, or_op}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::{CustomType, FunctionType, Type, TypeBound, TypeRow}; use crate::{type_row, Direction, IncomingPort, Node}; @@ -612,12 +612,12 @@ fn dfg_with_cycles() -> Result<(), HugrError> { type_row![BOOL_T], )); let [input, output] = h.get_io(h.root()).unwrap(); - let and = h.add_node_with_parent(h.root(), and_op())?; + let or = h.add_node_with_parent(h.root(), or_op())?; let not1 = h.add_node_with_parent(h.root(), not_op())?; let not2 = h.add_node_with_parent(h.root(), not_op())?; - h.connect(input, 0, and, 0)?; - h.connect(and, 0, not1, 0)?; - h.connect(not1, 0, and, 1)?; + h.connect(input, 0, or, 0)?; + h.connect(or, 0, not1, 0)?; + h.connect(not1, 0, or, 1)?; h.connect(input, 1, not2, 0)?; h.connect(not2, 0, output, 0)?; // The graph contains a cycle: diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index 30fb317f1..d432fadad 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -7,7 +7,7 @@ use crate::{ extension::{prelude::BOOL_T, ExtensionId}, ops, type_row, types::{ - type_param::{TypeArg, TypeArgError, TypeParam}, + type_param::{TypeArg, TypeParam}, FunctionType, }, Extension, @@ -47,19 +47,12 @@ fn extension() -> Extension { "logical 'and'".into(), vec![H_INT], |arg_values: &[TypeArg]| { - let a = arg_values.iter().exactly_one().unwrap(); - let n: u64 = match a { - TypeArg::BoundedNat { n } => *n, - _ => { - return Err(TypeArgError::TypeMismatch { - arg: a.clone(), - param: H_INT, - } - .into()); - } + let TypeArg::BoundedNat { n } = arg_values.iter().exactly_one().unwrap() else { + panic!("should be covered by validation.") }; + Ok(FunctionType::new( - vec![BOOL_T; n as usize], + vec![BOOL_T; *n as usize], type_row![BOOL_T], )) }, @@ -72,19 +65,12 @@ fn extension() -> Extension { "logical 'or'".into(), vec![H_INT], |arg_values: &[TypeArg]| { - let a = arg_values.iter().exactly_one().unwrap(); - let n: u64 = match a { - TypeArg::BoundedNat { n } => *n, - _ => { - return Err(TypeArgError::TypeMismatch { - arg: a.clone(), - param: H_INT, - } - .into()); - } + let TypeArg::BoundedNat { n } = arg_values.iter().exactly_one().unwrap() else { + panic!("should be covered by validation.") }; + Ok(FunctionType::new( - vec![BOOL_T; n as usize], + vec![BOOL_T; *n as usize], type_row![BOOL_T], )) }, @@ -114,7 +100,7 @@ pub(crate) mod test { Extension, }; - use super::{extension, AND_NAME, EXTENSION, FALSE_NAME, NOT_NAME, TRUE_NAME}; + use super::{extension, AND_NAME, EXTENSION, FALSE_NAME, NOT_NAME, OR_NAME, TRUE_NAME}; #[test] fn test_logic_extension() { @@ -143,6 +129,14 @@ pub(crate) mod test { .into() } + /// Generate a logic extension and "or" operation over [`crate::prelude::BOOL_T`] + pub(crate) fn or_op() -> LeafOp { + EXTENSION + .instantiate_extension_op(OR_NAME, [TypeArg::BoundedNat { n: 2 }], &EMPTY_REG) + .unwrap() + .into() + } + /// Generate a logic extension and "not" operation over [`crate::prelude::BOOL_T`] pub(crate) fn not_op() -> LeafOp { EXTENSION From 8be903e239dab1c26a0468ba5c0d216e0dc9f962 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 10 Nov 2023 14:04:11 +0000 Subject: [PATCH 05/15] test: improve float_types coverage --- src/std_extensions/arithmetic/float_types.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/std_extensions/arithmetic/float_types.rs b/src/std_extensions/arithmetic/float_types.rs index 5413b6183..32b7815ef 100644 --- a/src/std_extensions/arithmetic/float_types.rs +++ b/src/std_extensions/arithmetic/float_types.rs @@ -100,6 +100,11 @@ mod test { fn test_float_consts() { let const_f64_1 = ConstF64::new(1.0); let const_f64_2 = ConstF64::new(2.0); + + assert_eq!(const_f64_1.value(), 1.0); + assert_eq!(*const_f64_2, 2.0); + assert_eq!(const_f64_1.name(), "f64(1)"); + assert!(const_f64_1.equal_consts(&ConstF64::new(1.0))); assert_ne!(const_f64_1, const_f64_2); assert_eq!(const_f64_1, ConstF64::new(1.0)); } From f5f994b26ea6fc30e417cd04ecc5cb4956ba341d Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 10 Nov 2023 14:08:26 +0000 Subject: [PATCH 06/15] refactor: use PolyFuncType everywhere in float_ops --- src/std_extensions/arithmetic/float_ops.rs | 78 ++++++++-------------- 1 file changed, 28 insertions(+), 50 deletions(-) diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index 5a8907300..fe131b5b2 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -1,9 +1,9 @@ //! Basic floating-point operations. use crate::{ - extension::{ExtensionId, ExtensionSet, SignatureError}, + extension::{ExtensionId, ExtensionSet}, type_row, - types::{type_param::TypeArg, FunctionType}, + types::{FunctionType, PolyFuncType}, Extension, }; @@ -12,25 +12,20 @@ use super::float_types::FLOAT64_TYPE; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float"); -fn fcmp_sig(_arg_values: &[TypeArg]) -> Result { - Ok(FunctionType::new( +fn fcmp_sig() -> PolyFuncType { + FunctionType::new( type_row![FLOAT64_TYPE; 2], type_row![crate::extension::prelude::BOOL_T], - )) + ) + .into() } -fn fbinop_sig(_arg_values: &[TypeArg]) -> Result { - Ok(FunctionType::new( - type_row![FLOAT64_TYPE; 2], - type_row![FLOAT64_TYPE], - )) +fn fbinop_sig() -> PolyFuncType { + FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE]).into() } -fn funop_sig(_arg_values: &[TypeArg]) -> Result { - Ok(FunctionType::new( - type_row![FLOAT64_TYPE], - type_row![FLOAT64_TYPE], - )) +fn funop_sig() -> PolyFuncType { + FunctionType::new(type_row![FLOAT64_TYPE], type_row![FLOAT64_TYPE]).into() } /// Extension for basic arithmetic operations. @@ -41,77 +36,60 @@ pub fn extension() -> Extension { ); extension - .add_op_custom_sig_simple("feq".into(), "equality test".to_owned(), vec![], fcmp_sig) + .add_op_type_scheme_simple("feq".into(), "equality test".to_owned(), fcmp_sig()) .unwrap(); extension - .add_op_custom_sig_simple("fne".into(), "inequality test".to_owned(), vec![], fcmp_sig) + .add_op_type_scheme_simple("fne".into(), "inequality test".to_owned(), fcmp_sig()) .unwrap(); extension - .add_op_custom_sig_simple("flt".into(), "\"less than\"".to_owned(), vec![], fcmp_sig) + .add_op_type_scheme_simple("flt".into(), "\"less than\"".to_owned(), fcmp_sig()) .unwrap(); extension - .add_op_custom_sig_simple( - "fgt".into(), - "\"greater than\"".to_owned(), - vec![], - fcmp_sig, - ) + .add_op_type_scheme_simple("fgt".into(), "\"greater than\"".to_owned(), fcmp_sig()) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "fle".into(), "\"less than or equal\"".to_owned(), - vec![], - fcmp_sig, + fcmp_sig(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "fge".into(), "\"greater than or equal\"".to_owned(), - vec![], - fcmp_sig, + fcmp_sig(), ) .unwrap(); extension - .add_op_custom_sig_simple("fmax".into(), "maximum".to_owned(), vec![], fbinop_sig) + .add_op_type_scheme_simple("fmax".into(), "maximum".to_owned(), fbinop_sig()) .unwrap(); extension - .add_op_custom_sig_simple("fmin".into(), "minimum".to_owned(), vec![], fbinop_sig) + .add_op_type_scheme_simple("fmin".into(), "minimum".to_owned(), fbinop_sig()) .unwrap(); extension - .add_op_custom_sig_simple("fadd".into(), "addition".to_owned(), vec![], fbinop_sig) + .add_op_type_scheme_simple("fadd".into(), "addition".to_owned(), fbinop_sig()) .unwrap(); extension - .add_op_custom_sig_simple("fsub".into(), "subtraction".to_owned(), vec![], fbinop_sig) + .add_op_type_scheme_simple("fsub".into(), "subtraction".to_owned(), fbinop_sig()) .unwrap(); extension - .add_op_custom_sig_simple("fneg".into(), "negation".to_owned(), vec![], funop_sig) + .add_op_type_scheme_simple("fneg".into(), "negation".to_owned(), funop_sig()) .unwrap(); extension - .add_op_custom_sig_simple( - "fabs".into(), - "absolute value".to_owned(), - vec![], - funop_sig, - ) + .add_op_type_scheme_simple("fabs".into(), "absolute value".to_owned(), funop_sig()) .unwrap(); extension - .add_op_custom_sig_simple( - "fmul".into(), - "multiplication".to_owned(), - vec![], - fbinop_sig, - ) + .add_op_type_scheme_simple("fmul".into(), "multiplication".to_owned(), fbinop_sig()) .unwrap(); extension - .add_op_custom_sig_simple("fdiv".into(), "division".to_owned(), vec![], fbinop_sig) + .add_op_type_scheme_simple("fdiv".into(), "division".to_owned(), fbinop_sig()) .unwrap(); extension - .add_op_custom_sig_simple("ffloor".into(), "floor".to_owned(), vec![], funop_sig) + .add_op_type_scheme_simple("ffloor".into(), "floor".to_owned(), funop_sig()) .unwrap(); extension - .add_op_custom_sig_simple("fceil".into(), "ceiling".to_owned(), vec![], funop_sig) + .add_op_type_scheme_simple("fceil".into(), "ceiling".to_owned(), funop_sig()) .unwrap(); extension From 6355fa2e74c739419bc1eb5ace6c108e6747238e Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 10 Nov 2023 14:23:45 +0000 Subject: [PATCH 07/15] test: improve int_ops coverage --- src/std_extensions/arithmetic/int_types.rs | 37 +++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index 536be3774..dfa56d2e2 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -232,6 +232,7 @@ mod test { assert_ne!(const_u32_7, const_u64_7); assert_ne!(const_u32_7, const_u32_8); assert_eq!(const_u32_7, ConstIntU::new(5, 7)); + assert_matches!( ConstIntU::new(3, 256), Err(ConstTypeError::CustomCheckFail(_)) @@ -244,6 +245,40 @@ mod test { ConstIntS::new(3, 128), Err(ConstTypeError::CustomCheckFail(_)) ); - assert_matches!(ConstIntS::new(3, -128), Ok(_)); + assert!(ConstIntS::new(3, -128).is_ok()); + + let const_u32_7 = const_u32_7.unwrap(); + assert!(const_u32_7.equal_consts(&ConstIntU::new(5, 7).unwrap())); + assert_eq!(const_u32_7.log_width(), 5); + assert_eq!(const_u32_7.value(), 7); + assert!(const_u32_7 + .check_custom_type(&int_custom_type(TypeArg::BoundedNat { n: 5 })) + .is_ok()); + assert!(const_u32_7 + .check_custom_type(&int_custom_type(TypeArg::BoundedNat { n: 6 })) + .is_err()); + + assert_eq!(const_u32_7.name(), "u5(7)"); + assert!(const_u32_7 + .check_custom_type(&int_custom_type(TypeArg::BoundedNat { n: 19 })) + .is_err()); + + let const_i32_2 = ConstIntS::new(5, -2).unwrap(); + assert!(const_i32_2.equal_consts(&ConstIntS::new(5, -2).unwrap())); + assert_eq!(const_i32_2.log_width(), 5); + assert_eq!(const_i32_2.value(), -2); + assert!(const_i32_2 + .check_custom_type(&int_custom_type(TypeArg::BoundedNat { n: 5 })) + .is_ok()); + assert!(const_i32_2 + .check_custom_type(&int_custom_type(TypeArg::BoundedNat { n: 6 })) + .is_err()); + assert!(const_i32_2 + .check_custom_type(&int_custom_type(TypeArg::BoundedNat { n: 19 })) + .is_err()); + assert_eq!(const_i32_2.name(), "i5(-2)"); + + ConstIntS::new(50, -2).unwrap_err(); + ConstIntU::new(50, 2).unwrap_err(); } } From 76b1995cb43644d3090b8e5c86996e7a3be6b240 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 10 Nov 2023 14:40:40 +0000 Subject: [PATCH 08/15] refactor: use PolyFuncType for conversions extension --- src/std_extensions/arithmetic/conversions.rs | 71 ++++++++++++-------- src/std_extensions/arithmetic/int_types.rs | 2 +- 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index d07b97f62..77a591de8 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -1,36 +1,46 @@ //! Conversions between integer and floating-point values. use crate::{ - extension::{ExtensionId, ExtensionSet, SignatureError}, + extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, PRELUDE}, type_row, - types::{type_param::TypeArg, FunctionType, Type}, - utils::collect_array, + types::{type_param::TypeArg, FunctionType, PolyFuncType, Type}, Extension, }; -use super::int_types::int_type; +use super::int_types::INT_TYPE_ID; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); -fn ftoi_sig(arg_values: &[TypeArg]) -> Result { - let [arg] = collect_array(arg_values); - Ok(FunctionType::new( +fn int_type_var(var_id: usize, int_type_def: &TypeDef) -> Result { + Ok(Type::new_extension(int_type_def.instantiate(vec![ + TypeArg::new_var_use(var_id, LOG_WIDTH_TYPE_PARAM), + ])?)) +} + +fn ftoi_sig( + int_type_var: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + let body = FunctionType::new( type_row![FLOAT64_TYPE], vec![Type::new_sum(vec![ - int_type(arg.clone()), + int_type_var, crate::extension::prelude::ERROR_TYPE, ])], - )) + ); + + PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg) } -fn itof_sig(arg_values: &[TypeArg]) -> Result { - let [arg] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg.clone())], - type_row![FLOAT64_TYPE], - )) +fn itof_sig( + int_type_var: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + let body = FunctionType::new(vec![int_type_var], type_row![FLOAT64_TYPE]); + + PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg) } /// Extension for basic arithmetic operations. @@ -42,37 +52,42 @@ pub fn extension() -> Extension { super::float_types::EXTENSION_ID, ]), ); - + let int_types_extension = super::int_types::extension(); + let int_type_def = int_types_extension.get_type(&INT_TYPE_ID).unwrap(); + let int_type_var = int_type_var(0, int_type_def).unwrap(); + let temp_reg: ExtensionRegistry = [ + extension.clone(), + int_types_extension, + super::float_types::extension(), + PRELUDE.to_owned(), + ] + .into(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "trunc_u".into(), "float to unsigned int".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ftoi_sig, + ftoi_sig(int_type_var.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "trunc_s".into(), "float to signed int".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ftoi_sig, + ftoi_sig(int_type_var.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "convert_u".into(), "unsigned int to float".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - itof_sig, + itof_sig(int_type_var.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "convert_s".into(), "signed int to float".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - itof_sig, + itof_sig(int_type_var, &temp_reg).unwrap(), ) .unwrap(); diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index dfa56d2e2..0ad7c7ad6 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -18,7 +18,7 @@ use lazy_static::lazy_static; pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int.types"); /// Identifier for the integer type. -const INT_TYPE_ID: SmolStr = SmolStr::new_inline("int"); +pub const INT_TYPE_ID: SmolStr = SmolStr::new_inline("int"); fn int_custom_type(width_arg: TypeArg) -> CustomType { CustomType::new(INT_TYPE_ID, [width_arg], EXTENSION_ID, TypeBound::Eq) From 1508ead3be1cadc682cea3976f2477cf5e70a67b Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 10 Nov 2023 15:15:40 +0000 Subject: [PATCH 09/15] refactor: use PolyFunctType to define int ops extension --- src/std_extensions/arithmetic/conversions.rs | 12 +- src/std_extensions/arithmetic/int_ops.rs | 427 ++++++++++--------- src/std_extensions/arithmetic/int_types.rs | 8 +- 3 files changed, 242 insertions(+), 205 deletions(-) diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 77a591de8..4ce3e7e19 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -1,24 +1,18 @@ //! Conversions between integer and floating-point values. use crate::{ - extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, PRELUDE}, + extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, PRELUDE}, type_row, - types::{type_param::TypeArg, FunctionType, PolyFuncType, Type}, + types::{FunctionType, PolyFuncType, Type}, Extension, }; -use super::int_types::INT_TYPE_ID; +use super::int_types::{int_type_var, INT_TYPE_ID}; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); -fn int_type_var(var_id: usize, int_type_def: &TypeDef) -> Result { - Ok(Type::new_extension(int_type_def.instantiate(vec![ - TypeArg::new_var_use(var_id, LOG_WIDTH_TYPE_PARAM), - ])?)) -} - fn ftoi_sig( int_type_var: Type, temp_reg: &ExtensionRegistry, diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 4a6aae198..cac087d6a 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -1,9 +1,10 @@ //! Basic integer operations. -use super::int_types::{get_log_width, int_type, type_arg, LOG_WIDTH_TYPE_PARAM}; +use super::int_types::{get_log_width, int_type, int_type_var, INT_TYPE_ID, LOG_WIDTH_TYPE_PARAM}; use crate::extension::prelude::{BOOL_T, ERROR_TYPE}; +use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::type_row; -use crate::types::FunctionType; +use crate::types::{FunctionType, PolyFuncType}; use crate::utils::collect_array; use crate::{ extension::{ExtensionId, ExtensionSet, SignatureError}, @@ -40,100 +41,145 @@ fn inarrow_sig(arg_values: &[TypeArg]) -> Result { )) } -fn itob_sig(_arg_values: &[TypeArg]) -> Result { - Ok(FunctionType::new( - vec![int_type(type_arg(0))], - type_row![BOOL_T], - )) +fn int_polytype( + n_vars: usize, + input: impl Into, + output: impl Into, + temp_reg: &ExtensionRegistry, +) -> Result { + PolyFuncType::new_validated( + vec![LOG_WIDTH_TYPE_PARAM; n_vars], + FunctionType::new(input, output), + temp_reg, + ) } -fn btoi_sig(_arg_values: &[TypeArg]) -> Result { - Ok(FunctionType::new( - type_row![BOOL_T], - vec![int_type(type_arg(0))], - )) +fn itob_sig( + int_type_var: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype(1, vec![int_type_var], type_row![BOOL_T], temp_reg) } -fn icmp_sig(arg_values: &[TypeArg]) -> Result { - let [arg] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg.clone()); 2], - type_row![BOOL_T], - )) +fn btoi_sig( + int_type_var: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype(1, type_row![BOOL_T], vec![int_type_var], temp_reg) } -fn ibinop_sig(arg_values: &[TypeArg]) -> Result { - let [arg] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg.clone()); 2], - vec![int_type(arg.clone())], - )) +fn icmp_sig( + int_type_var: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype(1, vec![int_type_var; 2], type_row![BOOL_T], temp_reg) } -fn iunop_sig(arg_values: &[TypeArg]) -> Result { - let [arg] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg.clone())], - vec![int_type(arg.clone())], - )) +fn ibinop_sig( + int_type_var: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype( + 1, + vec![int_type_var.clone(); 2], + vec![int_type_var.clone()], + temp_reg, + ) } -fn idivmod_checked_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - let intpair: TypeRow = vec![int_type(arg0.clone()), int_type(arg1.clone())].into(); - Ok(FunctionType::new( +fn iunop_sig( + int_type_var: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype(1, vec![int_type_var.clone()], vec![int_type_var], temp_reg) +} + +fn idivmod_checked_sig( + int_type_var_0: Type, + int_type_var_1: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + let intpair: TypeRow = vec![int_type_var_0, int_type_var_1].into(); + int_polytype( + 2, intpair.clone(), vec![Type::new_sum(vec![Type::new_tuple(intpair), ERROR_TYPE])], - )) + temp_reg, + ) } -fn idivmod_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - let intpair: TypeRow = vec![int_type(arg0.clone()), int_type(arg1.clone())].into(); - Ok(FunctionType::new( - intpair.clone(), - vec![Type::new_tuple(intpair)], - )) +fn idivmod_sig( + int_type_var_0: Type, + int_type_var_1: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + let intpair: TypeRow = vec![int_type_var_0, int_type_var_1].into(); + int_polytype(2, intpair.clone(), vec![Type::new_tuple(intpair)], temp_reg) } -fn idiv_checked_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![Type::new_sum(vec![int_type(arg0.clone()), ERROR_TYPE])], - )) +fn idiv_checked_sig( + int_type_var_0: Type, + int_type_var_1: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype( + 2, + vec![int_type_var_0.clone(), int_type_var_1], + vec![Type::new_sum(vec![int_type_var_0, ERROR_TYPE])], + temp_reg, + ) } -fn idiv_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![int_type(arg0.clone())], - )) +fn idiv_sig( + int_type_var_0: Type, + int_type_var_1: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype( + 2, + vec![int_type_var_0.clone(), int_type_var_1], + vec![int_type_var_0], + temp_reg, + ) } -fn imod_checked_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])], - )) +fn imod_checked_sig( + int_type_var_0: Type, + int_type_var_1: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype( + 2, + vec![int_type_var_0, int_type_var_1.clone()], + vec![Type::new_sum(vec![int_type_var_1, ERROR_TYPE])], + temp_reg, + ) } -fn imod_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![int_type(arg1.clone())], - )) +fn imod_sig( + int_type_var_0: Type, + int_type_var_1: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype( + 2, + vec![int_type_var_0, int_type_var_1.clone()], + vec![int_type_var_1], + temp_reg, + ) } -fn ish_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![int_type(arg0.clone())], - )) +fn ish_sig( + int_type_var_0: Type, + int_type_var_1: Type, + temp_reg: &ExtensionRegistry, +) -> Result { + int_polytype( + 2, + vec![int_type_var_0.clone(), int_type_var_1], + vec![int_type_var_0], + temp_reg, + ) } /// Extension for basic integer operations. @@ -142,6 +188,13 @@ pub fn extension() -> Extension { EXTENSION_ID, ExtensionSet::singleton(&super::int_types::EXTENSION_ID), ); + let int_types_extension = super::int_types::extension(); + let int_type_def = int_types_extension.get_type(&INT_TYPE_ID).unwrap(); + let int_type_var_0 = int_type_var(0, int_type_def).unwrap(); + let int_type_var_1 = int_type_var(1, int_type_def).unwrap(); + + let temp_reg: ExtensionRegistry = + [extension.clone(), int_types_extension, PRELUDE.to_owned()].into(); extension .add_op_custom_sig_simple( @@ -177,347 +230,306 @@ pub fn extension() -> Extension { ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "itobool".into(), "convert to bool (1 is true, 0 is false)".to_owned(), - vec![], - itob_sig, + itob_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ifrombool".into(), "convert from bool (1 is true, 0 is false)".to_owned(), - vec![], - btoi_sig, + btoi_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ieq".into(), "equality test".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ine".into(), "inequality test".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ilt_u".into(), "\"less than\" as unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ilt_s".into(), "\"less than\" as signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "igt_u".into(), "\"greater than\" as unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "igt_s".into(), "\"greater than\" as signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ile_u".into(), "\"less than or equal\" as unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ile_s".into(), "\"less than or equal\" as signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ige_u".into(), "\"greater than or equal\" as unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ige_s".into(), "\"greater than or equal\" as signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imax_u".into(), "maximum of unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imax_s".into(), "maximum of signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imin_u".into(), "minimum of unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imin_s".into(), "minimum of signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "iadd".into(), "addition modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "isub".into(), "subtraction modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ineg".into(), "negation modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - iunop_sig, + iunop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imul".into(), "multiplication modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "idivmod_checked_u".into(), "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ q*m+r=n, 0<=r TypeArg { + TypeArg::BoundedNat { n } + } + #[test] + fn test_binary_signatures() { + let sig = iwiden_sig(&[ta(3), ta(4)]).unwrap(); + assert_eq!( + sig, + FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(4))],) + ); + + iwiden_sig(&[ta(4), ta(3)]).unwrap_err(); + + let sig = inarrow_sig(&[ta(2), ta(1)]).unwrap(); + assert_eq!( + sig, + FunctionType::new( + vec![int_type(ta(2))], + vec![Type::new_sum(vec![int_type(ta(1)), ERROR_TYPE])], + ) + ); + + inarrow_sig(&[ta(1), ta(2)]).unwrap_err(); + } } diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index 0ad7c7ad6..526ea9d67 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -5,7 +5,7 @@ use std::num::NonZeroU64; use smol_str::SmolStr; use crate::{ - extension::ExtensionId, + extension::{ExtensionId, SignatureError, TypeDef}, types::{ type_param::{TypeArg, TypeArgError, TypeParam}, ConstTypeError, CustomCheckFailure, CustomType, Type, TypeBound, @@ -198,6 +198,12 @@ pub fn extension() -> Extension { extension } +/// get an integer type variable, given the integer type definition +pub(super) fn int_type_var(var_id: usize, int_type_def: &TypeDef) -> Result { + Ok(Type::new_extension(int_type_def.instantiate(vec![ + TypeArg::new_var_use(var_id, LOG_WIDTH_TYPE_PARAM), + ])?)) +} #[cfg(test)] mod test { use cool_asserts::assert_matches; From 09dbb1bf19bc6430197c27082bc1a36155dc7f8a Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Fri, 10 Nov 2023 15:25:40 +0000 Subject: [PATCH 10/15] refactor: use PolyFuncType for utils ops --- src/utils.rs | 41 ++++++++++++----------------------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/src/utils.rs b/src/utils.rs index de37f348c..9ce62ba37 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -34,23 +34,23 @@ pub(crate) mod test_quantum_extension { use crate::{ extension::{ prelude::{BOOL_T, QB_T}, - ExtensionId, ExtensionRegistry, SignatureError, PRELUDE, + ExtensionId, ExtensionRegistry, PRELUDE, }, ops::LeafOp, std_extensions::arithmetic::float_types::FLOAT64_TYPE, type_row, - types::{FunctionType, TypeArg}, + types::{FunctionType, PolyFuncType}, Extension, }; use lazy_static::lazy_static; - fn one_qb_func(_: &[TypeArg]) -> Result { - Ok(FunctionType::new_linear(type_row![QB_T])) + fn one_qb_func() -> PolyFuncType { + FunctionType::new_linear(type_row![QB_T]).into() } - fn two_qb_func(_: &[TypeArg]) -> Result { - Ok(FunctionType::new_linear(type_row![QB_T, QB_T])) + fn two_qb_func() -> PolyFuncType { + FunctionType::new_linear(type_row![QB_T, QB_T]).into() } /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("test.quantum"); @@ -58,42 +58,25 @@ pub(crate) mod test_quantum_extension { let mut extension = Extension::new(EXTENSION_ID); extension - .add_op_custom_sig_simple( - SmolStr::new_inline("H"), - "Hadamard".into(), - vec![], - one_qb_func, - ) + .add_op_type_scheme_simple(SmolStr::new_inline("H"), "Hadamard".into(), one_qb_func()) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( SmolStr::new_inline("RzF64"), "Rotation specified by float".into(), - vec![], - |_: &[_]| { - Ok(FunctionType::new( - type_row![QB_T, FLOAT64_TYPE], - type_row![QB_T], - )) - }, + FunctionType::new(type_row![QB_T, FLOAT64_TYPE], type_row![QB_T]).into(), ) .unwrap(); extension - .add_op_custom_sig_simple(SmolStr::new_inline("CX"), "CX".into(), vec![], two_qb_func) + .add_op_type_scheme_simple(SmolStr::new_inline("CX"), "CX".into(), two_qb_func()) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( SmolStr::new_inline("Measure"), "Measure a qubit, returning the qubit and the measurement result.".into(), - vec![], - |_arg_values: &[TypeArg]| { - Ok(FunctionType::new(type_row![QB_T], type_row![QB_T, BOOL_T])) - // TODO add logic as an extension delta when inference is - // done? - // https://github.com/CQCL-DEV/hugr/issues/425 - }, + FunctionType::new(type_row![QB_T], type_row![QB_T, BOOL_T]).into(), ) .unwrap(); From 3ddd6ab9c727299d8dc36c3302297452d2037c2f Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 12:12:01 +0000 Subject: [PATCH 11/15] address some review comments mainly `sum_with_error` --- src/extension/prelude.rs | 5 +++++ src/std_extensions/arithmetic/conversions.rs | 13 +++++------- src/std_extensions/arithmetic/int_ops.rs | 15 ++++++-------- src/std_extensions/collections.rs | 21 ++++++++++---------- src/std_extensions/logic.rs | 4 ++-- 5 files changed, 28 insertions(+), 30 deletions(-) diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index 98038fc18..37ad705bc 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -136,6 +136,11 @@ pub const ERROR_TYPE: Type = Type::new_extension(CustomType::new_simple( /// The string name of the error type. pub const ERROR_TYPE_NAME: SmolStr = SmolStr::new_inline("error"); +/// Return a Sum type with the first variant as the given type and the second an Error. +pub fn sum_with_error(ty: Type) -> Type { + Type::new_sum(vec![ty, ERROR_TYPE]) +} + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] /// Structure for holding constant usize values. pub struct ConstUsize(u64); diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 4ce3e7e19..cbac81a85 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -1,7 +1,10 @@ //! Conversions between integer and floating-point values. use crate::{ - extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, PRELUDE}, + extension::{ + prelude::sum_with_error, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, + PRELUDE, + }, type_row, types::{FunctionType, PolyFuncType, Type}, Extension, @@ -17,13 +20,7 @@ fn ftoi_sig( int_type_var: Type, temp_reg: &ExtensionRegistry, ) -> Result { - let body = FunctionType::new( - type_row![FLOAT64_TYPE], - vec![Type::new_sum(vec![ - int_type_var, - crate::extension::prelude::ERROR_TYPE, - ])], - ); + let body = FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_type_var)]); PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg) } diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index cac087d6a..18afe2655 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -1,7 +1,7 @@ //! Basic integer operations. use super::int_types::{get_log_width, int_type, int_type_var, INT_TYPE_ID, LOG_WIDTH_TYPE_PARAM}; -use crate::extension::prelude::{BOOL_T, ERROR_TYPE}; +use crate::extension::prelude::{sum_with_error, BOOL_T}; use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::type_row; use crate::types::{FunctionType, PolyFuncType}; @@ -37,7 +37,7 @@ fn inarrow_sig(arg_values: &[TypeArg]) -> Result { } Ok(FunctionType::new( vec![int_type(arg0.clone())], - vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])], + vec![sum_with_error(int_type(arg1.clone()))], )) } @@ -103,7 +103,7 @@ fn idivmod_checked_sig( int_polytype( 2, intpair.clone(), - vec![Type::new_sum(vec![Type::new_tuple(intpair), ERROR_TYPE])], + vec![sum_with_error(Type::new_tuple(intpair))], temp_reg, ) } @@ -125,7 +125,7 @@ fn idiv_checked_sig( int_polytype( 2, vec![int_type_var_0.clone(), int_type_var_1], - vec![Type::new_sum(vec![int_type_var_0, ERROR_TYPE])], + vec![sum_with_error(int_type_var_0)], temp_reg, ) } @@ -151,7 +151,7 @@ fn imod_checked_sig( int_polytype( 2, vec![int_type_var_0, int_type_var_1.clone()], - vec![Type::new_sum(vec![int_type_var_1, ERROR_TYPE])], + vec![sum_with_error(int_type_var_1)], temp_reg, ) } @@ -566,10 +566,7 @@ mod test { let sig = inarrow_sig(&[ta(2), ta(1)]).unwrap(); assert_eq!( sig, - FunctionType::new( - vec![int_type(ta(2))], - vec![Type::new_sum(vec![int_type(ta(1)), ERROR_TYPE])], - ) + FunctionType::new(vec![int_type(ta(2))], vec![sum_with_error(int_type(ta(1)))],) ); inarrow_sig(&[ta(1), ta(2)]).unwrap_err(); diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index 7db1225e2..cc830f930 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -75,7 +75,7 @@ fn extension() -> Extension { let temp_reg: ExtensionRegistry = [extension.clone()].into(); let list_type_def = extension.get_type(&LIST_TYPENAME).unwrap(); - let (l, e) = list_types(list_type_def); + let (l, e) = list_and_elem_type(list_type_def); extension .add_op_type_scheme_simple( POP_NAME, @@ -103,16 +103,6 @@ fn extension() -> Extension { extension } -fn list_types(list_type_def: &TypeDef) -> (Type, Type) { - let elem_type = Type::new_var_use(0, TypeBound::Any); - let list_type = Type::new_extension( - list_type_def - .instantiate(vec![TypeArg::new_var_use(0, TP)]) - .unwrap(), - ); - (list_type, elem_type) -} - lazy_static! { /// Collections extension definition. pub static ref EXTENSION: Extension = extension(); @@ -122,6 +112,15 @@ fn get_type(name: &str) -> &TypeDef { EXTENSION.get_type(name).unwrap() } +fn list_and_elem_type(list_type_def: &TypeDef) -> (Type, Type) { + let elem_type = Type::new_var_use(0, TypeBound::Any); + let list_type = Type::new_extension( + list_type_def + .instantiate(vec![TypeArg::new_var_use(0, TP)]) + .unwrap(), + ); + (list_type, elem_type) +} #[cfg(test)] mod test { use crate::{ diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index d432fadad..794b98c9d 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -47,7 +47,7 @@ fn extension() -> Extension { "logical 'and'".into(), vec![H_INT], |arg_values: &[TypeArg]| { - let TypeArg::BoundedNat { n } = arg_values.iter().exactly_one().unwrap() else { + let Ok(TypeArg::BoundedNat { n }) = arg_values.iter().exactly_one() else { panic!("should be covered by validation.") }; @@ -65,7 +65,7 @@ fn extension() -> Extension { "logical 'or'".into(), vec![H_INT], |arg_values: &[TypeArg]| { - let TypeArg::BoundedNat { n } = arg_values.iter().exactly_one().unwrap() else { + let Ok(TypeArg::BoundedNat { n }) = arg_values.iter().exactly_one() else { panic!("should be covered by validation.") }; From c068c403068193060d1f98a6bd5480f468590f30 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 12:32:38 +0000 Subject: [PATCH 12/15] refactor: use `int_type_var(usize)` to simplify int_ops and conversions --- src/std_extensions/arithmetic/conversions.rs | 35 ++-- src/std_extensions/arithmetic/int_ops.rs | 204 +++++++------------ src/std_extensions/arithmetic/int_types.rs | 19 +- 3 files changed, 105 insertions(+), 153 deletions(-) diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index cbac81a85..63207c219 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -6,30 +6,27 @@ use crate::{ PRELUDE, }, type_row, - types::{FunctionType, PolyFuncType, Type}, + types::{FunctionType, PolyFuncType}, Extension, }; -use super::int_types::{int_type_var, INT_TYPE_ID}; +use super::int_types::int_type_var; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); -fn ftoi_sig( - int_type_var: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - let body = FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_type_var)]); +fn ftoi_sig(temp_reg: &ExtensionRegistry) -> Result { + let body = FunctionType::new( + type_row![FLOAT64_TYPE], + vec![sum_with_error(int_type_var(0))], + ); PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg) } -fn itof_sig( - int_type_var: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - let body = FunctionType::new(vec![int_type_var], type_row![FLOAT64_TYPE]); +fn itof_sig(temp_reg: &ExtensionRegistry) -> Result { + let body = FunctionType::new(vec![int_type_var(0)], type_row![FLOAT64_TYPE]); PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg) } @@ -43,12 +40,8 @@ pub fn extension() -> Extension { super::float_types::EXTENSION_ID, ]), ); - let int_types_extension = super::int_types::extension(); - let int_type_def = int_types_extension.get_type(&INT_TYPE_ID).unwrap(); - let int_type_var = int_type_var(0, int_type_def).unwrap(); let temp_reg: ExtensionRegistry = [ - extension.clone(), - int_types_extension, + super::int_types::EXTENSION.to_owned(), super::float_types::extension(), PRELUDE.to_owned(), ] @@ -57,28 +50,28 @@ pub fn extension() -> Extension { .add_op_type_scheme_simple( "trunc_u".into(), "float to unsigned int".to_owned(), - ftoi_sig(int_type_var.clone(), &temp_reg).unwrap(), + ftoi_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "trunc_s".into(), "float to signed int".to_owned(), - ftoi_sig(int_type_var.clone(), &temp_reg).unwrap(), + ftoi_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "convert_u".into(), "unsigned int to float".to_owned(), - itof_sig(int_type_var.clone(), &temp_reg).unwrap(), + itof_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "convert_s".into(), "signed int to float".to_owned(), - itof_sig(int_type_var, &temp_reg).unwrap(), + itof_sig(&temp_reg).unwrap(), ) .unwrap(); diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 18afe2655..6fdb8a4a3 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -1,6 +1,6 @@ //! Basic integer operations. -use super::int_types::{get_log_width, int_type, int_type_var, INT_TYPE_ID, LOG_WIDTH_TYPE_PARAM}; +use super::int_types::{get_log_width, int_type, int_type_var, LOG_WIDTH_TYPE_PARAM}; use crate::extension::prelude::{sum_with_error, BOOL_T}; use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::type_row; @@ -54,52 +54,36 @@ fn int_polytype( ) } -fn itob_sig( - int_type_var: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - int_polytype(1, vec![int_type_var], type_row![BOOL_T], temp_reg) +fn itob_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(1, vec![int_type_var(0)], type_row![BOOL_T], temp_reg) } -fn btoi_sig( - int_type_var: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - int_polytype(1, type_row![BOOL_T], vec![int_type_var], temp_reg) +fn btoi_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(1, type_row![BOOL_T], vec![int_type_var(0)], temp_reg) } -fn icmp_sig( - int_type_var: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - int_polytype(1, vec![int_type_var; 2], type_row![BOOL_T], temp_reg) +fn icmp_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(1, vec![int_type_var(0); 2], type_row![BOOL_T], temp_reg) } -fn ibinop_sig( - int_type_var: Type, - temp_reg: &ExtensionRegistry, -) -> Result { +fn ibinop_sig(temp_reg: &ExtensionRegistry) -> Result { + let int_type_var = int_type_var(0); + int_polytype( 1, vec![int_type_var.clone(); 2], - vec![int_type_var.clone()], + vec![int_type_var], temp_reg, ) } -fn iunop_sig( - int_type_var: Type, - temp_reg: &ExtensionRegistry, -) -> Result { +fn iunop_sig(temp_reg: &ExtensionRegistry) -> Result { + let int_type_var = int_type_var(0); int_polytype(1, vec![int_type_var.clone()], vec![int_type_var], temp_reg) } -fn idivmod_checked_sig( - int_type_var_0: Type, - int_type_var_1: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - let intpair: TypeRow = vec![int_type_var_0, int_type_var_1].into(); +fn idivmod_checked_sig(temp_reg: &ExtensionRegistry) -> Result { + let intpair: TypeRow = vec![int_type_var(0), int_type_var(1)].into(); int_polytype( 2, intpair.clone(), @@ -108,78 +92,44 @@ fn idivmod_checked_sig( ) } -fn idivmod_sig( - int_type_var_0: Type, - int_type_var_1: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - let intpair: TypeRow = vec![int_type_var_0, int_type_var_1].into(); +fn idivmod_sig(temp_reg: &ExtensionRegistry) -> Result { + let intpair: TypeRow = vec![int_type_var(0), int_type_var(1)].into(); int_polytype(2, intpair.clone(), vec![Type::new_tuple(intpair)], temp_reg) } -fn idiv_checked_sig( - int_type_var_0: Type, - int_type_var_1: Type, - temp_reg: &ExtensionRegistry, -) -> Result { +fn idiv_checked_sig(temp_reg: &ExtensionRegistry) -> Result { int_polytype( 2, - vec![int_type_var_0.clone(), int_type_var_1], - vec![sum_with_error(int_type_var_0)], + vec![int_type_var(1)], + vec![sum_with_error(int_type_var(0))], temp_reg, ) } -fn idiv_sig( - int_type_var_0: Type, - int_type_var_1: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - int_polytype( - 2, - vec![int_type_var_0.clone(), int_type_var_1], - vec![int_type_var_0], - temp_reg, - ) +fn idiv_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(2, vec![int_type_var(1)], vec![int_type_var(0)], temp_reg) } -fn imod_checked_sig( - int_type_var_0: Type, - int_type_var_1: Type, - temp_reg: &ExtensionRegistry, -) -> Result { +fn imod_checked_sig(temp_reg: &ExtensionRegistry) -> Result { int_polytype( 2, - vec![int_type_var_0, int_type_var_1.clone()], - vec![sum_with_error(int_type_var_1)], + vec![int_type_var(0), int_type_var(1).clone()], + vec![sum_with_error(int_type_var(1))], temp_reg, ) } -fn imod_sig( - int_type_var_0: Type, - int_type_var_1: Type, - temp_reg: &ExtensionRegistry, -) -> Result { +fn imod_sig(temp_reg: &ExtensionRegistry) -> Result { int_polytype( 2, - vec![int_type_var_0, int_type_var_1.clone()], - vec![int_type_var_1], + vec![int_type_var(0), int_type_var(1).clone()], + vec![int_type_var(1)], temp_reg, ) } -fn ish_sig( - int_type_var_0: Type, - int_type_var_1: Type, - temp_reg: &ExtensionRegistry, -) -> Result { - int_polytype( - 2, - vec![int_type_var_0.clone(), int_type_var_1], - vec![int_type_var_0], - temp_reg, - ) +fn ish_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(2, vec![int_type_var(1)], vec![int_type_var(0)], temp_reg) } /// Extension for basic integer operations. @@ -188,13 +138,13 @@ pub fn extension() -> Extension { EXTENSION_ID, ExtensionSet::singleton(&super::int_types::EXTENSION_ID), ); - let int_types_extension = super::int_types::extension(); - let int_type_def = int_types_extension.get_type(&INT_TYPE_ID).unwrap(); - let int_type_var_0 = int_type_var(0, int_type_def).unwrap(); - let int_type_var_1 = int_type_var(1, int_type_def).unwrap(); - let temp_reg: ExtensionRegistry = - [extension.clone(), int_types_extension, PRELUDE.to_owned()].into(); + let temp_reg: ExtensionRegistry = [ + extension.clone(), + super::int_types::EXTENSION.to_owned(), + PRELUDE.to_owned(), + ] + .into(); extension .add_op_custom_sig_simple( @@ -233,140 +183,140 @@ pub fn extension() -> Extension { .add_op_type_scheme_simple( "itobool".into(), "convert to bool (1 is true, 0 is false)".to_owned(), - itob_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + itob_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ifrombool".into(), "convert from bool (1 is true, 0 is false)".to_owned(), - btoi_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + btoi_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ieq".into(), "equality test".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ine".into(), "inequality test".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ilt_u".into(), "\"less than\" as unsigned integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ilt_s".into(), "\"less than\" as signed integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "igt_u".into(), "\"greater than\" as unsigned integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "igt_s".into(), "\"greater than\" as signed integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ile_u".into(), "\"less than or equal\" as unsigned integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ile_s".into(), "\"less than or equal\" as signed integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ige_u".into(), "\"greater than or equal\" as unsigned integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ige_s".into(), "\"greater than or equal\" as signed integers".to_owned(), - icmp_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "imax_u".into(), "maximum of unsigned integers".to_owned(), - ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "imax_s".into(), "maximum of signed integers".to_owned(), - ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "imin_u".into(), "minimum of unsigned integers".to_owned(), - ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "imin_s".into(), "minimum of signed integers".to_owned(), - ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "iadd".into(), "addition modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "isub".into(), "subtraction modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "ineg".into(), "negation modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - iunop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + iunop_sig(&temp_reg).unwrap(), ) .unwrap(); extension .add_op_type_scheme_simple( "imul".into(), "multiplication modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - ibinop_sig(int_type_var_0.clone(), &temp_reg).unwrap(), + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension @@ -375,7 +325,7 @@ pub fn extension() -> Extension { "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ q*m+r=n, 0<=r Extension { "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ q*m+r=n, 0<=r Extension { "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^M, generates \ signed q and unsigned r where q*m+r=n, 0<=r Extension { "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^M, generates \ signed q and unsigned r where q*m+r=n, 0<=r Extension { "shift first input left by k bits where k is unsigned interpretation of second input \ (leftmost bits dropped, rightmost bits set to zero" .to_owned(), - ish_sig(int_type_var_0.clone(), int_type_var_1.clone(), &temp_reg).unwrap(), + ish_sig(&temp_reg).unwrap(), ) .unwrap(); extension @@ -511,7 +461,7 @@ pub fn extension() -> Extension { "shift first input right by k bits where k is unsigned interpretation of second input \ (rightmost bits dropped, leftmost bits set to zero)" .to_owned(), - ish_sig(int_type_var_0.clone(), int_type_var_1.clone(), &temp_reg).unwrap(), + ish_sig(&temp_reg).unwrap(), ) .unwrap(); extension @@ -520,7 +470,7 @@ pub fn extension() -> Extension { "rotate first input left by k bits where k is unsigned interpretation of second input \ (leftmost bits replace rightmost bits)" .to_owned(), - ish_sig(int_type_var_0.clone(), int_type_var_1.clone(), &temp_reg).unwrap(), + ish_sig(&temp_reg).unwrap(), ) .unwrap(); extension @@ -529,7 +479,7 @@ pub fn extension() -> Extension { "rotate first input right by k bits where k is unsigned interpretation of second input \ (rightmost bits replace leftmost bits)" .to_owned(), - ish_sig(int_type_var_0.clone(), int_type_var_1.clone(), &temp_reg).unwrap(), + ish_sig( &temp_reg).unwrap(), ) .unwrap(); diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index 526ea9d67..7a67de28a 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -5,7 +5,7 @@ use std::num::NonZeroU64; use smol_str::SmolStr; use crate::{ - extension::{ExtensionId, SignatureError, TypeDef}, + extension::ExtensionId, types::{ type_param::{TypeArg, TypeArgError, TypeParam}, ConstTypeError, CustomCheckFailure, CustomType, Type, TypeBound, @@ -198,11 +198,20 @@ pub fn extension() -> Extension { extension } +lazy_static! { + /// Lazy reference to int types extension. + pub static ref EXTENSION: Extension = extension(); +} + /// get an integer type variable, given the integer type definition -pub(super) fn int_type_var(var_id: usize, int_type_def: &TypeDef) -> Result { - Ok(Type::new_extension(int_type_def.instantiate(vec![ - TypeArg::new_var_use(var_id, LOG_WIDTH_TYPE_PARAM), - ])?)) +pub(super) fn int_type_var(var_id: usize) -> Type { + Type::new_extension( + EXTENSION + .get_type(&INT_TYPE_ID) + .unwrap() + .instantiate(vec![TypeArg::new_var_use(var_id, LOG_WIDTH_TYPE_PARAM)]) + .unwrap(), + ) } #[cfg(test)] mod test { From aeb7ef6a68a899e57740676ff4c0eda2bb4dadb8 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 12:37:28 +0000 Subject: [PATCH 13/15] only call simple functions once --- src/std_extensions/arithmetic/float_ops.rs | 47 ++++++++++++++-------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index fe131b5b2..4e2da641d 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -35,61 +35,76 @@ pub fn extension() -> Extension { ExtensionSet::singleton(&super::float_types::EXTENSION_ID), ); + let fcmp_sig = fcmp_sig(); + let fbinop_sig = fbinop_sig(); + let funop_sig = funop_sig(); extension - .add_op_type_scheme_simple("feq".into(), "equality test".to_owned(), fcmp_sig()) + .add_op_type_scheme_simple("feq".into(), "equality test".to_owned(), fcmp_sig.clone()) .unwrap(); extension - .add_op_type_scheme_simple("fne".into(), "inequality test".to_owned(), fcmp_sig()) + .add_op_type_scheme_simple("fne".into(), "inequality test".to_owned(), fcmp_sig.clone()) .unwrap(); extension - .add_op_type_scheme_simple("flt".into(), "\"less than\"".to_owned(), fcmp_sig()) + .add_op_type_scheme_simple("flt".into(), "\"less than\"".to_owned(), fcmp_sig.clone()) .unwrap(); extension - .add_op_type_scheme_simple("fgt".into(), "\"greater than\"".to_owned(), fcmp_sig()) + .add_op_type_scheme_simple( + "fgt".into(), + "\"greater than\"".to_owned(), + fcmp_sig.clone(), + ) .unwrap(); extension .add_op_type_scheme_simple( "fle".into(), "\"less than or equal\"".to_owned(), - fcmp_sig(), + fcmp_sig.clone(), ) .unwrap(); extension .add_op_type_scheme_simple( "fge".into(), "\"greater than or equal\"".to_owned(), - fcmp_sig(), + fcmp_sig, ) .unwrap(); extension - .add_op_type_scheme_simple("fmax".into(), "maximum".to_owned(), fbinop_sig()) + .add_op_type_scheme_simple("fmax".into(), "maximum".to_owned(), fbinop_sig.clone()) .unwrap(); extension - .add_op_type_scheme_simple("fmin".into(), "minimum".to_owned(), fbinop_sig()) + .add_op_type_scheme_simple("fmin".into(), "minimum".to_owned(), fbinop_sig.clone()) .unwrap(); extension - .add_op_type_scheme_simple("fadd".into(), "addition".to_owned(), fbinop_sig()) + .add_op_type_scheme_simple("fadd".into(), "addition".to_owned(), fbinop_sig.clone()) .unwrap(); extension - .add_op_type_scheme_simple("fsub".into(), "subtraction".to_owned(), fbinop_sig()) + .add_op_type_scheme_simple("fsub".into(), "subtraction".to_owned(), fbinop_sig.clone()) .unwrap(); extension - .add_op_type_scheme_simple("fneg".into(), "negation".to_owned(), funop_sig()) + .add_op_type_scheme_simple("fneg".into(), "negation".to_owned(), funop_sig.clone()) .unwrap(); extension - .add_op_type_scheme_simple("fabs".into(), "absolute value".to_owned(), funop_sig()) + .add_op_type_scheme_simple( + "fabs".into(), + "absolute value".to_owned(), + funop_sig.clone(), + ) .unwrap(); extension - .add_op_type_scheme_simple("fmul".into(), "multiplication".to_owned(), fbinop_sig()) + .add_op_type_scheme_simple( + "fmul".into(), + "multiplication".to_owned(), + fbinop_sig.clone(), + ) .unwrap(); extension - .add_op_type_scheme_simple("fdiv".into(), "division".to_owned(), fbinop_sig()) + .add_op_type_scheme_simple("fdiv".into(), "division".to_owned(), fbinop_sig) .unwrap(); extension - .add_op_type_scheme_simple("ffloor".into(), "floor".to_owned(), funop_sig()) + .add_op_type_scheme_simple("ffloor".into(), "floor".to_owned(), funop_sig.clone()) .unwrap(); extension - .add_op_type_scheme_simple("fceil".into(), "ceiling".to_owned(), funop_sig()) + .add_op_type_scheme_simple("fceil".into(), "ceiling".to_owned(), funop_sig) .unwrap(); extension From d93d5e655fafdcf40f7c503f2d33e5cbde94cb8f Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 12:46:21 +0000 Subject: [PATCH 14/15] remove another unecessary extension --- src/std_extensions/arithmetic/int_ops.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 6fdb8a4a3..cc09fc4a3 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -139,12 +139,8 @@ pub fn extension() -> Extension { ExtensionSet::singleton(&super::int_types::EXTENSION_ID), ); - let temp_reg: ExtensionRegistry = [ - extension.clone(), - super::int_types::EXTENSION.to_owned(), - PRELUDE.to_owned(), - ] - .into(); + let temp_reg: ExtensionRegistry = + [super::int_types::EXTENSION.to_owned(), PRELUDE.to_owned()].into(); extension .add_op_custom_sig_simple( From b5159e9c71d820d731142ffeaa1a5d5bc15a005d Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 13:34:39 +0000 Subject: [PATCH 15/15] inline single use functions --- src/std_extensions/arithmetic/float_ops.rs | 28 +++++++--------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index 4e2da641d..7dc20f40d 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -12,22 +12,6 @@ use super::float_types::FLOAT64_TYPE; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float"); -fn fcmp_sig() -> PolyFuncType { - FunctionType::new( - type_row![FLOAT64_TYPE; 2], - type_row![crate::extension::prelude::BOOL_T], - ) - .into() -} - -fn fbinop_sig() -> PolyFuncType { - FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE]).into() -} - -fn funop_sig() -> PolyFuncType { - FunctionType::new(type_row![FLOAT64_TYPE], type_row![FLOAT64_TYPE]).into() -} - /// Extension for basic arithmetic operations. pub fn extension() -> Extension { let mut extension = Extension::new_with_reqs( @@ -35,9 +19,15 @@ pub fn extension() -> Extension { ExtensionSet::singleton(&super::float_types::EXTENSION_ID), ); - let fcmp_sig = fcmp_sig(); - let fbinop_sig = fbinop_sig(); - let funop_sig = funop_sig(); + let fcmp_sig: PolyFuncType = FunctionType::new( + type_row![FLOAT64_TYPE; 2], + type_row![crate::extension::prelude::BOOL_T], + ) + .into(); + let fbinop_sig: PolyFuncType = + FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE]).into(); + let funop_sig: PolyFuncType = + FunctionType::new(type_row![FLOAT64_TYPE], type_row![FLOAT64_TYPE]).into(); extension .add_op_type_scheme_simple("feq".into(), "equality test".to_owned(), fcmp_sig.clone()) .unwrap();