-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support spilling for hash aggregation #7400
Changes from 18 commits
9961558
d6a849f
5dfa345
927cb85
ea38ad5
0863e81
9068e5f
4b195f6
985a90c
d9a77c8
aa1fc50
6646776
b636a5e
b7d1a59
a7f4fee
ff53e06
25a9e5d
2ead8b3
113b263
9a65b27
c8b4a10
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1291,6 +1291,7 @@ mod tests { | |||||||||||||||||||||||||
use std::sync::Arc; | ||||||||||||||||||||||||||
use std::task::{Context, Poll}; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
use datafusion_execution::config::SessionConfig; | ||||||||||||||||||||||||||
use futures::{FutureExt, Stream}; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
// Generate a schema which consists of 5 columns (a, b, c, d, e) | ||||||||||||||||||||||||||
|
@@ -1461,7 +1462,22 @@ mod tests { | |||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
async fn check_grouping_sets(input: Arc<dyn ExecutionPlan>) -> Result<()> { | ||||||||||||||||||||||||||
fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> { | ||||||||||||||||||||||||||
let session_config = SessionConfig::new().with_batch_size(batch_size); | ||||||||||||||||||||||||||
let runtime = Arc::new( | ||||||||||||||||||||||||||
RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(max_memory, 1.0)) | ||||||||||||||||||||||||||
.unwrap(), | ||||||||||||||||||||||||||
); | ||||||||||||||||||||||||||
let task_ctx = TaskContext::default() | ||||||||||||||||||||||||||
.with_session_config(session_config) | ||||||||||||||||||||||||||
.with_runtime(runtime); | ||||||||||||||||||||||||||
Arc::new(task_ctx) | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
async fn check_grouping_sets( | ||||||||||||||||||||||||||
input: Arc<dyn ExecutionPlan>, | ||||||||||||||||||||||||||
spill: bool, | ||||||||||||||||||||||||||
) -> Result<()> { | ||||||||||||||||||||||||||
let input_schema = input.schema(); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let grouping_set = PhysicalGroupBy { | ||||||||||||||||||||||||||
|
@@ -1486,7 +1502,11 @@ mod tests { | |||||||||||||||||||||||||
DataType::Int64, | ||||||||||||||||||||||||||
))]; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let task_ctx = Arc::new(TaskContext::default()); | ||||||||||||||||||||||||||
let task_ctx = if spill { | ||||||||||||||||||||||||||
new_spill_ctx(4, 1000) | ||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||
Arc::new(TaskContext::default()) | ||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let partial_aggregate = Arc::new(AggregateExec::try_new( | ||||||||||||||||||||||||||
AggregateMode::Partial, | ||||||||||||||||||||||||||
|
@@ -1501,24 +1521,53 @@ mod tests { | |||||||||||||||||||||||||
let result = | ||||||||||||||||||||||||||
common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let expected = vec![ | ||||||||||||||||||||||||||
"+---+-----+-----------------+", | ||||||||||||||||||||||||||
"| a | b | COUNT(1)[count] |", | ||||||||||||||||||||||||||
"+---+-----+-----------------+", | ||||||||||||||||||||||||||
"| | 1.0 | 2 |", | ||||||||||||||||||||||||||
"| | 2.0 | 2 |", | ||||||||||||||||||||||||||
"| | 3.0 | 2 |", | ||||||||||||||||||||||||||
"| | 4.0 | 2 |", | ||||||||||||||||||||||||||
"| 2 | | 2 |", | ||||||||||||||||||||||||||
"| 2 | 1.0 | 2 |", | ||||||||||||||||||||||||||
"| 3 | | 3 |", | ||||||||||||||||||||||||||
"| 3 | 2.0 | 2 |", | ||||||||||||||||||||||||||
"| 3 | 3.0 | 1 |", | ||||||||||||||||||||||||||
"| 4 | | 3 |", | ||||||||||||||||||||||||||
"| 4 | 3.0 | 1 |", | ||||||||||||||||||||||||||
"| 4 | 4.0 | 2 |", | ||||||||||||||||||||||||||
"+---+-----+-----------------+", | ||||||||||||||||||||||||||
]; | ||||||||||||||||||||||||||
let expected = if spill { | ||||||||||||||||||||||||||
vec![ | ||||||||||||||||||||||||||
"+---+-----+-----------------+", | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kazuyukitanimura Hi, do you remember why the result of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the reason is that we are comparing the output of a partial aggregate and when datafusion/datafusion/physical-plan/src/aggregates/row_hash.rs Lines 934 to 945 in 5b6b404
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's right. When spilling happens, it means we don't have enough memory to hold original batch so we will do partial aggregation on smaller batch, different partial aggregation result will be but I think it doesn't change final aggregation result. |
||||||||||||||||||||||||||
"| a | b | COUNT(1)[count] |", | ||||||||||||||||||||||||||
"+---+-----+-----------------+", | ||||||||||||||||||||||||||
"| | 1.0 | 1 |", | ||||||||||||||||||||||||||
"| | 1.0 | 1 |", | ||||||||||||||||||||||||||
"| | 2.0 | 1 |", | ||||||||||||||||||||||||||
"| | 2.0 | 1 |", | ||||||||||||||||||||||||||
"| | 3.0 | 1 |", | ||||||||||||||||||||||||||
"| | 3.0 | 1 |", | ||||||||||||||||||||||||||
"| | 4.0 | 1 |", | ||||||||||||||||||||||||||
"| | 4.0 | 1 |", | ||||||||||||||||||||||||||
"| 2 | | 1 |", | ||||||||||||||||||||||||||
"| 2 | | 1 |", | ||||||||||||||||||||||||||
"| 2 | 1.0 | 1 |", | ||||||||||||||||||||||||||
"| 2 | 1.0 | 1 |", | ||||||||||||||||||||||||||
"| 3 | | 1 |", | ||||||||||||||||||||||||||
"| 3 | | 2 |", | ||||||||||||||||||||||||||
"| 3 | 2.0 | 2 |", | ||||||||||||||||||||||||||
"| 3 | 3.0 | 1 |", | ||||||||||||||||||||||||||
"| 4 | | 1 |", | ||||||||||||||||||||||||||
"| 4 | | 2 |", | ||||||||||||||||||||||||||
"| 4 | 3.0 | 1 |", | ||||||||||||||||||||||||||
"| 4 | 4.0 | 2 |", | ||||||||||||||||||||||||||
"+---+-----+-----------------+", | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||
vec![ | ||||||||||||||||||||||||||
"+---+-----+-----------------+", | ||||||||||||||||||||||||||
"| a | b | COUNT(1)[count] |", | ||||||||||||||||||||||||||
"+---+-----+-----------------+", | ||||||||||||||||||||||||||
"| | 1.0 | 2 |", | ||||||||||||||||||||||||||
"| | 2.0 | 2 |", | ||||||||||||||||||||||||||
"| | 3.0 | 2 |", | ||||||||||||||||||||||||||
"| | 4.0 | 2 |", | ||||||||||||||||||||||||||
"| 2 | | 2 |", | ||||||||||||||||||||||||||
"| 2 | 1.0 | 2 |", | ||||||||||||||||||||||||||
"| 3 | | 3 |", | ||||||||||||||||||||||||||
"| 3 | 2.0 | 2 |", | ||||||||||||||||||||||||||
"| 3 | 3.0 | 1 |", | ||||||||||||||||||||||||||
"| 4 | | 3 |", | ||||||||||||||||||||||||||
"| 4 | 3.0 | 1 |", | ||||||||||||||||||||||||||
"| 4 | 4.0 | 2 |", | ||||||||||||||||||||||||||
"+---+-----+-----------------+", | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||
assert_batches_sorted_eq!(expected, &result); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let groups = partial_aggregate.group_expr().expr().to_vec(); | ||||||||||||||||||||||||||
|
@@ -1532,6 +1581,12 @@ mod tests { | |||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let final_grouping_set = PhysicalGroupBy::new_single(final_group); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let task_ctx = if spill { | ||||||||||||||||||||||||||
new_spill_ctx(4, 3160) | ||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||
task_ctx | ||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let merged_aggregate = Arc::new(AggregateExec::try_new( | ||||||||||||||||||||||||||
AggregateMode::Final, | ||||||||||||||||||||||||||
final_grouping_set, | ||||||||||||||||||||||||||
|
@@ -1577,7 +1632,7 @@ mod tests { | |||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
/// build the aggregates on the data from some_data() and check the results | ||||||||||||||||||||||||||
async fn check_aggregates(input: Arc<dyn ExecutionPlan>) -> Result<()> { | ||||||||||||||||||||||||||
async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> { | ||||||||||||||||||||||||||
let input_schema = input.schema(); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let grouping_set = PhysicalGroupBy { | ||||||||||||||||||||||||||
|
@@ -1592,7 +1647,11 @@ mod tests { | |||||||||||||||||||||||||
DataType::Float64, | ||||||||||||||||||||||||||
))]; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let task_ctx = Arc::new(TaskContext::default()); | ||||||||||||||||||||||||||
let task_ctx = if spill { | ||||||||||||||||||||||||||
new_spill_ctx(2, 2144) | ||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||
Arc::new(TaskContext::default()) | ||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let partial_aggregate = Arc::new(AggregateExec::try_new( | ||||||||||||||||||||||||||
AggregateMode::Partial, | ||||||||||||||||||||||||||
|
@@ -1607,15 +1666,29 @@ mod tests { | |||||||||||||||||||||||||
let result = | ||||||||||||||||||||||||||
common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let expected = [ | ||||||||||||||||||||||||||
"+---+---------------+-------------+", | ||||||||||||||||||||||||||
"| a | AVG(b)[count] | AVG(b)[sum] |", | ||||||||||||||||||||||||||
"+---+---------------+-------------+", | ||||||||||||||||||||||||||
"| 2 | 2 | 2.0 |", | ||||||||||||||||||||||||||
"| 3 | 3 | 7.0 |", | ||||||||||||||||||||||||||
"| 4 | 3 | 11.0 |", | ||||||||||||||||||||||||||
"+---+---------------+-------------+", | ||||||||||||||||||||||||||
]; | ||||||||||||||||||||||||||
let expected = if spill { | ||||||||||||||||||||||||||
vec![ | ||||||||||||||||||||||||||
"+---+---------------+-------------+", | ||||||||||||||||||||||||||
"| a | AVG(b)[count] | AVG(b)[sum] |", | ||||||||||||||||||||||||||
"+---+---------------+-------------+", | ||||||||||||||||||||||||||
"| 2 | 1 | 1.0 |", | ||||||||||||||||||||||||||
"| 2 | 1 | 1.0 |", | ||||||||||||||||||||||||||
"| 3 | 1 | 2.0 |", | ||||||||||||||||||||||||||
"| 3 | 2 | 5.0 |", | ||||||||||||||||||||||||||
"| 4 | 3 | 11.0 |", | ||||||||||||||||||||||||||
"+---+---------------+-------------+", | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||
vec![ | ||||||||||||||||||||||||||
"+---+---------------+-------------+", | ||||||||||||||||||||||||||
"| a | AVG(b)[count] | AVG(b)[sum] |", | ||||||||||||||||||||||||||
"+---+---------------+-------------+", | ||||||||||||||||||||||||||
"| 2 | 2 | 2.0 |", | ||||||||||||||||||||||||||
"| 3 | 3 | 7.0 |", | ||||||||||||||||||||||||||
"| 4 | 3 | 11.0 |", | ||||||||||||||||||||||||||
"+---+---------------+-------------+", | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||
assert_batches_sorted_eq!(expected, &result); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); | ||||||||||||||||||||||||||
|
@@ -1658,7 +1731,13 @@ mod tests { | |||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let metrics = merged_aggregate.metrics().unwrap(); | ||||||||||||||||||||||||||
let output_rows = metrics.output_rows().unwrap(); | ||||||||||||||||||||||||||
assert_eq!(3, output_rows); | ||||||||||||||||||||||||||
if spill { | ||||||||||||||||||||||||||
// When spilling, the output rows metrics become partial output size + final output size | ||||||||||||||||||||||||||
// This is because final aggregation starts while partial aggregation is still emitting | ||||||||||||||||||||||||||
assert_eq!(8, output_rows); | ||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||
assert_eq!(3, output_rows); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
Ok(()) | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
@@ -1779,31 +1858,63 @@ mod tests { | |||||||||||||||||||||||||
let input: Arc<dyn ExecutionPlan> = | ||||||||||||||||||||||||||
Arc::new(TestYieldingExec { yield_first: false }); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
check_aggregates(input).await | ||||||||||||||||||||||||||
check_aggregates(input, false).await | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#[tokio::test] | ||||||||||||||||||||||||||
async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> { | ||||||||||||||||||||||||||
let input: Arc<dyn ExecutionPlan> = | ||||||||||||||||||||||||||
Arc::new(TestYieldingExec { yield_first: false }); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
check_grouping_sets(input).await | ||||||||||||||||||||||||||
check_grouping_sets(input, false).await | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#[tokio::test] | ||||||||||||||||||||||||||
async fn aggregate_source_with_yielding() -> Result<()> { | ||||||||||||||||||||||||||
let input: Arc<dyn ExecutionPlan> = | ||||||||||||||||||||||||||
Arc::new(TestYieldingExec { yield_first: true }); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
check_aggregates(input).await | ||||||||||||||||||||||||||
check_aggregates(input, false).await | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#[tokio::test] | ||||||||||||||||||||||||||
async fn aggregate_grouping_sets_with_yielding() -> Result<()> { | ||||||||||||||||||||||||||
let input: Arc<dyn ExecutionPlan> = | ||||||||||||||||||||||||||
Arc::new(TestYieldingExec { yield_first: true }); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
check_grouping_sets(input).await | ||||||||||||||||||||||||||
check_grouping_sets(input, false).await | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#[tokio::test] | ||||||||||||||||||||||||||
async fn aggregate_source_not_yielding_with_spill() -> Result<()> { | ||||||||||||||||||||||||||
let input: Arc<dyn ExecutionPlan> = | ||||||||||||||||||||||||||
Arc::new(TestYieldingExec { yield_first: false }); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
check_aggregates(input, true).await | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#[tokio::test] | ||||||||||||||||||||||||||
async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> { | ||||||||||||||||||||||||||
let input: Arc<dyn ExecutionPlan> = | ||||||||||||||||||||||||||
Arc::new(TestYieldingExec { yield_first: false }); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
check_grouping_sets(input, true).await | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#[tokio::test] | ||||||||||||||||||||||||||
async fn aggregate_source_with_yielding_with_spill() -> Result<()> { | ||||||||||||||||||||||||||
let input: Arc<dyn ExecutionPlan> = | ||||||||||||||||||||||||||
Arc::new(TestYieldingExec { yield_first: true }); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
check_aggregates(input, true).await | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#[tokio::test] | ||||||||||||||||||||||||||
async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> { | ||||||||||||||||||||||||||
let input: Arc<dyn ExecutionPlan> = | ||||||||||||||||||||||||||
Arc::new(TestYieldingExec { yield_first: true }); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
check_grouping_sets(input, true).await | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#[tokio::test] | ||||||||||||||||||||||||||
|
@@ -1971,7 +2082,10 @@ mod tests { | |||||||||||||||||||||||||
async fn run_first_last_multi_partitions() -> Result<()> { | ||||||||||||||||||||||||||
for use_coalesce_batches in [false, true] { | ||||||||||||||||||||||||||
for is_first_acc in [false, true] { | ||||||||||||||||||||||||||
first_last_multi_partitions(use_coalesce_batches, is_first_acc).await? | ||||||||||||||||||||||||||
for spill in [false, true] { | ||||||||||||||||||||||||||
first_last_multi_partitions(use_coalesce_batches, is_first_acc, spill) | ||||||||||||||||||||||||||
.await? | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
Ok(()) | ||||||||||||||||||||||||||
|
@@ -1997,8 +2111,13 @@ mod tests { | |||||||||||||||||||||||||
async fn first_last_multi_partitions( | ||||||||||||||||||||||||||
use_coalesce_batches: bool, | ||||||||||||||||||||||||||
is_first_acc: bool, | ||||||||||||||||||||||||||
spill: bool, | ||||||||||||||||||||||||||
) -> Result<()> { | ||||||||||||||||||||||||||
let task_ctx = Arc::new(TaskContext::default()); | ||||||||||||||||||||||||||
let task_ctx = if spill { | ||||||||||||||||||||||||||
new_spill_ctx(2, 2812) | ||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||
Arc::new(TaskContext::default()) | ||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
let (schema, data) = some_data_v2(); | ||||||||||||||||||||||||||
let partition1 = data[0].clone(); | ||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "Clears the contents".
Also it's a bit confusing why this method takes a
RecordBatch
- perhaps we can just pass the number of rows asnew_capacity
or something?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intention of passing
RecordBatch
is to use other information than number of rows for reserving the new capacity such as data (binary) capacity. But it will be separate future PRs.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh cool, perhaps its worth pointing out why
RecordBatch
is required in the comments too.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the comment