Skip to content

Commit

Permalink
fix: account for memory in RepartitionExec (#4820)
Browse files Browse the repository at this point in the history
* refactor: explicit loop instead of (tail) recursion

* test: simplify

* fix: account for memory in `RepartitionExec`

Fixes #4816.

* fix: sorting memory limit test
  • Loading branch information
crepererum authored Jan 7, 2023
1 parent 2db3d2e commit 83c1026
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 36 deletions.
19 changes: 6 additions & 13 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ mod tests {
use crate::{assert_batches_sorted_eq, physical_plan::common};
use arrow::array::{Float64Array, UInt32Array};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::error::{ArrowError, Result as ArrowResult};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count, Median};
Expand Down Expand Up @@ -1207,18 +1207,11 @@ mod tests {
let err = common::collect(stream).await.unwrap_err();

// error root cause traversal is a bit complicated, see #4172.
if let DataFusionError::ArrowError(ArrowError::ExternalError(err)) = err {
if let Some(err) = err.downcast_ref::<DataFusionError>() {
assert!(
matches!(err, DataFusionError::ResourcesExhausted(_)),
"Wrong inner error type: {err}",
);
} else {
panic!("Wrong arrow error type: {err}")
}
} else {
panic!("Wrong outer error type: {err}")
}
let err = err.find_root();
assert!(
matches!(err, DataFusionError::ResourcesExhausted(_)),
"Wrong error type: {err}",
);
}

Ok(())
Expand Down
127 changes: 105 additions & 22 deletions datafusion/core/src/physical_plan/repartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::task::{Context, Poll};
use std::{any::Any, vec};

use crate::error::{DataFusionError, Result};
use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation};
use crate::physical_plan::hash_utils::create_hashes;
use crate::physical_plan::{
DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics,
Expand All @@ -50,14 +51,21 @@ use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use tokio::task::JoinHandle;

type MaybeBatch = Option<ArrowResult<RecordBatch>>;
type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>;

/// Inner state of [`RepartitionExec`].
#[derive(Debug)]
struct RepartitionExecState {
/// Channels for sending batches from input partitions to output partitions.
/// Key is the partition number.
channels:
HashMap<usize, (UnboundedSender<MaybeBatch>, UnboundedReceiver<MaybeBatch>)>,
channels: HashMap<
usize,
(
UnboundedSender<MaybeBatch>,
UnboundedReceiver<MaybeBatch>,
SharedMemoryReservation,
),
>,

/// Helper that ensures that that background job is killed once it is no longer needed.
abort_helper: Arc<AbortOnDropMany<()>>,
Expand Down Expand Up @@ -338,7 +346,13 @@ impl ExecutionPlan for RepartitionExec {
// for this would be to add spill-to-disk capabilities.
let (sender, receiver) =
mpsc::unbounded_channel::<Option<ArrowResult<RecordBatch>>>();
state.channels.insert(partition, (sender, receiver));
let reservation = Arc::new(Mutex::new(
MemoryConsumer::new(format!("RepartitionExec[{partition}]"))
.register(context.memory_pool()),
));
state
.channels
.insert(partition, (sender, receiver, reservation));
}

// launch one async task per *input* partition
Expand All @@ -347,7 +361,9 @@ impl ExecutionPlan for RepartitionExec {
let txs: HashMap<_, _> = state
.channels
.iter()
.map(|(partition, (tx, _rx))| (*partition, tx.clone()))
.map(|(partition, (tx, _rx, reservation))| {
(*partition, (tx.clone(), Arc::clone(reservation)))
})
.collect();

let r_metrics = RepartitionMetrics::new(i, partition, &self.metrics);
Expand All @@ -366,7 +382,9 @@ impl ExecutionPlan for RepartitionExec {
// (and pass along any errors, including panic!s)
let join_handle = tokio::spawn(Self::wait_for_task(
AbortOnDropSingle::new(input_task),
txs,
txs.into_iter()
.map(|(partition, (tx, _reservation))| (partition, tx))
.collect(),
));
join_handles.push(join_handle);
}
Expand All @@ -381,14 +399,17 @@ impl ExecutionPlan for RepartitionExec {

// now return stream for the specified *output* partition which will
// read from the channel
let (_tx, rx, reservation) = state
.channels
.remove(&partition)
.expect("partition not used yet");
Ok(Box::pin(RepartitionStream {
num_input_partitions,
num_input_partitions_processed: 0,
schema: self.input.schema(),
input: UnboundedReceiverStream::new(
state.channels.remove(&partition).unwrap().1,
),
input: UnboundedReceiverStream::new(rx),
drop_helper: Arc::clone(&state.abort_helper),
reservation,
}))
}

Expand Down Expand Up @@ -439,7 +460,7 @@ impl RepartitionExec {
async fn pull_from_input(
input: Arc<dyn ExecutionPlan>,
i: usize,
mut txs: HashMap<usize, UnboundedSender<Option<ArrowResult<RecordBatch>>>>,
mut txs: HashMap<usize, (UnboundedSender<MaybeBatch>, SharedMemoryReservation)>,
partitioning: Partitioning,
r_metrics: RepartitionMetrics,
context: Arc<TaskContext>,
Expand Down Expand Up @@ -467,11 +488,16 @@ impl RepartitionExec {
};

partitioner.partition(batch, |partition, partitioned| {
let size = partitioned.get_array_memory_size();

let timer = r_metrics.send_time.timer();
// if there is still a receiver, send to it
if let Some(tx) = txs.get_mut(&partition) {
if let Some((tx, reservation)) = txs.get_mut(&partition) {
reservation.lock().try_grow(size)?;

if tx.send(Some(Ok(partitioned))).is_err() {
// If the other end has hung up, it was an early shutdown (e.g. LIMIT)
reservation.lock().shrink(size);
txs.remove(&partition);
}
}
Expand Down Expand Up @@ -546,6 +572,9 @@ struct RepartitionStream {
/// Handle to ensure background tasks are killed when no longer needed.
#[allow(dead_code)]
drop_helper: Arc<AbortOnDropMany<()>>,

/// Memory reservation.
reservation: SharedMemoryReservation,
}

impl Stream for RepartitionStream {
Expand All @@ -555,20 +584,35 @@ impl Stream for RepartitionStream {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.input.poll_next_unpin(cx) {
Poll::Ready(Some(Some(v))) => Poll::Ready(Some(v)),
Poll::Ready(Some(None)) => {
self.num_input_partitions_processed += 1;
if self.num_input_partitions == self.num_input_partitions_processed {
// all input partitions have finished sending batches
Poll::Ready(None)
} else {
// other partitions still have data to send
self.poll_next(cx)
loop {
match self.input.poll_next_unpin(cx) {
Poll::Ready(Some(Some(v))) => {
if let Ok(batch) = &v {
self.reservation
.lock()
.shrink(batch.get_array_memory_size());
}

return Poll::Ready(Some(v));
}
Poll::Ready(Some(None)) => {
self.num_input_partitions_processed += 1;

if self.num_input_partitions == self.num_input_partitions_processed {
// all input partitions have finished sending batches
return Poll::Ready(None);
} else {
// other partitions still have data to send
continue;
}
}
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Pending => {
return Poll::Pending;
}
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
Expand All @@ -583,6 +627,8 @@ impl RecordBatchStream for RepartitionStream {
#[cfg(test)]
mod tests {
use super::*;
use crate::execution::context::SessionConfig;
use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use crate::from_slice::FromSlice;
use crate::prelude::SessionContext;
use crate::test::create_vec_batches;
Expand Down Expand Up @@ -1078,4 +1124,41 @@ mod tests {
assert!(batch0.is_empty() || batch1.is_empty());
Ok(())
}

#[tokio::test]
async fn oom() -> Result<()> {
// define input partitions
let schema = test_schema();
let partition = create_vec_batches(&schema, 50);
let input_partitions = vec![partition];
let partitioning = Partitioning::RoundRobinBatch(4);

// setup up context
let session_ctx = SessionContext::with_config_rt(
SessionConfig::default(),
Arc::new(
RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0))
.unwrap(),
),
);
let task_ctx = session_ctx.task_ctx();

// create physical plan
let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?;
let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?;

// pull partitions
for i in 0..exec.partitioning.partition_count() {
let mut stream = exec.execute(i, task_ctx.clone())?;
let err =
DataFusionError::ArrowError(stream.next().await.unwrap().unwrap_err());
let err = err.find_root();
assert!(
matches!(err, DataFusionError::ResourcesExhausted(_)),
"Wrong error type: {err}",
);
}

Ok(())
}
}
6 changes: 5 additions & 1 deletion datafusion/core/tests/memory_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ async fn run_limit_test(query: &str, expected_error: &str, memory_limit: usize)

let runtime = RuntimeEnv::new(rt_config).unwrap();

let ctx = SessionContext::with_config_rt(SessionConfig::new(), Arc::new(runtime));
let ctx = SessionContext::with_config_rt(
// do NOT re-partition (since RepartitionExec has also has a memory budget which we'll likely hit first)
SessionConfig::new().with_target_partitions(1),
Arc::new(runtime),
);
ctx.register_table("t", Arc::new(table))
.expect("registering table");

Expand Down

0 comments on commit 83c1026

Please sign in to comment.