From 600b815097b2b1ebbde3f48d071852bd6364a30a Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 21 Apr 2024 11:42:34 +0800 Subject: [PATCH] Minor: Signature check for UDAF (#10147) * add sig for udaf Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/expr/src/lib.rs | 3 ++- .../expr/src/type_coercion/aggregates.rs | 21 ++++++++----------- .../functions-aggregate/src/first_last.rs | 14 +++++++++++-- .../physical-expr-common/src/aggregate/mod.rs | 7 +++++++ 4 files changed, 30 insertions(+), 15 deletions(-) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index a297f2dc7886..36732324eff6 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -79,7 +79,8 @@ pub use logical_plan::*; pub use operator::Operator; pub use partition_evaluator::PartitionEvaluator; pub use signature::{ - FuncMonotonicity, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, + ArrayFunctionSignature, FuncMonotonicity, Signature, TypeSignature, Volatility, + TIMEZONE_WILDCARD, }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{AggregateUDF, AggregateUDFImpl}; diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 866aea06b4d4..44f2671f4f99 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -93,7 +93,7 @@ pub fn coerce_types( ) -> Result> { use DataType::*; // Validate input_types matches (at least one of) the func signature. - check_arg_count(agg_fun, input_types, &signature.type_signature)?; + check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?; match agg_fun { AggregateFunction::Count | AggregateFunction::ApproxDistinct => { @@ -323,8 +323,8 @@ pub fn coerce_types( /// This method DOES NOT validate the argument types - only that (at least one, /// in the case of [`TypeSignature::OneOf`]) signature matches the desired /// number of input types. -fn check_arg_count( - agg_fun: &AggregateFunction, +pub fn check_arg_count( + func_name: &str, input_types: &[DataType], signature: &TypeSignature, ) -> Result<()> { @@ -332,8 +332,7 @@ fn check_arg_count( TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { if input_types.len() != *agg_count { return plan_err!( - "The function {:?} expects {:?} arguments, but {:?} were provided", - agg_fun, + "The function {func_name} expects {:?} arguments, but {:?} were provided", agg_count, input_types.len() ); @@ -342,8 +341,7 @@ fn check_arg_count( TypeSignature::Exact(types) => { if types.len() != input_types.len() { return plan_err!( - "The function {:?} expects {:?} arguments, but {:?} were provided", - agg_fun, + "The function {func_name} expects {:?} arguments, but {:?} were provided", types.len(), input_types.len() ); @@ -352,11 +350,10 @@ fn check_arg_count( TypeSignature::OneOf(variants) => { let ok = variants .iter() - .any(|v| check_arg_count(agg_fun, input_types, v).is_ok()); + .any(|v| check_arg_count(func_name, input_types, v).is_ok()); if !ok { return plan_err!( - "The function {:?} does not accept {:?} function arguments.", - agg_fun, + "The function {func_name} does not accept {:?} function arguments.", input_types.len() ); } @@ -364,7 +361,7 @@ fn check_arg_count( TypeSignature::VariadicAny => { if input_types.is_empty() { return plan_err!( - "The function {agg_fun:?} expects at least one argument" + "The function {func_name} expects at least one argument" ); } } @@ -594,7 +591,7 @@ mod tests { let input_types = vec![DataType::Int64, DataType::Int32]; let signature = fun.signature(); let result = coerce_types(&fun, &input_types, &signature); - assert_eq!("Error during planning: The function Min expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace()); + assert_eq!("Error during planning: The function MIN expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace()); // test input args is invalid data type for sum or avg let fun = AggregateFunction::Sum; diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 1a56b23cd26a..f8b388a4f3b1 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -27,7 +27,10 @@ use datafusion_common::{ use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; -use datafusion_expr::{Accumulator, AggregateUDFImpl, Expr, Signature, Volatility}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, + TypeSignature, Volatility, +}; use datafusion_physical_expr_common::aggregate::utils::{ down_cast_any_ref, get_sort_options, ordering_fields, }; @@ -73,7 +76,14 @@ impl FirstValue { pub fn new() -> Self { Self { aliases: vec![String::from("FIRST_VALUE")], - signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::one_of( + vec![ + // TODO: we can introduce more strict signature that only numeric of array types are allowed + TypeSignature::ArraySignature(ArrayFunctionSignature::Array), + TypeSignature::Uniform(1, NUMERICS.to_vec()), + ], + Volatility::Immutable, + ), } } } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 33044fd9beee..448af634176a 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -19,6 +19,7 @@ pub mod utils; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::type_coercion::aggregates::check_arg_count; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, }; @@ -46,6 +47,12 @@ pub fn create_aggregate_expr( .map(|arg| arg.data_type(schema)) .collect::>>()?; + check_arg_count( + fun.name(), + &input_exprs_types, + &fun.signature().type_signature, + )?; + let ordering_types = ordering_req .iter() .map(|e| e.expr.data_type(schema))