diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index 6069ee0e9360..9e117b6d053f 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -86,6 +86,60 @@ AS VALUES (NULL, NULL, NULL, NULL) ; +statement ok +CREATE TABLE array_has_table_1D +AS VALUES + (make_array(1, 2), 1, make_array(1,2,3), make_array(1,3), make_array(1,3,5), make_array(2,4,6,8,1,3,5)), + (make_array(3, 4, 5), 2, make_array(1,2,3,4), make_array(2,5), make_array(2,4,6), make_array(1,3,5)) +; + +statement ok +CREATE TABLE array_has_table_1D_Float +AS VALUES + (make_array(1.0, 2.0), 1.0, make_array(1.0,2.0,3.0), make_array(1.0,3.0), make_array(1.11), make_array(2.22, 3.33)), + (make_array(3.0, 4.0, 5.0), 2.0, make_array(1.0,2.0,3.0,4.0), make_array(2.0,5.0), make_array(2.22, 1.11), make_array(1.11, 3.33)) +; + +statement ok +CREATE TABLE array_has_table_1D_Boolean +AS VALUES + (make_array(true, true, true), false, make_array(true, true, false, true, false), make_array(true, false, true), make_array(false), make_array(true, false)), + (make_array(false, false, false), false, make_array(true, false, true), make_array(true, true), make_array(true, true), make_array(false,false,true)) +; + +statement ok +CREATE TABLE array_has_table_1D_UTF8 +AS VALUES + (make_array('a', 'bc', 'def'), 'bc', make_array('datafusion', 'rust', 'arrow'), make_array('rust', 'arrow'), make_array('rust', 'arrow', 'python'), make_array('data')), + (make_array('a', 'bc', 'def'), 'defg', make_array('datafusion', 'rust', 'arrow'), make_array('datafusion', 'rust', 'arrow', 'python'), make_array('rust', 'arrow'), make_array('datafusion', 'rust', 'arrow')) +; + +statement ok +CREATE TABLE array_has_table_2D +AS VALUES + (make_array([1,2]), make_array(1,3), make_array([1,2,3], [4,5], [6,7]), make_array([4,5], [6,7])), + (make_array([3,4], [5]), make_array(5), make_array([1,2,3,4], [5,6,7], [8,9,10]), make_array([1,2,3], [5,6,7], [8,9,10])) +; + +statement ok +CREATE TABLE array_has_table_2D_float +AS VALUES + (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.1, 2.2], [3.3])), + (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.0], [1.1, 2.2], [3.3])) +; + +statement ok +CREATE TABLE array_has_table_3D +AS VALUES + (make_array([[1,2]]), make_array([1])), + (make_array([[1,2]]), make_array([1,2])), + (make_array([[1,2]]), make_array([1,2,3])), + (make_array([[1], [2]]), make_array([2])), + (make_array([[1], [2]]), make_array([1], [2])), + (make_array([[1], [2]], [[2], [3]]), make_array([1], [2], [3])), + (make_array([[1], [2]], [[2], [3]]), make_array([1], [2])) +; + statement ok CREATE TABLE arrays_values_without_nulls AS VALUES @@ -1164,48 +1218,129 @@ NULL 1 1 2 NULL 1 2 1 NULL -## array_contains +## array_has/array_has_all/array_has_any + +query BBBBBBBBBBBB +select array_has(make_array(1,2), 1), + array_has(make_array(1,2,NULL), 1), + array_has(make_array([2,3], [3,4]), make_array(2,3)), + array_has(make_array([[1], [2,3]], [[4,5], [6]]), make_array([1], [2,3])), + array_has(make_array([[1], [2,3]], [[4,5], [6]]), make_array([4,5], [6])), + array_has(make_array([[1], [2,3]], [[4,5], [6]]), make_array([1])), + array_has(make_array([[[1]]]), make_array([[1]])), + array_has(make_array([[[1]]], [[[1], [2]]]), make_array([[2]])), + array_has(make_array([[[1]]], [[[1], [2]]]), make_array([[1], [2]])), + list_has(make_array(1,2,3), 4), + array_contains(make_array(1,2,3), 3), + list_contains(make_array(1,2,3), 0) +; +---- +true true true true true false true false true false true false -# array_contains scalar function #1 query BBB -select array_contains(make_array(1, 2, 3), make_array(1, 1, 2, 3)), array_contains([1, 2, 3], [1, 1, 2]), array_contains([1, 2, 3], [2, 1, 3, 1]); +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D; ---- true true true +false false false -# array_contains scalar function #2 -query BB -select array_contains([[1, 2], [3, 4]], [[1, 2], [3, 4], [1, 3]]), array_contains([[[1], [2]], [[3], [4]]], [1, 2, 2, 3, 4]); +query BBB +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D_Float; ---- -true true +true true false +false false true -# array_contains scalar function #3 query BBB -select array_contains(make_array(1, 2, 3), make_array(1, 2, 3, 4)), array_contains([1, 2, 3], [1, 1, 4]), array_contains([1, 2, 3], [2, 1, 3, 4]); +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D_Boolean; ---- -false false false +false true true +true true true -# array_contains scalar function #4 -query BB -select array_contains([[1, 2], [3, 4]], [[1, 2], [3, 4], [1, 5]]), array_contains([[[1], [2]], [[3], [4]]], [1, 2, 2, 3, 5]); +query BBB +select array_has(column1, column2), + array_has_all(column3, column4), + array_has_any(column5, column6) +from array_has_table_1D_UTF8; ---- -false false +true true false +false false true -# array_contains scalar function #5 query BB -select array_contains([true, true, false, true, false], [true, false, false]), array_contains([true, false, true], [true, true]); +select array_has(column1, column2), + array_has_all(column3, column4) +from array_has_table_2D; +---- +false true +true false + +query B +select array_has_all(column1, column2) +from array_has_table_2D_float; +---- +true +false + +query B +select array_has(column1, column2) from array_has_table_3D; +---- +false +true +false +false +true +false +true + +query BBBB +select array_has(column1, make_array(5, 6)), + array_has(column1, make_array(7, NULL)), + array_has(column2, 5.5), + array_has(column3, 'o') +from arrays; +---- +false false false true +true false true false +true false false true +false true false false +false false false false +false false false false + +query BBBBBBBBBBBBB +select array_has_all(make_array(1,2,3), make_array(1,3)), + array_has_all(make_array(1,2,3), make_array(1,4)), + array_has_all(make_array([1,2], [3,4]), make_array([1,2])), + array_has_all(make_array([1,2], [3,4]), make_array([1,3])), + array_has_all(make_array([1,2], [3,4]), make_array([1,2], [3,4], [5,6])), + array_has_all(make_array([[1,2,3]]), make_array([[1]])), + array_has_all(make_array([[1,2,3]]), make_array([[1,2,3]])), + array_has_any(make_array(1,2,3), make_array(1,10,100)), + array_has_any(make_array(1,2,3), make_array(10,100)), + array_has_any(make_array([1,2], [3,4]), make_array([1,10], [10,4])), + array_has_any(make_array([1,2], [3,4]), make_array([10,20], [3,4])), + array_has_any(make_array([[1,2,3]]), make_array([[1,2,3], [4,5,6]])), + array_has_any(make_array([[1,2,3]]), make_array([[1,2,3]], [[4,5,6]])) +; ---- -true true +true false true false false false true true false false true false true -# array_contains scalar function #6 -query BB -select array_contains(make_array(true, true, true), make_array(false, false)), array_contains([false, false, false], [true, true]); +query BBBB +select list_has_all(make_array(1,2,3), make_array(4,5,6)), + list_has_all(make_array(1,2,3), make_array(1,2)), + list_has_any(make_array(1,2,3), make_array(4,5,6)), + list_has_any(make_array(1,2,3), make_array(1,2,4)) +; ---- -false false - +false true false true ### Array operators tests - - ## array concatenate operator # array concatenate operator with scalars #1 (like array_concat scalar function) @@ -1296,7 +1431,6 @@ select make_array(f0) from fixed_size_list_array [[1, 2], [3, 4]] - ### Delete tables @@ -1312,5 +1446,29 @@ drop table arrays; statement ok drop table arrays_values; +statement ok +drop table arrays_values_v2; + +statement ok +drop table array_has_table_1D; + +statement ok +drop table array_has_table_1D_Float; + +statement ok +drop table array_has_table_1D_Boolean; + +statement ok +drop table array_has_table_1D_UTF8; + +statement ok +drop table array_has_table_2D; + +statement ok +drop table array_has_table_2D_float; + +statement ok +drop table array_has_table_3D; + statement ok drop table arrays_values_without_nulls; diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index dded73c088b4..6914de686aa3 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -119,8 +119,12 @@ pub enum BuiltinScalarFunction { ArrayAppend, /// array_concat ArrayConcat, - /// array_contains - ArrayContains, + /// array_has + ArrayHas, + /// array_has_all + ArrayHasAll, + /// array_has_any + ArrayHasAny, /// array_dims ArrayDims, /// array_fill @@ -330,7 +334,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::ArrayAppend => Volatility::Immutable, BuiltinScalarFunction::ArrayConcat => Volatility::Immutable, - BuiltinScalarFunction::ArrayContains => Volatility::Immutable, + BuiltinScalarFunction::ArrayHasAll => Volatility::Immutable, + BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable, + BuiltinScalarFunction::ArrayHas => Volatility::Immutable, BuiltinScalarFunction::ArrayDims => Volatility::Immutable, BuiltinScalarFunction::ArrayFill => Volatility::Immutable, BuiltinScalarFunction::ArrayLength => Volatility::Immutable, @@ -501,7 +507,9 @@ impl BuiltinScalarFunction { Ok(expr_type) } - BuiltinScalarFunction::ArrayContains => Ok(Boolean), + BuiltinScalarFunction::ArrayHasAll + | BuiltinScalarFunction::ArrayHasAny + | BuiltinScalarFunction::ArrayHas => Ok(Boolean), BuiltinScalarFunction::ArrayDims => { Ok(List(Arc::new(Field::new("item", UInt64, true)))) } @@ -808,7 +816,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayConcat => { Signature::variadic_any(self.volatility()) } - BuiltinScalarFunction::ArrayContains => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayHasAll + | BuiltinScalarFunction::ArrayHasAny + | BuiltinScalarFunction::ArrayHas => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayFill => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayLength => { @@ -1278,8 +1288,12 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::ArrayConcat => { &["array_concat", "array_cat", "list_concat", "list_cat"] } - BuiltinScalarFunction::ArrayContains => &["array_contains"], BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"], + BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"], + BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"], + BuiltinScalarFunction::ArrayHas => { + &["array_has", "list_has", "array_contains", "list_contains"] + } BuiltinScalarFunction::ArrayFill => &["array_fill"], BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"], BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 30d9580c42ee..418aa8d8f8a9 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -536,10 +536,22 @@ scalar_expr!( ); nary_scalar_expr!(ArrayConcat, array_concat, "concatenates arrays."); scalar_expr!( - ArrayContains, - array_contains, + ArrayHas, + array_has, first_array second_array, - "returns true, if each element of the second array appearing in the first array, otherwise false." +"Returns true, if the element appears in the first array, otherwise false." +); +scalar_expr!( + ArrayHasAll, + array_has_all, + first_array second_array, +"Returns true if each element of the second array appears in the first array; otherwise, it returns false." +); +scalar_expr!( + ArrayHasAny, + array_has_any, + first_array second_array, +"Returns true if at least one element of the second array appears in the first array; otherwise, it returns false." ); scalar_expr!( ArrayDims, diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 47558ecc262c..239b2a6a48ac 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1375,69 +1375,275 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -macro_rules! contains { - ($FIRST_ARRAY:expr, $SECOND_ARRAY:expr, $ARRAY_TYPE:ident) => {{ - let first_array = downcast_arg!($FIRST_ARRAY, $ARRAY_TYPE); - let second_array = downcast_arg!($SECOND_ARRAY, $ARRAY_TYPE); - let mut res = true; - for x in second_array.values().iter().dedup() { - if !first_array.values().contains(x) { - res = false; - break; +macro_rules! non_list_contains { + ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ + let sub_array = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); + let mut boolean_builder = BooleanArray::builder($ARRAY.len()); + + for (arr, elem) in $ARRAY.iter().zip(sub_array.iter()) { + if let (Some(arr), Some(elem)) = (arr, elem) { + let arr = downcast_arg!(arr, $ARRAY_TYPE); + let res = arr.iter().dedup().flatten().any(|x| x == elem); + boolean_builder.append_value(res); } } + Ok(Arc::new(boolean_builder.finish())) + }}; +} + +/// Array_has SQL function +pub fn array_has(args: &[ArrayRef]) -> Result { + assert_eq!(args.len(), 2); + + let array = args[0].as_list::(); + + match args[1].data_type() { + DataType::List(_) => { + let sub_array = args[1].as_list::(); + let mut boolean_builder = BooleanArray::builder(array.len()); + + for (arr, elem) in array.iter().zip(sub_array.iter()) { + if let (Some(arr), Some(elem)) = (arr, elem) { + let list_arr = arr.as_list::(); + let res = list_arr.iter().dedup().flatten().any(|x| *x == *elem); + boolean_builder.append_value(res); + } + } + Ok(Arc::new(boolean_builder.finish())) + } + + // Int64, Int32, Int16, Int8 + // UInt64, UInt32, UInt16, UInt8 + DataType::Int64 => { + non_list_contains!(array, args[1], Int64Array) + } + DataType::Int32 => { + non_list_contains!(array, args[1], Int32Array) + } + DataType::Int16 => { + non_list_contains!(array, args[1], Int16Array) + } + DataType::Int8 => { + non_list_contains!(array, args[1], Int8Array) + } + DataType::UInt64 => { + non_list_contains!(array, args[1], UInt64Array) + } + DataType::UInt32 => { + non_list_contains!(array, args[1], UInt32Array) + } + DataType::UInt16 => { + non_list_contains!(array, args[1], UInt16Array) + } + DataType::UInt8 => { + non_list_contains!(array, args[1], UInt8Array) + } + DataType::Float64 => { + non_list_contains!(array, args[1], Float64Array) + } + DataType::Float32 => { + non_list_contains!(array, args[1], Float32Array) + } + DataType::Utf8 => { + non_list_contains!(array, args[1], StringArray) + } + DataType::LargeUtf8 => { + non_list_contains!(array, args[1], LargeStringArray) + } + DataType::Boolean => { + non_list_contains!(array, args[1], BooleanArray) + } + data_type => Err(DataFusionError::NotImplemented(format!( + "Array_has is not implemented for '{data_type:?}'" + ))), + } +} + +macro_rules! array_has_any_non_list_check { + ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ + let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); + let sub_arr = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); + + let mut res = false; + for elem in sub_arr.iter().dedup() { + if let Some(elem) = elem { + res |= arr.iter().dedup().flatten().any(|x| x == elem); + } else { + return Err(DataFusionError::Internal(format!( + "array_has_any does not support Null type for element in sub_array" + ))); + } + } res }}; } -/// Array_contains SQL function -pub fn array_contains(args: &[ArrayRef]) -> Result { - fn concat_inner_lists(arg: ArrayRef) -> Result { - match arg.data_type() { - DataType::List(field) => match field.data_type() { - DataType::List(..) => { - concat_inner_lists(array_concat(&[as_list_array(&arg)? - .values() - .clone()])?) +/// Array_has_any SQL function +pub fn array_has_any(args: &[ArrayRef]) -> Result { + assert_eq!(args.len(), 2); + + let array = args[0].as_list::(); + let sub_array = args[1].as_list::(); + + let mut boolean_builder = BooleanArray::builder(array.len()); + for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { + if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { + let res = match (arr.data_type(), sub_arr.data_type()) { + (DataType::List(_), DataType::List(_)) => { + let arr = downcast_arg!(arr, ListArray); + let sub_arr = downcast_arg!(sub_arr, ListArray); + + let mut res = false; + for elem in sub_arr.iter().dedup().flatten() { + res |= arr.iter().dedup().flatten().any(|x| *x == *elem); + } + res } - _ => Ok(as_list_array(&arg)?.values().clone()), - }, - data_type => Err(DataFusionError::NotImplemented(format!( - "Array is not type '{data_type:?}'." - ))), + // Int64, Int32, Int16, Int8 + // UInt64, UInt32, UInt16, UInt8 + (DataType::Int64, DataType::Int64) => { + array_has_any_non_list_check!(arr, sub_arr, Int64Array) + } + (DataType::Int32, DataType::Int32) => { + array_has_any_non_list_check!(arr, sub_arr, Int32Array) + } + (DataType::Int16, DataType::Int16) => { + array_has_any_non_list_check!(arr, sub_arr, Int16Array) + } + (DataType::Int8, DataType::Int8) => { + array_has_any_non_list_check!(arr, sub_arr, Int8Array) + } + (DataType::UInt64, DataType::UInt64) => { + array_has_any_non_list_check!(arr, sub_arr, UInt64Array) + } + (DataType::UInt32, DataType::UInt32) => { + array_has_any_non_list_check!(arr, sub_arr, UInt32Array) + } + (DataType::UInt16, DataType::UInt16) => { + array_has_any_non_list_check!(arr, sub_arr, UInt16Array) + } + (DataType::UInt8, DataType::UInt8) => { + array_has_any_non_list_check!(arr, sub_arr, UInt8Array) + } + + (DataType::Float64, DataType::Float64) => { + array_has_any_non_list_check!(arr, sub_arr, Float64Array) + } + (DataType::Float32, DataType::Float32) => { + array_has_any_non_list_check!(arr, sub_arr, Float32Array) + } + (DataType::Boolean, DataType::Boolean) => { + array_has_any_non_list_check!(arr, sub_arr, BooleanArray) + } + // Utf8, LargeUtf8 + (DataType::Utf8, DataType::Utf8) => { + array_has_any_non_list_check!(arr, sub_arr, StringArray) + } + (DataType::LargeUtf8, DataType::LargeUtf8) => { + array_has_any_non_list_check!(arr, sub_arr, LargeStringArray) + } + + (arr_type, sub_arr_type) => Err(DataFusionError::NotImplemented(format!( + "Array_has_any is not implemented for '{arr_type:?}' and '{sub_arr_type:?}'", + )))?, + }; + boolean_builder.append_value(res); } } + Ok(Arc::new(boolean_builder.finish())) +} - let concat_first_array = concat_inner_lists(args[0].clone())?.clone(); - let concat_second_array = concat_inner_lists(args[1].clone())?.clone(); +macro_rules! array_has_all_non_list_check { + ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ + let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); + let sub_arr = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); - let res = match (concat_first_array.data_type(), concat_second_array.data_type()) { - (DataType::Utf8, DataType::Utf8) => contains!(concat_first_array, concat_second_array, StringArray), - (DataType::LargeUtf8, DataType::LargeUtf8) => contains!(concat_first_array, concat_second_array, LargeStringArray), - (DataType::Boolean, DataType::Boolean) => { - let first_array = downcast_arg!(concat_first_array, BooleanArray); - let second_array = downcast_arg!(concat_second_array, BooleanArray); - compute::bool_or(first_array) == compute::bool_or(second_array) - } - (DataType::Float32, DataType::Float32) => contains!(concat_first_array, concat_second_array, Float32Array), - (DataType::Float64, DataType::Float64) => contains!(concat_first_array, concat_second_array, Float64Array), - (DataType::Int8, DataType::Int8) => contains!(concat_first_array, concat_second_array, Int8Array), - (DataType::Int16, DataType::Int16) => contains!(concat_first_array, concat_second_array, Int16Array), - (DataType::Int32, DataType::Int32) => contains!(concat_first_array, concat_second_array, Int32Array), - (DataType::Int64, DataType::Int64) => contains!(concat_first_array, concat_second_array, Int64Array), - (DataType::UInt8, DataType::UInt8) => contains!(concat_first_array, concat_second_array, UInt8Array), - (DataType::UInt16, DataType::UInt16) => contains!(concat_first_array, concat_second_array, UInt16Array), - (DataType::UInt32, DataType::UInt32) => contains!(concat_first_array, concat_second_array, UInt32Array), - (DataType::UInt64, DataType::UInt64) => contains!(concat_first_array, concat_second_array, UInt64Array), - (first_array_data_type, second_array_data_type) => { - return Err(DataFusionError::NotImplemented(format!( - "Array_contains is not implemented for types '{first_array_data_type:?}' and '{second_array_data_type:?}'." - ))) + let mut res = true; + for elem in sub_arr.iter().dedup() { + if let Some(elem) = elem { + res &= arr.iter().dedup().flatten().any(|x| x == elem); + } else { + return Err(DataFusionError::Internal(format!( + "array_has_all does not support Null type for element in sub_array" + ))); + } } - }; + res + }}; +} + +/// Array_has_all SQL function +pub fn array_has_all(args: &[ArrayRef]) -> Result { + assert_eq!(args.len(), 2); + + let array = args[0].as_list::(); + let sub_array = args[1].as_list::(); + + let mut boolean_builder = BooleanArray::builder(array.len()); + for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { + if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { + let res = match (arr.data_type(), sub_arr.data_type()) { + (DataType::List(_), DataType::List(_)) => { + let arr = downcast_arg!(arr, ListArray); + let sub_arr = downcast_arg!(sub_arr, ListArray); + + let mut res = true; + for elem in sub_arr.iter().dedup().flatten() { + res &= arr.iter().dedup().flatten().any(|x| *x == *elem); + } + res + } + // Int64, Int32, Int16, Int8 + // UInt64, UInt32, UInt16, UInt8 + (DataType::Int64, DataType::Int64) => { + array_has_all_non_list_check!(arr, sub_arr, Int64Array) + } + (DataType::Int32, DataType::Int32) => { + array_has_all_non_list_check!(arr, sub_arr, Int32Array) + } + (DataType::Int16, DataType::Int16) => { + array_has_all_non_list_check!(arr, sub_arr, Int16Array) + } + (DataType::Int8, DataType::Int8) => { + array_has_all_non_list_check!(arr, sub_arr, Int8Array) + } + (DataType::UInt64, DataType::UInt64) => { + array_has_all_non_list_check!(arr, sub_arr, UInt64Array) + } + (DataType::UInt32, DataType::UInt32) => { + array_has_all_non_list_check!(arr, sub_arr, UInt32Array) + } + (DataType::UInt16, DataType::UInt16) => { + array_has_all_non_list_check!(arr, sub_arr, UInt16Array) + } + (DataType::UInt8, DataType::UInt8) => { + array_has_all_non_list_check!(arr, sub_arr, UInt8Array) + } - Ok(Arc::new(BooleanArray::from(vec![res]))) + (DataType::Float64, DataType::Float64) => { + array_has_all_non_list_check!(arr, sub_arr, Float64Array) + } + (DataType::Float32, DataType::Float32) => { + array_has_all_non_list_check!(arr, sub_arr, Float32Array) + } + (DataType::Boolean, DataType::Boolean) => { + array_has_all_non_list_check!(arr, sub_arr, BooleanArray) + } + (DataType::Utf8, DataType::Utf8) => { + array_has_all_non_list_check!(arr, sub_arr, StringArray) + } + (DataType::LargeUtf8, DataType::LargeUtf8) => { + array_has_all_non_list_check!(arr, sub_arr, LargeStringArray) + } + (arr_type, sub_arr_type) => Err(DataFusionError::NotImplemented(format!( + "Array_has_all is not implemented for '{arr_type:?}' and '{sub_arr_type:?}'", + )))?, + }; + boolean_builder.append_value(res); + } + } + Ok(Arc::new(boolean_builder.finish())) } #[cfg(test)] @@ -2070,63 +2276,6 @@ mod tests { assert_eq!(result, &UInt64Array::from_value(2, 1)); } - #[test] - fn test_array_contains() { - // array_contains([1, 2, 3, 4], array_append([1, 2, 3, 4], 3)) = t - let first_array = return_array().into_array(1); - let second_array = array_append(&[ - first_array.clone(), - Arc::new(Int64Array::from(vec![Some(3)])), - ]) - .expect("failed to initialize function array_contains"); - - let arr = array_contains(&[first_array.clone(), second_array]) - .expect("failed to initialize function array_contains"); - let result = as_boolean_array(&arr); - - assert_eq!(result, &BooleanArray::from(vec![true])); - - // array_contains([1, 2, 3, 4], array_append([1, 2, 3, 4], 5)) = f - let second_array = array_append(&[ - first_array.clone(), - Arc::new(Int64Array::from(vec![Some(5)])), - ]) - .expect("failed to initialize function array_contains"); - - let arr = array_contains(&[first_array.clone(), second_array]) - .expect("failed to initialize function array_contains"); - let result = as_boolean_array(&arr); - - assert_eq!(result, &BooleanArray::from(vec![false])); - } - - #[test] - fn test_nested_array_contains() { - // array_contains([[1, 2, 3, 4], [5, 6, 7, 8]], array_append([1, 2, 3, 4], 3)) = t - let first_array = return_nested_array().into_array(1); - let array = return_array().into_array(1); - let second_array = - array_append(&[array.clone(), Arc::new(Int64Array::from(vec![Some(3)]))]) - .expect("failed to initialize function array_contains"); - - let arr = array_contains(&[first_array.clone(), second_array]) - .expect("failed to initialize function array_contains"); - let result = as_boolean_array(&arr); - - assert_eq!(result, &BooleanArray::from(vec![true])); - - // array_contains([[1, 2, 3, 4], [5, 6, 7, 8]], array_append([1, 2, 3, 4], 9)) = f - let second_array = - array_append(&[array.clone(), Arc::new(Int64Array::from(vec![Some(9)]))]) - .expect("failed to initialize function array_contains"); - - let arr = array_contains(&[first_array.clone(), second_array]) - .expect("failed to initialize function array_contains"); - let result = as_boolean_array(&arr); - - assert_eq!(result, &BooleanArray::from(vec![false])); - } - fn return_array() -> ColumnarValue { let args = [ ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 14279d70064b..f48823bff514 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -416,8 +416,14 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayConcat => { Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args)) } - BuiltinScalarFunction::ArrayContains => { - Arc::new(|args| make_scalar_function(array_expressions::array_contains)(args)) + BuiltinScalarFunction::ArrayHasAll => { + Arc::new(|args| make_scalar_function(array_expressions::array_has_all)(args)) + } + BuiltinScalarFunction::ArrayHasAny => { + Arc::new(|args| make_scalar_function(array_expressions::array_has_any)(args)) + } + BuiltinScalarFunction::ArrayHas => { + Arc::new(|args| make_scalar_function(array_expressions::array_has)(args)) } BuiltinScalarFunction::ArrayDims => { Arc::new(|args| make_scalar_function(array_expressions::array_dims)(args)) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index a1caa4c62185..8192a403d33e 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -563,10 +563,12 @@ enum ScalarFunction { ArrayToString = 97; Cardinality = 98; TrimArray = 99; - ArrayContains = 100; Encode = 101; Decode = 102; Cot = 103; + ArrayHas = 104; + ArrayHasAny = 105; + ArrayHasAll = 106; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 4155a052cff2..05bfbd089dfe 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -18064,10 +18064,12 @@ impl serde::Serialize for ScalarFunction { Self::ArrayToString => "ArrayToString", Self::Cardinality => "Cardinality", Self::TrimArray => "TrimArray", - Self::ArrayContains => "ArrayContains", Self::Encode => "Encode", Self::Decode => "Decode", Self::Cot => "Cot", + Self::ArrayHas => "ArrayHas", + Self::ArrayHasAny => "ArrayHasAny", + Self::ArrayHasAll => "ArrayHasAll", }; serializer.serialize_str(variant) } @@ -18179,10 +18181,12 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayToString", "Cardinality", "TrimArray", - "ArrayContains", "Encode", "Decode", "Cot", + "ArrayHas", + "ArrayHasAny", + "ArrayHasAll", ]; struct GeneratedVisitor; @@ -18325,10 +18329,12 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayToString" => Ok(ScalarFunction::ArrayToString), "Cardinality" => Ok(ScalarFunction::Cardinality), "TrimArray" => Ok(ScalarFunction::TrimArray), - "ArrayContains" => Ok(ScalarFunction::ArrayContains), "Encode" => Ok(ScalarFunction::Encode), "Decode" => Ok(ScalarFunction::Decode), "Cot" => Ok(ScalarFunction::Cot), + "ArrayHas" => Ok(ScalarFunction::ArrayHas), + "ArrayHasAny" => Ok(ScalarFunction::ArrayHasAny), + "ArrayHasAll" => Ok(ScalarFunction::ArrayHasAll), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index af0703460a68..f50754494d1d 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2262,10 +2262,12 @@ pub enum ScalarFunction { ArrayToString = 97, Cardinality = 98, TrimArray = 99, - ArrayContains = 100, Encode = 101, Decode = 102, Cot = 103, + ArrayHas = 104, + ArrayHasAny = 105, + ArrayHasAll = 106, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2374,10 +2376,12 @@ impl ScalarFunction { ScalarFunction::ArrayToString => "ArrayToString", ScalarFunction::Cardinality => "Cardinality", ScalarFunction::TrimArray => "TrimArray", - ScalarFunction::ArrayContains => "ArrayContains", ScalarFunction::Encode => "Encode", ScalarFunction::Decode => "Decode", ScalarFunction::Cot => "Cot", + ScalarFunction::ArrayHas => "ArrayHas", + ScalarFunction::ArrayHasAny => "ArrayHasAny", + ScalarFunction::ArrayHasAll => "ArrayHasAll", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2483,10 +2487,12 @@ impl ScalarFunction { "ArrayToString" => Some(Self::ArrayToString), "Cardinality" => Some(Self::Cardinality), "TrimArray" => Some(Self::TrimArray), - "ArrayContains" => Some(Self::ArrayContains), "Encode" => Some(Self::Encode), "Decode" => Some(Self::Decode), "Cot" => Some(Self::Cot), + "ArrayHas" => Some(Self::ArrayHas), + "ArrayHasAny" => Some(Self::ArrayHasAny), + "ArrayHasAll" => Some(Self::ArrayHasAll), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index a3718090ed08..674588692d98 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -36,12 +36,12 @@ use datafusion_common::{ }; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::{ - abs, acos, acosh, array, array_append, array_concat, array_contains, array_dims, - array_fill, array_length, array_ndims, array_position, array_positions, - array_prepend, array_remove, array_replace, array_to_string, ascii, asin, asinh, - atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, - chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, - current_time, date_bin, date_part, date_trunc, degrees, digest, exp, + abs, acos, acosh, array, array_append, array_concat, array_dims, array_fill, + array_has, array_has_all, array_has_any, array_length, array_ndims, array_position, + array_positions, array_prepend, array_remove, array_replace, array_to_string, ascii, + asin, asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, + character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, + current_date, current_time, date_bin, date_part, date_trunc, degrees, digest, exp, expr::{self, InList, Sort, WindowFunction}, factorial, floor, from_unixtime, gcd, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, @@ -451,7 +451,9 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ToTimestamp => Self::ToTimestamp, ScalarFunction::ArrayAppend => Self::ArrayAppend, ScalarFunction::ArrayConcat => Self::ArrayConcat, - ScalarFunction::ArrayContains => Self::ArrayContains, + ScalarFunction::ArrayHasAll => Self::ArrayHasAll, + ScalarFunction::ArrayHasAny => Self::ArrayHasAny, + ScalarFunction::ArrayHas => Self::ArrayHas, ScalarFunction::ArrayDims => Self::ArrayDims, ScalarFunction::ArrayFill => Self::ArrayFill, ScalarFunction::ArrayLength => Self::ArrayLength, @@ -1219,7 +1221,15 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), - ScalarFunction::ArrayContains => Ok(array_contains( + ScalarFunction::ArrayHasAll => Ok(array_has_all( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), + ScalarFunction::ArrayHasAny => Ok(array_has_any( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), + ScalarFunction::ArrayHas => Ok(array_has( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index e899eee6fefe..072bc84d5452 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1382,7 +1382,9 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp, BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend, BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat, - BuiltinScalarFunction::ArrayContains => Self::ArrayContains, + BuiltinScalarFunction::ArrayHasAll => Self::ArrayHasAll, + BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny, + BuiltinScalarFunction::ArrayHas => Self::ArrayHas, BuiltinScalarFunction::ArrayDims => Self::ArrayDims, BuiltinScalarFunction::ArrayFill => Self::ArrayFill, BuiltinScalarFunction::ArrayLength => Self::ArrayLength, diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index cf1e2f58af07..14cf8dc2ac89 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -179,24 +179,26 @@ Unlike to some databases the math functions in Datafusion works the same way as ## Array Expressions -| Function | Notes | -| ----------------------------------------- | -------------------------------------------------------------------------------------------------------------- | -| array_append(array, element) | Appends an element to the end of an array. `array_append([1, 2, 3], 4) -> [1, 2, 3, 4]` | -| array_concat(array[, ..., array_n]) | Concatenates arrays. `array_concat([1, 2, 3], [4, 5, 6]) -> [1, 2, 3, 4, 5, 6]` | -| array_contains(first_array, second_array) | Returns true, if each element of the second array appearing in the first array, otherwise false. | -| array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | -| array_fill(element, array) | Returns an array filled with copies of the given value. | -| array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | -| array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | -| array_position(array, element) | Searches for an element in the array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2` | -| array_positions(array, element) | Searches for an element in the array, returns all occurrences. `array_positions([1, 2, 2, 3, 4], 2) -> [2, 3]` | -| array_prepend(array, element) | Prepends an element to the beginning of an array. `array_prepend(1, [2, 3, 4]) -> [1, 2, 3, 4]` | -| array_remove(array, element) | Removes all elements equal to the given value from the array. | -| array_replace(array, from, to) | Replaces a specified element with another specified element. | -| array_to_string(array, delimeter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | -| cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | -| make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | -| trim_array(array, n) | Removes the last n elements from the array. | +| Function | Notes | +| ------------------------------------ | -------------------------------------------------------------------------------------------------------------- | +| array_append(array, element) | Appends an element to the end of an array. `array_append([1, 2, 3], 4) -> [1, 2, 3, 4]` | +| array_concat(array[, ..., array_n]) | Concatenates arrays. `array_concat([1, 2, 3], [4, 5, 6]) -> [1, 2, 3, 4, 5, 6]` | +| array_has(array, element) | Returns true if the array contains the element `array_has([1,2,3], 1) -> true` | +| array_has_all(array, sub-array) | Returns true if all elements of sub-array exist in array `array_has_all([1,2,3], [1,3]) -> true` | +| array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` | +| array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | +| array_fill(element, array) | Returns an array filled with copies of the given value. | +| array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | +| array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | +| array_position(array, element) | Searches for an element in the array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2` | +| array_positions(array, element) | Searches for an element in the array, returns all occurrences. `array_positions([1, 2, 2, 3, 4], 2) -> [2, 3]` | +| array_prepend(array, element) | Prepends an element to the beginning of an array. `array_prepend(1, [2, 3, 4]) -> [1, 2, 3, 4]` | +| array_remove(array, element) | Removes all elements equal to the given value from the array. | +| array_replace(array, from, to) | Replaces a specified element with another specified element. | +| array_to_string(array, delimeter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | +| cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | +| make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | +| trim_array(array, n) | Removes the last n elements from the array. | ## Regular Expressions diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 31b075821b10..bcdd3832523e 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1522,19 +1522,49 @@ array_concat(array[, ..., array_n]) - list_cat - list_concat -### `array_contains` +### `array_has` -Returns true, if each element of the second array appears in the first array, otherwise false. +Returns true if the array contains the element ``` -array_contains(first_array, second_array) +array_has(array, element) ``` #### Arguments -- **first_array**: Array expression. +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **element**: Scalar or Array expression. + Can be a constant, column, or function, and any combination of array operators. + +### `array_has_all` + +Returns true if all elements of sub-array exist in array + +``` +array_has_all(array, sub-array) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **sub-array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. + +### `array_has_any` + +Returns true if any elements exist in both arrays + +``` +array_has_any(array, sub-array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **second_array**: Array expression. +- **sub-array**: Array expression. Can be a constant, column, or function, and any combination of array operators. ### `array_dims`