Skip to content

Commit

Permalink
Support DictionaryArray in unary kernel (#1990)
Browse files Browse the repository at this point in the history
* Init

* More

* Fix clippy

* Apply on dictionary values directly in unary_dict.

* Fix clippy

* Avoid validate when constructing new dictionary array
  • Loading branch information
viirya authored Jul 6, 2022
1 parent b156cce commit 62053a8
Showing 1 changed file with 177 additions and 4 deletions.
181 changes: 177 additions & 4 deletions arrow/src/compute/kernels/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

//! Defines kernels suitable to perform operations to primitive arrays.
use crate::array::{Array, ArrayData, PrimitiveArray};
use crate::array::{Array, ArrayData, ArrayRef, DictionaryArray, PrimitiveArray};
use crate::buffer::Buffer;
use crate::datatypes::ArrowPrimitiveType;
use crate::datatypes::{
ArrowNumericType, ArrowPrimitiveType, DataType, Int16Type, Int32Type, Int64Type,
Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use crate::error::{ArrowError, Result};
use std::sync::Arc;

#[inline]
fn into_primitive_array_data<I: ArrowPrimitiveType, O: ArrowPrimitiveType>(
Expand Down Expand Up @@ -78,10 +83,128 @@ where
PrimitiveArray::<O>::from(data)
}

/// A helper function that applies an unary function to a dictionary array with primitive value type.
#[allow(clippy::redundant_closure)]
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
where
K: ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
let dict_values = array
.values()
.as_any()
.downcast_ref::<PrimitiveArray<T>>()
.unwrap();

let values = dict_values
.iter()
.map(|v| v.map(|value| op(value)))
.collect::<PrimitiveArray<T>>();

let keys = array.keys();

let mut data = ArrayData::builder(array.data_type().clone())
.len(keys.len())
.add_buffer(keys.data().buffers()[0].clone())
.add_child_data(values.data().clone());

match keys.data().null_buffer() {
Some(buffer) if keys.data().null_count() > 0 => {
data = data
.null_bit_buffer(Some(buffer.clone()))
.null_count(keys.data().null_count());
}
_ => data = data.null_count(0),
}

let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() }.into();
Ok(Arc::new(new_dict))
}

/// Applies an unary function to an array with primitive values.
pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
where
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
match array.data_type() {
DataType::Dictionary(key_type, _) => match key_type.as_ref() {
DataType::Int8 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap(),
op,
),
DataType::Int16 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<Int16Type>>()
.unwrap(),
op,
),
DataType::Int32 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.unwrap(),
op,
),
DataType::Int64 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<Int64Type>>()
.unwrap(),
op,
),
DataType::UInt8 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt8Type>>()
.unwrap(),
op,
),
DataType::UInt16 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt16Type>>()
.unwrap(),
op,
),
DataType::UInt32 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt32Type>>()
.unwrap(),
op,
),
DataType::UInt64 => unary_dict::<_, F, T>(
array
.as_any()
.downcast_ref::<DictionaryArray<UInt64Type>>()
.unwrap(),
op,
),
t => Err(ArrowError::NotYetImplemented(format!(
"Cannot perform unary operation on dictionary array of key type {}.",
t
))),
},
_ => Ok(Arc::new(unary::<T, F, T>(
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
op,
))),
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::array::{as_primitive_array, Float64Array};
use crate::array::{
as_primitive_array, Float64Array, PrimitiveBuilder, PrimitiveDictionaryBuilder,
};
use crate::datatypes::{Float64Type, Int32Type, Int8Type};

#[test]
fn test_unary_f64_slice() {
Expand All @@ -93,6 +216,56 @@ mod tests {
assert_eq!(
result,
Float64Array::from(vec![None, Some(7.0), None, Some(7.0)])
)
);

let result = unary_dyn::<_, Float64Type>(input_slice, |n| n + 1.0).unwrap();

assert_eq!(
result.as_any().downcast_ref::<Float64Array>().unwrap(),
&Float64Array::from(vec![None, Some(7.8), None, Some(8.2)])
);
}

#[test]
fn test_unary_dict_and_unary_dyn() {
let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
builder.append(5).unwrap();
builder.append(6).unwrap();
builder.append(7).unwrap();
builder.append(8).unwrap();
builder.append_null().unwrap();
builder.append(9).unwrap();
let dictionary_array = builder.finish();

let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
builder.append(6).unwrap();
builder.append(7).unwrap();
builder.append(8).unwrap();
builder.append(9).unwrap();
builder.append_null().unwrap();
builder.append(10).unwrap();
let expected = builder.finish();

let result = unary_dict::<_, _, Int32Type>(&dictionary_array, |n| n + 1).unwrap();
assert_eq!(
result
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap(),
&expected
);

let result = unary_dyn::<_, Int32Type>(&dictionary_array, |n| n + 1).unwrap();
assert_eq!(
result
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap(),
&expected
);
}
}

0 comments on commit 62053a8

Please sign in to comment.