Skip to content

Commit

Permalink
refactor code and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Mar 22, 2024
1 parent cc1d412 commit 1c64d23
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 63 deletions.
2 changes: 0 additions & 2 deletions datafusion/expr/src/tree_node/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ impl TreeNode for Expr {
| Expr::Wildcard {..}
| Expr::Placeholder (_) => vec![],
Expr::BinaryExpr(BinaryExpr { left, right, .. }) => {
println!("left is {:?}, right is {:?}", left, right);
vec![left.as_ref(), right.as_ref()]
}
Expr::Like(Like { expr, pattern, .. })
Expand Down Expand Up @@ -123,7 +122,6 @@ impl TreeNode for Expr {
let mut expr_vec = args.iter().collect::<Vec<_>>();
expr_vec.extend(partition_by);
expr_vec.extend(order_by);
println!("expressions vector is {:?}", expr_vec);
expr_vec
}
Expr::InList(InList { expr, list, .. }) => {
Expand Down
168 changes: 115 additions & 53 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::utils::is_volatile_expression;
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_expr::Expr;
use std::collections::{BTreeSet, HashMap};
use std::ops::Deref;
use std::slice::Windows;
use std::sync::Arc;

Expand Down Expand Up @@ -486,8 +487,8 @@ impl CommonSubexprEliminate {
original_schema
);
match optimized_plan {
Some(LogicalPlan::Projection(_)) => Ok(optimized_plan),
Some(optimized_plan) if optimized_plan.schema() != &original_schema => {
// add an additional projection if the output schema changed.
println!(
"************** \n optimized schema {:?} \n ************",
optimized_plan.schema(),
Expand All @@ -510,6 +511,7 @@ impl CommonSubexprEliminate {
let res = plan
.clone()
.rewrite(&mut ProjectionAdder {
data_type_map: HashMap::new(),
depth_map: HashMap::new(),
depth: 0,
})
Expand All @@ -525,11 +527,15 @@ impl OptimizerRule for CommonSubexprEliminate {
config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
let optimized_plan_option = self.common_optimize(plan, config)?;

// println!("optimized plan option is {:?}", optimized_plan_option);
let plan = match optimized_plan_option {
Some(plan) => plan,
_ => plan.clone(),
};
self.add_extra_projection(&plan)
let res = self.add_extra_projection(&plan);
println!("res is {:?}", res);
res
}

fn name(&self) -> &str {
Expand Down Expand Up @@ -600,6 +606,7 @@ fn build_common_expr_project_plan(
"\n ********** expr is {:?} \n \n and data type is {:?} \n **************** \n",
expr, data_type
);
println!("********* \n {:?} ***********\n", expr);
let field = DFField::new_unqualified(&id, data_type.clone(), true);
fields_set.insert(field.name().to_owned());
project_exprs.push(expr.clone().alias(&id));
Expand All @@ -615,7 +622,11 @@ fn build_common_expr_project_plan(
project_exprs.push(Expr::Column(field.qualified_column()));
}
}

println!(
"************** \n exprs are {:?} \n ************** \n",
project_exprs
);
println!("********** \n input are {:?} \n ************** \n", input);
Ok(LogicalPlan::Projection(Projection::try_new(
project_exprs,
Arc::new(input),
Expand Down Expand Up @@ -974,54 +985,65 @@ fn replace_common_expr(
}

struct ProjectionAdder {
depth_map: HashMap<u128, HashSet<DFField>>,
depth_map: HashMap<u128, HashSet<Expr>>,
depth: u128,
data_type_map: HashMap<Expr, DataType>,
}

impl ProjectionAdder {
// TODO: adding more expressions for sub query, currently only support for Simple Binary Expressions
fn get_complex_expressions(
exprs: Vec<Expr>,
schema: DFSchemaRef,
) -> HashSet<DFField> {
) -> (HashSet<Expr>, HashMap<Expr, DataType>) {
let mut res = HashSet::new();
let mut expr_data_type: HashMap<Expr, DataType> = HashMap::new();
println!(
"********** \n current schema is {:?} \n ******** \n",
schema
);
for expr in exprs {
println!("current expr is {:?}", expr);
match expr {
Expr::BinaryExpr(BinaryExpr {
left: ref l_box,
op: _,
right: ref r_box,
}) => {
match (&**l_box, &**r_box) {
(Expr::Column(l), Expr::Column(r)) => {
// Both `left` and `right` are `Expr::Column`, so we push to `res`
if schema.has_column(l) && schema.has_column(r) {
res.insert(DFField::new_unqualified(
&expr.to_string(),
schema
.field_from_column(l)
.unwrap()
.data_type()
.clone(),
true,
));
}
}
// If they are not both `Expr::Column`, you can handle other cases or do nothing
_ => {}
}) => match (&**l_box, &**r_box) {
(Expr::Column(l), Expr::Column(r)) => {
let l_field = schema
.field_from_column(l)
.expect("Field not found for left column");

// res.insert(DFField::new_unqualified(
// &expr.to_string(),
// l_field.data_type().clone(),
// true,
// ));
expr_data_type
.entry(expr.clone())
.or_insert(l_field.data_type().clone());
res.insert(expr.clone());
}
}
Expr::Cast(_) => {
res.extend(Self::get_complex_expressions(vec![expr], schema.clone()))
_ => {}
},
Expr::Cast(Cast { expr, data_type }) => {
let (expr_set, type_map) =
Self::get_complex_expressions(vec![*expr], schema.clone());
res.extend(expr_set);
expr_data_type.extend(type_map);
}

Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
res.extend(Self::get_complex_expressions(args, schema.clone()))
let (expr_set, type_map) =
Self::get_complex_expressions(args, schema.clone());
res.extend(expr_set);
expr_data_type.extend(type_map);
}
_ => {}
}
}
res
(res, expr_data_type)
}
}
impl TreeNodeRewriter for ProjectionAdder {
Expand All @@ -1037,57 +1059,97 @@ impl TreeNodeRewriter for ProjectionAdder {
println!("********* \n exprs are {:?} \n *********** \n", exprs);
let depth_set = self.depth_map.entry(self.depth).or_default();
println!("self.input schema is {:?}", node.schema());
depth_set.extend(Self::get_complex_expressions(exprs, node.schema().clone()));
let mut schema = node.schema().deref().clone();
for ip in node.inputs() {
schema.merge(ip.schema());
}
let (extended_set, data_map) =
Self::get_complex_expressions(exprs, Arc::new(schema));
println!(
"*********** \n extened set is {:?} \n ********** \n",
extended_set
);
depth_set.extend(extended_set);
self.data_type_map.extend(data_map);
Ok(Transformed::no(node))
}
fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
self.depth -= 1;
println!("*********** \ncur plan is {:?} \n*************\n", node);
println!("*********** \n cur plan is {:?} \n*************\n", node);

let current_depth_schema =
self.depth_map.get(&self.depth).cloned().unwrap_or_default();

// get the intersected part
let added_schema = self
let added_expr = self
.depth_map
.iter()
.filter(|(&depth, _)| depth < self.depth)
.fold(current_depth_schema, |acc, (_, fields)| {
acc.intersection(fields).cloned().collect()
.fold(current_depth_schema, |acc, (_, expr)| {
acc.intersection(expr).cloned().collect()
});

println!(
"******** \n intersected parts are {:?} \n ***********",
added_expr
);
println!(
"******** \n depth map is {:?} \n ***********",
self.depth_map
);
println!(
"************ \n data type map is {:?} \n *********** \n",
self.data_type_map
);
self.depth -= 1;
// do not do extra things
if added_schema.is_empty() {
if added_expr.is_empty() {
return Ok(Transformed::no(node));
}

println!("\n*************\n{:?}\n*************\n", added_schema);
println!("\n*************\n{:?}\n*************\n", added_expr);

match node {
// do not add for Projections
LogicalPlan::Projection(_) => Ok(Transformed::no(node)),
LogicalPlan::Projection(_) | LogicalPlan::TableScan(_) => {
Ok(Transformed::no(node))
}
_ => {
let mut col_exprs: Vec<Expr> = node
.schema()
.fields()
.iter()
.map(|field| Expr::Column(field.qualified_column()))
.collect();
// avoid recursive add projections
if added_expr.iter().any(|expr| {
node.inputs()[0]
.schema()
.has_column_with_unqualified_name(&expr.to_string())
}) {
return Ok(Transformed::no(node));
}

col_exprs.extend(
added_schema
.iter()
.map(|field| Expr::Column(field.qualified_column())),
let mut field_set = HashSet::new();
let mut project_exprs = vec![];
for expr in added_expr {
let f = DFField::new_unqualified(
&expr.to_string(),
self.data_type_map[&expr].clone(),
true,
);
field_set.insert(f.name().to_owned());
project_exprs.push(expr.clone().alias(expr.to_string()));
}
for field in node.inputs()[0].schema().fields() {
if field_set.insert(field.qualified_name()) {
project_exprs.push(Expr::Column(field.qualified_column()));
}
}
println!(
"************* \n project expressions are {:?} \n ********** \n",
project_exprs
);

// adding new plan here
let new_plan = LogicalPlan::Projection(Projection::try_new(
col_exprs.clone(),
Arc::new(node.clone()),
project_exprs,
Arc::new(node.inputs()[0].clone()),
)?);

println!("new plan is {:?}", new_plan);
Ok(Transformed::yes(
node.with_new_exprs(col_exprs, [new_plan].to_vec())?,
node.with_new_exprs(node.expressions(), [new_plan].to_vec())?,
))
}
}
Expand Down
Loading

0 comments on commit 1c64d23

Please sign in to comment.