diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index 632224946..7637ceb5c 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -4,14 +4,14 @@ use anyhow::{anyhow, Result}; use hugr_core::builder::{ BuildHandle, Container, DFGWrapper, HugrBuilder, ModuleBuilder, SubContainer, }; -use hugr_core::extension::prelude::PRELUDE_ID; +use hugr_core::extension::prelude::{BOOL_T, PRELUDE, PRELUDE_ID, USIZE_T}; use hugr_core::extension::{ExtensionRegistry, ExtensionSet, EMPTY_REG}; use hugr_core::ops::handle::FuncID; use hugr_core::std_extensions::arithmetic::{ conversions, float_ops, float_types, int_ops, int_types, }; use hugr_core::std_extensions::{collections, logic}; -use hugr_core::types::TypeRow; +use hugr_core::types::{SumType, TypeRow}; use hugr_core::{Hugr, HugrView}; use inkwell::module::Module; use inkwell::passes::PassManager; @@ -551,6 +551,125 @@ mod test_fns { check_emission!(hugr, llvm_ctx); } + #[rstest] + fn tail_loop_simple(mut llvm_ctx: TestContext) { + // Infinite loop + let hugr = { + let just_input = USIZE_T; + let just_output = Type::UNIT; + let sum_ty = SumType::new(vec![just_input.clone(), just_output.clone()]); + let input_v = TypeRow::from(vec![just_input.clone()]); + let output_v = TypeRow::from(vec![just_output.clone()]); + + llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions); + + + SimpleHugrConfig::new() + .with_extensions(PRELUDE_REGISTRY.clone()) + .with_ins(input_v) + .with_outs(output_v) + .finish(|mut builder: DFGW| { + let [just_in_w] = builder.input_wires_arr(); + let mut tail_b = builder.tail_loop_builder([(just_input.clone(), just_in_w)], [], vec![just_output.clone()].into()).unwrap(); + + let input = tail_b.input(); + let [inp_w] = input.outputs_arr(); + + let loop_sig = tail_b.loop_signature().unwrap().clone(); + + // builder.add_dataflow_op(ops::Noop, input_wires) + + let sum_inp_w = tail_b.make_continue(loop_sig.clone(), [inp_w]).unwrap(); + + let outs@[_] = tail_b.finish_with_outputs(sum_inp_w, []).unwrap().outputs_arr(); + builder.finish_with_outputs(outs).unwrap() + }) + }; + check_emission!(hugr, llvm_ctx); + + } + + #[rstest] + fn tail_loop(mut llvm_ctx: TestContext) { + // Implement the hugr equivalent of the program: + let hugr = { + let int_ty = int_types::int_type(6); + let just_input = int_ty.clone(); + let just_output = Type::UNIT; + let other_ty = int_ty.clone(); + let sum_ty = SumType::new(vec![just_input.clone(), just_output.clone()]); + let input_v = TypeRow::from(vec![just_input.clone(), other_ty.clone()]); + let output_v = TypeRow::from(vec![just_output.clone(), other_ty.clone()]); + + llvm_ctx.add_extensions(add_int_extensions); + + let mut registry = PRELUDE_REGISTRY.clone(); + registry.register(int_ops::EXTENSION.clone()).unwrap(); + registry.register(int_types::EXTENSION.clone()).unwrap(); + + SimpleHugrConfig::new() + .with_extensions(registry) + .with_ins(input_v) + .with_outs(output_v) + .finish(|mut builder: DFGW| { + let [just_in_w, other_w] = builder.input_wires_arr(); + let mut tail_b = builder.tail_loop_builder([(just_input.clone(), just_in_w)], [(other_ty.clone(), other_w)], vec![just_output.clone()].into()).unwrap(); + let [sum_inp_w, other_w] = tail_b.input_wires_arr(); + + let zero = ConstInt::new_u(6, 0).unwrap(); + let zero_w = tail_b.add_load_value(zero); + let [result] = tail_b.add_dataflow_op(int_ops::IntOpDef::ile_u.with_log_width(6), [sum_inp_w, zero_w]).unwrap().outputs_arr(); + let input = tail_b.input(); + let [inp_w, other_w] = input.outputs_arr(); + + let loop_sig = tail_b.loop_signature().unwrap().clone(); + + let sum_inp_w = tail_b.make_continue(loop_sig.clone(), [inp_w]).unwrap(); + + let cond = { + let mut cond_b = tail_b + .conditional_builder( + ([just_input.into(), just_output.into()], sum_inp_w), + vec![(other_ty.clone(), other_w)], + vec![sum_ty.into(), other_ty.clone()].into(), + ) + .unwrap(); + + // If the check is false, we add 1 and continue + let mut false_case_b = cond_b.case_builder(0).unwrap(); + let [counter, val] = false_case_b.input_wires_arr(); + let one = ConstInt::new_u(6, 1).unwrap(); + let one_w = false_case_b.add_load_value(one); + + let two = ConstInt::new_u(6, 2).unwrap(); + let two_w = false_case_b.add_load_value(two); + let [val] = false_case_b.add_dataflow_op(int_ops::IntOpDef::imul.with_log_width(6), [val, two_w]).unwrap().outputs_arr(); + + let [counter] = false_case_b.add_dataflow_op(int_ops::IntOpDef::isub.with_log_width(6), [counter, one_w]).unwrap().outputs_arr(); + let tagged_counter = false_case_b.make_continue(loop_sig.clone(), [counter]).unwrap(); + + false_case_b.finish_with_outputs([tagged_counter, val]).unwrap(); + + // In the true case, we break and output true along with the "other" input wire + let mut true_case_b = cond_b.case_builder(1).unwrap(); + + let [_, val_w] = true_case_b.input_wires_arr(); + let unit_v = Value::unit_sum(0, 1).unwrap(); + let unit_w = true_case_b.add_load_value(unit_v); + let tagged_output = true_case_b.make_break(loop_sig.clone(), [unit_w]).unwrap(); + true_case_b.finish_with_outputs([tagged_output, val_w]).unwrap(); + + cond_b.finish_sub_container().unwrap() + }; + let [sum, rest] = cond.outputs_arr(); + let outs@[_,_] = tail_b.finish_with_outputs(sum, [rest]).unwrap().outputs_arr(); + builder.finish_with_outputs(outs).unwrap() + }) + }; + check_emission!(hugr, llvm_ctx); + + } + #[rstest] fn test_exec(mut exec_ctx: TestContext) { let hugr = SimpleHugrConfig::new()