Skip to content

Commit

Permalink
Extract result of find_common_exprs into a struct (#4)
Browse files Browse the repository at this point in the history
* Extract the result of find_common_exprs into a struct

* Make naming consistent
  • Loading branch information
alamb authored Aug 8, 2024
1 parent 6a62811 commit 5fa5457
Showing 1 changed file with 82 additions and 44 deletions.
126 changes: 82 additions & 44 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,22 @@ pub struct CommonSubexprEliminate {
random_state: RandomState,
}

type FindCommonExprResult = Option<(Vec<(Expr, String)>, Vec<Vec<Expr>>)>;
/// The result of potentially rewriting a list of expressions to eliminate common
/// subexpressions.
#[derive(Debug)]
enum FoundCommonExprs {
/// No common expressions were found
No { original_exprs_list: Vec<Vec<Expr>> },
/// Common expressions were found
Yes {
/// extracted common expressions
common_exprs: Vec<(Expr, String)>,
/// new expressions with common subexpressions replaced
new_exprs_list: Vec<Vec<Expr>>,
/// original expressions
original_exprs_list: Vec<Vec<Expr>>,
},
}

impl CommonSubexprEliminate {
pub fn new() -> Self {
Expand Down Expand Up @@ -242,16 +257,13 @@ impl CommonSubexprEliminate {

/// Extracts common sub-expressions and rewrites `exprs_list`.
///
/// Returns a tuple of:
/// 1. The rewritten expressions
/// 2. An optional tuple that contains the extracted common sub-expressions and the
/// original `exprs_list`.
/// Returns `FoundCommonExprs` recording the result of the extraction
fn find_common_exprs(
&self,
exprs_list: Vec<Vec<Expr>>,
config: &dyn OptimizerConfig,
expr_mask: ExprMask,
) -> Result<Transformed<(Vec<Vec<Expr>>, FindCommonExprResult)>> {
) -> Result<Transformed<FoundCommonExprs>> {
let mut found_common = false;
let mut expr_stats = ExprStats::new();
let id_arrays_list = exprs_list
Expand Down Expand Up @@ -279,12 +291,15 @@ impl CommonSubexprEliminate {
)?;
assert!(!common_exprs.is_empty());

Ok(Transformed::yes((
Ok(Transformed::yes(FoundCommonExprs::Yes {
common_exprs: common_exprs.into_values().collect(),
new_exprs_list,
Some((common_exprs.into_values().collect(), exprs_list)),
)))
original_exprs_list: exprs_list,
}))
} else {
Ok(Transformed::no((exprs_list, None)))
Ok(Transformed::no(FoundCommonExprs::No {
original_exprs_list: exprs_list,
}))
}
}

Expand Down Expand Up @@ -356,17 +371,22 @@ impl CommonSubexprEliminate {

// Extract common sub-expressions from the list.
self.find_common_exprs(window_expr_list, config, ExprMask::Normal)?
.map_data(|(new_window_expr_list, common)| match common {
.map_data(|common| match common {
// If there are common sub-expressions, then the insert a projection node
// with the common expressions between the new window nodes and the
// original input.
Some((common_exprs, window_expr_list)) => {
FoundCommonExprs::Yes {
common_exprs,
new_exprs_list,
original_exprs_list,
} => {
build_common_expr_project_plan(input, common_exprs).map(|new_input| {
(new_window_expr_list, new_input, Some(window_expr_list))
(new_exprs_list, new_input, Some(original_exprs_list))
})
}

None => Ok((new_window_expr_list, input, None)),
FoundCommonExprs::No {
original_exprs_list,
} => Ok((original_exprs_list, input, None)),
})?
// Recurse into the new input.
// (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.)
Expand Down Expand Up @@ -441,19 +461,22 @@ impl CommonSubexprEliminate {
let input = unwrap_arc(input);
// Extract common sub-expressions from the aggregate and grouping expressions.
self.find_common_exprs(vec![group_expr, aggr_expr], config, ExprMask::Normal)?
.map_data(|(mut new_expr_list, common)| {
let new_aggr_expr = new_expr_list.pop().unwrap();
let new_group_expr = new_expr_list.pop().unwrap();

.map_data(|common| {
match common {
// If there are common sub-expressions, then insert a projection node
// with the common expressions between the new aggregate node and the
// original input.
Some((common_exprs, mut expr_list)) => {
FoundCommonExprs::Yes {
common_exprs,
mut new_exprs_list,
mut original_exprs_list,
} => {
let new_aggr_expr = new_exprs_list.pop().unwrap();
let new_group_expr = new_exprs_list.pop().unwrap();

build_common_expr_project_plan(input, common_exprs).map(
|new_input| {
let aggr_expr = expr_list.pop().unwrap();

let aggr_expr = original_exprs_list.pop().unwrap();
(
new_aggr_expr,
new_group_expr,
Expand All @@ -464,7 +487,14 @@ impl CommonSubexprEliminate {
)
}

None => Ok((new_aggr_expr, new_group_expr, input, None)),
FoundCommonExprs::No {
mut original_exprs_list,
} => {
let new_aggr_expr = original_exprs_list.pop().unwrap();
let new_group_expr = original_exprs_list.pop().unwrap();

Ok((new_aggr_expr, new_group_expr, input, None))
}
}
})?
// Recurse into the new input.
Expand All @@ -487,16 +517,17 @@ impl CommonSubexprEliminate {
config,
ExprMask::NormalAndAggregates,
)?
.map_data(|(mut new_aggr_list, common)| {
let rewritten_aggr_expr = new_aggr_list.pop().unwrap();

.map_data(|common| {
match common {
// If there are common aggregate sub-expressions, then insert a
// projection above the new rebuilt aggregate node.
Some((common_aggr_exprs, mut aggr_list)) => {
let new_aggr_expr = aggr_list.pop().unwrap();
FoundCommonExprs::Yes {
common_exprs,
mut new_exprs_list,
mut original_exprs_list,
} => {
let rewritten_aggr_expr = new_exprs_list.pop().unwrap();
let new_aggr_expr = original_exprs_list.pop().unwrap();

let mut agg_exprs = common_aggr_exprs
let mut agg_exprs = common_exprs
.into_iter()
.map(|(expr, expr_alias)| expr.alias(expr_alias))
.collect::<Vec<_>>();
Expand Down Expand Up @@ -552,7 +583,11 @@ impl CommonSubexprEliminate {

// If there aren't any common aggregate sub-expressions, then just
// rebuild the aggregate node.
None => {
FoundCommonExprs::No {
mut original_exprs_list,
} => {
let rewritten_aggr_expr = original_exprs_list.pop().unwrap();

// If there were common expressions extracted, then we need to
// make sure we restore the original column names.
// TODO: Although `find_common_exprs()` inserts aliases around
Expand Down Expand Up @@ -622,18 +657,21 @@ impl CommonSubexprEliminate {
) -> Result<Transformed<(Vec<Expr>, LogicalPlan)>> {
// Extract common sub-expressions from the expressions.
self.find_common_exprs(vec![exprs], config, ExprMask::Normal)?
.map_data(|(mut new_exprs_list, common)| {
let new_exprs = new_exprs_list.pop().unwrap();

match common {
// If there are common sub-expressions, then insert a projection node
// with the common expressions between the original node and the
// original input.
Some((common_exprs, _)) => {
build_common_expr_project_plan(input, common_exprs)
.map(|new_input| (new_exprs, new_input))
}
None => Ok((new_exprs, input)),
.map_data(|common| match common {
FoundCommonExprs::Yes {
common_exprs,
mut new_exprs_list,
original_exprs_list: _,
} => {
let new_exprs = new_exprs_list.pop().unwrap();
build_common_expr_project_plan(input, common_exprs)
.map(|new_input| (new_exprs, new_input))
}
FoundCommonExprs::No {
mut original_exprs_list,
} => {
let new_exprs = original_exprs_list.pop().unwrap();
Ok((new_exprs, input))
}
})?
// Recurse into the new input.
Expand Down

0 comments on commit 5fa5457

Please sign in to comment.