Skip to content

Commit

Permalink
tests: Add tail loop emission insta tests
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor committed Dec 10, 2024
1 parent 63aa1f8 commit 3869c2f
Showing 1 changed file with 121 additions and 2 deletions.
123 changes: 121 additions & 2 deletions hugr-llvm/src/emit/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ use anyhow::{anyhow, Result};
use hugr_core::builder::{
BuildHandle, Container, DFGWrapper, HugrBuilder, ModuleBuilder, SubContainer,
};
use hugr_core::extension::prelude::PRELUDE_ID;
use hugr_core::extension::prelude::{BOOL_T, PRELUDE, PRELUDE_ID, USIZE_T};
use hugr_core::extension::{ExtensionRegistry, ExtensionSet, EMPTY_REG};
use hugr_core::ops::handle::FuncID;
use hugr_core::std_extensions::arithmetic::{
conversions, float_ops, float_types, int_ops, int_types,
};
use hugr_core::std_extensions::{collections, logic};
use hugr_core::types::TypeRow;
use hugr_core::types::{SumType, TypeRow};
use hugr_core::{Hugr, HugrView};
use inkwell::module::Module;
use inkwell::passes::PassManager;
Expand Down Expand Up @@ -551,6 +551,125 @@ mod test_fns {
check_emission!(hugr, llvm_ctx);
}

#[rstest]
fn tail_loop_simple(mut llvm_ctx: TestContext) {
// Infinite loop
let hugr = {
let just_input = USIZE_T;
let just_output = Type::UNIT;
let sum_ty = SumType::new(vec![just_input.clone(), just_output.clone()]);
let input_v = TypeRow::from(vec![just_input.clone()]);
let output_v = TypeRow::from(vec![just_output.clone()]);

llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);


SimpleHugrConfig::new()
.with_extensions(PRELUDE_REGISTRY.clone())
.with_ins(input_v)
.with_outs(output_v)
.finish(|mut builder: DFGW| {
let [just_in_w] = builder.input_wires_arr();
let mut tail_b = builder.tail_loop_builder([(just_input.clone(), just_in_w)], [], vec![just_output.clone()].into()).unwrap();

let input = tail_b.input();
let [inp_w] = input.outputs_arr();

let loop_sig = tail_b.loop_signature().unwrap().clone();

// builder.add_dataflow_op(ops::Noop, input_wires)

let sum_inp_w = tail_b.make_continue(loop_sig.clone(), [inp_w]).unwrap();

let outs@[_] = tail_b.finish_with_outputs(sum_inp_w, []).unwrap().outputs_arr();
builder.finish_with_outputs(outs).unwrap()
})
};
check_emission!(hugr, llvm_ctx);

}

#[rstest]
fn tail_loop(mut llvm_ctx: TestContext) {
// Implement the hugr equivalent of the program:
let hugr = {
let int_ty = int_types::int_type(6);
let just_input = int_ty.clone();
let just_output = Type::UNIT;
let other_ty = int_ty.clone();
let sum_ty = SumType::new(vec![just_input.clone(), just_output.clone()]);
let input_v = TypeRow::from(vec![just_input.clone(), other_ty.clone()]);
let output_v = TypeRow::from(vec![just_output.clone(), other_ty.clone()]);

llvm_ctx.add_extensions(add_int_extensions);

let mut registry = PRELUDE_REGISTRY.clone();
registry.register(int_ops::EXTENSION.clone()).unwrap();
registry.register(int_types::EXTENSION.clone()).unwrap();

SimpleHugrConfig::new()
.with_extensions(registry)
.with_ins(input_v)
.with_outs(output_v)
.finish(|mut builder: DFGW| {
let [just_in_w, other_w] = builder.input_wires_arr();
let mut tail_b = builder.tail_loop_builder([(just_input.clone(), just_in_w)], [(other_ty.clone(), other_w)], vec![just_output.clone()].into()).unwrap();
let [sum_inp_w, other_w] = tail_b.input_wires_arr();

let zero = ConstInt::new_u(6, 0).unwrap();
let zero_w = tail_b.add_load_value(zero);
let [result] = tail_b.add_dataflow_op(int_ops::IntOpDef::ile_u.with_log_width(6), [sum_inp_w, zero_w]).unwrap().outputs_arr();
let input = tail_b.input();
let [inp_w, other_w] = input.outputs_arr();

let loop_sig = tail_b.loop_signature().unwrap().clone();

let sum_inp_w = tail_b.make_continue(loop_sig.clone(), [inp_w]).unwrap();

let cond = {
let mut cond_b = tail_b
.conditional_builder(
([just_input.into(), just_output.into()], sum_inp_w),
vec![(other_ty.clone(), other_w)],
vec![sum_ty.into(), other_ty.clone()].into(),
)
.unwrap();

// If the check is false, we add 1 and continue
let mut false_case_b = cond_b.case_builder(0).unwrap();
let [counter, val] = false_case_b.input_wires_arr();
let one = ConstInt::new_u(6, 1).unwrap();
let one_w = false_case_b.add_load_value(one);

let two = ConstInt::new_u(6, 2).unwrap();
let two_w = false_case_b.add_load_value(two);
let [val] = false_case_b.add_dataflow_op(int_ops::IntOpDef::imul.with_log_width(6), [val, two_w]).unwrap().outputs_arr();

let [counter] = false_case_b.add_dataflow_op(int_ops::IntOpDef::isub.with_log_width(6), [counter, one_w]).unwrap().outputs_arr();
let tagged_counter = false_case_b.make_continue(loop_sig.clone(), [counter]).unwrap();

false_case_b.finish_with_outputs([tagged_counter, val]).unwrap();

// In the true case, we break and output true along with the "other" input wire
let mut true_case_b = cond_b.case_builder(1).unwrap();

let [_, val_w] = true_case_b.input_wires_arr();
let unit_v = Value::unit_sum(0, 1).unwrap();
let unit_w = true_case_b.add_load_value(unit_v);
let tagged_output = true_case_b.make_break(loop_sig.clone(), [unit_w]).unwrap();
true_case_b.finish_with_outputs([tagged_output, val_w]).unwrap();

cond_b.finish_sub_container().unwrap()
};
let [sum, rest] = cond.outputs_arr();
let outs@[_,_] = tail_b.finish_with_outputs(sum, [rest]).unwrap().outputs_arr();
builder.finish_with_outputs(outs).unwrap()
})
};
check_emission!(hugr, llvm_ctx);

}

#[rstest]
fn test_exec(mut exec_ctx: TestContext) {
let hugr = SimpleHugrConfig::new()
Expand Down

0 comments on commit 3869c2f

Please sign in to comment.