Skip to content

Commit

Permalink
Improve coerce API so it does not need DFSchema
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed May 1, 2024
1 parent d3237b2 commit 44d1dda
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 69 deletions.
2 changes: 1 addition & 1 deletion datafusion-examples/examples/expr_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ pub fn physical_expr(schema: &Schema, expr: Expr) -> Result<Arc<dyn PhysicalExpr
ExprSimplifier::new(SimplifyContext::new(&props).with_schema(df_schema.clone()));

// apply type coercion here to ensure types match
let expr = simplifier.coerce(expr, df_schema.clone())?;
let expr = simplifier.coerce(expr, &df_schema)?;

create_physical_expr(&expr, df_schema.as_ref(), &props)
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/test_util/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl TestParquetFile {
let parquet_options = ctx.copied_table_options().parquet;
if let Some(filter) = maybe_filter {
let simplifier = ExprSimplifier::new(context);
let filter = simplifier.coerce(filter, df_schema.clone()).unwrap();
let filter = simplifier.coerce(filter, &df_schema).unwrap();
let physical_filter_expr =
create_physical_expr(&filter, &df_schema, &ExecutionProps::default())?;
let parquet_exec = Arc::new(ParquetExec::new(
Expand Down
99 changes: 48 additions & 51 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
use datafusion_common::{
exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema,
DFSchemaRef, DataFusionError, Result, ScalarValue,
DataFusionError, Result, ScalarValue,
};
use datafusion_expr::expr::{
self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList,
Expand Down Expand Up @@ -99,9 +99,7 @@ fn analyze_internal(
// select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)
schema.merge(external_schema);

let mut expr_rewrite = TypeCoercionRewriter {
schema: Arc::new(schema),
};
let mut expr_rewrite = TypeCoercionRewriter { schema: &schema };

let new_expr = plan
.expressions()
Expand All @@ -116,11 +114,11 @@ fn analyze_internal(
plan.with_new_exprs(new_expr, new_inputs)
}

pub(crate) struct TypeCoercionRewriter {
pub(crate) schema: DFSchemaRef,
pub(crate) struct TypeCoercionRewriter<'a> {
pub(crate) schema: &'a DFSchema,
}

impl TreeNodeRewriter for TypeCoercionRewriter {
impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
type Node = Expr;

fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
Expand All @@ -132,14 +130,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
subquery,
outer_ref_columns,
}) => {
let new_plan = analyze_internal(&self.schema, &subquery)?;
let new_plan = analyze_internal(self.schema, &subquery)?;
Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns,
})))
}
Expr::Exists(Exists { subquery, negated }) => {
let new_plan = analyze_internal(&self.schema, &subquery.subquery)?;
let new_plan = analyze_internal(self.schema, &subquery.subquery)?;
Ok(Transformed::yes(Expr::Exists(Exists {
subquery: Subquery {
subquery: Arc::new(new_plan),
Expand All @@ -153,8 +151,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
subquery,
negated,
}) => {
let new_plan = analyze_internal(&self.schema, &subquery.subquery)?;
let expr_type = expr.get_type(&self.schema)?;
let new_plan = analyze_internal(self.schema, &subquery.subquery)?;
let expr_type = expr.get_type(self.schema)?;
let subquery_type = new_plan.schema().field(0).data_type();
let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!(
"expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery"
Expand All @@ -165,32 +163,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
outer_ref_columns: subquery.outer_ref_columns,
};
Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
Box::new(expr.cast_to(&common_type, &self.schema)?),
Box::new(expr.cast_to(&common_type, self.schema)?),
cast_subquery(new_subquery, &common_type)?,
negated,
))))
}
Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
*expr,
&self.schema,
self.schema,
)?))),
Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
get_casted_expr_for_bool_op(*expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
get_casted_expr_for_bool_op(*expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
get_casted_expr_for_bool_op(*expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
get_casted_expr_for_bool_op(*expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
get_casted_expr_for_bool_op(*expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
get_casted_expr_for_bool_op(*expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::Like(Like {
negated,
Expand All @@ -199,8 +197,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
escape_char,
case_insensitive,
}) => {
let left_type = expr.get_type(&self.schema)?;
let right_type = pattern.get_type(&self.schema)?;
let left_type = expr.get_type(self.schema)?;
let right_type = pattern.get_type(self.schema)?;
let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| {
let op_name = if case_insensitive {
"ILIKE"
Expand All @@ -211,8 +209,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
"There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
)
})?;
let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?);
let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?);
let expr = Box::new(expr.cast_to(&coerced_type, self.schema)?);
let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
Ok(Transformed::yes(Expr::Like(Like::new(
negated,
expr,
Expand All @@ -223,14 +221,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let (left_type, right_type) = get_input_types(
&left.get_type(&self.schema)?,
&left.get_type(self.schema)?,
&op,
&right.get_type(&self.schema)?,
&right.get_type(self.schema)?,
)?;
Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left.cast_to(&left_type, &self.schema)?),
Box::new(left.cast_to(&left_type, self.schema)?),
op,
Box::new(right.cast_to(&right_type, &self.schema)?),
Box::new(right.cast_to(&right_type, self.schema)?),
))))
}
Expr::Between(Between {
Expand All @@ -239,15 +237,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
low,
high,
}) => {
let expr_type = expr.get_type(&self.schema)?;
let low_type = low.get_type(&self.schema)?;
let expr_type = expr.get_type(self.schema)?;
let low_type = low.get_type(self.schema)?;
let low_coerced_type = comparison_coercion(&expr_type, &low_type)
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
))
})?;
let high_type = high.get_type(&self.schema)?;
let high_type = high.get_type(self.schema)?;
let high_coerced_type = comparison_coercion(&expr_type, &low_type)
.ok_or_else(|| {
DataFusionError::Internal(format!(
Expand All @@ -262,21 +260,21 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
))
})?;
Ok(Transformed::yes(Expr::Between(Between::new(
Box::new(expr.cast_to(&coercion_type, &self.schema)?),
Box::new(expr.cast_to(&coercion_type, self.schema)?),
negated,
Box::new(low.cast_to(&coercion_type, &self.schema)?),
Box::new(high.cast_to(&coercion_type, &self.schema)?),
Box::new(low.cast_to(&coercion_type, self.schema)?),
Box::new(high.cast_to(&coercion_type, self.schema)?),
))))
}
Expr::InList(InList {
expr,
list,
negated,
}) => {
let expr_data_type = expr.get_type(&self.schema)?;
let expr_data_type = expr.get_type(self.schema)?;
let list_data_types = list
.iter()
.map(|list_expr| list_expr.get_type(&self.schema))
.map(|list_expr| list_expr.get_type(self.schema))
.collect::<Result<Vec<_>>>()?;
let result_type =
get_coerce_type_for_list(&expr_data_type, &list_data_types);
Expand All @@ -286,11 +284,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
),
Some(coerced_type) => {
// find the coerced type
let cast_expr = expr.cast_to(&coerced_type, &self.schema)?;
let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
let cast_list_expr = list
.into_iter()
.map(|list_expr| {
list_expr.cast_to(&coerced_type, &self.schema)
list_expr.cast_to(&coerced_type, self.schema)
})
.collect::<Result<Vec<_>>>()?;
Ok(Transformed::yes(Expr::InList(InList ::new(
Expand All @@ -302,18 +300,17 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
}
}
Expr::Case(case) => {
let case = coerce_case_expression(case, &self.schema)?;
let case = coerce_case_expression(case, self.schema)?;
Ok(Transformed::yes(Expr::Case(case)))
}
Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def {
ScalarFunctionDefinition::UDF(fun) => {
let new_expr = coerce_arguments_for_signature(
args,
&self.schema,
self.schema,
fun.signature(),
)?;
let new_expr =
coerce_arguments_for_fun(new_expr, &self.schema, &fun)?;
let new_expr = coerce_arguments_for_fun(new_expr, self.schema, &fun)?;
Ok(Transformed::yes(Expr::ScalarFunction(
ScalarFunction::new_udf(fun, new_expr),
)))
Expand All @@ -331,7 +328,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
let new_expr = coerce_agg_exprs_for_signature(
&fun,
args,
&self.schema,
self.schema,
&fun.signature(),
)?;
Ok(Transformed::yes(Expr::AggregateFunction(
Expand All @@ -348,7 +345,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
AggregateFunctionDefinition::UDF(fun) => {
let new_expr = coerce_arguments_for_signature(
args,
&self.schema,
self.schema,
fun.signature(),
)?;
Ok(Transformed::yes(Expr::AggregateFunction(
Expand All @@ -375,14 +372,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
null_treatment,
}) => {
let window_frame =
coerce_window_frame(window_frame, &self.schema, &order_by)?;
coerce_window_frame(window_frame, self.schema, &order_by)?;

let args = match &fun {
expr::WindowFunctionDefinition::AggregateFunction(fun) => {
coerce_agg_exprs_for_signature(
fun,
args,
&self.schema,
self.schema,
&fun.signature(),
)?
}
Expand Down Expand Up @@ -495,7 +492,7 @@ fn coerce_frame_bound(
// For example, ROWS and GROUPS frames use `UInt64` during calculations.
fn coerce_window_frame(
window_frame: WindowFrame,
schema: &DFSchemaRef,
schema: &DFSchema,
expressions: &[Expr],
) -> Result<WindowFrame> {
let mut window_frame = window_frame;
Expand Down Expand Up @@ -531,7 +528,7 @@ fn coerce_window_frame(

// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
// The above op will be rewrite to the binary op when creating the physical op.
fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> {
fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
let left_type = expr.get_type(schema)?;
get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?;
expr.cast_to(&DataType::Boolean, schema)
Expand Down Expand Up @@ -615,7 +612,7 @@ fn coerce_agg_exprs_for_signature(
.collect()
}

fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result<Case> {
fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
// Given expressions like:
//
// CASE a1
Expand Down Expand Up @@ -1238,7 +1235,7 @@ mod test {
vec![Field::new("a", DataType::Int64, true)].into(),
std::collections::HashMap::new(),
)?);
let mut rewriter = TypeCoercionRewriter { schema };
let mut rewriter = TypeCoercionRewriter { schema: &schema };
let expr = is_true(lit(12i32).gt(lit(13i64)));
let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
let result = expr.rewrite(&mut rewriter).data()?;
Expand All @@ -1249,7 +1246,7 @@ mod test {
vec![Field::new("a", DataType::Int64, true)].into(),
std::collections::HashMap::new(),
)?);
let mut rewriter = TypeCoercionRewriter { schema };
let mut rewriter = TypeCoercionRewriter { schema: &schema };
let expr = is_true(lit(12i32).eq(lit(13i64)));
let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
let result = expr.rewrite(&mut rewriter).data()?;
Expand All @@ -1260,7 +1257,7 @@ mod test {
vec![Field::new("a", DataType::Int64, true)].into(),
std::collections::HashMap::new(),
)?);
let mut rewriter = TypeCoercionRewriter { schema };
let mut rewriter = TypeCoercionRewriter { schema: &schema };
let expr = is_true(lit(12i32).lt(lit(13i64)));
let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
let result = expr.rewrite(&mut rewriter).data()?;
Expand Down
20 changes: 4 additions & 16 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ use datafusion_common::{
cast::{as_large_list_array, as_list_array},
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
};
use datafusion_common::{
internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{InList, InSubquery};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
Expand Down Expand Up @@ -208,14 +206,8 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
///
/// See the [type coercion module](datafusion_expr::type_coercion)
/// documentation for more details on type coercion
///
// Would be nice if this API could use the SimplifyInfo
// rather than creating an DFSchemaRef coerces rather than doing
// it manually.
// https://github.com/apache/datafusion/issues/3793
pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result<Expr> {
pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result<Expr> {
let mut expr_rewrite = TypeCoercionRewriter { schema };

expr.rewrite(&mut expr_rewrite).data()
}

Expand Down Expand Up @@ -1686,7 +1678,7 @@ mod tests {
sync::Arc,
};

use datafusion_common::{assert_contains, ToDFSchema};
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
use datafusion_expr::{interval_arithmetic::Interval, *};

use crate::simplify_expressions::SimplifyContext;
Expand Down Expand Up @@ -1721,11 +1713,7 @@ mod tests {
// should fully simplify to 3 < i (though i has been coerced to i64)
let expected = lit(3i64).lt(col("i"));

// Would be nice if this API could use the SimplifyInfo
// rather than creating an DFSchemaRef coerces rather than doing
// it manually.
// https://github.com/apache/datafusion/issues/3793
let expr = simplifier.coerce(expr, schema).unwrap();
let expr = simplifier.coerce(expr, &schema).unwrap();

assert_eq!(expected, simplifier.simplify(expr).unwrap());
}
Expand Down

0 comments on commit 44d1dda

Please sign in to comment.