From 96360056253d72ef3a51856e76c004d398276909 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn=20Horstmann?= Date: Wed, 17 Nov 2021 12:58:01 +0100 Subject: [PATCH] Fix primitive sort when input contains more nulls than the given sort limit --- arrow/src/compute/kernels/sort.rs | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs index 88c7785bc985..6a72224979cd 100644 --- a/arrow/src/compute/kernels/sort.rs +++ b/arrow/src/compute/kernels/sort.rs @@ -1032,13 +1032,11 @@ fn sort_valids( ) where T: ?Sized + Copy, { - let nulls_len = nulls.len(); + let valids_len = valids.len(); if !descending { - sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| cmp(a.1, b.1)); + sort_unstable_by(valids, len.min(valids_len), |a, b| cmp(a.1, b.1)); } else { - sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| { - cmp(a.1, b.1).reverse() - }); + sort_unstable_by(valids, len.min(valids_len), |a, b| cmp(a.1, b.1).reverse()); // reverse to keep a stable ordering nulls.reverse(); } @@ -1050,13 +1048,13 @@ fn sort_valids_array( nulls: &mut [T], len: usize, ) { - let nulls_len = nulls.len(); + let valids_len = valids.len(); if !descending { - sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| { + sort_unstable_by(valids, len.min(valids_len), |a, b| { cmp_array(a.1.as_ref(), b.1.as_ref()) }); } else { - sort_unstable_by(valids, len.saturating_sub(nulls_len), |a, b| { + sort_unstable_by(valids, len.min(valids_len), |a, b| { cmp_array(a.1.as_ref(), b.1.as_ref()).reverse() }); // reverse to keep a stable ordering @@ -1555,6 +1553,19 @@ mod tests { ); } + #[test] + fn test_sort_to_indices_primitive_more_nulls_than_limit() { + test_sort_to_indices_primitive_arrays::( + vec![None, None, Some(3), None, Some(1), None, Some(2)], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(2), + vec![4, 6], + ); + } + #[test] fn test_sort_boolean() { // boolean