Skip to content

Commit

Permalink
refactor: remove lifetime from DynComparator (#542)
Browse files Browse the repository at this point in the history
This commit removes the need for an explicit lifetime on the `DynComparator`.

The rationale behind this change is that callers may wish to share this comparator amongst threads and the explicit lifetime makes this harder to achieve.

As a nice side-effect, performance of the sort kernel seems to have improved:

```
$ critcmp master pr

group                                          master                          pr
-----                                          ------                          --
bool sort 2^12                                 1.03    310.8±1.34µs            1.00    302.8±7.78µs
bool sort nulls 2^12                           1.01    287.4±2.22µs            1.00    284.0±3.23µs
sort 2^10                                      1.04     98.7±3.58µs            1.00     94.6±0.50µs
sort 2^12                                      1.05    510.7±5.56µs            1.00    486.2±9.94µs
sort 2^12 limit 10                             1.05     48.1±0.38µs            1.00     45.6±0.30µs
sort 2^12 limit 100                            1.04     52.8±0.37µs            1.00     50.6±0.41µs
sort 2^12 limit 1000                           1.06    141.1±0.94µs            1.00    132.7±0.95µs
sort 2^12 limit 2^12                           1.03    501.2±4.01µs            1.00    486.5±4.87µs
sort nulls 2^10                                1.02     70.9±0.72µs            1.00     69.4±0.51µs
sort nulls 2^12                                1.02    369.7±3.51µs            1.00   363.0±18.52µs
sort nulls 2^12 limit 10                       1.01     70.6±1.22µs            1.00     70.0±1.27µs
sort nulls 2^12 limit 100                      1.00     71.7±0.82µs            1.00     71.8±1.60µs
sort nulls 2^12 limit 1000                     1.01     80.5±1.55µs            1.00     79.4±1.41µs
sort nulls 2^12 limit 2^12                     1.05    375.4±4.78µs            1.00    356.1±3.04µs
```
  • Loading branch information
e-dard authored Jul 14, 2021
1 parent cdcf013 commit fde79a2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 32 deletions.
48 changes: 19 additions & 29 deletions arrow/src/array/ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::error::{ArrowError, Result};
use num::Float;

/// Compare the values at two arbitrary indices in two arrays.
pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;

/// compares two floats, placing NaNs at last
fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
Expand All @@ -39,60 +39,50 @@ fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
}
}

fn compare_primitives<'a, T: ArrowPrimitiveType>(
left: &'a Array,
right: &'a Array,
) -> DynComparator<'a>
fn compare_primitives<T: ArrowPrimitiveType>(left: &Array, right: &Array) -> DynComparator
where
T::Native: Ord,
{
let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let left: PrimitiveArray<T> = PrimitiveArray::from(left.data().clone());
let right: PrimitiveArray<T> = PrimitiveArray::from(right.data().clone());
Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
}

fn compare_boolean<'a>(left: &'a Array, right: &'a Array) -> DynComparator<'a> {
let left = left.as_any().downcast_ref::<BooleanArray>().unwrap();
let right = right.as_any().downcast_ref::<BooleanArray>().unwrap();
fn compare_boolean(left: &Array, right: &Array) -> DynComparator {
let left: BooleanArray = BooleanArray::from(left.data().clone());
let right: BooleanArray = BooleanArray::from(right.data().clone());

Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
}

fn compare_float<'a, T: ArrowPrimitiveType>(
left: &'a Array,
right: &'a Array,
) -> DynComparator<'a>
fn compare_float<T: ArrowPrimitiveType>(left: &Array, right: &Array) -> DynComparator
where
T::Native: Float,
{
let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let left: PrimitiveArray<T> = PrimitiveArray::from(left.data().clone());
let right: PrimitiveArray<T> = PrimitiveArray::from(right.data().clone());
Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
}

fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
fn compare_string<T>(left: &Array, right: &Array) -> DynComparator
where
T: StringOffsetSizeTrait,
{
let left = left
.as_any()
.downcast_ref::<GenericStringArray<T>>()
.unwrap();
let right = right
.as_any()
.downcast_ref::<GenericStringArray<T>>()
.unwrap();
let left: StringArray = StringArray::from(left.data().clone());
let right: StringArray = StringArray::from(right.data().clone());

Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
}

fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
fn compare_dict_string<T>(left: &Array, right: &Array) -> DynComparator
where
T: ArrowDictionaryKeyType,
{
let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
let left_keys = left.keys();
let right_keys = right.keys();

let left_keys: PrimitiveArray<T> = PrimitiveArray::from(left.keys().data().clone());
let right_keys: PrimitiveArray<T> = PrimitiveArray::from(right.keys().data().clone());
let left_values = StringArray::from(left.values().data().clone());
let right_values = StringArray::from(right.values().data().clone());

Expand Down Expand Up @@ -125,7 +115,7 @@ where
/// ```
// This is a factory of comparisons.
// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> {
pub fn build_compare(left: &Array, right: &Array) -> Result<DynComparator> {
use DataType::*;
use IntervalUnit::*;
use TimeUnit::*;
Expand Down
6 changes: 3 additions & 3 deletions arrow/src/compute/kernels/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -886,9 +886,9 @@ where
}

type LexicographicalCompareItem<'a> = (
&'a ArrayData, // data
Box<dyn Fn(usize, usize) -> Ordering + 'a>, // comparator
SortOptions, // sort_option
&'a ArrayData, // data
DynComparator, // comparator
SortOptions, // sort_option
);

/// A lexicographical comparator that wraps given array data (columns) and can lexicographically compare data
Expand Down

0 comments on commit fde79a2

Please sign in to comment.