Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Evgeny Maruschenko committed Oct 9, 2023
1 parent ace9c6d commit bd6c4d4
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 54 deletions.
125 changes: 95 additions & 30 deletions datafusion/optimizer/src/eliminate_nested_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@
//! Optimizer rule to replace nested unions to single union.
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::Result;
use datafusion_expr::{
builder::project_with_column_index,
expr_rewriter::coerce_plan_expr_for_schema,
logical_plan::{LogicalPlan, Projection, Union},
};
use datafusion_expr::logical_plan::{LogicalPlan, Union};

use crate::optimizer::ApplyOrder;
use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
use std::sync::Arc;

#[derive(Default)]
Expand All @@ -44,36 +41,27 @@ impl OptimizerRule for EliminateNestedUnion {
plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
// TODO: Add optimization for nested distinct unions.
match plan {
LogicalPlan::Union(union) => {
let Union { inputs, schema } = union;

let union_schema = schema.clone();

LogicalPlan::Union(Union { inputs, schema }) => {
let inputs = inputs
.into_iter()
.flat_map(|plan| match Arc::as_ref(plan) {
LogicalPlan::Union(Union { inputs, .. }) => inputs.clone(),
_ => vec![Arc::clone(plan)],
})
.map(|plan| {
let plan = coerce_plan_expr_for_schema(&plan, &union_schema)?;
match plan {
LogicalPlan::Projection(Projection {
expr, input, ..
}) => Ok(Arc::new(project_with_column_index(
expr,
input,
union_schema.clone(),
)?)),
_ => Ok(Arc::new(plan)),
}
.flat_map(|plan| match plan.as_ref() {
LogicalPlan::Union(Union { inputs, schema }) => inputs
.into_iter()
.map(|plan| {
Arc::new(
coerce_plan_expr_for_schema(plan, schema).unwrap(),
)
})
.collect::<Vec<_>>(),
_ => vec![plan.clone()],
})
.collect::<Result<Vec<_>>>()?;
.collect::<Vec<_>>();

Ok(Some(LogicalPlan::Union(Union {
inputs,
schema: union_schema,
schema: schema.clone(),
})))
}
_ => Ok(None),
Expand All @@ -94,13 +82,13 @@ mod tests {
use super::*;
use crate::test::*;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{col, logical_plan::table_scan};

fn schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("key", DataType::Utf8, false),
Field::new("value", DataType::Int32, false),
Field::new("value", DataType::Float64, false),
])
}

Expand Down Expand Up @@ -143,4 +131,81 @@ mod tests {
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

// We don't need to use project_with_column_index in logical optimizer,
// after LogicalPlanBuilder::union, we already have all equal expression aliases
#[test]
fn eliminate_nested_union_with_projection() -> Result<()> {
let plan_builder = table_scan(Some("table"), &schema(), None)?;

let plan = plan_builder
.clone()
.union(
plan_builder
.clone()
.project(vec![col("id").alias("table_id"), col("key"), col("value")])?
.build()?,
)?
.union(
plan_builder
.clone()
.project(vec![col("id").alias("_id"), col("key"), col("value")])?
.build()?,
)?
.build()?;

let expected = "Union\
\n TableScan: table\
\n Projection: table.id AS id, table.key, table.value\
\n TableScan: table\
\n Projection: table.id AS id, table.key, table.value\
\n TableScan: table";
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn eliminate_nested_union_with_type_cast_projection() -> Result<()> {
let table_1 = table_scan(
Some("table_1"),
&Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("key", DataType::Utf8, false),
Field::new("value", DataType::Float64, false),
]),
None,
)?;

let table_2 = table_scan(
Some("table_1"),
&Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("key", DataType::Utf8, false),
Field::new("value", DataType::Float32, false),
]),
None,
)?;

let table_3 = table_scan(
Some("table_1"),
&Schema::new(vec![
Field::new("id", DataType::Int16, false),
Field::new("key", DataType::Utf8, false),
Field::new("value", DataType::Float32, false),
]),
None,
)?;

let plan = table_1
.union(table_2.build()?)?
.union(table_3.build()?)?
.build()?;

let expected = "Union\
\n TableScan: table_1\
\n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
\n TableScan: table_1\
\n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
\n TableScan: table_1";
assert_optimized_plan_equal(&plan, expected)
}
}
6 changes: 2 additions & 4 deletions datafusion/optimizer/src/eliminate_one_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ impl OptimizerRule for EliminateOneUnion {
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
match plan {
LogicalPlan::Union(union) if union.inputs.len() == 1 => {
let Union { inputs, schema: _ } = union;

LogicalPlan::Union(Union { inputs, .. }) if inputs.len() == 1 => {
Ok(inputs.first().map(|input| input.as_ref().clone()))
}
_ => Ok(None),
Expand Down Expand Up @@ -103,7 +101,7 @@ mod tests {
}

#[test]
fn eliminate_nested_union() -> Result<()> {
fn eliminate_one_union() -> Result<()> {
let table_plan = coerce_plan_expr_for_schema(
&table_scan(Some("table"), &schema(), None)?.build()?,
&schema().to_dfschema()?,
Expand Down
5 changes: 3 additions & 2 deletions datafusion/optimizer/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,14 @@ impl Optimizer {
/// Create a new optimizer using the recommended list of rules
pub fn new() -> Self {
let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
Arc::new(EliminateNestedUnion::new()),
Arc::new(SimplifyExpressions::new()),
Arc::new(UnwrapCastInComparison::new()),
Arc::new(ReplaceDistinctWithAggregate::new()),
Arc::new(EliminateJoin::new()),
Arc::new(DecorrelatePredicateSubquery::new()),
Arc::new(ScalarSubqueryToJoin::new()),
Arc::new(ExtractEquijoinPredicate::new()),
Arc::new(EliminateNestedUnion::new()),
// simplify expressions does not simplify expressions in subqueries, so we
// run it again after running the optimizations that potentially converted
// subqueries to joins
Expand All @@ -242,9 +242,10 @@ impl Optimizer {
Arc::new(CommonSubexprEliminate::new()),
Arc::new(EliminateLimit::new()),
Arc::new(PropagateEmptyRelation::new()),
// Must be after PropagateEmptyRelation
Arc::new(EliminateOneUnion::new()),
Arc::new(FilterNullJoinKeys::default()),
Arc::new(EliminateOuterJoin::new()),
Arc::new(EliminateOneUnion::new()),
// Filters can't be pushed down past Limits, we should do PushDownFilter after PushDownLimit
Arc::new(PushDownLimit::new()),
Arc::new(PushDownFilter::new()),
Expand Down
18 changes: 0 additions & 18 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2053,24 +2053,6 @@ fn union_all() {
quick_test(sql, expected);
}

#[test]
fn union_4_combined_in_one() {
let sql = "SELECT order_id from orders
UNION ALL SELECT order_id FROM orders
UNION ALL SELECT order_id FROM orders
UNION ALL SELECT order_id FROM orders";
let expected = "Union\
\n Projection: orders.order_id\
\n TableScan: orders\
\n Projection: orders.order_id\
\n TableScan: orders\
\n Projection: orders.order_id\
\n TableScan: orders\
\n Projection: orders.order_id\
\n TableScan: orders";
quick_test(sql, expected);
}

#[test]
fn union_with_different_column_names() {
let sql = "SELECT order_id from orders UNION ALL SELECT customer_id FROM orders";
Expand Down
4 changes: 4 additions & 0 deletions datafusion/sqllogictest/test_files/explain.slt
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ logical_plan after inline_table_scan SAME TEXT AS ABOVE
logical_plan after type_coercion SAME TEXT AS ABOVE
logical_plan after count_wildcard_rule SAME TEXT AS ABOVE
analyzed_logical_plan SAME TEXT AS ABOVE
logical_plan after eliminate_nested_union SAME TEXT AS ABOVE
logical_plan after simplify_expressions SAME TEXT AS ABOVE
logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE
logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE
Expand All @@ -200,6 +201,7 @@ logical_plan after eliminate_cross_join SAME TEXT AS ABOVE
logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE
logical_plan after eliminate_limit SAME TEXT AS ABOVE
logical_plan after propagate_empty_relation SAME TEXT AS ABOVE
logical_plan after eliminate_one_union SAME TEXT AS ABOVE
logical_plan after filter_null_join_keys SAME TEXT AS ABOVE
logical_plan after eliminate_outer_join SAME TEXT AS ABOVE
logical_plan after push_down_limit SAME TEXT AS ABOVE
Expand All @@ -213,6 +215,7 @@ Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c
--TableScan: simple_explain_test projection=[a, b, c]
logical_plan after eliminate_projection TableScan: simple_explain_test projection=[a, b, c]
logical_plan after push_down_limit SAME TEXT AS ABOVE
logical_plan after eliminate_nested_union SAME TEXT AS ABOVE
logical_plan after simplify_expressions SAME TEXT AS ABOVE
logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE
logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE
Expand All @@ -229,6 +232,7 @@ logical_plan after eliminate_cross_join SAME TEXT AS ABOVE
logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE
logical_plan after eliminate_limit SAME TEXT AS ABOVE
logical_plan after propagate_empty_relation SAME TEXT AS ABOVE
logical_plan after eliminate_one_union SAME TEXT AS ABOVE
logical_plan after filter_null_join_keys SAME TEXT AS ABOVE
logical_plan after eliminate_outer_join SAME TEXT AS ABOVE
logical_plan after push_down_limit SAME TEXT AS ABOVE
Expand Down

0 comments on commit bd6c4d4

Please sign in to comment.