From 30812598d3c74a5e5feb8957fe93110504986b45 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Fri, 10 Nov 2023 17:03:19 +0800 Subject: [PATCH 1/3] feat: support UDAF in substrait producer/consumer Signed-off-by: Ruihang Xia --- .../substrait/src/logical_plan/consumer.rs | 45 +++++++++---- .../substrait/src/logical_plan/producer.rs | 29 +++++++++ .../tests/cases/roundtrip_logical_plan.rs | 64 ++++++++++++++++++- 3 files changed, 121 insertions(+), 17 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index a15121652452..597fb633494a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -19,6 +19,7 @@ use async_recursion::async_recursion; use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef}; +use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ aggregate_function, window_function::find_df_window_func, BinaryExpr, BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, @@ -359,6 +360,7 @@ pub async fn from_substrait_rel( _ => false, }; from_substrait_agg_func( + ctx, f, input.schema(), extensions, @@ -654,6 +656,7 @@ pub async fn from_substriat_func_args( /// Convert Substrait AggregateFunction to DataFusion Expr pub async fn from_substrait_agg_func( + ctx: &SessionContext, f: &AggregateFunction, input_schema: &DFSchema, extensions: &HashMap, @@ -674,23 +677,37 @@ pub async fn from_substrait_agg_func( args.push(arg_expr?.as_ref().clone()); } - let fun = match extensions.get(&f.function_reference) { - Some(function_name) => { - aggregate_function::AggregateFunction::from_str(function_name) - } - None => not_impl_err!( - "Aggregated function not found: function anchor = {:?}", + let Some(function_name) = extensions.get(&f.function_reference) else { + return plan_err!( + "Aggregated function not registered: function anchor = {:?}", f.function_reference - ), + ); }; - Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction { - fun: fun.unwrap(), - args, - distinct, - filter, - order_by, - }))) + // try udaf first, then built-in aggr fn. + if let Ok(fun) = ctx.udaf(function_name) { + Ok(Arc::new(Expr::AggregateUDF(expr::AggregateUDF { + fun, + args, + filter, + order_by, + }))) + } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) + { + Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction { + fun, + args, + distinct, + filter, + order_by, + }))) + } else { + not_impl_err!( + "Aggregated function {} is not supported: function anchor = {:?}", + function_name, + f.function_reference + ) + } } /// Convert Substrait Rex to DataFusion Expr diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index e3c6f94d43d5..3edd606dd406 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -610,6 +610,35 @@ pub fn to_substrait_agg_measure( } }) } + Expr::AggregateUDF(expr::AggregateUDF{ fun, args, filter, order_by }) =>{ + let sorts = if let Some(order_by) = order_by { + order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + } + let function_name = fun.name.to_lowercase(); + let function_anchor = _register_function(function_name, extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: AggregationInvocation::All as i32, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + None => None + } + }) + }, Expr::Alias(Alias{expr,..})=> { to_substrait_agg_measure(expr, schema, extension_info) } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index ca2b4d48c460..f04bff2e8a52 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use datafusion::arrow::array::ArrayRef; +use datafusion::physical_plan::Accumulator; +use datafusion::scalar::ScalarValue; use datafusion_substrait::logical_plan::{ consumer::from_substrait_plan, producer::to_substrait_plan, }; @@ -28,7 +31,9 @@ use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; -use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode}; +use datafusion::logical_expr::{ + Extension, LogicalPlan, UserDefinedLogicalNode, Volatility, +}; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; @@ -626,6 +631,56 @@ async fn extension_logical_plan() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_aggregate_udf() -> Result<()> { + #[derive(Debug)] + struct Dummy {} + + impl Accumulator for Dummy { + fn state(&self) -> datafusion::error::Result> { + Ok(vec![]) + } + + fn update_batch( + &mut self, + _values: &[ArrayRef], + ) -> datafusion::error::Result<()> { + Ok(()) + } + + fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + Ok(()) + } + + fn evaluate(&self) -> datafusion::error::Result { + Ok(ScalarValue::Float64(None)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + } + + let dummy_agg = create_udaf( + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "dummy_agg", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + vec![DataType::Int64], + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Int64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new(|_| Ok(Box::new(Dummy {}))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), + ); + + let ctx = create_context().await?; + ctx.register_udaf(dummy_agg); + + roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await +} + fn check_post_join_filters(rel: &Rel) -> Result<()> { // search for target_rel and field value in proto match &rel.rel_type { @@ -762,8 +817,7 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { Ok(()) } -async fn roundtrip(sql: &str) -> Result<()> { - let ctx = create_context().await?; +async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> { let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan, &ctx)?; @@ -779,6 +833,10 @@ async fn roundtrip(sql: &str) -> Result<()> { Ok(()) } +async fn roundtrip(sql: &str) -> Result<()> { + roundtrip_with_ctx(sql, create_context().await?).await +} + async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; From 62efa7077dec49fa4d788bcdd42a015420447ff6 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sun, 12 Nov 2023 16:49:03 +0800 Subject: [PATCH 2/3] Update datafusion/substrait/src/logical_plan/consumer.rs Co-authored-by: Andrew Lamb --- datafusion/substrait/src/logical_plan/consumer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 597fb633494a..e9eb307e6881 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -679,7 +679,7 @@ pub async fn from_substrait_agg_func( let Some(function_name) = extensions.get(&f.function_reference) else { return plan_err!( - "Aggregated function not registered: function anchor = {:?}", + "Aggregate function not registered: function anchor = {:?}", f.function_reference ); }; From 61e819555fee6f8f3a58f83f322d5f7db160717f Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sun, 12 Nov 2023 19:14:25 +0800 Subject: [PATCH 3/3] remove redundent to_lowercase Signed-off-by: Ruihang Xia --- .../substrait/src/logical_plan/producer.rs | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 3edd606dd406..969e013c3958 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -588,8 +588,7 @@ pub fn to_substrait_agg_measure( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); } - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = _register_function(fun.to_string(), extension_info); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -620,8 +619,7 @@ pub fn to_substrait_agg_measure( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); } - let function_name = fun.name.to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = _register_function(fun.name.clone(), extension_info); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -732,8 +730,8 @@ pub fn make_binary_op_scalar_func( HashMap, ), ) -> Expression { - let function_name = operator_to_name(op).to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = + _register_function(operator_to_name(op).to_string(), extension_info); Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -836,8 +834,7 @@ pub fn to_substrait_rex( )?)), }); } - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = _register_function(fun.to_string(), extension_info); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -1002,8 +999,7 @@ pub fn to_substrait_rex( window_frame, }) => { // function reference - let function_name = fun.to_string().to_lowercase(); - let function_anchor = _register_function(function_name, extension_info); + let function_anchor = _register_function(fun.to_string(), extension_info); // arguments let mut arguments: Vec = vec![]; for arg in args {