Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: Combine ConstIntU and ConstIntS #974

Merged
merged 5 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,14 @@ mod test {
use crate::std_extensions::arithmetic::conversions::ConvertOpDef;
use crate::std_extensions::arithmetic::float_ops::FloatOps;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES};
use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES};
use crate::std_extensions::logic::{self, NaryLogic};

use rstest::rstest;

/// int to constant
fn i2c(b: u64) -> Value {
Value::extension(ConstIntU::new(5, b).unwrap())
Value::extension(ConstInt::new_u(5, b).unwrap())
}

/// float to constant
Expand Down
4 changes: 2 additions & 2 deletions hugr/src/hugr/rewrite/inline_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ mod test {
use crate::ops::{Lift, OpType, Value};
use crate::std_extensions::arithmetic::float_types;
use crate::std_extensions::arithmetic::int_ops::{self, IntOpDef};
use crate::std_extensions::arithmetic::int_types::{self, ConstIntU};
use crate::std_extensions::arithmetic::int_types::{self, ConstInt};
use crate::types::FunctionType;
use crate::utils::test_quantum_extension;
use crate::{type_row, Direction, HugrView, Node, Port};
Expand Down Expand Up @@ -184,7 +184,7 @@ mod test {
d: &mut DFGBuilder<T>,
) -> Result<Wire, Box<dyn std::error::Error>> {
let int_ty = &int_types::INT_TYPES[6];
let cst = Value::extension(ConstIntU::new(6, 15)?);
let cst = Value::extension(ConstInt::new_u(6, 15)?);
let c1 = d.add_load_const(cst);
let [lifted] = d
.add_dataflow_op(
Expand Down
14 changes: 7 additions & 7 deletions hugr/src/std_extensions/arithmetic/conversions/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
ops::constant::CustomConst,
std_extensions::arithmetic::{
float_types::ConstF64,
int_types::{get_log_width, ConstIntS, ConstIntU, INT_TYPES},
int_types::{get_log_width, ConstInt, INT_TYPES},
},
types::ConstTypeError,
IncomingPort,
Expand Down Expand Up @@ -78,7 +78,7 @@ impl ConstFold for TruncU {
consts: &[(IncomingPort, ops::Value)],
) -> ConstFoldResult {
fold_trunc(type_args, consts, |f, log_width| {
ConstIntU::new(log_width, f.trunc() as u64).map(Into::into)
ConstInt::new_u(log_width, f.trunc() as u64).map(Into::into)
})
}
}
Expand All @@ -92,7 +92,7 @@ impl ConstFold for TruncS {
consts: &[(IncomingPort, ops::Value)],
) -> ConstFoldResult {
fold_trunc(type_args, consts, |f, log_width| {
ConstIntS::new(log_width, f.trunc() as i64).map(Into::into)
ConstInt::new_s(log_width, f.trunc() as i64).map(Into::into)
})
}
}
Expand All @@ -105,8 +105,8 @@ impl ConstFold for ConvertU {
_type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Value)],
) -> ConstFoldResult {
let u: &ConstIntU = get_input(consts)?;
let f = u.value() as f64;
let u: &ConstInt = get_input(consts)?;
let f = u.value_u() as f64;
Some(vec![(0.into(), ConstF64::new(f).into())])
}
}
Expand All @@ -119,8 +119,8 @@ impl ConstFold for ConvertS {
_type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Value)],
) -> ConstFoldResult {
let u: &ConstIntS = get_input(consts)?;
let f = u.value() as f64;
let u: &ConstInt = get_input(consts)?;
let f = u.value_s() as f64;
Some(vec![(0.into(), ConstF64::new(f).into())])
}
}
124 changes: 59 additions & 65 deletions hugr/src/std_extensions/arithmetic/int_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ const fn is_valid_log_width(n: u8) -> bool {
n < LOG_WIDTH_BOUND
}

/// The maximum allowed log width.
pub const LOG_WIDTH_MAX: u8 = 6;

/// The smallest forbidden log width.
pub const LOG_WIDTH_BOUND: u8 = 7;
pub const LOG_WIDTH_BOUND: u8 = LOG_WIDTH_MAX + 1;

/// Type parameter for the log width of the integer.
#[allow(clippy::assertions_on_constants)]
Expand All @@ -71,23 +74,22 @@ const fn type_arg(log_width: u8) -> TypeArg {
n: log_width as u64,
}
}
/// An unsigned integer
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct ConstIntU {
log_width: u8,
value: u64,
}

/// A signed integer
/// An integer (either signed or unsigned)
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct ConstIntS {
pub struct ConstInt {
log_width: u8,
value: i64,
// We always use a u64 for the value. The interpretation is:
// - as an unsigned integer, (value mod 2^N);
// - as a signed integer, (value mod 2^(N-1) - 2^(N-1)*a)
// where N = 2^log_width and a is the (N-1)th bit of x (counting from
// 0 = least significant bit).
value: u64,
}

impl ConstIntU {
/// Create a new [`ConstIntU`]
pub fn new(log_width: u8, value: u64) -> Result<Self, ConstTypeError> {
impl ConstInt {
/// Create a new [`ConstInt`] with a given width and unsigned value
pub fn new_u(log_width: u8, value: u64) -> Result<Self, ConstTypeError> {
if !is_valid_log_width(log_width) {
return Err(ConstTypeError::CustomCheckFail(
crate::types::CustomCheckFailure::Message("Invalid integer width.".to_owned()),
Expand All @@ -103,20 +105,8 @@ impl ConstIntU {
Ok(Self { log_width, value })
}

/// Returns the value of the constant
pub fn value(&self) -> u64 {
self.value
}

/// Returns the number of bits of the constant
pub fn log_width(&self) -> u8 {
self.log_width
}
}

impl ConstIntS {
/// Create a new [`ConstIntS`]
pub fn new(log_width: u8, value: i64) -> Result<Self, ConstTypeError> {
/// Create a new [`ConstInt`] with a given width and signed value
pub fn new_s(log_width: u8, value: i64) -> Result<Self, ConstTypeError> {
if !is_valid_log_width(log_width) {
return Err(ConstTypeError::CustomCheckFail(
crate::types::CustomCheckFailure::Message("Invalid integer width.".to_owned()),
Expand All @@ -130,42 +120,46 @@ impl ConstIntS {
),
));
}
Ok(Self { log_width, value })
}

/// Returns the value of the constant
pub fn value(&self) -> i64 {
self.value
Ok(Self {
log_width,
value: (if value >= 0 || log_width == LOG_WIDTH_MAX {
value
} else {
value + (1i64 << width)
}) as u64,
})
}

/// Returns the number of bits of the constant
pub fn log_width(&self) -> u8 {
self.log_width
}
}

#[typetag::serde]
impl CustomConst for ConstIntU {
fn name(&self) -> SmolStr {
format!("u{}({})", self.log_width, self.value).into()
}
fn equal_consts(&self, other: &dyn CustomConst) -> bool {
crate::ops::constant::downcast_equal_consts(self, other)
}

fn extension_reqs(&self) -> ExtensionSet {
ExtensionSet::singleton(&EXTENSION_ID)
/// Returns the value of the constant as an unsigned integer
pub fn value_u(&self) -> u64 {
self.value
}

fn get_type(&self) -> Type {
int_type(type_arg(self.log_width))
/// Returns the value of the constant as a signed integer
pub fn value_s(&self) -> i64 {
if self.log_width == LOG_WIDTH_MAX {
self.value as i64
} else {
let width = 1u8 << self.log_width;
if ((self.value << 1) >> width) == 0 {
self.value as i64
} else {
self.value as i64 - (1i64 << width)
}
}
}
}

#[typetag::serde]
impl CustomConst for ConstIntS {
impl CustomConst for ConstInt {
fn name(&self) -> SmolStr {
format!("i{}({})", self.log_width, self.value).into()
format!("u{}({})", 1u8 << self.log_width, self.value).into()
}
fn equal_consts(&self, other: &dyn CustomConst) -> bool {
crate::ops::constant::downcast_equal_consts(self, other)
Expand Down Expand Up @@ -239,43 +233,43 @@ mod test {

#[test]
fn test_int_consts() {
let const_u32_7 = ConstIntU::new(5, 7);
let const_u64_7 = ConstIntU::new(6, 7);
let const_u32_8 = ConstIntU::new(5, 8);
let const_u32_7 = ConstInt::new_u(5, 7);
let const_u64_7 = ConstInt::new_u(6, 7);
let const_u32_8 = ConstInt::new_u(5, 8);
assert_ne!(const_u32_7, const_u64_7);
assert_ne!(const_u32_7, const_u32_8);
assert_eq!(const_u32_7, ConstIntU::new(5, 7));
assert_eq!(const_u32_7, ConstInt::new_u(5, 7));

assert_matches!(
ConstIntU::new(3, 256),
ConstInt::new_u(3, 256),
Err(ConstTypeError::CustomCheckFail(_))
);
assert_matches!(
ConstIntU::new(9, 256),
ConstInt::new_u(9, 256),
Err(ConstTypeError::CustomCheckFail(_))
);
assert_matches!(
ConstIntS::new(3, 128),
ConstInt::new_s(3, 128),
Err(ConstTypeError::CustomCheckFail(_))
);
assert!(ConstIntS::new(3, -128).is_ok());
assert!(ConstInt::new_s(3, -128).is_ok());

let const_u32_7 = const_u32_7.unwrap();
assert!(const_u32_7.equal_consts(&ConstIntU::new(5, 7).unwrap()));
assert!(const_u32_7.equal_consts(&ConstInt::new_u(5, 7).unwrap()));
assert_eq!(const_u32_7.log_width(), 5);
assert_eq!(const_u32_7.value(), 7);
assert_eq!(const_u32_7.value_u(), 7);
assert!(const_u32_7.validate().is_ok());

assert_eq!(const_u32_7.name(), "u5(7)");
assert_eq!(const_u32_7.name(), "u32(7)");

let const_i32_2 = ConstIntS::new(5, -2).unwrap();
assert!(const_i32_2.equal_consts(&ConstIntS::new(5, -2).unwrap()));
let const_i32_2 = ConstInt::new_s(5, -2).unwrap();
assert!(const_i32_2.equal_consts(&ConstInt::new_s(5, -2).unwrap()));
assert_eq!(const_i32_2.log_width(), 5);
assert_eq!(const_i32_2.value(), -2);
assert_eq!(const_i32_2.value_s(), -2);
assert!(const_i32_2.validate().is_ok());
assert_eq!(const_i32_2.name(), "i5(-2)");
assert_eq!(const_i32_2.name(), "u32(4294967294)");

ConstIntS::new(50, -2).unwrap_err();
ConstIntU::new(50, 2).unwrap_err();
ConstInt::new_s(50, -2).unwrap_err();
ConstInt::new_u(50, 2).unwrap_err();
}
}