diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index 679c81f46aca6..c343e2576c041 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -61,12 +61,13 @@ impl OptimizerRule for EliminateOneUnion { #[cfg(test)] mod tests { use super::*; - use crate::eliminate_filter::EliminateFilter; - use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::test::*; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ScalarValue; - use datafusion_expr::{logical_plan::table_scan, Expr}; + use datafusion_common::ToDFSchema; + use datafusion_expr::{ + expr_rewriter::coerce_plan_expr_for_schema, + logical_plan::{table_scan, Union}, + }; use std::sync::Arc; fn schema() -> Schema { @@ -79,11 +80,7 @@ mod tests { fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_with_rules( - vec![ - Arc::new(EliminateFilter::new()), - Arc::new(PropagateEmptyRelation::new()), - Arc::new(EliminateOneUnion::new()), - ], + vec![Arc::new(EliminateOneUnion::new())], plan, expected, ) @@ -107,19 +104,17 @@ mod tests { #[test] fn eliminate_nested_union() -> Result<()> { - let plan_builder = table_scan(Some("table"), &schema(), None)?; - - let plan = plan_builder - .clone() - .union( - plan_builder - .clone() - .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? - .build()?, - )? - .build()?; + let table_plan = coerce_plan_expr_for_schema( + &table_scan(Some("table"), &schema(), None)?.build()?, + &schema().to_dfschema()?, + )?; + let schema = table_plan.schema().clone(); + let single_union_plan = LogicalPlan::Union(Union { + inputs: vec![Arc::new(table_plan)], + schema, + }); let expected = "TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(&single_union_plan, expected) } }