diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 4c5cd3ab2855..8c2eb96a48d8 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -15,13 +15,14 @@ //! Push Down Filter optimizer rule ensures that filters are applied as early as possible in the plan use crate::optimizer::ApplyOrder; -use crate::utils::{conjunction, split_conjunction}; +use crate::utils::{conjunction, split_conjunction, split_conjunction_owned}; use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::{ internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result, }; use datafusion_expr::expr::Alias; +use datafusion_expr::Volatility; use datafusion_expr::{ and, expr_rewriter::replace_col, @@ -652,32 +653,60 @@ impl OptimizerRule for PushDownFilter { child_plan.with_new_inputs(&[new_filter])? } LogicalPlan::Projection(projection) => { - // A projection is filter-commutable, but re-writes all predicate expressions + // A projection is filter-commutable if it do not contain volatile predicates or contain volatile + // predicates that are not used in the filter. However, we should re-writes all predicate expressions. // collect projection. - let replace_map = projection - .schema - .fields() - .iter() - .enumerate() - .map(|(i, field)| { - // strip alias, as they should not be part of filters - let expr = match &projection.expr[i] { - Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(), - expr => expr.clone(), - }; - - (field.qualified_name(), expr) - }) - .collect::>(); + let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = + projection + .schema + .fields() + .iter() + .enumerate() + .map(|(i, field)| { + // strip alias, as they should not be part of filters + let expr = match &projection.expr[i] { + Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(), + expr => expr.clone(), + }; + + (field.qualified_name(), expr) + }) + .partition(|(_, value)| is_volatile_expression(value)); - // re-write all filters based on this projection - // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" - let new_filter = LogicalPlan::Filter(Filter::try_new( - replace_cols_by_name(filter.predicate.clone(), &replace_map)?, - projection.input.clone(), - )?); + let mut push_predicates = vec![]; + let mut keep_predicates = vec![]; + for expr in split_conjunction_owned(filter.predicate.clone()).into_iter() + { + if contain(&expr, &volatile_map) { + keep_predicates.push(expr); + } else { + push_predicates.push(expr); + } + } - child_plan.with_new_inputs(&[new_filter])? + match conjunction(push_predicates) { + Some(expr) => { + // re-write all filters based on this projection + // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" + let new_filter = LogicalPlan::Filter(Filter::try_new( + replace_cols_by_name(expr, &non_volatile_map)?, + projection.input.clone(), + )?); + + match conjunction(keep_predicates) { + None => child_plan.with_new_inputs(&[new_filter])?, + Some(keep_predicate) => { + let child_plan = + child_plan.with_new_inputs(&[new_filter])?; + LogicalPlan::Filter(Filter::try_new( + keep_predicate, + Arc::new(child_plan), + )?) + } + } + } + None => return Ok(None), + } } LogicalPlan::Union(union) => { let mut inputs = Vec::with_capacity(union.inputs.len()); @@ -881,6 +910,42 @@ pub fn replace_cols_by_name( }) } +/// check whether the expression is volatile predicates +fn is_volatile_expression(e: &Expr) -> bool { + let mut is_volatile = false; + e.apply(&mut |expr| { + Ok(match expr { + Expr::ScalarFunction(f) if f.fun.volatility() == Volatility::Volatile => { + is_volatile = true; + VisitRecursion::Stop + } + _ => VisitRecursion::Continue, + }) + }) + .unwrap(); + is_volatile +} + +/// check whether the expression uses the columns in `check_map`. +fn contain(e: &Expr, check_map: &HashMap) -> bool { + let mut is_contain = false; + e.apply(&mut |expr| { + Ok(if let Expr::Column(c) = &expr { + match check_map.get(&c.flat_name()) { + Some(_) => { + is_contain = true; + VisitRecursion::Stop + } + None => VisitRecursion::Continue, + } + } else { + VisitRecursion::Continue + }) + }) + .unwrap(); + is_contain +} + #[cfg(test)] mod tests { use super::*; @@ -893,9 +958,9 @@ mod tests { use datafusion_common::{DFSchema, DFSchemaRef}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ - and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr, - Expr, Extension, LogicalPlanBuilder, Operator, TableSource, TableType, - UserDefinedLogicalNodeCore, + and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, random, sum, + BinaryExpr, Expr, Extension, LogicalPlanBuilder, Operator, TableSource, + TableType, UserDefinedLogicalNodeCore, }; use std::fmt::{Debug, Formatter}; use std::sync::Arc; @@ -2712,4 +2777,79 @@ Projection: a, b \n TableScan: test2"; assert_optimized_plan_eq(&plan, expected) } + + #[test] + fn test_push_down_volatile_function_in_aggregate() -> Result<()> { + // SELECT t.a, t.r FROM (SELECT a, SUM(b), random()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5; + let table_scan = test_table_scan_with_name("test1")?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![sum(col("b"))])? + .project(vec![ + col("a"), + sum(col("b")), + add(random(), lit(1)).alias("r"), + ])? + .alias("t")? + .filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))? + .project(vec![col("t.a"), col("t.r")])? + .build()?; + + let expected_before = "Projection: t.a, t.r\ + \n Filter: t.a > Int32(5) AND t.r > Float64(0.5)\ + \n SubqueryAlias: t\ + \n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\ + \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ + \n TableScan: test1"; + assert_eq!(format!("{plan:?}"), expected_before); + + let expected_after = "Projection: t.a, t.r\ + \n SubqueryAlias: t\ + \n Filter: r > Float64(0.5)\ + \n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\ + \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ + \n TableScan: test1, full_filters=[test1.a > Int32(5)]"; + assert_optimized_plan_eq(&plan, expected_after) + } + + #[test] + fn test_push_down_volatile_function_in_join() -> Result<()> { + // SELECT t.a, t.r FROM (SELECT test1.a AS a, random() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5; + let table_scan = test_table_scan_with_name("test1")?; + let left = LogicalPlanBuilder::from(table_scan).build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan).build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::Inner, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .project(vec![col("test1.a").alias("a"), random().alias("r")])? + .alias("t")? + .filter(col("t.r").gt(lit(0.8)))? + .project(vec![col("t.a"), col("t.r")])? + .build()?; + + let expected_before = "Projection: t.a, t.r\ + \n Filter: t.r > Float64(0.8)\ + \n SubqueryAlias: t\ + \n Projection: test1.a AS a, random() AS r\ + \n Inner Join: test1.a = test2.a\ + \n TableScan: test1\ + \n TableScan: test2"; + assert_eq!(format!("{plan:?}"), expected_before); + + let expected = "Projection: t.a, t.r\ + \n SubqueryAlias: t\ + \n Filter: r > Float64(0.8)\ + \n Projection: test1.a AS a, random() AS r\ + \n Inner Join: test1.a = test2.a\ + \n TableScan: test1\ + \n TableScan: test2"; + assert_optimized_plan_eq(&plan, expected) + } }