Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for casting StringViewArray to DecimalArray #6720

Merged
merged 4 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions arrow-array/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ mod list_view_array;

pub use list_view_array::*;

use crate::iterator::ArrayIter;

/// An array in the [arrow columnar format](https://arrow.apache.org/docs/format/Columnar.html)
pub trait Array: std::fmt::Debug + Send + Sync {
/// Returns the array as [`Any`] so that it can be
Expand Down Expand Up @@ -570,6 +572,40 @@ pub trait ArrayAccessor: Array {
unsafe fn value_unchecked(&self, index: usize) -> Self::Item;
}

/// A trait for Arrow String Arrays, currently three types are supported:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice to start making public -- we use it in DataFusion a bit as well, so starting to consolidate on a single implementation will be good

/// - `StringArray`
/// - `LargeStringArray`
/// - `StringViewArray`
///
/// This trait helps to abstract over the different types of string arrays
/// so that we don't need to duplicate the implementation for each type.
pub trait StringArrayType<'a>: ArrayAccessor<Item = &'a str> + Sized {
/// Returns true if all data within this string array is ASCII
fn is_ascii(&self) -> bool;

/// Constructs a new iterator
fn iter(&self) -> ArrayIter<Self>;
}

impl<'a, O: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray<O> {
fn is_ascii(&self) -> bool {
GenericStringArray::<O>::is_ascii(self)
}

fn iter(&self) -> ArrayIter<Self> {
GenericStringArray::<O>::iter(self)
}
}
impl<'a> StringArrayType<'a> for &'a StringViewArray {
fn is_ascii(&self) -> bool {
StringViewArray::is_ascii(self)
}

fn iter(&self) -> ArrayIter<Self> {
StringViewArray::iter(self)
}
}

impl PartialEq for dyn Array + '_ {
fn eq(&self, other: &Self) -> bool {
self.to_data().eq(&other.to_data())
Expand Down
68 changes: 58 additions & 10 deletions arrow-cast/src/cast/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,15 +323,16 @@ where
})
}

pub(crate) fn string_to_decimal_cast<T, Offset: OffsetSizeTrait>(
from: &GenericStringArray<Offset>,
pub(crate) fn generic_string_to_decimal_cast<'a, T, S>(
from: &'a S,
precision: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<PrimitiveArray<T>, ArrowError>
where
T: DecimalType,
T::Native: DecimalCast + ArrowNativeTypeOp,
&'a S: StringArrayType<'a>,
{
if cast_options.safe {
let iter = from.iter().map(|v| {
Expand Down Expand Up @@ -375,6 +376,37 @@ where
}
}

pub(crate) fn string_to_decimal_cast<T, Offset: OffsetSizeTrait>(
from: &GenericStringArray<Offset>,
precision: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<PrimitiveArray<T>, ArrowError>
where
T: DecimalType,
T::Native: DecimalCast + ArrowNativeTypeOp,
{
generic_string_to_decimal_cast::<T, GenericStringArray<Offset>>(
from,
precision,
scale,
cast_options,
)
}

pub(crate) fn string_view_to_decimal_cast<T>(
from: &StringViewArray,
precision: u8,
scale: i8,
cast_options: &CastOptions,
) -> Result<PrimitiveArray<T>, ArrowError>
where
T: DecimalType,
T::Native: DecimalCast + ArrowNativeTypeOp,
{
generic_string_to_decimal_cast::<T, StringViewArray>(from, precision, scale, cast_options)
}

/// Cast Utf8 to decimal
pub(crate) fn cast_string_to_decimal<T, Offset: OffsetSizeTrait>(
from: &dyn Array,
Expand All @@ -399,14 +431,30 @@ where
)));
}

Ok(Arc::new(string_to_decimal_cast::<T, Offset>(
from.as_any()
.downcast_ref::<GenericStringArray<Offset>>()
.unwrap(),
precision,
scale,
cast_options,
)?))
let result = match from.data_type() {
DataType::Utf8View => string_view_to_decimal_cast::<T>(
from.as_any().downcast_ref::<StringViewArray>().unwrap(),
precision,
scale,
cast_options,
)?,
DataType::Utf8 | DataType::LargeUtf8 => string_to_decimal_cast::<T, Offset>(
from.as_any()
.downcast_ref::<GenericStringArray<Offset>>()
.unwrap(),
precision,
scale,
cast_options,
)?,
other => {
return Err(ArrowError::ComputeError(format!(
"Cannot cast {:?} to decimal",
other
)))
}
};

Ok(Arc::new(result))
}

pub(crate) fn cast_floating_point_to_decimal128<T: ArrowPrimitiveType>(
Expand Down
48 changes: 41 additions & 7 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Decimal128(_, _) | Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) => true,
// decimal to Utf8
(Decimal128(_, _) | Decimal256(_, _), Utf8 | LargeUtf8) => true,
// Utf8 to decimal
(Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true,
// string to decimal
(Utf8View | Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true,
(Struct(from_fields), Struct(to_fields)) => {
from_fields.len() == to_fields.len() &&
from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| {
Expand Down Expand Up @@ -230,7 +230,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
) => true,
(Utf8 | LargeUtf8, Utf8View) => true,
(BinaryView, Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View ) => true,
(Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16,
(Utf8View | Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16,
(_, Utf8 | LargeUtf8) => from_type.is_primitive(),

(_, Binary | LargeBinary) => from_type.is_integer(),
Expand Down Expand Up @@ -1061,7 +1061,7 @@ pub fn cast_with_options(
*scale,
cast_options,
),
Utf8 => cast_string_to_decimal::<Decimal128Type, i32>(
Utf8View | Utf8 => cast_string_to_decimal::<Decimal128Type, i32>(
array,
*precision,
*scale,
Expand Down Expand Up @@ -1150,7 +1150,7 @@ pub fn cast_with_options(
*scale,
cast_options,
),
Utf8 => cast_string_to_decimal::<Decimal256Type, i32>(
Utf8View | Utf8 => cast_string_to_decimal::<Decimal256Type, i32>(
array,
*precision,
*scale,
Expand Down Expand Up @@ -2485,12 +2485,11 @@ where

#[cfg(test)]
mod tests {
use super::*;
use arrow_buffer::{Buffer, IntervalDayTime, NullBuffer};
use chrono::NaiveDate;
use half::f16;

use super::*;

macro_rules! generate_cast_test_case {
($INPUT_ARRAY: expr, $OUTPUT_TYPE_ARRAY: ident, $OUTPUT_TYPE: expr, $OUTPUT_VALUES: expr) => {
let output =
Expand Down Expand Up @@ -3720,6 +3719,41 @@ mod tests {
assert!(!c.is_valid(4));
}

#[test]
fn test_cast_utf8view_to_i32() {
let array = StringViewArray::from(vec!["5", "6", "seven", "8", "9.1"]);
let b = cast(&array, &DataType::Int32).unwrap();
let c = b.as_primitive::<Int32Type>();
assert_eq!(5, c.value(0));
assert_eq!(6, c.value(1));
assert!(!c.is_valid(2));
assert_eq!(8, c.value(3));
assert!(!c.is_valid(4));
}

#[test]
fn test_cast_utf8view_to_f32() {
let array = StringViewArray::from(vec!["3", "4.56", "seven", "8.9"]);
let b = cast(&array, &DataType::Float32).unwrap();
let c = b.as_primitive::<Float32Type>();
assert_eq!(3.0, c.value(0));
assert_eq!(4.56, c.value(1));
assert!(!c.is_valid(2));
assert_eq!(8.9, c.value(3));
}

#[test]
fn test_cast_utf8view_to_decimal128() {
let array = StringViewArray::from(vec![None, Some("4"), Some("5.6"), Some("7.89")]);
let arr = Arc::new(array) as ArrayRef;
generate_cast_test_case!(
&arr,
Decimal128Array,
&DataType::Decimal128(4, 2),
vec![None, Some(400_i128), Some(560_i128), Some(789_i128)]
);
}

#[test]
fn test_cast_with_options_utf8_to_i32() {
let array = StringArray::from(vec!["5", "6", "seven", "8", "9.1"]);
Expand Down
Loading