Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Remove PartialEq impl for ConstF64 #1079

Merged
merged 3 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ mod test {
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES};
use crate::std_extensions::logic::{self, NaryLogic, NotOp};
use crate::utils::test::assert_fully_folded;
use crate::utils::test::{assert_fully_folded, assert_fully_folded_with};

use rstest::rstest;

Expand All @@ -275,9 +275,13 @@ mod test {
fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) {
let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))];
let add_op: OpType = FloatOps::fadd.into();
let out = fold_leaf_op(&add_op, &consts).unwrap();
let outs = fold_leaf_op(&add_op, &consts)
.unwrap()
.into_iter()
.map(|(p, v)| (p, v.get_custom_value::<ConstF64>().unwrap().value()))
.collect_vec();

assert_eq!(&out[..], &[(0.into(), f2c(c))]);
assert_eq!(outs.as_slice(), &[(0.into(), c)]);
}
#[test]
fn test_big() {
Expand Down Expand Up @@ -529,8 +533,9 @@ mod test {
let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap();
let mut h0 = build.finish_hugr_with_outputs(x2.outputs(), &reg).unwrap();
constant_fold_pass(&mut h0, &reg);
let expected = Value::extension(ConstF64::new(1.0));
assert_fully_folded(&h0, &expected);
assert_fully_folded_with(&h0, |v| {
v.get_custom_value::<ConstF64>().unwrap().value() == 1.0
});
assert_eq!(h0.node_count(), 5);

// HUGR computing 1.0 / 0.0
Expand Down
63 changes: 32 additions & 31 deletions hugr/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use crate::extension::{EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::{self, dataflow::IOTrait, Input, Module, Noop, Output, Value, DFG};
use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::arithmetic::float_types::FLOAT64_TYPE;
use crate::std_extensions::arithmetic::int_ops::INT_OPS_REGISTRY;
use crate::std_extensions::arithmetic::int_types::{int_custom_type, ConstInt, INT_TYPES};
use crate::std_extensions::logic::NotOp;
use crate::types::{
Expand Down Expand Up @@ -94,20 +94,14 @@ impl_sertesting_from!(NodeSer, optype);

#[test]
fn empty_hugr_serialize() {
let hg = Hugr::default();
assert_eq!(ser_roundtrip(&hg), hg);
}

/// Serialize and deserialize a value.
pub fn ser_roundtrip<T: Serialize + serde::de::DeserializeOwned>(g: &T) -> T {
ser_roundtrip_validate(g, None)
check_hugr_roundtrip(&Hugr::default(), true);
}

/// Serialize and deserialize a value, optionally validating against a schema.
pub fn ser_roundtrip_validate<T: Serialize + serde::de::DeserializeOwned>(
pub fn ser_serialize_check_schema<T: Serialize + serde::de::DeserializeOwned>(
g: &T,
schema: Option<&JSONSchema>,
) -> T {
) -> serde_json::Value {
let s = serde_json::to_string(g).unwrap();
let val: serde_json::Value = serde_json::from_str(&s).unwrap();

Expand All @@ -123,7 +117,7 @@ pub fn ser_roundtrip_validate<T: Serialize + serde::de::DeserializeOwned>(
panic!("Serialization test failed.");
}
}
serde_json::from_value(val).unwrap()
val
}

/// Serialize and deserialize a HUGR, and check that the result is the same as the original.
Expand All @@ -138,13 +132,15 @@ pub fn check_hugr_schema_roundtrip(hugr: &Hugr) -> Hugr {
///
/// If `check_schema` is true, checks the serialized json against the in-tree schema.
///
/// Note that we do not literally compare the before and after `Hugr`s for
/// equality, because impls of `CustomConst` are not required to implement
/// equality checking.
///
/// Returns the deserialized HUGR.
pub fn check_hugr_roundtrip(hugr: &Hugr, check_schema: bool) -> Hugr {
let new_hugr: Hugr = ser_roundtrip_validate(hugr, check_schema.then_some(&SCHEMA));
let new_hugr_strict: Hugr =
ser_roundtrip_validate(hugr, check_schema.then_some(&SCHEMA_STRICT));
assert_eq!(new_hugr, new_hugr_strict);

let hugr_ser = ser_serialize_check_schema(hugr, check_schema.then_some(&SCHEMA));
let _ = ser_serialize_check_schema(hugr, check_schema.then_some(&SCHEMA_STRICT));
let new_hugr: Hugr = serde_json::from_value(hugr_ser).unwrap();
// Original HUGR, with canonicalized node indices
//
// The internal port indices may still be different.
Expand All @@ -159,7 +155,9 @@ pub fn check_hugr_roundtrip(hugr: &Hugr, check_schema: bool) -> Hugr {
for node in new_hugr.nodes() {
let new_op = new_hugr.get_optype(node);
let old_op = h_canon.get_optype(node);
assert_eq!(new_op, old_op);
if !new_op.is_const() {
assert_eq!(new_op, old_op);
}
}

// Check that the graphs are equivalent up to port renumbering.
Expand All @@ -182,8 +180,13 @@ pub fn check_hugr_roundtrip(hugr: &Hugr, check_schema: bool) -> Hugr {

fn check_testing_roundtrip(t: impl Into<TestingModel>) {
let before = Versioned::new(t.into());
let after_strict = ser_roundtrip_validate(&before, Some(&TESTING_SCHEMA_STRICT));
let after = ser_roundtrip_validate(&before, Some(&TESTING_SCHEMA));
let after_strict = serde_json::from_value(ser_serialize_check_schema(
&before,
Some(&TESTING_SCHEMA_STRICT),
))
.unwrap();
let after =
serde_json::from_value(ser_serialize_check_schema(&before, Some(&TESTING_SCHEMA))).unwrap();
assert_eq!(before, after);
assert_eq!(after, after_strict);
}
Expand Down Expand Up @@ -347,9 +350,10 @@ fn hierarchy_order() -> Result<(), Box<dyn std::error::Error>> {

#[test]
fn constants_roundtrip() -> Result<(), Box<dyn std::error::Error>> {
let mut builder = DFGBuilder::new(FunctionType::new(vec![], vec![FLOAT64_TYPE])).unwrap();
let w = builder.add_load_value(ConstF64::new(0.5));
let hugr = builder.finish_hugr_with_outputs([w], &FLOAT_OPS_REGISTRY)?;
let mut builder =
DFGBuilder::new(FunctionType::new(vec![], vec![INT_TYPES[4].clone()])).unwrap();
let w = builder.add_load_value(ConstInt::new_s(4, -2).unwrap());
let hugr = builder.finish_hugr_with_outputs([w], &INT_OPS_REGISTRY)?;

let ser = serde_json::to_string(&hugr)?;
let deser = serde_json::from_str(&ser)?;
Expand All @@ -363,18 +367,18 @@ fn constants_roundtrip() -> Result<(), Box<dyn std::error::Error>> {
fn serialize_types_roundtrip() {
let g: Type = Type::new_function(FunctionType::new_endo(vec![]));

assert_eq!(ser_roundtrip(&g), g);
check_testing_roundtrip(g.clone());

// A Simple tuple
let t = Type::new_tuple(vec![USIZE_T, g]);
assert_eq!(ser_roundtrip(&t), t);
check_testing_roundtrip(t);

// A Classic sum
let t = Type::new_sum([type_row![USIZE_T], type_row![FLOAT64_TYPE]]);
assert_eq!(ser_roundtrip(&t), t);
check_testing_roundtrip(t);

let t = Type::new_unit_sum(4);
assert_eq!(ser_roundtrip(&t), t);
check_testing_roundtrip(t);
}

#[rstest]
Expand All @@ -401,10 +405,7 @@ fn roundtrip_sumtype(#[case] sum_type: SumType) {
#[case(Value::unit())]
#[case(Value::true_val())]
#[case(Value::unit_sum(3,5).unwrap())]
#[case(Value::extension(ConstF64::new(-1.5)))]
#[case(Value::extension(ConstF64::new(0.0)))]
#[case(Value::extension(ConstF64::new(-0.0)))]
#[case(Value::extension(ConstF64::new(f64::MIN_POSITIVE)))]
#[case(Value::extension(ConstInt::new_u(2,1).unwrap()))]
#[case(Value::sum(1,[Value::extension(ConstInt::new_u(2,1).unwrap())], SumType::new([vec![], vec![INT_TYPES[2].clone()]])).unwrap())]
#[case(Value::tuple([Value::false_val(), Value::extension(ConstInt::new_s(2,1).unwrap())]))]
#[case(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap())]
Expand Down
5 changes: 3 additions & 2 deletions hugr/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ pub type ValueNameRef = str;
mod test {
use super::Value;
use crate::builder::test::simple_dfg_hugr;
use crate::std_extensions::arithmetic::int_types::ConstInt;
use crate::{
builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr},
extension::{
Expand Down Expand Up @@ -605,9 +606,9 @@ mod test {
const_usize.get_custom_value::<ConstUsize>(),
Some(&ConstUsize::new(257))
);
assert_eq!(const_usize.get_custom_value::<ConstF64>(), None);
assert_eq!(const_usize.get_custom_value::<ConstInt>(), None);
assert_eq!(const_tuple.get_custom_value::<ConstUsize>(), None);
assert_eq!(const_tuple.get_custom_value::<ConstF64>(), None);
assert_eq!(const_tuple.get_custom_value::<ConstInt>(), None);
}

#[test]
Expand Down
9 changes: 3 additions & 6 deletions hugr/src/std_extensions/arithmetic/float_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub const FLOAT64_CUSTOM_TYPE: CustomType =
/// 64-bit IEEE 754-2019 floating-point type (as [Type])
pub const FLOAT64_TYPE: Type = Type::new_extension(FLOAT64_CUSTOM_TYPE);

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
/// A floating-point value.
pub struct ConstF64 {
/// The value.
Expand Down Expand Up @@ -64,8 +64,8 @@ impl CustomConst for ConstF64 {
FLOAT64_TYPE
}

fn equal_consts(&self, other: &dyn CustomConst) -> bool {
crate::ops::constant::downcast_equal_consts(self, other)
fn equal_consts(&self, _: &dyn CustomConst) -> bool {
false
}

fn extension_reqs(&self) -> ExtensionSet {
Expand Down Expand Up @@ -110,8 +110,5 @@ mod test {
assert_eq!(const_f64_1.value(), 1.0);
assert_eq!(*const_f64_2, 2.0);
assert_eq!(const_f64_1.name(), "f64(1)");
assert!(const_f64_1.equal_consts(&ConstF64::new(1.0)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should keep (but negate) this test, for coverage.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

assert_ne!(const_f64_1, const_f64_2);
assert_eq!(const_f64_1, ConstF64::new(1.0));
}
}
11 changes: 10 additions & 1 deletion hugr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,22 @@ pub(crate) mod test {

/// Check that a hugr just loads and returns a single expected constant.
pub(crate) fn assert_fully_folded(h: &Hugr, expected_value: &Value) {
assert_fully_folded_with(h, |v| v == expected_value)
}

/// Check that a hugr just loads and returns a single constant, and validate
/// that constant using `check_value`.
///
/// [CustomConst::equals_const] is not required to be implemented. Use this
/// function for Values containing such a `CustomConst`.
pub(crate) fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) {
let mut node_count = 0;

for node in h.children(h.root()) {
let op = h.get_optype(node);
match op {
OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
OpType::Const(c) if c.value() == expected_value => node_count += 1,
OpType::Const(c) if check_value(c.value()) => node_count += 1,
_ => panic!("unexpected op: {:?}\n{}", op, h.mermaid_string()),
}
}
Expand Down