Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding complex expressions projections for Subquery #9719

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 189 additions & 26 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

//! Eliminate common sub-expression.

use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;

use crate::utils::is_volatile_expression;
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_expr::{Expr, Operator};
use std::collections::{BTreeSet, HashMap};
use std::ops::Deref;

use std::sync::Arc;

use arrow::datatypes::DataType;
use datafusion_common::tree_node::{
Expand All @@ -31,9 +33,10 @@ use datafusion_common::tree_node::{
use datafusion_common::{
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
};
use datafusion_expr::expr::Alias;
use datafusion_expr::expr::{Alias, WindowFunction};
use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window};
use datafusion_expr::{col, Expr, ExprSchemable};
use datafusion_expr::{col, BinaryExpr, Cast, ExprSchemable};
use hashbrown::HashSet;

/// A map from expression's identifier to tuple including
/// - the expression itself (cloned)
Expand Down Expand Up @@ -120,14 +123,15 @@ impl CommonSubexprEliminate {
.iter()
.zip(arrays_list.iter())
.map(|(exprs, arrays)| {
exprs
let res = exprs
.iter()
.cloned()
.zip(arrays.iter())
.map(|(expr, id_array)| {
replace_common_expr(expr, id_array, expr_set, affected_id)
})
.collect::<Result<Vec<_>>>()
.collect::<Result<Vec<_>>>();
res
})
.collect::<Result<Vec<_>>>()
}
Expand All @@ -146,12 +150,11 @@ impl CommonSubexprEliminate {
self.rewrite_exprs_list(exprs_list, arrays_list, expr_set, &mut affected_id)?;

let mut new_input = self
.try_optimize(input, config)?
.common_optimize(input, config)?
.unwrap_or_else(|| input.clone());
if !affected_id.is_empty() {
new_input = build_common_expr_project_plan(new_input, affected_id, expr_set)?;
}

Ok((rewrite_exprs, new_input))
}

Expand Down Expand Up @@ -187,7 +190,6 @@ impl CommonSubexprEliminate {
window_exprs.push(window_expr);
arrays_per_window.push(arrays);
}

let mut window_exprs = window_exprs
.iter()
.map(|expr| expr.as_slice())
Expand All @@ -196,7 +198,6 @@ impl CommonSubexprEliminate {
.iter()
.map(|arrays| arrays.as_slice())
.collect::<Vec<_>>();

assert_eq!(window_exprs.len(), arrays_per_window.len());
let (mut new_expr, new_input) = self.rewrite_expr(
&window_exprs,
Expand All @@ -206,14 +207,12 @@ impl CommonSubexprEliminate {
config,
)?;
assert_eq!(window_exprs.len(), new_expr.len());

// Construct consecutive window operator, with their corresponding new window expressions.
plan = new_input;
while let Some(new_window_expr) = new_expr.pop() {
// Since `new_expr` and `window_exprs` length are same. We can safely `.unwrap` here.
let orig_window_expr = window_exprs.pop().unwrap();
assert_eq!(new_window_expr.len(), orig_window_expr.len());

// Rename new re-written window expressions with original name (by giving alias)
// Otherwise we may receive schema error, in subsequent operators.
let new_window_expr = new_window_expr
Expand All @@ -226,7 +225,6 @@ impl CommonSubexprEliminate {
.collect::<Result<Vec<_>>>()?;
plan = LogicalPlan::Window(Window::try_new(new_window_expr, Arc::new(plan))?);
}

Ok(plan)
}

Expand Down Expand Up @@ -360,16 +358,13 @@ impl CommonSubexprEliminate {

// Visit expr list and build expr identifier to occuring count map (`expr_set`).
let arrays = to_arrays(&expr, input_schema, &mut expr_set, ExprMask::Normal)?;

let (mut new_expr, new_input) =
self.rewrite_expr(&[&expr], &[&arrays], input, &expr_set, config)?;

plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input])
}
}

impl OptimizerRule for CommonSubexprEliminate {
fn try_optimize(
impl CommonSubexprEliminate {
fn common_optimize(
&self,
plan: &LogicalPlan,
config: &dyn OptimizerConfig,
Expand Down Expand Up @@ -413,8 +408,8 @@ impl OptimizerRule for CommonSubexprEliminate {

let original_schema = plan.schema().clone();
match optimized_plan {
Some(LogicalPlan::Projection(_)) => Ok(optimized_plan),
Some(optimized_plan) if optimized_plan.schema() != &original_schema => {
// add an additional projection if the output schema changed.
Ok(Some(build_recover_project_plan(
&original_schema,
optimized_plan,
Expand All @@ -423,6 +418,33 @@ impl OptimizerRule for CommonSubexprEliminate {
plan => Ok(plan),
}
}
/// currently the implemention is not optimal, Basically I just do a top-down iteration over all the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment seems to be incomplete?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for that, I'll complete it

///
fn add_extra_projection(&self, plan: &LogicalPlan) -> Result<Option<LogicalPlan>> {
plan.clone()
.rewrite(&mut ProjectionAdder {
data_type_map: HashMap::new(),
depth_map: HashMap::new(),
depth: 0,
})
.map(|transformed| Some(transformed.data))
}
}
impl OptimizerRule for CommonSubexprEliminate {
fn try_optimize(
&self,
plan: &LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
let optimized_plan_option = self.common_optimize(plan, config)?;

// println!("optimized plan option is {:?}", optimized_plan_option);
let plan = match optimized_plan_option {
Some(plan) => plan,
_ => plan.clone(),
};
self.add_extra_projection(&plan)
}

fn name(&self) -> &str {
"common_sub_expression_eliminate"
Expand Down Expand Up @@ -454,7 +476,8 @@ fn to_arrays(
expr_set: &mut ExprSet,
expr_mask: ExprMask,
) -> Result<Vec<Vec<(usize, String)>>> {
expr.iter()
let res = expr
.iter()
.map(|e| {
let mut id_array = vec![];
expr_to_identifier(
Expand All @@ -467,7 +490,8 @@ fn to_arrays(

Ok(id_array)
})
.collect::<Result<Vec<_>>>()
.collect::<Result<Vec<_>>>();
res
}

/// Build the "intermediate" projection plan that evaluates the extracted common expressions.
Expand Down Expand Up @@ -498,7 +522,6 @@ fn build_common_expr_project_plan(
project_exprs.push(Expr::Column(field.qualified_column()));
}
}

Ok(LogicalPlan::Projection(Projection::try_new(
project_exprs,
Arc::new(input),
Expand Down Expand Up @@ -693,7 +716,6 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));

let data_type = expr.get_type(&self.input_schema)?;

self.expr_set
.entry(desc)
.or_insert_with(|| (expr.clone(), 0, data_type))
Expand Down Expand Up @@ -829,6 +851,147 @@ fn replace_common_expr(
.data()
}

struct ProjectionAdder {
depth_map: HashMap<u128, HashSet<Expr>>,
depth: u128,
data_type_map: HashMap<Expr, DataType>,
}
pub fn is_not_complex(op: &Operator) -> bool {
matches!(
op,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sugg:

  • Operator::LtEq and Operator::GtEq can also be sonsidered not complex
  • Operator is Copy, so changing the function parameter to op: Operator seems better

&Operator::Eq | &Operator::NotEq | &Operator::Lt | &Operator::Gt | &Operator::And
)
}
impl ProjectionAdder {
// TODO: adding more expressions for sub query, currently only support for Simple Binary Expressions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a tracking issue for this that we could refer to in the comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add that

fn get_complex_expressions(
exprs: Vec<Expr>,
schema: DFSchemaRef,
) -> (HashSet<Expr>, HashMap<Expr, DataType>) {
let mut res = HashSet::new();
let mut expr_data_type: HashMap<Expr, DataType> = HashMap::new();
for expr in exprs {
match expr {
Expr::BinaryExpr(BinaryExpr {
left: ref l_box,
op,
right: ref r_box,
}) if !is_not_complex(&op) => {
if let (Expr::Column(l), Expr::Column(_r)) = (&**l_box, &**r_box) {
let l_field = schema
.field_from_column(l)
.expect("Field not found for left column");

// res.insert(DFField::new_unqualified(
// &expr.to_string(),
// l_field.data_type().clone(),
// true,
// ));
expr_data_type
.entry(expr.clone())
.or_insert(l_field.data_type().clone());
res.insert(expr.clone());
}
}
Expr::Cast(Cast { expr, data_type: _ }) => {
let (expr_set, type_data_map) =
Self::get_complex_expressions(vec![*expr], schema.clone());
res.extend(expr_set);
expr_data_type.extend(type_data_map);
}

Expr::WindowFunction(WindowFunction { fun: _, args, .. }) => {
let (expr_set, type_map) =
Self::get_complex_expressions(args, schema.clone());
res.extend(expr_set);
expr_data_type.extend(type_map);
}
_ => {}
}
}
(res, expr_data_type)
}
}
impl TreeNodeRewriter for ProjectionAdder {
type Node = LogicalPlan;
/// currently we just collect the complex bianryOP

fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
// use depth to trace where we are in the LogicalPlan tree
self.depth += 1;
// extract all expressions + check whether it contains in depth_sets
let exprs = node.expressions();
let depth_set = self.depth_map.entry(self.depth).or_default();
let mut schema = node.schema().deref().clone();
for ip in node.inputs() {
schema.merge(ip.schema());
}
let (extended_set, data_map) =
Self::get_complex_expressions(exprs, Arc::new(schema));
depth_set.extend(extended_set);
self.data_type_map.extend(data_map);
Ok(Transformed::no(node))
}
fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
let current_depth_schema =
self.depth_map.get(&self.depth).cloned().unwrap_or_default();

// get the intersected part
let added_expr = self
.depth_map
.iter()
.filter(|(&depth, _)| depth < self.depth)
.fold(current_depth_schema, |acc, (_, expr)| {
acc.intersection(expr).cloned().collect()
});
self.depth -= 1;
// do not do extra things
if added_expr.is_empty() {
return Ok(Transformed::no(node));
}
match node {
// do not add for Projections
LogicalPlan::Projection(_)
| LogicalPlan::TableScan(_)
| LogicalPlan::Join(_) => Ok(Transformed::no(node)),
_ => {
// avoid recursive add projections
if added_expr.iter().any(|expr| {
node.inputs()[0]
.schema()
.has_column_with_unqualified_name(&expr.to_string())
}) {
return Ok(Transformed::no(node));
}

let mut field_set = HashSet::new();
let mut project_exprs = vec![];
for expr in added_expr {
let f = DFField::new_unqualified(
&expr.to_string(),
self.data_type_map[&expr].clone(),
true,
);
field_set.insert(f.name().to_owned());
project_exprs.push(expr.clone().alias(expr.to_string()));
}
for field in node.inputs()[0].schema().fields() {
if field_set.insert(field.qualified_name()) {
project_exprs.push(Expr::Column(field.qualified_column()));
}
}
// adding new plan here
let new_plan = LogicalPlan::Projection(Projection::try_new(
project_exprs,
Arc::new(node.inputs()[0].clone()),
)?);
Ok(Transformed::yes(
node.with_new_exprs(node.expressions(), [new_plan].to_vec())?,
))
}
}
}
}
#[cfg(test)]
mod test {
use std::iter;
Expand All @@ -853,7 +1016,7 @@ mod test {
fn assert_optimized_plan_eq(expected: &str, plan: &LogicalPlan) {
let optimizer = CommonSubexprEliminate {};
let optimized_plan = optimizer
.try_optimize(plan, &OptimizerContext::new())
.common_optimize(plan, &OptimizerContext::new())
.unwrap()
.expect("failed to optimize plan");
let formatted_plan = format!("{optimized_plan:?}");
Expand Down Expand Up @@ -1272,7 +1435,7 @@ mod test {
.unwrap();
let rule = CommonSubexprEliminate {};
let optimized_plan = rule
.try_optimize(&plan, &OptimizerContext::new())
.common_optimize(&plan, &OptimizerContext::new())
.unwrap()
.unwrap();

Expand Down
Loading
Loading