Skip to content

Commit

Permalink
Minor: Signature check for UDAF (#10147)
Browse files Browse the repository at this point in the history
* add sig for udaf

Signed-off-by: jayzhan211 <[email protected]>

* fix test

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored Apr 21, 2024
1 parent 0e21281 commit 600b815
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 15 deletions.
3 changes: 2 additions & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ pub use logical_plan::*;
pub use operator::Operator;
pub use partition_evaluator::PartitionEvaluator;
pub use signature::{
FuncMonotonicity, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD,
ArrayFunctionSignature, FuncMonotonicity, Signature, TypeSignature, Volatility,
TIMEZONE_WILDCARD,
};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{AggregateUDF, AggregateUDFImpl};
Expand Down
21 changes: 9 additions & 12 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub fn coerce_types(
) -> Result<Vec<DataType>> {
use DataType::*;
// Validate input_types matches (at least one of) the func signature.
check_arg_count(agg_fun, input_types, &signature.type_signature)?;
check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?;

match agg_fun {
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
Expand Down Expand Up @@ -323,17 +323,16 @@ pub fn coerce_types(
/// This method DOES NOT validate the argument types - only that (at least one,
/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
/// number of input types.
fn check_arg_count(
agg_fun: &AggregateFunction,
pub fn check_arg_count(
func_name: &str,
input_types: &[DataType],
signature: &TypeSignature,
) -> Result<()> {
match signature {
TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
if input_types.len() != *agg_count {
return plan_err!(
"The function {:?} expects {:?} arguments, but {:?} were provided",
agg_fun,
"The function {func_name} expects {:?} arguments, but {:?} were provided",
agg_count,
input_types.len()
);
Expand All @@ -342,8 +341,7 @@ fn check_arg_count(
TypeSignature::Exact(types) => {
if types.len() != input_types.len() {
return plan_err!(
"The function {:?} expects {:?} arguments, but {:?} were provided",
agg_fun,
"The function {func_name} expects {:?} arguments, but {:?} were provided",
types.len(),
input_types.len()
);
Expand All @@ -352,19 +350,18 @@ fn check_arg_count(
TypeSignature::OneOf(variants) => {
let ok = variants
.iter()
.any(|v| check_arg_count(agg_fun, input_types, v).is_ok());
.any(|v| check_arg_count(func_name, input_types, v).is_ok());
if !ok {
return plan_err!(
"The function {:?} does not accept {:?} function arguments.",
agg_fun,
"The function {func_name} does not accept {:?} function arguments.",
input_types.len()
);
}
}
TypeSignature::VariadicAny => {
if input_types.is_empty() {
return plan_err!(
"The function {agg_fun:?} expects at least one argument"
"The function {func_name} expects at least one argument"
);
}
}
Expand Down Expand Up @@ -594,7 +591,7 @@ mod tests {
let input_types = vec![DataType::Int64, DataType::Int32];
let signature = fun.signature();
let result = coerce_types(&fun, &input_types, &signature);
assert_eq!("Error during planning: The function Min expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace());
assert_eq!("Error during planning: The function MIN expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace());

// test input args is invalid data type for sum or avg
let fun = AggregateFunction::Sum;
Expand Down
14 changes: 12 additions & 2 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ use datafusion_common::{
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{Accumulator, AggregateUDFImpl, Expr, Signature, Volatility};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature,
TypeSignature, Volatility,
};
use datafusion_physical_expr_common::aggregate::utils::{
down_cast_any_ref, get_sort_options, ordering_fields,
};
Expand Down Expand Up @@ -73,7 +76,14 @@ impl FirstValue {
pub fn new() -> Self {
Self {
aliases: vec![String::from("FIRST_VALUE")],
signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
signature: Signature::one_of(
vec![
// TODO: we can introduce more strict signature that only numeric of array types are allowed
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
TypeSignature::Uniform(1, NUMERICS.to_vec()),
],
Volatility::Immutable,
),
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod utils;

use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::type_coercion::aggregates::check_arg_count;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator,
};
Expand Down Expand Up @@ -46,6 +47,12 @@ pub fn create_aggregate_expr(
.map(|arg| arg.data_type(schema))
.collect::<Result<Vec<_>>>()?;

check_arg_count(
fun.name(),
&input_exprs_types,
&fun.signature().type_signature,
)?;

let ordering_types = ordering_req
.iter()
.map(|e| e.expr.data_type(schema))
Expand Down

0 comments on commit 600b815

Please sign in to comment.