Skip to content

Commit

Permalink
Add constant folding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Aug 22, 2024
1 parent 6100512 commit 9da92fd
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 0 deletions.
46 changes: 46 additions & 0 deletions hugr-core/src/std_extensions/arithmetic/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ impl MakeRegisteredOp for FloatOps {

#[cfg(test)]
mod test {
use cgmath::AbsDiffEq;
use rstest::rstest;

use super::*;

#[test]
Expand All @@ -148,4 +151,47 @@ mod test {
assert!(name.as_str().starts_with('f'));
}
}

#[rstest]
#[case::fadd(FloatOps::fadd, &[0.1, 0.2], &[0.30000000000000004])]
#[case::fsub(FloatOps::fsub, &[1., 2.], &[-1.])]
#[case::fmul(FloatOps::fmul, &[2., 3.], &[6.])]
#[case::fdiv(FloatOps::fdiv, &[7., 2.], &[3.5])]
#[case::fpow(FloatOps::fpow, &[0.5, 3.], &[0.125])]
#[case::ffloor(FloatOps::ffloor, &[42.42], &[42.])]
#[case::fceil(FloatOps::fceil, &[42.42], &[43.])]
#[case::fround(FloatOps::fround, &[42.42], &[42.])]
fn float_fold(#[case] op: FloatOps, #[case] inputs: &[f64], #[case] outputs: &[f64]) {
use crate::ops::Value;
use crate::std_extensions::arithmetic::float_types::ConstF64;

let consts: Vec<_> = inputs
.iter()
.enumerate()
.map(|(i, &x)| (i.into(), Value::extension(ConstF64::new(x))))
.collect();

let res = op
.to_extension_op()
.unwrap()
.constant_fold(&consts)
.unwrap();

for (i, expected) in outputs.iter().enumerate() {
let res_val: f64 = res
.get(i)
.unwrap()
.1
.get_custom_value::<ConstF64>()
.expect("This function assumes all incoming constants are floats.")
.value();

assert!(
res_val.abs_diff_eq(expected, f64::EPSILON),
"expected {:?}, got {:?}",
expected,
res_val
);
}
}
}
50 changes: 50 additions & 0 deletions hugr-core/src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ fn sum_ty_with_err(t: Type) -> Type {

#[cfg(test)]
mod test {
use rstest::rstest;

use crate::{
ops::{dataflow::DataflowOpTrait, ExtensionOp},
std_extensions::arithmetic::int_types::int_type,
Expand Down Expand Up @@ -456,4 +458,52 @@ mod test {
assert_eq!(ConcreteIntOp::from_op(&ext_op).unwrap(), o);
assert_eq!(IntOpDef::from_op(&ext_op).unwrap(), IntOpDef::itobool);
}

#[rstest]
#[case::iadd(IntOpDef::iadd.with_log_width(5), &[1, 2], &[3], 5)]
#[case::isub(IntOpDef::isub.with_log_width(5), &[5, 2], &[3], 5)]
#[case::imul(IntOpDef::imul.with_log_width(5), &[2, 8], &[16], 5)]
#[case::idiv(IntOpDef::idiv_u.with_log_width(5), &[37, 8], &[4], 5)]
#[case::imod(IntOpDef::imod_u.with_log_width(5), &[43, 8], &[3], 5)]
#[case::ipow(IntOpDef::ipow.with_log_width(5), &[2, 8], &[256], 5)]
#[case::iu_to_s(IntOpDef::iu_to_s.with_log_width(5), &[42], &[42], 5)]
#[case::is_to_u(IntOpDef::is_to_u.with_log_width(5), &[42], &[42], 5)]
fn int_fold(
#[case] op: ConcreteIntOp,
#[case] inputs: &[u64],
#[case] outputs: &[u64],
#[case] log_width: u8,
) {
use crate::ops::Value;
use crate::std_extensions::arithmetic::int_types::ConstInt;

let consts: Vec<_> = inputs
.iter()
.enumerate()
.map(|(i, &x)| {
(
i.into(),
Value::extension(ConstInt::new_u(log_width, x).unwrap()),
)
})
.collect();

let res = op
.to_extension_op()
.unwrap()
.constant_fold(&consts)
.unwrap();

for (i, &expected) in outputs.iter().enumerate() {
let res_val: u64 = res
.get(i)
.unwrap()
.1
.get_custom_value::<ConstInt>()
.expect("This function assumes all incoming constants are floats.")
.value_u();

assert_eq!(res_val, expected);
}
}
}

0 comments on commit 9da92fd

Please sign in to comment.