diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 54452e3653a8..4a02e1c68140 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1085,100 +1085,149 @@ pub fn array_positions(args: &[ArrayRef]) -> Result { Ok(res) } -macro_rules! general_remove { - ($ARRAY:expr, $ELEMENT:expr, $MAX:expr, $ARRAY_TYPE:ident) => {{ - let mut offsets: Vec = vec![0]; - let mut values = - downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); +/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences +/// of `element_array[i]`. +/// +/// The type of each **element** in `list_array` must be the same as the type of +/// `element_array`. This function also handles nested arrays +/// ([`ListArray`] of [`ListArray`]s) +/// +/// For example, when called to remove a list array (where each element is a +/// list of int32s, the second argument are int32 arrays, and the +/// third argument is the number of occurrences to remove +/// +/// ```text +/// general_remove( +/// [1, 2, 3, 2], 2, 1 ==> [1, 3, 2] (only the first 2 is removed) +/// [4, 5, 6, 5], 5, 2 ==> [4, 6] (both 5s are removed) +/// ) +/// ``` +fn general_remove( + list_array: &GenericListArray, + element_array: &ArrayRef, + arr_n: Vec, +) -> Result { + let data_type = list_array.value_type(); + let mut new_values = vec![]; + // Build up the offsets for the final output array + let mut offsets = Vec::::with_capacity(arr_n.len() + 1); + offsets.push(OffsetSize::zero()); - let element = downcast_arg!($ELEMENT, $ARRAY_TYPE); - for ((arr, el), max) in $ARRAY.iter().zip(element.iter()).zip($MAX.iter()) { - let last_offset: i32 = offsets.last().copied().ok_or_else(|| { - DataFusionError::Internal(format!("offsets should not be empty")) - })?; - match arr { - Some(arr) => { - let child_array = downcast_arg!(arr, $ARRAY_TYPE); - let mut counter = 0; - let max = if max < Some(1) { 1 } else { max.unwrap() }; + // n is the number of elements to remove in this row + for (row_index, (list_array_row, n)) in + list_array.iter().zip(arr_n.iter()).enumerate() + { + match list_array_row { + Some(list_array_row) => { + let indices = UInt32Array::from(vec![row_index as u32]); + let element_array_row = + arrow::compute::take(element_array, &indices, None)?; - let filter_array = child_array + let eq_array = match element_array_row.data_type() { + // arrow_ord::cmp::distinct does not support ListArray, so we need to compare it by loop + DataType::List(_) => { + // compare each element of the from array + let element_array_row_inner = + as_list_array(&element_array_row)?.value(0); + let list_array_row_inner = as_list_array(&list_array_row)?; + + list_array_row_inner + .iter() + // compare element by element the current row of list_array + .map(|row| row.map(|row| row.ne(&element_array_row_inner))) + .collect::() + } + _ => { + let from_arr = Scalar::new(element_array_row); + // use distinct so Null = Null is false + arrow_ord::cmp::distinct(&list_array_row, &from_arr)? + } + }; + + // We need to keep at most first n elements as `false`, which represent the elements to remove. + let eq_array = if eq_array.false_count() < *n as usize { + eq_array + } else { + let mut count = 0; + eq_array .iter() - .map(|element| { - if counter != max && element == el { - counter += 1; - Some(false) + .map(|e| { + // Keep first n `false` elements, and reverse other elements to `true`. + if let Some(false) = e { + if count < *n { + count += 1; + e + } else { + Some(true) + } } else { - Some(true) + e } }) - .collect::(); + .collect::() + }; - let filtered_array = compute::filter(&child_array, &filter_array)?; - values = downcast_arg!( - compute::concat(&[&values, &filtered_array,])?.clone(), - $ARRAY_TYPE - ) - .clone(); - offsets.push(last_offset + filtered_array.len() as i32); - } - None => offsets.push(last_offset), + let filtered_array = arrow::compute::filter(&list_array_row, &eq_array)?; + offsets.push( + offsets[row_index] + OffsetSize::usize_as(filtered_array.len()), + ); + new_values.push(filtered_array); + } + None => { + // Null element results in a null row (no new offsets) + offsets.push(offsets[row_index]); } } + } - let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); + let values = if new_values.is_empty() { + new_empty_array(&data_type) + } else { + let new_values = new_values.iter().map(|x| x.as_ref()).collect::>(); + arrow::compute::concat(&new_values)? + }; - Arc::new(ListArray::try_new( - field, - OffsetBuffer::new(offsets.into()), - Arc::new(values), - None, - )?) - }}; + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::new(offsets.into()), + values, + list_array.nulls().cloned(), + )?)) } -macro_rules! array_removement_function { - ($FUNC:ident, $MAX_FUNC:expr, $DOC:expr) => { - #[doc = $DOC] - pub fn $FUNC(args: &[ArrayRef]) -> Result { - let arr = as_list_array(&args[0])?; - let element = &args[1]; - let max = $MAX_FUNC(args)?; - - check_datatypes(stringify!($FUNC), &[arr.values(), element])?; - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - general_remove!(arr, element, max, $ARRAY_TYPE) - }; - } - let res = call_array_function!(arr.value_type(), true); - - Ok(res) +fn array_remove_internal( + array: &ArrayRef, + element_array: &ArrayRef, + arr_n: Vec, +) -> Result { + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_remove::(list_array, element_array, arr_n) } - }; + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_remove::(list_array, element_array, arr_n) + } + _ => internal_err!("array_remove_all expects a list array"), + } } -fn remove_one(args: &[ArrayRef]) -> Result { - Ok(Int64Array::from_value(1, args[0].len())) +pub fn array_remove_all(args: &[ArrayRef]) -> Result { + let arr_n = vec![i64::MAX; args[0].len()]; + array_remove_internal(&args[0], &args[1], arr_n) } -fn remove_n(args: &[ArrayRef]) -> Result { - as_int64_array(&args[2]).cloned() +pub fn array_remove(args: &[ArrayRef]) -> Result { + let arr_n = vec![1; args[0].len()]; + array_remove_internal(&args[0], &args[1], arr_n) } -fn remove_all(args: &[ArrayRef]) -> Result { - Ok(Int64Array::from_value(i64::MAX, args[0].len())) +pub fn array_remove_n(args: &[ArrayRef]) -> Result { + let arr_n = as_int64_array(&args[2])?.values().to_vec(); + array_remove_internal(&args[0], &args[1], arr_n) } -// array removement functions -array_removement_function!(array_remove, remove_one, "Array_remove SQL function"); -array_removement_function!(array_remove_n, remove_n, "Array_remove_n SQL function"); -array_removement_function!( - array_remove_all, - remove_all, - "Array_remove_all SQL function" -); - /// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences /// of `from_array[i]`, `to_array[i]`. /// @@ -2608,173 +2657,6 @@ mod tests { ); } - #[test] - fn test_array_remove() { - // array_remove([3, 1, 2, 3, 2, 3], 3) = [1, 2, 3, 2, 3] - let list_array = return_array_with_repeating_elements(); - let array = array_remove(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_remove"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 3, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_remove() { - // array_remove( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // ) = [[5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let element_array = return_array(); - let array = array_remove(&[list_array, element_array]) - .expect("failed to initialize function array_remove"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(1), Some(2), Some(3), Some(4)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_remove_n() { - // array_remove_n([3, 1, 2, 3, 2, 3], 3, 2) = [1, 2, 2, 3] - let list_array = return_array_with_repeating_elements(); - let array = array_remove_n(&[ - list_array, - Arc::new(Int64Array::from_value(3, 1)), - Arc::new(Int64Array::from_value(2, 1)), - ]) - .expect("failed to initialize function array_remove_n"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove_n"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_remove_n() { - // array_remove_n( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // 3, - // ) = [[5, 6, 7, 8], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let element_array = return_array(); - let array = array_remove_n(&[ - list_array, - element_array, - Arc::new(Int64Array::from_value(3, 1)), - ]) - .expect("failed to initialize function array_remove_n"); - let result = - as_list_array(&array).expect("failed to initialize function array_remove_n"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - - #[test] - fn test_array_remove_all() { - // array_remove_all([3, 1, 2, 3, 2, 3], 3) = [1, 2, 2] - let list_array = return_array_with_repeating_elements(); - let array = - array_remove_all(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) - .expect("failed to initialize function array_remove_all"); - let result = as_list_array(&array) - .expect("failed to initialize function array_remove_all"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 2], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_array_remove_all() { - // array_remove_all( - // [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]], - // [1, 2, 3, 4], - // ) = [[5, 6, 7, 8], [9, 10, 11, 12], [5, 6, 7, 8]] - let list_array = return_nested_array_with_repeating_elements(); - let element_array = return_array(); - let array = array_remove_all(&[list_array, element_array]) - .expect("failed to initialize function array_remove_all"); - let result = as_list_array(&array) - .expect("failed to initialize function array_remove_all"); - - assert_eq!(result.len(), 1); - let data = vec![ - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - Some(vec![Some(9), Some(10), Some(11), Some(12)]), - Some(vec![Some(5), Some(6), Some(7), Some(8)]), - ]; - let expected = ListArray::from_iter_primitive::(data); - assert_eq!( - expected, - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .clone() - ); - } - #[test] fn test_array_to_string() { // array_to_string([1, 2, 3, 4], ',') = 1,2,3,4 @@ -3059,63 +2941,4 @@ mod tests { make_array(&[arr1, arr2]).expect("failed to initialize function array") } - - fn return_array_with_repeating_elements() -> ArrayRef { - // Returns: [3, 1, 2, 3, 2, 3] - let args = [ - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - ]; - make_array(&args).expect("failed to initialize function array") - } - - fn return_nested_array_with_repeating_elements() -> ArrayRef { - // Returns: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [9, 10, 11, 12], [5, 6, 7, 8]] - let args = [ - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef, - ]; - let arr1 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![Some(5)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(6)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(7)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(8)])) as ArrayRef, - ]; - let arr2 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![Some(1)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(2)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(3)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(4)])) as ArrayRef, - ]; - let arr3 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![Some(9)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(10)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(11)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(12)])) as ArrayRef, - ]; - let arr4 = make_array(&args).expect("failed to initialize function array"); - - let args = [ - Arc::new(Int64Array::from(vec![Some(5)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(6)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(7)])) as ArrayRef, - Arc::new(Int64Array::from(vec![Some(8)])) as ArrayRef, - ]; - let arr5 = make_array(&args).expect("failed to initialize function array"); - - make_array(&[arr1, arr2, arr3, arr4, arr5]) - .expect("failed to initialize function array") - } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index ad81f37e0764..9207f0f0e359 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2039,6 +2039,20 @@ select array_remove(make_array(1, 2, 2, 1, 1), 2), array_remove(make_array(1.0, ---- [1, 2, 1, 1] [2.0, 2.0, 1.0, 1.0] [h, e, l, o] +query ??? +select + array_remove(make_array(1, null, 2, 3), 2), + array_remove(make_array(1.1, null, 2.2, 3.3), 1.1), + array_remove(make_array('a', null, 'bc'), 'a'); +---- +[1, , 3] [, 2.2, 3.3] [, bc] + +# TODO: https://github.com/apache/arrow-datafusion/issues/7142 +# query +# select +# array_remove(make_array(1, null, 2), null), +# array_remove(make_array(1, null, 2, null), null); + # array_remove scalar function #2 (element is list) query ?? select array_remove(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_remove(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]);