Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get the round result for decimal to a decimal with smaller scale #3224

Merged
merged 7 commits into from
Dec 3, 2022
105 changes: 101 additions & 4 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2131,6 +2131,7 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
if BYTE_WIDTH1 == 16 {
let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
if BYTE_WIDTH2 == 16 {
// the div must be greater or equal than 10
let div = 10_i128
.pow_checked((input_scale - output_scale) as u32)
.map_err(|_| {
Expand All @@ -2139,10 +2140,23 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
*output_scale,
))
})?;
let half = div / 2;
let neg_half = -half;

array
.try_unary::<_, Decimal128Type, _>(|v| {
v.checked_div(div).ok_or_else(|| {
// cast to smaller scale, need to round the result
// the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
let d = v / div;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to make consistent with the usage and clear compilation result?

let r = v % div;
if v >= 0 && r >= half {
d.checked_add(1)
} else if v < 0 && r <= neg_half {
d.checked_sub(1)
} else {
Some(d)
}
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). Overflowing on {:?}",
Decimal128Type::PREFIX,
Expand All @@ -2166,9 +2180,23 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
))
})?;

let half = div / i256::from_i128(2_i128);
let neg_half = -half;

array
.try_unary::<_, Decimal256Type, _>(|v| {
i256::from_i128(v).checked_div(div).ok_or_else(|| {
// the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
let v = i256::from_i128(v);
let d = v / div;
let r = v % div;
if v >= i256::ZERO && r >= half {
d.checked_add(i256::ONE)
} else if v < i256::ZERO && r <= neg_half {
d.checked_sub(i256::ONE)
} else {
Some(d)
}
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). Overflowing on {:?}",
Decimal256Type::PREFIX,
Expand All @@ -2193,10 +2221,21 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
*output_scale,
))
})?;
let half = div / i256::from_i128(2_i128);
let neg_half = -half;
if BYTE_WIDTH2 == 16 {
array
.try_unary::<_, Decimal128Type, _>(|v| {
v.checked_div(div).ok_or_else(|| {
// the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
let d = v / div;
let r = v % div;
if v >= i256::ZERO && r >= half {
d.checked_add(i256::ONE)
} else if v < i256::ZERO && r <= neg_half {
d.checked_sub(i256::ONE)
} else {
Some(d)
}.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). Overflowing on {:?}",
Decimal128Type::PREFIX,
Expand All @@ -2217,7 +2256,17 @@ fn cast_decimal_to_decimal<const BYTE_WIDTH1: usize, const BYTE_WIDTH2: usize>(
} else {
array
.try_unary::<_, Decimal256Type, _>(|v| {
v.checked_div(div).ok_or_else(|| {
// the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation
let d = v / div;
let r = v % div;
if v >= i256::ZERO && r >= half {
d.checked_add(i256::ONE)
} else if v < i256::ZERO && r <= neg_half {
d.checked_sub(i256::ONE)
} else {
Some(d)
}
.ok_or_else(|| {
ArrowError::CastError(format!(
"Cannot cast to {:?}({}, {}). Overflowing on {:?}",
Decimal256Type::PREFIX,
Expand Down Expand Up @@ -3588,6 +3637,26 @@ mod tests {
}
}
}

let cast_option = CastOptions { safe: false };
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add cast with safe is false

let casted_array_with_option =
cast_with_options($INPUT_ARRAY, $OUTPUT_TYPE, &cast_option).unwrap();
let result_array = casted_array_with_option
.as_any()
.downcast_ref::<$OUTPUT_TYPE_ARRAY>()
.unwrap();
assert_eq!($OUTPUT_TYPE, result_array.data_type());
assert_eq!(result_array.len(), $OUTPUT_VALUES.len());
for (i, x) in $OUTPUT_VALUES.iter().enumerate() {
match x {
Some(x) => {
assert_eq!(result_array.value(i), *x);
}
None => {
assert!(result_array.is_null(i));
}
}
}
};
}

Expand Down Expand Up @@ -3795,6 +3864,34 @@ mod tests {
result.unwrap_err().to_string());
}

#[test]
fn test_cast_decimal256_to_decimal128_overflow() {
let input_type = DataType::Decimal256(76, 5);
let output_type = DataType::Decimal128(38, 7);
assert!(can_cast_types(&input_type, &output_type));
let array = vec![Some(i256::from_i128(i128::MAX))];
let input_decimal_array = create_decimal256_array(array, 76, 5).unwrap();
let array = Arc::new(input_decimal_array) as ArrayRef;
let result =
cast_with_options(&array, &output_type, &CastOptions { safe: false });
assert_eq!("Invalid argument error: 17014118346046923173168730371588410572700 cannot be casted to 128-bit integer for Decimal128",
result.unwrap_err().to_string());
}

#[test]
fn test_cast_decimal256_to_decimal256_overflow() {
let input_type = DataType::Decimal256(76, 5);
let output_type = DataType::Decimal256(76, 55);
assert!(can_cast_types(&input_type, &output_type));
let array = vec![Some(i256::from_i128(i128::MAX))];
let input_decimal_array = create_decimal256_array(array, 76, 5).unwrap();
let array = Arc::new(input_decimal_array) as ArrayRef;
let result =
cast_with_options(&array, &output_type, &CastOptions { safe: false });
assert_eq!("Cast error: Cannot cast to \"Decimal256\"(76, 55). Overflowing on 170141183460469231731687303715884105727",
result.unwrap_err().to_string());
}

#[test]
fn test_cast_decimal128_to_decimal256() {
let input_type = DataType::Decimal128(20, 3);
Expand Down