diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 7fa97dad7aa6..7ccf58af832d 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -524,11 +524,33 @@ pub fn array_except(args: &[ArrayRef]) -> Result { /// /// See test cases in `array.slt` for more details. pub fn array_slice(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let from_array = as_int64_array(&args[1])?; - let to_array = as_int64_array(&args[2])?; + let array_data_type = args[0].data_type(); + match array_data_type { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + let from_array = as_int64_array(&args[1])?; + let to_array = as_int64_array(&args[2])?; + general_array_slice::(array, from_array, to_array) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + let from_array = as_int64_array(&args[1])?; + let to_array = as_int64_array(&args[2])?; + general_array_slice::(array, from_array, to_array) + } + _ => not_impl_err!("array_slice does not support type: {:?}", array_data_type), + } +} - let values = list_array.values(); +fn general_array_slice( + array: &GenericListArray, + from_array: &Int64Array, + to_array: &Int64Array, +) -> Result +where + i64: TryInto, +{ + let values = array.values(); let original_data = values.to_data(); let capacity = Capacities::Array(original_data.len()); @@ -539,72 +561,98 @@ pub fn array_slice(args: &[ArrayRef]) -> Result { // We have the slice syntax compatible with DuckDB v0.8.1. // The rule `adjusted_from_index` and `adjusted_to_index` follows the rule of array_slice in duckdb. - fn adjusted_from_index(index: i64, len: usize) -> Option { + fn adjusted_from_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { // 0 ~ len - 1 let adjusted_zero_index = if index < 0 { - index + len as i64 + if let Ok(index) = index.try_into() { + index + len + } else { + return exec_err!("array_slice got invalid index: {}", index); + } } else { // array_slice(arr, 1, to) is the same as array_slice(arr, 0, to) - std::cmp::max(index - 1, 0) + if let Ok(index) = index.try_into() { + std::cmp::max(index - O::usize_as(1), O::usize_as(0)) + } else { + return exec_err!("array_slice got invalid index: {}", index); + } }; - if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { - Some(adjusted_zero_index) + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) } else { // Out of bounds - None + Ok(None) } } - fn adjusted_to_index(index: i64, len: usize) -> Option { + fn adjusted_to_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { // 0 ~ len - 1 let adjusted_zero_index = if index < 0 { // array_slice in duckdb with negative to_index is python-like, so index itself is exclusive - index + len as i64 - 1 + if let Ok(index) = index.try_into() { + index + len - O::usize_as(1) + } else { + return exec_err!("array_slice got invalid index: {}", index); + } } else { // array_slice(arr, from, len + 1) is the same as array_slice(arr, from, len) - std::cmp::min(index - 1, len as i64 - 1) + if let Ok(index) = index.try_into() { + std::cmp::min(index - O::usize_as(1), len - O::usize_as(1)) + } else { + return exec_err!("array_slice got invalid index: {}", index); + } }; - if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { - Some(adjusted_zero_index) + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) } else { // Out of bounds - None + Ok(None) } } - let mut offsets = vec![0]; + let mut offsets = vec![O::usize_as(0)]; - for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { - let start = offset_window[0] as usize; - let end = offset_window[1] as usize; + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let start = offset_window[0]; + let end = offset_window[1]; let len = end - start; // len 0 indicate array is null, return empty array in this row. - if len == 0 { + if len == O::usize_as(0) { offsets.push(offsets[row_index]); continue; } // If index is null, we consider it as the minimum / maximum index of the array. let from_index = if from_array.is_null(row_index) { - Some(0) + Some(O::usize_as(0)) } else { - adjusted_from_index(from_array.value(row_index), len) + adjusted_from_index::(from_array.value(row_index), len)? }; let to_index = if to_array.is_null(row_index) { - Some(len as i64 - 1) + Some(len - O::usize_as(1)) } else { - adjusted_to_index(to_array.value(row_index), len) + adjusted_to_index::(to_array.value(row_index), len)? }; if let (Some(from), Some(to)) = (from_index, to_index) { if from <= to { - assert!(start + to as usize <= end); - mutable.extend(0, start + from as usize, start + to as usize + 1); - offsets.push(offsets[row_index] + (to - from + 1) as i32); + assert!(start + to <= end); + mutable.extend( + 0, + (start + from).to_usize().unwrap(), + (start + to + O::usize_as(1)).to_usize().unwrap(), + ); + offsets.push(offsets[row_index] + (to - from + O::usize_as(1))); } else { // invalid range, return empty array offsets.push(offsets[row_index]); @@ -617,9 +665,9 @@ pub fn array_slice(args: &[ArrayRef]) -> Result { let data = mutable.freeze(); - Ok(Arc::new(ListArray::try_new( - Arc::new(Field::new("item", list_array.value_type(), true)), - OffsetBuffer::new(offsets.into()), + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", array.value_type(), true)), + OffsetBuffer::::new(offsets.into()), arrow_array::make_array(data), None, )?)) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 1202a2b1e99d..210739aa51da 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -912,128 +912,235 @@ select array_slice(make_array(1, 2, 3, 4, 5), 2, 4), array_slice(make_array('h', ---- [2, 3, 4] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 1, 2); +---- +[2, 3, 4] [h, e] + # array_slice scalar function #2 (with positive indexes; full array) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 5); ---- [1, 2, 3, 4, 5] [h, e, l, l, o] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 5); +---- +[1, 2, 3, 4, 5] [h, e, l, l, o] + # array_slice scalar function #3 (with positive indexes; first index = second index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 4, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 3); ---- [4] [l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 4, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, 3); +---- +[4] [l] + # array_slice scalar function #4 (with positive indexes; first index > second_index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, 1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 4, 1); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 4, 1); +---- +[] [] + # array_slice scalar function #5 (with positive indexes; out of bounds) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 7); ---- [2, 3, 4, 5] [l, l, o] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, 7); +---- +[2, 3, 4, 5] [l, l, o] + # array_slice scalar function #6 (with positive indexes; nested array) query ? select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1, 1); ---- [[1, 2, 3, 4, 5]] +query ? +select array_slice(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), 1, 1); +---- +[[1, 2, 3, 4, 5]] + # array_slice scalar function #7 (with zero and positive number) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 3); ---- [1, 2, 3, 4] [h, e, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 3); +---- +[1, 2, 3, 4] [h, e, l] + # array_slice scalar function #8 (with NULL and positive number) query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL, 3); + # array_slice scalar function #9 (with positive number and NULL) query error select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, NULL); + # array_slice scalar function #10 (with zero-zero) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, 0), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 0); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, 0), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, 0); +---- +[] [] + # array_slice scalar function #11 (with NULL-NULL) query error select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL); + + # array_slice scalar function #12 (with zero and negative number) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, -3); ---- [1] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0, -3); +---- +[1] [h, e] + # array_slice scalar function #13 (with negative number and NULL) query error select array_slice(make_array(1, 2, 3, 4, 5), -2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, NULL); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2, NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, NULL); + # array_slice scalar function #14 (with NULL and negative number) query error select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3); +query error +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL, -3); + # array_slice scalar function #15 (with negative indexes) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -1); ---- [2, 3, 4] [l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -1); +---- +[2, 3, 4] [l, l] + # array_slice scalar function #16 (with negative indexes; almost full array (only with negative indices cannot return full array)) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -5, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -5, -1); ---- [1, 2, 3, 4] [h, e, l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -5, -1), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -5, -1); +---- +[1, 2, 3, 4] [h, e, l, l] + # array_slice scalar function #17 (with negative indexes; first index = second index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -3); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -3); +---- +[] [] + # array_slice scalar function #18 (with negative indexes; first index > second_index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -4, -6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -6); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4, -6), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, -6); +---- +[] [] + # array_slice scalar function #19 (with negative indexes; out of bounds) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -7, -2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -7, -3); ---- [] [] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -7, -2), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7, -3); +---- +[] [] + # array_slice scalar function #20 (with negative indexes; nested array) query ?? select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1), array_slice(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), -1, -1); ---- [[1, 2, 3, 4, 5]] [] +query ?? +select array_slice(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), -2, -1), array_slice(arrow_cast(make_array(make_array(1, 2, 3), make_array(6, 7, 8)), 'LargeList(List(Int64))'), -1, -1); +---- +[[1, 2, 3, 4, 5]] [] + + # array_slice scalar function #21 (with first positive index and last negative index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), 2, -3), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 2, -2); ---- [2] [e, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, -3), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 2, -2); +---- +[2] [e, l] + # array_slice scalar function #22 (with first negative index and last positive index) query ?? select array_slice(make_array(1, 2, 3, 4, 5), -2, 5), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, 4); ---- [4, 5] [l, l] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2, 5), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, 4); +---- +[4, 5] [l, l] + # list_slice scalar function #23 (function alias `array_slice`) query ?? select list_slice(make_array(1, 2, 3, 4, 5), 2, 4), list_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 2); ---- [2, 3, 4] [h, e] +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 1, 2); +---- +[2, 3, 4] [h, e] + # array_slice with columns query ? select array_slice(column1, column2, column3) from slices; @@ -1046,6 +1153,17 @@ select array_slice(column1, column2, column3) from slices; [41, 42, 43, 44, 45, 46] [55, 56, 57, 58, 59, 60] +query ? +select array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, column3) from slices; +---- +[] +[12, 13, 14, 15, 16] +[] +[] +[] +[41, 42, 43, 44, 45, 46] +[55, 56, 57, 58, 59, 60] + # TODO: support NULLS in output instead of `[]` # array_slice with columns and scalars query ??? @@ -1059,6 +1177,17 @@ select array_slice(make_array(1, 2, 3, 4, 5), column2, column3), array_slice(col [1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] [5] [, 54, 55, 56, 57, 58, 59, 60] [55] +query ??? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), 3, column3), array_slice(arrow_cast(column1, 'LargeList(Int64)'), column2, 5) from slices; +---- +[1] [] [, 2, 3, 4, 5] +[] [13, 14, 15, 16] [12, 13, 14, 15] +[] [] [21, 22, 23, , 25] +[] [33] [] +[4, 5] [] [] +[1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] +[5] [, 54, 55, 56, 57, 58, 59, 60] [55] + # make_array with nulls query ??????? select make_array(make_array('a','b'), null),