Skip to content

Commit

Permalink
feat: Support array_sort(list_sort) (#8279)
Browse files Browse the repository at this point in the history
* Minor: Improve the document format of JoinHashMap

* list sort

* fix: example doc

* fix: ci

* fix: doc error

* fix pb

* like DuckDB function semantics

* fix ci

* fix pb

* fix: doc

* add table test

* fix: not as expected

* fix: return null

* resolve conflicts

* doc

* merge
  • Loading branch information
Asura7969 authored Dec 6, 2023
1 parent fa8a0d9 commit d9d8ddd
Show file tree
Hide file tree
Showing 11 changed files with 194 additions and 12 deletions.
8 changes: 8 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ pub enum BuiltinScalarFunction {
// array functions
/// array_append
ArrayAppend,
/// array_sort
ArraySort,
/// array_concat
ArrayConcat,
/// array_has
Expand Down Expand Up @@ -398,6 +400,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Tanh => Volatility::Immutable,
BuiltinScalarFunction::Trunc => Volatility::Immutable,
BuiltinScalarFunction::ArrayAppend => Volatility::Immutable,
BuiltinScalarFunction::ArraySort => Volatility::Immutable,
BuiltinScalarFunction::ArrayConcat => Volatility::Immutable,
BuiltinScalarFunction::ArrayEmpty => Volatility::Immutable,
BuiltinScalarFunction::ArrayHasAll => Volatility::Immutable,
Expand Down Expand Up @@ -545,6 +548,7 @@ impl BuiltinScalarFunction {
Ok(data_type)
}
BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArraySort => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayConcat => {
let mut expr_type = Null;
let mut max_dims = 0;
Expand Down Expand Up @@ -909,6 +913,9 @@ impl BuiltinScalarFunction {
// for now, the list is small, as we do not have many built-in functions.
match self {
BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArraySort => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayConcat => {
Expand Down Expand Up @@ -1558,6 +1565,7 @@ impl BuiltinScalarFunction {
"array_push_back",
"list_push_back",
],
BuiltinScalarFunction::ArraySort => &["array_sort", "list_sort"],
BuiltinScalarFunction::ArrayConcat => {
&["array_concat", "array_cat", "list_concat", "list_cat"]
}
Expand Down
3 changes: 3 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,8 @@ scalar_expr!(
"appends an element to the end of an array."
);

scalar_expr!(ArraySort, array_sort, array desc null_first, "returns sorted array.");

scalar_expr!(
ArrayPopBack,
array_pop_back,
Expand Down Expand Up @@ -1184,6 +1186,7 @@ mod test {
test_scalar_expr!(FromUnixtime, from_unixtime, unixtime);

test_scalar_expr!(ArrayAppend, array_append, array, element);
test_scalar_expr!(ArraySort, array_sort, array, desc, null_first);
test_scalar_expr!(ArrayPopFront, array_pop_front, array);
test_scalar_expr!(ArrayPopBack, array_pop_back, array);
test_unary_scalar_expr!(ArrayDims, array_dims);
Expand Down
83 changes: 81 additions & 2 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use arrow::datatypes::{DataType, Field, UInt64Type};
use arrow::row::{RowConverter, SortField};
use arrow_buffer::NullBuffer;

use arrow_schema::FieldRef;
use arrow_schema::{FieldRef, SortOptions};
use datafusion_common::cast::{
as_generic_list_array, as_generic_string_array, as_int64_array, as_list_array,
as_null_array, as_string_array,
Expand Down Expand Up @@ -693,7 +693,7 @@ fn general_append_and_prepend(
/// # Arguments
///
/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values.
///
///
/// # Examples
///
/// gen_range(3) => [0, 1, 2]
Expand Down Expand Up @@ -777,6 +777,85 @@ pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(res)
}

/// Array_sort SQL function
pub fn array_sort(args: &[ArrayRef]) -> Result<ArrayRef> {
let sort_option = match args.len() {
1 => None,
2 => {
let sort = as_string_array(&args[1])?.value(0);
Some(SortOptions {
descending: order_desc(sort)?,
nulls_first: true,
})
}
3 => {
let sort = as_string_array(&args[1])?.value(0);
let nulls_first = as_string_array(&args[2])?.value(0);
Some(SortOptions {
descending: order_desc(sort)?,
nulls_first: order_nulls_first(nulls_first)?,
})
}
_ => return internal_err!("array_sort expects 1 to 3 arguments"),
};

let list_array = as_list_array(&args[0])?;
let row_count = list_array.len();

let mut array_lengths = vec![];
let mut arrays = vec![];
let mut valid = BooleanBufferBuilder::new(row_count);
for i in 0..row_count {
if list_array.is_null(i) {
array_lengths.push(0);
valid.append(false);
} else {
let arr_ref = list_array.value(i);
let arr_ref = arr_ref.as_ref();

let sorted_array = compute::sort(arr_ref, sort_option)?;
array_lengths.push(sorted_array.len());
arrays.push(sorted_array);
valid.append(true);
}
}

// Assume all arrays have the same data type
let data_type = list_array.value_type();
let buffer = valid.finish();

let elements = arrays
.iter()
.map(|a| a.as_ref())
.collect::<Vec<&dyn Array>>();

let list_arr = ListArray::new(
Arc::new(Field::new("item", data_type, true)),
OffsetBuffer::from_lengths(array_lengths),
Arc::new(compute::concat(elements.as_slice())?),
Some(NullBuffer::new(buffer)),
);
Ok(Arc::new(list_arr))
}

fn order_desc(modifier: &str) -> Result<bool> {
match modifier.to_uppercase().as_str() {
"DESC" => Ok(true),
"ASC" => Ok(false),
_ => internal_err!("the second parameter of array_sort expects DESC or ASC"),
}
}

fn order_nulls_first(modifier: &str) -> Result<bool> {
match modifier.to_uppercase().as_str() {
"NULLS FIRST" => Ok(true),
"NULLS LAST" => Ok(false),
_ => internal_err!(
"the third parameter of array_sort expects NULLS FIRST or NULLS LAST"
),
}
}

/// Array_prepend SQL function
pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_list_array(&args[1])?;
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::ArrayAppend => {
Arc::new(|args| make_scalar_function(array_expressions::array_append)(args))
}
BuiltinScalarFunction::ArraySort => {
Arc::new(|args| make_scalar_function(array_expressions::array_sort)(args))
}
BuiltinScalarFunction::ArrayConcat => {
Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,7 @@ enum ScalarFunction {
Levenshtein = 125;
SubstrIndex = 126;
FindInSet = 127;
ArraySort = 128;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 11 additions & 4 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ use datafusion_expr::{
array_except, array_has, array_has_all, array_has_any, array_intersect, array_length,
array_ndims, array_position, array_positions, array_prepend, array_remove,
array_remove_all, array_remove_n, array_repeat, array_replace, array_replace_all,
array_replace_n, array_slice, array_to_string, arrow_typeof, ascii, asin, asinh,
atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length,
chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date,
current_time, date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp,
array_replace_n, array_slice, array_sort, array_to_string, arrow_typeof, ascii, asin,
asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil,
character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot,
current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest,
encode, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero,
lcm, left, levenshtein, ln, log, log10, log2,
Expand Down Expand Up @@ -463,6 +464,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::Rtrim => Self::Rtrim,
ScalarFunction::ToTimestamp => Self::ToTimestamp,
ScalarFunction::ArrayAppend => Self::ArrayAppend,
ScalarFunction::ArraySort => Self::ArraySort,
ScalarFunction::ArrayConcat => Self::ArrayConcat,
ScalarFunction::ArrayEmpty => Self::ArrayEmpty,
ScalarFunction::ArrayExcept => Self::ArrayExcept,
Expand Down Expand Up @@ -1343,6 +1345,11 @@ pub fn parse_expr(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
)),
ScalarFunction::ArraySort => Ok(array_sort(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
parse_expr(&args[2], registry)?,
)),
ScalarFunction::ArrayPopFront => {
Ok(array_pop_front(parse_expr(&args[0], registry)?))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::Rtrim => Self::Rtrim,
BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp,
BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend,
BuiltinScalarFunction::ArraySort => Self::ArraySort,
BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat,
BuiltinScalarFunction::ArrayEmpty => Self::ArrayEmpty,
BuiltinScalarFunction::ArrayExcept => Self::ArrayExcept,
Expand Down
50 changes: 44 additions & 6 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,44 @@ select make_array(['a','b'], null);
----
[[a, b], ]

## array_sort (aliases: `list_sort`)
query ???
select array_sort(make_array(1, 3, null, 5, NULL, -5)), array_sort(make_array(1, 3, null, 2), 'ASC'), array_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST');
----
[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1]

query ?
select array_sort(column1, 'DESC', 'NULLS LAST') from arrays_values;
----
[10, 9, 8, 7, 6, 5, 4, 3, 2, ]
[20, 18, 17, 16, 15, 14, 13, 12, 11, ]
[30, 29, 28, 27, 26, 25, 23, 22, 21, ]
[40, 39, 38, 37, 35, 34, 33, 32, 31, ]
NULL
[50, 49, 48, 47, 46, 45, 44, 43, 42, 41]
[60, 59, 58, 57, 56, 55, 54, 52, 51, ]
[70, 69, 68, 67, 66, 65, 64, 63, 62, 61]

query ?
select array_sort(column1, 'ASC', 'NULLS FIRST') from arrays_values;
----
[, 2, 3, 4, 5, 6, 7, 8, 9, 10]
[, 11, 12, 13, 14, 15, 16, 17, 18, 20]
[, 21, 22, 23, 25, 26, 27, 28, 29, 30]
[, 31, 32, 33, 34, 35, 37, 38, 39, 40]
NULL
[41, 42, 43, 44, 45, 46, 47, 48, 49, 50]
[, 51, 52, 54, 55, 56, 57, 58, 59, 60]
[61, 62, 63, 64, 65, 66, 67, 68, 69, 70]


## list_sort (aliases: `array_sort`)
query ???
select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3, null, 2), 'ASC'), list_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST');
----
[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1]


## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`)

# TODO: array_append with NULLs
Expand Down Expand Up @@ -1224,7 +1262,7 @@ select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, ma

# array_repeat scalar function #1
query ????????
select
select
array_repeat(1, 5),
array_repeat(3.14, 3),
array_repeat('l', 4),
Expand Down Expand Up @@ -1257,7 +1295,7 @@ AS VALUES
(0, 3, 3.3, 'datafusion', make_array(8, 9));

query ??????
select
select
array_repeat(column2, column1),
array_repeat(column3, column1),
array_repeat(column4, column1),
Expand All @@ -1272,7 +1310,7 @@ from array_repeat_table;
[] [] [] [] [3, 3, 3] []

statement ok
drop table array_repeat_table;
drop table array_repeat_table;

## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`)

Expand Down Expand Up @@ -2188,7 +2226,7 @@ select array_remove(make_array(1, 2, 2, 1, 1), 2), array_remove(make_array(1.0,
[1, 2, 1, 1] [2.0, 2.0, 1.0, 1.0] [h, e, l, o]

query ???
select
select
array_remove(make_array(1, null, 2, 3), 2),
array_remove(make_array(1.1, null, 2.2, 3.3), 1.1),
array_remove(make_array('a', null, 'bc'), 'a');
Expand Down Expand Up @@ -2887,7 +2925,7 @@ from array_intersect_table_3D;
query ??????
SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)),
array_intersect(make_array(1,3,5), make_array(2,4,6)),
array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')),
array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')),
array_intersect(make_array(true, false), make_array(true)),
array_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)),
array_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4]))
Expand Down Expand Up @@ -2918,7 +2956,7 @@ NULL
query ??????
SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)),
list_intersect(make_array(1,3,5), make_array(2,4,6)),
list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')),
list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')),
list_intersect(make_array(true, false), make_array(true)),
list_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)),
list_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4]))
Expand Down
Loading

0 comments on commit d9d8ddd

Please sign in to comment.