From 720ed8235bdfea25b9bc49ddf8050574c16f9988 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Tue, 7 Nov 2023 22:27:00 -0800 Subject: [PATCH] cleanup return_type() --- datafusion/expr/src/built_in_function.rs | 38 +++++-------------- datafusion/expr/src/expr_schema.rs | 19 ++++++++-- .../expr/src/type_coercion/functions.rs | 11 +++++- datafusion/physical-expr/src/functions.rs | 15 ++++---- 4 files changed, 42 insertions(+), 41 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 16187572c521..ef8e1f25e752 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -28,14 +28,12 @@ use crate::signature::TIMEZONE_WILDCARD; use crate::type_coercion::binary::get_wider_type; use crate::type_coercion::functions::data_types; use crate::{ - conditional_expressions, struct_expressions, utils, FuncMonotonicity, Signature, + conditional_expressions, struct_expressions, FuncMonotonicity, Signature, TypeSignature, Volatility, }; use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; -use datafusion_common::{ - internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, -}; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -483,6 +481,13 @@ impl BuiltinScalarFunction { } /// Returns the output [`DataType`] of this function + /// + /// This method should be invoked only after `input_expr_types` have been validated + /// against the function's `TypeSignature` using `type_coercion::functions::data_types()`. + /// + /// This method will: + /// 1. Perform additional checks on `input_expr_types` that are beyond the scope of `TypeSignature` validation. + /// 2. Deduce the output `DataType` based on the provided `input_expr_types`. pub fn return_type(self, input_expr_types: &[DataType]) -> Result { use DataType::*; use TimeUnit::*; @@ -490,31 +495,6 @@ impl BuiltinScalarFunction { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. - if input_expr_types.is_empty() - && !self.signature().type_signature.supports_zero_argument() - { - return plan_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types - ) - ); - } - - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &self.signature()).map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - ) - ) - })?; - // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 025b74eb5009..2889fac8c1ee 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -23,7 +23,8 @@ use crate::expr::{ }; use crate::field_util::GetFieldAccessSchema; use crate::type_coercion::binary::get_result_type; -use crate::{LogicalPlan, Projection, Subquery}; +use crate::type_coercion::functions::data_types; +use crate::{utils, LogicalPlan, Projection, Subquery}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ @@ -89,12 +90,24 @@ impl ExprSchemable for Expr { Ok((fun.return_type)(&data_types)?.as_ref().clone()) } Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let data_types = args + let arg_data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - fun.return_type(&data_types) + // verify that input data types is consistent with function's `TypeSignature` + data_types(&arg_data_types, &fun.signature()).map_err(|_| { + plan_datafusion_err!( + "{}", + utils::generate_signature_error_msg( + &format!("{fun}"), + fun.signature(), + &arg_data_types, + ) + ) + })?; + + fun.return_type(&arg_data_types) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { let data_types = args diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index b49bf37d6754..79b574238495 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -35,8 +35,17 @@ pub fn data_types( signature: &Signature, ) -> Result> { if current_types.is_empty() { - return Ok(vec![]); + if signature.type_signature.supports_zero_argument() { + return Ok(vec![]); + } else { + return plan_err!( + "Coercion from {:?} to the signature {:?} failed.", + current_types, + &signature.type_signature + ); + } } + let valid_types = get_valid_types(&signature.type_signature, current_types)?; if valid_types diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index b66bac41014d..f14bad093ac7 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -47,7 +47,8 @@ use arrow::{ use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; pub use datafusion_expr::FuncMonotonicity; use datafusion_expr::{ - BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, + type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue, + ScalarFunctionImplementation, }; use std::ops::Neg; use std::sync::Arc; @@ -65,6 +66,9 @@ pub fn create_physical_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; + // verify that input data types is consistent with function's `TypeSignature` + data_types(&input_expr_types, &fun.signature())?; + let data_type = fun.return_type(&input_expr_types)?; let fun_expr: ScalarFunctionImplementation = match fun { @@ -2952,13 +2956,8 @@ mod tests { "Builtin scalar function {fun} does not support empty arguments" ); } - Err(DataFusionError::Plan(err)) => { - if !err - .contains("No function matches the given name and argument types") - { - return plan_err!( - "Builtin scalar function {fun} didn't got the right error message with empty arguments"); - } + Err(DataFusionError::Plan(_)) => { + // Continue the loop } Err(..) => { return internal_err!(