Skip to content

Commit

Permalink
array_has_all done
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Jul 20, 2023
1 parent 7cfac12 commit b76ac83
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 41 deletions.
103 changes: 77 additions & 26 deletions datafusion/core/tests/sqllogictests/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,29 @@ AS VALUES
statement ok
CREATE TABLE array_has_table_1D
AS VALUES
(make_array(1, 2), 1, make_array(true, true, true), false),
(make_array(3, 4, 5), 2, make_array(false, false, false), false)
(make_array(1, 2), 1, make_array(true, true, true), false, make_array(1,2,3), make_array(1,3)),
(make_array(3, 4, 5), 2, make_array(false, false, false), false, make_array(1,2,3,4), make_array(2,5))
;

statement ok
CREATE TABLE array_has_table_1D_Float
AS VALUES
(make_array(1.0, 2.0), 1.0, make_array(1.0,2.0,3.0), make_array(1.0,3.0)),
(make_array(3.0, 4.0, 5.0), 2.0, make_array(1.0,2.0,3.0,4.0), make_array(2.0,5.0))
;

statement ok
CREATE TABLE array_has_table_2D
AS VALUES
(make_array([1,2]), make_array(1,3)),
(make_array([3,4], [5]), make_array(5))
(make_array([1,2]), make_array(1,3), make_array([1,2,3], [4,5], [6,7]), make_array([4,5], [6,7])),
(make_array([3,4], [5]), make_array(5), make_array([1,2,3,4], [5,6,7], [8,9,10]), make_array([1,2,3], [5,6,7], [8,9,10]))
;

statement ok
CREATE TABLE array_has_table_2D_float
AS VALUES
(make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.1, 2.2], [3.3])),
(make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.0], [1.1, 2.2], [3.3]))
;

statement ok
Expand All @@ -112,7 +126,6 @@ AS VALUES
(make_array([[1], [2]], [[2], [3]]), make_array([1], [2]))
;


statement ok
CREATE TABLE arrays_values_without_nulls
AS VALUES
Expand Down Expand Up @@ -1038,10 +1051,12 @@ NULL 1 1

## array_has/array_has_all/array_has_any

query B
select array_has(make_array(1,2), 1);
query BB
select array_has(make_array(1,2), 1),
array_has(make_array(1,2,NULL), 1)
;
----
true
true true

query B
select array_has(make_array([2,3], [3,4]), make_array(2,3));
Expand Down Expand Up @@ -1070,6 +1085,13 @@ from array_has_table_1D;
true false
false true

query B
select array_has(column1, column2)
from array_has_table_1D_Float;
----
true
false

query B
select array_has(column1, column2) from array_has_table_2D;
----
Expand Down Expand Up @@ -1112,26 +1134,46 @@ false true false false
false false false false
false false false false

# array_has_all scalar function #1
# query BBB
# select array_has_all(make_array(1, 2, 3), make_array(1, 1, 2, 3)), array_has_all([1, 2, 3], [1, 1, 2]), array_has_all([1, 2, 3], [2, 1, 3, 1]);
query BBBBBBB
select array_has_all(make_array(1,2,3), make_array(1,3)),
array_has_all(make_array(1,2,3), make_array(1,4)),
array_has_all(make_array([1,2], [3,4]), make_array([1,2])),
array_has_all(make_array([1,2], [3,4]), make_array([1,3])),
array_has_all(make_array([1,2], [3,4]), make_array([1,2], [3,4], [5,6])),
array_has_all(make_array([[1,2,3]]), make_array([[1]])),
array_has_all(make_array([[1,2,3]]), make_array([[1,2,3]]))
;
----
true false true false false false true

# array_has_all scalar function #2
# query BB
# select array_has_all([[1, 2], [3, 4]], [[1, 2], [3, 4], [1, 3]]), array_has_all([[[1], [2]], [[3], [4]]], [1, 2, 2, 3, 4]);
query B
select array_has_all(column5, column6)
from array_has_table_1D;
----
true
false

query B
select array_has_all(column3, column4)
from array_has_table_1D_Float;
----
true
false

query B
select array_has_all(column3, column4)
from array_has_table_2D;
----
true
false

query B
select array_has_all(column1, column2)
from array_has_table_2D_float;
----
true
false

# array_has_all scalar function #3
# query BBB
# select array_has_all(make_array(1, 2, 3), make_array(1, 2, 3, 4)), array_has_all([1, 2, 3], [1, 1, 4]), array_has_all([1, 2, 3], [2, 1, 3, 4]);
# ----
# false false false
#
# # array_has_all scalar function #4
# query BB
# select array_has_all([[1, 2], [3, 4]], [[1, 2], [3, 4], [1, 5]]), array_has_all([[[1], [2]], [[3], [4]]], [1, 2, 2, 3, 5]);
# ----
# false false
#
# # array_has_all scalar function #5
# query BB
# select array_has_all([true, true, false, true, false], [true, false, false]), array_has_all([true, false, true], [true, true]);
Expand Down Expand Up @@ -1319,8 +1361,17 @@ drop table arrays_values_v2;
statement ok
drop table array_has_table_1D;

statement ok
drop table array_has_table_1D_Float;

statement ok
drop table array_has_table_2D;

statement ok
drop table array_has_table_2D_float;

statement ok
drop table array_has_table_3D;

statement ok
drop table arrays_values_without_nulls;
12 changes: 6 additions & 6 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,18 +531,18 @@ scalar_expr!(
"appends an element to the end of an array."
);
nary_scalar_expr!(ArrayConcat, array_concat, "concatenates arrays.");
scalar_expr!(
ArrayHasAll,
array_has_all,
first_array second_array,
"Returns true if each element of the second array appears in the first array; otherwise, it returns false."
);
scalar_expr!(
ArrayHas,
array_has,
first_array second_array,
"Returns true, if the element appears in the first array, otherwise false."
);
scalar_expr!(
ArrayHasAll,
array_has_all,
first_array second_array,
"Returns true if each element of the second array appears in the first array; otherwise, it returns false."
);
scalar_expr!(
ArrayHasAny,
array_has_any,
Expand Down
58 changes: 49 additions & 9 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1517,7 +1517,7 @@ macro_rules! non_list_contains {
for (arr, elem) in $ARRAY.iter().zip(sub_array.iter()) {
if let (Some(arr), Some(elem)) = (arr, elem) {
let arr = downcast_arg!(arr, $ARRAY_TYPE);
let res = arr.iter().flatten().any(|x| x == elem);
let res = arr.iter().dedup().flatten().any(|x| x == elem);
boolean_builder.append_value(res);
}
}
Expand All @@ -1539,7 +1539,7 @@ pub fn array_has(args: &[ArrayRef]) -> Result<ArrayRef> {
for (arr, elem) in array.iter().zip(sub_array.iter()) {
if let (Some(arr), Some(elem)) = (arr, elem) {
let list_arr = arr.as_list::<i32>();
let res = list_arr.iter().flatten().any(|x| *x == *elem);
let res = list_arr.iter().dedup().flatten().any(|x| *x == *elem);
boolean_builder.append_value(res);
}
}
Expand Down Expand Up @@ -1567,21 +1567,61 @@ pub fn array_has(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

macro_rules! array_has_all_non_list_check {
($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{
let arr = downcast_arg!($ARRAY, $ARRAY_TYPE);
let sub_arr = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE);

let mut res = true;
for elem in sub_arr.iter().dedup() {
res &= arr
.iter()
.dedup()
.flatten()
.any(|x| x == elem.expect("null type not supported"));
}
res
}};
}

/// Array_has_all SQL function
pub fn array_has_all(args: &[ArrayRef]) -> Result<ArrayRef> {
assert_eq!(args.len(), 2);
let array = flatten_list_array::<i32>(args[0].clone())?;
// TODO: Dont need to flatten rhs array
let sub_array = flatten_list_array::<i32>(args[1].clone())?;
let mut boolean_array = Vec::with_capacity(array.len());

let array = args[0].as_list::<i32>();
let sub_array = args[1].as_list::<i32>();

let mut boolean_builder = BooleanArray::builder(array.len());
for (arr, sub_arr) in array.iter().zip(sub_array.iter()) {
if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
let res = contains_internal(arr.clone(), sub_arr.clone())?;
boolean_array.push(res);
match (arr.data_type(), sub_arr.data_type()) {
(DataType::List(_), DataType::List(_)) => {
let arr = downcast_arg!(arr, ListArray);
let sub_arr = downcast_arg!(sub_arr, ListArray);

let mut res = true;
for elem in sub_arr.iter().dedup().flatten() {
res &= arr.iter().dedup().flatten().any(|x| *x == *elem);
}
boolean_builder.append_value(res);
}
(DataType::Int64, DataType::Int64) => {
let res = array_has_all_non_list_check!(arr, sub_arr, Int64Array);
boolean_builder.append_value(res);
}
(DataType::Float64, DataType::Float64) => {
let res = array_has_all_non_list_check!(arr, sub_arr, Float64Array);
boolean_builder.append_value(res);
}
_ => Err(DataFusionError::NotImplemented(format!(
"Array_has_all is not implemented for types '{:?}' and '{:?}'.",
arr.data_type(),
sub_arr.data_type()
)))?,
}
}
}
Ok(Arc::new(BooleanArray::from(boolean_array)))
Ok(Arc::new(boolean_builder.finish()))
}

#[cfg(test)]
Expand Down

0 comments on commit b76ac83

Please sign in to comment.