Skip to content

Commit

Permalink
Added i*_diff and Ord.
Browse files Browse the repository at this point in the history
commit-id:fce91b1f
  • Loading branch information
orizi committed Jul 25, 2023
1 parent f1fecf5 commit aa7a5e3
Show file tree
Hide file tree
Showing 24 changed files with 405 additions and 155 deletions.
106 changes: 106 additions & 0 deletions corelib/src/integer.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1942,6 +1942,27 @@ impl I8MulEq of MulEq<i8> {
}
}

/// If `lhs` >= `rhs` returns `Ok(lhs - rhs)` else returns `Err(2**8 + lhs - rhs)`.
extern fn i8_diff(lhs: i8, rhs: i8) -> Result<u8, u8> implicits(RangeCheck) nopanic;
impl I8PartialOrd of PartialOrd<i8> {
#[inline(always)]
fn le(lhs: i8, rhs: i8) -> bool {
i8_diff(rhs, lhs).into_is_ok()
}
#[inline(always)]
fn ge(lhs: i8, rhs: i8) -> bool {
i8_diff(lhs, rhs).into_is_ok()
}
#[inline(always)]
fn lt(lhs: i8, rhs: i8) -> bool {
i8_diff(lhs, rhs).into_is_err()
}
#[inline(always)]
fn gt(lhs: i8, rhs: i8) -> bool {
i8_diff(rhs, lhs).into_is_err()
}
}

#[derive(Copy, Drop)]
extern type i16;
impl NumericLiterali16 of NumericLiteral<i16>;
Expand Down Expand Up @@ -2013,6 +2034,27 @@ impl I16MulEq of MulEq<i16> {
}
}

/// If `lhs` >= `rhs` returns `Ok(lhs - rhs)` else returns `Err(2**16 + lhs - rhs)`.
extern fn i16_diff(lhs: i16, rhs: i16) -> Result<u16, u16> implicits(RangeCheck) nopanic;
impl I16PartialOrd of PartialOrd<i16> {
#[inline(always)]
fn le(lhs: i16, rhs: i16) -> bool {
i16_diff(rhs, lhs).into_is_ok()
}
#[inline(always)]
fn ge(lhs: i16, rhs: i16) -> bool {
i16_diff(lhs, rhs).into_is_ok()
}
#[inline(always)]
fn lt(lhs: i16, rhs: i16) -> bool {
i16_diff(lhs, rhs).into_is_err()
}
#[inline(always)]
fn gt(lhs: i16, rhs: i16) -> bool {
i16_diff(rhs, lhs).into_is_err()
}
}

#[derive(Copy, Drop)]
extern type i32;
impl NumericLiterali32 of NumericLiteral<i32>;
Expand Down Expand Up @@ -2084,6 +2126,27 @@ impl I32MulEq of MulEq<i32> {
}
}

/// If `lhs` >= `rhs` returns `Ok(lhs - rhs)` else returns `Err(2**32 + lhs - rhs)`.
extern fn i32_diff(lhs: i32, rhs: i32) -> Result<u32, u32> implicits(RangeCheck) nopanic;
impl I32PartialOrd of PartialOrd<i32> {
#[inline(always)]
fn le(lhs: i32, rhs: i32) -> bool {
i32_diff(rhs, lhs).into_is_ok()
}
#[inline(always)]
fn ge(lhs: i32, rhs: i32) -> bool {
i32_diff(lhs, rhs).into_is_ok()
}
#[inline(always)]
fn lt(lhs: i32, rhs: i32) -> bool {
i32_diff(lhs, rhs).into_is_err()
}
#[inline(always)]
fn gt(lhs: i32, rhs: i32) -> bool {
i32_diff(rhs, lhs).into_is_err()
}
}

#[derive(Copy, Drop)]
extern type i64;
impl NumericLiterali64 of NumericLiteral<i64>;
Expand Down Expand Up @@ -2155,6 +2218,27 @@ impl I64MulEq of MulEq<i64> {
}
}

/// If `lhs` >= `rhs` returns `Ok(lhs - rhs)` else returns `Err(2**64 + lhs - rhs)`.
extern fn i64_diff(lhs: i64, rhs: i64) -> Result<u64, u64> implicits(RangeCheck) nopanic;
impl I64PartialOrd of PartialOrd<i64> {
#[inline(always)]
fn le(lhs: i64, rhs: i64) -> bool {
i64_diff(rhs, lhs).into_is_ok()
}
#[inline(always)]
fn ge(lhs: i64, rhs: i64) -> bool {
i64_diff(lhs, rhs).into_is_ok()
}
#[inline(always)]
fn lt(lhs: i64, rhs: i64) -> bool {
i64_diff(lhs, rhs).into_is_err()
}
#[inline(always)]
fn gt(lhs: i64, rhs: i64) -> bool {
i64_diff(rhs, lhs).into_is_err()
}
}

#[derive(Copy, Drop)]
extern type i128;
impl NumericLiterali128 of NumericLiteral<i128>;
Expand Down Expand Up @@ -2212,3 +2296,25 @@ impl I128SubEq of SubEq<i128> {
self = Sub::sub(self, other);
}
}


/// If `lhs` >= `rhs` returns `Ok(lhs - rhs)` else returns `Err(2**128 + lhs - rhs)`.
extern fn i128_diff(lhs: i128, rhs: i128) -> Result<u128, u128> implicits(RangeCheck) nopanic;
impl I128PartialOrd of PartialOrd<i128> {
#[inline(always)]
fn le(lhs: i128, rhs: i128) -> bool {
i128_diff(rhs, lhs).into_is_ok()
}
#[inline(always)]
fn ge(lhs: i128, rhs: i128) -> bool {
i128_diff(lhs, rhs).into_is_ok()
}
#[inline(always)]
fn lt(lhs: i128, rhs: i128) -> bool {
i128_diff(lhs, rhs).into_is_err()
}
#[inline(always)]
fn gt(lhs: i128, rhs: i128) -> bool {
i128_diff(rhs, lhs).into_is_err()
}
}
65 changes: 65 additions & 0 deletions corelib/src/test/integer_test.cairo
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use core::result::ResultTrait;
use traits::{Into, TryInto};
use core::traits::Default;
use option::OptionTrait;
Expand Down Expand Up @@ -1146,6 +1147,15 @@ fn test_i8_operators() {
assert_eq(@(2_i8 * -4_i8), @-8_i8, '2 * -4 == -8');
assert_eq(@(-1_i8 * -3_i8), @3_i8, '-1 * -3 == 3');
assert_eq(@(-2_i8 * -4_i8), @8_i8, '-2 * -4 == 8');
assert_lt(1_i8, 4_i8, '1 < 4');
assert_le(1_i8, 4_i8, '1 <= 4');
assert(!(4_i8 < 4_i8), '!(4 < 4)');
assert_le(5_i8, 5_i8, '5 <= 5');
assert(!(5_i8 <= 4_i8), '!(5 <= 8)');
assert_gt(5_i8, 2_i8, '5 > 2');
assert_ge(5_i8, 2_i8, '5 >= 2');
assert(!(3_i8 > 3_i8), '!(3 > 3)');
assert_ge(3_i8, 3_i8, '3 >= 3');
}

#[test]
Expand Down Expand Up @@ -1240,6 +1250,15 @@ fn test_i16_operators() {
assert_eq(@(2_i16 * -4_i16), @-8_i16, '2 * -4 == -8');
assert_eq(@(-1_i16 * -3_i16), @3_i16, '-1 * -3 == 3');
assert_eq(@(-2_i16 * -4_i16), @8_i16, '-2 * -4 == 8');
assert_lt(1_i16, 4_i16, '1 < 4');
assert_le(1_i16, 4_i16, '1 <= 4');
assert(!(4_i16 < 4_i16), '!(4 < 4)');
assert_le(5_i16, 5_i16, '5 <= 5');
assert(!(5_i16 <= 4_i16), '!(5 <= 8)');
assert_gt(5_i16, 2_i16, '5 > 2');
assert_ge(5_i16, 2_i16, '5 >= 2');
assert(!(3_i16 > 3_i16), '!(3 > 3)');
assert_ge(3_i16, 3_i16, '3 >= 3');
}

#[test]
Expand Down Expand Up @@ -1334,6 +1353,15 @@ fn test_i32_operators() {
assert_eq(@(2_i32 * -4_i32), @-8_i32, '2 * -4 == -8');
assert_eq(@(-1_i32 * -3_i32), @3_i32, '-1 * -3 == 3');
assert_eq(@(-2_i32 * -4_i32), @8_i32, '-2 * -4 == 8');
assert_lt(1_i32, 4_i32, '1 < 4');
assert_le(1_i32, 4_i32, '1 <= 4');
assert(!(4_i32 < 4_i32), '!(4 < 4)');
assert_le(5_i32, 5_i32, '5 <= 5');
assert(!(5_i32 <= 4_i32), '!(5 <= 8)');
assert_gt(5_i32, 2_i32, '5 > 2');
assert_ge(5_i32, 2_i32, '5 >= 2');
assert(!(3_i32 > 3_i32), '!(3 > 3)');
assert_ge(3_i32, 3_i32, '3 >= 3');
}

#[test]
Expand Down Expand Up @@ -1436,6 +1464,15 @@ fn test_i64_operators() {
assert_eq(@(2_i64 * -4_i64), @-8_i64, '2 * -4 == -8');
assert_eq(@(-1_i64 * -3_i64), @3_i64, '-1 * -3 == 3');
assert_eq(@(-2_i64 * -4_i64), @8_i64, '-2 * -4 == 8');
assert_lt(1_i64, 4_i64, '1 < 4');
assert_le(1_i64, 4_i64, '1 <= 4');
assert(!(4_i64 < 4_i64), '!(4 < 4)');
assert_le(5_i64, 5_i64, '5 <= 5');
assert(!(5_i64 <= 4_i64), '!(5 <= 8)');
assert_gt(5_i64, 2_i64, '5 > 2');
assert_ge(5_i64, 2_i64, '5 >= 2');
assert(!(3_i64 > 3_i64), '!(3 > 3)');
assert_ge(3_i64, 3_i64, '3 >= 3');
}

#[test]
Expand Down Expand Up @@ -1530,6 +1567,15 @@ fn test_i128_operators() {
assert_eq(@(-3_i128 + -6_i128), @-9_i128, '-3 + -6 == -9');
assert_eq(@(-3_i128 - -1_i128), @-2_i128, '-3 - -1 == -2');
assert_eq(@(-231_i128 - -131_i128), @-100_i128, '-231--131=-100');
assert_lt(1_i128, 4_i128, '1 < 4');
assert_le(1_i128, 4_i128, '1 <= 4');
assert(!(4_i128 < 4_i128), '!(4 < 4)');
assert_le(5_i128, 5_i128, '5 <= 5');
assert(!(5_i128 <= 4_i128), '!(5 <= 8)');
assert_gt(5_i128, 2_i128, '5 > 2');
assert_ge(5_i128, 2_i128, '5 >= 2');
assert(!(3_i128 > 3_i128), '!(3 > 3)');
assert_ge(3_i128, 3_i128, '3 >= 3');
}

#[test]
Expand Down Expand Up @@ -1579,3 +1625,22 @@ fn test_i128_add_overflow_2() {
fn test_i128_add_underflow() {
-0x64000000000000000000000000000000_i128 + -0x1e000000000000000000000000000000_i128;
}

#[test]
fn test_signed_int_diff() {
assert_eq(@integer::i8_diff(3, 3).unwrap(), @0, 'i8: 3 - 3 == 0');
assert_eq(@integer::i8_diff(4, 3).unwrap(), @1, 'i8: 4 - 3 == 1');
assert_eq(@integer::i8_diff(3, 5).unwrap_err(), @~(2 - 1), 'i8: 3 - 5 == -2');
assert_eq(@integer::i16_diff(3, 3).unwrap(), @0, 'i16: 3 - 3 == 0');
assert_eq(@integer::i16_diff(4, 3).unwrap(), @1, 'i16: 4 - 3 == 1');
assert_eq(@integer::i16_diff(3, 5).unwrap_err(), @~(2 - 1), 'i16: 3 - 5 == -2');
assert_eq(@integer::i32_diff(3, 3).unwrap(), @0, 'i32: 3 - 3 == 0');
assert_eq(@integer::i32_diff(4, 3).unwrap(), @1, 'i32: 4 - 3 == 1');
assert_eq(@integer::i32_diff(3, 5).unwrap_err(), @~(2 - 1), 'i32: 3 - 5 == -2');
assert_eq(@integer::i64_diff(3, 3).unwrap(), @0, 'i64: 3 - 3 == 0');
assert_eq(@integer::i64_diff(4, 3).unwrap(), @1, 'i64: 4 - 3 == 1');
assert_eq(@integer::i64_diff(3, 5).unwrap_err(), @~(2 - 1), 'i64: 3 - 5 == -2');
assert_eq(@integer::i128_diff(3, 3).unwrap(), @0, 'i128: 3 - 3 == 0');
assert_eq(@integer::i128_diff(4, 3).unwrap(), @1, 'i128: 4 - 3 == 1');
assert_eq(@integer::i128_diff(3, 5).unwrap_err(), @~(2 - 1), 'i128: 3 - 5 == -2');
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ pub fn core_libfunc_ap_change<InfoProvider: InvocationApChangeInfoProvider>(
Sint128Concrete::Operation(_) => {
vec![ApChange::Known(3), ApChange::Known(4), ApChange::Known(4)]
}
Sint128Concrete::Diff(_) => vec![ApChange::Known(2), ApChange::Known(3)],
},
CoreConcreteLibfunc::Mem(libfunc) => match libfunc {
MemConcreteLibfunc::StoreTemp(libfunc) => {
Expand Down Expand Up @@ -325,5 +326,6 @@ fn sint_ap_change<TSintTraits: SintTraits + IntMulTraits + IsZeroTraits>(
SintConcrete::Operation(_) => {
vec![ApChange::Known(4), ApChange::Known(4), ApChange::Known(4)]
}
SintConcrete::Diff(_) => vec![ApChange::Known(2), ApChange::Known(3)],
}
}
8 changes: 8 additions & 0 deletions crates/cairo-lang-sierra-gas/src/core_libfunc_cost_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,10 @@ fn sint_libfunc_cost<TSintTraits: SintTraits + IsZeroTraits + IntMulTraits>(
ConstCost { steps: 6, holes: 0, range_checks: 1 }.into(),
ConstCost { steps: 6, holes: 0, range_checks: 1 }.into(),
],
SintConcrete::Diff(_) => vec![
(ConstCost { steps: 3, holes: 0, range_checks: 1 }).into(),
(ConstCost { steps: 5, holes: 0, range_checks: 1 }).into(),
],
}
}

Expand Down Expand Up @@ -656,6 +660,10 @@ fn s128_libfunc_cost(libfunc: &Sint128Concrete) -> Vec<BranchCost> {
ConstCost { steps: 6, holes: 0, range_checks: 1 }.into(),
ConstCost { steps: 6, holes: 0, range_checks: 1 }.into(),
],
Sint128Concrete::Diff(_) => vec![
ConstCost { steps: 3, holes: 0, range_checks: 1 }.into(),
ConstCost { steps: 5, holes: 0, range_checks: 1 }.into(),
],
}
}

Expand Down
89 changes: 88 additions & 1 deletion crates/cairo-lang-sierra-to-casm/src/invocations/int/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ use cairo_lang_casm::builder::CasmBuilder;
use cairo_lang_casm::casm_build_extend;
use cairo_lang_casm::cell_expression::CellExpression;
use cairo_lang_sierra::extensions::int::{IntConstConcreteLibfunc, IntTraits};
use num_bigint::BigInt;

use super::{CompiledInvocation, CompiledInvocationBuilder, InvocationError};
use crate::invocations::{add_input_variables, CostValidationInfo};
use crate::invocations::{
add_input_variables, get_non_fallthrough_statement_id, CostValidationInfo,
};
use crate::references::ReferenceExpression;

pub mod signed;
Expand Down Expand Up @@ -45,3 +48,87 @@ pub fn build_small_wide_mul(
CostValidationInfo::default(),
))
}

/// Handles a small integer diff operation.
/// absolute distance between the inputs must be smaller than `limit`.
fn build_small_diff(
builder: CompiledInvocationBuilder<'_>,
limit: BigInt,
) -> Result<CompiledInvocation, InvocationError> {
let failure_handle_statement_id = get_non_fallthrough_statement_id(&builder);
let [range_check, a, b] = builder.try_get_single_cells()?;
let mut casm_builder = CasmBuilder::default();
add_input_variables! {casm_builder,
buffer(0) range_check;
deref a;
deref b;
};
casm_build_extend! {casm_builder,
let orig_range_check = range_check;
tempvar a_ge_b;
tempvar a_minus_b = a - b;
const u128_limit = (BigInt::from(u128::MAX) + 1) as BigInt;
const limit = limit;
hint TestLessThan {lhs: a_minus_b, rhs: limit} into {dst: a_ge_b};
jump NoOverflow if a_ge_b != 0;
// Overflow (negative):
// Here we know that 0 - (limit - 1) <= a - b < 0.
tempvar fixed_a_minus_b = a_minus_b + u128_limit;
assert fixed_a_minus_b = *(range_check++);
let wrapping_a_minus_b = a_minus_b + limit;
jump Target;
NoOverflow:
assert a_minus_b = *(range_check++);
};
Ok(builder.build_from_casm_builder(
casm_builder,
[
("Fallthrough", &[&[range_check], &[a_minus_b]], None),
("Target", &[&[range_check], &[wrapping_a_minus_b]], Some(failure_handle_statement_id)),
],
CostValidationInfo {
range_check_info: Some((orig_range_check, range_check)),
extra_costs: None,
},
))
}

/// Handles a 128 bit diff operation.
fn build_128bit_diff(
builder: CompiledInvocationBuilder<'_>,
) -> Result<CompiledInvocation, InvocationError> {
let failure_handle_statement_id = get_non_fallthrough_statement_id(&builder);
let [range_check, a, b] = builder.try_get_single_cells()?;
let mut casm_builder = CasmBuilder::default();
add_input_variables! {casm_builder,
buffer(0) range_check;
deref a;
deref b;
};
casm_build_extend! {casm_builder,
let orig_range_check = range_check;
tempvar a_ge_b;
tempvar a_minus_b = a - b;
const u128_limit = (BigInt::from(u128::MAX) + 1) as BigInt;
hint TestLessThan {lhs: a_minus_b, rhs: u128_limit} into {dst: a_ge_b};
jump NoOverflow if a_ge_b != 0;
// Overflow (negative):
// Here we know that 0 - (2**128 - 1) <= a - b < 0.
tempvar wrapping_a_minus_b = a_minus_b + u128_limit;
assert wrapping_a_minus_b = *(range_check++);
jump Target;
NoOverflow:
assert a_minus_b = *(range_check++);
};
Ok(builder.build_from_casm_builder(
casm_builder,
[
("Fallthrough", &[&[range_check], &[a_minus_b]], None),
("Target", &[&[range_check], &[wrapping_a_minus_b]], Some(failure_handle_statement_id)),
],
CostValidationInfo {
range_check_info: Some((orig_range_check, range_check)),
extra_costs: None,
},
))
}
Loading

0 comments on commit aa7a5e3

Please sign in to comment.