Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new physical rule CombinePartialFinalAggregate #5837

Merged
merged 12 commits into from
Apr 12, 2023
2 changes: 2 additions & 0 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ use datafusion_sql::planner::object_name_to_table_reference;
use uuid::Uuid;

// backwards compatibility
use crate::physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate;
pub use datafusion_execution::config::SessionConfig;
pub use datafusion_execution::TaskContext;

Expand Down Expand Up @@ -1296,6 +1297,7 @@ impl SessionState {
// Enforce sort before PipelineFixer
Arc::new(EnforceDistribution::new()),
Arc::new(EnforceSorting::new()),
Arc::new(CombinePartialFinalAggregate::new()),
// If the query is processing infinite inputs, the PipelineFixer rule applies the
// necessary transformations to make the query runnable (if it is not already runnable).
// If the query can not be made runnable, the rule emits an error with a diagnostic message.
Expand Down
120 changes: 120 additions & 0 deletions datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! CombinePartialFinalAggregate optimizer rule checks the adjacent Partial and Final AggregateExecs
//! and try to combine them if necessary
use crate::error::Result;
use crate::physical_optimizer::PhysicalOptimizerRule;
use crate::physical_plan::aggregates::{AggregateExec, AggregateMode};
use crate::physical_plan::ExecutionPlan;
use datafusion_common::config::ConfigOptions;
use std::sync::Arc;

use datafusion_common::tree_node::{Transformed, TreeNode};

/// CombinePartialFinalAggregate optimizer rule combines the adjacent Partial and Final AggregateExecs
/// into a Single AggregateExec if their grouping exprs and aggregate exprs equal.
///
/// This rule should be applied after the EnforceDistribution and EnforceSorting rules
///
#[derive(Default)]
pub struct CombinePartialFinalAggregate {}

impl CombinePartialFinalAggregate {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}

impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
plan.transform_down(&|plan| {
let transformed = plan.as_any().downcast_ref::<AggregateExec>().and_then(
|AggregateExec {
mode: final_mode,
input: final_input,
group_by: final_group_by,
aggr_expr: final_aggr_expr,
..
}| {
if matches!(
final_mode,
AggregateMode::Final | AggregateMode::FinalPartitioned
) {
final_input
.as_any()
.downcast_ref::<AggregateExec>()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there's no RepartitionExec, it means the distribution of AggregateExec with final mode and AggregateExec with partial mode are the same. Therefore, there's no need to do two-phase aggregations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mingmwang for introducing this rule, which will significantly improve the query performances for the SQL patterns shown in UTs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the performance improve will not that significant, because usually the Final aggregation step is not that heavy.

.and_then(
|AggregateExec {
mode: input_mode,
input: partial_input,
group_by: input_group_by,
aggr_expr: input_aggr_expr,
input_schema,
..
}| {
if matches!(input_mode, AggregateMode::Partial)
&& final_group_by.eq(input_group_by)
&& final_aggr_expr.len() == input_aggr_expr.len()
&& final_aggr_expr
.iter()
.zip(input_aggr_expr.iter())
.all(|(final_expr, partial_expr)| {
final_expr.eq(partial_expr)
})
{
AggregateExec::try_new(
AggregateMode::Single,
input_group_by.clone(),
input_aggr_expr.to_vec(),
partial_input.clone(),
input_schema.clone(),
)
.ok()
.map(Arc::new)
} else {
None
}
},
)
} else {
None
}
},
);

Ok(if let Some(transformed) = transformed {
Transformed::Yes(transformed)
} else {
Transformed::No(plan)
})
})
}

fn name(&self) -> &str {
"CombinePartialFinalAggregate"
}

fn schema_check(&self) -> bool {
true
}
}
mingmwang marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions datafusion/core/src/physical_optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

pub mod aggregate_statistics;
pub mod coalesce_batches;
pub mod combine_partial_final_agg;
pub mod dist_enforcement;
pub mod global_sort_selection;
pub mod join_selection;
Expand Down
40 changes: 33 additions & 7 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ pub enum AggregateMode {
/// with Hash repartitioning on the group keys. If a group key is
/// duplicated, duplicate groups would be produced
FinalPartitioned,
/// Single aggregate is a combination of Partial and Final aggregate mode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Single aggregate is a combination of Partial and Final aggregate mode
/// Applies the entire logical aggregation operation in a single operator,
/// as opposed to Partial / Final modes which apply the logical aggregation using
/// two operators.

Single,
}

/// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET)
Expand Down Expand Up @@ -147,6 +149,24 @@ impl PhysicalGroupBy {
}
}

impl PartialEq for PhysicalGroupBy {
fn eq(&self, other: &PhysicalGroupBy) -> bool {
self.expr.len() == other.expr.len()
&& self
.expr
.iter()
.zip(other.expr.iter())
.all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wondered why this needed to be manually derived, so I tried removing it and got this error:

error[E0369]: binary operation `==` cannot be applied to type `Vec<(Arc<dyn PhysicalExpr>, std::string::String)>`
  --> datafusion/core/src/physical_plan/aggregates/mod.rs:91:5
   |
88 | #[derive(Clone, Debug, Default, PartialEq)]
   |                                 --------- in this derive macro expansion
...
91 |     expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
   |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   |
   = note: this error originates in the derive macro `PartialEq` (in Nightly builds, run with -Z macro-backtrace for more info)

Copy link
Contributor Author

@mingmwang mingmwang Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like if Struct contains any boxed Trait Object, we can not use the PartialEq derive macros.

rust-lang/rust#39128

&& self.null_expr.len() == other.null_expr.len()
&& self
.null_expr
.iter()
.zip(other.null_expr.iter())
.all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2)
&& self.groups == other.groups
}
}

enum StreamType {
AggregateStream(AggregateStream),
GroupedHashAggregateStream(GroupedHashAggregateStream),
Expand Down Expand Up @@ -316,8 +336,8 @@ impl ExecutionPlan for AggregateExec {
/// Get the output partitioning of this plan
fn output_partitioning(&self) -> Partitioning {
match &self.mode {
AggregateMode::Partial => {
// Partial Aggregation will not change the output partitioning but need to respect the Alias
AggregateMode::Partial | AggregateMode::Single => {
// Partial and Single Aggregation will not change the output partitioning but need to respect the Alias
let input_partition = self.input.output_partitioning();
match input_partition {
Partitioning::Hash(exprs, part) => {
Expand Down Expand Up @@ -360,7 +380,9 @@ impl ExecutionPlan for AggregateExec {

fn required_input_distribution(&self) -> Vec<Distribution> {
match &self.mode {
AggregateMode::Partial => vec![Distribution::UnspecifiedDistribution],
AggregateMode::Partial | AggregateMode::Single => {
Dandandan marked this conversation as resolved.
Show resolved Hide resolved
vec![Distribution::UnspecifiedDistribution]
}
AggregateMode::FinalPartitioned => {
vec![Distribution::HashPartitioned(self.output_group_expr())]
}
Expand Down Expand Up @@ -528,7 +550,9 @@ fn create_schema(
fields.extend(expr.state_fields()?.iter().cloned())
}
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
AggregateMode::Final
| AggregateMode::FinalPartitioned
| AggregateMode::Single => {
// in final mode, the field with the final result of the accumulator
for expr in aggr_expr {
fields.push(expr.field()?)
Expand All @@ -554,7 +578,7 @@ fn aggregate_expressions(
col_idx_base: usize,
) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
match mode {
AggregateMode::Partial => {
AggregateMode::Partial | AggregateMode::Single => {
Ok(aggr_expr.iter().map(|agg| agg.expressions()).collect())
}
// in this mode, we build the merge expressions of the aggregation
Expand Down Expand Up @@ -617,7 +641,7 @@ fn create_row_accumulators(
}

/// returns a vector of ArrayRefs, where each entry corresponds to either the
/// final value (mode = Final) or states (mode = Partial)
/// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial)
fn finalize_aggregation(
accumulators: &[AccumulatorItem],
mode: &AggregateMode,
Expand All @@ -636,7 +660,9 @@ fn finalize_aggregation(
.collect::<Result<Vec<_>>>()?;
Ok(a.iter().flatten().cloned().collect::<Vec<_>>())
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
AggregateMode::Final
| AggregateMode::FinalPartitioned
| AggregateMode::Single => {
// merge the state to the final value
accumulators
.iter()
Expand Down
4 changes: 3 additions & 1 deletion datafusion/core/src/physical_plan/aggregates/no_grouping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ fn aggregate_batch(
// 1.3
let size_pre = accum.size();
let res = match mode {
AggregateMode::Partial => accum.update_batch(values),
AggregateMode::Partial | AggregateMode::Single => {
accum.update_batch(values)
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
accum.merge_batch(values)
}
Expand Down
26 changes: 14 additions & 12 deletions datafusion/core/src/physical_plan/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ impl GroupedHashAggregateStream {
group_state.aggregation_buffer.as_mut_slice(),
);
match self.mode {
AggregateMode::Partial => {
AggregateMode::Partial | AggregateMode::Single => {
accumulator.update_batch(&values, &mut state_accessor)
}
AggregateMode::FinalPartitioned
Expand Down Expand Up @@ -486,7 +486,7 @@ impl GroupedHashAggregateStream {
.try_for_each(|(accumulator, values)| {
let size_pre = accumulator.size();
let res = match self.mode {
AggregateMode::Partial => {
AggregateMode::Partial | AggregateMode::Single => {
accumulator.update_batch(&values)
}
AggregateMode::FinalPartitioned
Expand Down Expand Up @@ -594,7 +594,9 @@ impl GroupedHashAggregateStream {
AggregateMode::Partial => {
read_as_batch(&state_buffers, &self.row_aggr_schema, RowType::WordAligned)
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
AggregateMode::Final
| AggregateMode::FinalPartitioned
| AggregateMode::Single => {
let mut results = vec![];
for (idx, acc) in self.row_accumulators.iter().enumerate() {
let mut state_accessor =
Expand Down Expand Up @@ -636,15 +638,15 @@ impl GroupedHashAggregateStream {
.expect("Unexpected accumulator state in hash aggregate")
}),
),
AggregateMode::Final | AggregateMode::FinalPartitioned => {
ScalarValue::iter_to_array(group_state_chunk.iter().map(
|row_group_state| {
row_group_state.accumulator_set[idx].evaluate().expect(
"Unexpected accumulator state in hash aggregate",
)
},
))
}
AggregateMode::Final
| AggregateMode::FinalPartitioned
| AggregateMode::Single => ScalarValue::iter_to_array(
group_state_chunk.iter().map(|row_group_state| {
row_group_state.accumulator_set[idx]
.evaluate()
.expect("Unexpected accumulator state in hash aggregate")
}),
),
}?;
// Cast output if needed (e.g. for types like Dictionary where
// the intermediate GroupByScalar type was not the same as the
Expand Down
20 changes: 20 additions & 0 deletions datafusion/core/src/physical_plan/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use crate::error::Result;
use crate::physical_plan::PhysicalExpr;
pub use datafusion_expr::AggregateUDF;

use datafusion_physical_expr::aggregate::utils::down_cast_any_ref;
use std::sync::Arc;

/// Creates a physical expression of the UDAF, that includes all necessary type coercion.
Expand Down Expand Up @@ -102,3 +103,22 @@ impl AggregateExpr for AggregateFunctionExpr {
&self.name
}
}

impl PartialEq<dyn Any> for AggregateFunctionExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.data_type == x.data_type
&& self.fun == x.fun
&& self.args.len() == x.args.len()
&& self
.args
.iter()
.zip(x.args.iter())
.all(|(this_arg, other_arg)| this_arg.eq(other_arg))
})
.unwrap_or(false)
}
}
14 changes: 14 additions & 0 deletions datafusion/physical-expr/src/aggregate/approx_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Defines physical expressions that can evaluated at runtime during query execution

use super::hyperloglog::HyperLogLog;
use crate::aggregate::utils::down_cast_any_ref;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use arrow::array::{
Expand Down Expand Up @@ -114,6 +115,19 @@ impl AggregateExpr for ApproxDistinct {
}
}

impl PartialEq<dyn Any> for ApproxDistinct {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.input_data_type == x.input_data_type
&& self.expr.eq(&x.expr)
})
.unwrap_or(false)
}
}

#[derive(Debug)]
struct BinaryHLLAccumulator<T>
where
Expand Down
15 changes: 15 additions & 0 deletions datafusion/physical-expr/src/aggregate/approx_median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution

use crate::aggregate::utils::down_cast_any_ref;
use crate::expressions::{lit, ApproxPercentileCont};
use crate::{AggregateExpr, PhysicalExpr};
use arrow::{datatypes::DataType, datatypes::Field};
Expand Down Expand Up @@ -82,3 +83,17 @@ impl AggregateExpr for ApproxMedian {
&self.name
}
}

impl PartialEq<dyn Any> for ApproxMedian {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.data_type == x.data_type
&& self.expr.eq(&x.expr)
&& self.approx_percentile == x.approx_percentile
})
.unwrap_or(false)
}
}
Loading