diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index d6a0add9b253..7d9fc1a0a713 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -102,7 +102,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { } } Expr::ScalarUDF(ScalarUDF { fun, .. }) => { - match fun.signature.volatility { + match fun.signature().volatility { Volatility::Immutable => VisitRecursion::Continue, // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index d523c39ee01e..36be388de3a8 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -806,7 +806,7 @@ impl SessionContext { self.state .write() .scalar_functions - .insert(f.name.clone(), Arc::new(f)); + .insert(f.name().to_string(), Arc::new(f)); } /// Registers an aggregate UDF within this context. diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index f941e88f3a36..fb3e0b38e54f 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -222,7 +222,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { create_function_physical_name(&func.fun.to_string(), false, &func.args) } Expr::ScalarUDF(ScalarUDF { fun, args }) => { - create_function_physical_name(&fun.name, false, args) + create_function_physical_name(fun.name(), false, args) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { create_function_physical_name(&fun.to_string(), false, args) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 239a3188502c..af63f7724d9b 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1178,7 +1178,7 @@ impl fmt::Display for Expr { fmt_function(f, &func.fun.to_string(), false, &func.args, true) } Expr::ScalarUDF(ScalarUDF { fun, args }) => { - fmt_function(f, &fun.name, false, args, true) + fmt_function(f, fun.name(), false, args, true) } Expr::WindowFunction(WindowFunction { fun, @@ -1512,7 +1512,7 @@ fn create_name(e: &Expr) -> Result { create_function_name(&func.fun.to_string(), false, &func.args) } Expr::ScalarUDF(ScalarUDF { fun, args }) => { - create_function_name(&fun.name, false, args) + create_function_name(fun.name(), false, args) } Expr::WindowFunction(WindowFunction { fun, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 025b74eb5009..900beaf2c938 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -86,7 +86,7 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - Ok((fun.return_type)(&data_types)?.as_ref().clone()) + Ok(fun.return_type(&data_types)?) } Expr::ScalarFunction(ScalarFunction { fun, args }) => { let data_types = args diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index be6c90aa5985..a482717c8489 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -15,23 +15,31 @@ // specific language governing permissions and limitations // under the License. -//! Udf module contains foundational types that are used to represent UDFs in DataFusion. +//! [`ScalarUDF`]: Scalar User Defined Functions use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use arrow::datatypes::DataType; +use datafusion_common::Result; use std::fmt; use std::fmt::Debug; use std::fmt::Formatter; use std::sync::Arc; -/// Logical representation of a UDF. +/// Logical representation of a Scalar User Defined Function. +/// +/// A scalar function produces a single row output for each row of input. +/// +/// This struct contains the information DataFusion needs to plan and invoke +/// functions such name, type signature, return type, and actual implementation. +/// #[derive(Clone)] pub struct ScalarUDF { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, + /// The name of the function + name: String, + /// The signature (the types of arguments that are supported) + signature: Signature, + /// Function that returns the return type given the argument types + return_type: ReturnTypeFunction, /// actual implementation /// /// The fn param is the wrapped function but be aware that the function will @@ -40,7 +48,7 @@ pub struct ScalarUDF { /// will be passed. In that case the single element is a null array to indicate /// the batch's row count (so that the generative zero-argument function can know /// the result array size). - pub fun: ScalarFunctionImplementation, + fun: ScalarFunctionImplementation, } impl Debug for ScalarUDF { @@ -89,4 +97,23 @@ impl ScalarUDF { pub fn call(&self, args: Vec) -> Expr { Expr::ScalarUDF(crate::expr::ScalarUDF::new(Arc::new(self.clone()), args)) } + + /// Returns this function's name + pub fn name(&self) -> &str { + &self.name + } + /// Returns this function's signature + pub fn signature(&self) -> &Signature { + &self.signature + } + /// return the return type of this function given the types of the arguments + pub fn return_type(&self, args: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(args)?; + Ok(res.as_ref().clone()) + } + /// return the implementation of this function + pub fn fun(&self) -> &ScalarFunctionImplementation { + &self.fun + } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index bfdbec390199..ba67e1857c09 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -323,7 +323,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let new_expr = coerce_arguments_for_signature( args.as_slice(), &self.schema, - &fun.signature, + fun.signature(), )?; Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr))) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 04fdcca0a994..23f78e917d2f 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -350,7 +350,7 @@ impl<'a> ConstEvaluator<'a> { Self::volatility_ok(fun.volatility()) } Expr::ScalarUDF(expr::ScalarUDF { fun, .. }) => { - Self::volatility_ok(fun.signature.volatility) + Self::volatility_ok(fun.signature().volatility) } Expr::Literal(_) | Expr::BinaryExpr { .. } diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index b66bac41014d..48cba778bbcc 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -211,7 +211,7 @@ pub fn create_physical_expr( &format!("{fun}"), fun_expr, input_phy_exprs.to_vec(), - &data_type, + data_type, monotonicity, ))) } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 5acd5dcf2336..6fc42bc94172 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -77,14 +77,14 @@ impl ScalarFunctionExpr { name: &str, fun: ScalarFunctionImplementation, args: Vec>, - return_type: &DataType, + return_type: DataType, monotonicity: Option, ) -> Self { Self { fun, name: name.to_owned(), args, - return_type: return_type.clone(), + return_type, monotonicity, } } @@ -165,7 +165,7 @@ impl PhysicalExpr for ScalarFunctionExpr { &self.name, self.fun.clone(), children, - self.return_type(), + self.return_type().clone(), self.monotonicity.clone(), ))) } diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index af1e77cbf566..0ec1cf3f256b 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -35,10 +35,10 @@ pub fn create_physical_expr( .collect::>>()?; Ok(Arc::new(ScalarFunctionExpr::new( - &fun.name, - fun.fun.clone(), + fun.name(), + fun.fun().clone(), input_phy_exprs.to_vec(), - (fun.return_type)(&input_exprs_types)?.as_ref(), + fun.return_type(&input_exprs_types)?, None, ))) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 687b73cfc886..f8a5d0a67f06 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -754,7 +754,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } Expr::ScalarUDF(ScalarUDF { fun, args }) => Self { expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { - fun_name: fun.name.clone(), + fun_name: fun.name().to_string(), args: args .iter() .map(|expr| expr.try_into()) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index a956eded9032..c619f5f5a570 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -18,7 +18,6 @@ //! Serde code to convert from protocol buffers to Rust data structures. use std::convert::{TryFrom, TryInto}; -use std::ops::Deref; use std::sync::Arc; use arrow::compute::SortOptions; @@ -308,12 +307,12 @@ pub fn parse_physical_expr( &e.name, fun_expr, args, - &convert_required!(e.return_type)?, + convert_required!(e.return_type)?, None, )) } ExprType::ScalarUdf(e) => { - let scalar_fun = registry.udf(e.name.as_str())?.deref().clone().fun; + let scalar_fun = registry.udf(e.name.as_str())?.fun().clone(); let args = e .args @@ -325,7 +324,7 @@ pub fn parse_physical_expr( e.name.as_str(), scalar_fun, args, - &convert_required!(e.return_type)?, + convert_required!(e.return_type)?, None, )) } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 01a0916d8cd2..e3cde2563169 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -515,7 +515,7 @@ fn roundtrip_builtin_scalar_function() -> Result<()> { "acos", fun_expr, vec![col("a", &schema)?], - &DataType::Int64, + DataType::Int64, None, ); @@ -549,7 +549,7 @@ fn roundtrip_scalar_udf() -> Result<()> { "dummy", scalar_fn, vec![col("a", &schema)?], - &DataType::Int64, + DataType::Int64, None, );