Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move min and max to user defined aggregate function, remove AggregateFunction / AggregateFunctionDefinition::BuiltIn #11013

Merged
merged 8 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-examples/examples/dataframe_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::sync::Arc;

use datafusion::error::Result;
use datafusion::functions_aggregate::average::avg;
use datafusion::functions_aggregate::min_max::max;
use datafusion::prelude::*;
use datafusion::test_util::arrow_test_data;
use datafusion_common::ScalarValue;
Expand Down
8 changes: 6 additions & 2 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ use datafusion_common::{
};
use datafusion_expr::{case, is_null, lit};
use datafusion_expr::{
max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_functions_aggregate::expr_fn::{
avg, count, max, median, min, stddev, sum,
};
use datafusion_functions_aggregate::expr_fn::{avg, count, median, stddev, sum};

use async_trait::async_trait;
use datafusion_catalog::Session;
Expand Down Expand Up @@ -144,6 +146,7 @@ impl Default for DataFrameWriteOptions {
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # use datafusion::functions_aggregate::expr_fn::min;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
Expand Down Expand Up @@ -407,6 +410,7 @@ impl DataFrame {
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # use datafusion::functions_aggregate::expr_fn::min;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use datafusion_common::{
use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::expressions::{MaxAccumulator, MinAccumulator};
use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
use datafusion_physical_plan::metrics::MetricsSet;

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use super::listing::PartitionedFile;
use crate::arrow::datatypes::{Schema, SchemaRef};
use crate::error::Result;
use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator};
use crate::functions_aggregate::min_max::{MaxAccumulator, MinAccumulator};
use crate::physical_plan::{Accumulator, ColumnStatistics, Statistics};
use arrow_schema::DataType;

Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ where
///
/// ```
/// use datafusion::prelude::*;
/// # use datafusion::functions_aggregate::expr_fn::min;
/// # use datafusion::{error::Result, assert_batches_eq};
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
//! ```rust
//! # use datafusion::prelude::*;
//! # use datafusion::error::Result;
//! # use datafusion::functions_aggregate::expr_fn::min;
//! # use datafusion::arrow::record_batch::RecordBatch;
//!
//! # #[tokio::main]
Expand Down
15 changes: 2 additions & 13 deletions datafusion/core/src/physical_optimizer/aggregate_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,39 +272,28 @@ fn is_non_distinct_count(agg_expr: &dyn AggregateExpr) -> bool {
return true;
}
}

false
}

// TODO: Move this check into AggregateUDFImpl
// https://github.com/apache/datafusion/issues/11153
fn is_min(agg_expr: &dyn AggregateExpr) -> bool {
if agg_expr.as_any().is::<expressions::Min>() {
return true;
}

if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if agg_expr.fun().name() == "min" {
if agg_expr.fun().name().to_lowercase() == "min" {
return true;
}
}

false
}

// TODO: Move this check into AggregateUDFImpl
// https://github.com/apache/datafusion/issues/11153
fn is_max(agg_expr: &dyn AggregateExpr) -> bool {
if agg_expr.as_any().is::<expressions::Max>() {
return true;
}

if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if agg_expr.fun().name() == "max" {
if agg_expr.fun().name().to_lowercase() == "max" {
return true;
}
}

false
}

Expand Down
28 changes: 3 additions & 25 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ use crate::physical_plan::unnest::UnnestExec;
use crate::physical_plan::values::ValuesExec;
use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec};
use crate::physical_plan::{
aggregates, displayable, udaf, windows, AggregateExpr, ExecutionPlan,
ExecutionPlanProperties, InputOrderMode, Partitioning, PhysicalExpr, WindowExpr,
displayable, udaf, windows, AggregateExpr, ExecutionPlan, ExecutionPlanProperties,
InputOrderMode, Partitioning, PhysicalExpr, WindowExpr,
};

use arrow::compute::SortOptions;
Expand Down Expand Up @@ -1812,7 +1812,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
e: &Expr,
name: impl Into<String>,
logical_input_schema: &DFSchema,
physical_input_schema: &Schema,
_physical_input_schema: &Schema,
execution_props: &ExecutionProps,
) -> Result<AggregateExprWithOptionalArgs> {
match e {
Expand Down Expand Up @@ -1840,28 +1840,6 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
== NullTreatment::IgnoreNulls;

let (agg_expr, filter, order_by) = match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
let physical_sort_exprs = match order_by {
Some(exprs) => Some(create_physical_sort_exprs(
exprs,
logical_input_schema,
execution_props,
)?),
None => None,
};
let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);
let agg_expr = aggregates::create_aggregate_expr(
fun,
*distinct,
&physical_args,
&ordering_reqs,
physical_input_schema,
name,
ignore_nulls,
)?;
(agg_expr, filter, physical_sort_exprs)
}
AggregateFunctionDefinition::UDF(fun) => {
let sort_exprs = order_by.clone().unwrap_or(vec![]);
let physical_sort_exprs = match order_by {
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder,
scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame,
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
cast, col, exists, expr, in_subquery, lit, out_ref_col, placeholder, scalar_subquery,
when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum};
use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, max, sum};

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down
21 changes: 6 additions & 15 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ use datafusion::physical_plan::{collect, InputOrderMode};
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_common::{Result, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::type_coercion::aggregates::coerce_types;
use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf;
use datafusion_expr::{
AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf};
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
Expand Down Expand Up @@ -361,14 +361,14 @@ fn get_random_function(
window_fn_map.insert(
"min",
(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
WindowFunctionDefinition::AggregateUDF(min_udaf()),
vec![arg.clone()],
),
);
window_fn_map.insert(
"max",
(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![arg.clone()],
),
);
Expand Down Expand Up @@ -465,16 +465,7 @@ fn get_random_function(
let fn_name = window_fn_map.keys().collect::<Vec<_>>()[rand_fn_idx];
let (window_fn, args) = window_fn_map.values().collect::<Vec<_>>()[rand_fn_idx];
let mut args = args.clone();
if let WindowFunctionDefinition::AggregateFunction(f) = window_fn {
if !args.is_empty() {
// Do type coercion first argument
let a = args[0].clone();
let dt = a.data_type(schema.as_ref()).unwrap();
let sig = f.signature();
let coerced = coerce_types(f, &[dt], &sig).unwrap();
args[0] = cast(a, schema, coerced[0].clone()).unwrap();
}
} else if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn {
if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn {
if !args.is_empty() {
// Do type coercion first argument
let a = args[0].clone();
Expand Down
156 changes: 0 additions & 156 deletions datafusion/expr/src/aggregate_function.rs

This file was deleted.

Loading