Skip to content

Commit

Permalink
Rewrite array operator to function in parser (apache#11101)
Browse files Browse the repository at this point in the history
* rewrite func

Signed-off-by: jayzhan211 <[email protected]>

* remove rule in analyzer

Signed-off-by: jayzhan211 <[email protected]>

* fix test

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored and findepi committed Jul 16, 2024
1 parent 6bd99fc commit e0d259a
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 130 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async fn test_parameter_invalid_types() -> Result<()> {
.await;
assert_eq!(
results.unwrap_err().strip_backtrace(),
"Arrow error: Invalid argument error: Invalid comparison operation: List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) == List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })"
"type_coercion\ncaused by\nError during planning: Cannot infer common argument type for comparison operation List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) = Int32"
);
Ok(())
}
24 changes: 9 additions & 15 deletions datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -889,21 +889,18 @@ fn dictionary_coercion(
/// 2. Data type of the other side should be able to cast to string type
fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
string_coercion(lhs_type, rhs_type)
.or_else(|| list_coercion(lhs_type, rhs_type))
.or(match (lhs_type, rhs_type) {
(Utf8, from_type) | (from_type, Utf8) => {
string_concat_internal_coercion(from_type, &Utf8)
}
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
string_concat_internal_coercion(from_type, &LargeUtf8)
}
_ => None,
})
string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) {
(Utf8, from_type) | (from_type, Utf8) => {
string_concat_internal_coercion(from_type, &Utf8)
}
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
string_concat_internal_coercion(from_type, &LargeUtf8)
}
_ => None,
})
}

fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
// TODO: cast between array elements (#6558)
if lhs_type.equals_datatype(rhs_type) {
Some(lhs_type.to_owned())
} else {
Expand Down Expand Up @@ -952,10 +949,7 @@ fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
// TODO: cast between array elements (#6558)
(List(_), List(_)) => Some(lhs_type.clone()),
(List(_), _) => Some(lhs_type.clone()),
(_, List(_)) => Some(rhs_type.clone()),
_ => None,
}
}
Expand Down
109 changes: 2 additions & 107 deletions datafusion/functions-array/src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
//! Rewrites for using Array Functions
use crate::array_has::array_has_all;
use crate::concat::{array_append, array_concat, array_prepend};
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::Transformed;
use datafusion_common::utils::list_ndims;
use datafusion_common::DFSchema;
use datafusion_common::Result;
use datafusion_common::{Column, DFSchema};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::{BinaryExpr, Expr, Operator};
Expand All @@ -39,7 +37,7 @@ impl FunctionRewrite for ArrayFunctionRewriter {
fn rewrite(
&self,
expr: Expr,
schema: &DFSchema,
_schema: &DFSchema,
_config: &ConfigOptions,
) -> Result<Transformed<Expr>> {
let transformed = match expr {
Expand All @@ -61,91 +59,6 @@ impl FunctionRewrite for ArrayFunctionRewriter {
Transformed::yes(array_has_all(*right, *left))
}

// Column cases:
// 1) array_prepend/append/concat || column
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op == Operator::StringConcat
&& is_one_of_func(
&left,
&["array_append", "array_prepend", "array_concat"],
)
&& as_col(&right).is_some() =>
{
let c = as_col(&right).unwrap();
let d = schema.field_from_column(c)?.data_type();
let ndim = list_ndims(d);
match ndim {
0 => Transformed::yes(array_append(*left, *right)),
_ => Transformed::yes(array_concat(vec![*left, *right])),
}
}
// 2) select column1 || column2
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op == Operator::StringConcat
&& as_col(&left).is_some()
&& as_col(&right).is_some() =>
{
let c1 = as_col(&left).unwrap();
let c2 = as_col(&right).unwrap();
let d1 = schema.field_from_column(c1)?.data_type();
let d2 = schema.field_from_column(c2)?.data_type();
let ndim1 = list_ndims(d1);
let ndim2 = list_ndims(d2);
match (ndim1, ndim2) {
(0, _) => Transformed::yes(array_prepend(*left, *right)),
(_, 0) => Transformed::yes(array_append(*left, *right)),
_ => Transformed::yes(array_concat(vec![*left, *right])),
}
}

// Chain concat operator (a || b) || array,
// (array_concat, array_append, array_prepend) || array -> array concat
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op == Operator::StringConcat
&& is_one_of_func(
&left,
&["array_append", "array_prepend", "array_concat"],
)
&& is_func(&right, "make_array") =>
{
Transformed::yes(array_concat(vec![*left, *right]))
}

// Chain concat operator (a || b) || scalar,
// (array_concat, array_append, array_prepend) || scalar -> array append
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op == Operator::StringConcat
&& is_one_of_func(
&left,
&["array_append", "array_prepend", "array_concat"],
) =>
{
Transformed::yes(array_append(*left, *right))
}

// array || array -> array concat
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op == Operator::StringConcat
&& is_func(&left, "make_array")
&& is_func(&right, "make_array") =>
{
Transformed::yes(array_concat(vec![*left, *right]))
}

// array || scalar -> array append
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op == Operator::StringConcat && is_func(&left, "make_array") =>
{
Transformed::yes(array_append(*left, *right))
}

// scalar || array -> array prepend
Expr::BinaryExpr(BinaryExpr { left, op, right })
if op == Operator::StringConcat && is_func(&right, "make_array") =>
{
Transformed::yes(array_prepend(*left, *right))
}

_ => Transformed::no(expr),
};
Ok(transformed)
Expand All @@ -161,21 +74,3 @@ fn is_func(expr: &Expr, func_name: &str) -> bool {

func.name() == func_name
}

/// Returns true if expr is a function call with one of the specified names
fn is_one_of_func(expr: &Expr, func_names: &[&str]) -> bool {
let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else {
return false;
};

func_names.contains(&func.name())
}

/// returns Some(col) if this is Expr::Column
fn as_col(expr: &Expr) -> Option<&Column> {
if let Expr::Column(c) = expr {
Some(c)
} else {
None
}
}
72 changes: 65 additions & 7 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use arrow_schema::DataType;
use arrow_schema::TimeUnit;
use datafusion_common::utils::list_ndims;
use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value};

use datafusion_common::{
Expand Down Expand Up @@ -86,13 +87,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
StackEntry::Operator(op) => {
let right = eval_stack.pop().unwrap();
let left = eval_stack.pop().unwrap();

let expr = Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
op,
Box::new(right),
));

let expr = self.build_logical_expr(op, left, right, schema)?;
eval_stack.push(expr);
}
}
Expand All @@ -103,6 +98,69 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Ok(expr)
}

fn build_logical_expr(
&self,
op: Operator,
left: Expr,
right: Expr,
schema: &DFSchema,
) -> Result<Expr> {
// Rewrite string concat operator to function based on types
// if we get list || list then we rewrite it to array_concat()
// if we get list || non-list then we rewrite it to array_append()
// if we get non-list || list then we rewrite it to array_prepend()
// if we get string || string then we rewrite it to concat()
if op == Operator::StringConcat {
let left_type = left.get_type(schema)?;
let right_type = right.get_type(schema)?;
let left_list_ndims = list_ndims(&left_type);
let right_list_ndims = list_ndims(&right_type);

// We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient.
// The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite.
if left_list_ndims + right_list_ndims == 0 {
// TODO: concat function ignore null, but string concat takes null into consideration
// we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator`
} else if left_list_ndims == right_list_ndims {
if let Some(udf) = self.context_provider.get_function_meta("array_concat")
{
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![left, right],
)));
} else {
return internal_err!("array_concat not found");
}
} else if left_list_ndims > right_list_ndims {
if let Some(udf) = self.context_provider.get_function_meta("array_append")
{
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![left, right],
)));
} else {
return internal_err!("array_append not found");
}
} else if left_list_ndims < right_list_ndims {
if let Some(udf) =
self.context_provider.get_function_meta("array_prepend")
{
return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
udf,
vec![left, right],
)));
} else {
return internal_err!("array_append not found");
}
}
}
Ok(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
op,
Box::new(right),
)))
}

/// Generate a relational expression from a SQL expression
pub fn sql_to_expr(
&self,
Expand Down

0 comments on commit e0d259a

Please sign in to comment.