Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support UDAF in substrait producer/consumer #8119

Merged
merged 3 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef};

use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
aggregate_function, window_function::find_df_window_func, BinaryExpr,
BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
Expand Down Expand Up @@ -359,6 +360,7 @@ pub async fn from_substrait_rel(
_ => false,
};
from_substrait_agg_func(
ctx,
f,
input.schema(),
extensions,
Expand Down Expand Up @@ -654,6 +656,7 @@ pub async fn from_substriat_func_args(

/// Convert Substrait AggregateFunction to DataFusion Expr
pub async fn from_substrait_agg_func(
ctx: &SessionContext,
f: &AggregateFunction,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
Expand All @@ -674,23 +677,37 @@ pub async fn from_substrait_agg_func(
args.push(arg_expr?.as_ref().clone());
}

let fun = match extensions.get(&f.function_reference) {
Some(function_name) => {
aggregate_function::AggregateFunction::from_str(function_name)
}
None => not_impl_err!(
"Aggregated function not found: function anchor = {:?}",
let Some(function_name) = extensions.get(&f.function_reference) else {
return plan_err!(
"Aggregate function not registered: function anchor = {:?}",
f.function_reference
),
);
};

Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction {
fun: fun.unwrap(),
args,
distinct,
filter,
order_by,
})))
// try udaf first, then built-in aggr fn.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically the same function resolution logic that we have for sql. I wonder if we should consolidate it somewhere 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was considering putting this logic into SessionContext, as well as resolving scalar and window functions. However sql_function_to_expr uses ContextProvider to do this resolving. Maybe we have to do some refactors on ContextProvider at first.

By the way, I noticed that ContextProvider has lots of mock impl but those function resolving methods are only used in sql_function_to_expr. It also looks like an indication that refactoring is needed to me.

if let Ok(fun) = ctx.udaf(function_name) {
Ok(Arc::new(Expr::AggregateUDF(expr::AggregateUDF {
fun,
args,
filter,
order_by,
})))
} else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name)
{
Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction {
fun,
args,
distinct,
filter,
order_by,
})))
} else {
not_impl_err!(
"Aggregated function {} is not supported: function anchor = {:?}",
function_name,
f.function_reference
)
}
}

/// Convert Substrait Rex to DataFusion Expr
Expand Down
41 changes: 33 additions & 8 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,7 @@ pub fn to_substrait_agg_measure(
for arg in args {
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
}
let function_name = fun.to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
let function_anchor = _register_function(fun.to_string(), extension_info);
Ok(Measure {
measure: Some(AggregateFunction {
function_reference: function_anchor,
Expand All @@ -610,6 +609,34 @@ pub fn to_substrait_agg_measure(
}
})
}
Expr::AggregateUDF(expr::AggregateUDF{ fun, args, filter, order_by }) =>{
let sorts = if let Some(order_by) = order_by {
order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::<Result<Vec<_>>>()?
} else {
vec![]
};
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
}
let function_anchor = _register_function(fun.name.clone(), extension_info);
Ok(Measure {
measure: Some(AggregateFunction {
function_reference: function_anchor,
arguments,
sorts,
output_type: None,
invocation: AggregationInvocation::All as i32,
phase: AggregationPhase::Unspecified as i32,
args: vec![],
options: vec![],
}),
filter: match filter {
Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?),
None => None
}
})
},
Expr::Alias(Alias{expr,..})=> {
to_substrait_agg_measure(expr, schema, extension_info)
}
Expand Down Expand Up @@ -703,8 +730,8 @@ pub fn make_binary_op_scalar_func(
HashMap<String, u32>,
),
) -> Expression {
let function_name = operator_to_name(op).to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
let function_anchor =
_register_function(operator_to_name(op).to_string(), extension_info);
Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
Expand Down Expand Up @@ -807,8 +834,7 @@ pub fn to_substrait_rex(
)?)),
});
}
let function_name = fun.to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
let function_anchor = _register_function(fun.to_string(), extension_info);
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
Expand Down Expand Up @@ -973,8 +999,7 @@ pub fn to_substrait_rex(
window_frame,
}) => {
// function reference
let function_name = fun.to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
let function_anchor = _register_function(fun.to_string(), extension_info);
// arguments
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
Expand Down
64 changes: 61 additions & 3 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
// specific language governing permissions and limitations
// under the License.

use datafusion::arrow::array::ArrayRef;
use datafusion::physical_plan::Accumulator;
use datafusion::scalar::ScalarValue;
use datafusion_substrait::logical_plan::{
consumer::from_substrait_plan, producer::to_substrait_plan,
};
Expand All @@ -28,7 +31,9 @@ use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::SessionState;
use datafusion::execution::registry::SerializerRegistry;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode};
use datafusion::logical_expr::{
Extension, LogicalPlan, UserDefinedLogicalNode, Volatility,
};
use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST;
use datafusion::prelude::*;

Expand Down Expand Up @@ -626,6 +631,56 @@ async fn extension_logical_plan() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn roundtrip_aggregate_udf() -> Result<()> {
#[derive(Debug)]
struct Dummy {}

impl Accumulator for Dummy {
fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
Ok(vec![])
}

fn update_batch(
&mut self,
_values: &[ArrayRef],
) -> datafusion::error::Result<()> {
Ok(())
}

fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> {
Ok(())
}

fn evaluate(&self) -> datafusion::error::Result<ScalarValue> {
Ok(ScalarValue::Float64(None))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}
}

let dummy_agg = create_udaf(
// the name; used to represent it in plan descriptions and in the registry, to use in SQL.
"dummy_agg",
// the input type; DataFusion guarantees that the first entry of `values` in `update` has this type.
vec![DataType::Int64],
// the return type; DataFusion expects this to match the type returned by `evaluate`.
Arc::new(DataType::Int64),
Volatility::Immutable,
// This is the accumulator factory; DataFusion uses it to create new accumulators.
Arc::new(|_| Ok(Box::new(Dummy {}))),
// This is the description of the state. `state()` must match the types here.
Arc::new(vec![DataType::Float64, DataType::UInt32]),
);

let ctx = create_context().await?;
ctx.register_udaf(dummy_agg);

roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await
}

fn check_post_join_filters(rel: &Rel) -> Result<()> {
// search for target_rel and field value in proto
match &rel.rel_type {
Expand Down Expand Up @@ -762,8 +817,7 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> {
Ok(())
}

async fn roundtrip(sql: &str) -> Result<()> {
let ctx = create_context().await?;
async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> {
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
let proto = to_substrait_plan(&plan, &ctx)?;
Expand All @@ -779,6 +833,10 @@ async fn roundtrip(sql: &str) -> Result<()> {
Ok(())
}

async fn roundtrip(sql: &str) -> Result<()> {
roundtrip_with_ctx(sql, create_context().await?).await
}

async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> {
let ctx = create_context().await?;
let df = ctx.sql(sql).await?;
Expand Down