Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: don't push down volatile predicates in projection #7909

Merged
merged 11 commits into from
Oct 25, 2023
194 changes: 167 additions & 27 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
//! Push Down Filter optimizer rule ensures that filters are applied as early as possible in the plan

use crate::optimizer::ApplyOrder;
use crate::utils::{conjunction, split_conjunction};
use crate::utils::{conjunction, split_conjunction, split_conjunction_owned};
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
use datafusion_common::{
internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result,
};
use datafusion_expr::expr::Alias;
use datafusion_expr::Volatility;
use datafusion_expr::{
and,
expr_rewriter::replace_col,
Expand Down Expand Up @@ -652,32 +653,60 @@ impl OptimizerRule for PushDownFilter {
child_plan.with_new_inputs(&[new_filter])?
}
LogicalPlan::Projection(projection) => {
// A projection is filter-commutable, but re-writes all predicate expressions
// A projection is filter-commutable if it do not contain volatile predicates or contain volatile
// predicates that are not used in the filter. However, we should re-writes all predicate expressions.
// collect projection.
let replace_map = projection
.schema
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
// strip alias, as they should not be part of filters
let expr = match &projection.expr[i] {
Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(),
expr => expr.clone(),
};

(field.qualified_name(), expr)
})
.collect::<HashMap<_, _>>();
let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) =
projection
.schema
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
// strip alias, as they should not be part of filters
let expr = match &projection.expr[i] {
Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(),
expr => expr.clone(),
};

(field.qualified_name(), expr)
})
.partition(|(_, value)| is_volatile_expression(value));

// re-write all filters based on this projection
// E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
let new_filter = LogicalPlan::Filter(Filter::try_new(
replace_cols_by_name(filter.predicate.clone(), &replace_map)?,
projection.input.clone(),
)?);
let mut push_predicates = vec![];
let mut keep_predicates = vec![];
for expr in split_conjunction_owned(filter.predicate.clone()).into_iter()
{
if contain(&expr, &volatile_map) {
keep_predicates.push(expr);
} else {
push_predicates.push(expr);
}
}

child_plan.with_new_inputs(&[new_filter])?
match conjunction(push_predicates) {
Some(expr) => {
// re-write all filters based on this projection
// E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
let new_filter = LogicalPlan::Filter(Filter::try_new(
replace_cols_by_name(expr, &non_volatile_map)?,
projection.input.clone(),
)?);

match conjunction(keep_predicates) {
None => child_plan.with_new_inputs(&[new_filter])?,
Some(keep_predicate) => {
let child_plan =
child_plan.with_new_inputs(&[new_filter])?;
LogicalPlan::Filter(Filter::try_new(
keep_predicate,
Arc::new(child_plan),
)?)
}
}
}
None => return Ok(None),
}
}
LogicalPlan::Union(union) => {
let mut inputs = Vec::with_capacity(union.inputs.len());
Expand Down Expand Up @@ -881,6 +910,42 @@ pub fn replace_cols_by_name(
})
}

/// check whether the expression is volatile predicates
fn is_volatile_expression(e: &Expr) -> bool {
let mut is_volatile = false;
e.apply(&mut |expr| {
Ok(match expr {
Expr::ScalarFunction(f) if f.fun.volatility() == Volatility::Volatile => {
is_volatile = true;
VisitRecursion::Stop
}
_ => VisitRecursion::Continue,
})
})
.unwrap();
is_volatile
}

/// check whether the expression uses the columns in `check_map`.
fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
let mut is_contain = false;
e.apply(&mut |expr| {
Ok(if let Expr::Column(c) = &expr {
match check_map.get(&c.flat_name()) {
Some(_) => {
is_contain = true;
VisitRecursion::Stop
}
None => VisitRecursion::Continue,
}
} else {
VisitRecursion::Continue
})
})
.unwrap();
is_contain
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -893,9 +958,9 @@ mod tests {
use datafusion_common::{DFSchema, DFSchemaRef};
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr,
Expr, Extension, LogicalPlanBuilder, Operator, TableSource, TableType,
UserDefinedLogicalNodeCore,
and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, random, sum,
BinaryExpr, Expr, Extension, LogicalPlanBuilder, Operator, TableSource,
TableType, UserDefinedLogicalNodeCore,
};
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
Expand Down Expand Up @@ -2712,4 +2777,79 @@ Projection: a, b
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}

#[test]
fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
// SELECT t.a, t.r FROM (SELECT a, SUM(b), random()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5;
let table_scan = test_table_scan_with_name("test1")?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])?
.project(vec![
col("a"),
sum(col("b")),
add(random(), lit(1)).alias("r"),
])?
.alias("t")?
.filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
.project(vec![col("t.a"), col("t.r")])?
.build()?;

let expected_before = "Projection: t.a, t.r\
\n Filter: t.a > Int32(5) AND t.r > Float64(0.5)\
\n SubqueryAlias: t\
\n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\
\n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\
\n TableScan: test1";
assert_eq!(format!("{plan:?}"), expected_before);

let expected_after = "Projection: t.a, t.r\
\n SubqueryAlias: t\
\n Filter: r > Float64(0.5)\
\n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\
\n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\
\n TableScan: test1, full_filters=[test1.a > Int32(5)]";
assert_optimized_plan_eq(&plan, expected_after)
}

#[test]
fn test_push_down_volatile_function_in_join() -> Result<()> {
// SELECT t.a, t.r FROM (SELECT test1.a AS a, random() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5;
let table_scan = test_table_scan_with_name("test1")?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan).build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(
vec![Column::from_qualified_name("test1.a")],
vec![Column::from_qualified_name("test2.a")],
),
None,
)?
.project(vec![col("test1.a").alias("a"), random().alias("r")])?
.alias("t")?
.filter(col("t.r").gt(lit(0.8)))?
.project(vec![col("t.a"), col("t.r")])?
.build()?;

let expected_before = "Projection: t.a, t.r\
\n Filter: t.r > Float64(0.8)\
\n SubqueryAlias: t\
\n Projection: test1.a AS a, random() AS r\
\n Inner Join: test1.a = test2.a\
\n TableScan: test1\
\n TableScan: test2";
assert_eq!(format!("{plan:?}"), expected_before);

let expected = "Projection: t.a, t.r\
\n SubqueryAlias: t\
\n Filter: r > Float64(0.8)\
\n Projection: test1.a AS a, random() AS r\
\n Inner Join: test1.a = test2.a\
\n TableScan: test1\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
}