Skip to content

Commit

Permalink
support LargeList in array_prepend and array_append (#8679)
Browse files Browse the repository at this point in the history
* support largelist

* fix cast error

* fix cast

* add tests

* fix conflict

* s TODO comment for future tests

add TODO comment for future tests

---------

Co-authored-by: hwj <[email protected]>
  • Loading branch information
Weijun-H and hwj authored Jan 2, 2024
1 parent c1fe3dd commit 67baf10
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 91 deletions.
23 changes: 19 additions & 4 deletions datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,11 @@ pub fn arrays_into_list_array(
/// assert_eq!(base_type(&data_type), DataType::Int32);
/// ```
pub fn base_type(data_type: &DataType) -> DataType {
if let DataType::List(field) = data_type {
base_type(field.data_type())
} else {
data_type.to_owned()
match data_type {
DataType::List(field) | DataType::LargeList(field) => {
base_type(field.data_type())
}
_ => data_type.to_owned(),
}
}

Expand Down Expand Up @@ -462,6 +463,20 @@ pub fn coerced_type_with_base_type_only(
field.is_nullable(),
)))
}
DataType::LargeList(field) => {
let data_type = match field.data_type() {
DataType::LargeList(_) => {
coerced_type_with_base_type_only(field.data_type(), base_type)
}
_ => base_type.to_owned(),
};

DataType::LargeList(Arc::new(Field::new(
field.name(),
data_type,
field.is_nullable(),
)))
}

_ => base_type.clone(),
}
Expand Down
24 changes: 12 additions & 12 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,18 @@ fn get_valid_types(
&new_base_type,
);

if let DataType::List(ref field) = array_type {
let elem_type = field.data_type();
if is_append {
Ok(vec![vec![array_type.clone(), elem_type.to_owned()]])
} else {
Ok(vec![vec![elem_type.to_owned(), array_type.clone()]])
match array_type {
DataType::List(ref field) | DataType::LargeList(ref field) => {
let elem_type = field.data_type();
if is_append {
Ok(vec![vec![array_type.clone(), elem_type.to_owned()]])
} else {
Ok(vec![vec![elem_type.to_owned(), array_type.clone()]])
}
}
} else {
Ok(vec![vec![]])
_ => Ok(vec![vec![]]),
}
}

let valid_types = match signature {
TypeSignature::Variadic(valid_types) => valid_types
.iter()
Expand Down Expand Up @@ -311,9 +311,9 @@ fn coerced_from<'a>(
Utf8 | LargeUtf8 => Some(type_into.clone()),
Null if can_cast_types(type_from, type_into) => Some(type_into.clone()),

// Only accept list with the same number of dimensions unless the type is Null.
// List with different dimensions should be handled in TypeSignature or other places before this.
List(_)
// Only accept list and largelist with the same number of dimensions unless the type is Null.
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this.
List(_) | LargeList(_)
if datafusion_common::utils::base_type(type_from).eq(&Null)
|| list_ndims(type_from) == list_ndims(type_into) =>
{
Expand Down
144 changes: 71 additions & 73 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,6 @@ macro_rules! downcast_arg {
}};
}

/// Downcasts multiple arguments into a single concrete type
/// $ARGS: &[ArrayRef]
/// $ARRAY_TYPE: type to downcast to
///
/// $returns a Vec<$ARRAY_TYPE>
macro_rules! downcast_vec {
($ARGS:expr, $ARRAY_TYPE:ident) => {{
$ARGS
.iter()
.map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() {
Some(array) => Ok(array),
_ => internal_err!("failed to downcast"),
})
}};
}

/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array.
///
/// # Arguments
Expand Down Expand Up @@ -832,17 +816,20 @@ pub fn array_pop_back(args: &[ArrayRef]) -> Result<ArrayRef> {
///
/// # Examples
///
/// general_append_and_prepend(
/// generic_append_and_prepend(
/// [1, 2, 3], 4, append => [1, 2, 3, 4]
/// 5, [6, 7, 8], prepend => [5, 6, 7, 8]
/// )
fn general_append_and_prepend(
list_array: &ListArray,
fn generic_append_and_prepend<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
element_array: &ArrayRef,
data_type: &DataType,
is_append: bool,
) -> Result<ArrayRef> {
let mut offsets = vec![0];
) -> Result<ArrayRef>
where
i64: TryInto<O>,
{
let mut offsets = vec![O::usize_as(0)];
let values = list_array.values();
let original_data = values.to_data();
let element_data = element_array.to_data();
Expand All @@ -858,21 +845,21 @@ fn general_append_and_prepend(
let element_index = 1;

for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
let start = offset_window[0] as usize;
let end = offset_window[1] as usize;
let start = offset_window[0].to_usize().unwrap();
let end = offset_window[1].to_usize().unwrap();
if is_append {
mutable.extend(values_index, start, end);
mutable.extend(element_index, row_index, row_index + 1);
} else {
mutable.extend(element_index, row_index, row_index + 1);
mutable.extend(values_index, start, end);
}
offsets.push(offsets[row_index] + (end - start + 1) as i32);
offsets.push(offsets[row_index] + O::usize_as(end - start + 1));
}

let data = mutable.freeze();

Ok(Arc::new(ListArray::try_new(
Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new("item", data_type.to_owned(), true)),
OffsetBuffer::new(offsets.into()),
arrow_array::make_array(data),
Expand Down Expand Up @@ -938,36 +925,6 @@ pub fn gen_range(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(arr)
}

/// Array_append SQL function
pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_append expects two arguments");
}

let list_array = as_list_array(&args[0])?;
let element_array = &args[1];

let res = match list_array.value_type() {
DataType::List(_) => concat_internal(args)?,
DataType::Null => {
return make_array(&[
list_array.values().to_owned(),
element_array.to_owned(),
]);
}
data_type => {
return general_append_and_prepend(
list_array,
element_array,
&data_type,
true,
);
}
};

Ok(res)
}

/// Array_sort SQL function
pub fn array_sort(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.is_empty() || args.len() > 3 {
Expand Down Expand Up @@ -1051,32 +1008,71 @@ fn order_nulls_first(modifier: &str) -> Result<bool> {
}
}

/// Array_prepend SQL function
pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_prepend expects two arguments");
}

let list_array = as_list_array(&args[1])?;
let element_array = &args[0];
fn general_append_and_prepend<O: OffsetSizeTrait>(
args: &[ArrayRef],
is_append: bool,
) -> Result<ArrayRef>
where
i64: TryInto<O>,
{
let (list_array, element_array) = if is_append {
let list_array = as_generic_list_array::<O>(&args[0])?;
let element_array = &args[1];
check_datatypes("array_append", &[element_array, list_array.values()])?;
(list_array, element_array)
} else {
let list_array = as_generic_list_array::<O>(&args[1])?;
let element_array = &args[0];
check_datatypes("array_prepend", &[list_array.values(), element_array])?;
(list_array, element_array)
};

check_datatypes("array_prepend", &[element_array, list_array.values()])?;
let res = match list_array.value_type() {
DataType::List(_) => concat_internal(args)?,
DataType::Null => return make_array(&[element_array.to_owned()]),
DataType::List(_) => concat_internal::<i32>(args)?,
DataType::LargeList(_) => concat_internal::<i64>(args)?,
DataType::Null => {
return make_array(&[
list_array.values().to_owned(),
element_array.to_owned(),
]);
}
data_type => {
return general_append_and_prepend(
return generic_append_and_prepend::<O>(
list_array,
element_array,
&data_type,
false,
is_append,
);
}
};

Ok(res)
}

/// Array_append SQL function
pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_append expects two arguments");
}

match args[0].data_type() {
DataType::LargeList(_) => general_append_and_prepend::<i64>(args, true),
_ => general_append_and_prepend::<i32>(args, true),
}
}

/// Array_prepend SQL function
pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_prepend expects two arguments");
}

match args[1].data_type() {
DataType::LargeList(_) => general_append_and_prepend::<i64>(args, false),
_ => general_append_and_prepend::<i32>(args, false),
}
}

fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
let args_ndim = args
.iter()
Expand Down Expand Up @@ -1114,11 +1110,13 @@ fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
}

// Concatenate arrays on the same row.
fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
fn concat_internal<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let args = align_array_dimensions(args.to_vec())?;

let list_arrays =
downcast_vec!(args, ListArray).collect::<Result<Vec<&ListArray>>>()?;
let list_arrays = args
.iter()
.map(|arg| as_generic_list_array::<O>(arg))
.collect::<Result<Vec<_>>>()?;

// Assume number of rows is the same for all arrays
let row_count = list_arrays[0].len();
Expand Down Expand Up @@ -1165,7 +1163,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
.map(|a| a.as_ref())
.collect::<Vec<&dyn Array>>();

let list_arr = ListArray::new(
let list_arr = GenericListArray::<O>::new(
Arc::new(Field::new("item", data_type, true)),
OffsetBuffer::from_lengths(array_lengths),
Arc::new(compute::concat(elements.as_slice())?),
Expand All @@ -1192,7 +1190,7 @@ pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

concat_internal(new_args.as_slice())
concat_internal::<i32>(new_args.as_slice())
}

/// Array_empty SQL function
Expand Down
Loading

0 comments on commit 67baf10

Please sign in to comment.