From 392c63b816da625ce358043f1ae2cf16046ced81 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 10 Dec 2024 09:30:37 +0000 Subject: [PATCH] WIP: Tail loop emission --- hugr-llvm/src/emit/ops.rs | 78 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 6 deletions(-) diff --git a/hugr-llvm/src/emit/ops.rs b/hugr-llvm/src/emit/ops.rs index 5fdbd5a76..c8991e514 100644 --- a/hugr-llvm/src/emit/ops.rs +++ b/hugr-llvm/src/emit/ops.rs @@ -1,15 +1,15 @@ use anyhow::{anyhow, bail, Result}; use hugr_core::ops::{ constant::Sum, Call, CallIndirect, Case, Conditional, Const, ExtensionOp, Input, LoadConstant, - LoadFunction, OpTag, OpTrait, OpType, Output, Tag, Value, CFG, + LoadFunction, OpTag, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG, }; use hugr_core::{ hugr::views::SiblingGraph, types::{SumType, Type, TypeEnum}, HugrView, NodeIndex, }; -use inkwell::types::BasicTypeEnum; -use inkwell::values::{BasicValueEnum, CallableValue}; +use inkwell::types::{BasicTypeEnum, IntType}; +use inkwell::values::{BasicValueEnum, CallableValue, IntValue}; use itertools::{zip_eq, Itertools}; use petgraph::visit::Walker; @@ -21,7 +21,7 @@ use crate::{ use super::{ deaggregate_call_result, - func::{EmitFuncContext, RowPromise}, + func::{EmitFuncContext, RowMailBox, RowPromise}, EmitOpArgs, }; @@ -31,6 +31,7 @@ struct DataflowParentEmitter<'c, 'hugr, OT, H> { node: FatNode<'hugr, OT, H>, inputs: Option>>, outputs: Option>, + output_vals: Option>>, } impl<'c, 'hugr, OT: OpTrait, H: HugrView> DataflowParentEmitter<'c, 'hugr, OT, H> @@ -42,6 +43,7 @@ where node: args.node, inputs: Some(args.inputs), outputs: Some(args.outputs), + output_vals: None, } } @@ -58,7 +60,7 @@ where .ok_or(anyhow!("DataflowParentEmitter: Output taken twice")) } - pub fn emit_children(mut self, context: &mut EmitFuncContext<'c, '_, H>) -> Result<()> { + pub fn emit_children(&mut self, context: &mut EmitFuncContext<'c, '_, H>) -> Result<()> { use petgraph::visit::Topo; let node = self.node; if !OpTag::DataflowParent.is_superset(node.tag()) { @@ -306,6 +308,70 @@ fn emit_cfg<'c, H: HugrView>( cfg::CfgEmitter::new(context, args)?.emit_children(context) } +fn emit_tail_loop<'c, H: HugrView>( + context: &mut EmitFuncContext<'c, '_, H>, + args: EmitOpArgs<'c, '_, TailLoop, H> +) -> Result<()> { + // TODO: Switch on the tag in loop_body to see where to go next + // TODO: Handle "other" args + + + // Make a block to jump to when we `Break` + let out_bb = context.new_basic_block("loop_out", None); + // A block for the body of the loop + let body_bb = context.new_basic_block("loop_body", Some(out_bb)); + // Pack input data into a sum type - do we need this? + let prep_bb = context.new_basic_block("loop_prep", Some(body_bb)); + + context.builder().build_unconditional_branch(prep_bb); + + let sum_ty = SumType::new([args.node().just_inputs.clone(), args.node().just_outputs.clone()]); + let outs_rmb = context.node_outs_rmb(args.node)?; + + { + let builder = context.builder(); + builder.position_at_end(prep_bb); + let body_in_row = args.node.just_inputs.clone(); + let body_in_len = body_in_row.len(); + let body_in_tuple = context.llvm_sum_type(SumType::new_tuple(body_in_row))?; + + let mut loop_inputs = args.inputs.clone(); + let other_inputs = loop_inputs.split_off(body_in_len); + let loop_input_ptr = builder.build_alloca(body_in_tuple.clone(), "loop_input")?; + + let body_in_tup = body_in_tuple.build_tag(builder, 0, loop_inputs)?; + builder.build_store(loop_input_ptr, body_in_tup); + builder.build_unconditional_branch(body_bb); + + builder.position_at_end(body_bb); + }; + + + // Emit the body of the loop into the right block + let mut dfpe = DataflowParentEmitter::new(args); + dfpe.emit_children(context)?; + + // After the body we need to unpack the row type, then jump to the right block + let builder = context.builder(); + let output_vals: Vec = outs_rmb.read(builder, []).unwrap(); + let output_types: Vec<_> = outs_rmb.get_types().collect(); + let llvm_sum_ty = LLVMSumType::try_new(&context.typing_session(), sum_ty)?; + + println!("{:?}", output_vals); + let sum_output = LLVMSumValue::try_new(output_vals[0], llvm_sum_ty)?; + let tag = sum_output.build_get_tag(builder)?; + + let tag = IntValue::try_from(output_vals[0]).unwrap(); + let continue_tag = context.iw_context().i64_type().const_int(0, false); + let break_tag = context.iw_context().i64_type().const_int(1, false); + // TODO: Make this a conditional branch instead of switch + builder.build_switch(tag, out_bb, &[(break_tag, out_bb), (continue_tag, prep_bb)]); +du + // Return Ok so we can see the insta emission with + // `cargo insta test` for debugging + Ok(()) +} + fn emit_optype<'c, H: HugrView>( context: &mut EmitFuncContext<'c, '_, H>, args: EmitOpArgs<'c, '_, OpType, H>, @@ -330,7 +396,7 @@ fn emit_optype<'c, H: HugrView>( context.push_todo_func(node.into_ot(fd)); Ok(()) } - + OpType::TailLoop(x) => emit_tail_loop(context, args.into_ot(x)), _ => Err(anyhow!("Invalid child for Dataflow Parent: {node}")), } }