Skip to content

Commit

Permalink
Re-implement ExprIdentifierVisitor::desc_expr to use Expr::Display (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Sep 3, 2022
1 parent e26452a commit 786c319
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 187 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ async fn multiple_or_predicates() -> Result<()> {
let expected =vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #lineitem.l_partkey [l_partkey:Int64]",
" Projection: #part.p_partkey = #lineitem.l_partkey AS BinaryExpr-=Column-lineitem.l_partkeyColumn-part.p_partkey, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [BinaryExpr-=Column-lineitem.l_partkeyColumn-part.p_partkey:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
" Projection: #part.p_partkey = #lineitem.l_partkey AS #part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Int64(1) AND #lineitem.l_quantity <= Int64(11) AND #part.p_size BETWEEN Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Int64(10) AND #lineitem.l_quantity <= Int64(20) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Int64(20) AND #lineitem.l_quantity <= Int64(30) AND #part.p_size BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" CrossJoin: [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Float64]",
Expand Down
207 changes: 21 additions & 186 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, DataFusionError, Result};
use datafusion_expr::{
col,
expr::GroupingSet,
expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion},
expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion},
logical_plan::{Aggregate, Filter, LogicalPlan, Projection, Sort, Window},
utils::from_plan,
Expr, ExprSchemable,
};
use std::collections::{HashMap, HashSet};
use std::fmt::Write;
use std::sync::Arc;

/// A map from expression's identifier to tuple including
Expand Down Expand Up @@ -392,171 +390,7 @@ enum VisitRecord {

impl ExprIdentifierVisitor<'_> {
fn desc_expr(expr: &Expr) -> String {
let mut desc = String::new();
match expr {
Expr::Column(column) => {
desc.push_str("Column-");
desc.push_str(&column.flat_name());
}
Expr::ScalarVariable(_, var_names) => {
desc.push_str("ScalarVariable-");
desc.push_str(&var_names.join("."));
}
Expr::Alias(_, alias) => {
desc.push_str("Alias-");
desc.push_str(alias);
}
Expr::Literal(value) => {
desc.push_str("Literal");
desc.push_str(&value.to_string());
}
Expr::BinaryExpr { op, .. } => {
desc.push_str("BinaryExpr-");
desc.push_str(&op.to_string());
}
Expr::Not(_) => {
desc.push_str("Not-");
}
Expr::IsNotNull(_) => {
desc.push_str("IsNotNull-");
}
Expr::IsNull(_) => {
desc.push_str("IsNull-");
}
Expr::IsTrue(_) => {
desc.push_str("IsTrue-");
}
Expr::IsFalse(_) => {
desc.push_str("IsFalse-");
}
Expr::IsUnknown(_) => {
desc.push_str("IsUnknown-");
}
Expr::IsNotTrue(_) => {
desc.push_str("IsNotTrue-");
}
Expr::IsNotFalse(_) => {
desc.push_str("IsNotFalse-");
}
Expr::IsNotUnknown(_) => {
desc.push_str("IsNotUnknown-");
}
Expr::Negative(_) => {
desc.push_str("Negative-");
}
Expr::Between { negated, .. } => {
desc.push_str("Between-");
desc.push_str(&negated.to_string());
}
Expr::Like { negated, .. } => {
desc.push_str("Like-");
desc.push_str(&negated.to_string());
}
Expr::ILike { negated, .. } => {
desc.push_str("ILike-");
desc.push_str(&negated.to_string());
}
Expr::SimilarTo { negated, .. } => {
desc.push_str("SimilarTo-");
desc.push_str(&negated.to_string());
}
Expr::Case { .. } => {
desc.push_str("Case-");
}
Expr::Cast { data_type, .. } => {
desc.push_str("Cast-");
let _ = write!(desc, "{:?}", data_type);
}
Expr::TryCast { data_type, .. } => {
desc.push_str("TryCast-");
let _ = write!(desc, "{:?}", data_type);
}
Expr::Sort {
asc, nulls_first, ..
} => {
desc.push_str("Sort-");
let _ = write!(desc, "{}{}", asc, nulls_first);
}
Expr::ScalarFunction { fun, .. } => {
desc.push_str("ScalarFunction-");
desc.push_str(&fun.to_string());
}
Expr::ScalarUDF { fun, .. } => {
desc.push_str("ScalarUDF-");
desc.push_str(&fun.name);
}
Expr::WindowFunction {
fun, window_frame, ..
} => {
desc.push_str("WindowFunction-");
desc.push_str(&fun.to_string());
let _ = write!(desc, "{:?}", window_frame);
}
Expr::AggregateFunction { fun, distinct, .. } => {
desc.push_str("AggregateFunction-");
desc.push_str(&fun.to_string());
desc.push_str(&distinct.to_string());
}
Expr::AggregateUDF { fun, .. } => {
desc.push_str("AggregateUDF-");
desc.push_str(&fun.name);
}
Expr::InList { negated, .. } => {
desc.push_str("InList-");
desc.push_str(&negated.to_string());
}
Expr::Exists { negated, .. } => {
desc.push_str("Exists-");
desc.push_str(&negated.to_string());
}
Expr::InSubquery { negated, .. } => {
desc.push_str("InSubquery-");
desc.push_str(&negated.to_string());
}
Expr::ScalarSubquery(_) => {
desc.push_str("ScalarSubquery-");
}
Expr::Wildcard => {
desc.push_str("Wildcard-");
}
Expr::QualifiedWildcard { qualifier } => {
desc.push_str("QualifiedWildcard-");
desc.push_str(qualifier);
}
Expr::GetIndexedField { key, .. } => {
desc.push_str("GetIndexedField-");
desc.push_str(&key.to_string());
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
desc.push_str("Rollup");
for expr in exprs {
desc.push('-');
desc.push_str(&Self::desc_expr(expr));
}
}
GroupingSet::Cube(exprs) => {
desc.push_str("Cube");
for expr in exprs {
desc.push('-');
desc.push_str(&Self::desc_expr(expr));
}
}
GroupingSet::GroupingSets(lists_of_exprs) => {
desc.push_str("GroupingSets");
for exprs in lists_of_exprs {
desc.push('(');
for expr in exprs {
desc.push('-');
desc.push_str(&Self::desc_expr(expr));
}
desc.push(')');
}
}
},
}

desc
format!("{}", expr)
}

/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
Expand Down Expand Up @@ -749,13 +583,13 @@ mod test {
};
use std::iter;

fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
fn assert_optimized_plan_eq(expected: &str, plan: &LogicalPlan) {
let optimizer = CommonSubexprEliminate {};
let optimized_plan = optimizer
.optimize(plan, &mut OptimizerConfig::new())
.expect("failed to optimize plan");
let formatted_plan = format!("{:?}", optimized_plan);
assert_eq!(formatted_plan, expected);
assert_eq!(expected, formatted_plan);
}

#[test]
Expand All @@ -774,19 +608,20 @@ mod test {
expr_to_identifier(&expr, &mut HashMap::new(), &mut id_array, DataType::Int64)?;

let expected = vec![
(9, "BinaryExpr-*Literal2BinaryExpr--AggregateFunction-AVGfalseColumn-cAggregateFunction-SUMfalseBinaryExpr-+Literal1Column-a"),
(7, "BinaryExpr--AggregateFunction-AVGfalseColumn-cAggregateFunction-SUMfalseBinaryExpr-+Literal1Column-a"),
(4, "AggregateFunction-SUMfalseBinaryExpr-+Literal1Column-a"), (3, "BinaryExpr-+Literal1Column-a"),
(9, "SUM(#a + Utf8(\"1\")) - AVG(#c) * Int32(2)Int32(2)SUM(#a + Utf8(\"1\")) - AVG(#c)AVG(#c)#cSUM(#a + Utf8(\"1\"))#a + Utf8(\"1\")Utf8(\"1\")#a"),
(7, "SUM(#a + Utf8(\"1\")) - AVG(#c)AVG(#c)#cSUM(#a + Utf8(\"1\"))#a + Utf8(\"1\")Utf8(\"1\")#a"),
(4, "SUM(#a + Utf8(\"1\"))#a + Utf8(\"1\")Utf8(\"1\")#a"),
(3, "#a + Utf8(\"1\")Utf8(\"1\")#a"),
(1, ""),
(2, ""),
(6, "AggregateFunction-AVGfalseColumn-c"),
(6, "AVG(#c)#c"),
(5, ""),
(8, ""),
(8, "")
]
.into_iter()
.map(|(number, id)| (number, id.into()))
.collect::<Vec<_>>();
assert_eq!(id_array, expected);
assert_eq!(expected, id_array);

Ok(())
}
Expand Down Expand Up @@ -825,11 +660,11 @@ mod test {
)?
.build()?;

let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(#BinaryExpr-*BinaryExpr--Column-test.bLiteral1Column-test.a AS test.a * Int32(1) - test.b), SUM(#BinaryExpr-*BinaryExpr--Column-test.bLiteral1Column-test.a AS test.a * Int32(1) - test.b * Int32(1) + #test.c)]]\
\n Projection: #test.a * Int32(1) - #test.b AS BinaryExpr-*BinaryExpr--Column-test.bLiteral1Column-test.a, #test.a, #test.b, #test.c\
let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(##test.a * Int32(1) - #test.bInt32(1) - #test.b#test.bInt32(1)#test.a AS test.a * Int32(1) - test.b), SUM(##test.a * Int32(1) - #test.bInt32(1) - #test.b#test.bInt32(1)#test.a AS test.a * Int32(1) - test.b * Int32(1) + #test.c)]]\
\n Projection: #test.a * Int32(1) - #test.b AS #test.a * Int32(1) - #test.bInt32(1) - #test.b#test.bInt32(1)#test.a, #test.a, #test.b, #test.c\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected);
assert_optimized_plan_eq(expected, &plan);

Ok(())
}
Expand All @@ -848,11 +683,11 @@ mod test {
)?
.build()?;

let expected = "Aggregate: groupBy=[[]], aggr=[[Int32(1) + #AggregateFunction-AVGfalseColumn-test.a AS AVG(test.a), Int32(1) - #AggregateFunction-AVGfalseColumn-test.a AS AVG(test.a)]]\
\n Projection: AVG(#test.a) AS AggregateFunction-AVGfalseColumn-test.a, #test.a, #test.b, #test.c\
let expected = "Aggregate: groupBy=[[]], aggr=[[Int32(1) + #AVG(#test.a)#test.a AS AVG(test.a), Int32(1) - #AVG(#test.a)#test.a AS AVG(test.a)]]\
\n Projection: AVG(#test.a) AS AVG(#test.a)#test.a, #test.a, #test.b, #test.c\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected);
assert_optimized_plan_eq(expected, &plan);

Ok(())
}
Expand All @@ -868,11 +703,11 @@ mod test {
])?
.build()?;

let expected = "Projection: #BinaryExpr-+Column-test.aLiteral1 AS Int32(1) + test.a AS first, #BinaryExpr-+Column-test.aLiteral1 AS Int32(1) + test.a AS second\
\n Projection: Int32(1) + #test.a AS BinaryExpr-+Column-test.aLiteral1, #test.a, #test.b, #test.c\
let expected = "Projection: #Int32(1) + #test.a#test.aInt32(1) AS Int32(1) + test.a AS first, #Int32(1) + #test.a#test.aInt32(1) AS Int32(1) + test.a AS second\
\n Projection: Int32(1) + #test.a AS Int32(1) + #test.a#test.aInt32(1), #test.a, #test.b, #test.c\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected);
assert_optimized_plan_eq(expected, &plan);

Ok(())
}
Expand All @@ -891,7 +726,7 @@ mod test {
let expected = "Projection: Int32(1) + #test.a, #test.a + Int32(1)\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected);
assert_optimized_plan_eq(expected, &plan);

Ok(())
}
Expand All @@ -909,7 +744,7 @@ mod test {
\n Projection: Int32(1) + #test.a\
\n TableScan: test";

assert_optimized_plan_eq(&plan, expected);
assert_optimized_plan_eq(expected, &plan);

Ok(())
}
Expand Down

0 comments on commit 786c319

Please sign in to comment.