From fa2bb6c4e80d80e3d1d26ce85f7c21232f036dd5 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Mon, 16 Oct 2023 10:02:05 +0100 Subject: [PATCH] Extract ReceiverStreamBuilder (#7817) * Extract ReceiverStreamBuilder * Docs and format * Update datafusion/physical-plan/src/stream.rs * fmt * Undo changes to testing pin --------- Co-authored-by: Andrew Lamb --- datafusion/physical-plan/src/stream.rs | 232 ++++++++++++++----------- 1 file changed, 132 insertions(+), 100 deletions(-) diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index a3fb856c326d..fdf32620ca50 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -38,6 +38,124 @@ use tokio::task::JoinSet; use super::metrics::BaselineMetrics; use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; +/// Creates a stream from a collection of producing tasks, routing panics to the stream. +/// +/// Note that this is similar to [`ReceiverStream` from tokio-stream], with the differences being: +/// +/// 1. Methods to bound and "detach" tasks (`spawn()` and `spawn_blocking()`). +/// +/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics to the receiver. +/// +/// 3. Automatically cancels any outstanding tasks when the receiver stream is dropped. +/// +/// [`ReceiverStream` from tokio-stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html + +pub(crate) struct ReceiverStreamBuilder { + tx: Sender>, + rx: Receiver>, + join_set: JoinSet>, +} + +impl ReceiverStreamBuilder { + /// create new channels with the specified buffer size + pub fn new(capacity: usize) -> Self { + let (tx, rx) = tokio::sync::mpsc::channel(capacity); + + Self { + tx, + rx, + join_set: JoinSet::new(), + } + } + + /// Get a handle for sending data to the output + pub fn tx(&self) -> Sender> { + self.tx.clone() + } + + /// Spawn task that will be aborted if this builder (or the stream + /// built from it) are dropped + pub fn spawn(&mut self, task: F) + where + F: Future>, + F: Send + 'static, + { + self.join_set.spawn(task); + } + + /// Spawn a blocking task that will be aborted if this builder (or the stream + /// built from it) are dropped + /// + /// this is often used to spawn tasks that write to the sender + /// retrieved from `Self::tx` + pub fn spawn_blocking(&mut self, f: F) + where + F: FnOnce() -> Result<()>, + F: Send + 'static, + { + self.join_set.spawn_blocking(f); + } + + /// Create a stream of all data written to `tx` + pub fn build(self) -> BoxStream<'static, Result> { + let Self { + tx, + rx, + mut join_set, + } = self; + + // don't need tx + drop(tx); + + // future that checks the result of the join set, and propagates panic if seen + let check = async move { + while let Some(result) = join_set.join_next().await { + match result { + Ok(task_result) => { + match task_result { + // nothing to report + Ok(_) => continue, + // This means a blocking task error + Err(e) => { + return Some(exec_err!("Spawned Task error: {e}")); + } + } + } + // This means a tokio task error, likely a panic + Err(e) => { + if e.is_panic() { + // resume on the main thread + std::panic::resume_unwind(e.into_panic()); + } else { + // This should only occur if the task is + // cancelled, which would only occur if + // the JoinSet were aborted, which in turn + // would imply that the receiver has been + // dropped and this code is not running + return Some(internal_err!("Non Panic Task error: {e}")); + } + } + } + } + None + }; + + let check_stream = futures::stream::once(check) + // unwrap Option / only return the error + .filter_map(|item| async move { item }); + + // Convert the receiver into a stream + let rx_stream = futures::stream::unfold(rx, |mut rx| async move { + let next_item = rx.recv().await; + next_item.map(|next_item| (next_item, rx)) + }); + + // Merge the streams together so whichever is ready first + // produces the batch + futures::stream::select(rx_stream, check_stream).boxed() + } +} + /// Builder for [`RecordBatchReceiverStream`] that propagates errors /// and panic's correctly. /// @@ -47,28 +165,22 @@ use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; /// /// This also handles propagating panic`s and canceling the tasks. pub struct RecordBatchReceiverStreamBuilder { - tx: Sender>, - rx: Receiver>, schema: SchemaRef, - join_set: JoinSet>, + inner: ReceiverStreamBuilder, } impl RecordBatchReceiverStreamBuilder { /// create new channels with the specified buffer size pub fn new(schema: SchemaRef, capacity: usize) -> Self { - let (tx, rx) = tokio::sync::mpsc::channel(capacity); - Self { - tx, - rx, schema, - join_set: JoinSet::new(), + inner: ReceiverStreamBuilder::new(capacity), } } - /// Get a handle for sending [`RecordBatch`]es to the output + /// Get a handle for sending [`RecordBatch`] to the output pub fn tx(&self) -> Sender> { - self.tx.clone() + self.inner.tx() } /// Spawn task that will be aborted if this builder (or the stream @@ -81,7 +193,7 @@ impl RecordBatchReceiverStreamBuilder { F: Future>, F: Send + 'static, { - self.join_set.spawn(task); + self.inner.spawn(task) } /// Spawn a blocking task that will be aborted if this builder (or the stream @@ -94,7 +206,7 @@ impl RecordBatchReceiverStreamBuilder { F: FnOnce() -> Result<()>, F: Send + 'static, { - self.join_set.spawn_blocking(f); + self.inner.spawn_blocking(f) } /// runs the input_partition of the `input` ExecutionPlan on the @@ -110,7 +222,7 @@ impl RecordBatchReceiverStreamBuilder { ) { let output = self.tx(); - self.spawn(async move { + self.inner.spawn(async move { let mut stream = match input.execute(partition, context) { Err(e) => { // If send fails, the plan being torn down, there @@ -155,80 +267,17 @@ impl RecordBatchReceiverStreamBuilder { }); } - /// Create a stream of all `RecordBatch`es written to `tx` + /// Create a stream of all [`RecordBatch`] written to `tx` pub fn build(self) -> SendableRecordBatchStream { - let Self { - tx, - rx, - schema, - mut join_set, - } = self; - - // don't need tx - drop(tx); - - // future that checks the result of the join set, and propagates panic if seen - let check = async move { - while let Some(result) = join_set.join_next().await { - match result { - Ok(task_result) => { - match task_result { - // nothing to report - Ok(_) => continue, - // This means a blocking task error - Err(e) => { - return Some(exec_err!("Spawned Task error: {e}")); - } - } - } - // This means a tokio task error, likely a panic - Err(e) => { - if e.is_panic() { - // resume on the main thread - std::panic::resume_unwind(e.into_panic()); - } else { - // This should only occur if the task is - // cancelled, which would only occur if - // the JoinSet were aborted, which in turn - // would imply that the receiver has been - // dropped and this code is not running - return Some(internal_err!("Non Panic Task error: {e}")); - } - } - } - } - None - }; - - let check_stream = futures::stream::once(check) - // unwrap Option / only return the error - .filter_map(|item| async move { item }); - - // Convert the receiver into a stream - let rx_stream = futures::stream::unfold(rx, |mut rx| async move { - let next_item = rx.recv().await; - next_item.map(|next_item| (next_item, rx)) - }); - - // Merge the streams together so whichever is ready first - // produces the batch - let inner = futures::stream::select(rx_stream, check_stream).boxed(); - - Box::pin(RecordBatchReceiverStream { schema, inner }) + Box::pin(RecordBatchStreamAdapter::new( + self.schema, + self.inner.build(), + )) } } -/// A [`SendableRecordBatchStream`] that combines [`RecordBatch`]es from multiple inputs, -/// on new tokio Tasks, increasing the potential parallelism. -/// -/// This structure also handles propagating panics and cancelling the -/// underlying tasks correctly. -/// -/// Use [`Self::builder`] to construct one. -pub struct RecordBatchReceiverStream { - schema: SchemaRef, - inner: BoxStream<'static, Result>, -} +#[doc(hidden)] +pub struct RecordBatchReceiverStream {} impl RecordBatchReceiverStream { /// Create a builder with an internal buffer of capacity batches. @@ -240,23 +289,6 @@ impl RecordBatchReceiverStream { } } -impl Stream for RecordBatchReceiverStream { - type Item = Result; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - self.inner.poll_next_unpin(cx) - } -} - -impl RecordBatchStream for RecordBatchReceiverStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - pin_project! { /// Combines a [`Stream`] with a [`SchemaRef`] implementing /// [`RecordBatchStream`] for the combination