Skip to content

Commit

Permalink
Support fixed point multiplication for DictionaryArray of Decimals (#…
Browse files Browse the repository at this point in the history
…4136)

* Add multiply_fixed_point_dyn

* Fix clippy

* For review
  • Loading branch information
viirya authored Apr 30, 2023
1 parent 9d72cc5 commit 08dc16c
Showing 1 changed file with 222 additions and 30 deletions.
252 changes: 222 additions & 30 deletions arrow-arith/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,114 @@ pub fn multiply_dyn_checked(
}
}

#[cfg(feature = "dyn_arith_dict")]
fn get_precision_scale(dt: &DataType) -> Result<(u8, i8), ArrowError> {
match dt {
DataType::Decimal128(precision, scale) => Ok((*precision, *scale)),
_ => Err(ArrowError::ComputeError(
"Cannot get precision and scale from non-decimal type".to_string(),
)),
}
}

/// Returns the precision and scale of the result of a multiplication of two decimal types,
/// and the divisor for fixed point multiplication.
fn get_fixed_point_info(
left: (u8, i8),
right: (u8, i8),
required_scale: i8,
) -> Result<(u8, i8, i256), ArrowError> {
let product_scale = left.1 + right.1;
let precision = min(left.0 + right.0 + 1, DECIMAL128_MAX_PRECISION);

if required_scale > product_scale {
return Err(ArrowError::ComputeError(format!(
"Required scale {} is greater than product scale {}",
required_scale, product_scale
)));
}

let divisor =
i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);

Ok((precision, product_scale, divisor))
}

#[cfg(feature = "dyn_arith_dict")]
/// Perform `left * right` operation on two decimal arrays. If either left or right value is
/// null then the result is also null.
///
/// This performs decimal multiplication which allows precision loss if an exact representation
/// is not possible for the result, according to the required scale. In the case, the result
/// will be rounded to the required scale.
///
/// If the required scale is greater than the product scale, an error is returned.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
///
/// It is implemented for compatibility with precision loss `multiply` function provided by
/// other data processing engines. For multiplication with precision loss detection, use
/// `multiply_dyn` or `multiply_dyn_checked` instead.
pub fn multiply_fixed_point_dyn(
left: &dyn Array,
right: &dyn Array,
required_scale: i8,
) -> Result<ArrayRef, ArrowError> {
match (left.data_type(), right.data_type()) {
(
DataType::Dictionary(_, lhs_value_type),
DataType::Dictionary(_, rhs_value_type),
) if matches!(lhs_value_type.as_ref(), &DataType::Decimal128(_, _))
&& matches!(rhs_value_type.as_ref(), &DataType::Decimal128(_, _)) =>
{
downcast_dictionary_array!(
left => match left.values().data_type() {
DataType::Decimal128(_, _) => {
let lhs_precision_scale = get_precision_scale(lhs_value_type.as_ref())?;
let rhs_precision_scale = get_precision_scale(rhs_value_type.as_ref())?;

let (precision, product_scale, divisor) = get_fixed_point_info(lhs_precision_scale, rhs_precision_scale, required_scale)?;

let right = as_dictionary_array::<_>(right);

if required_scale == product_scale {
let mul = multiply_dyn(left, right)?;
let array = mul.as_any().downcast_ref::<Decimal128Array>().unwrap();
let array = array.clone().with_precision_and_scale(precision, required_scale)?;
return Ok(Arc::new(array))
}

let array = math_op_dict::<_, Decimal128Type, _>(left, right, |a, b| {
let a = i256::from_i128(a);
let b = i256::from_i128(b);

let mut mul = a.wrapping_mul(b);
mul = divide_and_round::<Decimal256Type>(mul, divisor);
mul.as_i128()
}).and_then(|a| a.with_precision_and_scale(precision, required_scale))?;

Ok(Arc::new(array))
}
t => unreachable!("Unsupported dictionary value type {}", t),
},
t => unreachable!("Unsupported data type {}", t),
)
}
(DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
let left = left.as_any().downcast_ref::<Decimal128Array>().unwrap();
let right = right.as_any().downcast_ref::<Decimal128Array>().unwrap();

multiply_fixed_point(left, right, required_scale)
.map(|a| Arc::new(a) as ArrayRef)
}
(_, _) => Err(ArrowError::CastError(format!(
"Unsupported data type {}, {}",
left.data_type(),
right.data_type()
))),
}
}

/// Perform `left * right` operation on two decimal arrays. If either left or right value is
/// null then the result is also null.
///
Expand All @@ -1451,27 +1559,17 @@ pub fn multiply_fixed_point_checked(
right: &PrimitiveArray<Decimal128Type>,
required_scale: i8,
) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
let product_scale = left.scale() + right.scale();
let precision = min(
left.precision() + right.precision() + 1,
DECIMAL128_MAX_PRECISION,
);
let (precision, product_scale, divisor) = get_fixed_point_info(
(left.precision(), left.scale()),
(right.precision(), right.scale()),
required_scale,
)?;

if required_scale == product_scale {
return multiply_checked(left, right)?
.with_precision_and_scale(precision, required_scale);
}

if required_scale > product_scale {
return Err(ArrowError::ComputeError(format!(
"Required scale {} is greater than product scale {}",
required_scale, product_scale
)));
}

let divisor =
i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);

try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
let a = i256::from_i128(a);
let b = i256::from_i128(b);
Expand Down Expand Up @@ -1505,27 +1603,17 @@ pub fn multiply_fixed_point(
right: &PrimitiveArray<Decimal128Type>,
required_scale: i8,
) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
let product_scale = left.scale() + right.scale();
let precision = min(
left.precision() + right.precision() + 1,
DECIMAL128_MAX_PRECISION,
);
let (precision, product_scale, divisor) = get_fixed_point_info(
(left.precision(), left.scale()),
(right.precision(), right.scale()),
required_scale,
)?;

if required_scale == product_scale {
return multiply(left, right)?
.with_precision_and_scale(precision, required_scale);
}

if required_scale > product_scale {
return Err(ArrowError::ComputeError(format!(
"Required scale {} is greater than product scale {}",
required_scale, product_scale
)));
}

let divisor =
i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);

binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
let a = i256::from_i128(a);
let b = i256::from_i128(b);
Expand Down Expand Up @@ -3910,6 +3998,110 @@ mod tests {
);
}

#[test]
#[cfg(feature = "dyn_arith_dict")]
fn test_decimal_multiply_fixed_point_dyn() {
// [123456789]
let a = Decimal128Array::from(vec![123456789000000000000000000])
.with_precision_and_scale(38, 18)
.unwrap();

// [10]
let b = Decimal128Array::from(vec![10000000000000000000])
.with_precision_and_scale(38, 18)
.unwrap();

// Avoid overflow by reducing the scale.
let result = multiply_fixed_point_dyn(&a, &b, 28).unwrap();
// [1234567890]
let expected = Arc::new(
Decimal128Array::from(vec![12345678900000000000000000000000000000])
.with_precision_and_scale(38, 28)
.unwrap(),
) as ArrayRef;

assert_eq!(&expected, &result);
assert_eq!(
result.as_primitive::<Decimal128Type>().value_as_string(0),
"1234567890.0000000000000000000000000000"
);

// [123456789, 10]
let a = Decimal128Array::from(vec![
123456789000000000000000000,
10000000000000000000,
])
.with_precision_and_scale(38, 18)
.unwrap();

// [10, 123456789, 12]
let b = Decimal128Array::from(vec![
10000000000000000000,
123456789000000000000000000,
12000000000000000000,
])
.with_precision_and_scale(38, 18)
.unwrap();

let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1), None]);
let array1 = DictionaryArray::new(keys, Arc::new(a));
let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(2), None]);
let array2 = DictionaryArray::new(keys, Arc::new(b));

let result = multiply_fixed_point_dyn(&array1, &array2, 28).unwrap();
let expected = Arc::new(
Decimal128Array::from(vec![
Some(12345678900000000000000000000000000000),
Some(12345678900000000000000000000000000000),
Some(1200000000000000000000000000000),
None,
])
.with_precision_and_scale(38, 28)
.unwrap(),
) as ArrayRef;

assert_eq!(&expected, &result);
assert_eq!(
result.as_primitive::<Decimal128Type>().value_as_string(0),
"1234567890.0000000000000000000000000000"
);
assert_eq!(
result.as_primitive::<Decimal128Type>().value_as_string(1),
"1234567890.0000000000000000000000000000"
);
assert_eq!(
result.as_primitive::<Decimal128Type>().value_as_string(2),
"120.0000000000000000000000000000"
);

// Required scale is same as the product of the input scales. Behavior is same as multiply_dyn.
let a = Decimal128Array::from(vec![123, 100])
.with_precision_and_scale(3, 2)
.unwrap();

let b = Decimal128Array::from(vec![100, 123, 120])
.with_precision_and_scale(3, 2)
.unwrap();

let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1), None]);
let array1 = DictionaryArray::new(keys, Arc::new(a));
let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(2), None]);
let array2 = DictionaryArray::new(keys, Arc::new(b));

let result = multiply_fixed_point_dyn(&array1, &array2, 4).unwrap();
let expected = multiply_dyn(&array1, &array2).unwrap();
let expected = Arc::new(
expected
.as_any()
.downcast_ref::<Decimal128Array>()
.unwrap()
.clone()
.with_precision_and_scale(7, 4)
.unwrap(),
) as ArrayRef;
assert_eq!(&expected, &result);
}

#[test]
fn test_timestamp_second_add_interval() {
// timestamp second + interval year month
Expand Down

0 comments on commit 08dc16c

Please sign in to comment.