-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding complex expressions projections for Subquery #9719
Changes from 11 commits
f495921
cc1d412
1c64d23
3e96839
71b5bd7
30d29a0
81523d1
1c4010b
4add9cb
251c48d
117e490
113e232
d4ac68f
01b23a3
ee962f1
40869a4
4396682
83002ce
b77efab
59e0385
c1a16f9
fb9e4d6
bdf626a
2c70a15
31cc2d7
e7f4b49
313cccd
759226b
4596f90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,11 +17,13 @@ | |
|
||
//! Eliminate common sub-expression. | ||
|
||
use std::collections::{BTreeSet, HashMap}; | ||
use std::sync::Arc; | ||
|
||
use crate::utils::is_volatile_expression; | ||
use crate::{utils, OptimizerConfig, OptimizerRule}; | ||
use datafusion_expr::{Expr, Operator}; | ||
use std::collections::{BTreeSet, HashMap}; | ||
use std::ops::Deref; | ||
|
||
use std::sync::Arc; | ||
|
||
use arrow::datatypes::DataType; | ||
use datafusion_common::tree_node::{ | ||
|
@@ -31,9 +33,10 @@ use datafusion_common::tree_node::{ | |
use datafusion_common::{ | ||
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, | ||
}; | ||
use datafusion_expr::expr::Alias; | ||
use datafusion_expr::expr::{Alias, WindowFunction}; | ||
use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window}; | ||
use datafusion_expr::{col, Expr, ExprSchemable}; | ||
use datafusion_expr::{col, BinaryExpr, Cast, ExprSchemable}; | ||
use hashbrown::HashSet; | ||
|
||
/// A map from expression's identifier to tuple including | ||
/// - the expression itself (cloned) | ||
|
@@ -120,14 +123,15 @@ impl CommonSubexprEliminate { | |
.iter() | ||
.zip(arrays_list.iter()) | ||
.map(|(exprs, arrays)| { | ||
exprs | ||
let res = exprs | ||
.iter() | ||
.cloned() | ||
.zip(arrays.iter()) | ||
.map(|(expr, id_array)| { | ||
replace_common_expr(expr, id_array, expr_set, affected_id) | ||
}) | ||
.collect::<Result<Vec<_>>>() | ||
.collect::<Result<Vec<_>>>(); | ||
res | ||
}) | ||
.collect::<Result<Vec<_>>>() | ||
} | ||
|
@@ -146,12 +150,11 @@ impl CommonSubexprEliminate { | |
self.rewrite_exprs_list(exprs_list, arrays_list, expr_set, &mut affected_id)?; | ||
|
||
let mut new_input = self | ||
.try_optimize(input, config)? | ||
.common_optimize(input, config)? | ||
.unwrap_or_else(|| input.clone()); | ||
if !affected_id.is_empty() { | ||
new_input = build_common_expr_project_plan(new_input, affected_id, expr_set)?; | ||
} | ||
|
||
Ok((rewrite_exprs, new_input)) | ||
} | ||
|
||
|
@@ -187,7 +190,6 @@ impl CommonSubexprEliminate { | |
window_exprs.push(window_expr); | ||
arrays_per_window.push(arrays); | ||
} | ||
|
||
let mut window_exprs = window_exprs | ||
.iter() | ||
.map(|expr| expr.as_slice()) | ||
|
@@ -196,7 +198,6 @@ impl CommonSubexprEliminate { | |
.iter() | ||
.map(|arrays| arrays.as_slice()) | ||
.collect::<Vec<_>>(); | ||
|
||
assert_eq!(window_exprs.len(), arrays_per_window.len()); | ||
let (mut new_expr, new_input) = self.rewrite_expr( | ||
&window_exprs, | ||
|
@@ -206,14 +207,12 @@ impl CommonSubexprEliminate { | |
config, | ||
)?; | ||
assert_eq!(window_exprs.len(), new_expr.len()); | ||
|
||
// Construct consecutive window operator, with their corresponding new window expressions. | ||
plan = new_input; | ||
while let Some(new_window_expr) = new_expr.pop() { | ||
// Since `new_expr` and `window_exprs` length are same. We can safely `.unwrap` here. | ||
let orig_window_expr = window_exprs.pop().unwrap(); | ||
assert_eq!(new_window_expr.len(), orig_window_expr.len()); | ||
|
||
// Rename new re-written window expressions with original name (by giving alias) | ||
// Otherwise we may receive schema error, in subsequent operators. | ||
let new_window_expr = new_window_expr | ||
|
@@ -226,7 +225,6 @@ impl CommonSubexprEliminate { | |
.collect::<Result<Vec<_>>>()?; | ||
plan = LogicalPlan::Window(Window::try_new(new_window_expr, Arc::new(plan))?); | ||
} | ||
|
||
Ok(plan) | ||
} | ||
|
||
|
@@ -360,16 +358,13 @@ impl CommonSubexprEliminate { | |
|
||
// Visit expr list and build expr identifier to occuring count map (`expr_set`). | ||
let arrays = to_arrays(&expr, input_schema, &mut expr_set, ExprMask::Normal)?; | ||
|
||
let (mut new_expr, new_input) = | ||
self.rewrite_expr(&[&expr], &[&arrays], input, &expr_set, config)?; | ||
|
||
plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input]) | ||
} | ||
} | ||
|
||
impl OptimizerRule for CommonSubexprEliminate { | ||
fn try_optimize( | ||
impl CommonSubexprEliminate { | ||
fn common_optimize( | ||
&self, | ||
plan: &LogicalPlan, | ||
config: &dyn OptimizerConfig, | ||
|
@@ -413,8 +408,8 @@ impl OptimizerRule for CommonSubexprEliminate { | |
|
||
let original_schema = plan.schema().clone(); | ||
match optimized_plan { | ||
Some(LogicalPlan::Projection(_)) => Ok(optimized_plan), | ||
Some(optimized_plan) if optimized_plan.schema() != &original_schema => { | ||
// add an additional projection if the output schema changed. | ||
Ok(Some(build_recover_project_plan( | ||
&original_schema, | ||
optimized_plan, | ||
|
@@ -423,6 +418,33 @@ impl OptimizerRule for CommonSubexprEliminate { | |
plan => Ok(plan), | ||
} | ||
} | ||
/// currently the implemention is not optimal, Basically I just do a top-down iteration over all the | ||
/// | ||
fn add_extra_projection(&self, plan: &LogicalPlan) -> Result<Option<LogicalPlan>> { | ||
plan.clone() | ||
.rewrite(&mut ProjectionAdder { | ||
data_type_map: HashMap::new(), | ||
depth_map: HashMap::new(), | ||
depth: 0, | ||
}) | ||
.map(|transformed| Some(transformed.data)) | ||
} | ||
} | ||
impl OptimizerRule for CommonSubexprEliminate { | ||
fn try_optimize( | ||
&self, | ||
plan: &LogicalPlan, | ||
config: &dyn OptimizerConfig, | ||
) -> Result<Option<LogicalPlan>> { | ||
let optimized_plan_option = self.common_optimize(plan, config)?; | ||
|
||
// println!("optimized plan option is {:?}", optimized_plan_option); | ||
let plan = match optimized_plan_option { | ||
Some(plan) => plan, | ||
_ => plan.clone(), | ||
}; | ||
self.add_extra_projection(&plan) | ||
} | ||
|
||
fn name(&self) -> &str { | ||
"common_sub_expression_eliminate" | ||
|
@@ -454,7 +476,8 @@ fn to_arrays( | |
expr_set: &mut ExprSet, | ||
expr_mask: ExprMask, | ||
) -> Result<Vec<Vec<(usize, String)>>> { | ||
expr.iter() | ||
let res = expr | ||
.iter() | ||
.map(|e| { | ||
let mut id_array = vec![]; | ||
expr_to_identifier( | ||
|
@@ -467,7 +490,8 @@ fn to_arrays( | |
|
||
Ok(id_array) | ||
}) | ||
.collect::<Result<Vec<_>>>() | ||
.collect::<Result<Vec<_>>>(); | ||
res | ||
} | ||
|
||
/// Build the "intermediate" projection plan that evaluates the extracted common expressions. | ||
|
@@ -498,7 +522,6 @@ fn build_common_expr_project_plan( | |
project_exprs.push(Expr::Column(field.qualified_column())); | ||
} | ||
} | ||
|
||
Ok(LogicalPlan::Projection(Projection::try_new( | ||
project_exprs, | ||
Arc::new(input), | ||
|
@@ -693,7 +716,6 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { | |
self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); | ||
|
||
let data_type = expr.get_type(&self.input_schema)?; | ||
|
||
self.expr_set | ||
.entry(desc) | ||
.or_insert_with(|| (expr.clone(), 0, data_type)) | ||
|
@@ -829,6 +851,147 @@ fn replace_common_expr( | |
.data() | ||
} | ||
|
||
struct ProjectionAdder { | ||
depth_map: HashMap<u128, HashSet<Expr>>, | ||
depth: u128, | ||
data_type_map: HashMap<Expr, DataType>, | ||
} | ||
pub fn is_not_complex(op: &Operator) -> bool { | ||
matches!( | ||
op, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sugg:
|
||
&Operator::Eq | &Operator::NotEq | &Operator::Lt | &Operator::Gt | &Operator::And | ||
) | ||
} | ||
impl ProjectionAdder { | ||
// TODO: adding more expressions for sub query, currently only support for Simple Binary Expressions | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a tracking issue for this that we could refer to in the comment? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll add that |
||
fn get_complex_expressions( | ||
exprs: Vec<Expr>, | ||
schema: DFSchemaRef, | ||
) -> (HashSet<Expr>, HashMap<Expr, DataType>) { | ||
let mut res = HashSet::new(); | ||
let mut expr_data_type: HashMap<Expr, DataType> = HashMap::new(); | ||
for expr in exprs { | ||
match expr { | ||
Expr::BinaryExpr(BinaryExpr { | ||
left: ref l_box, | ||
op, | ||
right: ref r_box, | ||
}) if !is_not_complex(&op) => { | ||
if let (Expr::Column(l), Expr::Column(_r)) = (&**l_box, &**r_box) { | ||
let l_field = schema | ||
.field_from_column(l) | ||
.expect("Field not found for left column"); | ||
|
||
// res.insert(DFField::new_unqualified( | ||
// &expr.to_string(), | ||
// l_field.data_type().clone(), | ||
// true, | ||
// )); | ||
expr_data_type | ||
.entry(expr.clone()) | ||
.or_insert(l_field.data_type().clone()); | ||
res.insert(expr.clone()); | ||
} | ||
} | ||
Expr::Cast(Cast { expr, data_type: _ }) => { | ||
let (expr_set, type_data_map) = | ||
Self::get_complex_expressions(vec![*expr], schema.clone()); | ||
res.extend(expr_set); | ||
expr_data_type.extend(type_data_map); | ||
} | ||
|
||
Expr::WindowFunction(WindowFunction { fun: _, args, .. }) => { | ||
let (expr_set, type_map) = | ||
Self::get_complex_expressions(args, schema.clone()); | ||
res.extend(expr_set); | ||
expr_data_type.extend(type_map); | ||
} | ||
_ => {} | ||
} | ||
} | ||
(res, expr_data_type) | ||
} | ||
} | ||
impl TreeNodeRewriter for ProjectionAdder { | ||
type Node = LogicalPlan; | ||
/// currently we just collect the complex bianryOP | ||
|
||
fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> { | ||
// use depth to trace where we are in the LogicalPlan tree | ||
self.depth += 1; | ||
// extract all expressions + check whether it contains in depth_sets | ||
let exprs = node.expressions(); | ||
let depth_set = self.depth_map.entry(self.depth).or_default(); | ||
let mut schema = node.schema().deref().clone(); | ||
for ip in node.inputs() { | ||
schema.merge(ip.schema()); | ||
} | ||
let (extended_set, data_map) = | ||
Self::get_complex_expressions(exprs, Arc::new(schema)); | ||
depth_set.extend(extended_set); | ||
self.data_type_map.extend(data_map); | ||
Ok(Transformed::no(node)) | ||
} | ||
fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> { | ||
let current_depth_schema = | ||
self.depth_map.get(&self.depth).cloned().unwrap_or_default(); | ||
|
||
// get the intersected part | ||
let added_expr = self | ||
.depth_map | ||
.iter() | ||
.filter(|(&depth, _)| depth < self.depth) | ||
.fold(current_depth_schema, |acc, (_, expr)| { | ||
acc.intersection(expr).cloned().collect() | ||
}); | ||
self.depth -= 1; | ||
// do not do extra things | ||
if added_expr.is_empty() { | ||
return Ok(Transformed::no(node)); | ||
} | ||
match node { | ||
// do not add for Projections | ||
LogicalPlan::Projection(_) | ||
| LogicalPlan::TableScan(_) | ||
| LogicalPlan::Join(_) => Ok(Transformed::no(node)), | ||
_ => { | ||
// avoid recursive add projections | ||
if added_expr.iter().any(|expr| { | ||
node.inputs()[0] | ||
.schema() | ||
.has_column_with_unqualified_name(&expr.to_string()) | ||
}) { | ||
return Ok(Transformed::no(node)); | ||
} | ||
|
||
let mut field_set = HashSet::new(); | ||
let mut project_exprs = vec![]; | ||
for expr in added_expr { | ||
let f = DFField::new_unqualified( | ||
&expr.to_string(), | ||
self.data_type_map[&expr].clone(), | ||
true, | ||
); | ||
field_set.insert(f.name().to_owned()); | ||
project_exprs.push(expr.clone().alias(expr.to_string())); | ||
} | ||
for field in node.inputs()[0].schema().fields() { | ||
if field_set.insert(field.qualified_name()) { | ||
project_exprs.push(Expr::Column(field.qualified_column())); | ||
} | ||
} | ||
// adding new plan here | ||
let new_plan = LogicalPlan::Projection(Projection::try_new( | ||
project_exprs, | ||
Arc::new(node.inputs()[0].clone()), | ||
)?); | ||
Ok(Transformed::yes( | ||
node.with_new_exprs(node.expressions(), [new_plan].to_vec())?, | ||
)) | ||
} | ||
} | ||
} | ||
} | ||
#[cfg(test)] | ||
mod test { | ||
use std::iter; | ||
|
@@ -853,7 +1016,7 @@ mod test { | |
fn assert_optimized_plan_eq(expected: &str, plan: &LogicalPlan) { | ||
let optimizer = CommonSubexprEliminate {}; | ||
let optimized_plan = optimizer | ||
.try_optimize(plan, &OptimizerContext::new()) | ||
.common_optimize(plan, &OptimizerContext::new()) | ||
.unwrap() | ||
.expect("failed to optimize plan"); | ||
let formatted_plan = format!("{optimized_plan:?}"); | ||
|
@@ -1272,7 +1435,7 @@ mod test { | |
.unwrap(); | ||
let rule = CommonSubexprEliminate {}; | ||
let optimized_plan = rule | ||
.try_optimize(&plan, &OptimizerContext::new()) | ||
.common_optimize(&plan, &OptimizerContext::new()) | ||
.unwrap() | ||
.unwrap(); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment seems to be incomplete?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for that, I'll complete it