diff --git a/datafusion/core/src/physical_plan/sorts/cursor.rs b/datafusion/core/src/physical_plan/sorts/cursor.rs index 7e8d600542fd..a9e512213057 100644 --- a/datafusion/core/src/physical_plan/sorts/cursor.rs +++ b/datafusion/core/src/physical_plan/sorts/cursor.rs @@ -19,6 +19,8 @@ use crate::physical_plan::sorts::sort::SortOptions; use arrow::buffer::ScalarBuffer; use arrow::datatypes::ArrowNativeTypeOp; use arrow::row::{Row, Rows}; +use arrow_array::types::ByteArrayType; +use arrow_array::{Array, ArrowPrimitiveType, GenericByteArray, PrimitiveArray}; use std::cmp::Ordering; /// A [`Cursor`] for [`Rows`] @@ -97,12 +99,90 @@ impl Cursor for RowCursor { } } -/// A cursor over sorted, nullable [`ArrowNativeTypeOp`] +/// An [`Array`] that can be converted into [`FieldValues`] +pub trait FieldArray: Array + 'static { + type Values: FieldValues; + + fn values(&self) -> Self::Values; +} + +/// A comparable set of non-nullable values +pub trait FieldValues { + type Value: ?Sized; + + fn len(&self) -> usize; + + fn compare(a: &Self::Value, b: &Self::Value) -> Ordering; + + fn value(&self, idx: usize) -> &Self::Value; +} + +impl FieldArray for PrimitiveArray { + type Values = PrimitiveValues; + + fn values(&self) -> Self::Values { + PrimitiveValues(self.values().clone()) + } +} + +#[derive(Debug)] +pub struct PrimitiveValues(ScalarBuffer); + +impl FieldValues for PrimitiveValues { + type Value = T; + + fn len(&self) -> usize { + self.0.len() + } + + #[inline] + fn compare(a: &Self::Value, b: &Self::Value) -> Ordering { + T::compare(*a, *b) + } + + #[inline] + fn value(&self, idx: usize) -> &Self::Value { + &self.0[idx] + } +} + +impl FieldArray for GenericByteArray { + type Values = Self; + + fn values(&self) -> Self::Values { + // Once https://github.com/apache/arrow-rs/pull/4048 is released + // Could potentially destructure array into buffers to reduce codegen, + // in a similar vein to what is done for PrimitiveArray + self.clone() + } +} + +impl FieldValues for GenericByteArray { + type Value = T::Native; + + fn len(&self) -> usize { + Array::len(self) + } + + #[inline] + fn compare(a: &Self::Value, b: &Self::Value) -> Ordering { + let a: &[u8] = a.as_ref(); + let b: &[u8] = b.as_ref(); + a.cmp(b) + } + + #[inline] + fn value(&self, idx: usize) -> &Self::Value { + self.value(idx) + } +} + +/// A cursor over sorted, nullable [`FieldValues`] /// /// Note: comparing cursors with different `SortOptions` will yield an arbitrary ordering #[derive(Debug)] -pub struct PrimitiveCursor { - values: ScalarBuffer, +pub struct FieldCursor { + values: T, offset: usize, // If nulls first, the first non-null index // Otherwise, the first null index @@ -110,18 +190,16 @@ pub struct PrimitiveCursor { options: SortOptions, } -impl PrimitiveCursor { - /// Create a new [`PrimitiveCursor`] from the provided `values` sorted according to `options` - pub fn new(options: SortOptions, values: ScalarBuffer, null_count: usize) -> Self { - assert!(null_count <= values.len()); - +impl FieldCursor { + /// Create a new [`FieldCursor`] from the provided `values` sorted according to `options` + pub fn new>(options: SortOptions, array: &A) -> Self { let null_threshold = match options.nulls_first { - true => null_count, - false => values.len() - null_count, + true => array.null_count(), + false => array.len() - array.null_count(), }; Self { - values, + values: array.values(), offset: 0, null_threshold, options, @@ -131,26 +209,22 @@ impl PrimitiveCursor { fn is_null(&self) -> bool { (self.offset < self.null_threshold) == self.options.nulls_first } - - fn value(&self) -> T { - self.values[self.offset] - } } -impl PartialEq for PrimitiveCursor { +impl PartialEq for FieldCursor { fn eq(&self, other: &Self) -> bool { self.cmp(other).is_eq() } } -impl Eq for PrimitiveCursor {} -impl PartialOrd for PrimitiveCursor { +impl Eq for FieldCursor {} +impl PartialOrd for FieldCursor { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for PrimitiveCursor { +impl Ord for FieldCursor { fn cmp(&self, other: &Self) -> Ordering { match (self.is_null(), other.is_null()) { (true, true) => Ordering::Equal, @@ -163,19 +237,19 @@ impl Ord for PrimitiveCursor { false => Ordering::Less, }, (false, false) => { - let s_v = self.value(); - let o_v = other.value(); + let s_v = self.values.value(self.offset); + let o_v = other.values.value(other.offset); match self.options.descending { - true => o_v.compare(s_v), - false => s_v.compare(o_v), + true => T::compare(o_v, s_v), + false => T::compare(s_v, o_v), } } } } } -impl Cursor for PrimitiveCursor { +impl Cursor for FieldCursor { fn is_finished(&self) -> bool { self.offset == self.values.len() } @@ -191,6 +265,24 @@ impl Cursor for PrimitiveCursor { mod tests { use super::*; + fn new_primitive( + options: SortOptions, + values: ScalarBuffer, + null_count: usize, + ) -> FieldCursor> { + let null_threshold = match options.nulls_first { + true => null_count, + false => values.len() - null_count, + }; + + FieldCursor { + offset: 0, + values: PrimitiveValues(values), + null_threshold, + options, + } + } + #[test] fn test_primitive_nulls_first() { let options = SortOptions { @@ -199,9 +291,9 @@ mod tests { }; let buffer = ScalarBuffer::from(vec![i32::MAX, 1, 2, 3]); - let mut a = PrimitiveCursor::new(options, buffer, 1); + let mut a = new_primitive(options, buffer, 1); let buffer = ScalarBuffer::from(vec![1, 2, -2, -1, 1, 9]); - let mut b = PrimitiveCursor::new(options, buffer, 2); + let mut b = new_primitive(options, buffer, 2); // NULL == NULL assert_eq!(a.cmp(&b), Ordering::Equal); @@ -243,9 +335,9 @@ mod tests { }; let buffer = ScalarBuffer::from(vec![0, 1, i32::MIN, i32::MAX]); - let mut a = PrimitiveCursor::new(options, buffer, 2); + let mut a = new_primitive(options, buffer, 2); let buffer = ScalarBuffer::from(vec![-1, i32::MAX, i32::MIN]); - let mut b = PrimitiveCursor::new(options, buffer, 2); + let mut b = new_primitive(options, buffer, 2); // 0 > -1 assert_eq!(a.cmp(&b), Ordering::Greater); @@ -269,9 +361,9 @@ mod tests { }; let buffer = ScalarBuffer::from(vec![6, 1, i32::MIN, i32::MAX]); - let mut a = PrimitiveCursor::new(options, buffer, 3); + let mut a = new_primitive(options, buffer, 3); let buffer = ScalarBuffer::from(vec![67, -3, i32::MAX, i32::MIN]); - let mut b = PrimitiveCursor::new(options, buffer, 2); + let mut b = new_primitive(options, buffer, 2); // 6 > 67 assert_eq!(a.cmp(&b), Ordering::Greater); @@ -299,9 +391,9 @@ mod tests { }; let buffer = ScalarBuffer::from(vec![i32::MIN, i32::MAX, 6, 3]); - let mut a = PrimitiveCursor::new(options, buffer, 2); + let mut a = new_primitive(options, buffer, 2); let buffer = ScalarBuffer::from(vec![i32::MAX, 4546, -3]); - let mut b = PrimitiveCursor::new(options, buffer, 1); + let mut b = new_primitive(options, buffer, 1); // NULL == NULL assert_eq!(a.cmp(&b), Ordering::Equal); diff --git a/datafusion/core/src/physical_plan/sorts/merge.rs b/datafusion/core/src/physical_plan/sorts/merge.rs index 1ea89b9a8194..7e2d986e9d94 100644 --- a/datafusion/core/src/physical_plan/sorts/merge.rs +++ b/datafusion/core/src/physical_plan/sorts/merge.rs @@ -20,21 +20,27 @@ use crate::physical_plan::metrics::MemTrackingMetrics; use crate::physical_plan::sorts::builder::BatchBuilder; use crate::physical_plan::sorts::cursor::Cursor; use crate::physical_plan::sorts::stream::{ - PartitionedStream, PrimitiveCursorStream, RowCursorStream, + FieldCursorStream, PartitionedStream, RowCursorStream, }; use crate::physical_plan::{ PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream, }; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow_array::downcast_primitive; +use arrow_array::*; use futures::Stream; use std::pin::Pin; use std::task::{ready, Context, Poll}; macro_rules! primitive_merge_helper { + ($t:ty, $($v:ident),+) => { + merge_helper!(PrimitiveArray<$t>, $($v),+) + }; +} + +macro_rules! merge_helper { ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident) => {{ - let streams = PrimitiveCursorStream::<$t>::new($sort, $streams); + let streams = FieldCursorStream::<$t>::new($sort, $streams); return Ok(Box::pin(SortPreservingMergeStream::new( Box::new(streams), $schema, @@ -58,6 +64,10 @@ pub(crate) fn streaming_merge( let data_type = sort.expr.data_type(schema.as_ref())?; downcast_primitive! { data_type => (primitive_merge_helper, sort, streams, schema, tracking_metrics, batch_size), + DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, tracking_metrics, batch_size) + DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, tracking_metrics, batch_size) + DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, tracking_metrics, batch_size) + DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, tracking_metrics, batch_size) _ => {} } } diff --git a/datafusion/core/src/physical_plan/sorts/stream.rs b/datafusion/core/src/physical_plan/sorts/stream.rs index d6f49e58b5b6..9de6e260dbc3 100644 --- a/datafusion/core/src/physical_plan/sorts/stream.rs +++ b/datafusion/core/src/physical_plan/sorts/stream.rs @@ -16,14 +16,13 @@ // under the License. use crate::common::Result; -use crate::physical_plan::sorts::cursor::{PrimitiveCursor, RowCursor}; +use crate::physical_plan::sorts::cursor::{FieldArray, FieldCursor, RowCursor}; use crate::physical_plan::SendableRecordBatchStream; use crate::physical_plan::{PhysicalExpr, PhysicalSortExpr}; -use arrow::array::{Array, ArrowPrimitiveType}; +use arrow::array::Array; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; -use datafusion_common::cast::as_primitive_array; use futures::stream::{Fuse, StreamExt}; use std::marker::PhantomData; use std::sync::Arc; @@ -144,7 +143,7 @@ impl PartitionedStream for RowCursorStream { } /// Specialized stream for sorts on single primitive columns -pub struct PrimitiveCursorStream { +pub struct FieldCursorStream { /// The physical expressions to sort by sort: PhysicalSortExpr, /// Input streams @@ -152,16 +151,15 @@ pub struct PrimitiveCursorStream { phantom: PhantomData T>, } -impl std::fmt::Debug for PrimitiveCursorStream { +impl std::fmt::Debug for FieldCursorStream { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PrimitiveCursorStream") - .field("data_type", &T::DATA_TYPE) .field("num_streams", &self.streams) .finish() } } -impl PrimitiveCursorStream { +impl FieldCursorStream { pub fn new(sort: PhysicalSortExpr, streams: Vec) -> Self { let streams = streams.into_iter().map(|s| s.fuse()).collect(); Self { @@ -171,24 +169,16 @@ impl PrimitiveCursorStream { } } - fn convert_batch( - &mut self, - batch: &RecordBatch, - ) -> Result> { + fn convert_batch(&mut self, batch: &RecordBatch) -> Result> { let value = self.sort.expr.evaluate(batch)?; let array = value.into_array(batch.num_rows()); - let array = as_primitive_array::(array.as_ref())?; - - Ok(PrimitiveCursor::new( - self.sort.options, - array.values().clone(), - array.null_count(), - )) + let array = array.as_any().downcast_ref::().expect("field values"); + Ok(FieldCursor::new(self.sort.options, array)) } } -impl PartitionedStream for PrimitiveCursorStream { - type Output = Result<(PrimitiveCursor, RecordBatch)>; +impl PartitionedStream for FieldCursorStream { + type Output = Result<(FieldCursor, RecordBatch)>; fn partitions(&self) -> usize { self.streams.0.len()