Skip to content

Commit

Permalink
Support casting Utf8 to Boolean
Browse files Browse the repository at this point in the history
  • Loading branch information
MazterQyou committed May 26, 2022
1 parent 2ba1ef4 commit f2d2e41
Showing 1 changed file with 63 additions and 7 deletions.
70 changes: 63 additions & 7 deletions arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Dictionary(_, value_type), _) => can_cast_types(value_type, to_type),
(_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type),

(_, Boolean) => DataType::is_numeric(from_type),
(_, Boolean) => DataType::is_numeric(from_type) || from_type == &Utf8,
(Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8,

(Utf8, LargeUtf8) => true,
Expand Down Expand Up @@ -280,6 +280,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
///
/// Behavior:
/// * Boolean to Utf8: `true` => '1', `false` => `0`
/// * Utf8 to boolean: `true`, `yes`, `on`, `1` => `true`, `false`, `no`, `off`, `0` => `false`,
/// short variants are accepted, other strings return null or error
/// * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings
/// in integer casts return null
/// * Numeric to boolean: 0 returns `false`, any other value returns `true`
Expand All @@ -293,7 +295,6 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
/// Unsupported Casts
/// * To or from `StructArray`
/// * List to primitive
/// * Utf8 to boolean
/// * Interval and duration
pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
cast_with_options(array, to_type, &DEFAULT_CAST_OPTIONS)
Expand Down Expand Up @@ -396,6 +397,8 @@ macro_rules! cast_decimal_to_float {
///
/// Behavior:
/// * Boolean to Utf8: `true` => '1', `false` => `0`
/// * Utf8 to boolean: `true`, `yes`, `on`, `1` => `true`, `false`, `no`, `off`, `0` => `false`,
/// short variants are accepted, other strings return null or error
/// * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings
/// in integer casts return null
/// * Numeric to boolean: 0 returns `false`, any other value returns `true`
Expand All @@ -409,7 +412,6 @@ macro_rules! cast_decimal_to_float {
/// Unsupported Casts
/// * To or from `StructArray`
/// * List to primitive
/// * Utf8 to boolean
pub fn cast_with_options(
array: &ArrayRef,
to_type: &DataType,
Expand Down Expand Up @@ -643,10 +645,7 @@ pub fn cast_with_options(
Int64 => cast_numeric_to_bool::<Int64Type>(array),
Float32 => cast_numeric_to_bool::<Float32Type>(array),
Float64 => cast_numeric_to_bool::<Float64Type>(array),
Utf8 => Err(ArrowError::CastError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
Utf8 => cast_utf8_to_boolean(array, cast_options),
_ => Err(ArrowError::CastError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
Expand Down Expand Up @@ -1661,6 +1660,34 @@ fn cast_string_to_timestamp_ns<Offset: OffsetSizeTrait>(
Ok(Arc::new(array) as ArrayRef)
}

/// Casts Utf8 to Boolean
fn cast_utf8_to_boolean(from: &ArrayRef, cast_options: &CastOptions) -> Result<ArrayRef> {
let array = as_string_array(from);

let output_array = array
.iter()
.map(|value| match value {
Some(value) => match value.to_ascii_lowercase().trim() {
"t" | "tr" | "tru" | "true" | "y" | "ye" | "yes" | "on" | "1" => {
Ok(Some(true))
}
"f" | "fa" | "fal" | "fals" | "false" | "n" | "no" | "of" | "off"
| "0" => Ok(Some(false)),
invalid_value => match cast_options.safe {
true => Ok(None),
false => Err(ArrowError::CastError(format!(
"Cannot cast string '{}' to value of Boolean type",
invalid_value,
))),
},
},
None => Ok(None),
})
.collect::<Result<BooleanArray>>()?;

Ok(Arc::new(output_array))
}

/// Cast numeric types to Boolean
///
/// Any zero value returns `false` while non-zero returns `true`
Expand Down Expand Up @@ -2638,6 +2665,35 @@ mod tests {
}
}

#[test]
fn test_cast_utf8_to_bool() {
let a = StringArray::from(vec!["true", "false", "invalid", " Y ", ""]);
let array = Arc::new(a) as ArrayRef;
let b = cast(&array, &DataType::Boolean).unwrap();
let c = b.as_any().downcast_ref::<BooleanArray>().unwrap();
assert!(c.value(0));
assert!(!c.value(1));
assert!(!c.is_valid(2));
assert!(c.value(3));
assert!(!c.is_valid(4));
}

#[test]
fn test_cast_with_options_utf8_to_bool() {
let a = StringArray::from(vec!["true", "false", "invalid", " Y ", ""]);
let array = Arc::new(a) as ArrayRef;
let result =
cast_with_options(&array, &DataType::Boolean, &CastOptions { safe: false });
match result {
Ok(_) => panic!("expected error"),
Err(e) => {
assert!(e.to_string().contains(
"Cast error: Cannot cast string 'invalid' to value of Boolean type"
))
}
}
}

#[test]
fn test_cast_bool_to_i32() {
let a = BooleanArray::from(vec![Some(true), Some(false), None]);
Expand Down

0 comments on commit f2d2e41

Please sign in to comment.