diff --git a/hugr-llvm/src/extension/prelude.rs b/hugr-llvm/src/extension/prelude.rs index 445317912..793621349 100644 --- a/hugr-llvm/src/extension/prelude.rs +++ b/hugr-llvm/src/extension/prelude.rs @@ -1,9 +1,13 @@ +use std::iter::once; + use anyhow::{anyhow, bail, ensure, Ok, Result}; use hugr_core::extension::prelude::{ERROR_TYPE_NAME, STRING_TYPE_NAME}; use hugr_core::{ extension::{ prelude::{ - self, error_type, ArrayOp, ArrayOpDef, ConstError, ConstExternalSymbol, ConstString, + self, + array::{ArrayRepeat, ArrayScan}, + error_type, ArrayOp, ArrayOpDef, ConstError, ConstExternalSymbol, ConstString, ConstUsize, MakeTuple, TupleOpDef, UnpackTuple, ARRAY_TYPE_NAME, }, simple_op::MakeExtensionOp as _, @@ -89,6 +93,30 @@ pub trait PreludeCodegen: Clone { array::emit_array_op(self, ctx, op, inputs, outputs) } + /// Emit a [hugr_core::extension::prelude::array::ArrayRepeat] op. + fn emit_array_repeat<'c, H: HugrView>( + &self, + ctx: &mut EmitFuncContext<'c, '_, H>, + op: ArrayRepeat, + func: BasicValueEnum<'c>, + ) -> Result> { + array::emit_repeat_op(ctx, op, func) + } + + /// Emit a [hugr_core::extension::prelude::array::ArrayScan] op. + /// + /// Returns the resulting array and the final values of the accumulators. + fn emit_array_scan<'c, H: HugrView>( + &self, + ctx: &mut EmitFuncContext<'c, '_, H>, + op: ArrayScan, + src_array: BasicValueEnum<'c>, + func: BasicValueEnum<'c>, + initial_accs: &[BasicValueEnum<'c>], + ) -> Result<(BasicValueEnum<'c>, Vec>)> { + array::emit_scan_op(ctx, op, src_array, func, initial_accs) + } + /// Emit a [hugr_core::extension::prelude::PRINT_OP_ID] node. fn emit_print( &self, @@ -312,6 +340,28 @@ fn add_prelude_extensions<'a, H: HugrView + 'a>( ) } }) + .extension_op(prelude::PRELUDE_ID, prelude::array::ARRAY_REPEAT_OP_ID, { + let pcg = pcg.clone(); + move |context, args| { + let func = args.inputs[0]; + let op = ArrayRepeat::from_extension_op(args.node().as_ref())?; + let arr = pcg.emit_array_repeat(context, op, func)?; + args.outputs.finish(context.builder(), [arr]) + } + }) + .extension_op(prelude::PRELUDE_ID, prelude::array::ARRAY_SCAN_OP_ID, { + let pcg = pcg.clone(); + move |context, args| { + let src_array = args.inputs[0]; + let func = args.inputs[1]; + let initial_accs = &args.inputs[2..]; + let op = ArrayScan::from_extension_op(args.node().as_ref())?; + let (tgt_array, final_accs) = + pcg.emit_array_scan(context, op, src_array, func, initial_accs)?; + args.outputs + .finish(context.builder(), once(tgt_array).chain(final_accs)) + } + }) .extension_op(prelude::PRELUDE_ID, prelude::PRINT_OP_ID, { let pcg = pcg.clone(); move |context, args| { diff --git a/hugr-llvm/src/extension/prelude/array.rs b/hugr-llvm/src/extension/prelude/array.rs index fa33b7407..a2037a0f5 100644 --- a/hugr-llvm/src/extension/prelude/array.rs +++ b/hugr-llvm/src/extension/prelude/array.rs @@ -2,7 +2,10 @@ use anyhow::{anyhow, Ok, Result}; use hugr_core::{ extension::{ - prelude::{array_type, option_type, ArrayOp, ArrayOpDef}, + prelude::{ + array::{ArrayRepeat, ArrayScan}, + array_type, option_type, ArrayOp, ArrayOpDef, + }, simple_op::MakeRegisteredOp, }, ops::DataflowOpTrait as _, @@ -12,29 +15,28 @@ use hugr_core::{ use inkwell::{ builder::{Builder, BuilderError}, types::BasicType, - values::{ArrayValue, BasicValue as _, BasicValueEnum, IntValue, PointerValue}, + values::{ArrayValue, BasicValue as _, BasicValueEnum, CallableValue, IntValue, PointerValue}, IntPredicate, }; +use itertools::Itertools; use crate::{ - emit::{EmitFuncContext, RowPromise}, + emit::{deaggregate_call_result, EmitFuncContext, RowPromise}, sum::LLVMSumType, types::{HugrType, TypingSession}, }; use super::PreludeCodegen; -/// Helper function to allocate an array on the stack and pass a pointer to it -/// to a closure. +/// Helper function to allocate an array on the stack. /// -/// The pointer forwarded to the closure is a pointer to the first element of -/// the array. I.e. it is of type `array.get_element_type().ptr_type()` not -/// `array.ptr_type()` -fn with_array_alloca<'c, T, E: From>( +/// Returns two pointers: The first one is a pointer to the first element of the +/// array (i.e. it is of type `array.get_element_type().ptr_type()`) whereas the +/// second one points to the whole array value, i.e. it is of type `array.ptr_type()`. +fn build_array_alloca<'c>( builder: &Builder<'c>, array: ArrayValue<'c>, - go: impl FnOnce(PointerValue<'c>) -> Result, -) -> Result { +) -> Result<(PointerValue<'c>, PointerValue<'c>), BuilderError> { let array_ty = array.get_type(); let array_len: IntValue<'c> = { let ctx = builder.get_insert_block().unwrap().get_context(); @@ -45,9 +47,66 @@ fn with_array_alloca<'c, T, E: From>( .build_bit_cast(ptr, array_ty.ptr_type(Default::default()), "")? .into_pointer_value(); builder.build_store(array_ptr, array)?; + Result::Ok((ptr, array_ptr)) +} + +/// Helper function to allocate an array on the stack and pass a pointer to it +/// to a closure. +/// +/// The pointer forwarded to the closure is a pointer to the first element of +/// the array. I.e. it is of type `array.get_element_type().ptr_type()` not +/// `array.ptr_type()` +fn with_array_alloca<'c, T, E: From>( + builder: &Builder<'c>, + array: ArrayValue<'c>, + go: impl FnOnce(PointerValue<'c>) -> Result, +) -> Result { + let (ptr, _) = build_array_alloca(builder, array)?; go(ptr) } +/// Helper function to build a loop that repeats for a given number of iterations. +/// +/// The provided closure is called to build the loop body. Afterwards, the builder is positioned at +/// the end of the loop exit block. +fn build_loop<'c, T, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + iters: IntValue<'c>, + go: impl FnOnce(&mut EmitFuncContext<'c, '_, H>, IntValue<'c>) -> Result, +) -> Result { + let builder = ctx.builder(); + let idx_ty = ctx.iw_context().i32_type(); + let idx_ptr = builder.build_alloca(idx_ty, "")?; + builder.build_store(idx_ptr, idx_ty.const_zero())?; + + let exit_block = ctx.new_basic_block("", None); + + let (body_block, val) = ctx.build_positioned_new_block("", Some(exit_block), |ctx, bb| { + let idx = ctx.builder().build_load(idx_ptr, "")?.into_int_value(); + let val = go(ctx, idx)?; + let builder = ctx.builder(); + let inc_idx = builder.build_int_add(idx, idx_ty.const_int(1, false), "")?; + builder.build_store(idx_ptr, inc_idx)?; + // Branch to the head is built later + Ok((bb, val)) + })?; + + let head_block = ctx.build_positioned_new_block("", Some(body_block), |ctx, bb| { + let builder = ctx.builder(); + let idx = builder.build_load(idx_ptr, "")?.into_int_value(); + let cmp = builder.build_int_compare(IntPredicate::ULT, idx, iters, "")?; + builder.build_conditional_branch(cmp, body_block, exit_block)?; + Ok(bb) + })?; + + let builder = ctx.builder(); + builder.build_unconditional_branch(head_block)?; + builder.position_at_end(body_block); + builder.build_unconditional_branch(head_block)?; + ctx.builder().position_at_end(exit_block); + Ok(val) +} + pub fn emit_array_op<'c, H: HugrView>( pcg: &impl PreludeCodegen, ctx: &mut EmitFuncContext<'c, '_, H>, @@ -384,13 +443,103 @@ fn emit_pop_op<'c>( ret_ty.build_tag(builder, 1, vec![elem_v, array_v]) } +/// Emits an [ArrayRepeat] op. +pub fn emit_repeat_op<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + op: ArrayRepeat, + func: BasicValueEnum<'c>, +) -> Result> { + let builder = ctx.builder(); + let array_len = ctx.iw_context().i32_type().const_int(op.size, false); + let array_ty = ctx.llvm_type(&op.elem_ty)?.array_type(op.size as u32); + let (ptr, array_ptr) = build_array_alloca(builder, array_ty.get_undef())?; + build_loop(ctx, array_len, |ctx, idx| { + let builder = ctx.builder(); + let func_ptr = CallableValue::try_from(func.into_pointer_value()) + .map_err(|_| anyhow!("ArrayOpDef::repeat expects a function pointer"))?; + let v = builder + .build_call(func_ptr, &[], "")? + .try_as_basic_value() + .left() + .ok_or(anyhow!("ArrayOpDef::repeat function must return a value"))?; + let elem_addr = unsafe { builder.build_in_bounds_gep(ptr, &[idx], "")? }; + builder.build_store(elem_addr, v)?; + Ok(()) + })?; + + let builder = ctx.builder(); + let array_v = builder.build_load(array_ptr, "")?; + Ok(array_v) +} + +/// Emits an [ArrayScan] op. +/// +/// Returns the resulting array and the final values of the accumulators. +pub fn emit_scan_op<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + op: ArrayScan, + src_array: BasicValueEnum<'c>, + func: BasicValueEnum<'c>, + initial_accs: &[BasicValueEnum<'c>], +) -> Result<(BasicValueEnum<'c>, Vec>)> { + let builder = ctx.builder(); + let ts = ctx.typing_session(); + let array_len = ctx.iw_context().i32_type().const_int(op.size, false); + let tgt_array_ty = ts.llvm_type(&op.tgt_ty)?.array_type(op.size as u32); + let (src_ptr, _) = build_array_alloca(builder, src_array.into_array_value())?; + let (tgt_ptr, tgt_array_ptr) = build_array_alloca(builder, tgt_array_ty.get_undef())?; + + let acc_tys: Vec<_> = op.acc_tys.iter().map(|ty| ts.llvm_type(ty)).try_collect()?; + let acc_ptrs: Vec<_> = acc_tys + .iter() + .map(|ty| builder.build_alloca(*ty, "")) + .try_collect()?; + for (ptr, initial_val) in acc_ptrs.iter().zip(initial_accs) { + builder.build_store(*ptr, *initial_val)?; + } + + build_loop(ctx, array_len, |ctx, idx| { + let builder = ctx.builder(); + let func_ptr = CallableValue::try_from(func.into_pointer_value()) + .map_err(|_| anyhow!("ArrayOpDef::scan expects a function pointer"))?; + let src_elem_addr = unsafe { builder.build_in_bounds_gep(src_ptr, &[idx], "")? }; + let src_elem = builder.build_load(src_elem_addr, "")?; + let mut args = vec![src_elem.into()]; + for ptr in acc_ptrs.iter() { + args.push(builder.build_load(*ptr, "")?.into()); + } + let call = builder.build_call(func_ptr, args.as_slice(), "")?; + let call_results = deaggregate_call_result(builder, call, 1 + acc_tys.len())?; + let tgt_elem_addr = unsafe { builder.build_in_bounds_gep(tgt_ptr, &[idx], "")? }; + builder.build_store(tgt_elem_addr, call_results[0])?; + for (ptr, next_act) in acc_ptrs.iter().zip(call_results[1..].iter()) { + builder.build_store(*ptr, *next_act)?; + } + Ok(()) + })?; + + let builder = ctx.builder(); + let tgt_array_v = builder.build_load(tgt_array_ptr, "")?; + let final_accs = acc_ptrs + .into_iter() + .map(|ptr| builder.build_load(ptr, "")) + .try_collect()?; + Ok((tgt_array_v, final_accs)) +} + #[cfg(test)] mod test { + use hugr_core::builder::Container as _; + use hugr_core::extension::prelude::array::ArrayRepeat; + use hugr_core::extension::ExtensionSet; + use hugr_core::ops::Tag; + use hugr_core::types::Type; use hugr_core::{ builder::{Dataflow, DataflowSubContainer, SubContainer}, extension::{ prelude::{ - self, array_type, bool_t, option_type, usize_t, ConstUsize, UnwrapBuilder as _, + self, array::ArrayScan, array_type, bool_t, option_type, usize_t, ConstUsize, + UnwrapBuilder as _, }, ExtensionRegistry, }, @@ -403,6 +552,7 @@ mod test { logic, }, type_row, + types::Signature, }; use itertools::Itertools as _; use rstest::rstest; @@ -454,6 +604,15 @@ mod test { .unwrap() } + fn exec_extension_set() -> ExtensionSet { + ExtensionSet::from_iter([ + int_types::EXTENSION_ID, + int_ops::EXTENSION_ID, + logic::EXTENSION_ID, + prelude::PRELUDE_ID, + ]) + } + #[rstest] #[case(0, 1)] #[case(1, 2)] @@ -779,4 +938,190 @@ mod test { exec_ctx.add_extensions(|cge| cge.add_default_prelude_extensions().add_int_extensions()); assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main")); } + + #[rstest] + #[case(5, 42, 0)] + #[case(5, 42, 1)] + #[case(5, 42, 2)] + #[case(5, 42, 3)] + #[case(5, 42, 4)] + fn exec_repeat( + mut exec_ctx: TestContext, + #[case] size: u64, + #[case] value: u64, + #[case] idx: u64, + ) { + // We build a HUGR that: + // - Contains a nested function that returns `value` + // - Creates an array of length `size` populated via this function + // - Looks up the value at `idx` and returns it + + let int_ty = int_type(6); + let hugr = SimpleHugrConfig::new() + .with_outs(int_ty.clone()) + .with_extensions(exec_registry()) + .finish_with_exts(|mut builder, reg| { + let mut func = builder + .define_function( + "foo", + Signature::new(vec![], vec![int_ty.clone()]) + .with_extension_delta(exec_extension_set()), + ) + .unwrap(); + let v = func.add_load_value(ConstInt::new_u(6, value).unwrap()); + let func_id = func.finish_with_outputs(vec![v]).unwrap(); + let func_v = builder + .load_func(func_id.handle(), &[], &exec_registry()) + .unwrap(); + let repeat = ArrayRepeat::new(int_ty.clone(), size, exec_extension_set()); + let arr = builder + .add_dataflow_op(repeat, vec![func_v]) + .unwrap() + .out_wire(0); + let idx_v = builder.add_load_value(ConstUsize::new(idx)); + let get_res = builder + .add_array_get(int_ty.clone(), size, arr, idx_v) + .unwrap(); + let [elem] = builder + .build_unwrap_sum(reg, 1, option_type(vec![int_ty.clone()]), get_res) + .unwrap(); + builder.finish_with_outputs([elem]).unwrap() + }); + exec_ctx.add_extensions(|cge| cge.add_default_prelude_extensions().add_int_extensions()); + assert_eq!(value, exec_ctx.exec_hugr_u64(hugr, "main")); + } + + #[rstest] + #[case(10, 1)] + #[case(10, 2)] + #[case(0, 1)] + fn exec_scan_map(mut exec_ctx: TestContext, #[case] size: u64, #[case] inc: u64) { + // We build a HUGR that: + // - Creates an array [1, 2, 3, ..., size] + // - Maps a function that increments each element by `inc` + // - Returns the sum of the array elements + let int_ty = int_type(6); + let hugr = SimpleHugrConfig::new() + .with_outs(int_ty.clone()) + .with_extensions(exec_registry()) + .finish_with_exts(|mut builder, reg| { + let mut r = builder.add_load_value(ConstInt::new_u(6, 0).unwrap()); + let new_array_args = (0..size) + .map(|i| builder.add_load_value(ConstInt::new_u(6, i).unwrap())) + .collect_vec(); + let arr = builder + .add_new_array(int_ty.clone(), new_array_args) + .unwrap(); + + let mut func = builder + .define_function( + "foo", + Signature::new(vec![int_ty.clone()], vec![int_ty.clone()]) + .with_extension_delta(exec_extension_set()), + ) + .unwrap(); + let [elem] = func.input_wires_arr(); + let delta = func.add_load_value(ConstInt::new_u(6, inc).unwrap()); + let out = func.add_iadd(6, elem, delta).unwrap(); + let func_id = func.finish_with_outputs(vec![out]).unwrap(); + let func_v = builder + .load_func(func_id.handle(), &[], &exec_registry()) + .unwrap(); + let scan = ArrayScan::new( + int_ty.clone(), + int_ty.clone(), + vec![], + size, + exec_extension_set(), + ); + let mut arr = builder + .add_dataflow_op(scan, [arr, func_v]) + .unwrap() + .out_wire(0); + + for i in 0..size { + let array_size = size - i; + let pop_res = builder + .add_array_pop_left(int_ty.clone(), array_size, arr) + .unwrap(); + let [elem, new_arr] = builder + .build_unwrap_sum( + reg, + 1, + option_type(vec![ + int_ty.clone(), + array_type(array_size - 1, int_ty.clone()), + ]), + pop_res, + ) + .unwrap(); + arr = new_arr; + r = builder.add_iadd(6, r, elem).unwrap(); + } + builder.finish_with_outputs([r]).unwrap() + }); + exec_ctx.add_extensions(|cge| cge.add_default_prelude_extensions().add_int_extensions()); + let expected: u64 = (inc..size + inc).sum(); + assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main")); + } + + #[rstest] + #[case(0)] + #[case(1)] + #[case(10)] + fn exec_scan_fold(mut exec_ctx: TestContext, #[case] size: u64) { + // We build a HUGR that: + // - Creates an array [1, 2, 3, ..., size] + // - Sums up the elements of the array using a scan and returns that sum + + let int_ty = int_type(6); + let hugr = SimpleHugrConfig::new() + .with_outs(int_ty.clone()) + .with_extensions(exec_registry()) + .finish_with_exts(|mut builder, _reg| { + let new_array_args = (0..size) + .map(|i| builder.add_load_value(ConstInt::new_u(6, i).unwrap())) + .collect_vec(); + let arr = builder + .add_new_array(int_ty.clone(), new_array_args) + .unwrap(); + + let mut func = builder + .define_function( + "foo", + Signature::new( + vec![int_ty.clone(), int_ty.clone()], + vec![Type::UNIT, int_ty.clone()], + ) + .with_extension_delta(exec_extension_set()), + ) + .unwrap(); + let [elem, acc] = func.input_wires_arr(); + let acc = func.add_iadd(6, elem, acc).unwrap(); + let unit = func + .add_dataflow_op(Tag::new(0, vec![type_row![]]), []) + .unwrap() + .out_wire(0); + let func_id = func.finish_with_outputs(vec![unit, acc]).unwrap(); + let func_v = builder + .load_func(func_id.handle(), &[], &exec_registry()) + .unwrap(); + let scan = ArrayScan::new( + int_ty.clone(), + Type::UNIT, + vec![int_ty.clone()], + size, + exec_extension_set(), + ); + let zero = builder.add_load_value(ConstInt::new_u(6, 0).unwrap()); + let sum = builder + .add_dataflow_op(scan, [arr, func_v, zero]) + .unwrap() + .out_wire(1); + builder.finish_with_outputs([sum]).unwrap() + }); + exec_ctx.add_extensions(|cge| cge.add_default_prelude_extensions().add_int_extensions()); + let expected: u64 = (0..size).sum(); + assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main")); + } }