diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs
index 49dbb1b4c480..3672d78c2de7 100644
--- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs
+++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs
@@ -43,6 +43,7 @@ use datafusion::arrow::io::ipc::read::FileReader;
use datafusion::arrow::io::ipc::write::FileWriter;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::{DataFusionError, Result};
+use datafusion::physical_plan::common::IPCWriterWrapper;
use datafusion::physical_plan::hash_utils::create_hashes;
use datafusion::physical_plan::metrics::{
self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet,
@@ -198,7 +199,7 @@ impl ShuffleWriterExec {
// we won't necessary produce output for every possible partition, so we
// create writers on demand
- let mut writers: Vec> = vec![];
+ let mut writers: Vec > = vec![];
for _ in 0..num_output_partitions {
writers.push(None);
}
@@ -268,8 +269,10 @@ impl ShuffleWriterExec {
let path = path.to_str().unwrap();
info!("Writing results to {}", path);
- let mut writer =
- ShuffleWriter::new(path, stream.schema().as_ref())?;
+ let mut writer = IPCWriterWrapper::new(
+ path,
+ stream.schema().as_ref(),
+ )?;
writer.write(&output_batch)?;
writers[output_partition] = Some(writer);
@@ -434,56 +437,6 @@ fn result_schema() -> SchemaRef {
]))
}
-struct ShuffleWriter {
- path: String,
- writer: FileWriter>,
- num_batches: u64,
- num_rows: u64,
- num_bytes: u64,
-}
-
-impl ShuffleWriter {
- fn new(path: &str, schema: &Schema) -> Result {
- let file = File::create(path)
- .map_err(|e| {
- BallistaError::General(format!(
- "Failed to create partition file at {}: {:?}",
- path, e
- ))
- })
- .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
- let buffer_writer = std::io::BufWriter::new(file);
- Ok(Self {
- num_batches: 0,
- num_rows: 0,
- num_bytes: 0,
- path: path.to_owned(),
- writer: FileWriter::try_new(buffer_writer, schema, WriteOptions::default())?,
- })
- }
-
- fn write(&mut self, batch: &RecordBatch) -> Result<()> {
- self.writer.write(batch)?;
- self.num_batches += 1;
- self.num_rows += batch.num_rows() as u64;
- let num_bytes: usize = batch
- .columns()
- .iter()
- .map(|array| estimated_bytes_size(array.as_ref()))
- .sum();
- self.num_bytes += num_bytes as u64;
- Ok(())
- }
-
- fn finish(&mut self) -> Result<()> {
- self.writer.finish().map_err(DataFusionError::ArrowError)
- }
-
- fn path(&self) -> &str {
- &self.path
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
index 4f4f72eca74b..b66844fcf352 100644
--- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
@@ -71,7 +71,8 @@ use datafusion::physical_plan::{
limit::{GlobalLimitExec, LocalLimitExec},
projection::ProjectionExec,
repartition::RepartitionExec,
- sort::{SortExec, SortOptions},
+ sorts::sort::SortExec,
+ sorts::SortOptions,
Partitioning,
};
use datafusion::physical_plan::{
diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs
index 23826605b797..27750f7efc14 100644
--- a/ballista/rust/core/src/serde/physical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/physical_plan/mod.rs
@@ -36,7 +36,7 @@ mod roundtrip_tests {
hash_aggregate::{AggregateMode, HashAggregateExec},
hash_join::{HashJoinExec, PartitionMode},
limit::{GlobalLimitExec, LocalLimitExec},
- sort::SortExec,
+ sorts::sort::SortExec,
AggregateExpr, ColumnarValue, Distribution, ExecutionPlan, Partitioning,
PhysicalExpr,
},
diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
index 41484db57a7b..930f0757e202 100644
--- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
@@ -29,7 +29,7 @@ use std::{
use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode};
use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion::physical_plan::projection::ProjectionExec;
-use datafusion::physical_plan::sort::SortExec;
+use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::{cross_join::CrossJoinExec, ColumnStatistics};
use datafusion::physical_plan::{
expressions::{
diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs
index 15857678bf01..3cb8950a878b 100644
--- a/ballista/rust/core/src/utils.rs
+++ b/ballista/rust/core/src/utils.rs
@@ -60,7 +60,7 @@ use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::hash_aggregate::HashAggregateExec;
use datafusion::physical_plan::hash_join::HashJoinExec;
use datafusion::physical_plan::projection::ProjectionExec;
-use datafusion::physical_plan::sort::SortExec;
+use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::{
metrics, AggregateExpr, ExecutionPlan, Metric, PhysicalExpr, RecordBatchStream,
};
diff --git a/ballista/rust/scheduler/src/planner.rs b/ballista/rust/scheduler/src/planner.rs
index 3291a62abe64..3d3884fd5021 100644
--- a/ballista/rust/scheduler/src/planner.rs
+++ b/ballista/rust/scheduler/src/planner.rs
@@ -254,7 +254,7 @@ mod test {
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec};
use datafusion::physical_plan::hash_join::HashJoinExec;
- use datafusion::physical_plan::sort::SortExec;
+ use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::{
coalesce_partitions::CoalescePartitionsExec, projection::ProjectionExec,
};
diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml
index 48ecb49ac2f3..1277cc3ed163 100644
--- a/datafusion/Cargo.toml
+++ b/datafusion/Cargo.toml
@@ -77,6 +77,8 @@ rand = "0.8"
avro-rs = { version = "0.13", features = ["snappy"], optional = true }
num-traits = { version = "0.2", optional = true }
pyo3 = { version = "0.14", optional = true }
+uuid = { version = "0.8", features = ["v4"] }
+tempfile = "3"
[dependencies.arrow]
package = "arrow2"
@@ -89,7 +91,6 @@ features = ["io_csv", "io_json", "io_parquet", "io_parquet_compression", "io_ipc
[dev-dependencies]
criterion = "0.3"
-tempfile = "3"
doc-comment = "0.3"
parquet-format-async-temp = "0"
diff --git a/datafusion/benches/aggregate_query_sql.rs b/datafusion/benches/aggregate_query_sql.rs
index e580f4a63507..aeb226facf40 100644
--- a/datafusion/benches/aggregate_query_sql.rs
+++ b/datafusion/benches/aggregate_query_sql.rs
@@ -132,5 +132,7 @@ fn criterion_benchmark(c: &mut Criterion) {
});
}
-criterion_group!(benches, criterion_benchmark);
+criterion_group!(name = benches;
+ config = Criterion::default().measurement_time(std::time::Duration::from_secs(30));
+ targets = criterion_benchmark);
criterion_main!(benches);
diff --git a/datafusion/src/error.rs b/datafusion/src/error.rs
index a47bfac8b622..d9ac067d6e3f 100644
--- a/datafusion/src/error.rs
+++ b/datafusion/src/error.rs
@@ -61,6 +61,9 @@ pub enum DataFusionError {
/// Error returned during execution of the query.
/// Examples include files not found, errors in parsing certain types.
Execution(String),
+ /// This error is thrown when a consumer cannot acquire memory from the Memory Manager
+ /// we can just cancel the execution of the partition.
+ ResourcesExhausted(String),
}
impl DataFusionError {
@@ -129,6 +132,9 @@ impl Display for DataFusionError {
DataFusionError::Execution(ref desc) => {
write!(f, "Execution error: {}", desc)
}
+ DataFusionError::ResourcesExhausted(ref desc) => {
+ write!(f, "Resources exhausted: {}", desc)
+ }
}
}
}
diff --git a/datafusion/src/execution/disk_manager.rs b/datafusion/src/execution/disk_manager.rs
new file mode 100644
index 000000000000..80cc1506ae0e
--- /dev/null
+++ b/datafusion/src/execution/disk_manager.rs
@@ -0,0 +1,104 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Manages files generated during query execution, files are
+//! hashed among the directories listed in RuntimeConfig::local_dirs.
+
+use crate::error::{DataFusionError, Result};
+use std::collections::hash_map::DefaultHasher;
+use std::fs;
+use std::fs::File;
+use std::hash::{Hash, Hasher};
+use std::path::{Path, PathBuf};
+use uuid::Uuid;
+
+/// Manages files generated during query execution, e.g. spill files generated
+/// while processing dataset larger than available memory.
+pub struct DiskManager {
+ local_dirs: Vec,
+}
+
+impl DiskManager {
+ /// Create local dirs inside user provided dirs through conf
+ pub fn new(conf_dirs: &[String]) -> Result {
+ Ok(Self {
+ local_dirs: create_local_dirs(conf_dirs)?,
+ })
+ }
+
+ /// Create a file in conf dirs in randomized manner and return the file path
+ pub fn create_tmp_file(&self) -> Result {
+ create_tmp_file(&self.local_dirs)
+ }
+
+ #[allow(dead_code)]
+ fn cleanup_resource(&mut self) -> Result<()> {
+ for dir in self.local_dirs.drain(..) {
+ fs::remove_dir(dir)?;
+ }
+ Ok(())
+ }
+}
+
+/// Setup local dirs by creating one new dir in each of the given dirs
+fn create_local_dirs(local_dir: &[String]) -> Result> {
+ local_dir
+ .iter()
+ .map(|root| create_directory(root, "datafusion"))
+ .collect()
+}
+
+const MAX_DIR_CREATION_ATTEMPTS: i32 = 10;
+
+fn create_directory(root: &str, prefix: &str) -> Result {
+ let mut attempt = 0;
+ while attempt < MAX_DIR_CREATION_ATTEMPTS {
+ let mut path = PathBuf::from(root);
+ path.push(format!("{}-{}", prefix, Uuid::new_v4().to_string()));
+ let path = path.as_path();
+ if !path.exists() {
+ fs::create_dir(path)?;
+ return Ok(path.canonicalize().unwrap().to_str().unwrap().to_string());
+ }
+ attempt += 1;
+ }
+ Err(DataFusionError::Execution(format!(
+ "Failed to create a temp dir under {} after {} attempts",
+ root, MAX_DIR_CREATION_ATTEMPTS
+ )))
+}
+
+fn get_file(file_name: &str, local_dirs: &[String]) -> String {
+ let mut hasher = DefaultHasher::new();
+ file_name.hash(&mut hasher);
+ let hash = hasher.finish();
+ let dir = &local_dirs[hash.rem_euclid(local_dirs.len() as u64) as usize];
+ let mut path = PathBuf::new();
+ path.push(dir);
+ path.push(file_name);
+ path.to_str().unwrap().to_string()
+}
+
+fn create_tmp_file(local_dirs: &[String]) -> Result {
+ let name = Uuid::new_v4().to_string();
+ let mut path = get_file(&*name, local_dirs);
+ while Path::new(path.as_str()).exists() {
+ path = get_file(&*Uuid::new_v4().to_string(), local_dirs);
+ }
+ File::create(&path)?;
+ Ok(path)
+}
diff --git a/datafusion/src/execution/memory_management/allocation_strategist.rs b/datafusion/src/execution/memory_management/allocation_strategist.rs
new file mode 100644
index 000000000000..6e6e41fbe5af
--- /dev/null
+++ b/datafusion/src/execution/memory_management/allocation_strategist.rs
@@ -0,0 +1,269 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Execution Memory Pool that guarantees a memory allocation strategy
+
+use async_trait::async_trait;
+use hashbrown::HashMap;
+use log::{info, warn};
+use std::cmp::min;
+use std::fmt;
+use std::fmt::{Debug, Formatter};
+use tokio::runtime::Handle;
+use tokio::sync::{Notify, RwLock};
+
+#[async_trait]
+pub(crate) trait MemoryAllocationStrategist: Sync + Send + Debug {
+ /// Total memory available, which is pool_size - memory_used()
+ fn memory_available(&self) -> usize;
+ /// Current memory used by all PartitionManagers
+ fn memory_used(&self) -> usize;
+ /// Memory usage for a specific partition
+ fn memory_used_partition(&self, partition_id: usize) -> usize;
+ /// Acquire memory from a partition
+ async fn acquire_memory(&self, required: usize, partition_id: usize) -> usize;
+ /// Update memory usage for a partition
+ async fn update_usage(
+ &self,
+ granted_size: usize,
+ real_size: usize,
+ partition_id: usize,
+ );
+ /// release memory from partition
+ async fn release_memory(&self, release_size: usize, partition_id: usize);
+ /// release all memory acquired by a partition
+ async fn release_all(&self, partition_id: usize) -> usize;
+}
+
+pub(crate) struct DummyAllocationStrategist {
+ pool_size: usize,
+}
+
+impl DummyAllocationStrategist {
+ pub fn new() -> Self {
+ Self {
+ pool_size: usize::MAX,
+ }
+ }
+}
+
+impl Debug for DummyAllocationStrategist {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ f.debug_struct("DummyExecutionMemoryPool")
+ .field("total", &self.pool_size)
+ .finish()
+ }
+}
+
+#[async_trait]
+impl MemoryAllocationStrategist for DummyAllocationStrategist {
+ fn memory_available(&self) -> usize {
+ usize::MAX
+ }
+
+ fn memory_used(&self) -> usize {
+ 0
+ }
+
+ fn memory_used_partition(&self, _partition_id: usize) -> usize {
+ 0
+ }
+
+ async fn acquire_memory(&self, required: usize, _partition_id: usize) -> usize {
+ required
+ }
+
+ async fn update_usage(
+ &self,
+ _granted_size: usize,
+ _real_size: usize,
+ _partition_id: usize,
+ ) {
+ }
+
+ async fn release_memory(&self, _release_size: usize, _partition_id: usize) {}
+
+ async fn release_all(&self, _partition_id: usize) -> usize {
+ usize::MAX
+ }
+}
+
+pub(crate) struct FairStrategist {
+ pool_size: usize,
+ /// memory usage per partition
+ memory_usage: RwLock>,
+ notify: Notify,
+}
+
+impl FairStrategist {
+ pub fn new(size: usize) -> Self {
+ Self {
+ pool_size: size,
+ memory_usage: RwLock::new(HashMap::new()),
+ notify: Notify::new(),
+ }
+ }
+}
+
+impl Debug for FairStrategist {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ConstraintExecutionMemoryPool")
+ .field("total", &self.pool_size)
+ .field("used", &self.memory_used())
+ .finish()
+ }
+}
+
+#[async_trait]
+impl MemoryAllocationStrategist for FairStrategist {
+ fn memory_available(&self) -> usize {
+ self.pool_size - self.memory_used()
+ }
+
+ fn memory_used(&self) -> usize {
+ Handle::current()
+ .block_on(async { self.memory_usage.read().await.values().sum() })
+ }
+
+ fn memory_used_partition(&self, partition_id: usize) -> usize {
+ Handle::current().block_on(async {
+ let partition_usage = self.memory_usage.read().await;
+ match partition_usage.get(&partition_id) {
+ None => 0,
+ Some(v) => *v,
+ }
+ })
+ }
+
+ async fn acquire_memory(&self, required: usize, partition_id: usize) -> usize {
+ assert!(required > 0);
+ {
+ let mut partition_usage = self.memory_usage.write().await;
+ if !partition_usage.contains_key(&partition_id) {
+ partition_usage.entry(partition_id).or_insert(0);
+ // This will later cause waiting tasks to wake up and check numTasks again
+ self.notify.notify_waiters();
+ }
+ }
+
+ // Keep looping until we're either sure that we don't want to grant this request (because this
+ // partition would have more than 1 / num_active_partition of the memory) or we have enough free
+ // memory to give it (we always let each partition get at least 1 / (2 * num_active_partition)).
+ loop {
+ let partition_usage = self.memory_usage.read().await;
+ let num_active_partition = partition_usage.len();
+ let current_mem = *partition_usage.get(&partition_id).unwrap();
+
+ let max_memory_per_partition = self.pool_size / num_active_partition;
+ let min_memory_per_partition = self.pool_size / (2 * num_active_partition);
+
+ // How much we can grant this partition; keep its share within 0 <= X <= 1 / num_active_partition
+ let max_grant = match max_memory_per_partition.checked_sub(current_mem) {
+ None => 0,
+ Some(max_available) => min(required, max_available),
+ };
+
+ let total_used: usize = partition_usage.values().sum();
+ let total_available = self.pool_size - total_used;
+ // Only give it as much memory as is free, which might be none if it reached 1 / num_active_partition
+ let to_grant = min(max_grant, total_available);
+
+ // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
+ // if we can't give it this much now, wait for other tasks to free up memory
+ // (this happens if older tasks allocated lots of memory before N grew)
+ if to_grant < required && current_mem + to_grant < min_memory_per_partition {
+ info!(
+ "{:?} waiting for at least 1/2N of pool to be free",
+ partition_id
+ );
+ let _ = self.notify.notified().await;
+ } else {
+ drop(partition_usage);
+ let mut partition_usage = self.memory_usage.write().await;
+ *partition_usage.entry(partition_id).or_insert(0) += to_grant;
+ return to_grant;
+ }
+ }
+ }
+
+ async fn update_usage(
+ &self,
+ granted_size: usize,
+ real_size: usize,
+ partition_id: usize,
+ ) {
+ assert!(granted_size > 0);
+ assert!(real_size > 0);
+ if granted_size == real_size {
+ return;
+ } else {
+ let mut partition_usage = self.memory_usage.write().await;
+ if granted_size > real_size {
+ *partition_usage.entry(partition_id).or_insert(0) -=
+ granted_size - real_size;
+ } else {
+ // TODO: this would have caused OOM already if size estimation ahead is much smaller than
+ // that of actual allocation
+ *partition_usage.entry(partition_id).or_insert(0) +=
+ real_size - granted_size;
+ }
+ }
+ }
+
+ async fn release_memory(&self, release_size: usize, partition_id: usize) {
+ let partition_usage = self.memory_usage.read().await;
+ let current_mem = match partition_usage.get(&partition_id) {
+ None => 0,
+ Some(v) => *v,
+ };
+
+ let to_free = if current_mem < release_size {
+ warn!(
+ "Release called to free {} but partition only holds {} from the pool",
+ release_size, current_mem
+ );
+ current_mem
+ } else {
+ release_size
+ };
+ if partition_usage.contains_key(&partition_id) {
+ drop(partition_usage);
+ let mut partition_usage = self.memory_usage.write().await;
+ let entry = partition_usage.entry(partition_id).or_insert(0);
+ *entry -= to_free;
+ if *entry == 0 {
+ partition_usage.remove(&partition_id);
+ }
+ }
+ self.notify.notify_waiters();
+ }
+
+ async fn release_all(&self, partition_id: usize) -> usize {
+ let partition_usage = self.memory_usage.read().await;
+ let mut current_mem = 0;
+ match partition_usage.get(&partition_id) {
+ None => return current_mem,
+ Some(v) => current_mem = *v,
+ }
+
+ drop(partition_usage);
+ let mut partition_usage = self.memory_usage.write().await;
+ partition_usage.remove(&partition_id);
+ self.notify.notify_waiters();
+ return current_mem;
+ }
+}
diff --git a/datafusion/src/execution/memory_management/mod.rs b/datafusion/src/execution/memory_management/mod.rs
new file mode 100644
index 000000000000..d5d55440e566
--- /dev/null
+++ b/datafusion/src/execution/memory_management/mod.rs
@@ -0,0 +1,406 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Manages all available memory during query execution
+
+pub mod allocation_strategist;
+
+use std::cmp::Reverse;
+use crate::error::DataFusionError::ResourcesExhausted;
+use crate::error::{DataFusionError, Result};
+use crate::execution::memory_management::allocation_strategist::{
+ DummyAllocationStrategist, FairStrategist, MemoryAllocationStrategist,
+};
+use async_trait::async_trait;
+use futures::lock::Mutex;
+use hashbrown::HashMap;
+use log::{debug, info};
+use std::fmt;
+use std::fmt::{Debug, Display, Formatter};
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::{Arc, Weak};
+
+static mut CONSUMER_ID: AtomicUsize = AtomicUsize::new(0);
+
+#[derive(Clone)]
+/// Memory manager that enforces how execution memory is shared between all kinds of memory consumers.
+/// Execution memory refers to that used for computation in sorts, aggregations, joins and shuffles.
+pub struct MemoryManager {
+ strategist: Arc,
+ partition_memory_manager: Arc>>,
+}
+
+impl MemoryManager {
+ /// Create memory manager based on configured execution pool size.
+ pub fn new(exec_pool_size: usize) -> Self {
+ let strategist: Arc =
+ if exec_pool_size == usize::MAX {
+ Arc::new(DummyAllocationStrategist::new())
+ } else {
+ Arc::new(FairStrategist::new(exec_pool_size))
+ };
+ Self {
+ strategist,
+ partition_memory_manager: Arc::new(Mutex::new(HashMap::new())),
+ }
+ }
+
+ /// Acquire size of `required` memory from manager
+ pub async fn acquire_exec_memory(
+ self: &Arc,
+ required: usize,
+ consumer_id: &MemoryConsumerId,
+ ) -> Result {
+ let partition_id = consumer_id.partition_id;
+ let mut all_managers = self.partition_memory_manager.lock().await;
+ let partition_manager = all_managers
+ .entry(partition_id)
+ .or_insert_with(|| PartitionMemoryManager::new(partition_id, self.clone()));
+ partition_manager
+ .acquire_exec_memory(required, consumer_id)
+ .await
+ }
+
+ /// Register consumer to manager, for memory tracking and enables spilling by
+ /// memory used.
+ pub async fn register_consumer(self: &Arc, consumer: Arc) {
+ let partition_id = consumer.partition_id();
+ let mut all_managers = self.partition_memory_manager.lock().await;
+ let partition_manager = all_managers
+ .entry(partition_id)
+ .or_insert_with(|| PartitionMemoryManager::new(partition_id, self.clone()));
+ partition_manager.register_consumer(consumer).await;
+ }
+
+ pub(crate) async fn acquire_exec_pool_memory(
+ &self,
+ required: usize,
+ consumer: &MemoryConsumerId,
+ ) -> usize {
+ self.strategist
+ .acquire_memory(required, consumer.partition_id)
+ .await
+ }
+
+ pub(crate) async fn release_exec_pool_memory(
+ &self,
+ release_size: usize,
+ partition_id: usize,
+ ) {
+ self.strategist
+ .release_memory(release_size, partition_id)
+ .await
+ }
+
+ /// Revise pool usage while handling variable length data structure.
+ /// In this case, we may estimate and allocate in advance, and revise the usage
+ /// after the construction of the data structure.
+ #[allow(dead_code)]
+ pub(crate) async fn update_exec_pool_usage(
+ &self,
+ granted_size: usize,
+ real_size: usize,
+ consumer: &MemoryConsumerId,
+ ) {
+ self.strategist
+ .update_usage(granted_size, real_size, consumer.partition_id)
+ .await
+ }
+
+ /// Called during the shutdown procedure of a partition, for memory reclamation.
+ #[allow(dead_code)]
+ pub(crate) async fn release_all_exec_pool_for_partition(
+ &self,
+ partition_id: usize,
+ ) -> usize {
+ self.strategist.release_all(partition_id).await
+ }
+
+ #[allow(dead_code)]
+ pub(crate) fn exec_memory_used(&self) -> usize {
+ self.strategist.memory_used()
+ }
+
+ pub(crate) fn exec_memory_used_for_partition(&self, partition_id: usize) -> usize {
+ self.strategist.memory_used_partition(partition_id)
+ }
+}
+
+fn next_id() -> usize {
+ unsafe { CONSUMER_ID.fetch_add(1, Ordering::SeqCst) }
+}
+
+/// Memory manager that tracks all consumers for a specific partition
+/// Trigger the spill for consumer(s) when memory is insufficient
+pub struct PartitionMemoryManager {
+ memory_manager: Weak,
+ partition_id: usize,
+ consumers: Mutex>>,
+}
+
+impl PartitionMemoryManager {
+ /// Create manager for a partition
+ pub fn new(partition_id: usize, memory_manager: Arc) -> Self {
+ Self {
+ memory_manager: Arc::downgrade(&memory_manager),
+ partition_id,
+ consumers: Mutex::new(HashMap::new()),
+ }
+ }
+
+ /// Register a memory consumer at its first appearance
+ pub async fn register_consumer(&self, consumer: Arc) {
+ let mut consumers = self.consumers.lock().await;
+ let id = consumer.id().clone();
+ consumers.insert(id, consumer);
+ }
+
+ /// Try to acquire `required` of execution memory for the consumer and return the number of bytes
+ /// obtained, or return ResourcesExhausted if no enough memory available even after possible spills.
+ pub async fn acquire_exec_memory(
+ &self,
+ required: usize,
+ consumer_id: &MemoryConsumerId,
+ ) -> Result {
+ let mut consumers = self.consumers.lock().await;
+ let memory_manager = self.memory_manager.upgrade().ok_or_else(|| {
+ DataFusionError::Execution("Failed to get MemoryManager".to_string())
+ })?;
+ let mut got = memory_manager
+ .acquire_exec_pool_memory(required, consumer_id)
+ .await;
+ if got < required {
+ // Try to release memory from other consumers first
+ // Sort the consumers according to their memory usage and spill from
+ // consumer that holds the maximum memory, to reduce the total frequency of
+ // spilling
+
+ let mut all_consumers: Vec> = vec![];
+ for c in consumers.iter() {
+ all_consumers.push(c.1.clone());
+ }
+ all_consumers.sort_by_key(|b| Reverse(b.get_used()));
+
+ for c in all_consumers.iter_mut() {
+ if c.id() == consumer_id {
+ continue;
+ }
+
+ let released = c.spill(required - got, consumer_id).await?;
+ if released > 0 {
+ debug!(
+ "Partition {} released {} from consumer {}",
+ self.partition_id,
+ released,
+ c.str_repr()
+ );
+ got += memory_manager
+ .acquire_exec_pool_memory(required - got, consumer_id)
+ .await;
+ if got > required {
+ break;
+ }
+ }
+ }
+ }
+
+ if got < required {
+ // spill itself
+ let consumer = consumers.get_mut(consumer_id).unwrap();
+ let released = consumer.spill(required - got, consumer_id).await?;
+ if released > 0 {
+ debug!(
+ "Partition {} released {} from consumer itself {}",
+ self.partition_id,
+ released,
+ consumer.str_repr()
+ );
+ got += memory_manager
+ .acquire_exec_pool_memory(required - got, consumer_id)
+ .await;
+ }
+ }
+
+ if got < required {
+ return Err(ResourcesExhausted(format!(
+ "Unable to acquire {} bytes of memory, got {}",
+ required, got
+ )));
+ }
+
+ debug!("{} acquired {}", consumer_id, got);
+ Ok(got)
+ }
+
+ /// log current memory usage for all consumers in this partition
+ pub async fn show_memory_usage(&self) -> Result<()> {
+ info!("Memory usage for partition {}", self.partition_id);
+ let consumers = self.consumers.lock().await;
+ let mut used = 0;
+ for (_, c) in consumers.iter() {
+ let cur_used = c.get_used();
+ used += cur_used;
+ if cur_used > 0 {
+ info!(
+ "Consumer {} acquired {}",
+ c.str_repr(),
+ human_readable_size(cur_used as usize)
+ )
+ }
+ }
+ let no_consumer_size = self
+ .memory_manager
+ .upgrade()
+ .ok_or_else(|| {
+ DataFusionError::Execution("Failed to get MemoryManager".to_string())
+ })?
+ .exec_memory_used_for_partition(self.partition_id)
+ - (used as usize);
+ info!(
+ "{} bytes of memory were used for partition {} without specific consumer",
+ human_readable_size(no_consumer_size),
+ self.partition_id
+ );
+ Ok(())
+ }
+}
+
+#[derive(Clone, Debug, Hash, Eq, PartialEq)]
+/// Id that uniquely identifies a Memory Consumer
+pub struct MemoryConsumerId {
+ /// partition the consumer belongs to
+ pub partition_id: usize,
+ /// unique id
+ pub id: usize,
+}
+
+impl MemoryConsumerId {
+ /// Auto incremented new Id
+ pub fn new(partition_id: usize) -> Self {
+ let id = next_id();
+ Self { partition_id, id }
+ }
+}
+
+impl Display for MemoryConsumerId {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ write!(f, "{}:{}", self.partition_id, self.id)
+ }
+}
+
+#[async_trait]
+/// A memory consumer that supports spilling.
+pub trait MemoryConsumer: Send + Sync + Debug {
+ /// Display name of the consumer
+ fn name(&self) -> String;
+
+ /// Unique id of the consumer
+ fn id(&self) -> &MemoryConsumerId;
+
+ /// Ptr to MemoryManager
+ fn memory_manager(&self) -> Arc;
+
+ /// partition that the consumer belongs to
+ fn partition_id(&self) -> usize {
+ self.id().partition_id
+ }
+
+ /// Try allocate `required` bytes as needed
+ async fn allocate(&self, required: usize) -> Result<()> {
+ let got = self
+ .memory_manager()
+ .acquire_exec_memory(required, self.id())
+ .await?;
+ self.update_used(got as isize);
+ Ok(())
+ }
+
+ /// Spill at least `size` bytes to disk and update related counters
+ async fn spill(&self, size: usize, trigger: &MemoryConsumerId) -> Result {
+ let released = self.spill_inner(size, trigger).await?;
+ if released > 0 {
+ self.memory_manager()
+ .release_exec_pool_memory(released, self.id().partition_id)
+ .await;
+ self.update_used(-(released as isize));
+ self.spilled_bytes_add(released);
+ self.spilled_count_increment();
+ }
+ Ok(released)
+ }
+
+ /// Spill at least `size` bytes to disk and frees memory
+ async fn spill_inner(&self, size: usize, trigger: &MemoryConsumerId)
+ -> Result;
+
+ /// Get current memory usage for the consumer itself
+ fn get_used(&self) -> isize;
+
+ /// Update memory usage
+ fn update_used(&self, delta: isize);
+
+ /// Get total number of spilled bytes so far
+ fn spilled_bytes(&self) -> usize;
+
+ /// Update spilled bytes counter
+ fn spilled_bytes_add(&self, add: usize);
+
+ /// Get total number of triggered spills so far
+ fn spilled_count(&self) -> usize;
+
+ /// Update spilled count
+ fn spilled_count_increment(&self);
+
+ /// String representation for the consumer
+ fn str_repr(&self) -> String {
+ format!("{}({})", self.name(), self.id())
+ }
+
+ #[inline]
+ /// log during spilling
+ fn log_spill(&self, size: usize) {
+ info!(
+ "{} spilling of {} bytes to disk ({} times so far)",
+ self.str_repr(),
+ size,
+ self.spilled_count()
+ );
+ }
+}
+
+const TB: u64 = 1 << 40;
+const GB: u64 = 1 << 30;
+const MB: u64 = 1 << 20;
+const KB: u64 = 1 << 10;
+
+fn human_readable_size(size: usize) -> String {
+ let size = size as u64;
+ let (value, unit) = {
+ if size >= 2 * TB {
+ (size as f64 / TB as f64, "TB")
+ } else if size >= 2 * GB {
+ (size as f64 / GB as f64, "GB")
+ } else if size >= 2 * MB {
+ (size as f64 / MB as f64, "MB")
+ } else if size >= 2 * KB {
+ (size as f64 / KB as f64, "KB")
+ } else {
+ (size as f64, "B")
+ }
+ };
+ format!("{:.1} {}", value, unit)
+}
diff --git a/datafusion/src/execution/mod.rs b/datafusion/src/execution/mod.rs
index e353a3160b8d..c7929c179972 100644
--- a/datafusion/src/execution/mod.rs
+++ b/datafusion/src/execution/mod.rs
@@ -19,4 +19,7 @@
pub mod context;
pub mod dataframe_impl;
+pub mod disk_manager;
+pub mod memory_management;
pub mod options;
+pub mod runtime_env;
diff --git a/datafusion/src/execution/runtime_env.rs b/datafusion/src/execution/runtime_env.rs
new file mode 100644
index 000000000000..ae64a3733c11
--- /dev/null
+++ b/datafusion/src/execution/runtime_env.rs
@@ -0,0 +1,122 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Execution runtime environment that tracks memory, disk and various configurations
+//! that are used during physical plan execution.
+
+use crate::error::Result;
+use crate::execution::disk_manager::DiskManager;
+use crate::execution::memory_management::{MemoryConsumer, MemoryManager};
+use lazy_static::lazy_static;
+use std::sync::Arc;
+
+lazy_static! {
+ /// Employ lazy static temporarily for RuntimeEnv, to avoid plumbing it through
+ /// all `async fn execute(&self, partition: usize, runtime: Arc)`
+ pub static ref RUNTIME_ENV: Arc = {
+ let config = RuntimeConfig::new();
+ Arc::new(RuntimeEnv::new(config).unwrap())
+ };
+}
+
+#[derive(Clone)]
+/// Execution runtime environment
+pub struct RuntimeEnv {
+ /// Runtime configuration
+ pub config: RuntimeConfig,
+ /// Runtime memory management
+ pub memory_manager: Arc,
+ /// Manage temporary files during query execution
+ pub disk_manager: Arc,
+}
+
+impl RuntimeEnv {
+ /// Create env based on configuration
+ pub fn new(config: RuntimeConfig) -> Result {
+ let memory_manager = Arc::new(MemoryManager::new(config.max_memory));
+ let disk_manager = Arc::new(DiskManager::new(&config.local_dirs)?);
+ Ok(Self {
+ config,
+ memory_manager,
+ disk_manager,
+ })
+ }
+
+ /// Get execution batch size based on config
+ pub fn batch_size(&self) -> usize {
+ self.config.batch_size
+ }
+
+ /// Register the consumer to get it tracked
+ pub async fn register_consumer(&self, memory_consumer: Arc) {
+ self.memory_manager.register_consumer(memory_consumer).await;
+ }
+}
+
+#[derive(Clone)]
+/// Execution runtime configuration
+pub struct RuntimeConfig {
+ /// Default batch size when creating new batches
+ pub batch_size: usize,
+ /// Max execution memory allowed for DataFusion
+ pub max_memory: usize,
+ /// Local dirs to store temporary files during execution
+ pub local_dirs: Vec,
+}
+
+impl RuntimeConfig {
+ /// New with default values
+ pub fn new() -> Self {
+ Default::default()
+ }
+
+ /// Customize batch size
+ pub fn with_batch_size(mut self, n: usize) -> Self {
+ // batch size must be greater than zero
+ assert!(n > 0);
+ self.batch_size = n;
+ self
+ }
+
+ /// Customize exec size
+ pub fn with_max_execution_memory(mut self, max_memory: usize) -> Self {
+ assert!(max_memory > 0);
+ self.max_memory = max_memory;
+ self
+ }
+
+ /// Customize exec size
+ pub fn with_local_dirs(mut self, local_dirs: Vec) -> Self {
+ assert!(!local_dirs.is_empty());
+ self.local_dirs = local_dirs;
+ self
+ }
+}
+
+impl Default for RuntimeConfig {
+ fn default() -> Self {
+ let tmp_dir = tempfile::tempdir().unwrap();
+ let path = tmp_dir.path().to_str().unwrap().to_string();
+ std::mem::forget(tmp_dir);
+
+ Self {
+ batch_size: 8192,
+ max_memory: usize::MAX,
+ local_dirs: vec![path],
+ }
+ }
+}
diff --git a/datafusion/src/physical_plan/common.rs b/datafusion/src/physical_plan/common.rs
index 94d53438e736..9099dc50251e 100644
--- a/datafusion/src/physical_plan/common.rs
+++ b/datafusion/src/physical_plan/common.rs
@@ -25,12 +25,14 @@ use arrow::compute::concatenate;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::ArrowError;
use arrow::error::Result as ArrowResult;
+use arrow::io::ipc::write::{FileWriter, WriteOptions};
use arrow::record_batch::RecordBatch;
use futures::channel::mpsc;
use futures::{Future, SinkExt, Stream, StreamExt, TryStreamExt};
use pin_project_lite::pin_project;
use std::fs;
-use std::fs::metadata;
+use std::fs::{metadata, File};
+use std::io::BufWriter;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::task::JoinHandle;
@@ -275,6 +277,63 @@ impl Drop for AbortOnDropMany {
}
}
+/// Write in Arrow IPC format.
+pub struct IPCWriterWrapper {
+ /// path
+ pub path: String,
+ /// Inner writer
+ pub writer: FileWriter>,
+ /// bathes written
+ pub num_batches: u64,
+ /// rows written
+ pub num_rows: u64,
+ /// bytes written
+ pub num_bytes: u64,
+}
+
+impl IPCWriterWrapper {
+ /// Create new writer
+ pub fn new(path: &str, schema: &Schema) -> Result {
+ let file = File::create(path).map_err(DataFusionError::IoError)?;
+ let buffer_writer = std::io::BufWriter::new(file);
+ Ok(Self {
+ num_batches: 0,
+ num_rows: 0,
+ num_bytes: 0,
+ path: path.to_owned(),
+ writer: FileWriter::try_new(buffer_writer, schema, WriteOptions::default())?,
+ })
+ }
+
+ /// Write one single batch
+ pub fn write(&mut self, batch: &RecordBatch) -> Result<()> {
+ self.writer.write(batch)?;
+ self.num_batches += 1;
+ self.num_rows += batch.num_rows() as u64;
+ let num_bytes: usize = batch_memory_size(batch);
+ self.num_bytes += num_bytes as u64;
+ Ok(())
+ }
+
+ /// Finish the writer
+ pub fn finish(&mut self) -> Result<()> {
+ self.writer.finish().map_err(DataFusionError::ArrowError)
+ }
+
+ /// Path write to
+ pub fn path(&self) -> &str {
+ &self.path
+ }
+}
+
+/// Estimate batch memory footprint
+pub fn batch_memory_size(rb: &RecordBatch) -> usize {
+ rb.columns()
+ .iter()
+ .map(|c| estimated_bytes_size(c.as_ref()))
+ .sum()
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs
index 932c76bf894f..417b4695af33 100644
--- a/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/datafusion/src/physical_plan/hash_aggregate.rs
@@ -34,6 +34,7 @@ use crate::physical_plan::{
};
use crate::{
error::{DataFusionError, Result},
+ execution::memory_management::MemoryConsumerId,
scalar::ScalarValue,
};
@@ -213,8 +214,11 @@ impl ExecutionPlan for HashAggregateExec {
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
+ let streamer_id = MemoryConsumerId::new(partition);
+
if self.group_expr.is_empty() {
Ok(Box::pin(HashAggregateStream::new(
+ streamer_id,
self.mode,
self.schema.clone(),
self.aggr_expr.clone(),
@@ -735,6 +739,7 @@ pin_project! {
/// Special case aggregate with no groups
async fn compute_hash_aggregate(
+ _id: MemoryConsumerId,
mode: AggregateMode,
schema: SchemaRef,
aggr_expr: Vec>,
@@ -771,6 +776,7 @@ async fn compute_hash_aggregate(
impl HashAggregateStream {
/// Create a new HashAggregateStream
pub fn new(
+ id: MemoryConsumerId,
mode: AggregateMode,
schema: SchemaRef,
aggr_expr: Vec>,
@@ -783,6 +789,7 @@ impl HashAggregateStream {
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
let join_handle = tokio::spawn(async move {
let result = compute_hash_aggregate(
+ id,
mode,
schema_clone,
aggr_expr,
diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs
index 769e88bad5a9..277d3f00c6a6 100644
--- a/datafusion/src/physical_plan/mod.rs
+++ b/datafusion/src/physical_plan/mod.rs
@@ -641,8 +641,7 @@ pub mod projection;
#[cfg(feature = "regex_expressions")]
pub mod regex_expressions;
pub mod repartition;
-pub mod sort;
-pub mod sort_preserving_merge;
+pub mod sorts;
pub mod stream;
pub mod string_expressions;
pub mod type_coercion;
diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs
index 86490b786b06..784c0161bfe1 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -45,7 +45,7 @@ use crate::physical_plan::hash_join::HashJoinExec;
use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::repartition::RepartitionExec;
-use crate::physical_plan::sort::SortExec;
+use crate::physical_plan::sorts::sort::SortExec;
use crate::physical_plan::udf;
use crate::physical_plan::windows::WindowAggExec;
use crate::physical_plan::{join_utils, Partitioning};
diff --git a/datafusion/src/physical_plan/sorts/external_sort.rs b/datafusion/src/physical_plan/sorts/external_sort.rs
new file mode 100644
index 000000000000..2dce542f922a
--- /dev/null
+++ b/datafusion/src/physical_plan/sorts/external_sort.rs
@@ -0,0 +1,711 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines the External-Sort plan
+
+use crate::error::{DataFusionError, Result};
+use crate::execution::memory_management::{
+ MemoryConsumer, MemoryConsumerId, MemoryManager,
+};
+use crate::execution::runtime_env::RuntimeEnv;
+use crate::execution::runtime_env::RUNTIME_ENV;
+use crate::physical_plan::common::{
+ batch_memory_size, IPCWriterWrapper, SizedRecordBatchStream,
+};
+use crate::physical_plan::expressions::PhysicalSortExpr;
+use crate::physical_plan::metrics::{
+ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
+};
+use crate::physical_plan::sorts::in_mem_sort::InMemSortStream;
+use crate::physical_plan::sorts::sort::sort_batch;
+use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream;
+use crate::physical_plan::sorts::SpillableStream;
+use crate::physical_plan::stream::RecordBatchReceiverStream;
+use crate::physical_plan::{
+ DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
+ SendableRecordBatchStream, Statistics,
+};
+use arrow::datatypes::SchemaRef;
+use arrow::error::Result as ArrowResult;
+use arrow::io::ipc::read::{read_file_metadata, FileReader};
+use arrow::record_batch::RecordBatch;
+use async_trait::async_trait;
+use futures::lock::Mutex;
+use futures::StreamExt;
+use log::{error, info};
+use std::any::Any;
+use std::fmt;
+use std::fmt::{Debug, Formatter};
+use std::fs::File;
+use std::io::BufReader;
+use std::sync::atomic::{AtomicBool, AtomicIsize, AtomicUsize, Ordering};
+use std::sync::Arc;
+use tokio::sync::mpsc::{Receiver as TKReceiver, Sender as TKSender};
+use tokio::task;
+
+struct ExternalSorter {
+ id: MemoryConsumerId,
+ schema: SchemaRef,
+ in_mem_batches: Mutex>,
+ spills: Mutex>,
+ /// Sort expressions
+ expr: Vec,
+ runtime: Arc,
+ metrics: ExecutionPlanMetricsSet,
+ used: AtomicIsize,
+ spilled_bytes: AtomicUsize,
+ spilled_count: AtomicUsize,
+ insert_finished: AtomicBool,
+}
+
+impl ExternalSorter {
+ pub fn new(
+ partition_id: usize,
+ schema: SchemaRef,
+ expr: Vec,
+ runtime: Arc,
+ ) -> Self {
+ Self {
+ id: MemoryConsumerId::new(partition_id),
+ schema,
+ in_mem_batches: Mutex::new(vec![]),
+ spills: Mutex::new(vec![]),
+ expr,
+ runtime,
+ metrics: ExecutionPlanMetricsSet::new(),
+ used: AtomicIsize::new(0),
+ spilled_bytes: AtomicUsize::new(0),
+ spilled_count: AtomicUsize::new(0),
+ insert_finished: AtomicBool::new(false),
+ }
+ }
+
+ pub(crate) fn finish_insert(&self) {
+ self.insert_finished.store(true, Ordering::SeqCst);
+ }
+
+ async fn spill_while_inserting(&self) -> Result {
+ info!(
+ "{} spilling sort data of {} to disk while inserting ({} time(s) so far)",
+ self.str_repr(),
+ self.get_used(),
+ self.spilled_count()
+ );
+
+ let partition = self.partition_id();
+ let mut in_mem_batches = self.in_mem_batches.lock().await;
+ // we could always get a chance to free some memory as long as we are holding some
+ if in_mem_batches.len() == 0 {
+ return Ok(0);
+ }
+
+ let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
+
+ let path = self.runtime.disk_manager.create_tmp_file()?;
+ let stream = in_mem_merge_sort(
+ &mut *in_mem_batches,
+ self.schema.clone(),
+ &*self.expr,
+ self.runtime.batch_size(),
+ baseline_metrics,
+ )
+ .await;
+
+ let total_size = spill(&mut stream?, path.clone(), self.schema.clone()).await?;
+
+ let mut spills = self.spills.lock().await;
+ self.spilled_count.fetch_add(1, Ordering::SeqCst);
+ self.spilled_bytes.fetch_add(total_size, Ordering::SeqCst);
+ spills.push(path);
+ Ok(total_size)
+ }
+
+ async fn insert_batch(&self, input: RecordBatch) -> Result<()> {
+ let size = batch_memory_size(&input);
+ self.allocate(size).await?;
+ // sort each batch as it's inserted, more probably to be cache-resident
+ let sorted_batch = sort_batch(input, self.schema.clone(), &*self.expr)?;
+ let mut in_mem_batches = self.in_mem_batches.lock().await;
+ in_mem_batches.push(sorted_batch);
+ Ok(())
+ }
+
+ /// MergeSort in mem batches as well as spills into total order with `SortPreservingMergeStream`(SPMS).
+ /// Always put in mem batch based stream to idx 0 in SPMS so that we could spill
+ /// the stream when `spill()` is called on us.
+ async fn sort(&self) -> Result {
+ let partition = self.partition_id();
+ let mut in_mem_batches = self.in_mem_batches.lock().await;
+ let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
+ let mut streams: Vec = vec![];
+ let in_mem_stream = in_mem_merge_sort(
+ &mut *in_mem_batches,
+ self.schema.clone(),
+ &self.expr,
+ self.runtime.batch_size(),
+ baseline_metrics,
+ )
+ .await?;
+ streams.push(SpillableStream::new_spillable(in_mem_stream));
+
+ let mut spills = self.spills.lock().await;
+
+ for spill in spills.drain(..) {
+ let stream = read_spill_as_stream(spill, self.schema.clone()).await?;
+ streams.push(SpillableStream::new_unspillable(stream));
+ }
+ let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
+
+ Ok(Box::pin(
+ SortPreservingMergeStream::new_from_stream(
+ streams,
+ self.schema.clone(),
+ &self.expr,
+ self.runtime.batch_size(),
+ baseline_metrics,
+ partition,
+ self.runtime.clone(),
+ )
+ .await,
+ ))
+ }
+}
+
+impl Debug for ExternalSorter {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ExternalSorter")
+ .field("id", &self.id())
+ .field("memory_used", &self.get_used())
+ .field("spilled_bytes", &self.spilled_bytes())
+ .field("spilled_count", &self.spilled_count())
+ .finish()
+ }
+}
+
+#[async_trait]
+impl MemoryConsumer for ExternalSorter {
+ fn name(&self) -> String {
+ "ExternalSorter".to_owned()
+ }
+
+ fn id(&self) -> &MemoryConsumerId {
+ &self.id
+ }
+
+ fn memory_manager(&self) -> Arc {
+ self.runtime.memory_manager.clone()
+ }
+
+ async fn spill_inner(
+ &self,
+ _size: usize,
+ _trigger: &MemoryConsumerId,
+ ) -> Result {
+ if !self.insert_finished.load(Ordering::SeqCst) {
+ let total_size = self.spill_while_inserting().await;
+ total_size
+ } else {
+ Ok(0)
+ }
+ }
+
+ fn get_used(&self) -> isize {
+ self.used.load(Ordering::SeqCst)
+ }
+
+ fn update_used(&self, delta: isize) {
+ self.used.fetch_add(delta, Ordering::SeqCst);
+ }
+
+ fn spilled_bytes(&self) -> usize {
+ self.spilled_bytes.load(Ordering::SeqCst)
+ }
+
+ fn spilled_bytes_add(&self, add: usize) {
+ self.spilled_bytes.fetch_add(add, Ordering::SeqCst);
+ }
+
+ fn spilled_count(&self) -> usize {
+ self.spilled_count.load(Ordering::SeqCst)
+ }
+
+ fn spilled_count_increment(&self) {
+ self.spilled_count.fetch_add(1, Ordering::SeqCst);
+ }
+}
+
+/// consume the `sorted_bathes` and do in_mem_sort
+async fn in_mem_merge_sort(
+ sorted_bathes: &mut Vec,
+ schema: SchemaRef,
+ expressions: &[PhysicalSortExpr],
+ target_batch_size: usize,
+ baseline_metrics: BaselineMetrics,
+) -> Result {
+ if sorted_bathes.len() == 1 {
+ Ok(Box::pin(SizedRecordBatchStream::new(
+ schema,
+ vec![Arc::new(sorted_bathes.pop().unwrap())],
+ )))
+ } else {
+ let new = sorted_bathes.drain(..).collect();
+ assert_eq!(sorted_bathes.len(), 0);
+ Ok(Box::pin(InMemSortStream::new(
+ new,
+ schema,
+ expressions,
+ target_batch_size,
+ baseline_metrics,
+ )?))
+ }
+}
+
+async fn spill(
+ in_mem_stream: &mut SendableRecordBatchStream,
+ path: String,
+ schema: SchemaRef,
+) -> Result {
+ let (sender, receiver): (
+ TKSender>,
+ TKReceiver>,
+ ) = tokio::sync::mpsc::channel(2);
+ while let Some(item) = in_mem_stream.next().await {
+ sender.send(item).await.ok();
+ }
+ let path_clone = path.clone();
+ let res =
+ task::spawn_blocking(move || write_sorted(receiver, path_clone, schema)).await;
+ match res {
+ Ok(r) => r,
+ Err(e) => Err(DataFusionError::Execution(format!(
+ "Error occurred while spilling {}",
+ e
+ ))),
+ }
+}
+
+async fn read_spill_as_stream(
+ path: String,
+ schema: SchemaRef,
+) -> Result {
+ let (sender, receiver): (
+ TKSender>,
+ TKReceiver>,
+ ) = tokio::sync::mpsc::channel(2);
+ let path_clone = path.clone();
+ let join_handle = task::spawn_blocking(move || {
+ if let Err(e) = read_spill(sender, path_clone) {
+ error!("Failure while reading spill file: {}. Error: {}", path, e);
+ }
+ });
+ Ok(RecordBatchReceiverStream::create(
+ &schema,
+ receiver,
+ join_handle,
+ ))
+}
+
+pub(crate) async fn convert_stream_disk_based(
+ in_mem_stream: &mut SendableRecordBatchStream,
+ path: String,
+ schema: SchemaRef,
+) -> Result<(SendableRecordBatchStream, usize)> {
+ let size = spill(in_mem_stream, path.clone(), schema.clone()).await?;
+ read_spill_as_stream(path.clone(), schema.clone())
+ .await
+ .map(|s| (s, size))
+}
+
+fn write_sorted(
+ mut receiver: TKReceiver>,
+ path: String,
+ schema: SchemaRef,
+) -> Result {
+ let mut writer = IPCWriterWrapper::new(path.as_ref(), schema.as_ref())?;
+ while let Some(batch) = receiver.blocking_recv() {
+ writer.write(&batch?)?;
+ }
+ writer.finish()?;
+ info!(
+ "Spilled {} batches of total {} rows to disk, memory released {}",
+ writer.num_batches, writer.num_rows, writer.num_bytes
+ );
+ Ok(writer.num_bytes as usize)
+}
+
+fn read_spill(sender: TKSender>, path: String) -> Result<()> {
+ let mut file = BufReader::new(File::open(&path)?);
+ let file_meta = read_file_metadata(&mut file)?;
+ let reader = FileReader::new(&mut file, file_meta, None);
+ for batch in reader {
+ sender
+ .blocking_send(batch)
+ .map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
+ }
+ Ok(())
+}
+
+/// Sort execution plan
+#[derive(Debug)]
+pub struct ExternalSortExec {
+ /// Input schema
+ input: Arc,
+ /// Sort expressions
+ expr: Vec,
+ /// Execution metrics
+ metrics: ExecutionPlanMetricsSet,
+ /// Preserve partitions of input plan
+ preserve_partitioning: bool,
+}
+
+impl ExternalSortExec {
+ /// Create a new sort execution plan
+ pub fn try_new(
+ expr: Vec,
+ input: Arc,
+ ) -> Result {
+ Ok(Self::new_with_partitioning(expr, input, false))
+ }
+
+ /// Create a new sort execution plan with the option to preserve
+ /// the partitioning of the input plan
+ pub fn new_with_partitioning(
+ expr: Vec,
+ input: Arc,
+ preserve_partitioning: bool,
+ ) -> Self {
+ Self {
+ expr,
+ input,
+ metrics: ExecutionPlanMetricsSet::new(),
+ preserve_partitioning,
+ }
+ }
+
+ /// Input schema
+ pub fn input(&self) -> &Arc {
+ &self.input
+ }
+
+ /// Sort expressions
+ pub fn expr(&self) -> &[PhysicalSortExpr] {
+ &self.expr
+ }
+}
+
+#[async_trait]
+impl ExecutionPlan for ExternalSortExec {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn schema(&self) -> SchemaRef {
+ self.input.schema()
+ }
+
+ fn children(&self) -> Vec> {
+ vec![self.input.clone()]
+ }
+
+ /// Get the output partitioning of this plan
+ fn output_partitioning(&self) -> Partitioning {
+ if self.preserve_partitioning {
+ self.input.output_partitioning()
+ } else {
+ Partitioning::UnknownPartitioning(1)
+ }
+ }
+
+ fn required_child_distribution(&self) -> Distribution {
+ if self.preserve_partitioning {
+ Distribution::UnspecifiedDistribution
+ } else {
+ Distribution::SinglePartition
+ }
+ }
+
+ fn with_new_children(
+ &self,
+ children: Vec>,
+ ) -> Result> {
+ match children.len() {
+ 1 => Ok(Arc::new(ExternalSortExec::try_new(
+ self.expr.clone(),
+ children[0].clone(),
+ )?)),
+ _ => Err(DataFusionError::Internal(
+ "ExternalSortExec wrong number of children".to_string(),
+ )),
+ }
+ }
+
+ async fn execute(&self, partition: usize) -> Result {
+ if !self.preserve_partitioning {
+ if 0 != partition {
+ return Err(DataFusionError::Internal(format!(
+ "ExternalSortExec invalid partition {}",
+ partition
+ )));
+ }
+
+ // sort needs to operate on a single partition currently
+ if 1 != self.input.output_partitioning().partition_count() {
+ return Err(DataFusionError::Internal(
+ "SortExec requires a single input partition".to_owned(),
+ ));
+ }
+ }
+
+ let input = self.input.execute(partition).await?;
+ external_sort(input, partition, self.expr.clone(), RUNTIME_ENV.clone()).await
+ }
+
+ fn fmt_as(
+ &self,
+ t: DisplayFormatType,
+ f: &mut std::fmt::Formatter,
+ ) -> std::fmt::Result {
+ match t {
+ DisplayFormatType::Default => {
+ let expr: Vec = self.expr.iter().map(|e| e.to_string()).collect();
+ write!(f, "SortExec: [{}]", expr.join(","))
+ }
+ }
+ }
+
+ fn metrics(&self) -> Option {
+ Some(self.metrics.clone_inner())
+ }
+
+ fn statistics(&self) -> Statistics {
+ self.input.statistics()
+ }
+}
+
+/// Sort based on `ExternalSorter`
+pub async fn external_sort(
+ mut input: SendableRecordBatchStream,
+ partition_id: usize,
+ expr: Vec,
+ runtime: Arc,
+) -> Result {
+ let schema = input.schema();
+ let sorter = Arc::new(ExternalSorter::new(
+ partition_id,
+ schema.clone(),
+ expr,
+ runtime.clone(),
+ ));
+ runtime.register_consumer(sorter.clone()).await;
+
+ while let Some(batch) = input.next().await {
+ let batch = batch?;
+ sorter.insert_batch(batch).await?;
+ }
+
+ sorter.finish_insert();
+ sorter.sort().await
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::datasource::object_store::local::LocalFileSystem;
+ use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
+ use crate::physical_plan::expressions::col;
+ use crate::physical_plan::memory::MemoryExec;
+ use crate::physical_plan::{
+ collect,
+ file_format::{CsvExec, PhysicalPlanConfig},
+ };
+ use crate::test;
+ use crate::test_util;
+ use arrow::array::*;
+ use arrow::compute::sort::SortOptions;
+ use arrow::datatypes::*;
+
+ #[tokio::test]
+ async fn test_sort() -> Result<()> {
+ let schema = test_util::aggr_test_schema();
+ let partitions = 4;
+ let (_, files) =
+ test::create_partitioned_csv("aggregate_test_100.csv", partitions)?;
+
+ let csv = CsvExec::new(
+ PhysicalPlanConfig {
+ object_store: Arc::new(LocalFileSystem {}),
+ file_schema: Arc::clone(&schema),
+ file_groups: files,
+ statistics: Statistics::default(),
+ projection: None,
+ batch_size: 1024,
+ limit: None,
+ table_partition_cols: vec![],
+ },
+ true,
+ b',',
+ );
+
+ let sort_exec = Arc::new(ExternalSortExec::try_new(
+ vec![
+ // c1 string column
+ PhysicalSortExpr {
+ expr: col("c1", &schema)?,
+ options: SortOptions::default(),
+ },
+ // c2 uin32 column
+ PhysicalSortExpr {
+ expr: col("c2", &schema)?,
+ options: SortOptions::default(),
+ },
+ // c7 uin8 column
+ PhysicalSortExpr {
+ expr: col("c7", &schema)?,
+ options: SortOptions::default(),
+ },
+ ],
+ Arc::new(CoalescePartitionsExec::new(Arc::new(csv))),
+ )?);
+
+ let result: Vec = collect(sort_exec).await?;
+ assert_eq!(result.len(), 1);
+
+ let columns = result[0].columns();
+
+ let c1 = columns[0]
+ .as_any()
+ .downcast_ref::>()
+ .unwrap();
+ assert_eq!(c1.value(0), "a");
+ assert_eq!(c1.value(c1.len() - 1), "e");
+
+ let c2 = columns[1].as_any().downcast_ref::().unwrap();
+ assert_eq!(c2.value(0), 1);
+ assert_eq!(c2.value(c2.len() - 1), 5,);
+
+ let c7 = columns[6].as_any().downcast_ref::().unwrap();
+ assert_eq!(c7.value(0), 15);
+ assert_eq!(c7.value(c7.len() - 1), 254,);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_lex_sort_by_float() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Float32, true),
+ Field::new("b", DataType::Float64, true),
+ ]));
+
+ // define data.
+ let batch = RecordBatch::try_new(
+ schema.clone(),
+ vec![
+ Arc::new(Float32Array::from(vec![
+ Some(f32::NAN),
+ None,
+ None,
+ Some(f32::NAN),
+ Some(1.0_f32),
+ Some(1.0_f32),
+ Some(2.0_f32),
+ Some(3.0_f32),
+ ])),
+ Arc::new(Float64Array::from(vec![
+ Some(200.0_f64),
+ Some(20.0_f64),
+ Some(10.0_f64),
+ Some(100.0_f64),
+ Some(f64::NAN),
+ None,
+ None,
+ Some(f64::NAN),
+ ])),
+ ],
+ )?;
+
+ let sort_exec = Arc::new(ExternalSortExec::try_new(
+ vec![
+ PhysicalSortExpr {
+ expr: col("a", &schema)?,
+ options: SortOptions {
+ descending: true,
+ nulls_first: true,
+ },
+ },
+ PhysicalSortExpr {
+ expr: col("b", &schema)?,
+ options: SortOptions {
+ descending: false,
+ nulls_first: false,
+ },
+ },
+ ],
+ Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None)?),
+ )?);
+
+ assert_eq!(DataType::Float32, *sort_exec.schema().field(0).data_type());
+ assert_eq!(DataType::Float64, *sort_exec.schema().field(1).data_type());
+
+ let result: Vec = collect(sort_exec.clone()).await?;
+ // let metrics = sort_exec.metrics().unwrap();
+ // assert!(metrics.elapsed_compute().unwrap() > 0);
+ // assert_eq!(metrics.output_rows().unwrap(), 8);
+ assert_eq!(result.len(), 1);
+
+ let columns = result[0].columns();
+
+ assert_eq!(DataType::Float32, *columns[0].data_type());
+ assert_eq!(DataType::Float64, *columns[1].data_type());
+
+ let a = columns[0].as_any().downcast_ref::().unwrap();
+ let b = columns[1].as_any().downcast_ref::().unwrap();
+
+ // convert result to strings to allow comparing to expected result containing NaN
+ let result: Vec<(Option, Option)> = (0..result[0].num_rows())
+ .map(|i| {
+ let aval = if a.is_valid(i) {
+ Some(a.value(i).to_string())
+ } else {
+ None
+ };
+ let bval = if b.is_valid(i) {
+ Some(b.value(i).to_string())
+ } else {
+ None
+ };
+ (aval, bval)
+ })
+ .collect();
+
+ let expected: Vec<(Option, Option)> = vec![
+ (None, Some("10".to_owned())),
+ (None, Some("20".to_owned())),
+ (Some("NaN".to_owned()), Some("100".to_owned())),
+ (Some("NaN".to_owned()), Some("200".to_owned())),
+ (Some("3".to_owned()), Some("NaN".to_owned())),
+ (Some("2".to_owned()), None),
+ (Some("1".to_owned()), Some("NaN".to_owned())),
+ (Some("1".to_owned()), None),
+ ];
+
+ assert_eq!(expected, result);
+
+ Ok(())
+ }
+}
diff --git a/datafusion/src/physical_plan/sorts/in_mem_sort.rs b/datafusion/src/physical_plan/sorts/in_mem_sort.rs
new file mode 100644
index 000000000000..4491db2a80f1
--- /dev/null
+++ b/datafusion/src/physical_plan/sorts/in_mem_sort.rs
@@ -0,0 +1,241 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::BinaryHeap;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+use arrow::array::growable::make_growable;
+use arrow::compute::sort::SortOptions;
+use arrow::datatypes::SchemaRef;
+use arrow::error::ArrowError;
+use arrow::error::Result as ArrowResult;
+use arrow::record_batch::RecordBatch;
+use futures::Stream;
+
+use crate::error::Result;
+use crate::physical_plan::metrics::BaselineMetrics;
+use crate::physical_plan::sorts::{RowIndex, SortKeyCursor};
+use crate::physical_plan::{
+ expressions::PhysicalSortExpr, PhysicalExpr, RecordBatchStream,
+};
+
+pub(crate) struct InMemSortStream {
+ /// The schema of the RecordBatches yielded by this stream
+ schema: SchemaRef,
+ /// For each input stream maintain a dequeue of SortKeyCursor
+ ///
+ /// Exhausted cursors will be popped off the front once all
+ /// their rows have been yielded to the output
+ bathes: Vec>,
+ /// The accumulated row indexes for the next record batch
+ in_progress: Vec,
+ /// The desired RecordBatch size to yield
+ target_batch_size: usize,
+ /// used to record execution metrics
+ baseline_metrics: BaselineMetrics,
+ /// If the stream has encountered an error
+ aborted: bool,
+ /// min heap for record comparison
+ min_heap: BinaryHeap,
+}
+
+impl InMemSortStream {
+ pub(crate) fn new(
+ sorted_batches: Vec,
+ schema: SchemaRef,
+ expressions: &[PhysicalSortExpr],
+ target_batch_size: usize,
+ baseline_metrics: BaselineMetrics,
+ ) -> Result {
+ let len = sorted_batches.len();
+ let mut cursors = Vec::with_capacity(len);
+ let mut min_heap = BinaryHeap::with_capacity(len);
+
+ let column_expressions: Vec> =
+ expressions.iter().map(|x| x.expr.clone()).collect();
+
+ // The sort options for each expression
+ let sort_options: Arc> =
+ Arc::new(expressions.iter().map(|x| x.options).collect());
+
+ sorted_batches
+ .into_iter()
+ .enumerate()
+ .try_for_each(|(idx, batch)| {
+ let batch = Arc::new(batch);
+ let cursor = match SortKeyCursor::new(
+ idx,
+ batch.clone(),
+ &column_expressions,
+ sort_options.clone(),
+ ) {
+ Ok(cursor) => cursor,
+ Err(e) => return Err(e),
+ };
+ min_heap.push(cursor);
+ cursors.insert(idx, batch);
+ Ok(())
+ })?;
+
+ Ok(Self {
+ schema,
+ bathes: cursors,
+ target_batch_size,
+ baseline_metrics,
+ aborted: false,
+ in_progress: vec![],
+ min_heap,
+ })
+ }
+
+ /// Returns the index of the next batch to pull a row from, or None
+ /// if all cursors for all batch are exhausted
+ fn next_cursor(&mut self) -> Result> {
+ match self.min_heap.pop() {
+ None => Ok(None),
+ Some(cursor) => Ok(Some(cursor)),
+ }
+ }
+
+ /// Drains the in_progress row indexes, and builds a new RecordBatch from them
+ ///
+ /// Will then drop any cursors for which all rows have been yielded to the output
+ fn build_record_batch(&mut self) -> ArrowResult {
+ let columns = self
+ .schema
+ .fields()
+ .iter()
+ .enumerate()
+ .map(|(column_idx, _)| {
+ let arrays = self
+ .bathes
+ .iter()
+ .map(|batch| batch.column(column_idx).as_ref())
+ .collect::>();
+
+ let mut array_data =
+ make_growable(&arrays, false, self.in_progress.len());
+
+ if self.in_progress.is_empty() {
+ return array_data.as_arc();
+ }
+
+ let first = &self.in_progress[0];
+ let mut buffer_idx = first.stream_idx;
+ let mut start_row_idx = first.row_idx;
+ let mut end_row_idx = start_row_idx + 1;
+
+ for row_index in self.in_progress.iter().skip(1) {
+ let next_buffer_idx = row_index.stream_idx;
+
+ if next_buffer_idx == buffer_idx && row_index.row_idx == end_row_idx {
+ // subsequent row in same batch
+ end_row_idx += 1;
+ continue;
+ }
+
+ // emit current batch of rows for current buffer
+ array_data.extend(
+ buffer_idx,
+ start_row_idx,
+ end_row_idx - start_row_idx,
+ );
+
+ // start new batch of rows
+ buffer_idx = next_buffer_idx;
+ start_row_idx = row_index.row_idx;
+ end_row_idx = start_row_idx + 1;
+ }
+
+ // emit final batch of rows
+ array_data.extend(buffer_idx, start_row_idx, end_row_idx - start_row_idx);
+ array_data.as_arc()
+ })
+ .collect();
+
+ self.in_progress.clear();
+ RecordBatch::try_new(self.schema.clone(), columns)
+ }
+
+ #[inline]
+ fn poll_next_inner(
+ self: &mut Pin<&mut Self>,
+ _cx: &mut Context<'_>,
+ ) -> Poll>> {
+ if self.aborted {
+ return Poll::Ready(None);
+ }
+
+ loop {
+ // NB timer records time taken on drop, so there are no
+ // calls to `timer.done()` below.
+ let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
+ let _timer = elapsed_compute.timer();
+
+ match self.next_cursor() {
+ Ok(Some(mut cursor)) => {
+ let batch_idx = cursor.batch_idx;
+ let row_idx = cursor.advance();
+
+ // insert the cursor back to min_heap if the record batch is not exhausted
+ if !cursor.is_finished() {
+ self.min_heap.push(cursor);
+ }
+
+ self.in_progress.push(RowIndex {
+ stream_idx: batch_idx,
+ cursor_idx: 0,
+ row_idx,
+ });
+ }
+ Ok(None) if self.in_progress.is_empty() => return Poll::Ready(None),
+ Ok(None) => return Poll::Ready(Some(self.build_record_batch())),
+ Err(e) => {
+ self.aborted = true;
+ return Poll::Ready(Some(Err(ArrowError::External(
+ "".to_string(),
+ Box::new(e),
+ ))));
+ }
+ };
+
+ if self.in_progress.len() == self.target_batch_size {
+ return Poll::Ready(Some(self.build_record_batch()));
+ }
+ }
+ }
+}
+
+impl Stream for InMemSortStream {
+ type Item = ArrowResult;
+
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll> {
+ let poll = self.poll_next_inner(cx);
+ self.baseline_metrics.record_poll(poll)
+ }
+}
+
+impl RecordBatchStream for InMemSortStream {
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+}
diff --git a/datafusion/src/physical_plan/sorts/mod.rs b/datafusion/src/physical_plan/sorts/mod.rs
new file mode 100644
index 000000000000..0a055463c099
--- /dev/null
+++ b/datafusion/src/physical_plan/sorts/mod.rs
@@ -0,0 +1,294 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Sort functionalities
+
+pub mod external_sort;
+mod in_mem_sort;
+pub mod sort;
+pub mod sort_preserving_merge;
+
+use crate::error::{DataFusionError, Result};
+use crate::physical_plan::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream};
+use arrow::array::ord::DynComparator;
+pub use arrow::compute::sort::SortOptions;
+use arrow::record_batch::RecordBatch;
+use arrow::{array::ArrayRef, error::Result as ArrowResult};
+use futures::channel::mpsc;
+use futures::stream::FusedStream;
+use futures::Stream;
+use hashbrown::HashMap;
+use std::borrow::BorrowMut;
+use std::cmp::Ordering;
+use std::fmt::{Debug, Formatter};
+use std::pin::Pin;
+use std::sync::{Arc, RwLock};
+use std::task::{Context, Poll};
+
+/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of
+/// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys.
+///
+/// Additionally it maintains a row cursor that can be advanced through the rows
+/// of the provided `RecordBatch`
+///
+/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to
+/// by this row cursor, with that of another `SortKeyCursor`. A cursor stores
+/// a row comparator for each other cursor that it is compared to.
+struct SortKeyCursor {
+ columns: Vec,
+ cur_row: usize,
+ num_rows: usize,
+
+ // An index uniquely identifying the record batch scanned by this cursor.
+ batch_idx: usize,
+ batch: Arc,
+
+ // A collection of comparators that compare rows in this cursor's batch to
+ // the cursors in other batches. Other batches are uniquely identified by
+ // their batch_idx.
+ batch_comparators: RwLock>>,
+ sort_options: Arc>,
+}
+
+impl std::fmt::Debug for SortKeyCursor {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("SortKeyCursor")
+ .field("columns", &self.columns)
+ .field("cur_row", &self.cur_row)
+ .field("num_rows", &self.num_rows)
+ .field("batch_idx", &self.batch_idx)
+ .field("batch", &self.batch)
+ .field("batch_comparators", &"")
+ .finish()
+ }
+}
+
+impl SortKeyCursor {
+ fn new(
+ batch_idx: usize,
+ batch: Arc,
+ sort_key: &[Arc],
+ sort_options: Arc>,
+ ) -> Result {
+ let columns: Vec = sort_key
+ .iter()
+ .map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())))
+ .collect::>()?;
+ Ok(Self {
+ cur_row: 0,
+ num_rows: batch.num_rows(),
+ columns,
+ batch,
+ batch_idx,
+ batch_comparators: RwLock::new(HashMap::new()),
+ sort_options,
+ })
+ }
+
+ fn is_finished(&self) -> bool {
+ self.num_rows == self.cur_row
+ }
+
+ fn advance(&mut self) -> usize {
+ assert!(!self.is_finished());
+ let t = self.cur_row;
+ self.cur_row += 1;
+ t
+ }
+
+ /// Compares the sort key pointed to by this instance's row cursor with that of another
+ fn compare(&self, other: &SortKeyCursor) -> Result {
+ if self.columns.len() != other.columns.len() {
+ return Err(DataFusionError::Internal(format!(
+ "SortKeyCursors had inconsistent column counts: {} vs {}",
+ self.columns.len(),
+ other.columns.len()
+ )));
+ }
+
+ if self.columns.len() != self.sort_options.len() {
+ return Err(DataFusionError::Internal(format!(
+ "Incorrect number of SortOptions provided to SortKeyCursor::compare, expected {} got {}",
+ self.columns.len(),
+ self.sort_options.len()
+ )));
+ }
+
+ let zipped: Vec<((&ArrayRef, &ArrayRef), &SortOptions)> = self
+ .columns
+ .iter()
+ .zip(other.columns.iter())
+ .zip(self.sort_options.iter())
+ .collect::>();
+
+ self.init_cmp_if_needed(other, &zipped)?;
+
+ let map = self.batch_comparators.read().unwrap();
+ let cmp = map.get(&other.batch_idx).ok_or_else(|| {
+ DataFusionError::Execution(format!(
+ "Failed to find comparator for {} cmp {}",
+ self.batch_idx, other.batch_idx
+ ))
+ })?;
+
+ for (i, ((l, r), sort_options)) in zipped.iter().enumerate() {
+ match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) {
+ (false, true) if sort_options.nulls_first => return Ok(Ordering::Less),
+ (false, true) => return Ok(Ordering::Greater),
+ (true, false) if sort_options.nulls_first => {
+ return Ok(Ordering::Greater)
+ }
+ (true, false) => return Ok(Ordering::Less),
+ (false, false) => {}
+ (true, true) => match cmp[i](self.cur_row, other.cur_row) {
+ Ordering::Equal => {}
+ o if sort_options.descending => return Ok(o.reverse()),
+ o => return Ok(o),
+ },
+ }
+ }
+
+ Ok(Ordering::Equal)
+ }
+
+ /// Initialize a collection of comparators for comparing
+ /// columnar arrays of this cursor and "other" if needed.
+ fn init_cmp_if_needed(
+ &self,
+ other: &SortKeyCursor,
+ zipped: &[((&ArrayRef, &ArrayRef), &SortOptions)],
+ ) -> Result<()> {
+ let hm = self.batch_comparators.read().unwrap();
+ if !hm.contains_key(&other.batch_idx) {
+ drop(hm);
+ let mut map = self.batch_comparators.write().unwrap();
+ let cmp = map
+ .borrow_mut()
+ .entry(other.batch_idx)
+ .or_insert_with(|| Vec::with_capacity(other.columns.len()));
+
+ for (i, ((l, r), _)) in zipped.iter().enumerate() {
+ if i >= cmp.len() {
+ // initialise comparators
+ cmp.push(arrow::array::ord::build_compare(l.as_ref(), r.as_ref())?);
+ }
+ }
+ }
+ Ok(())
+ }
+}
+
+/// A `RowIndex` identifies a specific row from those buffered
+/// by a `SortPreservingMergeStream`
+#[derive(Debug, Clone)]
+struct RowIndex {
+ /// The index of the stream
+ stream_idx: usize,
+ /// For sort_preserving_merge, it's the index of the cursor within the stream's VecDequeue.
+ /// For in_mem_sort which have only one batch for each stream, cursor_idx always 0
+ cursor_idx: usize,
+ /// The row index
+ row_idx: usize,
+}
+
+impl Ord for SortKeyCursor {
+ fn cmp(&self, other: &Self) -> Ordering {
+ other.compare(self).unwrap()
+ }
+}
+
+impl PartialEq for SortKeyCursor {
+ fn eq(&self, other: &Self) -> bool {
+ other.compare(self).unwrap() == Ordering::Equal
+ }
+}
+
+impl Eq for SortKeyCursor {}
+
+impl PartialOrd for SortKeyCursor {
+ fn partial_cmp(&self, other: &Self) -> Option {
+ other.compare(self).ok()
+ }
+}
+
+pub(crate) struct SpillableStream {
+ pub stream: SendableRecordBatchStream,
+ pub spillable: bool,
+}
+
+impl Debug for SpillableStream {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "SpillableStream {}", self.spillable)
+ }
+}
+
+impl SpillableStream {
+ pub(crate) fn new_spillable(stream: SendableRecordBatchStream) -> Self {
+ Self {
+ stream,
+ spillable: true,
+ }
+ }
+
+ pub(crate) fn new_unspillable(stream: SendableRecordBatchStream) -> Self {
+ Self {
+ stream,
+ spillable: false,
+ }
+ }
+}
+
+#[derive(Debug)]
+enum StreamWrapper {
+ Receiver(mpsc::Receiver>),
+ Stream(Option),
+}
+
+impl Stream for StreamWrapper {
+ type Item = ArrowResult;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> {
+ match self.get_mut() {
+ StreamWrapper::Receiver(ref mut receiver) => Pin::new(receiver).poll_next(cx),
+ StreamWrapper::Stream(ref mut stream) => {
+ let inner = match stream {
+ None => return Poll::Ready(None),
+ Some(inner) => inner,
+ };
+
+ match Pin::new(&mut inner.stream).poll_next(cx) {
+ Poll::Ready(msg) => {
+ if msg.is_none() {
+ *stream = None
+ }
+ Poll::Ready(msg)
+ }
+ Poll::Pending => Poll::Pending,
+ }
+ }
+ }
+ }
+}
+
+impl FusedStream for StreamWrapper {
+ fn is_terminated(&self) -> bool {
+ match self {
+ StreamWrapper::Receiver(receiver) => receiver.is_terminated(),
+ StreamWrapper::Stream(stream) => stream.is_none(),
+ }
+ }
+}
diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs
similarity index 97%
rename from datafusion/src/physical_plan/sort.rs
rename to datafusion/src/physical_plan/sorts/sort.rs
index bf521bb7c1fc..0a15fb5f0173 100644
--- a/datafusion/src/physical_plan/sort.rs
+++ b/datafusion/src/physical_plan/sorts/sort.rs
@@ -17,15 +17,15 @@
//! Defines the SORT plan
-use super::common::AbortOnDropSingle;
-use super::metrics::{
- BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
-};
-use super::{RecordBatchStream, SendableRecordBatchStream, Statistics};
+use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::error::{DataFusionError, Result};
+use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::expressions::PhysicalSortExpr;
+use crate::physical_plan::metrics::{
+ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
+};
use crate::physical_plan::{
- common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
+ common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, Statistics,
};
pub use arrow::compute::sort::SortOptions;
use arrow::compute::{sort::lexsort_to_indices, take};
@@ -186,7 +186,8 @@ impl ExecutionPlan for SortExec {
}
}
-fn sort_batch(
+/// Sort the record batch based on `expr` and reorder based on sort result.
+pub fn sort_batch(
batch: RecordBatch,
schema: SchemaRef,
expr: &[PhysicalSortExpr],
@@ -198,8 +199,6 @@ fn sort_batch(
.map_err(DataFusionError::into_arrow_external_error)?;
let columns = columns.iter().map(|x| x.into()).collect::>();
- // sort combined record batch
- // TODO: pushup the limit expression to sort
let indices = lexsort_to_indices::(&columns, None)?;
// reorder all rows based on sorted indices
@@ -242,6 +241,7 @@ impl SortStream {
// combine all record batches into one for each column
let combined = common::combine_batches(&batches, schema.clone())?;
// sort combined record batch
+ // TODO: pushup the limit expression to sort
let result = combined
.map(|batch| sort_batch(batch, schema, &expr))
.transpose()?
diff --git a/datafusion/src/physical_plan/sort_preserving_merge.rs b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
similarity index 83%
rename from datafusion/src/physical_plan/sort_preserving_merge.rs
rename to datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
index ec3ad9f9a34c..37a0d6b83360 100644
--- a/datafusion/src/physical_plan/sort_preserving_merge.rs
+++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -17,8 +17,6 @@
//! Defines the sort preserving merge plan
-use super::common::AbortOnDropMany;
-use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use std::any::Any;
use std::cmp::Ordering;
use std::collections::VecDeque;
@@ -26,8 +24,7 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
-use arrow::array::ord::DynComparator;
-use arrow::array::{growable::make_growable, ord::build_compare, ArrayRef};
+use arrow::array::growable::make_growable;
use arrow::compute::sort::SortOptions;
use arrow::datatypes::SchemaRef;
use arrow::error::ArrowError;
@@ -36,15 +33,29 @@ use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use futures::channel::mpsc;
use futures::stream::FusedStream;
-use futures::{Stream, StreamExt};
-use hashbrown::HashMap;
+use futures::{Future, Stream, StreamExt};
use crate::error::{DataFusionError, Result};
+use crate::execution::memory_management::{
+ MemoryConsumer, MemoryConsumerId, MemoryManager,
+};
+use crate::execution::runtime_env::RuntimeEnv;
+use crate::execution::runtime_env::RUNTIME_ENV;
+use crate::physical_plan::common::AbortOnDropMany;
+use crate::physical_plan::metrics::{
+ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
+};
+use crate::physical_plan::sorts::external_sort::convert_stream_disk_based;
+use crate::physical_plan::sorts::{
+ RowIndex, SortKeyCursor, SpillableStream, StreamWrapper,
+};
use crate::physical_plan::{
common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType,
Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream,
SendableRecordBatchStream, Statistics,
};
+use futures::lock::Mutex;
+use std::fmt::{Debug, Formatter};
/// Sort preserving merge execution plan
///
@@ -151,7 +162,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
self.input.execute(0).await
}
_ => {
- let (receivers, join_handles) = (0..input_partitions)
+ let (streams, join_handles) = (0..input_partitions)
.into_iter()
.map(|part_i| {
let (sender, receiver) = mpsc::channel(1);
@@ -161,14 +172,19 @@ impl ExecutionPlan for SortPreservingMergeExec {
})
.unzip();
- Ok(Box::pin(SortPreservingMergeStream::new(
- receivers,
- AbortOnDropMany(join_handles),
- self.schema(),
- &self.expr,
- self.target_batch_size,
- baseline_metrics,
- )))
+ Ok(Box::pin(
+ SortPreservingMergeStream::new_from_receiver(
+ streams,
+ AbortOnDropMany(join_handles),
+ self.schema(),
+ &self.expr,
+ self.target_batch_size,
+ baseline_metrics,
+ partition,
+ RUNTIME_ENV.clone(),
+ )
+ .await,
+ ))
}
}
}
@@ -195,179 +211,148 @@ impl ExecutionPlan for SortPreservingMergeExec {
}
}
-/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of
-/// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys.
-///
-/// Additionally it maintains a row cursor that can be advanced through the rows
-/// of the provided `RecordBatch`
-///
-/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to
-/// by this row cursor, with that of another `SortKeyCursor`. A cursor stores
-/// a row comparator for each other cursor that it is compared to.
-struct SortKeyCursor {
- columns: Vec,
- cur_row: usize,
- num_rows: usize,
-
- // An index uniquely identifying the record batch scanned by this cursor.
- batch_idx: usize,
- batch: RecordBatch,
-
- // A collection of comparators that compare rows in this cursor's batch to
- // the cursors in other batches. Other batches are uniquely identified by
- // their batch_idx.
- batch_comparators: HashMap>,
+struct MergingStreams {
+ /// ConsumerId
+ id: MemoryConsumerId,
+ /// The sorted input streams to merge together
+ pub(crate) streams: Mutex>,
+ /// The schema of the RecordBatches yielded by this stream
+ schema: SchemaRef,
+ /// Runtime
+ runtime: Arc,
}
-impl<'a> std::fmt::Debug for SortKeyCursor {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- f.debug_struct("SortKeyCursor")
- .field("columns", &self.columns)
- .field("cur_row", &self.cur_row)
- .field("num_rows", &self.num_rows)
- .field("batch_idx", &self.batch_idx)
- .field("batch", &self.batch)
- .field("batch_comparators", &"")
+impl Debug for MergingStreams {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("MergingStreams")
+ .field("id", &self.id())
.finish()
}
}
-impl SortKeyCursor {
- fn new(
- batch_idx: usize,
- batch: RecordBatch,
- sort_key: &[Arc],
- ) -> Result {
- let columns = sort_key
- .iter()
- .map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())))
- .collect::>()?;
- Ok(Self {
- cur_row: 0,
- num_rows: batch.num_rows(),
- columns,
- batch,
- batch_idx,
- batch_comparators: HashMap::new(),
- })
+impl MergingStreams {
+ pub fn new(
+ partition: usize,
+ input_streams: Vec,
+ schema: SchemaRef,
+ runtime: Arc,
+ ) -> Self {
+ Self {
+ id: MemoryConsumerId::new(partition),
+ streams: Mutex::new(input_streams),
+ schema,
+ runtime,
+ }
}
- fn is_finished(&self) -> bool {
- self.num_rows == self.cur_row
+ async fn spill_underlying_stream(
+ &self,
+ stream_idx: usize,
+ path: String,
+ ) -> Result {
+ let mut streams = self.streams.lock().await;
+ let origin_stream = &mut streams[stream_idx];
+ match origin_stream {
+ StreamWrapper::Receiver(_) => {
+ Err(DataFusionError::Execution(
+ "Unexpected spilling a receiver stream in SortPreservingMerge"
+ .to_string(),
+ ))
+ }
+ StreamWrapper::Stream(stream) => match stream {
+ None => Ok(0),
+ Some(ref mut stream) => {
+ return if stream.spillable {
+ let (disk_stream, spill_size) = convert_stream_disk_based(
+ &mut stream.stream,
+ path,
+ self.schema.clone(),
+ )
+ .await?;
+ streams[stream_idx] = StreamWrapper::Stream(Some(
+ SpillableStream::new_unspillable(disk_stream),
+ ));
+ Ok(spill_size)
+ } else {
+ Ok(0)
+ }
+ }
+ },
+ }
}
+}
- fn advance(&mut self) -> usize {
- assert!(!self.is_finished());
- let t = self.cur_row;
- self.cur_row += 1;
- t
+#[async_trait]
+impl MemoryConsumer for MergingStreams {
+ fn name(&self) -> String {
+ "MergingStreams".to_owned()
}
- /// Compares the sort key pointed to by this instance's row cursor with that of another
- fn compare(
- &mut self,
- other: &SortKeyCursor,
- options: &[SortOptions],
- ) -> Result {
- if self.columns.len() != other.columns.len() {
- return Err(DataFusionError::Internal(format!(
- "SortKeyCursors had inconsistent column counts: {} vs {}",
- self.columns.len(),
- other.columns.len()
- )));
- }
+ fn id(&self) -> &MemoryConsumerId {
+ &self.id
+ }
- if self.columns.len() != options.len() {
- return Err(DataFusionError::Internal(format!(
- "Incorrect number of SortOptions provided to SortKeyCursor::compare, expected {} got {}",
- self.columns.len(),
- options.len()
- )));
- }
+ fn memory_manager(&self) -> Arc {
+ self.runtime.memory_manager.clone()
+ }
- let zipped = self
- .columns
- .iter()
- .zip(other.columns.iter())
- .zip(options.iter());
-
- // Recall or initialise a collection of comparators for comparing
- // columnar arrays of this cursor and "other".
- let cmp = self
- .batch_comparators
- .entry(other.batch_idx)
- .or_insert_with(|| Vec::with_capacity(other.columns.len()));
-
- for (i, ((l, r), sort_options)) in zipped.enumerate() {
- if i >= cmp.len() {
- // initialise comparators as potentially needed
- cmp.push(build_compare(l.as_ref(), r.as_ref())?);
- }
+ async fn spill_inner(
+ &self,
+ _size: usize,
+ _trigger: &MemoryConsumerId,
+ ) -> Result {
+ let path = self.runtime.disk_manager.create_tmp_file()?;
+ self.spill_underlying_stream(0, path).await
+ }
- match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) {
- (false, true) if sort_options.nulls_first => return Ok(Ordering::Less),
- (false, true) => return Ok(Ordering::Greater),
- (true, false) if sort_options.nulls_first => {
- return Ok(Ordering::Greater)
- }
- (true, false) => return Ok(Ordering::Less),
- (false, false) => {}
- (true, true) => match cmp[i](self.cur_row, other.cur_row) {
- Ordering::Equal => {}
- o if sort_options.descending => return Ok(o.reverse()),
- o => return Ok(o),
- },
- }
- }
+ fn get_used(&self) -> isize {
+ todo!()
+ }
- Ok(Ordering::Equal)
+ fn update_used(&self, _delta: isize) {
+ todo!()
+ }
+
+ fn spilled_bytes(&self) -> usize {
+ todo!()
}
-}
-/// A `RowIndex` identifies a specific row from those buffered
-/// by a `SortPreservingMergeStream`
-#[derive(Debug, Clone)]
-struct RowIndex {
- /// The index of the stream
- stream_idx: usize,
- /// The index of the cursor within the stream's VecDequeue
- cursor_idx: usize,
- /// The row index
- row_idx: usize,
+ fn spilled_bytes_add(&self, _add: usize) {
+ todo!()
+ }
+
+ fn spilled_count(&self) -> usize {
+ todo!()
+ }
+
+ fn spilled_count_increment(&self) {
+ todo!()
+ }
}
#[derive(Debug)]
-struct SortPreservingMergeStream {
+pub(crate) struct SortPreservingMergeStream {
/// The schema of the RecordBatches yielded by this stream
schema: SchemaRef,
-
/// The sorted input streams to merge together
- receivers: Vec>>,
-
+ streams: Arc,
/// Drop helper for tasks feeding the [`receivers`](Self::receivers)
_drop_helper: AbortOnDropMany<()>,
-
/// For each input stream maintain a dequeue of SortKeyCursor
///
/// Exhausted cursors will be popped off the front once all
/// their rows have been yielded to the output
cursors: Vec>,
-
/// The accumulated row indexes for the next record batch
in_progress: Vec,
-
/// The physical expressions to sort by
column_expressions: Vec>,
-
/// The sort options for each expression
- sort_options: Vec,
-
+ sort_options: Arc>,
/// The desired RecordBatch size to yield
target_batch_size: usize,
-
/// used to record execution metrics
baseline_metrics: BaselineMetrics,
-
/// If the stream has encountered an error
aborted: bool,
@@ -376,26 +361,82 @@ struct SortPreservingMergeStream {
}
impl SortPreservingMergeStream {
- fn new(
+ #[allow(clippy::too_many_arguments)]
+ pub(crate) async fn new_from_receiver(
receivers: Vec>>,
_drop_helper: AbortOnDropMany<()>,
schema: SchemaRef,
expressions: &[PhysicalSortExpr],
target_batch_size: usize,
baseline_metrics: BaselineMetrics,
+ partition: usize,
+ runtime: Arc,
) -> Self {
let cursors = (0..receivers.len())
.into_iter()
.map(|_| VecDeque::new())
.collect();
+ let receivers = receivers
+ .into_iter()
+ .map(StreamWrapper::Receiver)
+ .collect();
+ let streams = Arc::new(MergingStreams::new(
+ partition,
+ receivers,
+ schema.clone(),
+ runtime.clone(),
+ ));
+ runtime.register_consumer(streams.clone()).await;
+
Self {
schema,
cursors,
- receivers,
+ streams,
_drop_helper,
column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(),
- sort_options: expressions.iter().map(|x| x.options).collect(),
+ sort_options: Arc::new(expressions.iter().map(|x| x.options).collect()),
+ target_batch_size,
+ baseline_metrics,
+ aborted: false,
+ in_progress: vec![],
+ next_batch_index: 0,
+ }
+ }
+
+ pub(crate) async fn new_from_stream(
+ streams: Vec,
+ schema: SchemaRef,
+ expressions: &[PhysicalSortExpr],
+ target_batch_size: usize,
+ baseline_metrics: BaselineMetrics,
+ partition: usize,
+ runtime: Arc,
+ ) -> Self {
+ let cursors = (0..streams.len())
+ .into_iter()
+ .map(|_| VecDeque::new())
+ .collect();
+
+ let streams = streams
+ .into_iter()
+ .map(|s| StreamWrapper::Stream(Some(s)))
+ .collect::>();
+ let streams = Arc::new(MergingStreams::new(
+ partition,
+ streams,
+ schema.clone(),
+ runtime.clone(),
+ ));
+ runtime.register_consumer(streams.clone()).await;
+
+ Self {
+ schema,
+ cursors,
+ streams,
+ _drop_helper: AbortOnDropMany(vec![]),
+ column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(),
+ sort_options: Arc::new(expressions.iter().map(|x| x.options).collect()),
target_batch_size,
baseline_metrics,
aborted: false,
@@ -419,37 +460,45 @@ impl SortPreservingMergeStream {
}
}
- let stream = &mut self.receivers[idx];
- if stream.is_terminated() {
- return Poll::Ready(Ok(()));
- }
+ let mut streams_future = self.streams.streams.lock();
- // Fetch a new input record and create a cursor from it
- match futures::ready!(stream.poll_next_unpin(cx)) {
- None => return Poll::Ready(Ok(())),
- Some(Err(e)) => {
- return Poll::Ready(Err(e));
- }
- Some(Ok(batch)) => {
- let cursor = match SortKeyCursor::new(
- self.next_batch_index, // assign this batch an ID
- batch,
- &self.column_expressions,
- ) {
- Ok(cursor) => cursor,
- Err(e) => {
- return Poll::Ready(Err(ArrowError::External(
- "".to_string(),
- Box::new(e),
- )));
+ match Pin::new(&mut streams_future).poll(cx) {
+ Poll::Ready(mut streams) => {
+ let stream = &mut streams[idx];
+ if stream.is_terminated() {
+ return Poll::Ready(Ok(()));
+ }
+
+ // Fetch a new input record and create a cursor from it
+ match futures::ready!(stream.poll_next_unpin(cx)) {
+ None => return Poll::Ready(Ok(())),
+ Some(Err(e)) => {
+ return Poll::Ready(Err(e));
+ }
+ Some(Ok(batch)) => {
+ let cursor = match SortKeyCursor::new(
+ self.next_batch_index, // assign this batch an ID
+ Arc::new(batch),
+ &self.column_expressions,
+ self.sort_options.clone(),
+ ) {
+ Ok(cursor) => cursor,
+ Err(e) => {
+ return Poll::Ready(Err(ArrowError::External(
+ "".to_string(),
+ Box::new(e),
+ )));
+ }
+ };
+ self.next_batch_index += 1;
+ self.cursors[idx].push_back(cursor)
}
- };
- self.next_batch_index += 1;
- self.cursors[idx].push_back(cursor)
+ }
+
+ Poll::Ready(Ok(()))
}
+ Poll::Pending => Poll::Pending,
}
-
- Poll::Ready(Ok(()))
}
/// Returns the index of the next stream to pull a row from, or None
@@ -465,9 +514,7 @@ impl SortPreservingMergeStream {
match min_cursor {
None => min_cursor = Some((idx, candidate)),
Some((_, ref mut min)) => {
- if min.compare(candidate, &self.sort_options)?
- == Ordering::Greater
- {
+ if min.compare(candidate)? == Ordering::Greater {
min_cursor = Some((idx, candidate))
}
}
@@ -674,7 +721,7 @@ mod tests {
use crate::physical_plan::expressions::col;
use crate::physical_plan::file_format::{CsvExec, PhysicalPlanConfig};
use crate::physical_plan::memory::MemoryExec;
- use crate::physical_plan::sort::SortExec;
+ use crate::physical_plan::sorts::sort::SortExec;
use crate::physical_plan::{collect, common};
use crate::test::{self, assert_is_pending};
use crate::{assert_batches_eq, test_util};
@@ -682,7 +729,6 @@ mod tests {
use super::*;
use arrow::datatypes::{DataType, Field, Schema};
use futures::{FutureExt, SinkExt};
- use tokio_stream::StreamExt;
#[tokio::test]
async fn test_merge_interleave() {
@@ -1251,15 +1297,17 @@ mod tests {
let metrics = ExecutionPlanMetricsSet::new();
let baseline_metrics = BaselineMetrics::new(&metrics, 0);
- let merge_stream = SortPreservingMergeStream::new(
+ let merge_stream = SortPreservingMergeStream::new_from_receiver(
receivers,
- // Use empty vector since we want to use the join handles ourselves
AbortOnDropMany(vec![]),
batches.schema(),
sort.as_slice(),
1024,
baseline_metrics,
- );
+ 0,
+ RUNTIME_ENV.clone(),
+ )
+ .await;
let mut merged = common::collect(Box::pin(merge_stream)).await.unwrap();
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index bc1ff554abfa..60e8e859f62e 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -3087,7 +3087,7 @@ async fn explain_analyze_baseline_metrics() {
fn expected_to_have_metrics(plan: &dyn ExecutionPlan) -> bool {
use datafusion::physical_plan;
- plan.as_any().downcast_ref::().is_some()
+ plan.as_any().downcast_ref::().is_some()
|| plan.as_any().downcast_ref::().is_some()
// CoalescePartitionsExec doesn't do any work so is not included
|| plan.as_any().downcast_ref::().is_some()