Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor lexico sort for future code reuse #423

Merged
merged 1 commit into from
Jun 8, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 69 additions & 44 deletions arrow/src/compute/kernels/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@

//! Defines sort kernel for `ArrayRef`

use std::cmp::Ordering;

use crate::array::*;
use crate::buffer::MutableBuffer;
use crate::compute::take;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};

use std::cmp::Ordering;
use TimeUnit::*;

/// Sort the `ArrayRef` using `SortOptions`.
Expand Down Expand Up @@ -817,26 +815,55 @@ pub fn lexsort_to_indices(
));
};

// map to data and DynComparator
let flat_columns = columns
.iter()
.map(
|column| -> Result<(&ArrayData, DynComparator, SortOptions)> {
// flatten and convert build comparators
// use ArrayData for is_valid checks later to avoid dynamic call
let values = column.values.as_ref();
let data = values.data_ref();
Ok((
data,
build_compare(values, values)?,
column.options.unwrap_or_default(),
))
},
)
.collect::<Result<Vec<(&ArrayData, DynComparator, SortOptions)>>>()?;
let mut value_indices = (0..row_count).collect::<Vec<usize>>();
let mut len = value_indices.len();

if let Some(limit) = limit {
len = limit.min(len);
}

let lexicographical_comparator = LexicographicalComparator::try_new(columns)?;
sort_by(&mut value_indices, len, |a, b| {
lexicographical_comparator.compare(a, b)
});

Ok(UInt32Array::from(
(&value_indices)[0..len]
.iter()
.map(|i| *i as u32)
.collect::<Vec<u32>>(),
))
}

/// It's unstable_sort, may not preserve the order of equal elements
pub fn partial_sort<T, F>(v: &mut [T], limit: usize, mut is_less: F)
where
F: FnMut(&T, &T) -> Ordering,
{
let (before, _mid, _after) = v.select_nth_unstable_by(limit, &mut is_less);
before.sort_unstable_by(is_less);
}

type LexicographicalCompareItem<'a> = (
&'a ArrayData, // data
Box<dyn Fn(usize, usize) -> Ordering + 'a>, // comparator
SortOptions, // sort_option
);

/// A lexicographical comparator that wraps given array data (columns) and can lexicographically compare data
/// at given two indices. The lifetime is the same at the data wrapped.
pub(super) struct LexicographicalComparator<'a> {
compare_items: Vec<LexicographicalCompareItem<'a>>,
}

let lex_comparator = |a_idx: &usize, b_idx: &usize| -> Ordering {
for (data, comparator, sort_option) in flat_columns.iter() {
impl LexicographicalComparator<'_> {
/// lexicographically compare values at the wrapped columns with given indices.
pub(super) fn compare<'a, 'b>(
&'a self,
a_idx: &'b usize,
b_idx: &'b usize,
) -> Ordering {
for (data, comparator, sort_option) in &self.compare_items {
match (data.is_valid(*a_idx), data.is_valid(*b_idx)) {
(true, true) => {
match (comparator)(*a_idx, *b_idx) {
Expand Down Expand Up @@ -871,31 +898,29 @@ pub fn lexsort_to_indices(
}

Ordering::Equal
};

let mut value_indices = (0..row_count).collect::<Vec<usize>>();
let mut len = value_indices.len();

if let Some(limit) = limit {
len = limit.min(len);
}
sort_by(&mut value_indices, len, lex_comparator);

Ok(UInt32Array::from(
(&value_indices)[0..len]
/// Create a new lex comparator that will wrap the given sort columns and give comparison
/// results with two indices.
pub(super) fn try_new(
columns: &[SortColumn],
) -> Result<LexicographicalComparator<'_>> {
let compare_items = columns
.iter()
.map(|i| *i as u32)
.collect::<Vec<u32>>(),
))
}

/// It's unstable_sort, may not preserve the order of equal elements
pub fn partial_sort<T, F>(v: &mut [T], limit: usize, mut is_less: F)
where
F: FnMut(&T, &T) -> Ordering,
{
let (before, _mid, _after) = v.select_nth_unstable_by(limit, &mut is_less);
before.sort_unstable_by(is_less);
.map(|column| {
// flatten and convert build comparators
// use ArrayData for is_valid checks later to avoid dynamic call
let values = column.values.as_ref();
let data = values.data_ref();
Ok((
data,
build_compare(values, values)?,
column.options.unwrap_or_default(),
))
})
.collect::<Result<Vec<_>>>()?;
Ok(LexicographicalComparator { compare_items })
}
}

#[cfg(test)]
Expand Down