Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored bounded_int_trim. #7062

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 36 additions & 42 deletions corelib/src/internal/bounded_int.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -113,56 +113,49 @@ extern fn bounded_int_constrain<T, const BOUNDARY: felt252, impl H: ConstrainHel
value: T,
) -> Result<H::LowT, H::HighT> implicits(RangeCheck) nopanic;

/// A helper trait for trimming a `BoundedInt` instance.
pub trait TrimHelper<T, const TRIMMED_VALUE: felt252> {
/// A helper trait for trimming a `BoundedInt` instance min value.
pub trait TrimMinHelper<T> {
type Target;
}
/// A helper trait for trimming a `BoundedInt` instance max value.
pub trait TrimMaxHelper<T> {
type Target;
}
mod trim_impl {
pub impl Impl<
T, const TRIMMED_VALUE: felt252, const MIN: felt252, const MAX: felt252,
> of super::TrimHelper<T, TRIMMED_VALUE> {
pub impl Min<T, const MIN: felt252, const MAX: felt252> of super::TrimMinHelper<T> {
type Target = super::BoundedInt<MIN, MAX>;
}
pub impl Max<T, const MIN: felt252, const MAX: felt252> of super::TrimMaxHelper<T> {
type Target = super::BoundedInt<MIN, MAX>;
}
}
impl U8TrimBelow = trim_impl::Impl<u8, 0, 1, 0xff>;
impl U8TrimAbove = trim_impl::Impl<u8, 0xff, 0, 0xfe>;
impl I8TrimBelow = trim_impl::Impl<i8, -0x80, -0x7f, 0x7f>;
impl I8TrimAbove = trim_impl::Impl<i8, 0x7f, -0x80, 0x7e>;
impl U16TrimBelow = trim_impl::Impl<u16, 0, 1, 0xffff>;
impl U16TrimAbove = trim_impl::Impl<u16, 0xffff, 0, 0xfffe>;
impl I16TrimBelow = trim_impl::Impl<i16, -0x8000, -0x7fff, 0x7fff>;
impl I16TrimAbove = trim_impl::Impl<i16, 0x7fff, -0x8000, 0x7ffe>;
impl U32TrimBelow = trim_impl::Impl<u32, 0, 1, 0xffffffff>;
impl U32TrimAbove = trim_impl::Impl<u32, 0xffffffff, 0, 0xfffffffe>;
impl I32TrimBelow = trim_impl::Impl<i32, -0x80000000, -0x7fffffff, 0x7fffffff>;
impl I32TrimAbove = trim_impl::Impl<i32, 0x7fffffff, -0x80000000, 0x7ffffffe>;
impl U64TrimBelow = trim_impl::Impl<u64, 0, 1, 0xffffffffffffffff>;
impl U64TrimAbove = trim_impl::Impl<u64, 0xffffffffffffffff, 0, 0xfffffffffffffffe>;
impl I64TrimBelow =
trim_impl::Impl<i64, -0x8000000000000000, -0x7fffffffffffffff, 0x7fffffffffffffff>;
impl I64TrimAbove =
trim_impl::Impl<i64, 0x7fffffffffffffff, -0x8000000000000000, 0x7ffffffffffffffe>;
impl U128TrimBelow = trim_impl::Impl<u128, 0, 1, 0xffffffffffffffffffffffffffffffff>;
impl U128TrimAbove =
trim_impl::Impl<
u128, 0xffffffffffffffffffffffffffffffff, 0, 0xfffffffffffffffffffffffffffffffe,
>;
impl U8TrimBelow = trim_impl::Min<u8, 1, 0xff>;
impl U8TrimAbove = trim_impl::Max<u8, 0, 0xfe>;
impl I8TrimBelow = trim_impl::Min<i8, -0x7f, 0x7f>;
impl I8TrimAbove = trim_impl::Max<i8, -0x80, 0x7e>;
impl U16TrimBelow = trim_impl::Min<u16, 1, 0xffff>;
impl U16TrimAbove = trim_impl::Max<u16, 0, 0xfffe>;
impl I16TrimBelow = trim_impl::Min<i16, -0x7fff, 0x7fff>;
impl I16TrimAbove = trim_impl::Max<i16, -0x8000, 0x7ffe>;
impl U32TrimBelow = trim_impl::Min<u32, 1, 0xffffffff>;
impl U32TrimAbove = trim_impl::Max<u32, 0, 0xfffffffe>;
impl I32TrimBelow = trim_impl::Min<i32, -0x7fffffff, 0x7fffffff>;
impl I32TrimAbove = trim_impl::Max<i32, -0x80000000, 0x7ffffffe>;
impl U64TrimBelow = trim_impl::Min<u64, 1, 0xffffffffffffffff>;
impl U64TrimAbove = trim_impl::Max<u64, 0, 0xfffffffffffffffe>;
impl I64TrimBelow = trim_impl::Min<i64, -0x7fffffffffffffff, 0x7fffffffffffffff>;
impl I64TrimAbove = trim_impl::Max<i64, -0x8000000000000000, 0x7ffffffffffffffe>;
impl U128TrimBelow = trim_impl::Min<u128, 1, 0xffffffffffffffffffffffffffffffff>;
impl U128TrimAbove = trim_impl::Max<u128, 0, 0xfffffffffffffffffffffffffffffffe>;
impl I128TrimBelow =
trim_impl::Impl<
i128,
-0x80000000000000000000000000000000,
-0x7fffffffffffffffffffffffffffffff,
0x7fffffffffffffffffffffffffffffff,
>;
trim_impl::Min<i128, -0x7fffffffffffffffffffffffffffffff, 0x7fffffffffffffffffffffffffffffff>;
impl I128TrimAbove =
trim_impl::Impl<
i128,
0x7fffffffffffffffffffffffffffffff,
-0x80000000000000000000000000000000,
0x7ffffffffffffffffffffffffffffffe,
>;
trim_impl::Max<i128, -0x80000000000000000000000000000000, 0x7ffffffffffffffffffffffffffffffe>;

extern fn bounded_int_trim<T, const TRIMMED_VALUE: felt252, impl H: TrimHelper<T, TRIMMED_VALUE>>(
extern fn bounded_int_trim_min<T, impl H: TrimMinHelper<T>>(
value: T,
) -> core::internal::OptionRev<H::Target> nopanic;
extern fn bounded_int_trim_max<T, impl H: TrimMaxHelper<T>>(
value: T,
) -> core::internal::OptionRev<H::Target> nopanic;

Expand Down Expand Up @@ -272,5 +265,6 @@ impl MulMinusOneNegateHelper<T, impl H: MulHelper<T, MinusOne>> of NegateHelper<
pub use {
bounded_int_add as add, bounded_int_sub as sub, bounded_int_mul as mul,
bounded_int_div_rem as div_rem, bounded_int_constrain as constrain,
bounded_int_is_zero as is_zero, bounded_int_trim as trim,
bounded_int_is_zero as is_zero, bounded_int_trim_min as trim_min,
bounded_int_trim_max as trim_max,
};
106 changes: 44 additions & 62 deletions corelib/src/test/integer_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2159,81 +2159,63 @@ mod bounded_int {
#[test]
fn test_trim() {
use core::internal::OptionRev;
assert!(bounded_int::trim::<u8, 0>(0) == OptionRev::None);
assert!(bounded_int::trim::<u8, 0>(1) == OptionRev::Some(1));
assert!(bounded_int::trim::<u8, 0xff>(0xff) == OptionRev::None);
assert!(bounded_int::trim::<u8, 0xff>(0xfe) == OptionRev::Some(0xfe));
assert!(bounded_int::trim::<i8, -0x80>(-0x80) == OptionRev::None);
assert!(bounded_int::trim::<i8, -0x80>(1) == OptionRev::Some(1));
assert!(bounded_int::trim::<i8, 0x7f>(0x7f) == OptionRev::None);
assert!(bounded_int::trim::<i8, 0x7f>(1) == OptionRev::Some(1));

assert!(bounded_int::trim::<u16, 0>(0) == OptionRev::None);
assert!(bounded_int::trim::<u16, 0>(1) == OptionRev::Some(1));
assert!(bounded_int::trim::<u16, 0xffff>(0xffff) == OptionRev::None);
assert!(bounded_int::trim::<u16, 0xffff>(0xfffe) == OptionRev::Some(0xfffe));
assert!(bounded_int::trim::<i16, -0x8000>(-0x8000) == OptionRev::None);
assert!(bounded_int::trim::<i16, -0x8000>(1) == OptionRev::Some(1));
assert!(bounded_int::trim::<i16, 0x7fff>(0x7fff) == OptionRev::None);
assert!(bounded_int::trim::<i16, 0x7fff>(1) == OptionRev::Some(1));

assert!(bounded_int::trim::<u32, 0>(0) == OptionRev::None);
assert!(bounded_int::trim::<u32, 0>(1) == OptionRev::Some(1));
assert!(bounded_int::trim::<u32, 0xffffffff>(0xffffffff) == OptionRev::None);
assert!(bounded_int::trim::<u32, 0xffffffff>(0xfffffffe) == OptionRev::Some(0xfffffffe));
assert!(bounded_int::trim::<i32, -0x80000000>(-0x80000000) == OptionRev::None);
assert!(bounded_int::trim::<i32, -0x80000000>(1) == OptionRev::Some(1));
assert!(bounded_int::trim::<i32, 0x7fffffff>(0x7fffffff) == OptionRev::None);
assert!(bounded_int::trim::<i32, 0x7fffffff>(1) == OptionRev::Some(1));

assert!(bounded_int::trim::<u64, 0>(0) == OptionRev::None);
assert!(bounded_int::trim::<u64, 0>(1) == OptionRev::Some(1));
assert!(bounded_int::trim_min::<u8>(0) == OptionRev::None);
assert!(bounded_int::trim_min::<u8>(1) == OptionRev::Some(1));
assert!(bounded_int::trim_max::<u8>(0xff) == OptionRev::None);
assert!(bounded_int::trim_max::<u8>(0xfe) == OptionRev::Some(0xfe));
assert!(bounded_int::trim_min::<i8>(-0x80) == OptionRev::None);
assert!(bounded_int::trim_min::<i8>(1) == OptionRev::Some(1));
assert!(bounded_int::trim_max::<i8>(0x7f) == OptionRev::None);
assert!(bounded_int::trim_max::<i8>(1) == OptionRev::Some(1));

assert!(bounded_int::trim_min::<u16>(0) == OptionRev::None);
assert!(bounded_int::trim_min::<u16>(1) == OptionRev::Some(1));
assert!(bounded_int::trim_max::<u16>(0xffff) == OptionRev::None);
assert!(bounded_int::trim_max::<u16>(0xfffe) == OptionRev::Some(0xfffe));
assert!(bounded_int::trim_min::<i16>(-0x8000) == OptionRev::None);
assert!(bounded_int::trim_min::<i16>(1) == OptionRev::Some(1));
assert!(bounded_int::trim_max::<i16>(0x7fff) == OptionRev::None);
assert!(bounded_int::trim_max::<i16>(1) == OptionRev::Some(1));

assert!(bounded_int::trim_min::<u32>(0) == OptionRev::None);
assert!(bounded_int::trim_min::<u32>(1) == OptionRev::Some(1));
assert!(bounded_int::trim_max::<u32>(0xffffffff) == OptionRev::None);
assert!(bounded_int::trim_max::<u32>(0xfffffffe) == OptionRev::Some(0xfffffffe));
assert!(bounded_int::trim_min::<i32>(-0x80000000) == OptionRev::None);
assert!(bounded_int::trim_min::<i32>(1) == OptionRev::Some(1));
assert!(bounded_int::trim_max::<i32>(0x7fffffff) == OptionRev::None);
assert!(bounded_int::trim_max::<i32>(1) == OptionRev::Some(1));

assert!(bounded_int::trim_min::<u64>(0) == OptionRev::None);
assert!(bounded_int::trim_min::<u64>(1) == OptionRev::Some(1));
assert!(bounded_int::trim_max::<u64>(0xffffffffffffffff) == OptionRev::None);
assert!(
bounded_int::trim::<u64, 0xffffffffffffffff>(0xffffffffffffffff) == OptionRev::None,
bounded_int::trim_max::<u64>(0xfffffffffffffffe) == OptionRev::Some(0xfffffffffffffffe),
);
assert!(
bounded_int::trim::<
u64, 0xffffffffffffffff,
>(0xfffffffffffffffe) == OptionRev::Some(0xfffffffffffffffe),
);
assert!(
bounded_int::trim::<i64, -0x8000000000000000>(-0x8000000000000000) == OptionRev::None,
);
assert!(bounded_int::trim::<i64, -0x8000000000000000>(1) == OptionRev::Some(1));
assert!(
bounded_int::trim::<i64, 0x7fffffffffffffff>(0x7fffffffffffffff) == OptionRev::None,
);
assert!(bounded_int::trim::<i64, 0x7fffffffffffffff>(1) == OptionRev::Some(1));
assert!(bounded_int::trim_min::<i64>(-0x8000000000000000) == OptionRev::None);
assert!(bounded_int::trim_min::<i64>(1) == OptionRev::Some(1));
assert!(bounded_int::trim_max::<i64>(0x7fffffffffffffff) == OptionRev::None);
assert!(bounded_int::trim_max::<i64>(1) == OptionRev::Some(1));

assert!(bounded_int::trim::<u128, 0>(0) == OptionRev::None);
assert!(bounded_int::trim::<u128, 0>(1) == OptionRev::Some(1));
assert!(bounded_int::trim_min::<u128>(0) == OptionRev::None);
assert!(bounded_int::trim_min::<u128>(1) == OptionRev::Some(1));
assert!(
bounded_int::trim::<
u128, 0xffffffffffffffffffffffffffffffff,
>(0xffffffffffffffffffffffffffffffff) == OptionRev::None,
bounded_int::trim_max::<u128>(0xffffffffffffffffffffffffffffffff) == OptionRev::None,
);
assert!(
bounded_int::trim::<
u128, 0xffffffffffffffffffffffffffffffff,
bounded_int::trim_max::<
u128,
>(
0xfffffffffffffffffffffffffffffffe,
) == OptionRev::Some(0xfffffffffffffffffffffffffffffffe),
);
assert!(
bounded_int::trim::<
i128, -0x80000000000000000000000000000000,
>(-0x80000000000000000000000000000000) == OptionRev::None,
);
assert!(
bounded_int::trim::<i128, -0x80000000000000000000000000000000>(1) == OptionRev::Some(1),
);
assert!(
bounded_int::trim::<
i128, 0x7fffffffffffffffffffffffffffffff,
>(0x7fffffffffffffffffffffffffffffff) == OptionRev::None,
bounded_int::trim_min::<i128>(-0x80000000000000000000000000000000) == OptionRev::None,
);
assert!(bounded_int::trim_min::<i128>(1) == OptionRev::Some(1));
assert!(
bounded_int::trim::<i128, 0x7fffffffffffffffffffffffffffffff>(1) == OptionRev::Some(1),
bounded_int::trim_max::<i128>(0x7fffffffffffffffffffffffffffffff) == OptionRev::None,
);
assert!(bounded_int::trim_max::<i128>(1) == OptionRev::Some(1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ pub fn core_libfunc_ap_change<InfoProvider: InvocationApChangeInfoProvider>(
ApChange::Known(1 + if libfunc.boundary.is_zero() { 0 } else { 1 }),
]
}
BoundedIntConcreteLibfunc::Trim(libfunc) => {
BoundedIntConcreteLibfunc::TrimMin(libfunc)
| BoundedIntConcreteLibfunc::TrimMax(libfunc) => {
let ap_change = if libfunc.trimmed_value.is_zero() { 0 } else { 1 };
vec![ApChange::Known(ap_change), ApChange::Known(ap_change)]
}
Expand Down
3 changes: 2 additions & 1 deletion crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,8 @@ pub fn core_libfunc_cost(
.into(),
]
}
BoundedIntConcreteLibfunc::Trim(libfunc) => {
BoundedIntConcreteLibfunc::TrimMin(libfunc)
| BoundedIntConcreteLibfunc::TrimMax(libfunc) => {
let steps: BranchCost =
ConstCost::steps(if libfunc.trimmed_value.is_zero() { 1 } else { 2 }).into();
vec![steps.clone(), steps]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ pub fn build(
BoundedIntConcreteLibfunc::Constrain(libfunc) => {
build_constrain(builder, &libfunc.boundary)
}
BoundedIntConcreteLibfunc::Trim(libfunc) => build_trim(builder, &libfunc.trimmed_value),
BoundedIntConcreteLibfunc::TrimMin(libfunc)
| BoundedIntConcreteLibfunc::TrimMax(libfunc) => {
build_trim(builder, &libfunc.trimmed_value)
}
BoundedIntConcreteLibfunc::IsZero(_) => build_is_zero(builder),
BoundedIntConcreteLibfunc::WrapNonZero(_) => build_identity(builder),
}
Expand Down
81 changes: 43 additions & 38 deletions crates/cairo-lang-sierra/src/extensions/modules/bounded_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ define_libfunc_hierarchy! {
Mul(BoundedIntMulLibfunc),
DivRem(BoundedIntDivRemLibfunc),
Constrain(BoundedIntConstrainLibfunc),
Trim(BoundedIntTrimLibfunc),
TrimMin(BoundedIntTrimLibfunc<false>),
TrimMax(BoundedIntTrimLibfunc<true>),
IsZero(BoundedIntIsZeroLibfunc),
WrapNonZero(BoundedIntWrapNonZeroLibfunc),
}, BoundedIntConcreteLibfunc
Expand Down Expand Up @@ -388,67 +389,71 @@ impl SignatureBasedConcreteLibfunc for BoundedIntConstrainConcreteLibfunc {
/// Libfunc for trimming a BoundedInt<Min, Max> by removing `Min` or `Max` from the range.
/// The libfunc is also applicable for standard types such as u* and i*.
#[derive(Default)]
pub struct BoundedIntTrimLibfunc {}
impl NamedLibfunc for BoundedIntTrimLibfunc {
pub struct BoundedIntTrimLibfunc<const IS_MAX: bool> {}
impl<const IS_MAX: bool> NamedLibfunc for BoundedIntTrimLibfunc<IS_MAX> {
type Concrete = BoundedIntTrimConcreteLibfunc;

const STR_ID: &'static str = "bounded_int_trim";
const STR_ID: &'static str =
if IS_MAX { "bounded_int_trim_max" } else { "bounded_int_trim_min" };

fn specialize_signature(
&self,
context: &dyn SignatureSpecializationContext,
args: &[GenericArg],
) -> Result<LibfuncSignature, SpecializationError> {
let (ty, trimmed_value) = match args {
[GenericArg::Type(ty), GenericArg::Value(trimmed_value)] => Ok((ty, trimmed_value)),
[_, _] => Err(SpecializationError::UnsupportedGenericArg),
_ => Err(SpecializationError::WrongNumberOfGenericArgs),
}?;
Ok(Self::Concrete::new::<IS_MAX>(context, args)?.signature)
}

fn specialize(
&self,
context: &dyn SpecializationContext,
args: &[GenericArg],
) -> Result<Self::Concrete, SpecializationError> {
Self::Concrete::new::<IS_MAX>(context.upcast(), args)
}
}

pub struct BoundedIntTrimConcreteLibfunc {
pub trimmed_value: BigInt,
signature: LibfuncSignature,
}
impl BoundedIntTrimConcreteLibfunc {
fn new<const IS_MAX: bool>(
context: &dyn SignatureSpecializationContext,
args: &[GenericArg],
) -> Result<Self, SpecializationError> {
let ty = args_as_single_type(args)?;
let ty_info = context.get_type_info(ty.clone())?;
let mut range = Range::from_type_info(&ty_info)?;
if trimmed_value == &range.lower {
range.lower += 1;
let range = Range::from_type_info(&ty_info)?;
let (res_ty, trimmed_value) = if IS_MAX {
(
bounded_int_ty(context, range.lower.clone(), range.upper.clone() - 2)?,
range.upper - 1,
)
} else {
range.upper -= 1;
require(&range.upper == trimmed_value)
.ok_or(SpecializationError::UnsupportedGenericArg)?;
}
(
bounded_int_ty(context, range.lower.clone() + 1, range.upper.clone() - 1)?,
range.lower,
)
};
let ap_change = SierraApChange::Known { new_vars_only: trimmed_value.is_zero() };
Ok(LibfuncSignature {
let signature = LibfuncSignature {
param_signatures: vec![ParamSignature::new(ty.clone())],
branch_signatures: vec![
BranchSignature { vars: vec![], ap_change: ap_change.clone() },
BranchSignature {
vars: vec![OutputVarInfo {
ty: bounded_int_ty(context, range.lower, range.upper - 1)?,
ty: res_ty,
ref_info: OutputVarReferenceInfo::SameAsParam { param_idx: 0 },
}],
ap_change,
},
],
fallthrough: Some(0),
})
}

fn specialize(
&self,
context: &dyn SpecializationContext,
args: &[GenericArg],
) -> Result<Self::Concrete, SpecializationError> {
let trimmed_value = match args {
[GenericArg::Type(_), GenericArg::Value(trimmed_value)] => Ok(trimmed_value.clone()),
[_, _] => Err(SpecializationError::UnsupportedGenericArg),
_ => Err(SpecializationError::WrongNumberOfGenericArgs),
}?;
let context = context.upcast();
Ok(Self::Concrete { trimmed_value, signature: self.specialize_signature(context, args)? })
};
Ok(Self { trimmed_value, signature })
}
}

pub struct BoundedIntTrimConcreteLibfunc {
pub trimmed_value: BigInt,
signature: LibfuncSignature,
}
impl SignatureBasedConcreteLibfunc for BoundedIntTrimConcreteLibfunc {
fn signature(&self) -> &LibfuncSignature {
&self.signature
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
"bounded_int_is_zero",
"bounded_int_mul",
"bounded_int_sub",
"bounded_int_trim",
"bounded_int_trim_max",
"bounded_int_trim_min",
"bounded_int_wrap_non_zero",
"box_forward_snapshot",
"branch_align",
Expand Down
Loading
Loading