Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support spilling for hash aggregation #7400

Merged
merged 21 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow::record_batch::RecordBatch;
use arrow_array::{downcast_primitive, ArrayRef};
use arrow_schema::SchemaRef;
use datafusion_common::Result;
Expand Down Expand Up @@ -42,6 +43,9 @@ pub trait GroupValues: Send {

/// Emits the group values
fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;

/// Clear the contents and shrink the capacity to the size of the batch (free up memory usage)
fn clear_shrink(&mut self, batch: &RecordBatch);
}

pub fn new_group_values(schema: SchemaRef) -> Result<Box<dyn GroupValues>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use ahash::RandomState;
use arrow::array::BooleanBufferBuilder;
use arrow::buffer::NullBuffer;
use arrow::datatypes::i256;
use arrow::record_batch::RecordBatch;
use arrow_array::cast::AsArray;
use arrow_array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray};
use arrow_schema::DataType;
Expand Down Expand Up @@ -206,4 +207,12 @@ where
};
Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))])
}

fn clear_shrink(&mut self, batch: &RecordBatch) {
let count = batch.num_rows();
self.values.clear();
self.values.shrink_to(count);
self.map.clear();
self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared
}
}
12 changes: 12 additions & 0 deletions datafusion/core/src/physical_plan/aggregates/group_values/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use crate::physical_plan::aggregates::group_values::GroupValues;
use ahash::RandomState;
use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, Rows, SortField};
use arrow_array::ArrayRef;
use arrow_schema::SchemaRef;
Expand Down Expand Up @@ -181,4 +182,15 @@ impl GroupValues for GroupValuesRows {
}
})
}

fn clear_shrink(&mut self, batch: &RecordBatch) {
let count = batch.num_rows();
// FIXME: there is no good way to clear_shrink for self.group_values
self.group_values = self.row_converter.empty_rows(count, 0);
self.map.clear();
self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared
self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>();
self.hashes_buffer.clear();
self.hashes_buffer.shrink_to(count);
}
}
195 changes: 157 additions & 38 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,7 @@ mod tests {
use std::sync::Arc;
use std::task::{Context, Poll};

use datafusion_execution::config::SessionConfig;
use futures::{FutureExt, Stream};

// Generate a schema which consists of 5 columns (a, b, c, d, e)
Expand Down Expand Up @@ -1466,7 +1467,22 @@ mod tests {
)
}

async fn check_grouping_sets(input: Arc<dyn ExecutionPlan>) -> Result<()> {
fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> {
let session_config = SessionConfig::new().with_batch_size(batch_size);
let runtime = Arc::new(
RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(max_memory, 1.0))
.unwrap(),
);
let task_ctx = TaskContext::default()
.with_session_config(session_config)
.with_runtime(runtime);
Arc::new(task_ctx)
}

async fn check_grouping_sets(
input: Arc<dyn ExecutionPlan>,
spill: bool,
) -> Result<()> {
let input_schema = input.schema();

let grouping_set = PhysicalGroupBy {
Expand All @@ -1491,7 +1507,11 @@ mod tests {
DataType::Int64,
))];

let task_ctx = Arc::new(TaskContext::default());
let task_ctx = if spill {
new_spill_ctx(4, 1000)
} else {
Arc::new(TaskContext::default())
};

let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
Expand All @@ -1506,24 +1526,53 @@ mod tests {
let result =
common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?;

let expected = vec![
"+---+-----+-----------------+",
"| a | b | COUNT(1)[count] |",
"+---+-----+-----------------+",
"| | 1.0 | 2 |",
"| | 2.0 | 2 |",
"| | 3.0 | 2 |",
"| | 4.0 | 2 |",
"| 2 | | 2 |",
"| 2 | 1.0 | 2 |",
"| 3 | | 3 |",
"| 3 | 2.0 | 2 |",
"| 3 | 3.0 | 1 |",
"| 4 | | 3 |",
"| 4 | 3.0 | 1 |",
"| 4 | 4.0 | 2 |",
"+---+-----+-----------------+",
];
let expected = if spill {
vec![
"+---+-----+-----------------+",
Copy link
Contributor

@jayzhan211 jayzhan211 Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kazuyukitanimura Hi, do you remember why the result of spill is different from the non-spill one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason is that we are comparing the output of a partial aggregate and when spill we also have a lower desired batch size and hit the emit early logic:

fn emit_early_if_necessary(&mut self) -> Result<()> {
if self.group_values.len() >= self.batch_size
&& matches!(self.group_ordering, GroupOrdering::None)
&& matches!(self.mode, AggregateMode::Partial)
&& self.update_memory_reservation().is_err()
{
let n = self.group_values.len() / self.batch_size * self.batch_size;
let batch = self.emit(EmitTo::First(n), false)?;
self.exec_state = ExecutionState::ProducingOutput(batch);
}
Ok(())
}

Copy link
Member

@viirya viirya Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right. When spilling happens, it means we don't have enough memory to hold original batch so we will do partial aggregation on smaller batch, different partial aggregation result will be but I think it doesn't change final aggregation result.

"| a | b | COUNT(1)[count] |",
"+---+-----+-----------------+",
"| | 1.0 | 1 |",
"| | 1.0 | 1 |",
"| | 2.0 | 1 |",
"| | 2.0 | 1 |",
"| | 3.0 | 1 |",
"| | 3.0 | 1 |",
"| | 4.0 | 1 |",
"| | 4.0 | 1 |",
"| 2 | | 1 |",
"| 2 | | 1 |",
"| 2 | 1.0 | 1 |",
"| 2 | 1.0 | 1 |",
"| 3 | | 1 |",
"| 3 | | 2 |",
"| 3 | 2.0 | 2 |",
"| 3 | 3.0 | 1 |",
"| 4 | | 1 |",
"| 4 | | 2 |",
"| 4 | 3.0 | 1 |",
"| 4 | 4.0 | 2 |",
"+---+-----+-----------------+",
]
} else {
vec![
"+---+-----+-----------------+",
"| a | b | COUNT(1)[count] |",
"+---+-----+-----------------+",
"| | 1.0 | 2 |",
"| | 2.0 | 2 |",
"| | 3.0 | 2 |",
"| | 4.0 | 2 |",
"| 2 | | 2 |",
"| 2 | 1.0 | 2 |",
"| 3 | | 3 |",
"| 3 | 2.0 | 2 |",
"| 3 | 3.0 | 1 |",
"| 4 | | 3 |",
"| 4 | 3.0 | 1 |",
"| 4 | 4.0 | 2 |",
"+---+-----+-----------------+",
]
};
assert_batches_sorted_eq!(expected, &result);

let groups = partial_aggregate.group_expr().expr().to_vec();
Expand All @@ -1537,6 +1586,12 @@ mod tests {

let final_grouping_set = PhysicalGroupBy::new_single(final_group);

let task_ctx = if spill {
new_spill_ctx(4, 3160)
} else {
task_ctx
};

let merged_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
final_grouping_set,
Expand Down Expand Up @@ -1582,7 +1637,7 @@ mod tests {
}

/// build the aggregates on the data from some_data() and check the results
async fn check_aggregates(input: Arc<dyn ExecutionPlan>) -> Result<()> {
async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
let input_schema = input.schema();

let grouping_set = PhysicalGroupBy {
Expand All @@ -1597,7 +1652,11 @@ mod tests {
DataType::Float64,
))];

let task_ctx = Arc::new(TaskContext::default());
let task_ctx = if spill {
new_spill_ctx(2, 2144)
} else {
Arc::new(TaskContext::default())
};

let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
Expand All @@ -1612,15 +1671,29 @@ mod tests {
let result =
common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?;

let expected = [
"+---+---------------+-------------+",
"| a | AVG(b)[count] | AVG(b)[sum] |",
"+---+---------------+-------------+",
"| 2 | 2 | 2.0 |",
"| 3 | 3 | 7.0 |",
"| 4 | 3 | 11.0 |",
"+---+---------------+-------------+",
];
let expected = if spill {
vec![
"+---+---------------+-------------+",
"| a | AVG(b)[count] | AVG(b)[sum] |",
"+---+---------------+-------------+",
"| 2 | 1 | 1.0 |",
"| 2 | 1 | 1.0 |",
"| 3 | 1 | 2.0 |",
"| 3 | 2 | 5.0 |",
"| 4 | 3 | 11.0 |",
"+---+---------------+-------------+",
]
} else {
vec![
"+---+---------------+-------------+",
"| a | AVG(b)[count] | AVG(b)[sum] |",
"+---+---------------+-------------+",
"| 2 | 2 | 2.0 |",
"| 3 | 3 | 7.0 |",
"| 4 | 3 | 11.0 |",
"+---+---------------+-------------+",
]
};
assert_batches_sorted_eq!(expected, &result);

let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
Expand Down Expand Up @@ -1663,7 +1736,13 @@ mod tests {

let metrics = merged_aggregate.metrics().unwrap();
let output_rows = metrics.output_rows().unwrap();
assert_eq!(3, output_rows);
if spill {
// When spilling, the output rows metrics become partial output size + final output size
// This is because final aggregation starts while partial aggregation is still emitting
assert_eq!(8, output_rows);
} else {
assert_eq!(3, output_rows);
}

Ok(())
}
Expand Down Expand Up @@ -1784,31 +1863,63 @@ mod tests {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: false });

check_aggregates(input).await
check_aggregates(input, false).await
}

#[tokio::test]
async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: false });

check_grouping_sets(input).await
check_grouping_sets(input, false).await
}

#[tokio::test]
async fn aggregate_source_with_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: true });

check_aggregates(input).await
check_aggregates(input, false).await
}

#[tokio::test]
async fn aggregate_grouping_sets_with_yielding() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: true });

check_grouping_sets(input).await
check_grouping_sets(input, false).await
}

#[tokio::test]
async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: false });

check_aggregates(input, true).await
}

#[tokio::test]
async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: false });

check_grouping_sets(input, true).await
}

#[tokio::test]
async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: true });

check_aggregates(input, true).await
}

#[tokio::test]
async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: true });

check_grouping_sets(input, true).await
}

#[tokio::test]
Expand Down Expand Up @@ -1976,7 +2087,10 @@ mod tests {
async fn run_first_last_multi_partitions() -> Result<()> {
for use_coalesce_batches in [false, true] {
for is_first_acc in [false, true] {
first_last_multi_partitions(use_coalesce_batches, is_first_acc).await?
for spill in [false, true] {
first_last_multi_partitions(use_coalesce_batches, is_first_acc, spill)
.await?
}
}
}
Ok(())
Expand All @@ -2002,8 +2116,13 @@ mod tests {
async fn first_last_multi_partitions(
use_coalesce_batches: bool,
is_first_acc: bool,
spill: bool,
) -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let task_ctx = if spill {
new_spill_ctx(2, 2812)
} else {
Arc::new(TaskContext::default())
};

let (schema, data) = some_data_v2();
let partition1 = data[0].clone();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ impl GroupOrderingPartial {
Ok(())
}

/// Return the size of memor allocated by this structure
/// Return the size of memory allocated by this structure
pub(crate) fn size(&self) -> usize {
std::mem::size_of::<Self>()
+ self.order_indices.allocated_size()
Expand Down
Loading