From 4e6e05d8034a1755d3215aaed4ed5985f1658c24 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Mon, 20 May 2024 11:51:25 +0100 Subject: [PATCH 1/2] feat!: Remove `PartialEq` impl for `ConstF64` Comparing floats for equality is notoriously difficult. So we don't. Note that we change the serialisation round-tripping checks a bit, because `Const` nodes containt `CustomConst`s that do not implement `equals_const` will break equality checks on the Hugrs. BREAKING CHANGE: Remove `PartialEq` impl for `ConstF64` --- hugr/src/algorithm/const_fold.rs | 15 +++-- hugr/src/hugr/serialize/test.rs | 63 ++++++++++--------- hugr/src/ops/constant.rs | 5 +- .../std_extensions/arithmetic/float_types.rs | 9 +-- hugr/src/utils.rs | 11 +++- 5 files changed, 58 insertions(+), 45 deletions(-) diff --git a/hugr/src/algorithm/const_fold.rs b/hugr/src/algorithm/const_fold.rs index f0d66e6f6..5d4cffed3 100644 --- a/hugr/src/algorithm/const_fold.rs +++ b/hugr/src/algorithm/const_fold.rs @@ -253,7 +253,7 @@ mod test { use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; use crate::std_extensions::logic::{self, NaryLogic, NotOp}; - use crate::utils::test::assert_fully_folded; + use crate::utils::test::{assert_fully_folded, assert_fully_folded_with}; use rstest::rstest; @@ -275,9 +275,13 @@ mod test { fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; let add_op: OpType = FloatOps::fadd.into(); - let out = fold_leaf_op(&add_op, &consts).unwrap(); + let outs = fold_leaf_op(&add_op, &consts) + .unwrap() + .into_iter() + .map(|(p, v)| (p, v.get_custom_value::().unwrap().value())) + .collect_vec(); - assert_eq!(&out[..], &[(0.into(), f2c(c))]); + assert_eq!(outs.as_slice(), &[(0.into(), c)]); } #[test] fn test_big() { @@ -529,8 +533,9 @@ mod test { let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); let mut h0 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); constant_fold_pass(&mut h0, ®); - let expected = Value::extension(ConstF64::new(1.0)); - assert_fully_folded(&h0, &expected); + assert_fully_folded_with(&h0, |v| { + v.get_custom_value::().unwrap().value() == 1.0 + }); assert_eq!(h0.node_count(), 5); // HUGR computing 1.0 / 0.0 diff --git a/hugr/src/hugr/serialize/test.rs b/hugr/src/hugr/serialize/test.rs index 12bfdb98f..049e3a096 100644 --- a/hugr/src/hugr/serialize/test.rs +++ b/hugr/src/hugr/serialize/test.rs @@ -9,8 +9,8 @@ use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{self, dataflow::IOTrait, Input, Module, Noop, Output, Value, DFG}; -use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY; -use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; +use crate::std_extensions::arithmetic::float_types::FLOAT64_TYPE; +use crate::std_extensions::arithmetic::int_ops::INT_OPS_REGISTRY; use crate::std_extensions::arithmetic::int_types::{int_custom_type, ConstInt, INT_TYPES}; use crate::std_extensions::logic::NotOp; use crate::types::{ @@ -94,20 +94,14 @@ impl_sertesting_from!(NodeSer, optype); #[test] fn empty_hugr_serialize() { - let hg = Hugr::default(); - assert_eq!(ser_roundtrip(&hg), hg); -} - -/// Serialize and deserialize a value. -pub fn ser_roundtrip(g: &T) -> T { - ser_roundtrip_validate(g, None) + check_hugr_roundtrip(&Hugr::default(), true); } /// Serialize and deserialize a value, optionally validating against a schema. -pub fn ser_roundtrip_validate( +pub fn ser_serialize_check_schema( g: &T, schema: Option<&JSONSchema>, -) -> T { +) -> serde_json::Value { let s = serde_json::to_string(g).unwrap(); let val: serde_json::Value = serde_json::from_str(&s).unwrap(); @@ -123,7 +117,7 @@ pub fn ser_roundtrip_validate( panic!("Serialization test failed."); } } - serde_json::from_value(val).unwrap() + val } /// Serialize and deserialize a HUGR, and check that the result is the same as the original. @@ -138,13 +132,15 @@ pub fn check_hugr_schema_roundtrip(hugr: &Hugr) -> Hugr { /// /// If `check_schema` is true, checks the serialized json against the in-tree schema. /// +/// Note that we do not literally compare the before and after `Hugr`s for +/// equality, because impls of `CustomConst` are not required to implement +/// equality checking. +/// /// Returns the deserialized HUGR. pub fn check_hugr_roundtrip(hugr: &Hugr, check_schema: bool) -> Hugr { - let new_hugr: Hugr = ser_roundtrip_validate(hugr, check_schema.then_some(&SCHEMA)); - let new_hugr_strict: Hugr = - ser_roundtrip_validate(hugr, check_schema.then_some(&SCHEMA_STRICT)); - assert_eq!(new_hugr, new_hugr_strict); - + let hugr_ser = ser_serialize_check_schema(hugr, check_schema.then_some(&SCHEMA)); + let _ = ser_serialize_check_schema(hugr, check_schema.then_some(&SCHEMA_STRICT)); + let new_hugr: Hugr = serde_json::from_value(hugr_ser).unwrap(); // Original HUGR, with canonicalized node indices // // The internal port indices may still be different. @@ -159,7 +155,9 @@ pub fn check_hugr_roundtrip(hugr: &Hugr, check_schema: bool) -> Hugr { for node in new_hugr.nodes() { let new_op = new_hugr.get_optype(node); let old_op = h_canon.get_optype(node); - assert_eq!(new_op, old_op); + if !new_op.is_const() { + assert_eq!(new_op, old_op); + } } // Check that the graphs are equivalent up to port renumbering. @@ -182,8 +180,13 @@ pub fn check_hugr_roundtrip(hugr: &Hugr, check_schema: bool) -> Hugr { fn check_testing_roundtrip(t: impl Into) { let before = Versioned::new(t.into()); - let after_strict = ser_roundtrip_validate(&before, Some(&TESTING_SCHEMA_STRICT)); - let after = ser_roundtrip_validate(&before, Some(&TESTING_SCHEMA)); + let after_strict = serde_json::from_value(ser_serialize_check_schema( + &before, + Some(&TESTING_SCHEMA_STRICT), + )) + .unwrap(); + let after = + serde_json::from_value(ser_serialize_check_schema(&before, Some(&TESTING_SCHEMA))).unwrap(); assert_eq!(before, after); assert_eq!(after, after_strict); } @@ -347,9 +350,10 @@ fn hierarchy_order() -> Result<(), Box> { #[test] fn constants_roundtrip() -> Result<(), Box> { - let mut builder = DFGBuilder::new(FunctionType::new(vec![], vec![FLOAT64_TYPE])).unwrap(); - let w = builder.add_load_value(ConstF64::new(0.5)); - let hugr = builder.finish_hugr_with_outputs([w], &FLOAT_OPS_REGISTRY)?; + let mut builder = + DFGBuilder::new(FunctionType::new(vec![], vec![INT_TYPES[4].clone()])).unwrap(); + let w = builder.add_load_value(ConstInt::new_s(4, -2).unwrap()); + let hugr = builder.finish_hugr_with_outputs([w], &INT_OPS_REGISTRY)?; let ser = serde_json::to_string(&hugr)?; let deser = serde_json::from_str(&ser)?; @@ -363,18 +367,18 @@ fn constants_roundtrip() -> Result<(), Box> { fn serialize_types_roundtrip() { let g: Type = Type::new_function(FunctionType::new_endo(vec![])); - assert_eq!(ser_roundtrip(&g), g); + check_testing_roundtrip(g.clone()); // A Simple tuple let t = Type::new_tuple(vec![USIZE_T, g]); - assert_eq!(ser_roundtrip(&t), t); + check_testing_roundtrip(t); // A Classic sum let t = Type::new_sum([type_row![USIZE_T], type_row![FLOAT64_TYPE]]); - assert_eq!(ser_roundtrip(&t), t); + check_testing_roundtrip(t); let t = Type::new_unit_sum(4); - assert_eq!(ser_roundtrip(&t), t); + check_testing_roundtrip(t); } #[rstest] @@ -401,10 +405,7 @@ fn roundtrip_sumtype(#[case] sum_type: SumType) { #[case(Value::unit())] #[case(Value::true_val())] #[case(Value::unit_sum(3,5).unwrap())] -#[case(Value::extension(ConstF64::new(-1.5)))] -#[case(Value::extension(ConstF64::new(0.0)))] -#[case(Value::extension(ConstF64::new(-0.0)))] -#[case(Value::extension(ConstF64::new(f64::MIN_POSITIVE)))] +#[case(Value::extension(ConstInt::new_u(2,1).unwrap()))] #[case(Value::sum(1,[Value::extension(ConstInt::new_u(2,1).unwrap())], SumType::new([vec![], vec![INT_TYPES[2].clone()]])).unwrap())] #[case(Value::tuple([Value::false_val(), Value::extension(ConstInt::new_s(2,1).unwrap())]))] #[case(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap())] diff --git a/hugr/src/ops/constant.rs b/hugr/src/ops/constant.rs index 551415351..d0bd420d3 100644 --- a/hugr/src/ops/constant.rs +++ b/hugr/src/ops/constant.rs @@ -433,6 +433,7 @@ pub type ValueNameRef = str; mod test { use super::Value; use crate::builder::test::simple_dfg_hugr; + use crate::std_extensions::arithmetic::int_types::ConstInt; use crate::{ builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}, extension::{ @@ -605,9 +606,9 @@ mod test { const_usize.get_custom_value::(), Some(&ConstUsize::new(257)) ); - assert_eq!(const_usize.get_custom_value::(), None); + assert_eq!(const_usize.get_custom_value::(), None); assert_eq!(const_tuple.get_custom_value::(), None); - assert_eq!(const_tuple.get_custom_value::(), None); + assert_eq!(const_tuple.get_custom_value::(), None); } #[test] diff --git a/hugr/src/std_extensions/arithmetic/float_types.rs b/hugr/src/std_extensions/arithmetic/float_types.rs index 713e2e030..27cb21984 100644 --- a/hugr/src/std_extensions/arithmetic/float_types.rs +++ b/hugr/src/std_extensions/arithmetic/float_types.rs @@ -23,7 +23,7 @@ pub const FLOAT64_CUSTOM_TYPE: CustomType = /// 64-bit IEEE 754-2019 floating-point type (as [Type]) pub const FLOAT64_TYPE: Type = Type::new_extension(FLOAT64_CUSTOM_TYPE); -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] /// A floating-point value. pub struct ConstF64 { /// The value. @@ -64,8 +64,8 @@ impl CustomConst for ConstF64 { FLOAT64_TYPE } - fn equal_consts(&self, other: &dyn CustomConst) -> bool { - crate::ops::constant::downcast_equal_consts(self, other) + fn equal_consts(&self, _: &dyn CustomConst) -> bool { + false } fn extension_reqs(&self) -> ExtensionSet { @@ -110,8 +110,5 @@ mod test { assert_eq!(const_f64_1.value(), 1.0); assert_eq!(*const_f64_2, 2.0); assert_eq!(const_f64_1.name(), "f64(1)"); - assert!(const_f64_1.equal_consts(&ConstF64::new(1.0))); - assert_ne!(const_f64_1, const_f64_2); - assert_eq!(const_f64_1, ConstF64::new(1.0)); } } diff --git a/hugr/src/utils.rs b/hugr/src/utils.rs index 072f1105e..d7ce0e101 100644 --- a/hugr/src/utils.rs +++ b/hugr/src/utils.rs @@ -234,13 +234,22 @@ pub(crate) mod test { /// Check that a hugr just loads and returns a single expected constant. pub(crate) fn assert_fully_folded(h: &Hugr, expected_value: &Value) { + assert_fully_folded_with(h, |v| v == expected_value) + } + + /// Check that a hugr just loads and returns a single constant, and validate + /// that constant using `check_value`. + /// + /// [CustomConst::equals_const] is not required to be implemented. Use this + /// function for Values containing such a `CustomConst`. + pub(crate) fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) { let mut node_count = 0; for node in h.children(h.root()) { let op = h.get_optype(node); match op { OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, - OpType::Const(c) if c.value() == expected_value => node_count += 1, + OpType::Const(c) if check_value(c.value()) => node_count += 1, _ => panic!("unexpected op: {:?}\n{}", op, h.mermaid_string()), } } From b133cd73e06ec7d34e4690ec02d1e515c54cf301 Mon Sep 17 00:00:00 2001 From: Douglas Wilson Date: Tue, 21 May 2024 12:50:35 +0100 Subject: [PATCH 2/2] re-add removed assertion, negated --- hugr/src/std_extensions/arithmetic/float_types.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hugr/src/std_extensions/arithmetic/float_types.rs b/hugr/src/std_extensions/arithmetic/float_types.rs index 27cb21984..41ffbe45c 100644 --- a/hugr/src/std_extensions/arithmetic/float_types.rs +++ b/hugr/src/std_extensions/arithmetic/float_types.rs @@ -110,5 +110,7 @@ mod test { assert_eq!(const_f64_1.value(), 1.0); assert_eq!(*const_f64_2, 2.0); assert_eq!(const_f64_1.name(), "f64(1)"); + // ConstF64 does not support `equal_consts` + assert!(!const_f64_1.equal_consts(&ConstF64::new(1.0))); } }