From cb4170b50a54c466897afc83583f01dca23544c0 Mon Sep 17 00:00:00 2001 From: Kun Liu Date: Sat, 3 Dec 2022 23:39:25 +0800 Subject: [PATCH] Get the round result for decimal to a decimal with smaller scale (#3224) * support cast decimal for round when the option is false * fix conflict after merge * fix error case * change to wrapping api --- arrow-cast/src/cast.rs | 143 ++++++++++++++++++++++++++++++++--------- 1 file changed, 111 insertions(+), 32 deletions(-) diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index be767f137cd8..8d28a6cc772d 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -2164,6 +2164,7 @@ fn cast_decimal_to_decimal( if BYTE_WIDTH1 == 16 { let array = array.as_any().downcast_ref::().unwrap(); if BYTE_WIDTH2 == 16 { + // the div must be greater or equal than 10 let div = 10_i128 .pow_checked((input_scale - output_scale) as u32) .map_err(|_| { @@ -2172,10 +2173,23 @@ fn cast_decimal_to_decimal( *output_scale, )) })?; + let half = div / 2; + let neg_half = -half; array .try_unary::<_, Decimal128Type, _>(|v| { - v.checked_div(div).ok_or_else(|| { + // cast to smaller scale, need to round the result + // the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation + let d = v.wrapping_div(div); + let r = v.wrapping_rem(div); + if v >= 0 && r >= half { + d.checked_add(1) + } else if v < 0 && r <= neg_half { + d.checked_sub(1) + } else { + Some(d) + } + .ok_or_else(|| { ArrowError::CastError(format!( "Cannot cast to {:?}({}, {}). Overflowing on {:?}", Decimal128Type::PREFIX, @@ -2199,9 +2213,23 @@ fn cast_decimal_to_decimal( )) })?; + let half = div / i256::from_i128(2_i128); + let neg_half = -half; + array .try_unary::<_, Decimal256Type, _>(|v| { - i256::from_i128(v).checked_div(div).ok_or_else(|| { + // the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation + let v = i256::from_i128(v); + let d = v.wrapping_div(div); + let r = v.wrapping_rem(div); + if v >= i256::ZERO && r >= half { + d.checked_add(i256::ONE) + } else if v < i256::ZERO && r <= neg_half { + d.checked_sub(i256::ONE) + } else { + Some(d) + } + .ok_or_else(|| { ArrowError::CastError(format!( "Cannot cast to {:?}({}, {}). Overflowing on {:?}", Decimal256Type::PREFIX, @@ -2226,10 +2254,21 @@ fn cast_decimal_to_decimal( *output_scale, )) })?; + let half = div / i256::from_i128(2_i128); + let neg_half = -half; if BYTE_WIDTH2 == 16 { array .try_unary::<_, Decimal128Type, _>(|v| { - v.checked_div(div).ok_or_else(|| { + // the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation + let d = v.wrapping_div(div); + let r = v.wrapping_rem(div); + if v >= i256::ZERO && r >= half { + d.checked_add(i256::ONE) + } else if v < i256::ZERO && r <= neg_half { + d.checked_sub(i256::ONE) + } else { + Some(d) + }.ok_or_else(|| { ArrowError::CastError(format!( "Cannot cast to {:?}({}, {}). Overflowing on {:?}", Decimal128Type::PREFIX, @@ -2250,7 +2289,17 @@ fn cast_decimal_to_decimal( } else { array .try_unary::<_, Decimal256Type, _>(|v| { - v.checked_div(div).ok_or_else(|| { + // the div must be gt_eq 10, we don't need to check the overflow for the `div`/`mod` operation + let d = v.wrapping_div(div); + let r = v.wrapping_rem(div); + if v >= i256::ZERO && r >= half { + d.checked_add(i256::ONE) + } else if v < i256::ZERO && r <= neg_half { + d.checked_sub(i256::ONE) + } else { + Some(d) + } + .ok_or_else(|| { ArrowError::CastError(format!( "Cannot cast to {:?}({}, {}). Overflowing on {:?}", Decimal256Type::PREFIX, @@ -3621,6 +3670,26 @@ mod tests { } } } + + let cast_option = CastOptions { safe: false }; + let casted_array_with_option = + cast_with_options($INPUT_ARRAY, $OUTPUT_TYPE, &cast_option).unwrap(); + let result_array = casted_array_with_option + .as_any() + .downcast_ref::<$OUTPUT_TYPE_ARRAY>() + .unwrap(); + assert_eq!($OUTPUT_TYPE, result_array.data_type()); + assert_eq!(result_array.len(), $OUTPUT_VALUES.len()); + for (i, x) in $OUTPUT_VALUES.iter().enumerate() { + match x { + Some(x) => { + assert_eq!(result_array.value(i), *x); + } + None => { + assert!(result_array.is_null(i)); + } + } + } }; } @@ -3647,6 +3716,44 @@ mod tests { } #[test] + #[cfg(not(feature = "force_validate"))] + #[should_panic( + expected = "5789604461865809771178549250434395392663499233282028201972879200395656481997 cannot be casted to 128-bit integer for Decimal128" + )] + fn test_cast_decimal_to_decimal_round_with_error() { + // decimal256 to decimal128 overflow + let array = vec![ + Some(i256::from_i128(1123454)), + Some(i256::from_i128(2123456)), + Some(i256::from_i128(-3123453)), + Some(i256::from_i128(-3123456)), + None, + Some(i256::MAX), + Some(i256::MIN), + ]; + let input_decimal_array = create_decimal256_array(array, 76, 4).unwrap(); + let array = Arc::new(input_decimal_array) as ArrayRef; + let input_type = DataType::Decimal256(76, 4); + let output_type = DataType::Decimal128(20, 3); + assert!(can_cast_types(&input_type, &output_type)); + generate_cast_test_case!( + &array, + Decimal128Array, + &output_type, + vec![ + Some(112345_i128), + Some(212346_i128), + Some(-312345_i128), + Some(-312346_i128), + None, + None, + None, + ] + ); + } + + #[test] + #[cfg(not(feature = "force_validate"))] fn test_cast_decimal_to_decimal_round() { let array = vec![ Some(1123454), @@ -3734,34 +3841,6 @@ mod tests { None ] ); - - // decimal256 to decimal128 overflow - let array = vec![ - Some(i256::from_i128(1123454)), - Some(i256::from_i128(2123456)), - Some(i256::from_i128(-3123453)), - Some(i256::from_i128(-3123456)), - None, - Some(i256::MAX), - Some(i256::MIN), - ]; - let input_decimal_array = create_decimal256_array(array, 76, 4).unwrap(); - let array = Arc::new(input_decimal_array) as ArrayRef; - assert!(can_cast_types(&input_type, &output_type)); - generate_cast_test_case!( - &array, - Decimal128Array, - &output_type, - vec![ - Some(112345_i128), - Some(212346_i128), - Some(-312345_i128), - Some(-312346_i128), - None, - None, - None - ] - ); } #[test]