From 89545e82aa83869374120a71ec45cc85f9bbb19b Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Wed, 11 Dec 2024 15:52:15 +0000 Subject: [PATCH] feat(llvm): Tail loop emission (#1749) Work in progress. Will address #1688 --------- Co-authored-by: Douglas Wilson Co-authored-by: Seyon Sivarajah Co-authored-by: Douglas Wilson <141026920+doug-q@users.noreply.github.com> --- hugr-llvm/src/emit/func.rs | 2 +- hugr-llvm/src/emit/func/mailbox.rs | 15 +- hugr-llvm/src/emit/ops.rs | 59 +++++- ...mit__test__test_fns__tail_loop@llvm14.snap | 61 ++++++ ...est_fns__tail_loop@pre-mem2reg@llvm14.snap | 129 ++++++++++++ ...t_fns__tail_loop@pre-mem2reg@llvm14_2.snap | 168 +++++++++++++++ ...st__test_fns__tail_loop_simple@llvm14.snap | 36 ++++ ...__tail_loop_simple@pre-mem2reg@llvm14.snap | 55 +++++ ...tail_loop_simple@pre-mem2reg@llvm14_3.snap | 55 +++++ hugr-llvm/src/emit/test.rs | 194 +++++++++++++++++- hugr-llvm/src/sum.rs | 2 +- 11 files changed, 767 insertions(+), 9 deletions(-) create mode 100644 hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop@llvm14.snap create mode 100644 hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop@pre-mem2reg@llvm14.snap create mode 100644 hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop@pre-mem2reg@llvm14_2.snap create mode 100644 hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop_simple@llvm14.snap create mode 100644 hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop_simple@pre-mem2reg@llvm14.snap create mode 100644 hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop_simple@pre-mem2reg@llvm14_3.snap diff --git a/hugr-llvm/src/emit/func.rs b/hugr-llvm/src/emit/func.rs index 0a09dc9b3..0d954d55f 100644 --- a/hugr-llvm/src/emit/func.rs +++ b/hugr-llvm/src/emit/func.rs @@ -252,7 +252,7 @@ impl<'c, 'a, H: HugrView> EmitFuncContext<'c, 'a, H> { Ok(r) } - /// Returns a [RowMailBox] mapped to thie ouput wires of `node`. When emitting a node + /// Returns a [RowMailBox] mapped to the output wires of `node`. When emitting a node /// output values are written to this mailbox. pub fn node_outs_rmb<'hugr, OT: 'hugr>( &mut self, diff --git a/hugr-llvm/src/emit/func/mailbox.rs b/hugr-llvm/src/emit/func/mailbox.rs index a5c8b2e7b..5dc25fb29 100644 --- a/hugr-llvm/src/emit/func/mailbox.rs +++ b/hugr-llvm/src/emit/func/mailbox.rs @@ -1,6 +1,6 @@ use std::{borrow::Cow, rc::Rc}; -use anyhow::Result; +use anyhow::{bail, Result}; use delegate::delegate; use inkwell::{ builder::Builder, @@ -148,6 +148,19 @@ impl<'c> RowMailBox<'c> { builder: &Builder<'c>, vs: impl IntoIterator>, ) -> Result<()> { + let vs = vs.into_iter().collect_vec(); + #[cfg(debug_assertions)] + { + let actual_types = vs.clone().into_iter().map(|x| x.get_type()).collect_vec(); + let expected_types = self.get_types().collect_vec(); + if actual_types != expected_types { + bail!( + "RowMailbox::write: Expected types {:?}, got {:?}", + expected_types, + actual_types + ); + } + } zip_eq(self.0.iter(), vs).try_for_each(|(mb, v)| mb.write(builder, v)) } diff --git a/hugr-llvm/src/emit/ops.rs b/hugr-llvm/src/emit/ops.rs index 5fdbd5a76..3989ecd4b 100644 --- a/hugr-llvm/src/emit/ops.rs +++ b/hugr-llvm/src/emit/ops.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, bail, Result}; use hugr_core::ops::{ constant::Sum, Call, CallIndirect, Case, Conditional, Const, ExtensionOp, Input, LoadConstant, - LoadFunction, OpTag, OpTrait, OpType, Output, Tag, Value, CFG, + LoadFunction, OpTag, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG, }; use hugr_core::{ hugr::views::SiblingGraph, @@ -58,7 +58,7 @@ where .ok_or(anyhow!("DataflowParentEmitter: Output taken twice")) } - pub fn emit_children(mut self, context: &mut EmitFuncContext<'c, '_, H>) -> Result<()> { + pub fn emit_children(&mut self, context: &mut EmitFuncContext<'c, '_, H>) -> Result<()> { use petgraph::visit::Topo; let node = self.node; if !OpTag::DataflowParent.is_superset(node.tag()) { @@ -306,6 +306,59 @@ fn emit_cfg<'c, H: HugrView>( cfg::CfgEmitter::new(context, args)?.emit_children(context) } +fn emit_tail_loop<'c, H: HugrView>( + context: &mut EmitFuncContext<'c, '_, H>, + args: EmitOpArgs<'c, '_, TailLoop, H>, +) -> Result<()> { + let node = args.node(); + + // Make a block to jump to when we `Break` + let out_bb = context.new_basic_block("loop_out", None); + // A block for the body of the loop + let body_bb = context.new_basic_block("loop_body", Some(out_bb)); + + let (body_i_node, body_o_node) = node.get_io().unwrap(); + let body_i_rmb = context.node_outs_rmb(body_i_node)?; + let body_o_rmb = context.node_ins_rmb(body_o_node)?; + + body_i_rmb.write(context.builder(), args.inputs)?; + context.builder().build_unconditional_branch(body_bb)?; + + let control_llvm_sum_type = { + let sum_ty = SumType::new([node.just_inputs.clone(), node.just_outputs.clone()]); + context.llvm_sum_type(sum_ty)? + }; + + context.build_positioned(body_bb, move |context| { + let inputs = body_i_rmb.read_vec(context.builder(), [])?; + emit_dataflow_parent( + context, + EmitOpArgs { + node, + inputs, + outputs: body_o_rmb.promise(), + }, + )?; + let dataflow_outputs = body_o_rmb.read_vec(context.builder(), [])?; + let control_val = LLVMSumValue::try_new(dataflow_outputs[0], control_llvm_sum_type)?; + let mut outputs = Some(args.outputs); + + control_val.build_destructure(context.builder(), |builder, tag, mut values| { + values.extend(dataflow_outputs[1..].iter().copied()); + if tag == 0 { + body_i_rmb.write(builder, values)?; + builder.build_unconditional_branch(body_bb)?; + } else { + outputs.take().unwrap().finish(builder, values)?; + builder.build_unconditional_branch(out_bb)?; + } + Ok(()) + }) + })?; + context.builder().position_at_end(out_bb); + Ok(()) +} + fn emit_optype<'c, H: HugrView>( context: &mut EmitFuncContext<'c, '_, H>, args: EmitOpArgs<'c, '_, OpType, H>, @@ -330,7 +383,7 @@ fn emit_optype<'c, H: HugrView>( context.push_todo_func(node.into_ot(fd)); Ok(()) } - + OpType::TailLoop(x) => emit_tail_loop(context, args.into_ot(x)), _ => Err(anyhow!("Invalid child for Dataflow Parent: {node}")), } } diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop@llvm14.snap new file mode 100644 index 000000000..7765ae0a2 --- /dev/null +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop@llvm14.snap @@ -0,0 +1,61 @@ +--- +source: hugr-llvm/src/emit/test.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i64 @_hl.main.1() { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + br label %loop_body + +loop_body: ; preds = %12, %entry_block + %"9_0.0" = phi i64 [ 3, %entry_block ], [ %14, %12 ] + %"9_1.0" = phi i64 [ 7, %entry_block ], [ %0, %12 ] + %0 = mul i64 %"9_1.0", 2 + %1 = icmp eq i64 %"9_0.0", 0 + %2 = select i1 %1, { i32, {}, {} } { i32 1, {} poison, {} undef }, { i32, {}, {} } { i32 0, {} undef, {} poison } + %3 = extractvalue { i32, {}, {} } %2, 0 + switch i32 %3, label %4 [ + i32 1, label %6 + ] + +4: ; preds = %loop_body + %5 = extractvalue { i32, {}, {} } %2, 1 + br label %cond_17_case_0 + +6: ; preds = %loop_body + %7 = extractvalue { i32, {}, {} } %2, 2 + br label %cond_17_case_1 + +loop_out: ; preds = %15 + ret i64 %0 + +cond_17_case_0: ; preds = %4 + %8 = sub i64 %"9_0.0", 1 + %9 = insertvalue { i64 } undef, i64 %8, 0 + %10 = insertvalue { i32, { i64 }, {} } { i32 0, { i64 } poison, {} poison }, { i64 } %9, 1 + br label %cond_exit_17 + +cond_17_case_1: ; preds = %6 + br label %cond_exit_17 + +cond_exit_17: ; preds = %cond_17_case_1, %cond_17_case_0 + %"011.0" = phi { i32, { i64 }, {} } [ %10, %cond_17_case_0 ], [ { i32 1, { i64 } poison, {} undef }, %cond_17_case_1 ] + %11 = extractvalue { i32, { i64 }, {} } %"011.0", 0 + switch i32 %11, label %12 [ + i32 1, label %15 + ] + +12: ; preds = %cond_exit_17 + %13 = extractvalue { i32, { i64 }, {} } %"011.0", 1 + %14 = extractvalue { i64 } %13, 0 + br label %loop_body + +15: ; preds = %cond_exit_17 + %16 = extractvalue { i32, { i64 }, {} } %"011.0", 2 + br label %loop_out +} diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop@pre-mem2reg@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..6f9d1cb13 --- /dev/null +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop@pre-mem2reg@llvm14.snap @@ -0,0 +1,129 @@ +--- +source: hugr-llvm/src/emit/test.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define i64 @_hl.main.1() { +alloca_block: + %"0" = alloca i64, align 8 + %"7_0" = alloca i64, align 8 + %"5_0" = alloca i64, align 8 + %"8_0" = alloca i64, align 8 + %"9_0" = alloca i64, align 8 + %"9_1" = alloca i64, align 8 + %"17_0" = alloca { i32, { i64 }, {} }, align 8 + %"16_0" = alloca i64, align 8 + %"15_0" = alloca i64, align 8 + %"12_0" = alloca i64, align 8 + %"13_0" = alloca { i32, {}, {} }, align 8 + %"011" = alloca { i32, { i64 }, {} }, align 8 + %"013" = alloca i64, align 8 + %"23_0" = alloca i64, align 8 + %"20_0" = alloca i64, align 8 + %"24_0" = alloca i64, align 8 + %"25_0" = alloca { i32, { i64 }, {} }, align 8 + %"019" = alloca i64, align 8 + %"29_0" = alloca { i32, { i64 }, {} }, align 8 + %"27_0" = alloca i64, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i64 7, i64* %"7_0", align 4 + store i64 3, i64* %"5_0", align 4 + %"5_01" = load i64, i64* %"5_0", align 4 + %"7_02" = load i64, i64* %"7_0", align 4 + store i64 %"5_01", i64* %"9_0", align 4 + store i64 %"7_02", i64* %"9_1", align 4 + br label %loop_body + +loop_body: ; preds = %12, %entry_block + %"9_03" = load i64, i64* %"9_0", align 4 + %"9_14" = load i64, i64* %"9_1", align 4 + store i64 2, i64* %"15_0", align 4 + store i64 0, i64* %"12_0", align 4 + store i64 %"9_03", i64* %"9_0", align 4 + store i64 %"9_14", i64* %"9_1", align 4 + %"9_15" = load i64, i64* %"9_1", align 4 + %"15_06" = load i64, i64* %"15_0", align 4 + %0 = mul i64 %"9_15", %"15_06" + store i64 %0, i64* %"16_0", align 4 + %"9_07" = load i64, i64* %"9_0", align 4 + %"12_08" = load i64, i64* %"12_0", align 4 + %1 = icmp eq i64 %"9_07", %"12_08" + %2 = select i1 %1, { i32, {}, {} } { i32 1, {} poison, {} undef }, { i32, {}, {} } { i32 0, {} undef, {} poison } + store { i32, {}, {} } %2, { i32, {}, {} }* %"13_0", align 4 + %"13_09" = load { i32, {}, {} }, { i32, {}, {} }* %"13_0", align 4 + %"9_010" = load i64, i64* %"9_0", align 4 + %3 = extractvalue { i32, {}, {} } %"13_09", 0 + switch i32 %3, label %4 [ + i32 1, label %6 + ] + +4: ; preds = %loop_body + %5 = extractvalue { i32, {}, {} } %"13_09", 1 + store i64 %"9_010", i64* %"013", align 4 + br label %cond_17_case_0 + +6: ; preds = %loop_body + %7 = extractvalue { i32, {}, {} } %"13_09", 2 + store i64 %"9_010", i64* %"019", align 4 + br label %cond_17_case_1 + +loop_out: ; preds = %15 + %"8_026" = load i64, i64* %"8_0", align 4 + store i64 %"8_026", i64* %"0", align 4 + %"027" = load i64, i64* %"0", align 4 + ret i64 %"027" + +cond_17_case_0: ; preds = %4 + %"014" = load i64, i64* %"013", align 4 + store i64 1, i64* %"23_0", align 4 + store i64 %"014", i64* %"20_0", align 4 + %"20_015" = load i64, i64* %"20_0", align 4 + %"23_016" = load i64, i64* %"23_0", align 4 + %8 = sub i64 %"20_015", %"23_016" + store i64 %8, i64* %"24_0", align 4 + %"24_017" = load i64, i64* %"24_0", align 4 + %9 = insertvalue { i64 } undef, i64 %"24_017", 0 + %10 = insertvalue { i32, { i64 }, {} } { i32 0, { i64 } poison, {} poison }, { i64 } %9, 1 + store { i32, { i64 }, {} } %10, { i32, { i64 }, {} }* %"25_0", align 4 + %"25_018" = load { i32, { i64 }, {} }, { i32, { i64 }, {} }* %"25_0", align 4 + store { i32, { i64 }, {} } %"25_018", { i32, { i64 }, {} }* %"011", align 4 + br label %cond_exit_17 + +cond_17_case_1: ; preds = %6 + %"020" = load i64, i64* %"019", align 4 + store { i32, { i64 }, {} } { i32 1, { i64 } poison, {} undef }, { i32, { i64 }, {} }* %"29_0", align 4 + %"29_021" = load { i32, { i64 }, {} }, { i32, { i64 }, {} }* %"29_0", align 4 + store { i32, { i64 }, {} } %"29_021", { i32, { i64 }, {} }* %"011", align 4 + store i64 %"020", i64* %"27_0", align 4 + br label %cond_exit_17 + +cond_exit_17: ; preds = %cond_17_case_1, %cond_17_case_0 + %"012" = load { i32, { i64 }, {} }, { i32, { i64 }, {} }* %"011", align 4 + store { i32, { i64 }, {} } %"012", { i32, { i64 }, {} }* %"17_0", align 4 + %"17_022" = load { i32, { i64 }, {} }, { i32, { i64 }, {} }* %"17_0", align 4 + %"16_023" = load i64, i64* %"16_0", align 4 + store { i32, { i64 }, {} } %"17_022", { i32, { i64 }, {} }* %"17_0", align 4 + store i64 %"16_023", i64* %"16_0", align 4 + %"17_024" = load { i32, { i64 }, {} }, { i32, { i64 }, {} }* %"17_0", align 4 + %"16_025" = load i64, i64* %"16_0", align 4 + %11 = extractvalue { i32, { i64 }, {} } %"17_024", 0 + switch i32 %11, label %12 [ + i32 1, label %15 + ] + +12: ; preds = %cond_exit_17 + %13 = extractvalue { i32, { i64 }, {} } %"17_024", 1 + %14 = extractvalue { i64 } %13, 0 + store i64 %14, i64* %"9_0", align 4 + store i64 %"16_025", i64* %"9_1", align 4 + br label %loop_body + +15: ; preds = %cond_exit_17 + %16 = extractvalue { i32, { i64 }, {} } %"17_024", 2 + store i64 %"16_025", i64* %"8_0", align 4 + br label %loop_out +} diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop@pre-mem2reg@llvm14_2.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop@pre-mem2reg@llvm14_2.snap new file mode 100644 index 000000000..4b5b4ccf4 --- /dev/null +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop@pre-mem2reg@llvm14_2.snap @@ -0,0 +1,168 @@ +--- +source: hugr-llvm/src/emit/test.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define { { {} }, i64 } @_hl.main.1(i64 %0, i64 %1) { +alloca_block: + %"0" = alloca { {} }, align 8 + %"1" = alloca i64, align 8 + %"2_0" = alloca i64, align 8 + %"2_1" = alloca i64, align 8 + %"4_0" = alloca { {} }, align 8 + %"4_1" = alloca i64, align 8 + %"5_0" = alloca i64, align 8 + %"5_1" = alloca i64, align 8 + %"12_0" = alloca { i32, { i64 }, { { {} } } }, align 8 + %"12_1" = alloca i64, align 8 + %"8_0" = alloca i64, align 8 + %"10_0" = alloca { i32, { i64 }, { { {} } } }, align 8 + %"08" = alloca { i32, { i64 }, { { {} } } }, align 8 + %"19" = alloca i64, align 8 + %"012" = alloca i64, align 8 + %"113" = alloca i64, align 8 + %"19_0" = alloca i64, align 8 + %"17_0" = alloca i64, align 8 + %"14_0" = alloca i64, align 8 + %"14_1" = alloca i64, align 8 + %"20_0" = alloca i64, align 8 + %"21_0" = alloca i64, align 8 + %"22_0" = alloca { i32, { i64 }, { { {} } } }, align 8 + %"023" = alloca { {} }, align 8 + %"124" = alloca i64, align 8 + %"27_0" = alloca { {} }, align 8 + %"28_0" = alloca { i32, { i64 }, { { {} } } }, align 8 + %"24_0" = alloca { {} }, align 8 + %"24_1" = alloca i64, align 8 + %"9_0" = alloca { i32, {}, {} }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i64 %0, i64* %"2_0", align 4 + store i64 %1, i64* %"2_1", align 4 + %"2_01" = load i64, i64* %"2_0", align 4 + %"2_12" = load i64, i64* %"2_1", align 4 + store i64 %"2_01", i64* %"5_0", align 4 + store i64 %"2_12", i64* %"5_1", align 4 + br label %loop_body + +loop_body: ; preds = %20, %entry_block + %"5_03" = load i64, i64* %"5_0", align 4 + %"5_14" = load i64, i64* %"5_1", align 4 + store i64 0, i64* %"8_0", align 4 + store i64 %"5_03", i64* %"5_0", align 4 + store i64 %"5_14", i64* %"5_1", align 4 + %"5_05" = load i64, i64* %"5_0", align 4 + %2 = insertvalue { i64 } undef, i64 %"5_05", 0 + %3 = insertvalue { i32, { i64 }, { { {} } } } { i32 0, { i64 } poison, { { {} } } poison }, { i64 } %2, 1 + store { i32, { i64 }, { { {} } } } %3, { i32, { i64 }, { { {} } } }* %"10_0", align 4 + %"10_06" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"10_0", align 4 + %"5_17" = load i64, i64* %"5_1", align 4 + %4 = extractvalue { i32, { i64 }, { { {} } } } %"10_06", 0 + switch i32 %4, label %5 [ + i32 1, label %8 + ] + +5: ; preds = %loop_body + %6 = extractvalue { i32, { i64 }, { { {} } } } %"10_06", 1 + %7 = extractvalue { i64 } %6, 0 + store i64 %7, i64* %"012", align 4 + store i64 %"5_17", i64* %"113", align 4 + br label %cond_12_case_0 + +8: ; preds = %loop_body + %9 = extractvalue { i32, { i64 }, { { {} } } } %"10_06", 2 + %10 = extractvalue { { {} } } %9, 0 + store { {} } %10, { {} }* %"023", align 1 + store i64 %"5_17", i64* %"124", align 4 + br label %cond_12_case_1 + +loop_out: ; preds = %23 + %"4_036" = load { {} }, { {} }* %"4_0", align 1 + %"4_137" = load i64, i64* %"4_1", align 4 + store { {} } %"4_036", { {} }* %"0", align 1 + store i64 %"4_137", i64* %"1", align 4 + %"038" = load { {} }, { {} }* %"0", align 1 + %"139" = load i64, i64* %"1", align 4 + %mrv = insertvalue { { {} }, i64 } undef, { {} } %"038", 0 + %mrv40 = insertvalue { { {} }, i64 } %mrv, i64 %"139", 1 + ret { { {} }, i64 } %mrv40 + +cond_12_case_0: ; preds = %5 + %"014" = load i64, i64* %"012", align 4 + %"115" = load i64, i64* %"113", align 4 + store i64 2, i64* %"19_0", align 4 + store i64 1, i64* %"17_0", align 4 + store i64 %"014", i64* %"14_0", align 4 + store i64 %"115", i64* %"14_1", align 4 + %"14_116" = load i64, i64* %"14_1", align 4 + %"19_017" = load i64, i64* %"19_0", align 4 + %11 = mul i64 %"14_116", %"19_017" + store i64 %11, i64* %"20_0", align 4 + %"14_018" = load i64, i64* %"14_0", align 4 + %"17_019" = load i64, i64* %"17_0", align 4 + %12 = sub i64 %"14_018", %"17_019" + store i64 %12, i64* %"21_0", align 4 + %"21_020" = load i64, i64* %"21_0", align 4 + %13 = insertvalue { i64 } undef, i64 %"21_020", 0 + %14 = insertvalue { i32, { i64 }, { { {} } } } { i32 0, { i64 } poison, { { {} } } poison }, { i64 } %13, 1 + store { i32, { i64 }, { { {} } } } %14, { i32, { i64 }, { { {} } } }* %"22_0", align 4 + %"22_021" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"22_0", align 4 + %"20_022" = load i64, i64* %"20_0", align 4 + store { i32, { i64 }, { { {} } } } %"22_021", { i32, { i64 }, { { {} } } }* %"08", align 4 + store i64 %"20_022", i64* %"19", align 4 + br label %cond_exit_12 + +cond_12_case_1: ; preds = %8 + %"025" = load { {} }, { {} }* %"023", align 1 + %"126" = load i64, i64* %"124", align 4 + store { {} } undef, { {} }* %"27_0", align 1 + %"27_027" = load { {} }, { {} }* %"27_0", align 1 + %15 = insertvalue { { {} } } undef, { {} } %"27_027", 0 + %16 = insertvalue { i32, { i64 }, { { {} } } } { i32 1, { i64 } poison, { { {} } } poison }, { { {} } } %15, 2 + store { i32, { i64 }, { { {} } } } %16, { i32, { i64 }, { { {} } } }* %"28_0", align 4 + store { {} } %"025", { {} }* %"24_0", align 1 + store i64 %"126", i64* %"24_1", align 4 + %"28_028" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"28_0", align 4 + %"24_129" = load i64, i64* %"24_1", align 4 + store { i32, { i64 }, { { {} } } } %"28_028", { i32, { i64 }, { { {} } } }* %"08", align 4 + store i64 %"24_129", i64* %"19", align 4 + br label %cond_exit_12 + +cond_exit_12: ; preds = %cond_12_case_1, %cond_12_case_0 + %"010" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"08", align 4 + %"111" = load i64, i64* %"19", align 4 + store { i32, { i64 }, { { {} } } } %"010", { i32, { i64 }, { { {} } } }* %"12_0", align 4 + store i64 %"111", i64* %"12_1", align 4 + %"12_030" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"12_0", align 4 + %"12_131" = load i64, i64* %"12_1", align 4 + store { i32, { i64 }, { { {} } } } %"12_030", { i32, { i64 }, { { {} } } }* %"12_0", align 4 + store i64 %"12_131", i64* %"12_1", align 4 + %"5_032" = load i64, i64* %"5_0", align 4 + %"8_033" = load i64, i64* %"8_0", align 4 + %17 = icmp ule i64 %"5_032", %"8_033" + %18 = select i1 %17, { i32, {}, {} } { i32 1, {} poison, {} undef }, { i32, {}, {} } { i32 0, {} undef, {} poison } + store { i32, {}, {} } %18, { i32, {}, {} }* %"9_0", align 4 + %"12_034" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"12_0", align 4 + %"12_135" = load i64, i64* %"12_1", align 4 + %19 = extractvalue { i32, { i64 }, { { {} } } } %"12_034", 0 + switch i32 %19, label %20 [ + i32 1, label %23 + ] + +20: ; preds = %cond_exit_12 + %21 = extractvalue { i32, { i64 }, { { {} } } } %"12_034", 1 + %22 = extractvalue { i64 } %21, 0 + store i64 %22, i64* %"5_0", align 4 + store i64 %"12_135", i64* %"5_1", align 4 + br label %loop_body + +23: ; preds = %cond_exit_12 + %24 = extractvalue { i32, { i64 }, { { {} } } } %"12_034", 2 + %25 = extractvalue { { {} } } %24, 0 + store { {} } %25, { {} }* %"4_0", align 1 + store i64 %"12_135", i64* %"4_1", align 4 + br label %loop_out +} diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop_simple@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop_simple@llvm14.snap new file mode 100644 index 000000000..408f61aeb --- /dev/null +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop_simple@llvm14.snap @@ -0,0 +1,36 @@ +--- +source: hugr-llvm/src/emit/test.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define { {} } @_hl.main.1(i64 %0) { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + br label %loop_body + +loop_body: ; preds = %4, %entry_block + %"5_0.0" = phi i64 [ %0, %entry_block ], [ %6, %4 ] + %1 = insertvalue { i64 } undef, i64 %"5_0.0", 0 + %2 = insertvalue { i32, { i64 }, { { {} } } } { i32 0, { i64 } poison, { { {} } } poison }, { i64 } %1, 1 + %3 = extractvalue { i32, { i64 }, { { {} } } } %2, 0 + switch i32 %3, label %4 [ + i32 1, label %7 + ] + +4: ; preds = %loop_body + %5 = extractvalue { i32, { i64 }, { { {} } } } %2, 1 + %6 = extractvalue { i64 } %5, 0 + br label %loop_body + +7: ; preds = %loop_body + %8 = extractvalue { i32, { i64 }, { { {} } } } %2, 2 + %9 = extractvalue { { {} } } %8, 0 + br label %loop_out + +loop_out: ; preds = %7 + ret { {} } %9 +} diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop_simple@pre-mem2reg@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop_simple@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..7c1a36ea0 --- /dev/null +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop_simple@pre-mem2reg@llvm14.snap @@ -0,0 +1,55 @@ +--- +source: hugr-llvm/src/emit/test.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define { {} } @_hl.main.1(i64 %0) { +alloca_block: + %"0" = alloca { {} }, align 8 + %"2_0" = alloca i64, align 8 + %"4_0" = alloca { {} }, align 8 + %"5_0" = alloca i64, align 8 + %"7_0" = alloca { i32, { i64 }, { { {} } } }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i64 %0, i64* %"2_0", align 4 + %"2_01" = load i64, i64* %"2_0", align 4 + store i64 %"2_01", i64* %"5_0", align 4 + br label %loop_body + +loop_body: ; preds = %4, %entry_block + %"5_02" = load i64, i64* %"5_0", align 4 + store i64 %"5_02", i64* %"5_0", align 4 + %"5_03" = load i64, i64* %"5_0", align 4 + %1 = insertvalue { i64 } undef, i64 %"5_03", 0 + %2 = insertvalue { i32, { i64 }, { { {} } } } { i32 0, { i64 } poison, { { {} } } poison }, { i64 } %1, 1 + store { i32, { i64 }, { { {} } } } %2, { i32, { i64 }, { { {} } } }* %"7_0", align 4 + %"7_04" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"7_0", align 4 + store { i32, { i64 }, { { {} } } } %"7_04", { i32, { i64 }, { { {} } } }* %"7_0", align 4 + %"7_05" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"7_0", align 4 + %3 = extractvalue { i32, { i64 }, { { {} } } } %"7_05", 0 + switch i32 %3, label %4 [ + i32 1, label %7 + ] + +4: ; preds = %loop_body + %5 = extractvalue { i32, { i64 }, { { {} } } } %"7_05", 1 + %6 = extractvalue { i64 } %5, 0 + store i64 %6, i64* %"5_0", align 4 + br label %loop_body + +7: ; preds = %loop_body + %8 = extractvalue { i32, { i64 }, { { {} } } } %"7_05", 2 + %9 = extractvalue { { {} } } %8, 0 + store { {} } %9, { {} }* %"4_0", align 1 + br label %loop_out + +loop_out: ; preds = %7 + %"4_06" = load { {} }, { {} }* %"4_0", align 1 + store { {} } %"4_06", { {} }* %"0", align 1 + %"07" = load { {} }, { {} }* %"0", align 1 + ret { {} } %"07" +} diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop_simple@pre-mem2reg@llvm14_3.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop_simple@pre-mem2reg@llvm14_3.snap new file mode 100644 index 000000000..7c1a36ea0 --- /dev/null +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__tail_loop_simple@pre-mem2reg@llvm14_3.snap @@ -0,0 +1,55 @@ +--- +source: hugr-llvm/src/emit/test.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define { {} } @_hl.main.1(i64 %0) { +alloca_block: + %"0" = alloca { {} }, align 8 + %"2_0" = alloca i64, align 8 + %"4_0" = alloca { {} }, align 8 + %"5_0" = alloca i64, align 8 + %"7_0" = alloca { i32, { i64 }, { { {} } } }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i64 %0, i64* %"2_0", align 4 + %"2_01" = load i64, i64* %"2_0", align 4 + store i64 %"2_01", i64* %"5_0", align 4 + br label %loop_body + +loop_body: ; preds = %4, %entry_block + %"5_02" = load i64, i64* %"5_0", align 4 + store i64 %"5_02", i64* %"5_0", align 4 + %"5_03" = load i64, i64* %"5_0", align 4 + %1 = insertvalue { i64 } undef, i64 %"5_03", 0 + %2 = insertvalue { i32, { i64 }, { { {} } } } { i32 0, { i64 } poison, { { {} } } poison }, { i64 } %1, 1 + store { i32, { i64 }, { { {} } } } %2, { i32, { i64 }, { { {} } } }* %"7_0", align 4 + %"7_04" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"7_0", align 4 + store { i32, { i64 }, { { {} } } } %"7_04", { i32, { i64 }, { { {} } } }* %"7_0", align 4 + %"7_05" = load { i32, { i64 }, { { {} } } }, { i32, { i64 }, { { {} } } }* %"7_0", align 4 + %3 = extractvalue { i32, { i64 }, { { {} } } } %"7_05", 0 + switch i32 %3, label %4 [ + i32 1, label %7 + ] + +4: ; preds = %loop_body + %5 = extractvalue { i32, { i64 }, { { {} } } } %"7_05", 1 + %6 = extractvalue { i64 } %5, 0 + store i64 %6, i64* %"5_0", align 4 + br label %loop_body + +7: ; preds = %loop_body + %8 = extractvalue { i32, { i64 }, { { {} } } } %"7_05", 2 + %9 = extractvalue { { {} } } %8, 0 + store { {} } %9, { {} }* %"4_0", align 1 + br label %loop_out + +loop_out: ; preds = %7 + %"4_06" = load { {} }, { {} }* %"4_0", align 1 + store { {} } %"4_06", { {} }* %"0", align 1 + %"07" = load { {} }, { {} }* %"0", align 1 + ret { {} } %"07" +} diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index cdf3e024c..52fce5154 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -165,7 +165,7 @@ impl SimpleHugrConfig { // unvalidated. // println!("{}", mod_b.hugr().mermaid_string()); - mod_b.finish_hugr().unwrap() + mod_b.finish_hugr().unwrap_or_else(|e| panic!("{e}")) } } @@ -246,7 +246,7 @@ mod test_fns { use super::*; use crate::custom::CodegenExtsBuilder; use crate::extension::int::add_int_extensions; - use crate::types::HugrFuncType; + use crate::types::{HugrFuncType, HugrSumType}; use hugr_core::builder::DataflowSubContainer; use hugr_core::builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer}; @@ -261,7 +261,7 @@ mod test_fns { use hugr_core::{type_row, Hugr}; use itertools::Itertools; - use rstest::rstest; + use rstest::{fixture, rstest}; use std::iter; use crate::test::*; @@ -549,6 +549,194 @@ mod test_fns { check_emission!(hugr, llvm_ctx); } + #[rstest] + fn tail_loop_simple(mut llvm_ctx: TestContext) { + let hugr = { + let just_input = usize_t(); + let just_output = Type::UNIT; + let input_v = TypeRow::from(vec![just_input.clone()]); + let output_v = TypeRow::from(vec![just_output.clone()]); + + SimpleHugrConfig::new() + .with_extensions(PRELUDE_REGISTRY.clone()) + .with_ins(input_v) + .with_outs(output_v) + .finish(|mut builder: DFGW| { + let [just_in_w] = builder.input_wires_arr(); + let mut tail_b = builder + .tail_loop_builder( + [(just_input.clone(), just_in_w)], + [], + vec![just_output.clone()].into(), + ) + .unwrap(); + + let input = tail_b.input(); + let [inp_w] = input.outputs_arr(); + + let loop_sig = tail_b.loop_signature().unwrap().clone(); + + // builder.add_dataflow_op(ops::Noop, input_wires) + + let sum_inp_w = tail_b.make_continue(loop_sig.clone(), [inp_w]).unwrap(); + + let outs @ [_] = tail_b + .finish_with_outputs(sum_inp_w, []) + .unwrap() + .outputs_arr(); + builder.finish_with_outputs(outs).unwrap() + }) + }; + llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); + + check_emission!(hugr, llvm_ctx); + } + + #[fixture] + fn terminal_loop(#[default(3)] iters: u64, #[default(7)] input: u64) -> Hugr { + /* + Computes roughly the following: + ```python + def terminal_loop(counter: int, val: int) -> int: + while True: + val = val * 2 + if counter == 0: + break + else: + counter -= 1 + return val + ``` + */ + let int_ty = int_types::int_type(6); + let just_input = int_ty.clone(); + let other_ty = int_ty.clone(); + + let mut registry = PRELUDE_REGISTRY.clone(); + registry.register(int_ops::EXTENSION.clone()).unwrap(); + registry.register(int_types::EXTENSION.clone()).unwrap(); + + SimpleHugrConfig::new() + .with_extensions(registry) + .with_outs(int_ty.clone()) + .finish(|mut builder: DFGW| { + let just_in_w = builder.add_load_value(ConstInt::new_u(6, iters).unwrap()); + let other_w = builder.add_load_value(ConstInt::new_u(6, input).unwrap()); + + let tail_l = { + let mut tail_b = builder + .tail_loop_builder( + [(just_input.clone(), just_in_w)], + [(other_ty.clone(), other_w)], + type_row![], + ) + .unwrap(); + let [loop_int_w, other_w] = tail_b.input_wires_arr(); + + let zero = ConstInt::new_u(6, 0).unwrap(); + let zero_w = tail_b.add_load_value(zero); + let [eq_0] = tail_b + .add_dataflow_op( + int_ops::IntOpDef::ieq.with_log_width(6), + [loop_int_w, zero_w], + ) + .unwrap() + .outputs_arr(); + + let loop_sig = tail_b.loop_signature().unwrap().clone(); + + let two = ConstInt::new_u(6, 2).unwrap(); + let two_w = tail_b.add_load_value(two); + + let [other_mul_2] = tail_b + .add_dataflow_op( + int_ops::IntOpDef::imul.with_log_width(6), + [other_w, two_w], + ) + .unwrap() + .outputs_arr(); + let cond = { + let mut cond_b = tail_b + .conditional_builder( + ([type_row![], type_row![]], eq_0), + vec![(just_input.clone(), loop_int_w)], + vec![HugrSumType::new(vec![ + vec![just_input.clone()].into(), + vec![], + ]) + .into()] + .into(), + ) + .unwrap(); + + // If the check is false, we subtract 1 and continue + let _false_case = { + let mut false_case_b = cond_b.case_builder(0).unwrap(); + let [counter] = false_case_b.input_wires_arr(); + let one = ConstInt::new_u(6, 1).unwrap(); + let one_w = false_case_b.add_load_value(one); + + let [counter] = false_case_b + .add_dataflow_op( + int_ops::IntOpDef::isub.with_log_width(6), + [counter, one_w], + ) + .unwrap() + .outputs_arr(); + let tag_continue = false_case_b + .make_continue(loop_sig.clone(), [counter]) + .unwrap(); + + false_case_b.finish_with_outputs([tag_continue]).unwrap() + }; + let _true_case = { + // In the true case, we break and output true along with the "other" input wire + let mut true_case_b = cond_b.case_builder(1).unwrap(); + + let [_counter] = true_case_b.input_wires_arr(); + + let tagged_break = + true_case_b.make_break(loop_sig.clone(), []).unwrap(); + true_case_b.finish_with_outputs([tagged_break]).unwrap() + }; + + cond_b.finish_sub_container().unwrap() + }; + tail_b + .finish_with_outputs(cond.out_wire(0), [other_mul_2]) + .unwrap() + }; + let [out_int] = tail_l.outputs_arr(); + println!("{}", builder.hugr().mermaid_string()); + builder + .finish_with_outputs([out_int]) + .unwrap_or_else(|e| panic!("{e}")) + }) + } + + #[rstest] + fn tail_loop(mut llvm_ctx: TestContext, #[with(3, 7)] terminal_loop: Hugr) { + llvm_ctx.add_extensions(add_int_extensions); + + check_emission!(terminal_loop, llvm_ctx); + } + + #[rstest] + #[case(3, 7)] + #[case(2, 1)] + #[case(20, 0)] + fn tail_loop_exec( + mut exec_ctx: TestContext, + #[case] iters: u64, + #[case] input: u64, + #[with(iters, input)] terminal_loop: Hugr, + ) { + exec_ctx.add_extensions(add_int_extensions); + assert_eq!( + input * 1 << (iters + 1), + exec_ctx.exec_hugr_u64(terminal_loop, "main") + ); + } + #[rstest] fn test_exec(mut exec_ctx: TestContext) { let hugr = SimpleHugrConfig::new() diff --git a/hugr-llvm/src/sum.rs b/hugr-llvm/src/sum.rs index 31c6b5357..059b37873 100644 --- a/hugr-llvm/src/sum.rs +++ b/hugr-llvm/src/sum.rs @@ -249,7 +249,7 @@ impl<'c> LLVMSumValue<'c> { pub fn build_destructure( &self, builder: &Builder<'c>, - handler: impl Fn(&Builder<'c>, usize, Vec>) -> Result<()>, + mut handler: impl FnMut(&Builder<'c>, usize, Vec>) -> Result<()>, ) -> Result<()> { let orig_bb = builder .get_insert_block()