Skip to content

Commit

Permalink
refactor: use ExprBuilder to consume substrait expr
Browse files Browse the repository at this point in the history
Signed-off-by: Ruihang Xia <[email protected]>
  • Loading branch information
waynexia committed Dec 12, 2023
1 parent 95ba48b commit cf8f6d4
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 116 deletions.
3 changes: 3 additions & 0 deletions datafusion/common/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,9 @@ make_error!(not_impl_err, not_impl_datafusion_err, NotImplemented);
// Exposes a macro to create `DataFusionError::Execution`
make_error!(exec_err, exec_datafusion_err, Execution);

// Exposes a macro to create `DataFusionError::Substrait`
make_error!(substrait_err, substrait_datafusion_err, Substrait);

// Exposes a macro to create `DataFusionError::SQL`
#[macro_export]
macro_rules! sql_err {
Expand Down
248 changes: 132 additions & 116 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef};
use datafusion::common::{
not_impl_err, substrait_datafusion_err, substrait_err, DFField, DFSchema, DFSchemaRef,
};

use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
Expand Down Expand Up @@ -73,16 +75,7 @@ use crate::variation_const::{
enum ScalarFunctionType {
Builtin(BuiltinScalarFunction),
Op(Operator),
/// [Expr::Not]
Not,
/// [Expr::Like] Used for filtering rows based on the given wildcard pattern. Case sensitive
Like,
/// [Expr::Like] Case insensitive operator counterpart of `Like`
ILike,
/// [Expr::IsNull]
IsNull,
/// [Expr::IsNotNull]
IsNotNull,
Expr(BuiltinExprBuilder),
}

pub fn name_to_op(name: &str) -> Result<Operator> {
Expand Down Expand Up @@ -127,14 +120,11 @@ fn scalar_function_type_from_str(name: &str) -> Result<ScalarFunctionType> {
return Ok(ScalarFunctionType::Builtin(fun));
}

match name {
"not" => Ok(ScalarFunctionType::Not),
"like" => Ok(ScalarFunctionType::Like),
"ilike" => Ok(ScalarFunctionType::ILike),
"is_null" => Ok(ScalarFunctionType::IsNull),
"is_not_null" => Ok(ScalarFunctionType::IsNotNull),
others => not_impl_err!("Unsupported function name: {others:?}"),
if let Some(builder) = BuiltinExprBuilder::try_from_name(name) {
return Ok(ScalarFunctionType::Expr(builder));
}

not_impl_err!("Unsupported function name: {name:?}")
}

fn split_eq_and_noneq_join_predicate_with_nulls_equality(
Expand Down Expand Up @@ -881,64 +871,8 @@ pub async fn from_substrait_rex(
),
}
}
ScalarFunctionType::Not => {
let arg = f.arguments.first().ok_or_else(|| {
DataFusionError::Substrait(
"expect one argument for `NOT` expr".to_string(),
)
})?;
match &arg.arg_type {
Some(ArgType::Value(e)) => {
let expr = from_substrait_rex(e, input_schema, extensions)
.await?
.as_ref()
.clone();
Ok(Arc::new(Expr::Not(Box::new(expr))))
}
_ => not_impl_err!("Invalid arguments for Not expression"),
}
}
ScalarFunctionType::Like => {
make_datafusion_like(false, f, input_schema, extensions).await
}
ScalarFunctionType::ILike => {
make_datafusion_like(true, f, input_schema, extensions).await
}
ScalarFunctionType::IsNull => {
let arg = f.arguments.first().ok_or_else(|| {
DataFusionError::Substrait(
"expect one argument for `IS NULL` expr".to_string(),
)
})?;
match &arg.arg_type {
Some(ArgType::Value(e)) => {
let expr = from_substrait_rex(e, input_schema, extensions)
.await?
.as_ref()
.clone();
Ok(Arc::new(Expr::IsNull(Box::new(expr))))
}
_ => not_impl_err!("Invalid arguments for IS NULL expression"),
}
}
ScalarFunctionType::IsNotNull => {
let arg = f.arguments.first().ok_or_else(|| {
DataFusionError::Substrait(
"expect one argument for `IS NOT NULL` expr".to_string(),
)
})?;
match &arg.arg_type {
Some(ArgType::Value(e)) => {
let expr = from_substrait_rex(e, input_schema, extensions)
.await?
.as_ref()
.clone();
Ok(Arc::new(Expr::IsNotNull(Box::new(expr))))
}
_ => {
not_impl_err!("Invalid arguments for IS NOT NULL expression")
}
}
ScalarFunctionType::Expr(builder) => {
builder.build(f, input_schema, extensions).await
}
}
}
Expand Down Expand Up @@ -1341,50 +1275,132 @@ fn from_substrait_null(null_type: &Type) -> Result<ScalarValue> {
}
}

async fn make_datafusion_like(
case_insensitive: bool,
f: &ScalarFunction,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
) -> Result<Arc<Expr>> {
let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" };
if f.arguments.len() != 3 {
return not_impl_err!("Expect three arguments for `{fn_name}` expr");
/// Build [`Expr`] from its name and required inputs.
struct BuiltinExprBuilder {
expr_name: String,
}

impl BuiltinExprBuilder {
pub fn try_from_name(name: &str) -> Option<Self> {
match name {
"not" | "like" | "ilike" | "is_null" | "is_not_null" => Some(Self {
expr_name: name.to_string(),
}),
_ => None,
}
}

let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
return not_impl_err!("Invalid arguments type for `{fn_name}` expr");
};
let expr = from_substrait_rex(expr_substrait, input_schema, extensions)
.await?
.as_ref()
.clone();
let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else {
return not_impl_err!("Invalid arguments type for `{fn_name}` expr");
};
let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions)
.await?
.as_ref()
.clone();
let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else {
return not_impl_err!("Invalid arguments type for `{fn_name}` expr");
};
let escape_char_expr =
from_substrait_rex(escape_char_substrait, input_schema, extensions)
pub async fn build(
self,
f: &ScalarFunction,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
) -> Result<Arc<Expr>> {
match self.expr_name.as_str() {
"not" => Self::build_not_expr(f, input_schema, extensions).await,
"like" => Self::build_like_expr(false, f, input_schema, extensions).await,
"ilike" => Self::build_like_expr(true, f, input_schema, extensions).await,
"is_null" => {
Self::build_is_null_expr(false, f, input_schema, extensions).await
}
"is_not_null" => {
Self::build_is_null_expr(true, f, input_schema, extensions).await
}
_ => {
not_impl_err!("Unsupported builtin expression: {}", self.expr_name)
}
}
}

async fn build_not_expr(
f: &ScalarFunction,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
) -> Result<Arc<Expr>> {
if f.arguments.len() != 1 {
return not_impl_err!("Expect one argument for `NOT` expr");
}
let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
return not_impl_err!("Invalid arguments type for `NOT` expr");
};
let expr = from_substrait_rex(expr_substrait, input_schema, extensions)
.await?
.as_ref()
.clone();
let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else {
return Err(DataFusionError::Substrait(format!(
"Expect Utf8 literal for escape char, but found {escape_char_expr:?}",
)));
};
Ok(Arc::new(Expr::Not(Box::new(expr))))
}

Ok(Arc::new(Expr::Like(Like {
negated: false,
expr: Box::new(expr),
pattern: Box::new(pattern),
escape_char: escape_char.map(|c| c.chars().next().unwrap()),
case_insensitive,
})))
async fn build_like_expr(
case_insensitive: bool,
f: &ScalarFunction,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
) -> Result<Arc<Expr>> {
let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" };
if f.arguments.len() != 3 {
return not_impl_err!("Expect three arguments for `{fn_name}` expr");
}

let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
return not_impl_err!("Invalid arguments type for `{fn_name}` expr");
};
let expr = from_substrait_rex(expr_substrait, input_schema, extensions)
.await?
.as_ref()
.clone();
let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else {
return not_impl_err!("Invalid arguments type for `{fn_name}` expr");
};
let pattern = from_substrait_rex(pattern_substrait, input_schema, extensions)
.await?
.as_ref()
.clone();
let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else {
return not_impl_err!("Invalid arguments type for `{fn_name}` expr");
};
let escape_char_expr =
from_substrait_rex(escape_char_substrait, input_schema, extensions)
.await?
.as_ref()
.clone();
let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else {
return substrait_err!(
"Expect Utf8 literal for escape char, but found {escape_char_expr:?}"
);
};

Ok(Arc::new(Expr::Like(Like {
negated: false,
expr: Box::new(expr),
pattern: Box::new(pattern),
escape_char: escape_char.map(|c| c.chars().next().unwrap()),
case_insensitive,
})))
}

async fn build_is_null_expr(
is_not: bool,
f: &ScalarFunction,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
) -> Result<Arc<Expr>> {
let fn_name = if is_not { "IS NOT NULL" } else { "IS NULL" };
let arg = f.arguments.first().ok_or_else(|| {
substrait_datafusion_err!("expect one argument for `{fn_name}` expr")
})?;
match &arg.arg_type {
Some(ArgType::Value(e)) => {
let expr = from_substrait_rex(e, input_schema, extensions)
.await?
.as_ref()
.clone();
if is_not {
Ok(Arc::new(Expr::IsNotNull(Box::new(expr))))
} else {
Ok(Arc::new(Expr::IsNull(Box::new(expr))))
}
}
_ => substrait_err!("Invalid arguments for `{fn_name}` expression"),
}
}
}

0 comments on commit cf8f6d4

Please sign in to comment.