diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 6ab4507f949c..acd2535bf7bf 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -52,6 +52,7 @@ use arrow::{ UInt16Type, UInt32Type, UInt64Type, UInt8Type, DECIMAL128_MAX_PRECISION, }, }; +use arrow_array::cast::as_list_array; use arrow_array::{ArrowNativeTypeOp, Scalar}; pub use struct_builder::ScalarStructBuilder; @@ -2137,67 +2138,28 @@ impl ScalarValue { /// Retrieve ScalarValue for each row in `array` /// - /// Example 1: Array (ScalarValue::Int32) + /// Example /// ``` /// use datafusion_common::ScalarValue; /// use arrow::array::ListArray; /// use arrow::datatypes::{DataType, Int32Type}; /// - /// // Equivalent to [[1,2,3], [4,5]] /// let list_arr = ListArray::from_iter_primitive::(vec![ /// Some(vec![Some(1), Some(2), Some(3)]), + /// None, /// Some(vec![Some(4), Some(5)]) /// ]); /// - /// // Convert the array into Scalar Values for each row /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); /// /// let expected = vec![ - /// vec![ + /// vec![ /// ScalarValue::Int32(Some(1)), /// ScalarValue::Int32(Some(2)), /// ScalarValue::Int32(Some(3)), - /// ], - /// vec![ - /// ScalarValue::Int32(Some(4)), - /// ScalarValue::Int32(Some(5)), - /// ], - /// ]; - /// - /// assert_eq!(scalar_vec, expected); - /// ``` - /// - /// Example 2: Nested array (ScalarValue::List) - /// ``` - /// use datafusion_common::ScalarValue; - /// use arrow::array::ListArray; - /// use arrow::datatypes::{DataType, Int32Type}; - /// use datafusion_common::utils::array_into_list_array; - /// use std::sync::Arc; - /// - /// let list_arr = ListArray::from_iter_primitive::(vec![ - /// Some(vec![Some(1), Some(2), Some(3)]), - /// Some(vec![Some(4), Some(5)]) - /// ]); - /// - /// // Wrap into another layer of list, we got nested array as [ [[1,2,3], [4,5]] ] - /// let list_arr = array_into_list_array(Arc::new(list_arr)); - /// - /// // Convert the array into Scalar Values for each row, we got 1D arrays in this example - /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); - /// - /// let l1 = ListArray::from_iter_primitive::(vec![ - /// Some(vec![Some(1), Some(2), Some(3)]), - /// ]); - /// let l2 = ListArray::from_iter_primitive::(vec![ - /// Some(vec![Some(4), Some(5)]), - /// ]); - /// - /// let expected = vec![ - /// vec![ - /// ScalarValue::List(Arc::new(l1)), - /// ScalarValue::List(Arc::new(l2)), /// ], + /// vec![], + /// vec![ScalarValue::Int32(Some(4)), ScalarValue::Int32(Some(5))] /// ]; /// /// assert_eq!(scalar_vec, expected); @@ -2206,13 +2168,27 @@ impl ScalarValue { let mut scalars = Vec::with_capacity(array.len()); for index in 0..array.len() { - let nested_array = array.as_list::().value(index); - let scalar_values = (0..nested_array.len()) - .map(|i| ScalarValue::try_from_array(&nested_array, i)) - .collect::>>()?; + let scalar_values = match array.data_type() { + DataType::List(_) => { + let list_array = as_list_array(array); + match list_array.is_null(index) { + true => Vec::new(), + false => { + let nested_array = list_array.value(index); + ScalarValue::convert_array_to_scalar_vec(&nested_array)? + .into_iter() + .flatten() + .collect() + } + } + } + _ => { + let scalar = ScalarValue::try_from_array(array, index)?; + vec![scalar] + } + }; scalars.push(scalar_values); } - Ok(scalars) } diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 84b791a3de05..af6d0d5f4e24 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -44,9 +44,9 @@ async fn csv_query_array_agg_distinct() -> Result<()> { // We should have 1 row containing a list let column = actual[0].column(0); assert_eq!(column.len(), 1); + let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&column)?; let mut scalars = scalar_vec[0].clone(); - // workaround lack of Ord of ScalarValue let cmp = |a: &ScalarValue, b: &ScalarValue| { a.partial_cmp(b).expect("Can compare ScalarValues") diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 8e7b9d91ee49..4cd7a469b7a1 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -24,7 +24,6 @@ use std::sync::Arc; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; -use arrow_array::cast::AsArray; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; @@ -139,10 +138,9 @@ impl Accumulator for DistinctArrayAggAccumulator { assert_eq!(values.len(), 1, "batch input should only include 1 column!"); let array = &values[0]; - - for i in 0..array.len() { - let scalar = ScalarValue::try_from_array(&array, i)?; - self.values.insert(scalar); + let scalar_vec = ScalarValue::convert_array_to_scalar_vec(array)?; + for scalars in scalar_vec { + self.values.extend(scalars); } Ok(()) @@ -153,12 +151,7 @@ impl Accumulator for DistinctArrayAggAccumulator { return Ok(()); } - let array = &states[0]; - - assert_eq!(array.len(), 1, "state array should only include 1 row!"); - // Unwrap outer ListArray then do update batch - let inner_array = array.as_list::().value(0); - self.update_batch(&[inner_array]) + self.update_batch(states) } fn evaluate(&mut self) -> Result { @@ -189,54 +182,46 @@ mod tests { use arrow_array::ListArray; use arrow_buffer::OffsetBuffer; use datafusion_common::internal_err; + use datafusion_common::utils::array_into_list_array; + + // arrow::compute::sort cann't sort ListArray directly, so we need to sort the inner primitive array and wrap it back into ListArray. + fn sort_list_inner(arr: ScalarValue) -> ScalarValue { + let arr = match arr { + ScalarValue::List(arr) => arr.value(0), + _ => { + panic!("Expected ScalarValue::List, got {:?}", arr) + } + }; - // arrow::compute::sort can't sort nested ListArray directly, so we compare the scalar values pair-wise. - fn compare_list_contents( - expected: Vec, - actual: ScalarValue, - ) -> Result<()> { - let array = actual.to_array()?; - let list_array = array.as_list::(); - let inner_array = list_array.value(0); - let mut actual_scalars = vec![]; - for index in 0..inner_array.len() { - let sv = ScalarValue::try_from_array(&inner_array, index)?; - actual_scalars.push(sv); - } - - if actual_scalars.len() != expected.len() { - return internal_err!( - "Expected and actual list lengths differ: expected={}, actual={}", - expected.len(), - actual_scalars.len() - ); - } + let arr = arrow::compute::sort(&arr, None).unwrap(); + let list_arr = array_into_list_array(arr); + ScalarValue::List(Arc::new(list_arr)) + } - let mut seen = vec![false; expected.len()]; - for v in expected { - let mut found = false; - for (i, sv) in actual_scalars.iter().enumerate() { - if sv == &v { - seen[i] = true; - found = true; - break; + fn compare_list_contents(expected: ScalarValue, actual: ScalarValue) -> Result<()> { + let actual = sort_list_inner(actual); + + match (&expected, &actual) { + (ScalarValue::List(arr1), ScalarValue::List(arr2)) => { + if arr1.eq(arr2) { + Ok(()) + } else { + internal_err!( + "Actual value {:?} not found in expected values {:?}", + actual, + expected + ) } } - if !found { - return internal_err!( - "Expected value {:?} not found in actual values {:?}", - v, - actual_scalars - ); + _ => { + internal_err!("Expected scalar lists as inputs") } } - - Ok(()) } fn check_distinct_array_agg( input: ArrayRef, - expected: Vec, + expected: ScalarValue, datatype: DataType, ) -> Result<()> { let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); @@ -249,13 +234,14 @@ mod tests { true, )); let actual = aggregate(&batch, agg)?; + compare_list_contents(expected, actual) } fn check_merge_distinct_array_agg( input1: ArrayRef, input2: ArrayRef, - expected: Vec, + expected: ScalarValue, datatype: DataType, ) -> Result<()> { let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); @@ -276,20 +262,23 @@ mod tests { accum1.merge_batch(&[array])?; let actual = accum1.evaluate()?; + compare_list_contents(expected, actual) } #[test] fn distinct_array_agg_i32() -> Result<()> { let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); - - let expected = vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ScalarValue::Int32(Some(7)), - ]; + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(4), + Some(5), + Some(7), + ])]), + )); check_distinct_array_agg(col, expected, DataType::Int32) } @@ -299,15 +288,18 @@ mod tests { let col1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); let col2: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 7, 8, 4])); - let expected = vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(3)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ScalarValue::Int32(Some(7)), - ScalarValue::Int32(Some(8)), - ]; + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(7), + Some(8), + ])]), + )); check_merge_distinct_array_agg(col1, col2, expected, DataType::Int32) } @@ -359,16 +351,23 @@ mod tests { let l2 = ScalarValue::List(Arc::new(l2)); let l3 = ScalarValue::List(Arc::new(l3)); - // Duplicate l1 and l3 in the input array and check that it is deduped in the output. - let array = ScalarValue::iter_to_array(vec![ - l1.clone(), - l2.clone(), - l3.clone(), - l3.clone(), - l1.clone(), - ]) - .unwrap(); - let expected = vec![l1, l2, l3]; + // Duplicate l1 in the input array and check that it is deduped in the output. + let array = ScalarValue::iter_to_array(vec![l1.clone(), l2, l3, l1]).unwrap(); + + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ])]), + )); check_distinct_array_agg( array, @@ -427,10 +426,22 @@ mod tests { let l3 = ScalarValue::List(Arc::new(l3)); // Duplicate l1 in the input array and check that it is deduped in the output. - let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2.clone()]).unwrap(); - let input2 = ScalarValue::iter_to_array(vec![l1.clone(), l3.clone()]).unwrap(); - - let expected = vec![l1, l2, l3]; + let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2]).unwrap(); + let input2 = ScalarValue::iter_to_array(vec![l1, l3]).unwrap(); + + let expected = + ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + ])]), + )); check_merge_distinct_array_agg(input1, input2, expected, DataType::Int32) } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs index 71782fcc5f9b..52afd82d0326 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs @@ -26,7 +26,6 @@ use std::sync::Arc; use ahash::RandomState; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field, TimeUnit}; -use arrow_array::cast::AsArray; use arrow_array::types::{ Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, @@ -251,10 +250,11 @@ impl Accumulator for DistinctCountAccumulator { return Ok(()); } assert_eq!(states.len(), 1, "array_agg states must be singleton!"); - let array = &states[0]; - let list_array = array.as_list::(); - let inner_array = list_array.value(0); - self.update_batch(&[inner_array]) + let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; + for scalars in scalar_vec.into_iter() { + self.values.extend(scalars); + } + Ok(()) } fn evaluate(&mut self) -> Result { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index b78c6287746c..9a6a1ea55f8f 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -139,61 +139,6 @@ AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(agg_order.c1)] --------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_agg_multi_order.csv]]}, projection=[c1, c2, c3], has_header=true -# test array_agg_order with list data type -statement ok -CREATE TABLE array_agg_order_list_table AS VALUES - ('w', 2, [1,2,3], 10), - ('w', 1, [9,5,2], 20), - ('w', 1, [3,2,5], 30), - ('b', 2, [4,5,6], 20), - ('b', 1, [7,8,9], 30) -; - -query T? rowsort -select column1, array_agg(column3 order by column2, column4 desc) from array_agg_order_list_table group by column1; ----- -b [[7, 8, 9], [4, 5, 6]] -w [[3, 2, 5], [9, 5, 2], [1, 2, 3]] - -query T?? rowsort -select column1, first_value(column3 order by column2, column4 desc), last_value(column3 order by column2, column4 desc) from array_agg_order_list_table group by column1; ----- -b [7, 8, 9] [4, 5, 6] -w [3, 2, 5] [1, 2, 3] - -query T? rowsort -select column1, nth_value(column3, 2 order by column2, column4 desc) from array_agg_order_list_table group by column1; ----- -b [4, 5, 6] -w [9, 5, 2] - -statement ok -drop table array_agg_order_list_table; - -# test array_agg_distinct with list data type -statement ok -CREATE TABLE array_agg_distinct_list_table AS VALUES - ('w', [0,1]), - ('w', [0,1]), - ('w', [1,0]), - ('b', [1,0]), - ('b', [1,0]), - ('b', [1,0]), - ('b', [0,1]) -; - -# Apply array_sort to have determinisitic result, higher dimension nested array also works but not for array sort, -# so they are covered in `datafusion/physical-expr/src/aggregate/array_agg_distinct.rs` -query ?? -select array_sort(c1), array_sort(c2) from ( - select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2 from array_agg_distinct_list_table -); ----- -[b, w] [[0, 1], [1, 0]] - -statement ok -drop table array_agg_distinct_list_table; - statement error This feature is not implemented: LIMIT not supported in ARRAY_AGG: 1 SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100 @@ -3322,4 +3267,4 @@ SELECT CAST(a AS INT) FROM t GROUP BY t.a; 3 statement ok -DROP TABLE t; +DROP TABLE t; \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt b/datafusion/sqllogictest/test_files/aggregates_topk.slt index 3f139ede8c77..bd8f00e04158 100644 --- a/datafusion/sqllogictest/test_files/aggregates_topk.slt +++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt @@ -212,6 +212,3 @@ b 0 -2 a -1 -1 NULL 0 0 c 1 2 - -statement ok -drop table traces;