Skip to content

Commit

Permalink
fix BOOL_T -> bool_t() and types needing extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Dec 4, 2024
1 parent c34bab7 commit d387cef
Showing 1 changed file with 32 additions and 27 deletions.
59 changes: 32 additions & 27 deletions hugr-passes/src/dataflow/test.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use ascent::{lattice::BoundedLattice, Lattice};

use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder};
use hugr_core::extension::PRELUDE_REGISTRY;
use hugr_core::hugr::views::{DescendantsGraph, HierarchyView};
use hugr_core::ops::handle::DfgID;
use hugr_core::ops::TailLoop;
use hugr_core::types::TypeRow;
use hugr_core::{
builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer},
extension::{
prelude::{UnpackTuple, BOOL_T},
prelude::{bool_t, UnpackTuple},
ExtensionSet, EMPTY_REG,
},
ops::{handle::NodeHandle, DataflowOpTrait, Tag, Value},
Expand Down Expand Up @@ -64,7 +66,7 @@ fn test_make_tuple() {
let v1 = builder.add_load_value(Value::false_val());
let v2 = builder.add_load_value(Value::true_val());
let v3 = builder.make_tuple([v1, v2]).unwrap();
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();
let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap();

let results = Machine::default().run(TestContext(hugr), []);

Expand All @@ -77,10 +79,10 @@ fn test_unpack_tuple_const() {
let mut builder = DFGBuilder::new(endo_sig(vec![])).unwrap();
let v = builder.add_load_value(Value::tuple([Value::false_val(), Value::true_val()]));
let [o1, o2] = builder
.add_dataflow_op(UnpackTuple::new(type_row![BOOL_T, BOOL_T]), [v])
.add_dataflow_op(UnpackTuple::new(vec![bool_t(); 2].into()), [v])
.unwrap()
.outputs_arr();
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();
let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap();

let results = Machine::default().run(TestContext(hugr), []);

Expand Down Expand Up @@ -125,14 +127,14 @@ fn test_tail_loop_always_iterates() {
Value::sum(
TailLoop::CONTINUE_TAG,
[],
SumType::new([type_row![], BOOL_T.into()]),
SumType::new([type_row![], bool_t().into()]),
)
.unwrap(),
);
let true_w = builder.add_load_value(Value::true_val());

let tlb = builder
.tail_loop_builder([], [(BOOL_T, true_w)], vec![BOOL_T].into())
.tail_loop_builder([], [(bool_t(), true_w)], vec![bool_t()].into())
.unwrap();

// r_w has tag 0, so we always continue;
Expand Down Expand Up @@ -166,14 +168,14 @@ fn test_tail_loop_two_iters() {
let tlb = builder
.tail_loop_builder_exts(
[],
[(BOOL_T, false_w), (BOOL_T, true_w)],
[(bool_t(), false_w), (bool_t(), true_w)],
type_row![],
ExtensionSet::new(),
)
.unwrap();
assert_eq!(
tlb.loop_signature().unwrap().signature(),
Signature::new_endo(type_row![BOOL_T, BOOL_T])
Signature::new_endo(vec![bool_t(); 2])
);
let [in_w1, in_w2] = tlb.input_wires_arr();
let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap();
Expand All @@ -197,9 +199,9 @@ fn test_tail_loop_two_iters() {
#[test]
fn test_tail_loop_containing_conditional() {
let mut builder = DFGBuilder::new(Signature::new_endo(vec![])).unwrap();
let control_variants = vec![type_row![BOOL_T;2]; 2];
let control_variants = vec![vec![bool_t(); 2].into(); 2];
let control_t = Type::new_sum(control_variants.clone());
let body_out_variants = vec![control_t.clone().into(), type_row![BOOL_T; 2]];
let body_out_variants = vec![TypeRow::from(control_t.clone()), vec![bool_t(); 2].into()];

let init = builder.add_load_value(
Value::sum(
Expand All @@ -211,7 +213,7 @@ fn test_tail_loop_containing_conditional() {
);

let mut tlb = builder
.tail_loop_builder([(control_t, init)], [], type_row![BOOL_T; 2])
.tail_loop_builder([(control_t, init)], [], vec![bool_t(); 2].into())
.unwrap();
let tl = tlb.loop_signature().unwrap().clone();
let [in_w] = tlb.input_wires_arr();
Expand Down Expand Up @@ -259,7 +261,7 @@ fn test_tail_loop_containing_conditional() {

#[test]
fn test_conditional() {
let variants = vec![type_row![], type_row![], type_row![BOOL_T]];
let variants = vec![type_row![], type_row![], bool_t().into()];
let cond_t = Type::new_sum(variants.clone());
let mut builder = DFGBuilder::new(Signature::new(cond_t, type_row![])).unwrap();
let [arg_w] = builder.input_wires_arr();
Expand All @@ -270,8 +272,8 @@ fn test_conditional() {
let mut cond_builder = builder
.conditional_builder(
(variants, arg_w),
[(BOOL_T, true_w)],
type_row!(BOOL_T, BOOL_T),
[(bool_t(), true_w)],
vec![bool_t(); 2].into(),
)
.unwrap();
// will be unreachable
Expand Down Expand Up @@ -325,21 +327,21 @@ fn xor_and_cfg() -> Hugr {
// T,F T,F - T,F
// T,T T,T T,F F,T
let mut builder =
CFGBuilder::new(Signature::new(type_row![BOOL_T; 2], type_row![BOOL_T; 2])).unwrap();
CFGBuilder::new(Signature::new(vec![bool_t(); 2], vec![bool_t(); 2])).unwrap();

// entry (x, y) => (if x then A else B)(x=true, y)
let entry = builder
.entry_builder(vec![type_row![]; 2], type_row![BOOL_T;2])
.entry_builder(vec![type_row![]; 2], vec![bool_t(); 2].into())
.unwrap();
let [in_x, in_y] = entry.input_wires_arr();
let entry = entry.finish_with_outputs(in_x, [in_x, in_y]).unwrap();

// A(x==true, y) => (if y then B else X)(x, false)
let mut a = builder
.block_builder(
type_row![BOOL_T; 2],
vec![bool_t(); 2].into(),
vec![type_row![]; 2],
type_row![BOOL_T; 2],
vec![bool_t(); 2].into(),
)
.unwrap();
let [in_x, in_y] = a.input_wires_arr();
Expand All @@ -348,7 +350,11 @@ fn xor_and_cfg() -> Hugr {

// B(w, v) => X(v, w)
let mut b = builder
.block_builder(type_row![BOOL_T; 2], [type_row![]], type_row![BOOL_T; 2])
.block_builder(
vec![bool_t(); 2].into(),
[type_row![]],
vec![bool_t(); 2].into(),
)
.unwrap();
let [in_w, in_v] = b.input_wires_arr();
let [control] = b
Expand Down Expand Up @@ -407,9 +413,9 @@ fn test_call(
#[case] inp1: PartialValue<Void>,
#[case] out: PartialValue<Void>,
) {
let mut builder = DFGBuilder::new(Signature::new_endo(type_row![BOOL_T; 2])).unwrap();
let mut builder = DFGBuilder::new(Signature::new_endo(vec![bool_t(); 2])).unwrap();
let func_bldr = builder
.define_function("id", Signature::new_endo(BOOL_T))
.define_function("id", Signature::new_endo(bool_t()))
.unwrap();
let [v] = func_bldr.input_wires_arr();
let func_defn = func_bldr.finish_with_outputs([v]).unwrap();
Expand All @@ -436,12 +442,11 @@ fn test_call(

#[test]
fn test_region() {
let mut builder =
DFGBuilder::new(Signature::new(type_row![BOOL_T], type_row![BOOL_T;2])).unwrap();
let mut builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t(); 2])).unwrap();
let [in_w] = builder.input_wires_arr();
let cst_w = builder.add_load_const(Value::false_val());
let nested = builder
.dfg_builder(Signature::new_endo(type_row![BOOL_T; 2]), [in_w, cst_w])
.dfg_builder(Signature::new_endo(vec![bool_t(); 2]), [in_w, cst_w])
.unwrap();
let nested_ins = nested.input_wires();
let nested = nested.finish_with_outputs(nested_ins).unwrap();
Expand Down Expand Up @@ -490,13 +495,13 @@ fn test_region() {
fn test_module() {
let mut modb = ModuleBuilder::new();
let leaf_fn = modb
.define_function("leaf", Signature::new_endo(type_row![BOOL_T; 2]))
.define_function("leaf", Signature::new_endo(vec![bool_t(); 2]))
.unwrap();
let outs = leaf_fn.input_wires();
let leaf_fn = leaf_fn.finish_with_outputs(outs).unwrap();

let mut f2 = modb
.define_function("f2", Signature::new(BOOL_T, type_row![BOOL_T; 2]))
.define_function("f2", Signature::new(bool_t(), vec![bool_t(); 2]))
.unwrap();
let [inp] = f2.input_wires_arr();
let cst_true = f2.add_load_value(Value::true_val());
Expand All @@ -506,7 +511,7 @@ fn test_module() {
let f2 = f2.finish_with_outputs(f2_call.outputs()).unwrap();

let mut main = modb
.define_function("main", Signature::new(BOOL_T, type_row![BOOL_T; 2]))
.define_function("main", Signature::new(bool_t(), vec![bool_t(); 2]))
.unwrap();
let [inp] = main.input_wires_arr();
let cst_false = main.add_load_value(Value::false_val());
Expand Down

0 comments on commit d387cef

Please sign in to comment.