Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: common_subexpr_eliminate rule should not apply to short-circuit expression #8928

Merged
merged 7 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ use datafusion_expr::expr::Alias;
use datafusion_expr::logical_plan::{
Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
};
use datafusion_expr::{col, Expr, ExprSchemable};
use datafusion_expr::{
col, expr::ScalarFunction, BinaryExpr, BuiltinScalarFunction, Expr, ExprSchemable,
Operator, ScalarFunctionDefinition,
};

/// A map from expression's identifier to tuple including
/// - the expression itself (cloned)
Expand Down Expand Up @@ -616,8 +619,8 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {

fn pre_visit(&mut self, expr: &Expr) -> Result<VisitRecursion> {
// related to https://github.com/apache/arrow-datafusion/issues/8814
// If the expr contain volatile expression or is a case expression, skip it.
if matches!(expr, Expr::Case(..)) || is_volatile_expression(expr)? {
// If the expr contain volatile expression or is a short-circuit expression, skip it.
if is_short_circuit_expression(expr) || is_volatile_expression(expr)? {
return Ok(VisitRecursion::Skip);
}
self.visit_stack
Expand Down Expand Up @@ -655,6 +658,20 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
}
}

/// Check if the expression is short-circuit expression
fn is_short_circuit_expression(expr: &Expr) -> bool {
haohuaijin marked this conversation as resolved.
Show resolved Hide resolved
match expr {
Expr::ScalarFunction(ScalarFunction { func_def, .. }) => {
matches!(func_def, ScalarFunctionDefinition::BuiltIn(fun) if *fun == BuiltinScalarFunction::Coalesce)
}
Expr::BinaryExpr(BinaryExpr { op, .. }) => {
matches!(op, Operator::And | Operator::Or)
}
Expr::Case { .. } => true,
_ => false,
}
}

/// Go through an expression tree and generate identifier for every node in this tree.
fn expr_to_identifier(
expr: &Expr,
Expand Down Expand Up @@ -696,7 +713,13 @@ struct CommonSubexprRewriter<'a> {
impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
type N = Expr;

fn pre_visit(&mut self, _: &Expr) -> Result<RewriteRecursion> {
fn pre_visit(&mut self, expr: &Expr) -> Result<RewriteRecursion> {
// The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate
// the `id_array`, which records the expr's identifier used to rewrite expr. So if we
// skip an expr in `ExprIdentifierVisitor`, we should skip it here, too.
if is_short_circuit_expression(expr) || is_volatile_expression(expr)? {
return Ok(RewriteRecursion::Stop);
}
if self.curr_index >= self.id_array.len()
|| self.max_series_number > self.id_array[self.curr_index].0
{
Expand Down Expand Up @@ -1249,12 +1272,11 @@ mod test {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.filter(lit(1).gt(col("a")).and(lit(1).gt(col("a"))))?
.filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
.build()?;

let expected = "Projection: test.a, test.b, test.c\
\n Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\
\n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c\
\n Filter: Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a - Int32(10) > Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a\n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
Expand Down
30 changes: 30 additions & 0 deletions datafusion/sqllogictest/test_files/select.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1129,5 +1129,35 @@ FROM t AS A, (SELECT * FROM t WHERE x = 0) AS B;
0 0
0 0

query TT
explain select coalesce(1, y/x), coalesce(2, y/x) from t;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because of the reason describe in #8927, use plan to test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please add some comments here explaining the rationale and what the expected outputs are (so future readers know if changes are expected)

Something like this perhaps:

# Expressions that short circuit should not be refactored out as that may cause side effects (divide by zero)
# at plan time that would not actually happen during execution

Also, can you please add the tests that now pass (e.g. select coalesce(1, y/x), coalesce(2, y/x) from t;) so if someone breaks this code by accident, those queries would start failing, which might be easier to quickly tell is incorrect

----
logical_plan
Projection: coalesce(Int64(1), CAST(t.y / t.x AS Int64)), coalesce(Int64(2), CAST(t.y / t.x AS Int64))
--TableScan: t projection=[x, y]
physical_plan
ProjectionExec: expr=[coalesce(1, CAST(y@1 / x@0 AS Int64)) as coalesce(Int64(1),t.y / t.x), coalesce(2, CAST(y@1 / x@0 AS Int64)) as coalesce(Int64(2),t.y / t.x)]
--MemoryExec: partitions=1, partition_sizes=[1]

query TT
EXPLAIN SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t;
----
logical_plan
Projection: t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y > Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x
--TableScan: t projection=[x, y]
physical_plan
ProjectionExec: expr=[y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 as t.y > Int64(0) AND Int64(1) / t.y < Int64(1), x@0 > 0 AND y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x]
--MemoryExec: partitions=1, partition_sizes=[1]

query TT
EXPLAIN SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t;
----
logical_plan
Projection: t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y = Int64(0) OR Int64(1) / t.y < Int64(1), t.x = Int32(0) OR t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x
--TableScan: t projection=[x, y]
physical_plan
ProjectionExec: expr=[y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 as t.y = Int64(0) OR Int64(1) / t.y < Int64(1), x@0 = 0 OR y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x]
--MemoryExec: partitions=1, partition_sizes=[1]

statement ok
DROP TABLE t;