diff --git a/datafusion/core/src/physical_plan/sorts/builder.rs b/datafusion/core/src/physical_plan/sorts/builder.rs new file mode 100644 index 0000000000000..9a95f2926c910 --- /dev/null +++ b/datafusion/core/src/physical_plan/sorts/builder.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::common::Result; +use crate::physical_plan::sorts::index::RowIndex; +use arrow::array::{make_array, MutableArrayData}; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use std::collections::VecDeque; + +#[derive(Debug)] +pub struct BatchBuilder { + /// The schema of the RecordBatches yielded by this stream + schema: SchemaRef, + /// For each input stream maintain a dequeue of RecordBatches + /// + /// Exhausted batches will be popped off the front once all + /// their rows have been yielded to the output + batches: Vec>, + + /// The accumulated row indexes for the next record batch + indices: Vec, +} + +impl BatchBuilder { + pub fn new(schema: SchemaRef, stream_count: usize, batch_size: usize) -> Self { + let batches = (0..stream_count).map(|_| VecDeque::new()).collect(); + + Self { + schema, + batches, + indices: Vec::with_capacity(batch_size), + } + } + + pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) { + self.batches[stream_idx].push_back(batch) + } + + pub fn push_row(&mut self, stream_idx: usize, row_idx: usize) { + let batch_idx = self.batches[stream_idx].len() - 1; + self.indices.push(RowIndex { + stream_idx, + batch_idx, + row_idx, + }); + } + + pub fn len(&self) -> usize { + self.indices.len() + } + + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + + /// Drains the in_progress row indexes, and builds a new RecordBatch from them + /// + /// Will then drop any batches for which all rows have been yielded to the output + pub fn build_record_batch(&mut self) -> Result> { + if self.indices.is_empty() { + return Ok(None); + } + + // Mapping from stream index to the index of the first buffer from that stream + let mut buffer_idx = 0; + let mut stream_to_buffer_idx = Vec::with_capacity(self.batches.len()); + + for batches in &self.batches { + stream_to_buffer_idx.push(buffer_idx); + buffer_idx += batches.len(); + } + + let columns = self + .schema + .fields() + .iter() + .enumerate() + .map(|(column_idx, field)| { + let arrays = self + .batches + .iter() + .flat_map(|batch| { + batch.iter().map(|batch| batch.column(column_idx).data()) + }) + .collect(); + + let mut array_data = MutableArrayData::new( + arrays, + field.is_nullable(), + self.indices.len(), + ); + + let first = &self.indices[0]; + let mut buffer_idx = + stream_to_buffer_idx[first.stream_idx] + first.batch_idx; + let mut start_row_idx = first.row_idx; + let mut end_row_idx = start_row_idx + 1; + + for row_index in self.indices.iter().skip(1) { + let next_buffer_idx = + stream_to_buffer_idx[row_index.stream_idx] + row_index.batch_idx; + + if next_buffer_idx == buffer_idx && row_index.row_idx == end_row_idx { + // subsequent row in same batch + end_row_idx += 1; + continue; + } + + // emit current batch of rows for current buffer + array_data.extend(buffer_idx, start_row_idx, end_row_idx); + + // start new batch of rows + buffer_idx = next_buffer_idx; + start_row_idx = row_index.row_idx; + end_row_idx = start_row_idx + 1; + } + + // emit final batch of rows + array_data.extend(buffer_idx, start_row_idx, end_row_idx); + make_array(array_data.freeze()) + }) + .collect(); + + self.indices.clear(); + + // New cursors are only created once the previous cursor for the stream + // is finished. This means all remaining rows from all but the last batch + // for each stream have been yielded to the newly created record batch + // + // Additionally as `in_progress` has been drained, there are no longer + // any RowIndex's reliant on the batch indexes + // + // We can therefore drop all but the last batch for each stream + for batches in &mut self.batches { + if batches.len() > 1 { + // Drain all but the last batch + batches.drain(0..(batches.len() - 1)); + } + } + + Ok(Some(RecordBatch::try_new(self.schema.clone(), columns)?)) + } +} diff --git a/datafusion/core/src/physical_plan/sorts/cursor.rs b/datafusion/core/src/physical_plan/sorts/cursor.rs index e52544cf50022..86d99c1ac3dab 100644 --- a/datafusion/core/src/physical_plan/sorts/cursor.rs +++ b/datafusion/core/src/physical_plan/sorts/cursor.rs @@ -110,3 +110,19 @@ impl Ord for SortKeyCursor { } } } + +pub trait Cursor: Ord { + fn is_finished(&self) -> bool; + + fn advance(&mut self) -> Option; +} + +impl Cursor for SortKeyCursor { + fn is_finished(&self) -> bool { + self.is_finished() + } + + fn advance(&mut self) -> Option { + (!self.is_finished()).then(|| self.advance()) + } +} diff --git a/datafusion/core/src/physical_plan/sorts/merge.rs b/datafusion/core/src/physical_plan/sorts/merge.rs new file mode 100644 index 0000000000000..02a34e7294e17 --- /dev/null +++ b/datafusion/core/src/physical_plan/sorts/merge.rs @@ -0,0 +1,258 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::common::Result; +use crate::physical_plan::metrics::MemTrackingMetrics; +use crate::physical_plan::sorts::builder::BatchBuilder; +use crate::physical_plan::sorts::cursor::Cursor; +use crate::physical_plan::sorts::stream::{PartitionedStream, SortKeyCursorStream}; +use crate::physical_plan::{ + PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream, +}; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use futures::Stream; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +/// Perform a streaming merge of [`SendableRecordBatchStream`] +pub(crate) fn streaming_merge( + streams: Vec, + schema: SchemaRef, + expressions: &[PhysicalSortExpr], + tracking_metrics: MemTrackingMetrics, + batch_size: usize, +) -> Result { + let streams = SortKeyCursorStream::try_new(schema.as_ref(), expressions, streams)?; + + Ok(Box::pin(SortPreservingMergeStream::new( + Box::new(streams), + schema, + tracking_metrics, + batch_size, + ))) +} + +/// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`] +type CursorStream = Box>>; + +#[derive(Debug)] +struct SortPreservingMergeStream { + in_progress: BatchBuilder, + + /// The sorted input streams to merge together + streams: CursorStream, + + /// used to record execution metrics + tracking_metrics: MemTrackingMetrics, + + /// If the stream has encountered an error + aborted: bool, + + /// A loser tree that always produces the minimum cursor + /// + /// Node 0 stores the top winner, Nodes 1..num_streams store + /// the loser nodes + /// + /// This implements a "Tournament Tree" (aka Loser Tree) to keep + /// track of the current smallest element at the top. When the top + /// record is taken, the tree structure is not modified, and only + /// the path from bottom to top is visited, keeping the number of + /// comparisons close to the theoretical limit of `log(S)`. + /// + /// reference: + loser_tree: Vec, + + /// If the most recently yielded overall winner has been replaced + /// within the loser tree. A value of `false` indicates that the + /// overall winner has been yielded but the loser tree has not + /// been updated + loser_tree_adjusted: bool, + + /// target batch size + batch_size: usize, + + /// Vector that holds cursors for each non-exhausted input partition + cursors: Vec>, +} + +impl SortPreservingMergeStream { + fn new( + streams: CursorStream, + schema: SchemaRef, + tracking_metrics: MemTrackingMetrics, + batch_size: usize, + ) -> Self { + let stream_count = streams.partitions(); + + Self { + in_progress: BatchBuilder::new(schema, stream_count, batch_size), + streams, + tracking_metrics, + aborted: false, + cursors: (0..stream_count).map(|_| None).collect(), + loser_tree: vec![], + loser_tree_adjusted: false, + batch_size, + } + } + + /// If the stream at the given index is not exhausted, and the last cursor for the + /// stream is finished, poll the stream for the next RecordBatch and create a new + /// cursor for the stream from the returned result + fn maybe_poll_stream( + &mut self, + cx: &mut Context<'_>, + idx: usize, + ) -> Poll> { + if self.cursors[idx] + .as_ref() + .map(|cursor| !cursor.is_finished()) + .unwrap_or(false) + { + // Cursor is not finished - don't need a new RecordBatch yet + return Poll::Ready(Ok(())); + } + + match futures::ready!(self.streams.poll_next(cx, idx)) { + None => Poll::Ready(Ok(())), + Some(Err(e)) => Poll::Ready(Err(e)), + Some(Ok((cursor, batch))) => { + self.cursors[idx] = Some(cursor); + self.in_progress.push_batch(idx, batch); + Poll::Ready(Ok(())) + } + } + } + + fn poll_next_inner( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.aborted { + return Poll::Ready(None); + } + // try to initialize the loser tree + if self.loser_tree.is_empty() { + // Ensure all non-exhausted streams have a cursor from which + // rows can be pulled + for i in 0..self.streams.partitions() { + if let Err(e) = ready!(self.maybe_poll_stream(cx, i)) { + self.aborted = true; + return Poll::Ready(Some(Err(e))); + } + } + self.init_loser_tree(); + } + + // NB timer records time taken on drop, so there are no + // calls to `timer.done()` below. + let elapsed_compute = self.tracking_metrics.elapsed_compute().clone(); + let _timer = elapsed_compute.timer(); + + loop { + // Adjust the loser tree if necessary, returning control if needed + if !self.loser_tree_adjusted { + let winner = self.loser_tree[0]; + if let Err(e) = ready!(self.maybe_poll_stream(cx, winner)) { + self.aborted = true; + return Poll::Ready(Some(Err(e))); + } + self.update_loser_tree(); + } + + let stream_idx = self.loser_tree[0]; + let cursor = self.cursors[stream_idx].as_mut(); + if let Some(row_idx) = cursor.and_then(|c| c.advance()) { + self.loser_tree_adjusted = false; + self.in_progress.push_row(stream_idx, row_idx); + if self.in_progress.len() < self.batch_size { + continue; + } + } + + return Poll::Ready(self.in_progress.build_record_batch().transpose()); + } + } + + /// Returns `true` if the cursor at index `a` is greater than at index `b` + #[inline] + fn is_gt(&self, a: usize, b: usize) -> bool { + match (&self.cursors[a], &self.cursors[b]) { + (None, _) => true, + (_, None) => false, + (Some(a), Some(b)) => b < a, + } + } + + /// Attempts to initialize the loser tree with one value from each + /// non exhausted input, if possible + fn init_loser_tree(&mut self) { + // Init loser tree + self.loser_tree = vec![usize::MAX; self.cursors.len()]; + for i in 0..self.cursors.len() { + let mut winner = i; + let mut cmp_node = (self.cursors.len() + i) / 2; + while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX { + let challenger = self.loser_tree[cmp_node]; + if self.is_gt(winner, challenger) { + self.loser_tree[cmp_node] = winner; + winner = challenger; + } + + cmp_node /= 2; + } + self.loser_tree[cmp_node] = winner; + } + self.loser_tree_adjusted = true; + } + + /// Attempts to updated the loser tree, if possible + fn update_loser_tree(&mut self) { + let mut winner = self.loser_tree[0]; + // Replace overall winner by walking tree of losers + let mut cmp_node = (self.cursors.len() + winner) / 2; + while cmp_node != 0 { + let challenger = self.loser_tree[cmp_node]; + if self.is_gt(winner, challenger) { + self.loser_tree[cmp_node] = winner; + winner = challenger; + } + cmp_node /= 2; + } + self.loser_tree[0] = winner; + self.loser_tree_adjusted = true; + } +} + +impl Stream for SortPreservingMergeStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let poll = self.poll_next_inner(cx); + self.tracking_metrics.record_poll(poll) + } +} + +impl RecordBatchStream for SortPreservingMergeStream { + fn schema(&self) -> SchemaRef { + self.in_progress.schema().clone() + } +} diff --git a/datafusion/core/src/physical_plan/sorts/mod.rs b/datafusion/core/src/physical_plan/sorts/mod.rs index db6ab5c604e2b..cd5dae27dcc7e 100644 --- a/datafusion/core/src/physical_plan/sorts/mod.rs +++ b/datafusion/core/src/physical_plan/sorts/mod.rs @@ -17,30 +17,14 @@ //! Sort functionalities -use crate::physical_plan::SendableRecordBatchStream; -use std::fmt::{Debug, Formatter}; - +mod builder; mod cursor; mod index; +mod merge; pub mod sort; pub mod sort_preserving_merge; +mod stream; pub use cursor::SortKeyCursor; pub use index::RowIndex; - -pub(crate) struct SortedStream { - stream: SendableRecordBatchStream, - mem_used: usize, -} - -impl Debug for SortedStream { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - write!(f, "InMemSorterStream") - } -} - -impl SortedStream { - pub(crate) fn new(stream: SendableRecordBatchStream, mem_used: usize) -> Self { - Self { stream, mem_used } - } -} +pub(crate) use merge::streaming_merge; diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index c3fc06206ca15..e428b77511b2d 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -30,8 +30,7 @@ use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ BaselineMetrics, CompositeMetricsSet, MemTrackingMetrics, MetricsSet, }; -use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream; -use crate::physical_plan::sorts::SortedStream; +use crate::physical_plan::sorts::merge::streaming_merge; use crate::physical_plan::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; use crate::physical_plan::{ DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, Partitioning, @@ -169,37 +168,39 @@ impl ExternalSorter { let batch_size = self.session_config.batch_size(); if self.spilled_before() { - let tracking_metrics = self + let intermediate_metrics = self .metrics_set .new_intermediate_tracking(self.partition_id, &self.runtime.memory_pool); - let mut streams: Vec = vec![]; + let mut merge_metrics = self + .metrics_set + .new_final_tracking(self.partition_id, &self.runtime.memory_pool); + + let mut streams = vec![]; if !self.in_mem_batches.is_empty() { let in_mem_stream = in_mem_partial_sort( &mut self.in_mem_batches, self.schema.clone(), &self.expr, batch_size, - tracking_metrics, + intermediate_metrics, self.fetch, )?; - let prev_used = self.reservation.free(); - streams.push(SortedStream::new(in_mem_stream, prev_used)); + merge_metrics.init_mem_used(self.reservation.free()); + streams.push(in_mem_stream); } for spill in self.spills.drain(..) { let stream = read_spill_as_stream(spill, self.schema.clone())?; - streams.push(SortedStream::new(stream, 0)); + streams.push(stream); } - let tracking_metrics = self - .metrics_set - .new_final_tracking(self.partition_id, &self.runtime.memory_pool); - Ok(Box::pin(SortPreservingMergeStream::new_from_streams( + + streaming_merge( streams, self.schema.clone(), &self.expr, - tracking_metrics, + merge_metrics, self.session_config.batch_size(), - )?)) + ) } else if !self.in_mem_batches.is_empty() { let tracking_metrics = self .metrics_set diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs index 14204ef3b4b55..b22199c67bb77 100644 --- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs @@ -18,19 +18,10 @@ //! Defines the sort preserving merge plan use std::any::Any; -use std::collections::VecDeque; -use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; -use arrow::row::{RowConverter, SortField}; -use arrow::{ - array::{make_array as make_arrow_array, MutableArrayData}, - datatypes::SchemaRef, - record_batch::RecordBatch, -}; -use futures::stream::{Fuse, FusedStream}; -use futures::{ready, Stream, StreamExt}; +use arrow::array::Array; +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; use log::debug; use tokio::sync::mpsc; @@ -39,12 +30,11 @@ use crate::execution::context::TaskContext; use crate::physical_plan::metrics::{ ExecutionPlanMetricsSet, MemTrackingMetrics, MetricsSet, }; -use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream}; +use crate::physical_plan::sorts::streaming_merge; use crate::physical_plan::stream::RecordBatchReceiverStream; use crate::physical_plan::{ common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType, - Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, - SendableRecordBatchStream, Statistics, + Distribution, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; use datafusion_physical_expr::{ make_sort_requirements_from_exprs, EquivalenceProperties, PhysicalSortRequirement, @@ -206,34 +196,27 @@ impl ExecutionPlan for SortPreservingMergeExec { context.clone(), ); - SortedStream::new( - RecordBatchReceiverStream::create( - &schema, - receiver, - join_handle, - ), - 0, + RecordBatchReceiverStream::create( + &schema, + receiver, + join_handle, ) }) .collect(), Err(_) => (0..input_partitions) - .map(|partition| { - let stream = - self.input.execute(partition, context.clone())?; - Ok(SortedStream::new(stream, 0)) - }) + .map(|partition| self.input.execute(partition, context.clone())) .collect::>()?, }; debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute"); - let result = Box::pin(SortPreservingMergeStream::new_from_streams( + let result = streaming_merge( receivers, schema, &self.expr, tracking_metrics, context.session_config().batch_size(), - )?); + )?; debug!("Got stream result from SortPreservingMergeStream::new_from_receivers"); @@ -264,445 +247,6 @@ impl ExecutionPlan for SortPreservingMergeExec { } } -struct MergingStreams { - /// The sorted input streams to merge together - streams: Vec>, - /// number of streams - num_streams: usize, -} - -impl std::fmt::Debug for MergingStreams { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MergingStreams") - .field("num_streams", &self.num_streams) - .finish() - } -} - -impl MergingStreams { - fn new(input_streams: Vec>) -> Self { - Self { - num_streams: input_streams.len(), - streams: input_streams, - } - } - - fn num_streams(&self) -> usize { - self.num_streams - } -} - -#[derive(Debug)] -pub(crate) struct SortPreservingMergeStream { - /// The schema of the RecordBatches yielded by this stream - schema: SchemaRef, - - /// The sorted input streams to merge together - streams: MergingStreams, - - /// For each input stream maintain a dequeue of RecordBatches - /// - /// Exhausted batches will be popped off the front once all - /// their rows have been yielded to the output - batches: Vec>, - - /// The accumulated row indexes for the next record batch - in_progress: Vec, - - /// The physical expressions to sort by - column_expressions: Vec>, - - /// used to record execution metrics - tracking_metrics: MemTrackingMetrics, - - /// If the stream has encountered an error - aborted: bool, - - /// Vector that holds all [`SortKeyCursor`]s - cursors: Vec>, - - /// A loser tree that always produces the minimum cursor - /// - /// Node 0 stores the top winner, Nodes 1..num_streams store - /// the loser nodes - /// - /// This implements a "Tournament Tree" (aka Loser Tree) to keep - /// track of the current smallest element at the top. When the top - /// record is taken, the tree structure is not modified, and only - /// the path from bottom to top is visited, keeping the number of - /// comparisons close to the theoretical limit of `log(S)`. - /// - /// reference: - loser_tree: Vec, - - /// If the most recently yielded overall winner has been replaced - /// within the loser tree. A value of `false` indicates that the - /// overall winner has been yielded but the loser tree has not - /// been updated - loser_tree_adjusted: bool, - - /// target batch size - batch_size: usize, - - /// row converter - row_converter: RowConverter, -} - -impl SortPreservingMergeStream { - pub(crate) fn new_from_streams( - streams: Vec, - schema: SchemaRef, - expressions: &[PhysicalSortExpr], - mut tracking_metrics: MemTrackingMetrics, - batch_size: usize, - ) -> Result { - let stream_count = streams.len(); - let batches = (0..stream_count).map(|_| VecDeque::new()).collect(); - tracking_metrics.init_mem_used(streams.iter().map(|s| s.mem_used).sum()); - let wrappers = streams.into_iter().map(|s| s.stream.fuse()).collect(); - - let sort_fields = expressions - .iter() - .map(|expr| { - let data_type = expr.expr.data_type(&schema)?; - Ok(SortField::new_with_options(data_type, expr.options)) - }) - .collect::>>()?; - let row_converter = RowConverter::new(sort_fields)?; - - Ok(Self { - schema, - batches, - streams: MergingStreams::new(wrappers), - column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), - tracking_metrics, - aborted: false, - in_progress: vec![], - cursors: (0..stream_count).map(|_| None).collect(), - loser_tree: Vec::with_capacity(stream_count), - loser_tree_adjusted: false, - batch_size, - row_converter, - }) - } - - /// If the stream at the given index is not exhausted, and the last cursor for the - /// stream is finished, poll the stream for the next RecordBatch and create a new - /// cursor for the stream from the returned result - fn maybe_poll_stream( - &mut self, - cx: &mut Context<'_>, - idx: usize, - ) -> Poll> { - if self.cursors[idx] - .as_ref() - .map(|cursor| !cursor.is_finished()) - .unwrap_or(false) - { - // Cursor is not finished - don't need a new RecordBatch yet - return Poll::Ready(Ok(())); - } - let mut empty_batch = false; - { - let stream = &mut self.streams.streams[idx]; - if stream.is_terminated() { - return Poll::Ready(Ok(())); - } - - // Fetch a new input record and create a cursor from it - match futures::ready!(stream.poll_next_unpin(cx)) { - None => return Poll::Ready(Ok(())), - Some(Err(e)) => { - return Poll::Ready(Err(e)); - } - Some(Ok(batch)) => { - if batch.num_rows() > 0 { - let cols = self - .column_expressions - .iter() - .map(|expr| { - Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())) - }) - .collect::>>()?; - - let rows = match self.row_converter.convert_columns(&cols) { - Ok(rows) => rows, - Err(e) => { - return Poll::Ready(Err(DataFusionError::ArrowError(e))); - } - }; - - self.cursors[idx] = Some(SortKeyCursor::new(idx, rows)); - self.batches[idx].push_back(batch) - } else { - empty_batch = true; - } - } - } - } - - if empty_batch { - self.maybe_poll_stream(cx, idx) - } else { - Poll::Ready(Ok(())) - } - } - - /// Drains the in_progress row indexes, and builds a new RecordBatch from them - /// - /// Will then drop any batches for which all rows have been yielded to the output - fn build_record_batch(&mut self) -> Result { - // Mapping from stream index to the index of the first buffer from that stream - let mut buffer_idx = 0; - let mut stream_to_buffer_idx = Vec::with_capacity(self.batches.len()); - - for batches in &self.batches { - stream_to_buffer_idx.push(buffer_idx); - buffer_idx += batches.len(); - } - - let columns = self - .schema - .fields() - .iter() - .enumerate() - .map(|(column_idx, field)| { - let arrays = self - .batches - .iter() - .flat_map(|batch| { - batch.iter().map(|batch| batch.column(column_idx).data()) - }) - .collect(); - - let mut array_data = MutableArrayData::new( - arrays, - field.is_nullable(), - self.in_progress.len(), - ); - - if self.in_progress.is_empty() { - return make_arrow_array(array_data.freeze()); - } - - let first = &self.in_progress[0]; - let mut buffer_idx = - stream_to_buffer_idx[first.stream_idx] + first.batch_idx; - let mut start_row_idx = first.row_idx; - let mut end_row_idx = start_row_idx + 1; - - for row_index in self.in_progress.iter().skip(1) { - let next_buffer_idx = - stream_to_buffer_idx[row_index.stream_idx] + row_index.batch_idx; - - if next_buffer_idx == buffer_idx && row_index.row_idx == end_row_idx { - // subsequent row in same batch - end_row_idx += 1; - continue; - } - - // emit current batch of rows for current buffer - array_data.extend(buffer_idx, start_row_idx, end_row_idx); - - // start new batch of rows - buffer_idx = next_buffer_idx; - start_row_idx = row_index.row_idx; - end_row_idx = start_row_idx + 1; - } - - // emit final batch of rows - array_data.extend(buffer_idx, start_row_idx, end_row_idx); - make_arrow_array(array_data.freeze()) - }) - .collect(); - - self.in_progress.clear(); - - // New cursors are only created once the previous cursor for the stream - // is finished. This means all remaining rows from all but the last batch - // for each stream have been yielded to the newly created record batch - // - // Additionally as `in_progress` has been drained, there are no longer - // any RowIndex's reliant on the batch indexes - // - // We can therefore drop all but the last batch for each stream - for batches in &mut self.batches { - if batches.len() > 1 { - // Drain all but the last batch - batches.drain(0..(batches.len() - 1)); - } - } - - RecordBatch::try_new(self.schema.clone(), columns).map_err(Into::into) - } -} - -impl Stream for SortPreservingMergeStream { - type Item = Result; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let poll = self.poll_next_inner(cx); - self.tracking_metrics.record_poll(poll) - } -} - -impl SortPreservingMergeStream { - #[inline] - fn poll_next_inner( - self: &mut Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - if self.aborted { - return Poll::Ready(None); - } - // try to initialize the loser tree - if let Err(e) = ready!(self.init_loser_tree(cx)) { - return Poll::Ready(Some(Err(e))); - } - - // NB timer records time taken on drop, so there are no - // calls to `timer.done()` below. - let elapsed_compute = self.tracking_metrics.elapsed_compute().clone(); - let _timer = elapsed_compute.timer(); - - loop { - // Adjust the loser tree if necessary, returning control if needed - if let Err(e) = ready!(self.update_loser_tree(cx)) { - return Poll::Ready(Some(Err(e))); - } - - let min_cursor_idx = self.loser_tree[0]; - let next = self.cursors[min_cursor_idx] - .as_mut() - .filter(|cursor| !cursor.is_finished()) - .map(|cursor| (cursor.stream_idx(), cursor.advance())); - - if let Some((stream_idx, row_idx)) = next { - self.loser_tree_adjusted = false; - let batch_idx = self.batches[stream_idx].len() - 1; - self.in_progress.push(RowIndex { - stream_idx, - batch_idx, - row_idx, - }); - if self.in_progress.len() == self.batch_size { - return Poll::Ready(Some(self.build_record_batch())); - } - } else if !self.in_progress.is_empty() { - return Poll::Ready(Some(self.build_record_batch())); - } else { - return Poll::Ready(None); - } - } - } - - /// Attempts to initialize the loser tree with one value from each - /// non exhausted input, if possible. - /// - /// Returns - /// * Poll::Pending when more data is needed - /// * Poll::Ready(Ok()) on success - /// * Poll::Ready(Err..) if any of the inputs errored - #[inline] - fn init_loser_tree( - self: &mut Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let num_streams = self.streams.num_streams(); - - if !self.loser_tree.is_empty() { - return Poll::Ready(Ok(())); - } - - // Ensure all non-exhausted streams have a cursor from which - // rows can be pulled - for i in 0..num_streams { - if let Err(e) = ready!(self.maybe_poll_stream(cx, i)) { - self.aborted = true; - return Poll::Ready(Err(e)); - } - } - - // Init loser tree - self.loser_tree.resize(num_streams, usize::MAX); - for i in 0..num_streams { - let mut winner = i; - let mut cmp_node = (num_streams + i) / 2; - while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX { - let challenger = self.loser_tree[cmp_node]; - let challenger_win = - match (&self.cursors[winner], &self.cursors[challenger]) { - (None, _) => true, - (_, None) => false, - (Some(winner), Some(challenger)) => challenger < winner, - }; - - if challenger_win { - self.loser_tree[cmp_node] = winner; - winner = challenger; - } - - cmp_node /= 2; - } - self.loser_tree[cmp_node] = winner; - } - self.loser_tree_adjusted = true; - Poll::Ready(Ok(())) - } - - /// Attempts to updated the loser tree, if possible - /// - /// Returns - /// * Poll::Pending when the winning unput was not ready - /// * Poll::Ready(Ok()) on success - /// * Poll::Ready(Err..) if any of the winning input erroed - #[inline] - fn update_loser_tree( - self: &mut Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - if self.loser_tree_adjusted { - return Poll::Ready(Ok(())); - } - - let num_streams = self.streams.num_streams(); - let mut winner = self.loser_tree[0]; - if let Err(e) = ready!(self.maybe_poll_stream(cx, winner)) { - self.aborted = true; - return Poll::Ready(Err(e)); - } - - // Replace overall winner by walking tree of losers - let mut cmp_node = (num_streams + winner) / 2; - while cmp_node != 0 { - let challenger = self.loser_tree[cmp_node]; - let challenger_win = match (&self.cursors[winner], &self.cursors[challenger]) - { - (None, _) => true, - (_, None) => false, - (Some(winner), Some(challenger)) => challenger < winner, - }; - if challenger_win { - self.loser_tree[cmp_node] = winner; - winner = challenger; - } - cmp_node /= 2; - } - self.loser_tree[0] = winner; - self.loser_tree_adjusted = true; - Poll::Ready(Ok(())) - } -} - -impl RecordBatchStream for SortPreservingMergeStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - #[cfg(test)] mod tests { use std::iter::FromIterator; @@ -1284,9 +828,10 @@ mod tests { } }); - streams.push(SortedStream::new( - RecordBatchReceiverStream::create(&schema, receiver, join_handle), - 0, + streams.push(RecordBatchReceiverStream::create( + &schema, + receiver, + join_handle, )); } @@ -1294,7 +839,7 @@ mod tests { let tracking_metrics = MemTrackingMetrics::new(&metrics, task_ctx.memory_pool(), 0); - let merge_stream = SortPreservingMergeStream::new_from_streams( + let merge_stream = streaming_merge( streams, batches.schema(), sort.as_slice(), @@ -1303,7 +848,7 @@ mod tests { ) .unwrap(); - let mut merged = common::collect(Box::pin(merge_stream)).await.unwrap(); + let mut merged = common::collect(merge_stream).await.unwrap(); assert_eq!(merged.len(), 1); let merged = merged.remove(0); diff --git a/datafusion/core/src/physical_plan/sorts/stream.rs b/datafusion/core/src/physical_plan/sorts/stream.rs new file mode 100644 index 0000000000000..5c4e989c19ffa --- /dev/null +++ b/datafusion/core/src/physical_plan/sorts/stream.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::common::Result; +use crate::physical_plan::sorts::cursor::SortKeyCursor; +use crate::physical_plan::SendableRecordBatchStream; +use crate::physical_plan::{PhysicalExpr, PhysicalSortExpr}; +use arrow::datatypes::Schema; +use arrow::record_batch::RecordBatch; +use arrow::row::{RowConverter, SortField}; +use futures::stream::{Fuse, StreamExt}; +use std::sync::Arc; +use std::task::{ready, Context, Poll}; + +/// A [`Stream`](futures::Stream) that has multiple partitions that can +/// be polled separately but not concurrently +pub trait PartitionedStream: std::fmt::Debug + Send { + type Output; + + /// Returns the number of partitions + fn partitions(&self) -> usize; + + fn poll_next( + &mut self, + cx: &mut Context<'_>, + stream_idx: usize, + ) -> Poll>; +} + +/// A newtype wrapper around a set of fused [`SendableRecordBatchStream`] +/// that implements debug, and skips over empty [`RecordBatch`] +struct FusedStreams(Vec>); + +impl std::fmt::Debug for FusedStreams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FusedStreams") + .field("num_streams", &self.0.len()) + .finish() + } +} + +impl FusedStreams { + fn poll_next( + &mut self, + cx: &mut Context<'_>, + stream_idx: usize, + ) -> Poll>> { + loop { + match ready!(self.0[stream_idx].poll_next_unpin(cx)) { + Some(Ok(b)) if b.num_rows() == 0 => continue, + r => return Poll::Ready(r), + } + } + } +} + +/// A [`PartitionedStream`] that wraps a set of [`SendableRecordBatchStream`] +/// and computes [`SortKeyCursor`] based on the provided [`PhysicalSortExpr`] +#[derive(Debug)] +pub(crate) struct SortKeyCursorStream { + /// Converter to convert output of physical expressions + converter: RowConverter, + /// The physical expressions to sort by + column_expressions: Vec>, + /// Input streams + streams: FusedStreams, +} + +impl SortKeyCursorStream { + pub(crate) fn try_new( + schema: &Schema, + expressions: &[PhysicalSortExpr], + streams: Vec, + ) -> Result { + let sort_fields = expressions + .iter() + .map(|expr| { + let data_type = expr.expr.data_type(schema)?; + Ok(SortField::new_with_options(data_type, expr.options)) + }) + .collect::>>()?; + + let streams = streams.into_iter().map(|s| s.fuse()).collect(); + let converter = RowConverter::new(sort_fields)?; + Ok(Self { + converter, + column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), + streams: FusedStreams(streams), + }) + } + + fn convert_batch( + &mut self, + batch: &RecordBatch, + stream_idx: usize, + ) -> Result { + let cols = self + .column_expressions + .iter() + .map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows()))) + .collect::>>()?; + + let rows = self.converter.convert_columns(&cols)?; + Ok(SortKeyCursor::new(stream_idx, rows)) + } +} + +impl PartitionedStream for SortKeyCursorStream { + type Output = Result<(SortKeyCursor, RecordBatch)>; + + fn partitions(&self) -> usize { + self.streams.0.len() + } + + fn poll_next( + &mut self, + cx: &mut Context<'_>, + stream_idx: usize, + ) -> Poll> { + let r = match ready!(self.streams.poll_next(cx, stream_idx)) { + Some(r) => Some(r.and_then(|batch| { + let cursor = self.convert_batch(&batch, stream_idx)?; + Ok((cursor, batch)) + })), + None => None, + }; + Poll::Ready(r) + } +}