Skip to content

Commit

Permalink
feat: support LargeList in flatten (#9110)
Browse files Browse the repository at this point in the history
* support FixedSizeList in flatten

* Refactor flatten function and add test cases

* remove redundant tests
  • Loading branch information
Weijun-H authored Feb 5, 2024
1 parent e61c71d commit 195f825
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 25 deletions.
9 changes: 4 additions & 5 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,10 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Flatten => {
fn get_base_type(data_type: &DataType) -> Result<DataType> {
match data_type {
DataType::List(field) => match field.data_type() {
DataType::List(_) => get_base_type(field.data_type()),
_ => Ok(data_type.to_owned()),
},
_ => internal_err!("Not reachable, data_type should be List"),
DataType::List(field) if matches!(field.data_type(), DataType::List(_)) => get_base_type(field.data_type()),
DataType::LargeList(field) if matches!(field.data_type(), DataType::LargeList(_)) => get_base_type(field.data_type()),
DataType::Null | DataType::List(_) | DataType::LargeList(_) => Ok(data_type.to_owned()),
_ => internal_err!("Not reachable, data_type should be List or LargeList"),
}
}

Expand Down
52 changes: 36 additions & 16 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2246,38 +2246,41 @@ fn generic_list_cardinality<O: OffsetSizeTrait>(
}

// Create new offsets that are euqiavlent to `flatten` the array.
fn get_offsets_for_flatten(
offsets: OffsetBuffer<i32>,
indexes: OffsetBuffer<i32>,
) -> OffsetBuffer<i32> {
fn get_offsets_for_flatten<O: OffsetSizeTrait>(
offsets: OffsetBuffer<O>,
indexes: OffsetBuffer<O>,
) -> OffsetBuffer<O> {
let buffer = offsets.into_inner();
let offsets: Vec<i32> = indexes.iter().map(|i| buffer[*i as usize]).collect();
let offsets: Vec<O> = indexes
.iter()
.map(|i| buffer[i.to_usize().unwrap()])
.collect();
OffsetBuffer::new(offsets.into())
}

fn flatten_internal(
array: &dyn Array,
indexes: Option<OffsetBuffer<i32>>,
) -> Result<ListArray> {
let list_arr = as_list_array(array)?;
fn flatten_internal<O: OffsetSizeTrait>(
list_arr: GenericListArray<O>,
indexes: Option<OffsetBuffer<O>>,
) -> Result<GenericListArray<O>> {
let (field, offsets, values, _) = list_arr.clone().into_parts();
let data_type = field.data_type();

match data_type {
// Recursively get the base offsets for flattened array
DataType::List(_) => {
DataType::List(_) | DataType::LargeList(_) => {
let sub_list = as_generic_list_array::<O>(&values)?;
if let Some(indexes) = indexes {
let offsets = get_offsets_for_flatten(offsets, indexes);
flatten_internal(&values, Some(offsets))
flatten_internal::<O>(sub_list.clone(), Some(offsets))
} else {
flatten_internal(&values, Some(offsets))
flatten_internal::<O>(sub_list.clone(), Some(offsets))
}
}
// Reach the base level, create a new list array
_ => {
if let Some(indexes) = indexes {
let offsets = get_offsets_for_flatten(offsets, indexes);
let list_arr = ListArray::new(field, offsets, values, None);
let list_arr = GenericListArray::<O>::new(field, offsets, values, None);
Ok(list_arr)
} else {
Ok(list_arr.clone())
Expand All @@ -2292,8 +2295,25 @@ pub fn flatten(args: &[ArrayRef]) -> Result<ArrayRef> {
return exec_err!("flatten expects one argument");
}

let flattened_array = flatten_internal(&args[0], None)?;
Ok(Arc::new(flattened_array) as ArrayRef)
let array_type = args[0].data_type();
match array_type {
DataType::List(_) => {
let list_arr = as_list_array(&args[0])?;
let flattened_array = flatten_internal::<i32>(list_arr.clone(), None)?;
Ok(Arc::new(flattened_array) as ArrayRef)
}
DataType::LargeList(_) => {
let list_arr = as_large_list_array(&args[0])?;
let flattened_array = flatten_internal::<i64>(list_arr.clone(), None)?;
Ok(Arc::new(flattened_array) as ArrayRef)
}
DataType::Null => Ok(args[0].clone()),
_ => {
exec_err!("flatten does not support type '{array_type:?}'")
}
}

// Ok(Arc::new(flattened_array) as ArrayRef)
}

/// Dispatch array length computation based on the offset type.
Expand Down
42 changes: 38 additions & 4 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,17 @@ AS VALUES
(make_array([1, 2], [3, 4], [5, 6]), make_array([[8]]), make_array([[[1,2]]], [[[3]]]), make_array([1.0, 2.0], [3.0, 4.0], [5.0, 6.0]))
;

statement ok
CREATE TABLE large_flatten_table
AS
SELECT
arrow_cast(column1, 'LargeList(LargeList(Int64))') AS column1,
arrow_cast(column2, 'LargeList(LargeList(LargeList(Int64)))') AS column2,
arrow_cast(column3, 'LargeList(LargeList(LargeList(LargeList(Int64))))') AS column3,
arrow_cast(column4, 'LargeList(LargeList(Float64))') AS column4
FROM flatten_table
;

statement ok
CREATE TABLE array_has_table_1D
AS VALUES
Expand Down Expand Up @@ -5345,19 +5356,28 @@ select array_concat(column1, [7]) from arrays_values_v2;
[7]

# flatten
# follow DuckDB
query ?
select flatten(NULL);
----
NULL

# flatten with scalar values #1
query ???
select flatten(make_array(1, 2, 1, 3, 2)),
flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))),
flatten(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]));
----
[1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4]

query ????
select column1, column2, column3, column4 from flatten_table;
query ???
select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'LargeList(Int64)')),
flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 5)), 'LargeList(LargeList(Int64))')),
flatten(arrow_cast(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]), 'LargeList(LargeList(LargeList(Float64)))'));
----
[[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0], [2.1, 2.2], [3.2, 3.3, 3.4]]
[[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]
[1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4]

# flatten with column values
query ????
select flatten(column1),
flatten(column2),
Expand All @@ -5368,6 +5388,17 @@ from flatten_table;
[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]

query ????
select flatten(column1),
flatten(column2),
flatten(column3),
flatten(column4)
from large_flatten_table;
----
[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]

## empty
# empty scalar function #1
query B
select empty(make_array(1));
Expand Down Expand Up @@ -5746,6 +5777,9 @@ drop table fixed_size_nested_arrays_with_repeating_elements;
statement ok
drop table flatten_table;

statement ok
drop table large_flatten_table;

statement ok
drop table arrays_values_without_nulls;

Expand Down

0 comments on commit 195f825

Please sign in to comment.