Skip to content

Commit

Permalink
Add length kernel support for List Array (#1488)
Browse files Browse the repository at this point in the history
* add fn for list length
code format

Signed-off-by: remzi <[email protected]>

* add list support into length function

Signed-off-by: remzi <[email protected]>

* add tests

Signed-off-by: remzi <[email protected]>

* update doc

Signed-off-by: remzi <[email protected]>
  • Loading branch information
HaoYang670 authored Mar 28, 2022
1 parent 00accc7 commit f1eba3c
Showing 1 changed file with 107 additions and 29 deletions.
136 changes: 107 additions & 29 deletions arrow/src/compute/kernels/length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,23 @@ macro_rules! unary_offsets {
}};
}

fn octet_length_binary<O: BinaryOffsetSizeTrait, T: ArrowPrimitiveType>(
array: &dyn Array,
) -> ArrayRef
fn length_list<O, T>(array: &dyn Array) -> ArrayRef
where
O: OffsetSizeTrait,
T: ArrowPrimitiveType,
T::Native: OffsetSizeTrait,
{
let array = array
.as_any()
.downcast_ref::<GenericListArray<O>>()
.unwrap();
unary_offsets!(array, T::DATA_TYPE, |x| x)
}

fn length_binary<O, T>(array: &dyn Array) -> ArrayRef
where
O: BinaryOffsetSizeTrait,
T: ArrowPrimitiveType,
T::Native: BinaryOffsetSizeTrait,
{
let array = array
Expand All @@ -69,10 +82,10 @@ where
unary_offsets!(array, T::DATA_TYPE, |x| x)
}

fn octet_length<O: StringOffsetSizeTrait, T: ArrowPrimitiveType>(
array: &dyn Array,
) -> ArrayRef
fn length_string<O, T>(array: &dyn Array) -> ArrayRef
where
O: StringOffsetSizeTrait,
T: ArrowPrimitiveType,
T::Native: StringOffsetSizeTrait,
{
let array = array
Expand All @@ -82,10 +95,10 @@ where
unary_offsets!(array, T::DATA_TYPE, |x| x)
}

fn bit_length_impl_binary<O: BinaryOffsetSizeTrait, T: ArrowPrimitiveType>(
array: &dyn Array,
) -> ArrayRef
fn bit_length_binary<O, T>(array: &dyn Array) -> ArrayRef
where
O: BinaryOffsetSizeTrait,
T: ArrowPrimitiveType,
T::Native: BinaryOffsetSizeTrait,
{
let array = array
Expand All @@ -96,10 +109,10 @@ where
unary_offsets!(array, T::DATA_TYPE, |x| x * bits_in_bytes)
}

fn bit_length_impl<O: StringOffsetSizeTrait, T: ArrowPrimitiveType>(
array: &dyn Array,
) -> ArrayRef
fn bit_length_string<O, T>(array: &dyn Array) -> ArrayRef
where
O: StringOffsetSizeTrait,
T: ArrowPrimitiveType,
T::Native: StringOffsetSizeTrait,
{
let array = array
Expand All @@ -110,20 +123,23 @@ where
unary_offsets!(array, T::DATA_TYPE, |x| x * bits_in_bytes)
}

/// Returns an array of Int32/Int64 denoting the number of bytes in each value in the array.
/// Returns an array of Int32/Int64 denoting the length of each value in the array.
/// For list array, length is the number of elements in each list.
/// For string array and binary array, length is the number of bytes of each value.
///
/// * this only accepts StringArray/Utf8, LargeString/LargeUtf8, BinaryArray and LargeBinaryArray
/// * this only accepts ListArray/LargeListArray, StringArray/LargeStringArray and BinaryArray/LargeBinaryArray
/// * length of null is null.
/// * length is in number of bytes
pub fn length(array: &dyn Array) -> Result<ArrayRef> {
match array.data_type() {
DataType::Utf8 => Ok(octet_length::<i32, Int32Type>(array)),
DataType::LargeUtf8 => Ok(octet_length::<i64, Int64Type>(array)),
DataType::Binary => Ok(octet_length_binary::<i32, Int32Type>(array)),
DataType::LargeBinary => Ok(octet_length_binary::<i64, Int64Type>(array)),
_ => Err(ArrowError::ComputeError(format!(
DataType::List(_) => Ok(length_list::<i32, Int32Type>(array)),
DataType::LargeList(_) => Ok(length_list::<i64, Int64Type>(array)),
DataType::Utf8 => Ok(length_string::<i32, Int32Type>(array)),
DataType::LargeUtf8 => Ok(length_string::<i64, Int64Type>(array)),
DataType::Binary => Ok(length_binary::<i32, Int32Type>(array)),
DataType::LargeBinary => Ok(length_binary::<i64, Int64Type>(array)),
other => Err(ArrowError::ComputeError(format!(
"length not supported for {:?}",
array.data_type()
other
))),
}
}
Expand All @@ -135,19 +151,21 @@ pub fn length(array: &dyn Array) -> Result<ArrayRef> {
/// * bit_length is in number of bits
pub fn bit_length(array: &dyn Array) -> Result<ArrayRef> {
match array.data_type() {
DataType::Utf8 => Ok(bit_length_impl::<i32, Int32Type>(array)),
DataType::LargeUtf8 => Ok(bit_length_impl::<i64, Int64Type>(array)),
DataType::Binary => Ok(bit_length_impl_binary::<i32, Int32Type>(array)),
DataType::LargeBinary => Ok(bit_length_impl_binary::<i64, Int64Type>(array)),
_ => Err(ArrowError::ComputeError(format!(
DataType::Utf8 => Ok(bit_length_string::<i32, Int32Type>(array)),
DataType::LargeUtf8 => Ok(bit_length_string::<i64, Int64Type>(array)),
DataType::Binary => Ok(bit_length_binary::<i32, Int32Type>(array)),
DataType::LargeBinary => Ok(bit_length_binary::<i64, Int64Type>(array)),
other => Err(ArrowError::ComputeError(format!(
"bit_length not supported for {:?}",
array.data_type()
other
))),
}
}

#[cfg(test)]
mod tests {
use crate::datatypes::{Float32Type, Int8Type};

use super::*;

fn double_vec<T: Clone>(v: Vec<T>) -> Vec<T> {
Expand Down Expand Up @@ -182,6 +200,20 @@ mod tests {
}};
}

macro_rules! length_list_helper {
($offset_ty: ty, $result_ty: ty, $element_ty: ty, $value: expr, $expected: expr) => {{
let array =
GenericListArray::<$offset_ty>::from_iter_primitive::<$element_ty, _, _>(
$value,
);
let result = length(&array)?;
let result = result.as_any().downcast_ref::<$result_ty>().unwrap();
let expected: $result_ty = $expected.into();
assert_eq!(expected.data(), result.data());
Ok(())
}};
}

#[test]
#[cfg_attr(miri, ignore)] // running forever
fn length_test_string() -> Result<()> {
Expand Down Expand Up @@ -230,6 +262,28 @@ mod tests {
length_binary_helper!(i64, Int64Array, length, value, result)
}

#[test]
fn length_test_list() -> Result<()> {
let value = vec![
Some(vec![]),
Some(vec![Some(1), Some(2), Some(4)]),
Some(vec![Some(0)]),
];
let result: Vec<i32> = vec![0, 3, 1];
length_list_helper!(i32, Int32Array, Int32Type, value, result)
}

#[test]
fn length_test_large_list() -> Result<()> {
let value = vec![
Some(vec![]),
Some(vec![Some(1.1), Some(2.2), Some(3.3)]),
Some(vec![None]),
];
let result: Vec<i64> = vec![0, 3, 1];
length_list_helper!(i64, Int64Array, Float32Type, value, result)
}

type OptionStr = Option<&'static str>;

fn length_null_cases_string() -> Vec<(Vec<OptionStr>, usize, Vec<Option<i32>>)> {
Expand Down Expand Up @@ -293,6 +347,30 @@ mod tests {
length_binary_helper!(i64, Int64Array, length, value, result)
}

#[test]
fn length_null_list() -> Result<()> {
let value = vec![
Some(vec![]),
None,
Some(vec![Some(1), None, Some(2), Some(4)]),
Some(vec![Some(0)]),
];
let result: Vec<Option<i32>> = vec![Some(0), None, Some(4), Some(1)];
length_list_helper!(i32, Int32Array, Int8Type, value, result)
}

#[test]
fn length_null_large_list() -> Result<()> {
let value = vec![
Some(vec![]),
None,
Some(vec![Some(1.1), None, Some(4.0)]),
Some(vec![Some(0.1)]),
];
let result: Vec<Option<i64>> = vec![Some(0), None, Some(3), Some(1)];
length_list_helper!(i64, Int64Array, Float32Type, value, result)
}

/// Tests that length is not valid for u64.
#[test]
fn length_wrong_type() {
Expand All @@ -303,7 +381,7 @@ mod tests {

/// Tests with an offset
#[test]
fn length_offsets() -> Result<()> {
fn length_offsets_string() -> Result<()> {
let a = StringArray::from(vec![Some("hello"), Some(" "), Some("world"), None]);
let b = a.slice(1, 3);
let result = length(b.as_ref())?;
Expand All @@ -316,7 +394,7 @@ mod tests {
}

#[test]
fn binary_length_offsets() -> Result<()> {
fn length_offsets_binary() -> Result<()> {
let value: Vec<Option<&[u8]>> =
vec![Some(b"hello"), Some(b" "), Some(&[0xff, 0xf8]), None];
let a = BinaryArray::from(value);
Expand Down

0 comments on commit f1eba3c

Please sign in to comment.