Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/pr9719 exp #1

Merged
merged 13 commits into from
Mar 29, 2024
291 changes: 159 additions & 132 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,14 @@ impl CommonSubexprEliminate {
.iter()
.zip(arrays_list.iter())
.map(|(exprs, arrays)| {
let res = exprs
exprs
.iter()
.cloned()
.zip(arrays.iter())
.map(|(expr, id_array)| {
replace_common_expr(expr, id_array, expr_set, affected_id)
})
.collect::<Result<Vec<_>>>();
res
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<_>>>()
}
Expand Down Expand Up @@ -421,13 +420,15 @@ impl CommonSubexprEliminate {
/// currently the implemention is not optimal, Basically I just do a top-down iteration over all the
///
fn add_extra_projection(&self, plan: &LogicalPlan) -> Result<Option<LogicalPlan>> {
plan.clone()
let result = plan
.clone()
.rewrite(&mut ProjectionAdder {
data_type_map: HashMap::new(),
depth_map: HashMap::new(),
insertion_point_map: HashMap::new(),
depth: 0,
})
.map(|transformed| Some(transformed.data))
complex_exprs: HashMap::new(),
})?
.data;
Ok(Some(result))
}
}
impl OptimizerRule for CommonSubexprEliminate {
Expand Down Expand Up @@ -475,8 +476,7 @@ fn to_arrays(
expr_set: &mut ExprSet,
expr_mask: ExprMask,
) -> Result<Vec<Vec<(usize, String)>>> {
let res = expr
.iter()
expr.iter()
.map(|e| {
let mut id_array = vec![];
expr_to_identifier(
Expand All @@ -489,8 +489,7 @@ fn to_arrays(

Ok(id_array)
})
.collect::<Result<Vec<_>>>();
res
.collect::<Result<Vec<_>>>()
}

/// Build the "intermediate" projection plan that evaluates the extracted common expressions.
Expand Down Expand Up @@ -851,24 +850,28 @@ fn replace_common_expr(
}

struct ProjectionAdder {
depth_map: HashMap<u128, HashSet<Expr>>,
depth: u128,
data_type_map: HashMap<Expr, DataType>,
// Keeps track of cumulative usage of common expressions with its corresponding data type.
// accross plan where key is unsafe nodes that cumulative tracking is invalidated.
insertion_point_map: HashMap<usize, HashMap<Expr, (DataType, u32)>>,
depth: usize,
// Keeps track of cumulative usage of the common expressions with its corresponding data type.
// between safe nodes.
complex_exprs: HashMap<Expr, (DataType, u32)>,
}
pub fn is_not_complex(op: &Operator) -> bool {
matches!(
op,
&Operator::Eq | &Operator::NotEq | &Operator::Lt | &Operator::Gt | &Operator::And
)
}

impl ProjectionAdder {
// TODO: adding more expressions for sub query, currently only support for Simple Binary Expressions
fn get_complex_expressions(
exprs: Vec<Expr>,
schema: DFSchemaRef,
) -> (HashSet<Expr>, HashMap<Expr, DataType>) {
) -> HashSet<(Expr, DataType)> {
let mut res = HashSet::new();
let mut expr_data_type: HashMap<Expr, DataType> = HashMap::new();
for expr in exprs {
match expr {
Expr::BinaryExpr(BinaryExpr {
Expand All @@ -880,153 +883,177 @@ impl ProjectionAdder {
let l_field = schema
.field_from_column(l)
.expect("Field not found for left column");

expr_data_type
.entry(expr.clone())
.or_insert(l_field.data_type().clone());
res.insert(expr.clone());
res.insert((expr.clone(), l_field.data_type().clone()));
}
}
Expr::Cast(Cast { expr, data_type: _ }) => {
let (expr_set, type_data_map) =
let exprs_with_type =
Self::get_complex_expressions(vec![*expr], schema.clone());
res.extend(expr_set);
expr_data_type.extend(type_data_map);
res.extend(exprs_with_type);
}
Expr::Alias(Alias {
expr,
relation: _,
name: _,
}) => {
let exprs_with_type =
Self::get_complex_expressions(vec![*expr], schema.clone());
res.extend(exprs_with_type);
}

Expr::WindowFunction(WindowFunction { fun: _, args, .. }) => {
let (expr_set, type_map) =
let exprs_with_type =
Self::get_complex_expressions(args, schema.clone());
res.extend(expr_set);
expr_data_type.extend(type_map);
res.extend(exprs_with_type);
}
_ => {}
}
}
(res, expr_data_type)
res
}
}
impl TreeNodeRewriter for ProjectionAdder {
type Node = LogicalPlan;
/// currently we just collect the complex bianryOP

fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
fn update_expr_with_available_columns(
expr: &mut Expr,
available_columns: &[Column],
) -> Result<()> {
match expr {
Expr::BinaryExpr(_) => {
for available_col in available_columns {
if available_col.flat_name() == expr.display_name()? {
*expr = Expr::Column(available_col.clone());
}
}
}
Expr::WindowFunction(WindowFunction { fun: _, args, .. }) => {
args.iter_mut().try_for_each(|arg| {
Self::update_expr_with_available_columns(arg, available_columns)
})?
}
Expr::Cast(Cast { expr, .. }) => {
Self::update_expr_with_available_columns(expr, available_columns)?
}
Expr::Alias(alias) => {
Self::update_expr_with_available_columns(
&mut alias.expr,
available_columns,
)?;
}
_ => {
// cannot rewrite
}
}
Ok(())
}

// Assumes operators doesn't modify name of the fields.
// Otherwise this operation is not safe.
fn extend_with_exprs(&mut self, node: &LogicalPlan) {
// use depth to trace where we are in the LogicalPlan tree
self.depth += 1;
// extract all expressions and check whether it contains in depth_sets
// extract all expressions + check whether it contains in depth_sets
let exprs = node.expressions();
let depth_set = self.depth_map.entry(self.depth).or_default();
let mut schema = node.schema().deref().clone();
for ip in node.inputs() {
schema.merge(ip.schema());
}
let (extended_set, data_map) =
Self::get_complex_expressions(exprs, Arc::new(schema));
depth_set.extend(extended_set);
self.data_type_map.extend(data_map);
Ok(Transformed::no(node))
let expr_with_type = Self::get_complex_expressions(exprs, Arc::new(schema));
for (expr, dtype) in expr_with_type {
let (_, count) = self.complex_exprs.entry(expr).or_insert_with(|| (dtype, 0));
*count += 1;
}
}
fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
let current_depth_schema =
self.depth_map.get(&self.depth).cloned().unwrap_or_default();
}
impl TreeNodeRewriter for ProjectionAdder {
type Node = LogicalPlan;
/// currently we just collect the complex bianryOP

// get the intersected part looking up
let added_expr = self
.depth_map
.iter()
.filter(|(&depth, _)| depth < self.depth)
.fold(current_depth_schema.clone(), |acc, (_, expr)| {
acc.intersection(expr).cloned().collect()
});
// in projection, we are trying to get intersect with lower one and if some column is not used in upper one, we just abandon them
if let LogicalPlan::Projection(_) = node.clone() {
// we get the expressions that has intersect with deeper layer since those expressions could be rewrite
let mut cross_expr_deeper = HashSet::new();

self.depth_map
.iter()
.filter(|(&depth, _)| depth > self.depth)
.for_each(|(_, exprs)| {
// get intersection
let intersection = current_depth_schema
.intersection(exprs)
.cloned()
.collect::<HashSet<_>>();
cross_expr_deeper =
cross_expr_deeper.union(&intersection).cloned().collect();
});
let mut project_exprs = vec![];
if !cross_expr_deeper.is_empty() {
let current_expressions = node.expressions();
for expr in current_expressions {
if cross_expr_deeper.contains(&expr) {
let f = DFField::new_unqualified(
&expr.to_string(),
self.data_type_map[&expr].clone(),
true,
);
project_exprs.push(Expr::Column(f.qualified_column()));
cross_expr_deeper.remove(&expr);
} else {
project_exprs.push(expr);
}
}
let new_projection = LogicalPlan::Projection(Projection::try_new(
project_exprs,
Arc::new(node.inputs()[0].clone()),
)?);
self.depth -= 1;
return Ok(Transformed::yes(new_projection));
fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
// Insert for other end points
self.depth += 1;
match node {
LogicalPlan::TableScan(_) => {
// Stop tracking cumulative usage at the source.
let complex_exprs = std::mem::take(&mut self.complex_exprs);
self.insertion_point_map
.insert(self.depth - 1, complex_exprs);
Ok(Transformed::no(node))
}
LogicalPlan::Sort(_) | LogicalPlan::Filter(_) | LogicalPlan::Window(_) => {
// These are safe operators where, expression identity is preserved during operation.
self.extend_with_exprs(&node);
Ok(Transformed::no(node))
}
LogicalPlan::Projection(_) => {
// Stop tracking cumulative usage at the projection since it may invalidate expression identity.
let complex_exprs = std::mem::take(&mut self.complex_exprs);
self.insertion_point_map
.insert(self.depth - 1, complex_exprs);
// Start tracking common expressions from now on including projection.
self.extend_with_exprs(&node);
Ok(Transformed::no(node))
}
_ => {
// Unsupported operators
self.complex_exprs.clear();
Ok(Transformed::no(node))
}
}
}

fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
let cached_exprs = self
.insertion_point_map
.get(&self.depth)
.cloned()
.unwrap_or_default();
self.depth -= 1;
// do not do extra things
if added_expr.is_empty() {
let should_add_projection =
cached_exprs.iter().any(|(_expr, (_, count))| *count > 1);

let children = node.inputs();
if children.len() != 1 {
// Only can rewrite node with single child
return Ok(Transformed::no(node));
}
match node {
// do not add for Projections
LogicalPlan::Projection(_)
| LogicalPlan::TableScan(_)
| LogicalPlan::Join(_) => Ok(Transformed::no(node)),

_ => {
// avoid recursive add projections
if added_expr.iter().any(|expr| {
node.inputs()[0]
.schema()
.has_column_with_unqualified_name(&expr.to_string())
}) {
return Ok(Transformed::no(node));
}

let mut field_set = HashSet::new();
let mut project_exprs = vec![];
// here we can deduce that
for expr in added_expr {
let f = DFField::new_unqualified(
&expr.to_string(),
self.data_type_map[&expr].clone(),
true,
);
let child = children[0].clone();
let child = if should_add_projection {
let mut field_set = HashSet::new();
let mut project_exprs = vec![];
for (expr, (dtype, count)) in &cached_exprs {
if *count > 1 {
let f =
DFField::new_unqualified(&expr.to_string(), dtype.clone(), true);
field_set.insert(f.name().to_owned());
project_exprs.push(expr.clone().alias(expr.to_string()));
}
for field in node.inputs()[0].schema().fields() {
if field_set.insert(field.qualified_name()) {
project_exprs.push(Expr::Column(field.qualified_column()));
}
}
// Do not lose fields in the child.
for field in child.schema().fields() {
if field_set.insert(field.qualified_name()) {
project_exprs.push(Expr::Column(field.qualified_column()));
}
// adding new plan here
let new_plan = LogicalPlan::Projection(Projection::try_new(
project_exprs,
Arc::new(node.inputs()[0].clone()),
)?);
Ok(Transformed::yes(
node.with_new_exprs(node.expressions(), [new_plan].to_vec())?,
))
}
}

// adding new plan here
LogicalPlan::Projection(Projection::try_new(
project_exprs,
Arc::new(node.inputs()[0].clone()),
)?)
} else {
child
};
let mut expressions = node.expressions();
let available_columns = child
.schema()
.fields()
.iter()
.map(|field| field.qualified_column())
.collect::<Vec<_>>();
// Replace expressions with its pre-computed variant if available.
expressions.iter_mut().try_for_each(|expr| {
Self::update_expr_with_available_columns(expr, &available_columns)
})?;
let new_node = node.with_new_exprs(expressions, [child].to_vec())?;
Ok(Transformed::yes(new_node))
}
}
#[cfg(test)]
Expand Down
Loading
Loading