From 786c319d62ee76ce7acfab9b5e38a4749e3aae24 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 3 Sep 2022 06:36:18 -0600 Subject: [PATCH] Re-implement ExprIdentifierVisitor::desc_expr to use Expr::Display (#3339) --- datafusion/core/tests/sql/predicates.rs | 2 +- .../optimizer/src/common_subexpr_eliminate.rs | 207 ++---------------- 2 files changed, 22 insertions(+), 187 deletions(-) diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs index 93aa8c3fba83..f7bdc41a93a2 100644 --- a/datafusion/core/tests/sql/predicates.rs +++ b/datafusion/core/tests/sql/predicates.rs @@ -427,7 +427,7 @@ async fn multiple_or_predicates() -> Result<()> { let expected =vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: #lineitem.l_partkey [l_partkey:Int64]", - " Projection: #part.p_partkey = #lineitem.l_partkey AS BinaryExpr-=Column-lineitem.l_partkeyColumn-part.p_partkey, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [BinaryExpr-=Column-lineitem.l_partkeyColumn-part.p_partkey:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]", + " Projection: #part.p_partkey = #lineitem.l_partkey AS #part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]", " Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Int64(1) AND #lineitem.l_quantity <= Int64(11) AND #part.p_size BETWEEN Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Int64(10) AND #lineitem.l_quantity <= Int64(20) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Int64(20) AND #lineitem.l_quantity <= Int64(30) AND #part.p_size BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]", " CrossJoin: [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]", " TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Float64]", diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 2d148a087449..239939f81d66 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -22,7 +22,6 @@ use arrow::datatypes::DataType; use datafusion_common::{DFField, DFSchema, DataFusionError, Result}; use datafusion_expr::{ col, - expr::GroupingSet, expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}, expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}, logical_plan::{Aggregate, Filter, LogicalPlan, Projection, Sort, Window}, @@ -30,7 +29,6 @@ use datafusion_expr::{ Expr, ExprSchemable, }; use std::collections::{HashMap, HashSet}; -use std::fmt::Write; use std::sync::Arc; /// A map from expression's identifier to tuple including @@ -392,171 +390,7 @@ enum VisitRecord { impl ExprIdentifierVisitor<'_> { fn desc_expr(expr: &Expr) -> String { - let mut desc = String::new(); - match expr { - Expr::Column(column) => { - desc.push_str("Column-"); - desc.push_str(&column.flat_name()); - } - Expr::ScalarVariable(_, var_names) => { - desc.push_str("ScalarVariable-"); - desc.push_str(&var_names.join(".")); - } - Expr::Alias(_, alias) => { - desc.push_str("Alias-"); - desc.push_str(alias); - } - Expr::Literal(value) => { - desc.push_str("Literal"); - desc.push_str(&value.to_string()); - } - Expr::BinaryExpr { op, .. } => { - desc.push_str("BinaryExpr-"); - desc.push_str(&op.to_string()); - } - Expr::Not(_) => { - desc.push_str("Not-"); - } - Expr::IsNotNull(_) => { - desc.push_str("IsNotNull-"); - } - Expr::IsNull(_) => { - desc.push_str("IsNull-"); - } - Expr::IsTrue(_) => { - desc.push_str("IsTrue-"); - } - Expr::IsFalse(_) => { - desc.push_str("IsFalse-"); - } - Expr::IsUnknown(_) => { - desc.push_str("IsUnknown-"); - } - Expr::IsNotTrue(_) => { - desc.push_str("IsNotTrue-"); - } - Expr::IsNotFalse(_) => { - desc.push_str("IsNotFalse-"); - } - Expr::IsNotUnknown(_) => { - desc.push_str("IsNotUnknown-"); - } - Expr::Negative(_) => { - desc.push_str("Negative-"); - } - Expr::Between { negated, .. } => { - desc.push_str("Between-"); - desc.push_str(&negated.to_string()); - } - Expr::Like { negated, .. } => { - desc.push_str("Like-"); - desc.push_str(&negated.to_string()); - } - Expr::ILike { negated, .. } => { - desc.push_str("ILike-"); - desc.push_str(&negated.to_string()); - } - Expr::SimilarTo { negated, .. } => { - desc.push_str("SimilarTo-"); - desc.push_str(&negated.to_string()); - } - Expr::Case { .. } => { - desc.push_str("Case-"); - } - Expr::Cast { data_type, .. } => { - desc.push_str("Cast-"); - let _ = write!(desc, "{:?}", data_type); - } - Expr::TryCast { data_type, .. } => { - desc.push_str("TryCast-"); - let _ = write!(desc, "{:?}", data_type); - } - Expr::Sort { - asc, nulls_first, .. - } => { - desc.push_str("Sort-"); - let _ = write!(desc, "{}{}", asc, nulls_first); - } - Expr::ScalarFunction { fun, .. } => { - desc.push_str("ScalarFunction-"); - desc.push_str(&fun.to_string()); - } - Expr::ScalarUDF { fun, .. } => { - desc.push_str("ScalarUDF-"); - desc.push_str(&fun.name); - } - Expr::WindowFunction { - fun, window_frame, .. - } => { - desc.push_str("WindowFunction-"); - desc.push_str(&fun.to_string()); - let _ = write!(desc, "{:?}", window_frame); - } - Expr::AggregateFunction { fun, distinct, .. } => { - desc.push_str("AggregateFunction-"); - desc.push_str(&fun.to_string()); - desc.push_str(&distinct.to_string()); - } - Expr::AggregateUDF { fun, .. } => { - desc.push_str("AggregateUDF-"); - desc.push_str(&fun.name); - } - Expr::InList { negated, .. } => { - desc.push_str("InList-"); - desc.push_str(&negated.to_string()); - } - Expr::Exists { negated, .. } => { - desc.push_str("Exists-"); - desc.push_str(&negated.to_string()); - } - Expr::InSubquery { negated, .. } => { - desc.push_str("InSubquery-"); - desc.push_str(&negated.to_string()); - } - Expr::ScalarSubquery(_) => { - desc.push_str("ScalarSubquery-"); - } - Expr::Wildcard => { - desc.push_str("Wildcard-"); - } - Expr::QualifiedWildcard { qualifier } => { - desc.push_str("QualifiedWildcard-"); - desc.push_str(qualifier); - } - Expr::GetIndexedField { key, .. } => { - desc.push_str("GetIndexedField-"); - desc.push_str(&key.to_string()); - } - Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => { - desc.push_str("Rollup"); - for expr in exprs { - desc.push('-'); - desc.push_str(&Self::desc_expr(expr)); - } - } - GroupingSet::Cube(exprs) => { - desc.push_str("Cube"); - for expr in exprs { - desc.push('-'); - desc.push_str(&Self::desc_expr(expr)); - } - } - GroupingSet::GroupingSets(lists_of_exprs) => { - desc.push_str("GroupingSets"); - for exprs in lists_of_exprs { - desc.push('('); - for expr in exprs { - desc.push('-'); - desc.push_str(&Self::desc_expr(expr)); - } - desc.push(')'); - } - } - }, - } - - desc + format!("{}", expr) } /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` @@ -749,13 +583,13 @@ mod test { }; use std::iter; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { + fn assert_optimized_plan_eq(expected: &str, plan: &LogicalPlan) { let optimizer = CommonSubexprEliminate {}; let optimized_plan = optimizer .optimize(plan, &mut OptimizerConfig::new()) .expect("failed to optimize plan"); let formatted_plan = format!("{:?}", optimized_plan); - assert_eq!(formatted_plan, expected); + assert_eq!(expected, formatted_plan); } #[test] @@ -774,19 +608,20 @@ mod test { expr_to_identifier(&expr, &mut HashMap::new(), &mut id_array, DataType::Int64)?; let expected = vec![ - (9, "BinaryExpr-*Literal2BinaryExpr--AggregateFunction-AVGfalseColumn-cAggregateFunction-SUMfalseBinaryExpr-+Literal1Column-a"), - (7, "BinaryExpr--AggregateFunction-AVGfalseColumn-cAggregateFunction-SUMfalseBinaryExpr-+Literal1Column-a"), - (4, "AggregateFunction-SUMfalseBinaryExpr-+Literal1Column-a"), (3, "BinaryExpr-+Literal1Column-a"), + (9, "SUM(#a + Utf8(\"1\")) - AVG(#c) * Int32(2)Int32(2)SUM(#a + Utf8(\"1\")) - AVG(#c)AVG(#c)#cSUM(#a + Utf8(\"1\"))#a + Utf8(\"1\")Utf8(\"1\")#a"), + (7, "SUM(#a + Utf8(\"1\")) - AVG(#c)AVG(#c)#cSUM(#a + Utf8(\"1\"))#a + Utf8(\"1\")Utf8(\"1\")#a"), + (4, "SUM(#a + Utf8(\"1\"))#a + Utf8(\"1\")Utf8(\"1\")#a"), + (3, "#a + Utf8(\"1\")Utf8(\"1\")#a"), (1, ""), (2, ""), - (6, "AggregateFunction-AVGfalseColumn-c"), + (6, "AVG(#c)#c"), (5, ""), - (8, ""), + (8, "") ] .into_iter() .map(|(number, id)| (number, id.into())) .collect::>(); - assert_eq!(id_array, expected); + assert_eq!(expected, id_array); Ok(()) } @@ -825,11 +660,11 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(#BinaryExpr-*BinaryExpr--Column-test.bLiteral1Column-test.a AS test.a * Int32(1) - test.b), SUM(#BinaryExpr-*BinaryExpr--Column-test.bLiteral1Column-test.a AS test.a * Int32(1) - test.b * Int32(1) + #test.c)]]\ - \n Projection: #test.a * Int32(1) - #test.b AS BinaryExpr-*BinaryExpr--Column-test.bLiteral1Column-test.a, #test.a, #test.b, #test.c\ + let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(##test.a * Int32(1) - #test.bInt32(1) - #test.b#test.bInt32(1)#test.a AS test.a * Int32(1) - test.b), SUM(##test.a * Int32(1) - #test.bInt32(1) - #test.b#test.bInt32(1)#test.a AS test.a * Int32(1) - test.b * Int32(1) + #test.c)]]\ + \n Projection: #test.a * Int32(1) - #test.b AS #test.a * Int32(1) - #test.bInt32(1) - #test.b#test.bInt32(1)#test.a, #test.a, #test.b, #test.c\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(expected, &plan); Ok(()) } @@ -848,11 +683,11 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[Int32(1) + #AggregateFunction-AVGfalseColumn-test.a AS AVG(test.a), Int32(1) - #AggregateFunction-AVGfalseColumn-test.a AS AVG(test.a)]]\ - \n Projection: AVG(#test.a) AS AggregateFunction-AVGfalseColumn-test.a, #test.a, #test.b, #test.c\ + let expected = "Aggregate: groupBy=[[]], aggr=[[Int32(1) + #AVG(#test.a)#test.a AS AVG(test.a), Int32(1) - #AVG(#test.a)#test.a AS AVG(test.a)]]\ + \n Projection: AVG(#test.a) AS AVG(#test.a)#test.a, #test.a, #test.b, #test.c\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(expected, &plan); Ok(()) } @@ -868,11 +703,11 @@ mod test { ])? .build()?; - let expected = "Projection: #BinaryExpr-+Column-test.aLiteral1 AS Int32(1) + test.a AS first, #BinaryExpr-+Column-test.aLiteral1 AS Int32(1) + test.a AS second\ - \n Projection: Int32(1) + #test.a AS BinaryExpr-+Column-test.aLiteral1, #test.a, #test.b, #test.c\ + let expected = "Projection: #Int32(1) + #test.a#test.aInt32(1) AS Int32(1) + test.a AS first, #Int32(1) + #test.a#test.aInt32(1) AS Int32(1) + test.a AS second\ + \n Projection: Int32(1) + #test.a AS Int32(1) + #test.a#test.aInt32(1), #test.a, #test.b, #test.c\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(expected, &plan); Ok(()) } @@ -891,7 +726,7 @@ mod test { let expected = "Projection: Int32(1) + #test.a, #test.a + Int32(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(expected, &plan); Ok(()) } @@ -909,7 +744,7 @@ mod test { \n Projection: Int32(1) + #test.a\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(expected, &plan); Ok(()) }