Skip to content

Commit

Permalink
Merge commit '600b815097b2b1ebbde3f48d071852bd6364a30a' into chunchun…
Browse files Browse the repository at this point in the history
…/update-df-apr-week-3
  • Loading branch information
appletreeisyellow committed Apr 26, 2024
2 parents 8a15641 + 600b815 commit 49bccd2
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 26 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ pub use logical_plan::*;
pub use operator::Operator;
pub use partition_evaluator::PartitionEvaluator;
pub use signature::{
FuncMonotonicity, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD,
ArrayFunctionSignature, FuncMonotonicity, Signature, TypeSignature, Volatility,
TIMEZONE_WILDCARD,
};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{AggregateUDF, AggregateUDFImpl};
Expand Down
21 changes: 9 additions & 12 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub fn coerce_types(
) -> Result<Vec<DataType>> {
use DataType::*;
// Validate input_types matches (at least one of) the func signature.
check_arg_count(agg_fun, input_types, &signature.type_signature)?;
check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?;

match agg_fun {
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
Expand Down Expand Up @@ -323,17 +323,16 @@ pub fn coerce_types(
/// This method DOES NOT validate the argument types - only that (at least one,
/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
/// number of input types.
fn check_arg_count(
agg_fun: &AggregateFunction,
pub fn check_arg_count(
func_name: &str,
input_types: &[DataType],
signature: &TypeSignature,
) -> Result<()> {
match signature {
TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
if input_types.len() != *agg_count {
return plan_err!(
"The function {:?} expects {:?} arguments, but {:?} were provided",
agg_fun,
"The function {func_name} expects {:?} arguments, but {:?} were provided",
agg_count,
input_types.len()
);
Expand All @@ -342,8 +341,7 @@ fn check_arg_count(
TypeSignature::Exact(types) => {
if types.len() != input_types.len() {
return plan_err!(
"The function {:?} expects {:?} arguments, but {:?} were provided",
agg_fun,
"The function {func_name} expects {:?} arguments, but {:?} were provided",
types.len(),
input_types.len()
);
Expand All @@ -352,19 +350,18 @@ fn check_arg_count(
TypeSignature::OneOf(variants) => {
let ok = variants
.iter()
.any(|v| check_arg_count(agg_fun, input_types, v).is_ok());
.any(|v| check_arg_count(func_name, input_types, v).is_ok());
if !ok {
return plan_err!(
"The function {:?} does not accept {:?} function arguments.",
agg_fun,
"The function {func_name} does not accept {:?} function arguments.",
input_types.len()
);
}
}
TypeSignature::VariadicAny => {
if input_types.is_empty() {
return plan_err!(
"The function {agg_fun:?} expects at least one argument"
"The function {func_name} expects at least one argument"
);
}
}
Expand Down Expand Up @@ -594,7 +591,7 @@ mod tests {
let input_types = vec![DataType::Int64, DataType::Int32];
let signature = fun.signature();
let result = coerce_types(&fun, &input_types, &signature);
assert_eq!("Error during planning: The function Min expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace());
assert_eq!("Error during planning: The function MIN expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace());

// test input args is invalid data type for sum or avg
let fun = AggregateFunction::Sum;
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ datafusion-expr = { workspace = true }
datafusion-physical-expr-common = { workspace = true }
log = { workspace = true }
paste = "1.0.14"
sqlparser = { workspace = true }
16 changes: 13 additions & 3 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ use datafusion_common::{
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{Accumulator, AggregateUDFImpl, Expr, Signature, Volatility};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature,
TypeSignature, Volatility,
};
use datafusion_physical_expr_common::aggregate::utils::{
down_cast_any_ref, get_sort_options, ordering_fields,
};
Expand All @@ -36,14 +39,14 @@ use datafusion_physical_expr_common::expressions;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use datafusion_physical_expr_common::utils::reverse_order_bys;
use sqlparser::ast::NullTreatment;
use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;

make_udaf_function!(
FirstValue,
first_value,
value,
"Returns the first value in a group of values.",
first_value_udaf
);
Expand Down Expand Up @@ -73,7 +76,14 @@ impl FirstValue {
pub fn new() -> Self {
Self {
aliases: vec![String::from("FIRST_VALUE")],
signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
signature: Signature::one_of(
vec![
// TODO: we can introduce more strict signature that only numeric of array types are allowed
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
TypeSignature::Uniform(1, NUMERICS.to_vec()),
],
Volatility::Immutable,
),
}
}
}
Expand Down
21 changes: 13 additions & 8 deletions datafusion/functions-aggregate/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,24 @@
// under the License.

macro_rules! make_udaf_function {
($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
paste::paste! {
// "fluent expr_fn" style function
#[doc = $DOC]
pub fn $EXPR_FN($($arg: Expr),*) -> Expr {
pub fn $EXPR_FN(
args: Vec<Expr>,
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Expr>>,
null_treatment: Option<NullTreatment>
) -> Expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
$AGGREGATE_UDF_FN(),
vec![$($arg),*],
// TODO: Support arguments for `expr` API
false,
None,
None,
None,
args,
distinct,
filter,
order_by,
null_treatment,
))
}

Expand Down
7 changes: 7 additions & 0 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod utils;

use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::type_coercion::aggregates::check_arg_count;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator,
};
Expand Down Expand Up @@ -46,6 +47,12 @@ pub fn create_aggregate_expr(
.map(|arg| arg.data_type(schema))
.collect::<Result<Vec<_>>>()?;

check_arg_count(
fun.name(),
&input_exprs_types,
&fun.signature().type_signature,
)?;

let ordering_types = ordering_req
.iter()
.map(|e| e.expr.data_type(schema))
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ async fn roundtrip_expr_api() -> Result<()> {
lit(1),
),
array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)),
first_value(lit(1)),
first_value(vec![lit(1)], false, None, None, None),
];

// ensure expressions created with the expr api can be round tripped
Expand Down
2 changes: 1 addition & 1 deletion datafusion/substrait/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ itertools = { workspace = true }
object_store = { workspace = true }
prost = "0.12"
prost-types = "0.12"
substrait = "0.30.0"
substrait = "0.31.0"

[dev-dependencies]
tokio = { workspace = true }
Expand Down

0 comments on commit 49bccd2

Please sign in to comment.