Skip to content

Commit

Permalink
feat: Lower array repeat and scan ops (#1717)
Browse files Browse the repository at this point in the history
Closes #1681
  • Loading branch information
mark-koch authored Dec 10, 2024
1 parent 40e903f commit 685adaf
Show file tree
Hide file tree
Showing 2 changed files with 408 additions and 13 deletions.
52 changes: 51 additions & 1 deletion hugr-llvm/src/extension/prelude.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::iter::once;

use anyhow::{anyhow, bail, ensure, Ok, Result};
use hugr_core::extension::prelude::{ERROR_TYPE_NAME, STRING_TYPE_NAME};
use hugr_core::{
extension::{
prelude::{
self, error_type, ArrayOp, ArrayOpDef, ConstError, ConstExternalSymbol, ConstString,
self,
array::{ArrayRepeat, ArrayScan},
error_type, ArrayOp, ArrayOpDef, ConstError, ConstExternalSymbol, ConstString,
ConstUsize, MakeTuple, TupleOpDef, UnpackTuple, ARRAY_TYPE_NAME,
},
simple_op::MakeExtensionOp as _,
Expand Down Expand Up @@ -89,6 +93,30 @@ pub trait PreludeCodegen: Clone {
array::emit_array_op(self, ctx, op, inputs, outputs)
}

/// Emit a [hugr_core::extension::prelude::array::ArrayRepeat] op.
fn emit_array_repeat<'c, H: HugrView>(
&self,
ctx: &mut EmitFuncContext<'c, '_, H>,
op: ArrayRepeat,
func: BasicValueEnum<'c>,
) -> Result<BasicValueEnum<'c>> {
array::emit_repeat_op(ctx, op, func)
}

/// Emit a [hugr_core::extension::prelude::array::ArrayScan] op.
///
/// Returns the resulting array and the final values of the accumulators.
fn emit_array_scan<'c, H: HugrView>(
&self,
ctx: &mut EmitFuncContext<'c, '_, H>,
op: ArrayScan,
src_array: BasicValueEnum<'c>,
func: BasicValueEnum<'c>,
initial_accs: &[BasicValueEnum<'c>],
) -> Result<(BasicValueEnum<'c>, Vec<BasicValueEnum<'c>>)> {
array::emit_scan_op(ctx, op, src_array, func, initial_accs)
}

/// Emit a [hugr_core::extension::prelude::PRINT_OP_ID] node.
fn emit_print<H: HugrView>(
&self,
Expand Down Expand Up @@ -312,6 +340,28 @@ fn add_prelude_extensions<'a, H: HugrView + 'a>(
)
}
})
.extension_op(prelude::PRELUDE_ID, prelude::array::ARRAY_REPEAT_OP_ID, {
let pcg = pcg.clone();
move |context, args| {
let func = args.inputs[0];
let op = ArrayRepeat::from_extension_op(args.node().as_ref())?;
let arr = pcg.emit_array_repeat(context, op, func)?;
args.outputs.finish(context.builder(), [arr])
}
})
.extension_op(prelude::PRELUDE_ID, prelude::array::ARRAY_SCAN_OP_ID, {
let pcg = pcg.clone();
move |context, args| {
let src_array = args.inputs[0];
let func = args.inputs[1];
let initial_accs = &args.inputs[2..];
let op = ArrayScan::from_extension_op(args.node().as_ref())?;
let (tgt_array, final_accs) =
pcg.emit_array_scan(context, op, src_array, func, initial_accs)?;
args.outputs
.finish(context.builder(), once(tgt_array).chain(final_accs))
}
})
.extension_op(prelude::PRELUDE_ID, prelude::PRINT_OP_ID, {
let pcg = pcg.clone();
move |context, args| {
Expand Down
Loading

0 comments on commit 685adaf

Please sign in to comment.