Skip to content

Commit

Permalink
Unparsing optimized (> 2 inputs) unions (#14031)
Browse files Browse the repository at this point in the history
* tests and optimizer in testing queries

* unparse optimized unions

* format Cargo.toml

* format Cargo.toml

* revert test

* rewrite test to avoid cyclic dep

* remove old test

* cleanup

* comments and error handling

* handle union with lt 2 inputs
  • Loading branch information
MohamedAbdeen21 authored Jan 9, 2025
1 parent ad5a04f commit 5955860
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 17 deletions.
29 changes: 16 additions & 13 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,13 +706,6 @@ impl Unparser<'_> {
Ok(())
}
LogicalPlan::Union(union) => {
if union.inputs.len() != 2 {
return not_impl_err!(
"UNION ALL expected 2 inputs, but found {}",
union.inputs.len()
);
}

// Covers cases where the UNION is a subquery and the projection is at the top level
if select.already_projected() {
return self.derive_with_dialect_alias(
Expand All @@ -729,12 +722,22 @@ impl Unparser<'_> {
.map(|input| self.select_to_sql_expr(input, query))
.collect::<Result<Vec<_>>>()?;

let union_expr = SetExpr::SetOperation {
op: ast::SetOperator::Union,
set_quantifier: ast::SetQuantifier::All,
left: Box::new(input_exprs[0].clone()),
right: Box::new(input_exprs[1].clone()),
};
if input_exprs.len() < 2 {
return internal_err!("UNION operator requires at least 2 inputs");
}

// Build the union expression tree bottom-up by reversing the order
// note that we are also swapping left and right inputs because of the rev
let union_expr = input_exprs
.into_iter()
.rev()
.reduce(|a, b| SetExpr::SetOperation {
op: ast::SetOperator::Union,
set_quantifier: ast::SetQuantifier::All,
left: Box::new(b),
right: Box::new(a),
})
.unwrap();

let Some(query) = query.as_mut() else {
return internal_err!(
Expand Down
56 changes: 52 additions & 4 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use arrow_schema::*;
use arrow_schema::{DataType, Field, Schema};
use datafusion_common::{assert_contains, DFSchema, DFSchemaRef, Result, TableReference};
use datafusion_expr::test::function_stub::{
count_udaf, max_udaf, min_udaf, sum, sum_udaf,
};
use datafusion_expr::{
col, lit, table_scan, wildcard, Expr, Extension, LogicalPlan, LogicalPlanBuilder,
UserDefinedLogicalNode, UserDefinedLogicalNodeCore,
col, lit, table_scan, wildcard, EmptyRelation, Expr, Extension, LogicalPlan,
LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore,
};
use datafusion_functions::unicode;
use datafusion_functions_aggregate::grouping::grouping_udaf;
Expand All @@ -42,7 +42,7 @@ use std::{fmt, vec};

use crate::common::{MockContextProvider, MockSessionState};
use datafusion_expr::builder::{
table_scan_with_filter_and_fetch, table_scan_with_filters,
project, table_scan_with_filter_and_fetch, table_scan_with_filters,
};
use datafusion_functions::core::planner::CoreFunctionPlanner;
use datafusion_functions_nested::extract::array_element_udf;
Expand Down Expand Up @@ -1615,3 +1615,51 @@ fn test_unparse_extension_to_sql() -> Result<()> {
}
Ok(())
}

#[test]
fn test_unparse_optimized_multi_union() -> Result<()> {
let unparser = Unparser::default();

let schema = Schema::new(vec![
Field::new("x", DataType::Int32, false),
Field::new("y", DataType::Utf8, false),
]);

let dfschema = Arc::new(DFSchema::try_from(schema)?);

let empty = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: dfschema.clone(),
});

let plan = LogicalPlan::Union(Union {
inputs: vec![
project(empty.clone(), vec![lit(1).alias("x"), lit("a").alias("y")])?.into(),
project(empty.clone(), vec![lit(1).alias("x"), lit("b").alias("y")])?.into(),
project(empty.clone(), vec![lit(2).alias("x"), lit("a").alias("y")])?.into(),
project(empty.clone(), vec![lit(2).alias("x"), lit("c").alias("y")])?.into(),
],
schema: dfschema.clone(),
});

let sql = "SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y";

assert_eq!(unparser.plan_to_sql(&plan)?.to_string(), sql);

let plan = LogicalPlan::Union(Union {
inputs: vec![project(
empty.clone(),
vec![lit(1).alias("x"), lit("a").alias("y")],
)?
.into()],
schema: dfschema.clone(),
});

if let Some(err) = plan_to_sql(&plan).err() {
assert_contains!(err.to_string(), "UNION operator requires at least 2 inputs");
} else {
panic!("Expected error")
}

Ok(())
}

0 comments on commit 5955860

Please sign in to comment.