Skip to content

Commit

Permalink
fix: use signed comparator to compare decimal128 and decimal256 (#2275)
Browse files Browse the repository at this point in the history
* fix bug: decimal cmp

* optimizer the error message

* address comment
  • Loading branch information
liukun4515 authored Aug 2, 2022
1 parent 58dc611 commit ed9fc56
Showing 1 changed file with 113 additions and 3 deletions.
116 changes: 113 additions & 3 deletions arrow/src/util/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ macro_rules! def_decimal {
"Cannot compare two Decimals with different scale: {}, {}",
self.scale, other.scale
);
self.value.partial_cmp(&other.value)
Some(singed_cmp_le_bytes(&self.value, &other.value))
}
}

Expand All @@ -226,7 +226,7 @@ macro_rules! def_decimal {
"Cannot compare two Decimals with different scale: {}, {}",
self.scale, other.scale
);
self.value.cmp(&other.value)
singed_cmp_le_bytes(&self.value, &other.value)
}
}

Expand All @@ -245,6 +245,49 @@ macro_rules! def_decimal {
};
}

// compare two signed integer which are encoded with little endian.
// left bytes and right bytes must have the same length.
fn singed_cmp_le_bytes(left: &[u8], right: &[u8]) -> Ordering {
assert_eq!(
left.len(),
right.len(),
"Can't compare bytes array with different len: {}, {}",
left.len(),
right.len()
);
assert_ne!(left.len(), 0, "Can't compare bytes array of length 0");
let len = left.len();
// the sign bit is 1, the value is negative
let left_negative = left[len - 1] >= 0x80_u8;
let right_negative = right[len - 1] >= 0x80_u8;
if left_negative != right_negative {
return match left_negative {
true => {
// left is negative value
// right is positive value
Ordering::Less
}
false => {
Ordering::Greater
}
};
}
for i in 0..len {
let l_byte = left[len - 1 - i];
let r_byte = right[len - 1 - i];
match l_byte.cmp(&r_byte) {
Ordering::Less => {
return Ordering::Less;
}
Ordering::Greater => {
return Ordering::Greater;
}
Ordering::Equal => {}
}
}
Ordering::Equal
}

def_decimal!(
Decimal128,
128,
Expand All @@ -260,8 +303,11 @@ def_decimal!(

#[cfg(test)]
mod tests {
use crate::util::decimal::{BasicDecimal, Decimal128, Decimal256};
use crate::util::decimal::{
singed_cmp_le_bytes, BasicDecimal, Decimal128, Decimal256,
};
use num::{BigInt, Num};
use rand::random;

#[test]
fn decimal_128_to_string() {
Expand Down Expand Up @@ -368,4 +414,68 @@ mod tests {
let value = Decimal256::from_big_int(&num, 76, 4).unwrap();
assert_eq!(value.to_string(), "-574437317700748313234121683441537667865831564552201235664496608164256541.5731");
}

#[test]
fn test_lt_cmp_byte() {
for _i in 0..100 {
let left = random::<i128>();
let right = random::<i128>();
let result = singed_cmp_le_bytes(
left.to_le_bytes().as_slice(),
right.to_le_bytes().as_slice(),
);
assert_eq!(left.cmp(&right), result);
}
for _i in 0..100 {
let left = random::<i32>();
let right = random::<i32>();
let result = singed_cmp_le_bytes(
left.to_le_bytes().as_slice(),
right.to_le_bytes().as_slice(),
);
assert_eq!(left.cmp(&right), result);
}
}

#[test]
fn compare_decimal128() {
let v1 = -100_i128;
let v2 = 10000_i128;
let right = Decimal128::new_from_i128(20, 3, v2);
for v in v1..v2 {
let left = Decimal128::new_from_i128(20, 3, v);
assert!(left < right);
}

for _i in 0..100 {
let left = random::<i128>();
let right = random::<i128>();
let left_decimal = Decimal128::new_from_i128(38, 2, left);
let right_decimal = Decimal128::new_from_i128(38, 2, right);
assert_eq!(left < right, left_decimal < right_decimal);
assert_eq!(left == right, left_decimal == right_decimal)
}
}

#[test]
fn compare_decimal256() {
let v1 = -100_i128;
let v2 = 10000_i128;
let right = Decimal256::from_big_int(&BigInt::from(v2), 75, 2).unwrap();
for v in v1..v2 {
let left = Decimal256::from_big_int(&BigInt::from(v), 75, 2).unwrap();
assert!(left < right);
}

for _i in 0..100 {
let left = random::<i128>();
let right = random::<i128>();
let left_decimal =
Decimal256::from_big_int(&BigInt::from(left), 75, 2).unwrap();
let right_decimal =
Decimal256::from_big_int(&BigInt::from(right), 75, 2).unwrap();
assert_eq!(left < right, left_decimal < right_decimal);
assert_eq!(left == right, left_decimal == right_decimal)
}
}
}

0 comments on commit ed9fc56

Please sign in to comment.