Skip to content

Commit

Permalink
feat(llvm): Tail loop emission (#1749)
Browse files Browse the repository at this point in the history
Work in progress. Will address #1688

---------

Co-authored-by: Douglas Wilson <[email protected]>
Co-authored-by: Seyon Sivarajah <[email protected]>
Co-authored-by: Douglas Wilson <[email protected]>
  • Loading branch information
4 people authored Dec 11, 2024
1 parent a866bb5 commit 89545e8
Show file tree
Hide file tree
Showing 11 changed files with 767 additions and 9 deletions.
2 changes: 1 addition & 1 deletion hugr-llvm/src/emit/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ impl<'c, 'a, H: HugrView> EmitFuncContext<'c, 'a, H> {
Ok(r)
}

/// Returns a [RowMailBox] mapped to thie ouput wires of `node`. When emitting a node
/// Returns a [RowMailBox] mapped to the output wires of `node`. When emitting a node
/// output values are written to this mailbox.
pub fn node_outs_rmb<'hugr, OT: 'hugr>(
&mut self,
Expand Down
15 changes: 14 additions & 1 deletion hugr-llvm/src/emit/func/mailbox.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{borrow::Cow, rc::Rc};

use anyhow::Result;
use anyhow::{bail, Result};
use delegate::delegate;
use inkwell::{
builder::Builder,
Expand Down Expand Up @@ -148,6 +148,19 @@ impl<'c> RowMailBox<'c> {
builder: &Builder<'c>,
vs: impl IntoIterator<Item = BasicValueEnum<'c>>,
) -> Result<()> {
let vs = vs.into_iter().collect_vec();
#[cfg(debug_assertions)]
{
let actual_types = vs.clone().into_iter().map(|x| x.get_type()).collect_vec();
let expected_types = self.get_types().collect_vec();
if actual_types != expected_types {
bail!(
"RowMailbox::write: Expected types {:?}, got {:?}",
expected_types,
actual_types
);
}
}
zip_eq(self.0.iter(), vs).try_for_each(|(mb, v)| mb.write(builder, v))
}

Expand Down
59 changes: 56 additions & 3 deletions hugr-llvm/src/emit/ops.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use anyhow::{anyhow, bail, Result};
use hugr_core::ops::{
constant::Sum, Call, CallIndirect, Case, Conditional, Const, ExtensionOp, Input, LoadConstant,
LoadFunction, OpTag, OpTrait, OpType, Output, Tag, Value, CFG,
LoadFunction, OpTag, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG,
};
use hugr_core::{
hugr::views::SiblingGraph,
Expand Down Expand Up @@ -58,7 +58,7 @@ where
.ok_or(anyhow!("DataflowParentEmitter: Output taken twice"))
}

pub fn emit_children(mut self, context: &mut EmitFuncContext<'c, '_, H>) -> Result<()> {
pub fn emit_children(&mut self, context: &mut EmitFuncContext<'c, '_, H>) -> Result<()> {
use petgraph::visit::Topo;
let node = self.node;
if !OpTag::DataflowParent.is_superset(node.tag()) {
Expand Down Expand Up @@ -306,6 +306,59 @@ fn emit_cfg<'c, H: HugrView>(
cfg::CfgEmitter::new(context, args)?.emit_children(context)
}

fn emit_tail_loop<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, '_, H>,
args: EmitOpArgs<'c, '_, TailLoop, H>,
) -> Result<()> {
let node = args.node();

// Make a block to jump to when we `Break`
let out_bb = context.new_basic_block("loop_out", None);
// A block for the body of the loop
let body_bb = context.new_basic_block("loop_body", Some(out_bb));

let (body_i_node, body_o_node) = node.get_io().unwrap();
let body_i_rmb = context.node_outs_rmb(body_i_node)?;
let body_o_rmb = context.node_ins_rmb(body_o_node)?;

body_i_rmb.write(context.builder(), args.inputs)?;
context.builder().build_unconditional_branch(body_bb)?;

let control_llvm_sum_type = {
let sum_ty = SumType::new([node.just_inputs.clone(), node.just_outputs.clone()]);
context.llvm_sum_type(sum_ty)?
};

context.build_positioned(body_bb, move |context| {
let inputs = body_i_rmb.read_vec(context.builder(), [])?;
emit_dataflow_parent(
context,
EmitOpArgs {
node,
inputs,
outputs: body_o_rmb.promise(),
},
)?;
let dataflow_outputs = body_o_rmb.read_vec(context.builder(), [])?;
let control_val = LLVMSumValue::try_new(dataflow_outputs[0], control_llvm_sum_type)?;
let mut outputs = Some(args.outputs);

control_val.build_destructure(context.builder(), |builder, tag, mut values| {
values.extend(dataflow_outputs[1..].iter().copied());
if tag == 0 {
body_i_rmb.write(builder, values)?;
builder.build_unconditional_branch(body_bb)?;
} else {
outputs.take().unwrap().finish(builder, values)?;
builder.build_unconditional_branch(out_bb)?;
}
Ok(())
})
})?;
context.builder().position_at_end(out_bb);
Ok(())
}

fn emit_optype<'c, H: HugrView>(
context: &mut EmitFuncContext<'c, '_, H>,
args: EmitOpArgs<'c, '_, OpType, H>,
Expand All @@ -330,7 +383,7 @@ fn emit_optype<'c, H: HugrView>(
context.push_todo_func(node.into_ot(fd));
Ok(())
}

OpType::TailLoop(x) => emit_tail_loop(context, args.into_ot(x)),
_ => Err(anyhow!("Invalid child for Dataflow Parent: {node}")),
}
}
Expand Down
61 changes: 61 additions & 0 deletions hugr-llvm/src/emit/snapshots/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
---
source: hugr-llvm/src/emit/test.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i64 @_hl.main.1() {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
br label %loop_body

loop_body: ; preds = %12, %entry_block
%"9_0.0" = phi i64 [ 3, %entry_block ], [ %14, %12 ]
%"9_1.0" = phi i64 [ 7, %entry_block ], [ %0, %12 ]
%0 = mul i64 %"9_1.0", 2
%1 = icmp eq i64 %"9_0.0", 0
%2 = select i1 %1, { i32, {}, {} } { i32 1, {} poison, {} undef }, { i32, {}, {} } { i32 0, {} undef, {} poison }
%3 = extractvalue { i32, {}, {} } %2, 0
switch i32 %3, label %4 [
i32 1, label %6
]

4: ; preds = %loop_body
%5 = extractvalue { i32, {}, {} } %2, 1
br label %cond_17_case_0

6: ; preds = %loop_body
%7 = extractvalue { i32, {}, {} } %2, 2
br label %cond_17_case_1

loop_out: ; preds = %15
ret i64 %0

cond_17_case_0: ; preds = %4
%8 = sub i64 %"9_0.0", 1
%9 = insertvalue { i64 } undef, i64 %8, 0
%10 = insertvalue { i32, { i64 }, {} } { i32 0, { i64 } poison, {} poison }, { i64 } %9, 1
br label %cond_exit_17

cond_17_case_1: ; preds = %6
br label %cond_exit_17

cond_exit_17: ; preds = %cond_17_case_1, %cond_17_case_0
%"011.0" = phi { i32, { i64 }, {} } [ %10, %cond_17_case_0 ], [ { i32 1, { i64 } poison, {} undef }, %cond_17_case_1 ]
%11 = extractvalue { i32, { i64 }, {} } %"011.0", 0
switch i32 %11, label %12 [
i32 1, label %15
]

12: ; preds = %cond_exit_17
%13 = extractvalue { i32, { i64 }, {} } %"011.0", 1
%14 = extractvalue { i64 } %13, 0
br label %loop_body

15: ; preds = %cond_exit_17
%16 = extractvalue { i32, { i64 }, {} } %"011.0", 2
br label %loop_out
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
---
source: hugr-llvm/src/emit/test.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define i64 @_hl.main.1() {
alloca_block:
%"0" = alloca i64, align 8
%"7_0" = alloca i64, align 8
%"5_0" = alloca i64, align 8
%"8_0" = alloca i64, align 8
%"9_0" = alloca i64, align 8
%"9_1" = alloca i64, align 8
%"17_0" = alloca { i32, { i64 }, {} }, align 8
%"16_0" = alloca i64, align 8
%"15_0" = alloca i64, align 8
%"12_0" = alloca i64, align 8
%"13_0" = alloca { i32, {}, {} }, align 8
%"011" = alloca { i32, { i64 }, {} }, align 8
%"013" = alloca i64, align 8
%"23_0" = alloca i64, align 8
%"20_0" = alloca i64, align 8
%"24_0" = alloca i64, align 8
%"25_0" = alloca { i32, { i64 }, {} }, align 8
%"019" = alloca i64, align 8
%"29_0" = alloca { i32, { i64 }, {} }, align 8
%"27_0" = alloca i64, align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store i64 7, i64* %"7_0", align 4
store i64 3, i64* %"5_0", align 4
%"5_01" = load i64, i64* %"5_0", align 4
%"7_02" = load i64, i64* %"7_0", align 4
store i64 %"5_01", i64* %"9_0", align 4
store i64 %"7_02", i64* %"9_1", align 4
br label %loop_body

loop_body: ; preds = %12, %entry_block
%"9_03" = load i64, i64* %"9_0", align 4
%"9_14" = load i64, i64* %"9_1", align 4
store i64 2, i64* %"15_0", align 4
store i64 0, i64* %"12_0", align 4
store i64 %"9_03", i64* %"9_0", align 4
store i64 %"9_14", i64* %"9_1", align 4
%"9_15" = load i64, i64* %"9_1", align 4
%"15_06" = load i64, i64* %"15_0", align 4
%0 = mul i64 %"9_15", %"15_06"
store i64 %0, i64* %"16_0", align 4
%"9_07" = load i64, i64* %"9_0", align 4
%"12_08" = load i64, i64* %"12_0", align 4
%1 = icmp eq i64 %"9_07", %"12_08"
%2 = select i1 %1, { i32, {}, {} } { i32 1, {} poison, {} undef }, { i32, {}, {} } { i32 0, {} undef, {} poison }
store { i32, {}, {} } %2, { i32, {}, {} }* %"13_0", align 4
%"13_09" = load { i32, {}, {} }, { i32, {}, {} }* %"13_0", align 4
%"9_010" = load i64, i64* %"9_0", align 4
%3 = extractvalue { i32, {}, {} } %"13_09", 0
switch i32 %3, label %4 [
i32 1, label %6
]

4: ; preds = %loop_body
%5 = extractvalue { i32, {}, {} } %"13_09", 1
store i64 %"9_010", i64* %"013", align 4
br label %cond_17_case_0

6: ; preds = %loop_body
%7 = extractvalue { i32, {}, {} } %"13_09", 2
store i64 %"9_010", i64* %"019", align 4
br label %cond_17_case_1

loop_out: ; preds = %15
%"8_026" = load i64, i64* %"8_0", align 4
store i64 %"8_026", i64* %"0", align 4
%"027" = load i64, i64* %"0", align 4
ret i64 %"027"

cond_17_case_0: ; preds = %4
%"014" = load i64, i64* %"013", align 4
store i64 1, i64* %"23_0", align 4
store i64 %"014", i64* %"20_0", align 4
%"20_015" = load i64, i64* %"20_0", align 4
%"23_016" = load i64, i64* %"23_0", align 4
%8 = sub i64 %"20_015", %"23_016"
store i64 %8, i64* %"24_0", align 4
%"24_017" = load i64, i64* %"24_0", align 4
%9 = insertvalue { i64 } undef, i64 %"24_017", 0
%10 = insertvalue { i32, { i64 }, {} } { i32 0, { i64 } poison, {} poison }, { i64 } %9, 1
store { i32, { i64 }, {} } %10, { i32, { i64 }, {} }* %"25_0", align 4
%"25_018" = load { i32, { i64 }, {} }, { i32, { i64 }, {} }* %"25_0", align 4
store { i32, { i64 }, {} } %"25_018", { i32, { i64 }, {} }* %"011", align 4
br label %cond_exit_17

cond_17_case_1: ; preds = %6
%"020" = load i64, i64* %"019", align 4
store { i32, { i64 }, {} } { i32 1, { i64 } poison, {} undef }, { i32, { i64 }, {} }* %"29_0", align 4
%"29_021" = load { i32, { i64 }, {} }, { i32, { i64 }, {} }* %"29_0", align 4
store { i32, { i64 }, {} } %"29_021", { i32, { i64 }, {} }* %"011", align 4
store i64 %"020", i64* %"27_0", align 4
br label %cond_exit_17

cond_exit_17: ; preds = %cond_17_case_1, %cond_17_case_0
%"012" = load { i32, { i64 }, {} }, { i32, { i64 }, {} }* %"011", align 4
store { i32, { i64 }, {} } %"012", { i32, { i64 }, {} }* %"17_0", align 4
%"17_022" = load { i32, { i64 }, {} }, { i32, { i64 }, {} }* %"17_0", align 4
%"16_023" = load i64, i64* %"16_0", align 4
store { i32, { i64 }, {} } %"17_022", { i32, { i64 }, {} }* %"17_0", align 4
store i64 %"16_023", i64* %"16_0", align 4
%"17_024" = load { i32, { i64 }, {} }, { i32, { i64 }, {} }* %"17_0", align 4
%"16_025" = load i64, i64* %"16_0", align 4
%11 = extractvalue { i32, { i64 }, {} } %"17_024", 0
switch i32 %11, label %12 [
i32 1, label %15
]

12: ; preds = %cond_exit_17
%13 = extractvalue { i32, { i64 }, {} } %"17_024", 1
%14 = extractvalue { i64 } %13, 0
store i64 %14, i64* %"9_0", align 4
store i64 %"16_025", i64* %"9_1", align 4
br label %loop_body

15: ; preds = %cond_exit_17
%16 = extractvalue { i32, { i64 }, {} } %"17_024", 2
store i64 %"16_025", i64* %"8_0", align 4
br label %loop_out
}
Loading

0 comments on commit 89545e8

Please sign in to comment.