Skip to content

Commit

Permalink
Specialized Cursor for StringArray and BinaryArray (apache#5964)
Browse files Browse the repository at this point in the history
* Generify

* Specialized cursor for StringArray and BinaryArray

* fix clippy

* Review feedback

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
2 people authored and korowa committed Apr 13, 2023
1 parent 8f3d4de commit 593a241
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 56 deletions.
156 changes: 124 additions & 32 deletions datafusion/core/src/physical_plan/sorts/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ use crate::physical_plan::sorts::sort::SortOptions;
use arrow::buffer::ScalarBuffer;
use arrow::datatypes::ArrowNativeTypeOp;
use arrow::row::{Row, Rows};
use arrow_array::types::ByteArrayType;
use arrow_array::{Array, ArrowPrimitiveType, GenericByteArray, PrimitiveArray};
use std::cmp::Ordering;

/// A [`Cursor`] for [`Rows`]
Expand Down Expand Up @@ -97,31 +99,107 @@ impl Cursor for RowCursor {
}
}

/// A cursor over sorted, nullable [`ArrowNativeTypeOp`]
/// An [`Array`] that can be converted into [`FieldValues`]
pub trait FieldArray: Array + 'static {
type Values: FieldValues;

fn values(&self) -> Self::Values;
}

/// A comparable set of non-nullable values
pub trait FieldValues {
type Value: ?Sized;

fn len(&self) -> usize;

fn compare(a: &Self::Value, b: &Self::Value) -> Ordering;

fn value(&self, idx: usize) -> &Self::Value;
}

impl<T: ArrowPrimitiveType> FieldArray for PrimitiveArray<T> {
type Values = PrimitiveValues<T::Native>;

fn values(&self) -> Self::Values {
PrimitiveValues(self.values().clone())
}
}

#[derive(Debug)]
pub struct PrimitiveValues<T: ArrowNativeTypeOp>(ScalarBuffer<T>);

impl<T: ArrowNativeTypeOp> FieldValues for PrimitiveValues<T> {
type Value = T;

fn len(&self) -> usize {
self.0.len()
}

#[inline]
fn compare(a: &Self::Value, b: &Self::Value) -> Ordering {
T::compare(*a, *b)
}

#[inline]
fn value(&self, idx: usize) -> &Self::Value {
&self.0[idx]
}
}

impl<T: ByteArrayType> FieldArray for GenericByteArray<T> {
type Values = Self;

fn values(&self) -> Self::Values {
// Once https://github.com/apache/arrow-rs/pull/4048 is released
// Could potentially destructure array into buffers to reduce codegen,
// in a similar vein to what is done for PrimitiveArray
self.clone()
}
}

impl<T: ByteArrayType> FieldValues for GenericByteArray<T> {
type Value = T::Native;

fn len(&self) -> usize {
Array::len(self)
}

#[inline]
fn compare(a: &Self::Value, b: &Self::Value) -> Ordering {
let a: &[u8] = a.as_ref();
let b: &[u8] = b.as_ref();
a.cmp(b)
}

#[inline]
fn value(&self, idx: usize) -> &Self::Value {
self.value(idx)
}
}

/// A cursor over sorted, nullable [`FieldValues`]
///
/// Note: comparing cursors with different `SortOptions` will yield an arbitrary ordering
#[derive(Debug)]
pub struct PrimitiveCursor<T: ArrowNativeTypeOp> {
values: ScalarBuffer<T>,
pub struct FieldCursor<T: FieldValues> {
values: T,
offset: usize,
// If nulls first, the first non-null index
// Otherwise, the first null index
null_threshold: usize,
options: SortOptions,
}

impl<T: ArrowNativeTypeOp> PrimitiveCursor<T> {
/// Create a new [`PrimitiveCursor`] from the provided `values` sorted according to `options`
pub fn new(options: SortOptions, values: ScalarBuffer<T>, null_count: usize) -> Self {
assert!(null_count <= values.len());

impl<T: FieldValues> FieldCursor<T> {
/// Create a new [`FieldCursor`] from the provided `values` sorted according to `options`
pub fn new<A: FieldArray<Values = T>>(options: SortOptions, array: &A) -> Self {
let null_threshold = match options.nulls_first {
true => null_count,
false => values.len() - null_count,
true => array.null_count(),
false => array.len() - array.null_count(),
};

Self {
values,
values: array.values(),
offset: 0,
null_threshold,
options,
Expand All @@ -131,26 +209,22 @@ impl<T: ArrowNativeTypeOp> PrimitiveCursor<T> {
fn is_null(&self) -> bool {
(self.offset < self.null_threshold) == self.options.nulls_first
}

fn value(&self) -> T {
self.values[self.offset]
}
}

impl<T: ArrowNativeTypeOp> PartialEq for PrimitiveCursor<T> {
impl<T: FieldValues> PartialEq for FieldCursor<T> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}

impl<T: ArrowNativeTypeOp> Eq for PrimitiveCursor<T> {}
impl<T: ArrowNativeTypeOp> PartialOrd for PrimitiveCursor<T> {
impl<T: FieldValues> Eq for FieldCursor<T> {}
impl<T: FieldValues> PartialOrd for FieldCursor<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl<T: ArrowNativeTypeOp> Ord for PrimitiveCursor<T> {
impl<T: FieldValues> Ord for FieldCursor<T> {
fn cmp(&self, other: &Self) -> Ordering {
match (self.is_null(), other.is_null()) {
(true, true) => Ordering::Equal,
Expand All @@ -163,19 +237,19 @@ impl<T: ArrowNativeTypeOp> Ord for PrimitiveCursor<T> {
false => Ordering::Less,
},
(false, false) => {
let s_v = self.value();
let o_v = other.value();
let s_v = self.values.value(self.offset);
let o_v = other.values.value(other.offset);

match self.options.descending {
true => o_v.compare(s_v),
false => s_v.compare(o_v),
true => T::compare(o_v, s_v),
false => T::compare(s_v, o_v),
}
}
}
}
}

impl<T: ArrowNativeTypeOp> Cursor for PrimitiveCursor<T> {
impl<T: FieldValues> Cursor for FieldCursor<T> {
fn is_finished(&self) -> bool {
self.offset == self.values.len()
}
Expand All @@ -191,6 +265,24 @@ impl<T: ArrowNativeTypeOp> Cursor for PrimitiveCursor<T> {
mod tests {
use super::*;

fn new_primitive(
options: SortOptions,
values: ScalarBuffer<i32>,
null_count: usize,
) -> FieldCursor<PrimitiveValues<i32>> {
let null_threshold = match options.nulls_first {
true => null_count,
false => values.len() - null_count,
};

FieldCursor {
offset: 0,
values: PrimitiveValues(values),
null_threshold,
options,
}
}

#[test]
fn test_primitive_nulls_first() {
let options = SortOptions {
Expand All @@ -199,9 +291,9 @@ mod tests {
};

let buffer = ScalarBuffer::from(vec![i32::MAX, 1, 2, 3]);
let mut a = PrimitiveCursor::new(options, buffer, 1);
let mut a = new_primitive(options, buffer, 1);
let buffer = ScalarBuffer::from(vec![1, 2, -2, -1, 1, 9]);
let mut b = PrimitiveCursor::new(options, buffer, 2);
let mut b = new_primitive(options, buffer, 2);

// NULL == NULL
assert_eq!(a.cmp(&b), Ordering::Equal);
Expand Down Expand Up @@ -243,9 +335,9 @@ mod tests {
};

let buffer = ScalarBuffer::from(vec![0, 1, i32::MIN, i32::MAX]);
let mut a = PrimitiveCursor::new(options, buffer, 2);
let mut a = new_primitive(options, buffer, 2);
let buffer = ScalarBuffer::from(vec![-1, i32::MAX, i32::MIN]);
let mut b = PrimitiveCursor::new(options, buffer, 2);
let mut b = new_primitive(options, buffer, 2);

// 0 > -1
assert_eq!(a.cmp(&b), Ordering::Greater);
Expand All @@ -269,9 +361,9 @@ mod tests {
};

let buffer = ScalarBuffer::from(vec![6, 1, i32::MIN, i32::MAX]);
let mut a = PrimitiveCursor::new(options, buffer, 3);
let mut a = new_primitive(options, buffer, 3);
let buffer = ScalarBuffer::from(vec![67, -3, i32::MAX, i32::MIN]);
let mut b = PrimitiveCursor::new(options, buffer, 2);
let mut b = new_primitive(options, buffer, 2);

// 6 > 67
assert_eq!(a.cmp(&b), Ordering::Greater);
Expand Down Expand Up @@ -299,9 +391,9 @@ mod tests {
};

let buffer = ScalarBuffer::from(vec![i32::MIN, i32::MAX, 6, 3]);
let mut a = PrimitiveCursor::new(options, buffer, 2);
let mut a = new_primitive(options, buffer, 2);
let buffer = ScalarBuffer::from(vec![i32::MAX, 4546, -3]);
let mut b = PrimitiveCursor::new(options, buffer, 1);
let mut b = new_primitive(options, buffer, 1);

// NULL == NULL
assert_eq!(a.cmp(&b), Ordering::Equal);
Expand Down
18 changes: 14 additions & 4 deletions datafusion/core/src/physical_plan/sorts/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,27 @@ use crate::physical_plan::metrics::MemTrackingMetrics;
use crate::physical_plan::sorts::builder::BatchBuilder;
use crate::physical_plan::sorts::cursor::Cursor;
use crate::physical_plan::sorts::stream::{
PartitionedStream, PrimitiveCursorStream, RowCursorStream,
FieldCursorStream, PartitionedStream, RowCursorStream,
};
use crate::physical_plan::{
PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream,
};
use arrow::datatypes::SchemaRef;
use arrow::datatypes::{DataType, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_array::downcast_primitive;
use arrow_array::*;
use futures::Stream;
use std::pin::Pin;
use std::task::{ready, Context, Poll};

macro_rules! primitive_merge_helper {
($t:ty, $($v:ident),+) => {
merge_helper!(PrimitiveArray<$t>, $($v),+)
};
}

macro_rules! merge_helper {
($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident) => {{
let streams = PrimitiveCursorStream::<$t>::new($sort, $streams);
let streams = FieldCursorStream::<$t>::new($sort, $streams);
return Ok(Box::pin(SortPreservingMergeStream::new(
Box::new(streams),
$schema,
Expand All @@ -58,6 +64,10 @@ pub(crate) fn streaming_merge(
let data_type = sort.expr.data_type(schema.as_ref())?;
downcast_primitive! {
data_type => (primitive_merge_helper, sort, streams, schema, tracking_metrics, batch_size),
DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, tracking_metrics, batch_size)
DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, tracking_metrics, batch_size)
DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, tracking_metrics, batch_size)
DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, tracking_metrics, batch_size)
_ => {}
}
}
Expand Down
30 changes: 10 additions & 20 deletions datafusion/core/src/physical_plan/sorts/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
// under the License.

use crate::common::Result;
use crate::physical_plan::sorts::cursor::{PrimitiveCursor, RowCursor};
use crate::physical_plan::sorts::cursor::{FieldArray, FieldCursor, RowCursor};
use crate::physical_plan::SendableRecordBatchStream;
use crate::physical_plan::{PhysicalExpr, PhysicalSortExpr};
use arrow::array::{Array, ArrowPrimitiveType};
use arrow::array::Array;
use arrow::datatypes::Schema;
use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, SortField};
use datafusion_common::cast::as_primitive_array;
use futures::stream::{Fuse, StreamExt};
use std::marker::PhantomData;
use std::sync::Arc;
Expand Down Expand Up @@ -144,24 +143,23 @@ impl PartitionedStream for RowCursorStream {
}

/// Specialized stream for sorts on single primitive columns
pub struct PrimitiveCursorStream<T: ArrowPrimitiveType> {
pub struct FieldCursorStream<T: FieldArray> {
/// The physical expressions to sort by
sort: PhysicalSortExpr,
/// Input streams
streams: FusedStreams,
phantom: PhantomData<fn(T) -> T>,
}

impl<T: ArrowPrimitiveType> std::fmt::Debug for PrimitiveCursorStream<T> {
impl<T: FieldArray> std::fmt::Debug for FieldCursorStream<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PrimitiveCursorStream")
.field("data_type", &T::DATA_TYPE)
.field("num_streams", &self.streams)
.finish()
}
}

impl<T: ArrowPrimitiveType> PrimitiveCursorStream<T> {
impl<T: FieldArray> FieldCursorStream<T> {
pub fn new(sort: PhysicalSortExpr, streams: Vec<SendableRecordBatchStream>) -> Self {
let streams = streams.into_iter().map(|s| s.fuse()).collect();
Self {
Expand All @@ -171,24 +169,16 @@ impl<T: ArrowPrimitiveType> PrimitiveCursorStream<T> {
}
}

fn convert_batch(
&mut self,
batch: &RecordBatch,
) -> Result<PrimitiveCursor<T::Native>> {
fn convert_batch(&mut self, batch: &RecordBatch) -> Result<FieldCursor<T::Values>> {
let value = self.sort.expr.evaluate(batch)?;
let array = value.into_array(batch.num_rows());
let array = as_primitive_array::<T>(array.as_ref())?;

Ok(PrimitiveCursor::new(
self.sort.options,
array.values().clone(),
array.null_count(),
))
let array = array.as_any().downcast_ref::<T>().expect("field values");
Ok(FieldCursor::new(self.sort.options, array))
}
}

impl<T: ArrowPrimitiveType> PartitionedStream for PrimitiveCursorStream<T> {
type Output = Result<(PrimitiveCursor<T::Native>, RecordBatch)>;
impl<T: FieldArray> PartitionedStream for FieldCursorStream<T> {
type Output = Result<(FieldCursor<T::Values>, RecordBatch)>;

fn partitions(&self) -> usize {
self.streams.0.len()
Expand Down

0 comments on commit 593a241

Please sign in to comment.