Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract ReceiverStreamBuilder #7817

Merged
merged 6 commits into from
Oct 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 132 additions & 100 deletions datafusion/physical-plan/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,124 @@ use tokio::task::JoinSet;
use super::metrics::BaselineMetrics;
use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};

/// Creates a stream from a collection of producing tasks, routing panics to the stream.
///
/// Note that this is similar to [`ReceiverStream` from tokio-stream], with the differences being:
///
/// 1. Methods to bound and "detach" tasks (`spawn()` and `spawn_blocking()`).
///
/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics to the receiver.
///
/// 3. Automatically cancels any outstanding tasks when the receiver stream is dropped.
///
/// [`ReceiverStream` from tokio-stream]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html

pub(crate) struct ReceiverStreamBuilder<O> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this pub(crate) but it could easily be made public should we wish to do so

tx: Sender<Result<O>>,
rx: Receiver<Result<O>>,
join_set: JoinSet<Result<()>>,
}

impl<O: Send + 'static> ReceiverStreamBuilder<O> {
/// create new channels with the specified buffer size
pub fn new(capacity: usize) -> Self {
let (tx, rx) = tokio::sync::mpsc::channel(capacity);

Self {
tx,
rx,
join_set: JoinSet::new(),
}
}

/// Get a handle for sending data to the output
pub fn tx(&self) -> Sender<Result<O>> {
self.tx.clone()
}

/// Spawn task that will be aborted if this builder (or the stream
/// built from it) are dropped
pub fn spawn<F>(&mut self, task: F)
where
F: Future<Output = Result<()>>,
F: Send + 'static,
{
self.join_set.spawn(task);
}

/// Spawn a blocking task that will be aborted if this builder (or the stream
/// built from it) are dropped
///
/// this is often used to spawn tasks that write to the sender
/// retrieved from `Self::tx`
pub fn spawn_blocking<F>(&mut self, f: F)
where
F: FnOnce() -> Result<()>,
F: Send + 'static,
{
self.join_set.spawn_blocking(f);
}

/// Create a stream of all data written to `tx`
pub fn build(self) -> BoxStream<'static, Result<O>> {
let Self {
tx,
rx,
mut join_set,
} = self;

// don't need tx
drop(tx);

// future that checks the result of the join set, and propagates panic if seen
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the interesting logic that needs DRYing

let check = async move {
while let Some(result) = join_set.join_next().await {
match result {
Ok(task_result) => {
match task_result {
// nothing to report
Ok(_) => continue,
// This means a blocking task error
Err(e) => {
return Some(exec_err!("Spawned Task error: {e}"));
}
}
}
// This means a tokio task error, likely a panic
Err(e) => {
if e.is_panic() {
// resume on the main thread
std::panic::resume_unwind(e.into_panic());
} else {
// This should only occur if the task is
// cancelled, which would only occur if
// the JoinSet were aborted, which in turn
// would imply that the receiver has been
// dropped and this code is not running
return Some(internal_err!("Non Panic Task error: {e}"));
}
}
}
}
None
};

let check_stream = futures::stream::once(check)
// unwrap Option / only return the error
.filter_map(|item| async move { item });

// Convert the receiver into a stream
let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
let next_item = rx.recv().await;
next_item.map(|next_item| (next_item, rx))
});

// Merge the streams together so whichever is ready first
// produces the batch
futures::stream::select(rx_stream, check_stream).boxed()
}
}

/// Builder for [`RecordBatchReceiverStream`] that propagates errors
/// and panic's correctly.
///
Expand All @@ -47,28 +165,22 @@ use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
///
/// This also handles propagating panic`s and canceling the tasks.
pub struct RecordBatchReceiverStreamBuilder {
tx: Sender<Result<RecordBatch>>,
rx: Receiver<Result<RecordBatch>>,
schema: SchemaRef,
join_set: JoinSet<Result<()>>,
inner: ReceiverStreamBuilder<RecordBatch>,
}

impl RecordBatchReceiverStreamBuilder {
/// create new channels with the specified buffer size
pub fn new(schema: SchemaRef, capacity: usize) -> Self {
let (tx, rx) = tokio::sync::mpsc::channel(capacity);

Self {
tx,
rx,
schema,
join_set: JoinSet::new(),
inner: ReceiverStreamBuilder::new(capacity),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seemed unnecessary / undesirable to burden ReceiverStreamBuilder with a notion of Schema

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

}
}

/// Get a handle for sending [`RecordBatch`]es to the output
/// Get a handle for sending [`RecordBatch`] to the output
pub fn tx(&self) -> Sender<Result<RecordBatch>> {
self.tx.clone()
self.inner.tx()
}

/// Spawn task that will be aborted if this builder (or the stream
Expand All @@ -81,7 +193,7 @@ impl RecordBatchReceiverStreamBuilder {
F: Future<Output = Result<()>>,
F: Send + 'static,
{
self.join_set.spawn(task);
self.inner.spawn(task)
}

/// Spawn a blocking task that will be aborted if this builder (or the stream
Expand All @@ -94,7 +206,7 @@ impl RecordBatchReceiverStreamBuilder {
F: FnOnce() -> Result<()>,
F: Send + 'static,
{
self.join_set.spawn_blocking(f);
self.inner.spawn_blocking(f)
}

/// runs the input_partition of the `input` ExecutionPlan on the
Expand All @@ -110,7 +222,7 @@ impl RecordBatchReceiverStreamBuilder {
) {
let output = self.tx();

self.spawn(async move {
self.inner.spawn(async move {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic by comparison is rather ExecutionPlan specifi, and I don't think valuable to DRY

let mut stream = match input.execute(partition, context) {
Err(e) => {
// If send fails, the plan being torn down, there
Expand Down Expand Up @@ -155,80 +267,17 @@ impl RecordBatchReceiverStreamBuilder {
});
}

/// Create a stream of all `RecordBatch`es written to `tx`
/// Create a stream of all [`RecordBatch`] written to `tx`
pub fn build(self) -> SendableRecordBatchStream {
let Self {
tx,
rx,
schema,
mut join_set,
} = self;

// don't need tx
drop(tx);

// future that checks the result of the join set, and propagates panic if seen
let check = async move {
while let Some(result) = join_set.join_next().await {
match result {
Ok(task_result) => {
match task_result {
// nothing to report
Ok(_) => continue,
// This means a blocking task error
Err(e) => {
return Some(exec_err!("Spawned Task error: {e}"));
}
}
}
// This means a tokio task error, likely a panic
Err(e) => {
if e.is_panic() {
// resume on the main thread
std::panic::resume_unwind(e.into_panic());
} else {
// This should only occur if the task is
// cancelled, which would only occur if
// the JoinSet were aborted, which in turn
// would imply that the receiver has been
// dropped and this code is not running
return Some(internal_err!("Non Panic Task error: {e}"));
}
}
}
}
None
};

let check_stream = futures::stream::once(check)
// unwrap Option / only return the error
.filter_map(|item| async move { item });

// Convert the receiver into a stream
let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
let next_item = rx.recv().await;
next_item.map(|next_item| (next_item, rx))
});

// Merge the streams together so whichever is ready first
// produces the batch
let inner = futures::stream::select(rx_stream, check_stream).boxed();

Box::pin(RecordBatchReceiverStream { schema, inner })
Box::pin(RecordBatchStreamAdapter::new(
self.schema,
self.inner.build(),
))
}
}

/// A [`SendableRecordBatchStream`] that combines [`RecordBatch`]es from multiple inputs,
/// on new tokio Tasks, increasing the potential parallelism.
///
/// This structure also handles propagating panics and cancelling the
/// underlying tasks correctly.
///
/// Use [`Self::builder`] to construct one.
pub struct RecordBatchReceiverStream {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type doesn't really make a lot of sense, given it isn't actually what the builder returns

schema: SchemaRef,
inner: BoxStream<'static, Result<RecordBatch>>,
}
#[doc(hidden)]
pub struct RecordBatchReceiverStream {}

impl RecordBatchReceiverStream {
/// Create a builder with an internal buffer of capacity batches.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this method documentation also seems incorrect at this time.

Expand All @@ -240,23 +289,6 @@ impl RecordBatchReceiverStream {
}
}

impl Stream for RecordBatchReceiverStream {
type Item = Result<RecordBatch>;

fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
self.inner.poll_next_unpin(cx)
}
}

impl RecordBatchStream for RecordBatchReceiverStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}

pin_project! {
/// Combines a [`Stream`] with a [`SchemaRef`] implementing
/// [`RecordBatchStream`] for the combination
Expand Down