Skip to content

Commit

Permalink
fix tests after rebasing
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Dec 10, 2024
1 parent 588f09c commit 5360763
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 42 deletions.
10 changes: 2 additions & 8 deletions hugr-core/src/extension/resolution/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::extension::resolution::{
resolve_op_extensions, resolve_op_types_extensions, ExtensionCollectionError,
};
use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet};
use crate::ops::{CallIndirect, ExtensionOp, Input, OpType, Tag, Value};
use crate::ops::{CallIndirect, ExtensionOp, Input, OpTrait, OpType, Tag, Value};
use crate::std_extensions::arithmetic::float_types::float64_type;
use crate::std_extensions::arithmetic::int_ops;
use crate::std_extensions::arithmetic::int_types::{self, int_type};
Expand Down Expand Up @@ -222,12 +222,6 @@ fn resolve_hugr_extensions() {
#[rstest]
fn dropped_weak_extensions() {
let (ext_a, op_a) = make_extension("dummy.a", "op_a");
let build_extensions = ExtensionRegistry::new([
PRELUDE.to_owned(),
ext_a.clone(),
float_types::EXTENSION.to_owned(),
]);

let mut func = FunctionBuilder::new(
"dummy_fn",
Signature::new(vec![float64_type(), bool_t()], vec![]).with_extension_delta(
Expand All @@ -241,7 +235,7 @@ fn dropped_weak_extensions() {
let [_func_i0, func_i1] = func.input_wires_arr();
func.add_dataflow_op(op_a, vec![func_i1]).unwrap();

let hugr = func.finish_hugr(&build_extensions).unwrap();
let hugr = func.finish_hugr().unwrap();

// Do a serialization roundtrip to drop the references.
let ser = serde_json::to_string(&hugr).unwrap();
Expand Down
12 changes: 3 additions & 9 deletions hugr-llvm/src/extension/prelude/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -969,9 +969,7 @@ mod test {
.unwrap();
let v = func.add_load_value(ConstInt::new_u(6, value).unwrap());
let func_id = func.finish_with_outputs(vec![v]).unwrap();
let func_v = builder
.load_func(func_id.handle(), &[], &exec_registry())
.unwrap();
let func_v = builder.load_func(func_id.handle(), &[]).unwrap();
let repeat = ArrayRepeat::new(int_ty.clone(), size, exec_extension_set());
let arr = builder
.add_dataflow_op(repeat, vec![func_v])
Expand Down Expand Up @@ -1023,9 +1021,7 @@ mod test {
let delta = func.add_load_value(ConstInt::new_u(6, inc).unwrap());
let out = func.add_iadd(6, elem, delta).unwrap();
let func_id = func.finish_with_outputs(vec![out]).unwrap();
let func_v = builder
.load_func(func_id.handle(), &[], &exec_registry())
.unwrap();
let func_v = builder.load_func(func_id.handle(), &[]).unwrap();
let scan = ArrayScan::new(
int_ty.clone(),
int_ty.clone(),
Expand Down Expand Up @@ -1102,9 +1098,7 @@ mod test {
.unwrap()
.out_wire(0);
let func_id = func.finish_with_outputs(vec![unit, acc]).unwrap();
let func_v = builder
.load_func(func_id.handle(), &[], &exec_registry())
.unwrap();
let func_v = builder.load_func(func_id.handle(), &[]).unwrap();
let scan = ArrayScan::new(
int_ty.clone(),
Type::UNIT,
Expand Down
41 changes: 16 additions & 25 deletions hugr-passes/src/dataflow/test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use ascent::{lattice::BoundedLattice, Lattice};

use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder};
use hugr_core::extension::PRELUDE_REGISTRY;
use hugr_core::hugr::views::{DescendantsGraph, HierarchyView};
use hugr_core::ops::handle::DfgID;
use hugr_core::ops::TailLoop;
Expand All @@ -10,7 +9,7 @@ use hugr_core::{
builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer},
extension::{
prelude::{bool_t, UnpackTuple},
ExtensionSet, EMPTY_REG,
ExtensionSet,
},
ops::{handle::NodeHandle, DataflowOpTrait, Tag, Value},
type_row,
Expand Down Expand Up @@ -58,7 +57,7 @@ fn test_make_tuple() {
let v1 = builder.add_load_value(Value::false_val());
let v2 = builder.add_load_value(Value::true_val());
let v3 = builder.make_tuple([v1, v2]).unwrap();
let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap();
let hugr = builder.finish_hugr().unwrap();

let results = Machine::new(&hugr).run(TestContext, []);

Expand All @@ -74,7 +73,7 @@ fn test_unpack_tuple_const() {
.add_dataflow_op(UnpackTuple::new(vec![bool_t(); 2].into()), [v])
.unwrap()
.outputs_arr();
let hugr = builder.finish_hugr(&PRELUDE_REGISTRY).unwrap();
let hugr = builder.finish_hugr().unwrap();

let results = Machine::new(&hugr).run(TestContext, []);

Expand All @@ -100,7 +99,7 @@ fn test_tail_loop_never_iterates() {
.unwrap();
let tail_loop = tlb.finish_with_outputs(tagged.out_wire(0), []).unwrap();
let [tl_o] = tail_loop.outputs_arr();
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();
let hugr = builder.finish_hugr().unwrap();

let results = Machine::new(&hugr).run(TestContext, []);

Expand Down Expand Up @@ -135,7 +134,7 @@ fn test_tail_loop_always_iterates() {
let tail_loop = tlb.finish_with_outputs(r_w, [true_w]).unwrap();

let [tl_o1, tl_o2] = tail_loop.outputs_arr();
let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();
let hugr = builder.finish_hugr().unwrap();

let results = Machine::new(&hugr).run(TestContext, []);

Expand Down Expand Up @@ -172,7 +171,7 @@ fn test_tail_loop_two_iters() {
let [in_w1, in_w2] = tlb.input_wires_arr();
let tail_loop = tlb.finish_with_outputs(in_w1, [in_w2, in_w1]).unwrap();

let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();
let hugr = builder.finish_hugr().unwrap();
let [o_w1, o_w2] = tail_loop.outputs_arr();

let results = Machine::new(&hugr).run(TestContext, []);
Expand Down Expand Up @@ -235,7 +234,7 @@ fn test_tail_loop_containing_conditional() {

let tail_loop = tlb.finish_with_outputs(r, []).unwrap();

let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();
let hugr = builder.finish_hugr().unwrap();
let [o_w1, o_w2] = tail_loop.outputs_arr();

let results = Machine::new(&hugr).run(TestContext, []);
Expand Down Expand Up @@ -284,7 +283,7 @@ fn test_conditional() {

let [cond_o1, cond_o2] = cond.outputs_arr();

let hugr = builder.finish_hugr(&EMPTY_REG).unwrap();
let hugr = builder.finish_hugr().unwrap();

let arg_pv = PartialValue::new_variant(1, []).join(PartialValue::new_variant(
2,
Expand Down Expand Up @@ -363,7 +362,7 @@ fn xor_and_cfg() -> Hugr {
builder.branch(&a, tru, &b).unwrap(); // if true
builder.branch(&a, fals, &x).unwrap(); // if false
builder.branch(&b, 0, &x).unwrap();
builder.finish_hugr(&EMPTY_REG).unwrap()
builder.finish_hugr().unwrap()
}

#[rstest]
Expand Down Expand Up @@ -410,16 +409,14 @@ fn test_call(
let func_defn = func_bldr.finish_with_outputs([v]).unwrap();
let [a, b] = builder.input_wires_arr();
let [a2] = builder
.call(func_defn.handle(), &[], [a], &EMPTY_REG)
.call(func_defn.handle(), &[], [a])
.unwrap()
.outputs_arr();
let [b2] = builder
.call(func_defn.handle(), &[], [b], &EMPTY_REG)
.call(func_defn.handle(), &[], [b])
.unwrap()
.outputs_arr();
let hugr = builder
.finish_hugr_with_outputs([a2, b2], &EMPTY_REG)
.unwrap();
let hugr = builder.finish_hugr_with_outputs([a2, b2]).unwrap();

let results = Machine::new(&hugr).run(TestContext, [(0.into(), inp0), (1.into(), inp1)]);

Expand All @@ -439,9 +436,7 @@ fn test_region() {
.unwrap();
let nested_ins = nested.input_wires();
let nested = nested.finish_with_outputs(nested_ins).unwrap();
let hugr = builder
.finish_prelude_hugr_with_outputs(nested.outputs())
.unwrap();
let hugr = builder.finish_hugr_with_outputs(nested.outputs()).unwrap();
let [nested_input, _] = hugr.get_io(nested.node()).unwrap();
let whole_hugr_results = Machine::new(&hugr).run(TestContext, [(0.into(), pv_true())]);
assert_eq!(
Expand Down Expand Up @@ -494,21 +489,17 @@ fn test_module() {
.unwrap();
let [inp] = f2.input_wires_arr();
let cst_true = f2.add_load_value(Value::true_val());
let f2_call = f2
.call(leaf_fn.handle(), &[], [inp, cst_true], &EMPTY_REG)
.unwrap();
let f2_call = f2.call(leaf_fn.handle(), &[], [inp, cst_true]).unwrap();
let f2 = f2.finish_with_outputs(f2_call.outputs()).unwrap();

let mut main = modb
.define_function("main", Signature::new(bool_t(), vec![bool_t(); 2]))
.unwrap();
let [inp] = main.input_wires_arr();
let cst_false = main.add_load_value(Value::false_val());
let main_call = main
.call(leaf_fn.handle(), &[], [inp, cst_false], &EMPTY_REG)
.unwrap();
let main_call = main.call(leaf_fn.handle(), &[], [inp, cst_false]).unwrap();
main.finish_with_outputs(main_call.outputs()).unwrap();
let hugr = modb.finish_hugr(&EMPTY_REG).unwrap();
let hugr = modb.finish_hugr().unwrap();
let [f2_inp, _] = hugr.get_io(f2.node()).unwrap();

let results_just_main = Machine::new(&hugr).run(TestContext, [(0.into(), pv_true())]);
Expand Down

0 comments on commit 5360763

Please sign in to comment.