Skip to content

Commit

Permalink
Refactor function argument handling in
Browse files Browse the repository at this point in the history
ScalarFunctionDefinition
  • Loading branch information
Weijun-H committed Dec 1, 2023
1 parent e19c669 commit 2ce2c81
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 76 deletions.
15 changes: 5 additions & 10 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,12 @@ impl ExprSchemable for Expr {
Expr::Cast(Cast { data_type, .. })
| Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
let arg_data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
match func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
let arg_data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

// verify that input data types is consistent with function's `TypeSignature`
data_types(&arg_data_types, &fun.signature()).map_err(|_| {
plan_datafusion_err!(
Expand All @@ -105,11 +104,7 @@ impl ExprSchemable for Expr {
fun.return_type(&arg_data_types)
}
ScalarFunctionDefinition::UDF(fun) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok(fun.return_type(&data_types)?)
Ok(fun.return_type(&arg_data_types)?)
}
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
Expand Down
66 changes: 27 additions & 39 deletions datafusion/physical-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,50 +348,38 @@ pub fn create_physical_expr(
)))
}

Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
let physical_args = args
.iter()
.map(|e| {
create_physical_expr(
e,
input_dfschema,
input_schema,
execution_props,
)
})
.collect::<Result<Vec<_>>>()?;
functions::create_physical_expr(
fun,
&physical_args,
input_schema,
execution_props,
)
}
ScalarFunctionDefinition::UDF(fun) => {
let mut physical_args = vec![];
for e in args {
physical_args.push(create_physical_expr(
e,
input_dfschema,
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
let mut physical_args = args
.iter()
.map(|e| {
create_physical_expr(e, input_dfschema, input_schema, execution_props)
})
.collect::<Result<Vec<_>>>()?;
match func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
functions::create_physical_expr(
fun,
&physical_args,
input_schema,
execution_props,
)?);
)
}
ScalarFunctionDefinition::UDF(fun) => {
// udfs with zero params expect null array as input
if args.is_empty() {
physical_args.push(Arc::new(Literal::new(ScalarValue::Null)));
}
udf::create_physical_expr(
fun.clone().as_ref(),
&physical_args,
input_schema,
)
}
// udfs with zero params expect null array as input
if args.is_empty() {
physical_args.push(Arc::new(Literal::new(ScalarValue::Null)));
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
}
udf::create_physical_expr(
fun.clone().as_ref(),
&physical_args,
input_schema,
)
}
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
}
},
}
Expr::Between(Between {
expr,
negated,
Expand Down
53 changes: 26 additions & 27 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,40 +792,39 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
.to_string(),
))
}
Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
let fun: protobuf::ScalarFunction = fun.try_into()?;
let args: Vec<Self> = args
.iter()
.map(|e| e.try_into())
.collect::<Result<Vec<Self>, Error>>()?;
Self {
expr_type: Some(ExprType::ScalarFunction(
protobuf::ScalarFunctionNode {
fun: fun.into(),
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
let args = args
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, Error>>()?;
match func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
let fun: protobuf::ScalarFunction = fun.try_into()?;
Self {
expr_type: Some(ExprType::ScalarFunction(
protobuf::ScalarFunctionNode {
fun: fun.into(),
args,
},
)),
}
}
ScalarFunctionDefinition::UDF(fun) => Self {
expr_type: Some(ExprType::ScalarUdfExpr(
protobuf::ScalarUdfExprNode {
fun_name: fun.name().to_string(),
args,
},
)),
}
}
ScalarFunctionDefinition::UDF(fun) => Self {
expr_type: Some(ExprType::ScalarUdfExpr(
protobuf::ScalarUdfExprNode {
fun_name: fun.name().to_string(),
args: args
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, Error>>()?,
},
)),
},
ScalarFunctionDefinition::Name(_) => {
return Err(Error::NotImplemented(
},
ScalarFunctionDefinition::Name(_) => {
return Err(Error::NotImplemented(
"Proto serialization error: Trying to serialize a unresolved function"
.to_string(),
));
}
}
},
}
Expr::Not(expr) => {
let expr = Box::new(protobuf::Not {
expr: Some(Box::new(expr.as_ref().try_into()?)),
Expand Down

0 comments on commit 2ce2c81

Please sign in to comment.