From 83c102698945e0984f8fa53e75b04478e49e5242 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Sat, 7 Jan 2023 10:19:25 +0100 Subject: [PATCH] fix: account for memory in `RepartitionExec` (#4820) * refactor: explicit loop instead of (tail) recursion * test: simplify * fix: account for memory in `RepartitionExec` Fixes #4816. * fix: sorting memory limit test --- .../core/src/physical_plan/aggregates/mod.rs | 19 +-- .../core/src/physical_plan/repartition.rs | 127 +++++++++++++++--- datafusion/core/tests/memory_limit.rs | 6 +- 3 files changed, 116 insertions(+), 36 deletions(-) diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index 07f3563bbbc5..8044f4c15781 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -746,7 +746,7 @@ mod tests { use crate::{assert_batches_sorted_eq, physical_plan::common}; use arrow::array::{Float64Array, UInt32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use arrow::error::{ArrowError, Result as ArrowResult}; + use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count, Median}; @@ -1207,18 +1207,11 @@ mod tests { let err = common::collect(stream).await.unwrap_err(); // error root cause traversal is a bit complicated, see #4172. - if let DataFusionError::ArrowError(ArrowError::ExternalError(err)) = err { - if let Some(err) = err.downcast_ref::() { - assert!( - matches!(err, DataFusionError::ResourcesExhausted(_)), - "Wrong inner error type: {err}", - ); - } else { - panic!("Wrong arrow error type: {err}") - } - } else { - panic!("Wrong outer error type: {err}") - } + let err = err.find_root(); + assert!( + matches!(err, DataFusionError::ResourcesExhausted(_)), + "Wrong error type: {err}", + ); } Ok(()) diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs index ee2e976cecdb..451b0fba4b13 100644 --- a/datafusion/core/src/physical_plan/repartition.rs +++ b/datafusion/core/src/physical_plan/repartition.rs @@ -24,6 +24,7 @@ use std::task::{Context, Poll}; use std::{any::Any, vec}; use crate::error::{DataFusionError, Result}; +use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use crate::physical_plan::hash_utils::create_hashes; use crate::physical_plan::{ DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics, @@ -50,14 +51,21 @@ use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; use tokio::task::JoinHandle; type MaybeBatch = Option>; +type SharedMemoryReservation = Arc>; /// Inner state of [`RepartitionExec`]. #[derive(Debug)] struct RepartitionExecState { /// Channels for sending batches from input partitions to output partitions. /// Key is the partition number. - channels: - HashMap, UnboundedReceiver)>, + channels: HashMap< + usize, + ( + UnboundedSender, + UnboundedReceiver, + SharedMemoryReservation, + ), + >, /// Helper that ensures that that background job is killed once it is no longer needed. abort_helper: Arc>, @@ -338,7 +346,13 @@ impl ExecutionPlan for RepartitionExec { // for this would be to add spill-to-disk capabilities. let (sender, receiver) = mpsc::unbounded_channel::>>(); - state.channels.insert(partition, (sender, receiver)); + let reservation = Arc::new(Mutex::new( + MemoryConsumer::new(format!("RepartitionExec[{partition}]")) + .register(context.memory_pool()), + )); + state + .channels + .insert(partition, (sender, receiver, reservation)); } // launch one async task per *input* partition @@ -347,7 +361,9 @@ impl ExecutionPlan for RepartitionExec { let txs: HashMap<_, _> = state .channels .iter() - .map(|(partition, (tx, _rx))| (*partition, tx.clone())) + .map(|(partition, (tx, _rx, reservation))| { + (*partition, (tx.clone(), Arc::clone(reservation))) + }) .collect(); let r_metrics = RepartitionMetrics::new(i, partition, &self.metrics); @@ -366,7 +382,9 @@ impl ExecutionPlan for RepartitionExec { // (and pass along any errors, including panic!s) let join_handle = tokio::spawn(Self::wait_for_task( AbortOnDropSingle::new(input_task), - txs, + txs.into_iter() + .map(|(partition, (tx, _reservation))| (partition, tx)) + .collect(), )); join_handles.push(join_handle); } @@ -381,14 +399,17 @@ impl ExecutionPlan for RepartitionExec { // now return stream for the specified *output* partition which will // read from the channel + let (_tx, rx, reservation) = state + .channels + .remove(&partition) + .expect("partition not used yet"); Ok(Box::pin(RepartitionStream { num_input_partitions, num_input_partitions_processed: 0, schema: self.input.schema(), - input: UnboundedReceiverStream::new( - state.channels.remove(&partition).unwrap().1, - ), + input: UnboundedReceiverStream::new(rx), drop_helper: Arc::clone(&state.abort_helper), + reservation, })) } @@ -439,7 +460,7 @@ impl RepartitionExec { async fn pull_from_input( input: Arc, i: usize, - mut txs: HashMap>>>, + mut txs: HashMap, SharedMemoryReservation)>, partitioning: Partitioning, r_metrics: RepartitionMetrics, context: Arc, @@ -467,11 +488,16 @@ impl RepartitionExec { }; partitioner.partition(batch, |partition, partitioned| { + let size = partitioned.get_array_memory_size(); + let timer = r_metrics.send_time.timer(); // if there is still a receiver, send to it - if let Some(tx) = txs.get_mut(&partition) { + if let Some((tx, reservation)) = txs.get_mut(&partition) { + reservation.lock().try_grow(size)?; + if tx.send(Some(Ok(partitioned))).is_err() { // If the other end has hung up, it was an early shutdown (e.g. LIMIT) + reservation.lock().shrink(size); txs.remove(&partition); } } @@ -546,6 +572,9 @@ struct RepartitionStream { /// Handle to ensure background tasks are killed when no longer needed. #[allow(dead_code)] drop_helper: Arc>, + + /// Memory reservation. + reservation: SharedMemoryReservation, } impl Stream for RepartitionStream { @@ -555,20 +584,35 @@ impl Stream for RepartitionStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - match self.input.poll_next_unpin(cx) { - Poll::Ready(Some(Some(v))) => Poll::Ready(Some(v)), - Poll::Ready(Some(None)) => { - self.num_input_partitions_processed += 1; - if self.num_input_partitions == self.num_input_partitions_processed { - // all input partitions have finished sending batches - Poll::Ready(None) - } else { - // other partitions still have data to send - self.poll_next(cx) + loop { + match self.input.poll_next_unpin(cx) { + Poll::Ready(Some(Some(v))) => { + if let Ok(batch) = &v { + self.reservation + .lock() + .shrink(batch.get_array_memory_size()); + } + + return Poll::Ready(Some(v)); + } + Poll::Ready(Some(None)) => { + self.num_input_partitions_processed += 1; + + if self.num_input_partitions == self.num_input_partitions_processed { + // all input partitions have finished sending batches + return Poll::Ready(None); + } else { + // other partitions still have data to send + continue; + } + } + Poll::Ready(None) => { + return Poll::Ready(None); + } + Poll::Pending => { + return Poll::Pending; } } - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, } } } @@ -583,6 +627,8 @@ impl RecordBatchStream for RepartitionStream { #[cfg(test)] mod tests { use super::*; + use crate::execution::context::SessionConfig; + use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use crate::from_slice::FromSlice; use crate::prelude::SessionContext; use crate::test::create_vec_batches; @@ -1078,4 +1124,41 @@ mod tests { assert!(batch0.is_empty() || batch1.is_empty()); Ok(()) } + + #[tokio::test] + async fn oom() -> Result<()> { + // define input partitions + let schema = test_schema(); + let partition = create_vec_batches(&schema, 50); + let input_partitions = vec![partition]; + let partitioning = Partitioning::RoundRobinBatch(4); + + // setup up context + let session_ctx = SessionContext::with_config_rt( + SessionConfig::default(), + Arc::new( + RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0)) + .unwrap(), + ), + ); + let task_ctx = session_ctx.task_ctx(); + + // create physical plan + let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; + let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?; + + // pull partitions + for i in 0..exec.partitioning.partition_count() { + let mut stream = exec.execute(i, task_ctx.clone())?; + let err = + DataFusionError::ArrowError(stream.next().await.unwrap().unwrap_err()); + let err = err.find_root(); + assert!( + matches!(err, DataFusionError::ResourcesExhausted(_)), + "Wrong error type: {err}", + ); + } + + Ok(()) + } } diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit.rs index 91d66e884623..170f559038df 100644 --- a/datafusion/core/tests/memory_limit.rs +++ b/datafusion/core/tests/memory_limit.rs @@ -95,7 +95,11 @@ async fn run_limit_test(query: &str, expected_error: &str, memory_limit: usize) let runtime = RuntimeEnv::new(rt_config).unwrap(); - let ctx = SessionContext::with_config_rt(SessionConfig::new(), Arc::new(runtime)); + let ctx = SessionContext::with_config_rt( + // do NOT re-partition (since RepartitionExec has also has a memory budget which we'll likely hit first) + SessionConfig::new().with_target_partitions(1), + Arc::new(runtime), + ); ctx.register_table("t", Arc::new(table)) .expect("registering table");