diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 75510bf2a2b9..e0638c2b6c01 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -123,15 +123,14 @@ impl CommonSubexprEliminate { .iter() .zip(arrays_list.iter()) .map(|(exprs, arrays)| { - let res = exprs + exprs .iter() .cloned() .zip(arrays.iter()) .map(|(expr, id_array)| { replace_common_expr(expr, id_array, expr_set, affected_id) }) - .collect::>>(); - res + .collect::>>() }) .collect::>>() } @@ -421,13 +420,15 @@ impl CommonSubexprEliminate { /// currently the implemention is not optimal, Basically I just do a top-down iteration over all the /// fn add_extra_projection(&self, plan: &LogicalPlan) -> Result> { - plan.clone() + let result = plan + .clone() .rewrite(&mut ProjectionAdder { - data_type_map: HashMap::new(), - depth_map: HashMap::new(), + insertion_point_map: HashMap::new(), depth: 0, - }) - .map(|transformed| Some(transformed.data)) + complex_exprs: HashMap::new(), + })? + .data; + Ok(Some(result)) } } impl OptimizerRule for CommonSubexprEliminate { @@ -475,8 +476,7 @@ fn to_arrays( expr_set: &mut ExprSet, expr_mask: ExprMask, ) -> Result>> { - let res = expr - .iter() + expr.iter() .map(|e| { let mut id_array = vec![]; expr_to_identifier( @@ -489,8 +489,7 @@ fn to_arrays( Ok(id_array) }) - .collect::>>(); - res + .collect::>>() } /// Build the "intermediate" projection plan that evaluates the extracted common expressions. @@ -851,9 +850,13 @@ fn replace_common_expr( } struct ProjectionAdder { - depth_map: HashMap>, - depth: u128, - data_type_map: HashMap, + // Keeps track of cumulative usage of common expressions with its corresponding data type. + // accross plan where key is unsafe nodes that cumulative tracking is invalidated. + insertion_point_map: HashMap>, + depth: usize, + // Keeps track of cumulative usage of the common expressions with its corresponding data type. + // between safe nodes. + complex_exprs: HashMap, } pub fn is_not_complex(op: &Operator) -> bool { matches!( @@ -861,14 +864,14 @@ pub fn is_not_complex(op: &Operator) -> bool { &Operator::Eq | &Operator::NotEq | &Operator::Lt | &Operator::Gt | &Operator::And ) } + impl ProjectionAdder { // TODO: adding more expressions for sub query, currently only support for Simple Binary Expressions fn get_complex_expressions( exprs: Vec, schema: DFSchemaRef, - ) -> (HashSet, HashMap) { + ) -> HashSet<(Expr, DataType)> { let mut res = HashSet::new(); - let mut expr_data_type: HashMap = HashMap::new(); for expr in exprs { match expr { Expr::BinaryExpr(BinaryExpr { @@ -880,153 +883,177 @@ impl ProjectionAdder { let l_field = schema .field_from_column(l) .expect("Field not found for left column"); - - expr_data_type - .entry(expr.clone()) - .or_insert(l_field.data_type().clone()); - res.insert(expr.clone()); + res.insert((expr.clone(), l_field.data_type().clone())); } } Expr::Cast(Cast { expr, data_type: _ }) => { - let (expr_set, type_data_map) = + let exprs_with_type = Self::get_complex_expressions(vec![*expr], schema.clone()); - res.extend(expr_set); - expr_data_type.extend(type_data_map); + res.extend(exprs_with_type); + } + Expr::Alias(Alias { + expr, + relation: _, + name: _, + }) => { + let exprs_with_type = + Self::get_complex_expressions(vec![*expr], schema.clone()); + res.extend(exprs_with_type); } - Expr::WindowFunction(WindowFunction { fun: _, args, .. }) => { - let (expr_set, type_map) = + let exprs_with_type = Self::get_complex_expressions(args, schema.clone()); - res.extend(expr_set); - expr_data_type.extend(type_map); + res.extend(exprs_with_type); } _ => {} } } - (res, expr_data_type) + res } -} -impl TreeNodeRewriter for ProjectionAdder { - type Node = LogicalPlan; - /// currently we just collect the complex bianryOP - fn f_down(&mut self, node: Self::Node) -> Result> { + fn update_expr_with_available_columns( + expr: &mut Expr, + available_columns: &[Column], + ) -> Result<()> { + match expr { + Expr::BinaryExpr(_) => { + for available_col in available_columns { + if available_col.flat_name() == expr.display_name()? { + *expr = Expr::Column(available_col.clone()); + } + } + } + Expr::WindowFunction(WindowFunction { fun: _, args, .. }) => { + args.iter_mut().try_for_each(|arg| { + Self::update_expr_with_available_columns(arg, available_columns) + })? + } + Expr::Cast(Cast { expr, .. }) => { + Self::update_expr_with_available_columns(expr, available_columns)? + } + Expr::Alias(alias) => { + Self::update_expr_with_available_columns( + &mut alias.expr, + available_columns, + )?; + } + _ => { + // cannot rewrite + } + } + Ok(()) + } + + // Assumes operators doesn't modify name of the fields. + // Otherwise this operation is not safe. + fn extend_with_exprs(&mut self, node: &LogicalPlan) { // use depth to trace where we are in the LogicalPlan tree - self.depth += 1; - // extract all expressions and check whether it contains in depth_sets + // extract all expressions + check whether it contains in depth_sets let exprs = node.expressions(); - let depth_set = self.depth_map.entry(self.depth).or_default(); let mut schema = node.schema().deref().clone(); for ip in node.inputs() { schema.merge(ip.schema()); } - let (extended_set, data_map) = - Self::get_complex_expressions(exprs, Arc::new(schema)); - depth_set.extend(extended_set); - self.data_type_map.extend(data_map); - Ok(Transformed::no(node)) + let expr_with_type = Self::get_complex_expressions(exprs, Arc::new(schema)); + for (expr, dtype) in expr_with_type { + let (_, count) = self.complex_exprs.entry(expr).or_insert_with(|| (dtype, 0)); + *count += 1; + } } - fn f_up(&mut self, node: Self::Node) -> Result> { - let current_depth_schema = - self.depth_map.get(&self.depth).cloned().unwrap_or_default(); +} +impl TreeNodeRewriter for ProjectionAdder { + type Node = LogicalPlan; + /// currently we just collect the complex bianryOP - // get the intersected part looking up - let added_expr = self - .depth_map - .iter() - .filter(|(&depth, _)| depth < self.depth) - .fold(current_depth_schema.clone(), |acc, (_, expr)| { - acc.intersection(expr).cloned().collect() - }); - // in projection, we are trying to get intersect with lower one and if some column is not used in upper one, we just abandon them - if let LogicalPlan::Projection(_) = node.clone() { - // we get the expressions that has intersect with deeper layer since those expressions could be rewrite - let mut cross_expr_deeper = HashSet::new(); - - self.depth_map - .iter() - .filter(|(&depth, _)| depth > self.depth) - .for_each(|(_, exprs)| { - // get intersection - let intersection = current_depth_schema - .intersection(exprs) - .cloned() - .collect::>(); - cross_expr_deeper = - cross_expr_deeper.union(&intersection).cloned().collect(); - }); - let mut project_exprs = vec![]; - if !cross_expr_deeper.is_empty() { - let current_expressions = node.expressions(); - for expr in current_expressions { - if cross_expr_deeper.contains(&expr) { - let f = DFField::new_unqualified( - &expr.to_string(), - self.data_type_map[&expr].clone(), - true, - ); - project_exprs.push(Expr::Column(f.qualified_column())); - cross_expr_deeper.remove(&expr); - } else { - project_exprs.push(expr); - } - } - let new_projection = LogicalPlan::Projection(Projection::try_new( - project_exprs, - Arc::new(node.inputs()[0].clone()), - )?); - self.depth -= 1; - return Ok(Transformed::yes(new_projection)); + fn f_down(&mut self, node: Self::Node) -> Result> { + // Insert for other end points + self.depth += 1; + match node { + LogicalPlan::TableScan(_) => { + // Stop tracking cumulative usage at the source. + let complex_exprs = std::mem::take(&mut self.complex_exprs); + self.insertion_point_map + .insert(self.depth - 1, complex_exprs); + Ok(Transformed::no(node)) + } + LogicalPlan::Sort(_) | LogicalPlan::Filter(_) | LogicalPlan::Window(_) => { + // These are safe operators where, expression identity is preserved during operation. + self.extend_with_exprs(&node); + Ok(Transformed::no(node)) + } + LogicalPlan::Projection(_) => { + // Stop tracking cumulative usage at the projection since it may invalidate expression identity. + let complex_exprs = std::mem::take(&mut self.complex_exprs); + self.insertion_point_map + .insert(self.depth - 1, complex_exprs); + // Start tracking common expressions from now on including projection. + self.extend_with_exprs(&node); + Ok(Transformed::no(node)) + } + _ => { + // Unsupported operators + self.complex_exprs.clear(); + Ok(Transformed::no(node)) } } + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + let cached_exprs = self + .insertion_point_map + .get(&self.depth) + .cloned() + .unwrap_or_default(); self.depth -= 1; // do not do extra things - if added_expr.is_empty() { + let should_add_projection = + cached_exprs.iter().any(|(_expr, (_, count))| *count > 1); + + let children = node.inputs(); + if children.len() != 1 { + // Only can rewrite node with single child return Ok(Transformed::no(node)); } - match node { - // do not add for Projections - LogicalPlan::Projection(_) - | LogicalPlan::TableScan(_) - | LogicalPlan::Join(_) => Ok(Transformed::no(node)), - - _ => { - // avoid recursive add projections - if added_expr.iter().any(|expr| { - node.inputs()[0] - .schema() - .has_column_with_unqualified_name(&expr.to_string()) - }) { - return Ok(Transformed::no(node)); - } - - let mut field_set = HashSet::new(); - let mut project_exprs = vec![]; - // here we can deduce that - for expr in added_expr { - let f = DFField::new_unqualified( - &expr.to_string(), - self.data_type_map[&expr].clone(), - true, - ); + let child = children[0].clone(); + let child = if should_add_projection { + let mut field_set = HashSet::new(); + let mut project_exprs = vec![]; + for (expr, (dtype, count)) in &cached_exprs { + if *count > 1 { + let f = + DFField::new_unqualified(&expr.to_string(), dtype.clone(), true); field_set.insert(f.name().to_owned()); project_exprs.push(expr.clone().alias(expr.to_string())); } - for field in node.inputs()[0].schema().fields() { - if field_set.insert(field.qualified_name()) { - project_exprs.push(Expr::Column(field.qualified_column())); - } + } + // Do not lose fields in the child. + for field in child.schema().fields() { + if field_set.insert(field.qualified_name()) { + project_exprs.push(Expr::Column(field.qualified_column())); } - // adding new plan here - let new_plan = LogicalPlan::Projection(Projection::try_new( - project_exprs, - Arc::new(node.inputs()[0].clone()), - )?); - Ok(Transformed::yes( - node.with_new_exprs(node.expressions(), [new_plan].to_vec())?, - )) } - } + + // adding new plan here + LogicalPlan::Projection(Projection::try_new( + project_exprs, + Arc::new(node.inputs()[0].clone()), + )?) + } else { + child + }; + let mut expressions = node.expressions(); + let available_columns = child + .schema() + .fields() + .iter() + .map(|field| field.qualified_column()) + .collect::>(); + // Replace expressions with its pre-computed variant if available. + expressions.iter_mut().try_for_each(|expr| { + Self::update_expr_with_available_columns(expr, &available_columns) + })?; + let new_node = node.with_new_exprs(expressions, [child].to_vec())?; + Ok(Transformed::yes(new_node)) } } #[cfg(test)] diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 8ad773d89e13..39dc56d2af93 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -26,7 +26,6 @@ use std::collections::HashSet; use std::sync::Arc; -use crate::common_subexpr_eliminate::is_not_complex; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; @@ -694,26 +693,10 @@ fn indices_referred_by_expr( // TODO: Support more Expressions let mut cols = expr.to_columns()?; outer_columns(expr, &mut cols); - let mut res_vec: Vec = cols + Ok(cols .iter() .flat_map(|col| input_schema.index_of_column(col)) - .collect(); - match expr { - Expr::BinaryExpr(BinaryExpr { op, .. }) if !is_not_complex(op) => { - if let Some(index) = - input_schema.index_of_column_by_name(None, &expr.to_string())? - { - match res_vec.binary_search(&index) { - Ok(_) => {} - Err(pos) => { - res_vec.insert(pos, index); - } - } - } - } - _ => {} - } - Ok(res_vec) + .collect()) } /// Gets all required indices for the input; i.e. those required by the parent diff --git a/datafusion/sqllogictest/test_files/project_complex_sub_query.slt b/datafusion/sqllogictest/test_files/project_complex_sub_query.slt index 5ba1f493fa37..56e728325b79 100644 --- a/datafusion/sqllogictest/test_files/project_complex_sub_query.slt +++ b/datafusion/sqllogictest/test_files/project_complex_sub_query.slt @@ -36,17 +36,15 @@ explain SELECT c3+c4, SUM(c3+c4) OVER() FROM t; ---- logical_plan -Projection: t.c3 + t.c4, SUM(t.c3 + t.c4) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ---WindowAggr: windowExpr=[[SUM(CAST(t.c3 + t.c4 AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] -----Projection: t.c3 + t.c4 AS t.c3 + t.c4, t.c3, t.c4 -------TableScan: t projection=[c3, c4] +WindowAggr: windowExpr=[[SUM(CAST(t.c3 + t.c4 AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +--Projection: t.c3 + t.c4 AS t.c3 + t.c4 +----TableScan: t projection=[c3, c4] physical_plan -ProjectionExec: expr=[t.c3 + t.c4@0 as t.c3 + t.c4, SUM(t.c3 + t.c4) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as SUM(t.c3 + t.c4) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] ---WindowAggExec: wdw=[SUM(t.c3 + t.c4) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(t.c3 + t.c4) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] -----CoalescePartitionsExec -------ProjectionExec: expr=[c3@0 + c4@1 as t.c3 + t.c4, c3@0 as c3, c4@1 as c4] ---------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/project_complex_expression.csv]]}, projection=[c3, c4], has_header=true +WindowAggExec: wdw=[SUM(t.c3 + t.c4) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(t.c3 + t.c4) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] +--CoalescePartitionsExec +----ProjectionExec: expr=[c3@0 + c4@1 as t.c3 + t.c4] +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/project_complex_expression.csv]]}, projection=[c3, c4], has_header=true query TT explain SELECT c3+c4, SUM(c3+c4) OVER(order by c3+c4) @@ -91,17 +89,15 @@ explain SELECT c3-c4, SUM(c3-c4) OVER() FROM t; ---- logical_plan -Projection: t.c3 - t.c4, SUM(t.c3 - t.c4) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ---WindowAggr: windowExpr=[[SUM(CAST(t.c3 - t.c4 AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] -----Projection: t.c3 - t.c4 AS t.c3 - t.c4, t.c3, t.c4 -------TableScan: t projection=[c3, c4] +WindowAggr: windowExpr=[[SUM(CAST(t.c3 - t.c4 AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +--Projection: t.c3 - t.c4 AS t.c3 - t.c4 +----TableScan: t projection=[c3, c4] physical_plan -ProjectionExec: expr=[t.c3 - t.c4@0 as t.c3 - t.c4, SUM(t.c3 - t.c4) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as SUM(t.c3 - t.c4) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] ---WindowAggExec: wdw=[SUM(t.c3 - t.c4) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(t.c3 - t.c4) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] -----CoalescePartitionsExec -------ProjectionExec: expr=[c3@0 - c4@1 as t.c3 - t.c4, c3@0 as c3, c4@1 as c4] ---------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/project_complex_expression.csv]]}, projection=[c3, c4], has_header=true +WindowAggExec: wdw=[SUM(t.c3 - t.c4) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(t.c3 - t.c4) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] +--CoalescePartitionsExec +----ProjectionExec: expr=[c3@0 - c4@1 as t.c3 - t.c4] +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/project_complex_expression.csv]]}, projection=[c3, c4], has_header=true query II SELECT c3+c4, SUM(c3+c4) OVER()