From 10bafa9cabb80a31c469c8d1b6889918fe86d418 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Thu, 5 Jan 2023 01:24:22 -0500 Subject: [PATCH 01/12] Support non-tuple expression for in-subquery to join --- datafusion/core/tests/sql/subqueries.rs | 6 + datafusion/expr/src/utils.rs | 2 +- .../optimizer/src/decorrelate_where_in.rs | 282 ++++++++++++------ datafusion/optimizer/src/test/mod.rs | 1 + 4 files changed, 201 insertions(+), 90 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 3fff5ba3e80c..960f2340b1ef 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -100,6 +100,12 @@ where o_orderstatus in ( SubqueryAlias: __correlated_sq_1 Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey AS l_orderkey TableScan: lineitem projection=[l_orderkey, l_linestatus]"#; + let expected = "Projection: orders.o_orderkey\ + \n LeftSemi Join: orders.o_orderstatus = __correlated_sq_1.l_linestatus, orders.o_orderkey = __correlated_sq_1.l_orderkey\ + \n TableScan: orders projection=[o_orderkey, o_orderstatus]\ + \n SubqueryAlias: __correlated_sq_1\ + \n Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey\ + \n TableScan: lineitem projection=[l_orderkey, l_linestatus]"; assert_eq!(actual, expected); // assert data diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 8cf79c612e32..6857c46bc801 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -951,7 +951,7 @@ pub fn can_hash(data_type: &DataType) -> bool { } /// Check whether all columns are from the schema. -fn check_all_column_from_schema(columns: &HashSet, schema: DFSchemaRef) -> bool { +pub fn check_all_column_from_schema(columns: &HashSet, schema: DFSchemaRef) -> bool { columns .iter() .all(|column| schema.index_of_column(column).is_ok()) diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 1aa976ce8ca7..7d793ef8a190 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -17,15 +17,15 @@ use crate::alias::AliasGenerator; use crate::optimizer::ApplyOrder; -use crate::utils::{ - alias_cols, conjunction, exprs_to_join_cols, find_join_exprs, merge_cols, - only_or_err, split_conjunction, swap_table, verify_not_disjunction, -}; +use crate::utils::{conjunction, only_or_err, split_conjunction}; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{context, Result}; +use datafusion_common::{context, Column, Result}; +use datafusion_expr::expr_rewriter::{replace_col, unnormalize_col}; use datafusion_expr::logical_plan::{JoinType, Projection, Subquery}; +use datafusion_expr::utils::check_all_column_from_schema; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; use log::debug; +use std::collections::HashMap; use std::sync::Arc; #[derive(Default)] @@ -129,61 +129,132 @@ fn optimize_where_in( ) -> Result { let proj = Projection::try_from_plan(&query_info.query.subquery) .map_err(|e| context!("a projection is required", e))?; - let mut subqry_input = proj.input.clone(); + let subqry_input = proj.input.clone(); let proj = only_or_err(proj.expr.as_slice()) .map_err(|e| context!("single expression projection required", e))?; - let subquery_col = proj - .try_into_col() - .map_err(|e| context!("single column projection required", e))?; - let outer_col = query_info - .where_in_expr - .try_into_col() - .map_err(|e| context!("column comparison required", e))?; - - // If subquery is correlated, grab necessary information - let mut subqry_cols = vec![]; - let mut outer_cols = vec![]; - let mut join_filters = None; - let mut other_subqry_exprs = vec![]; - if let LogicalPlan::Filter(subqry_filter) = (*subqry_input).clone() { - // split into filters + let subquery_expr = proj; + + // build subquery side of join - the thing the subquery was querying + let subqry_alias = alias.next("__correlated_sq"); + let unormalize_subquery_expr = unnormalize_col(subquery_expr.clone()); + let right_join_col = Expr::Column(Column::from_qualified_name(format!( + "{subqry_alias}.{unormalize_subquery_expr:?}" + ))); + + // split join filters and subquery filters + let (join_filters, right_input) = if let LogicalPlan::Filter(subqry_filter) = + (*subqry_input).clone() + { + let input_schema = subqry_filter.input.schema(); let subqry_filter_exprs = split_conjunction(&subqry_filter.predicate); - verify_not_disjunction(&subqry_filter_exprs)?; - - // Grab column names to join on - let (col_exprs, other_exprs) = - find_join_exprs(subqry_filter_exprs, subqry_filter.input.schema()) - .map_err(|e| context!("column correlation not found", e))?; - if !col_exprs.is_empty() { - // it's correlated - subqry_input = subqry_filter.input.clone(); - (outer_cols, subqry_cols, join_filters) = - exprs_to_join_cols(&col_exprs, subqry_filter.input.schema(), false) - .map_err(|e| context!("column correlation not found", e))?; - other_subqry_exprs = other_exprs; + + let mut join_filters: Vec = vec![]; + let mut subquery_filters: Vec = vec![]; + for expr in subqry_filter_exprs { + let referenced_columns = expr.to_columns()?; + if check_all_column_from_schema(&referenced_columns, input_schema.clone()) { + subquery_filters.push(expr.clone()); + } else { + join_filters.push(expr.clone()) + } } - } - let (subqry_cols, outer_cols) = - merge_cols((&[subquery_col], &subqry_cols), (&[outer_col], &outer_cols)); + let mut plan = LogicalPlanBuilder::from((*subqry_filter.input).clone()); + if let Some(expr) = conjunction(subquery_filters) { + // if the subquery had additional expressions, restore them + plan = plan.filter(expr)? + } - // build subquery side of join - the thing the subquery was querying - let subqry_alias = alias.next("__correlated_sq"); - let mut subqry_plan = LogicalPlanBuilder::from((*subqry_input).clone()); - if let Some(expr) = conjunction(other_subqry_exprs) { - // if the subquery had additional expressions, restore them - subqry_plan = subqry_plan.filter(expr)? - } - let projection = alias_cols(&subqry_cols); + (join_filters, plan.build()?) + } else { + (vec![], subqry_input.as_ref().clone()) + }; + + // merge same filter + let in_filter = Expr::eq(query_info.where_in_expr.clone(), subquery_expr.clone()); + let join_filters = join_filters + .into_iter() + .filter(|filter| { + if filter == &in_filter { + return false; + } + + !match (filter, &in_filter) { + (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => { + (a_expr.op == b_expr.op) + && (a_expr.left == b_expr.left && a_expr.right == b_expr.right) + || (a_expr.left == b_expr.right && a_expr.right == b_expr.left) + } + _ => false, + } + }) + .collect::>(); + + // replace qualified name to alias. + let join_filter = join_filters.into_iter().reduce(Expr::and); + let input_schema = right_input.schema(); + let inner_and_outer_cols = + join_filter.as_ref().map_or(Result::Ok(vec![]), |filter| { + let mut referenced_cols = filter + .to_columns()? + .into_iter() + .filter_map(|col| { + if input_schema.field_from_column(&col).is_ok() { + let outer_col = Column::from_qualified_name(format!( + "{}.{}", + subqry_alias, col.name + )); + Some((col, outer_col)) + } else { + None + } + }) + .collect::>(); + + referenced_cols.dedup(); + Ok(referenced_cols) + })?; + + let join_cols_replace_map: HashMap<&Column, &Column> = inner_and_outer_cols + .iter() + .map(|cols| (&cols.0, &cols.1)) + .collect(); + let aliased_join_filter = join_filter.map_or(Ok(None), |filter| { + replace_col(filter, &join_cols_replace_map).map(Option::Some) + })?; + + // merge same projection item + let (subquery_cols, _): (Vec, Vec) = + inner_and_outer_cols.into_iter().unzip(); + let subquery_cols = if let Expr::Column(in_col) = subquery_expr { + subquery_cols + .into_iter() + .filter(|col| col != in_col) + .collect() + } else { + subquery_cols + }; + + // build projection + let projection = [subquery_expr + .clone() + .alias(format!("{unormalize_subquery_expr:?}"))] + .into_iter() + .chain(subquery_cols.into_iter().map(Expr::Column)) + .collect::>(); + + // merge in predicate to join filter + let in_predicate = Expr::eq(query_info.where_in_expr.clone(), right_join_col); + let join_filter = aliased_join_filter + .map(|filter| in_predicate.clone().and(filter)) + .unwrap_or_else(|| in_predicate); + + // projection + let subqry_plan = LogicalPlanBuilder::from(right_input); let subqry_plan = subqry_plan .project(projection)? .alias(&subqry_alias)? .build()?; - debug!("subquery plan:\n{}", subqry_plan.display_indent()); - - // qualify the join columns for outside the subquery - let subqry_cols = swap_table(&subqry_alias, &subqry_cols); - let join_keys = (outer_cols, subqry_cols); // join our sub query into the main plan let join_type = match query_info.negated { @@ -193,8 +264,8 @@ fn optimize_where_in( let mut new_plan = LogicalPlanBuilder::from(outer_input.clone()).join( subqry_plan, join_type, - join_keys, - join_filters, + (Vec::::new(), Vec::::new()), + Some(join_filter), )?; if let Some(expr) = conjunction(outer_other_exprs.to_vec()) { new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them @@ -263,8 +334,8 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq_1.c AS c [c:UInt32]\ @@ -272,7 +343,6 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ \n Projection: sq_2.c AS c [c:UInt32]\ \n TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) } @@ -293,7 +363,7 @@ mod tests { let expected = "Projection: test.b [b:UInt32]\ \n Filter: test.a = UInt32(1) AND test.b < UInt32(30) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq.c AS c [c:UInt32]\ @@ -347,7 +417,7 @@ mod tests { \n Subquery: [c:UInt32]\ \n Projection: sq1.c [c:UInt32]\ \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq2.c AS c [c:UInt32]\ @@ -372,11 +442,11 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: test.b = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.b = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\ \n Projection: sq.a AS a [a:UInt32]\ - \n LeftSemi Join: sq.a = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: sq.a = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ \n Projection: sq_nested.c AS c [c:UInt32]\ @@ -401,14 +471,14 @@ mod tests { .project(vec![col("b")])? .build()?; - let expected = "Projection: wrapped.b [b:UInt32]\ + let expected = "Projection: wrapped.b [b:UInt32]\ \n Filter: wrapped.b < UInt32(30) OR wrapped.c IN () [b:UInt32, c:UInt32]\ \n Subquery: [c:UInt32]\ \n Projection: sq_outer.c [c:UInt32]\ \n TableScan: sq_outer [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: wrapped [b:UInt32, c:UInt32]\ \n Projection: test.b, test.c [b:UInt32, c:UInt32]\ - \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq_inner.c AS c [c:UInt32]\ @@ -443,14 +513,16 @@ mod tests { debug!("plan to optimize:\n{}", plan.display_indent()); let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ + \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ + \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + assert_optimized_plan_eq_display_indent( Arc::new(DecorrelateWhereIn::new()), &plan, @@ -486,11 +558,11 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: orders.o_orderkey = __correlated_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey AS l_orderkey [l_orderkey:Int64]\ @@ -524,7 +596,7 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ @@ -554,14 +626,12 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - // Query will fail, but we can still transform the plan let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ - \n Filter: customer.c_custkey = customer.c_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_optimized_plan_eq_display_indent( Arc::new(DecorrelateWhereIn::new()), @@ -587,7 +657,7 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ @@ -618,7 +688,7 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey Filter: customer.c_custkey != orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ @@ -647,13 +717,27 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - // can't optimize on arbitrary expressions (yet) - assert_optimizer_err( + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_optimized_plan_eq_display_indent( Arc::new(DecorrelateWhereIn::new()), &plan, - "column correlation not found", + expected, ); Ok(()) + + // // can't optimize on arbitrary expressions (yet) + // assert_optimizer_err( + // Arc::new(DecorrelateWhereIn::new()), + // &plan, + // "column correlation not found", + // ); + // Ok(()) } /// Test for correlated IN subquery filter with subquery disjunction @@ -675,11 +759,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimizer_err( + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]\ + \n Projection: orders.o_custkey AS o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_optimized_plan_eq_display_indent( Arc::new(DecorrelateWhereIn::new()), &plan, - "Optimizing disjunctions not supported!", + expected, ); + Ok(()) } @@ -721,11 +813,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - // TODO: support join on expression - assert_optimizer_err( + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_optimized_plan_eq_display_indent( Arc::new(DecorrelateWhereIn::new()), &plan, - "column comparison required", + expected, ); Ok(()) } @@ -745,11 +843,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - // TODO: support join on expressions? - assert_optimizer_err( + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __correlated_sq_1 [o_custkey + Int32(1):Int64, o_custkey:Int64]\ + \n Projection: orders.o_custkey + Int32(1) AS o_custkey + Int32(1), orders.o_custkey [o_custkey + Int32(1):Int64, o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_optimized_plan_eq_display_indent( Arc::new(DecorrelateWhereIn::new()), &plan, - "single column projection required", + expected, ); Ok(()) } @@ -800,7 +904,7 @@ mod tests { let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ @@ -865,10 +969,10 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: test.c = __correlated_sq_1.c, test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ - \n Projection: sq.c AS c, sq.a AS a [c:UInt32, a:UInt32]\ + \n Projection: sq.c AS c, sq.a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq_display_indent( @@ -889,7 +993,7 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq.c AS c [c:UInt32]\ @@ -913,7 +1017,7 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftAnti Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq.c AS c [c:UInt32]\ diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index bd8c1f98a0c5..3f0d5578aafb 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -137,6 +137,7 @@ pub fn assert_optimized_plan_eq_display_indent( .unwrap_or_else(|| plan.clone()); let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); assert_eq!(formatted_plan, expected); + } pub fn assert_optimizer_err( From 6b929df75f032f37db49ce9235c82d06706258d7 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Sat, 7 Jan 2023 03:03:52 -0500 Subject: [PATCH 02/12] add tests --- datafusion/core/tests/sql/joins.rs | 134 ++++++++ datafusion/core/tests/sql/subqueries.rs | 7 +- .../optimizer/src/decorrelate_where_in.rs | 325 +++++++++++------- 3 files changed, 345 insertions(+), 121 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 4cc9628d15ae..e0017f4688fd 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2810,3 +2810,137 @@ async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn subquery_to_join_with_both_side_expr() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in (select t2.t2_id + 1 from t2)"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]", + " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "| 33 | c | 3 |", + "| 44 | d | 4 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn subquery_to_join_with_muti_filter() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in + (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t2.t2_int > 0)"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N]", + " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N]", + " Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N, t2_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "| 33 | c | 3 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn three_projection_exprs_subquery_to_join() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in + (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t1.t1_name != t2.t2_name and t2.t2_int > 0)"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int AND t1.t1_name != __correlated_sq_1.t2_name [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]", + " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int, t2.t2_name [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]", + " Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "| 33 | c | 3 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 960f2340b1ef..a5a99a7af7df 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -94,12 +94,7 @@ where o_orderstatus in ( let dataframe = ctx.sql(sql).await.unwrap(); let plan = dataframe.into_optimized_plan().unwrap(); let actual = format!("{}", plan.display_indent()); - let expected = r#"Projection: orders.o_orderkey - LeftSemi Join: orders.o_orderstatus = __correlated_sq_1.l_linestatus, orders.o_orderkey = __correlated_sq_1.l_orderkey - TableScan: orders projection=[o_orderkey, o_orderstatus] - SubqueryAlias: __correlated_sq_1 - Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey AS l_orderkey - TableScan: lineitem projection=[l_orderkey, l_linestatus]"#; + let expected = "Projection: orders.o_orderkey\ \n LeftSemi Join: orders.o_orderstatus = __correlated_sq_1.l_linestatus, orders.o_orderkey = __correlated_sq_1.l_orderkey\ \n TableScan: orders projection=[o_orderkey, o_orderstatus]\ diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 7d793ef8a190..fb8af8afe003 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -25,7 +25,7 @@ use datafusion_expr::logical_plan::{JoinType, Projection, Subquery}; use datafusion_expr::utils::check_all_column_from_schema; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; use log::debug; -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; #[derive(Default)] @@ -123,63 +123,126 @@ impl OptimizerRule for DecorrelateWhereIn { fn optimize_where_in( query_info: &SubqueryInfo, - outer_input: &LogicalPlan, + left: &LogicalPlan, outer_other_exprs: &[Expr], alias: &AliasGenerator, ) -> Result { let proj = Projection::try_from_plan(&query_info.query.subquery) .map_err(|e| context!("a projection is required", e))?; - let subqry_input = proj.input.clone(); - let proj = only_or_err(proj.expr.as_slice()) + let subquery_input = proj.input.clone(); + let subquery_expr = only_or_err(proj.expr.as_slice()) .map_err(|e| context!("single expression projection required", e))?; - let subquery_expr = proj; - - // build subquery side of join - the thing the subquery was querying - let subqry_alias = alias.next("__correlated_sq"); - let unormalize_subquery_expr = unnormalize_col(subquery_expr.clone()); - let right_join_col = Expr::Column(Column::from_qualified_name(format!( - "{subqry_alias}.{unormalize_subquery_expr:?}" - ))); - - // split join filters and subquery filters - let (join_filters, right_input) = if let LogicalPlan::Filter(subqry_filter) = - (*subqry_input).clone() - { - let input_schema = subqry_filter.input.schema(); - let subqry_filter_exprs = split_conjunction(&subqry_filter.predicate); + + // extract join filters + let (join_filters, right_input) = extract_join_filters(subquery_input.as_ref())?; + + // in_predicate may be also include in the join filters, remove it from the join filters. + let in_predicate = Expr::eq(query_info.where_in_expr.clone(), subquery_expr.clone()); + let join_filters = remove_duplicate_filter(join_filters, in_predicate); + + // replace qualified name with subquery alias. + let subquery_alias = alias.next("__correlated_sq"); + let input_schema = right_input.schema(); + let mut subquery_cols = + join_filters + .iter() + .try_fold(BTreeSet::::new(), |mut cols, expr| { + let using_cols: Vec = expr + .to_columns()? + .into_iter() + .filter(|col| input_schema.field_from_column(col).is_ok()) + .collect::<_>(); + + cols.extend(using_cols); + Result::Ok(cols) + })?; + let join_filter = conjunction(join_filters).map_or(Ok(None), |filter| { + replace_qualify_name(filter, &subquery_cols, &subquery_alias).map(Option::Some) + })?; + + // add projection + if let Expr::Column(col) = subquery_expr { + subquery_cols.remove(col); + } + let subquery_expr_name = format!("{:?}", unnormalize_col(subquery_expr.clone())); + let first_expr = subquery_expr.clone().alias(subquery_expr_name.clone()); + let projection_exprs = [first_expr] + .into_iter() + .chain(subquery_cols.into_iter().map(Expr::Column)) + .collect::>(); + + let right = LogicalPlanBuilder::from(right_input) + .project(projection_exprs)? + .alias(&subquery_alias)? + .build()?; + + // join our sub query into the main plan + let join_type = match query_info.negated { + true => JoinType::LeftAnti, + false => JoinType::LeftSemi, + }; + let right_join_col = Column::new(Some(subquery_alias), subquery_expr_name); + let in_predicate = Expr::eq( + query_info.where_in_expr.clone(), + Expr::Column(right_join_col), + ); + let join_filter = join_filter + .map(|filter| in_predicate.clone().and(filter)) + .unwrap_or_else(|| in_predicate); + + let mut new_plan = LogicalPlanBuilder::from(left.clone()).join( + right, + join_type, + (Vec::::new(), Vec::::new()), + Some(join_filter), + )?; + if let Some(expr) = conjunction(outer_other_exprs.to_vec()) { + new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them + } + let new_plan = new_plan.build()?; + + debug!("where in optimized:\n{}", new_plan.display_indent()); + Ok(new_plan) +} + +fn extract_join_filters(maybe_filter: &LogicalPlan) -> Result<(Vec, LogicalPlan)> { + if let LogicalPlan::Filter(plan_filter) = maybe_filter { + let input_schema = plan_filter.input.schema(); + let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); let mut join_filters: Vec = vec![]; let mut subquery_filters: Vec = vec![]; - for expr in subqry_filter_exprs { - let referenced_columns = expr.to_columns()?; - if check_all_column_from_schema(&referenced_columns, input_schema.clone()) { + for expr in subquery_filter_exprs { + let cols = expr.to_columns()?; + if check_all_column_from_schema(&cols, input_schema.clone()) { subquery_filters.push(expr.clone()); } else { join_filters.push(expr.clone()) } } - let mut plan = LogicalPlanBuilder::from((*subqry_filter.input).clone()); + // if the subquery still has filter expressions, restore them. + let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone()); if let Some(expr) = conjunction(subquery_filters) { - // if the subquery had additional expressions, restore them plan = plan.filter(expr)? } - (join_filters, plan.build()?) + Ok((join_filters, plan.build()?)) } else { - (vec![], subqry_input.as_ref().clone()) - }; + Ok((vec![], maybe_filter.clone())) + } +} - // merge same filter - let in_filter = Expr::eq(query_info.where_in_expr.clone(), subquery_expr.clone()); - let join_filters = join_filters +fn remove_duplicate_filter(filters: Vec, in_predicate: Expr) -> Vec { + filters .into_iter() .filter(|filter| { - if filter == &in_filter { + if filter == &in_predicate { return false; } - !match (filter, &in_filter) { + // ignore the binary order + !match (filter, &in_predicate) { (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => { (a_expr.op == b_expr.op) && (a_expr.left == b_expr.left && a_expr.right == b_expr.right) @@ -188,92 +251,24 @@ fn optimize_where_in( _ => false, } }) - .collect::>(); + .collect::>() +} - // replace qualified name to alias. - let join_filter = join_filters.into_iter().reduce(Expr::and); - let input_schema = right_input.schema(); - let inner_and_outer_cols = - join_filter.as_ref().map_or(Result::Ok(vec![]), |filter| { - let mut referenced_cols = filter - .to_columns()? - .into_iter() - .filter_map(|col| { - if input_schema.field_from_column(&col).is_ok() { - let outer_col = Column::from_qualified_name(format!( - "{}.{}", - subqry_alias, col.name - )); - Some((col, outer_col)) - } else { - None - } - }) - .collect::>(); - - referenced_cols.dedup(); - Ok(referenced_cols) - })?; - - let join_cols_replace_map: HashMap<&Column, &Column> = inner_and_outer_cols +fn replace_qualify_name( + expr: Expr, + cols: &BTreeSet, + subquery_alias: &str, +) -> Result { + let alias_cols = cols .iter() - .map(|cols| (&cols.0, &cols.1)) - .collect(); - let aliased_join_filter = join_filter.map_or(Ok(None), |filter| { - replace_col(filter, &join_cols_replace_map).map(Option::Some) - })?; - - // merge same projection item - let (subquery_cols, _): (Vec, Vec) = - inner_and_outer_cols.into_iter().unzip(); - let subquery_cols = if let Expr::Column(in_col) = subquery_expr { - subquery_cols - .into_iter() - .filter(|col| col != in_col) - .collect() - } else { - subquery_cols - }; - - // build projection - let projection = [subquery_expr - .clone() - .alias(format!("{unormalize_subquery_expr:?}"))] - .into_iter() - .chain(subquery_cols.into_iter().map(Expr::Column)) - .collect::>(); - - // merge in predicate to join filter - let in_predicate = Expr::eq(query_info.where_in_expr.clone(), right_join_col); - let join_filter = aliased_join_filter - .map(|filter| in_predicate.clone().and(filter)) - .unwrap_or_else(|| in_predicate); - - // projection - let subqry_plan = LogicalPlanBuilder::from(right_input); - let subqry_plan = subqry_plan - .project(projection)? - .alias(&subqry_alias)? - .build()?; - - // join our sub query into the main plan - let join_type = match query_info.negated { - true => JoinType::LeftAnti, - false => JoinType::LeftSemi, - }; - let mut new_plan = LogicalPlanBuilder::from(outer_input.clone()).join( - subqry_plan, - join_type, - (Vec::::new(), Vec::::new()), - Some(join_filter), - )?; - if let Some(expr) = conjunction(outer_other_exprs.to_vec()) { - new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them - } - let new_plan = new_plan.build()?; + .map(|col| { + Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name)) + }) + .collect::>(); + let replace_map: HashMap<&Column, &Column> = + cols.iter().zip(alias_cols.iter()).collect(); - debug!("where in optimized:\n{}", new_plan.display_indent()); - Ok(new_plan) + replace_col(expr, &replace_map) } struct SubqueryInfo { @@ -1030,4 +1025,104 @@ mod tests { ); Ok(()) } + + #[test] + fn in_subquery_both_side_expr() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + + let subquery = LogicalPlanBuilder::from(subquery_scan) + .project(vec![col("c") * lit(2u32)])? + .build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32]\ + \n Projection: sq.c * UInt32(2) AS c * UInt32(2) [c * UInt32(2):UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); + Ok(()) + } + + #[test] + fn in_subquery_join_filter_and_inner_filter() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter( + col("test.a") + .eq(col("sq.a")) + .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))), + )? + .project(vec![col("c") * lit(2u32)])? + .build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32]\ + \n Projection: sq.c * UInt32(2) AS c * UInt32(2), sq.a [c * UInt32(2):UInt32, a:UInt32]\ + \n Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); + Ok(()) + } + + #[test] + fn in_subquery_muti_project_subquery_cols() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter( + col("test.a") + .add(col("test.b")) + .eq(col("sq.a").add(col("sq.b"))) + .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))), + )? + .project(vec![col("c") * lit(2u32)])? + .build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32, b:UInt32]\ + \n Projection: sq.c * UInt32(2) AS c * UInt32(2), sq.a, sq.b [c * UInt32(2):UInt32, a:UInt32, b:UInt32]\ + \n Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); + Ok(()) + } } From 7767bb141672100420c698c89c3561dad3413ced Mon Sep 17 00:00:00 2001 From: ygf11 Date: Sat, 7 Jan 2023 03:38:41 -0500 Subject: [PATCH 03/12] add comment and fix cargo fmt --- datafusion/core/tests/sql/subqueries.rs | 2 +- datafusion/expr/src/utils.rs | 5 ++++- datafusion/optimizer/src/decorrelate_where_in.rs | 15 +++++++++++++++ datafusion/optimizer/src/test/mod.rs | 1 - 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index a5a99a7af7df..cb8aa602a769 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -100,7 +100,7 @@ where o_orderstatus in ( \n TableScan: orders projection=[o_orderkey, o_orderstatus]\ \n SubqueryAlias: __correlated_sq_1\ \n Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey\ - \n TableScan: lineitem projection=[l_orderkey, l_linestatus]"; + \n TableScan: lineitem projection=[l_orderkey, l_linestatus]"; assert_eq!(actual, expected); // assert data diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 6857c46bc801..97716961c3a6 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -951,7 +951,10 @@ pub fn can_hash(data_type: &DataType) -> bool { } /// Check whether all columns are from the schema. -pub fn check_all_column_from_schema(columns: &HashSet, schema: DFSchemaRef) -> bool { +pub fn check_all_column_from_schema( + columns: &HashSet, + schema: DFSchemaRef, +) -> bool { columns .iter() .all(|column| schema.index_of_column(column).is_ok()) diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index fb8af8afe003..225ae02b8003 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -121,6 +121,21 @@ impl OptimizerRule for DecorrelateWhereIn { } } +/// Optimize the where in subquery to left-enti/left-semi join. +/// If the subquery is a correlated subquery, we need extract the join predicate from the subquery. +/// +/// For example, given a query like: +/// `select t1.a, t1.b from t1 where t1 in (select t2.a from t2 where t1.b = t2.b and t1.c > t2.c)` +/// +/// The optimized plan will be: +/// +/// Projection: t1.a, t1.b | +/// LeftSemi Join: Filter: t1.a = __correlated_sq_1.a AND t1.b = __correlated_sq_1.b AND t1.c > __correlated_sq_1.c | +/// TableScan: t1 | +/// SubqueryAlias: __correlated_sq_1 | +/// Projection: t2.a AS a, t2.b, t2.c | +/// TableScan: t2 | +/// fn optimize_where_in( query_info: &SubqueryInfo, left: &LogicalPlan, diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 3f0d5578aafb..bd8c1f98a0c5 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -137,7 +137,6 @@ pub fn assert_optimized_plan_eq_display_indent( .unwrap_or_else(|| plan.clone()); let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); assert_eq!(formatted_plan, expected); - } pub fn assert_optimizer_err( From de0fdf31e2460ec9857cb58a5dd47f52f56bb2d4 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Sat, 7 Jan 2023 04:16:56 -0500 Subject: [PATCH 04/12] fix comment --- datafusion/optimizer/src/decorrelate_where_in.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 225ae02b8003..5d11399a510f 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -121,7 +121,7 @@ impl OptimizerRule for DecorrelateWhereIn { } } -/// Optimize the where in subquery to left-enti/left-semi join. +/// Optimize the where in subquery to left-anti/left-semi join. /// If the subquery is a correlated subquery, we need extract the join predicate from the subquery. /// /// For example, given a query like: @@ -149,7 +149,7 @@ fn optimize_where_in( .map_err(|e| context!("single expression projection required", e))?; // extract join filters - let (join_filters, right_input) = extract_join_filters(subquery_input.as_ref())?; + let (join_filters, subquery_input) = extract_join_filters(subquery_input.as_ref())?; // in_predicate may be also include in the join filters, remove it from the join filters. let in_predicate = Expr::eq(query_info.where_in_expr.clone(), subquery_expr.clone()); @@ -157,7 +157,7 @@ fn optimize_where_in( // replace qualified name with subquery alias. let subquery_alias = alias.next("__correlated_sq"); - let input_schema = right_input.schema(); + let input_schema = subquery_input.schema(); let mut subquery_cols = join_filters .iter() @@ -186,7 +186,7 @@ fn optimize_where_in( .chain(subquery_cols.into_iter().map(Expr::Column)) .collect::>(); - let right = LogicalPlanBuilder::from(right_input) + let right = LogicalPlanBuilder::from(subquery_input) .project(projection_exprs)? .alias(&subquery_alias)? .build()?; From d641b51926154d38818049e00f0802c14adcb3f4 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Sat, 7 Jan 2023 22:18:21 -0500 Subject: [PATCH 05/12] clean unused comment --- .../optimizer/src/decorrelate_where_in.rs | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 5d11399a510f..255966f6fbed 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -142,10 +142,10 @@ fn optimize_where_in( outer_other_exprs: &[Expr], alias: &AliasGenerator, ) -> Result { - let proj = Projection::try_from_plan(&query_info.query.subquery) + let projection = Projection::try_from_plan(&query_info.query.subquery) .map_err(|e| context!("a projection is required", e))?; - let subquery_input = proj.input.clone(); - let subquery_expr = only_or_err(proj.expr.as_slice()) + let subquery_input = projection.input.clone(); + let subquery_expr = only_or_err(projection.expr.as_slice()) .map_err(|e| context!("single expression projection required", e))?; // extract join filters @@ -153,12 +153,12 @@ fn optimize_where_in( // in_predicate may be also include in the join filters, remove it from the join filters. let in_predicate = Expr::eq(query_info.where_in_expr.clone(), subquery_expr.clone()); - let join_filters = remove_duplicate_filter(join_filters, in_predicate); + let join_filters = remove_duplicated_filter(join_filters, in_predicate); // replace qualified name with subquery alias. let subquery_alias = alias.next("__correlated_sq"); let input_schema = subquery_input.schema(); - let mut subquery_cols = + let mut subquery_cols: BTreeSet = join_filters .iter() .try_fold(BTreeSet::::new(), |mut cols, expr| { @@ -181,10 +181,10 @@ fn optimize_where_in( } let subquery_expr_name = format!("{:?}", unnormalize_col(subquery_expr.clone())); let first_expr = subquery_expr.clone().alias(subquery_expr_name.clone()); - let projection_exprs = [first_expr] + let projection_exprs: Vec = [first_expr] .into_iter() .chain(subquery_cols.into_iter().map(Expr::Column)) - .collect::>(); + .collect(); let right = LogicalPlanBuilder::from(subquery_input) .project(projection_exprs)? @@ -248,7 +248,7 @@ fn extract_join_filters(maybe_filter: &LogicalPlan) -> Result<(Vec, Logica } } -fn remove_duplicate_filter(filters: Vec, in_predicate: Expr) -> Vec { +fn remove_duplicated_filter(filters: Vec, in_predicate: Expr) -> Vec { filters .into_iter() .filter(|filter| { @@ -274,12 +274,12 @@ fn replace_qualify_name( cols: &BTreeSet, subquery_alias: &str, ) -> Result { - let alias_cols = cols + let alias_cols: Vec = cols .iter() .map(|col| { Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name)) }) - .collect::>(); + .collect(); let replace_map: HashMap<&Column, &Column> = cols.iter().zip(alias_cols.iter()).collect(); @@ -740,14 +740,6 @@ mod tests { expected, ); Ok(()) - - // // can't optimize on arbitrary expressions (yet) - // assert_optimizer_err( - // Arc::new(DecorrelateWhereIn::new()), - // &plan, - // "column correlation not found", - // ); - // Ok(()) } /// Test for correlated IN subquery filter with subquery disjunction From 5d25827c706945d5b257cec49effcea6334fcb8e Mon Sep 17 00:00:00 2001 From: ygf11 Date: Fri, 13 Jan 2023 16:05:34 +0800 Subject: [PATCH 06/12] Update datafusion/optimizer/src/decorrelate_where_in.rs Co-authored-by: Andrew Lamb --- datafusion/optimizer/src/decorrelate_where_in.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 255966f6fbed..6689bc3e5737 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -129,13 +129,14 @@ impl OptimizerRule for DecorrelateWhereIn { /// /// The optimized plan will be: /// -/// Projection: t1.a, t1.b | -/// LeftSemi Join: Filter: t1.a = __correlated_sq_1.a AND t1.b = __correlated_sq_1.b AND t1.c > __correlated_sq_1.c | -/// TableScan: t1 | -/// SubqueryAlias: __correlated_sq_1 | -/// Projection: t2.a AS a, t2.b, t2.c | -/// TableScan: t2 | -/// +/// ```text +/// Projection: t1.a, t1.b +/// LeftSemi Join: Filter: t1.a = __correlated_sq_1.a AND t1.b = __correlated_sq_1.b AND t1.c > __correlated_sq_1.c +/// TableScan: t1 +/// SubqueryAlias: __correlated_sq_1 +/// Projection: t2.a AS a, t2.b, t2.c +/// TableScan: t2 +/// ``` fn optimize_where_in( query_info: &SubqueryInfo, left: &LogicalPlan, From e0055c8567ff5fb1e385aa93362c105f6201a287 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Fri, 13 Jan 2023 16:05:48 +0800 Subject: [PATCH 07/12] Update datafusion/optimizer/src/decorrelate_where_in.rs Co-authored-by: Andrew Lamb --- datafusion/optimizer/src/decorrelate_where_in.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 6689bc3e5737..40002b2dcac2 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -162,7 +162,7 @@ fn optimize_where_in( let mut subquery_cols: BTreeSet = join_filters .iter() - .try_fold(BTreeSet::::new(), |mut cols, expr| { + .try_fold(BTreeSet::new(), |mut cols, expr| { let using_cols: Vec = expr .to_columns()? .into_iter() From 08a66bb2072746986b637fc2fe43bff6f6d8b80a Mon Sep 17 00:00:00 2001 From: ygf11 Date: Fri, 13 Jan 2023 16:05:58 +0800 Subject: [PATCH 08/12] Update datafusion/optimizer/src/decorrelate_where_in.rs Co-authored-by: Andrew Lamb --- datafusion/optimizer/src/decorrelate_where_in.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 40002b2dcac2..95b42b9353e7 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -270,7 +270,7 @@ fn remove_duplicated_filter(filters: Vec, in_predicate: Expr) -> Vec .collect::>() } -fn replace_qualify_name( +fn replace_qualified_name( expr: Expr, cols: &BTreeSet, subquery_alias: &str, From d4187f82bd08a4a919198e616dd02cbfce244003 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Fri, 13 Jan 2023 03:15:21 -0500 Subject: [PATCH 09/12] fix comment --- datafusion/optimizer/src/decorrelate_where_in.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 95b42b9353e7..a504a239fd38 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -173,7 +173,7 @@ fn optimize_where_in( Result::Ok(cols) })?; let join_filter = conjunction(join_filters).map_or(Ok(None), |filter| { - replace_qualify_name(filter, &subquery_cols, &subquery_alias).map(Option::Some) + replace_qualified_name(filter, &subquery_cols, &subquery_alias).map(Option::Some) })?; // add projection From 747957a6741a38ec81e34416dc1c87d8d9341253 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Sun, 15 Jan 2023 02:17:37 -0500 Subject: [PATCH 10/12] fix cargo fmt --- datafusion/core/tests/sql/joins.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 119b580483e9..2d601c942a25 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -3002,5 +3002,3 @@ async fn three_projection_exprs_subquery_to_join() -> Result<()> { Ok(()) } - - From 96d26bd9e9ed55eb3e40501876f309c2b202ac32 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Sun, 15 Jan 2023 03:58:43 -0500 Subject: [PATCH 11/12] add tests --- datafusion/core/tests/sql/joins.rs | 142 ++++++++++++++++++ .../optimizer/src/decorrelate_where_in.rs | 51 +++++++ 2 files changed, 193 insertions(+) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 2d601c942a25..7954b884be00 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -3002,3 +3002,145 @@ async fn three_projection_exprs_subquery_to_join() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn in_subquery_to_join_with_correlated_outer_filter() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in + (select t2.t2_id + 1 from t2 where t1.t1_int > 0)"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + // The `t1.t1_int > UInt32(0)` should be pushdown by `filter push down rule`. + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]", + " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "| 33 | c | 3 |", + "| 44 | d | 4 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn in_subquery_to_join_with_outer_filter() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in + (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t1.t1_name != t2.t2_name) and t1.t1_id > 0"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int AND t1.t1_name != __correlated_sq_1.t2_name [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Filter: t1.t1_id > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]", + " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int, t2.t2_name [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]", + " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "| 33 | c | 3 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn two_in_subquery_to_join_with_outer_filter() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in + (select t2.t2_id + 1 from t2) + and t1.t1_int in(select t2.t2_int + 1 from t2) + and t1.t1_id > 0"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_int AS Int64) = __correlated_sq_2.CAST(t2_int AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Filter: t1.t1_id > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]", + " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + " SubqueryAlias: __correlated_sq_2 [CAST(t2_int AS Int64) + Int64(1):Int64;N]", + " Projection: CAST(t2.t2_int AS Int64) + Int64(1) AS CAST(t2_int AS Int64) + Int64(1) [CAST(t2_int AS Int64) + Int64(1):Int64;N]", + " TableScan: t2 projection=[t2_int] [t2_int:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 44 | d | 4 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index a504a239fd38..13e3acf78876 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -96,6 +96,7 @@ impl OptimizerRule for DecorrelateWhereIn { return Ok(None); } + // iterate through all exists clauses in predicate, turning each into a join // iterate through all exists clauses in predicate, turning each into a join let mut cur_input = filter.input.as_ref().clone(); for subquery in subqueries { @@ -212,6 +213,7 @@ fn optimize_where_in( (Vec::::new(), Vec::::new()), Some(join_filter), )?; + if let Some(expr) = conjunction(outer_other_exprs.to_vec()) { new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them } @@ -1133,4 +1135,53 @@ mod tests { ); Ok(()) } + + #[test] + fn two_in_subquery_with_outer_filter() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan1 = test_table_scan_with_name("sq1")?; + let subquery_scan2 = test_table_scan_with_name("sq2")?; + + let subquery1 = LogicalPlanBuilder::from(subquery_scan1) + .filter(col("test.a").gt(col("sq1.a")))? + .project(vec![col("c") * lit(2u32)])? + .build()?; + + let subquery2 = LogicalPlanBuilder::from(subquery_scan2) + .filter(col("test.a").gt(col("sq2.a")))? + .project(vec![col("c") * lit(2u32)])? + .build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter( + in_subquery(col("c") + lit(1u32), Arc::new(subquery1)).and( + in_subquery(col("c") * lit(2u32), Arc::new(subquery2)) + .and(col("test.c").gt(lit(1u32))), + ), + )? + .project(vec![col("test.b")])? + .build()?; + + // Filter: test.c > UInt32(1) happen twice. + // issue: https://github.com/apache/arrow-datafusion/issues/4914 + let expected = "Projection: test.b [b:UInt32]\ + \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c * UInt32(2) = __correlated_sq_2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ + \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32]\ + \n Projection: sq1.c * UInt32(2) AS c * UInt32(2), sq1.a [c * UInt32(2):UInt32, a:UInt32]\ + \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ + \n SubqueryAlias: __correlated_sq_2 [c * UInt32(2):UInt32, a:UInt32]\ + \n Projection: sq2.c * UInt32(2) AS c * UInt32(2), sq2.a [c * UInt32(2):UInt32, a:UInt32]\ + \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); + Ok(()) + } } From f035229843ad1684debb9d5aee1d158b888b399f Mon Sep 17 00:00:00 2001 From: ygf11 Date: Sun, 15 Jan 2023 04:55:19 -0500 Subject: [PATCH 12/12] fix cargo fmt --- datafusion/core/tests/sql/joins.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 7954b884be00..c20c66e1016f 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -3143,4 +3143,3 @@ async fn two_in_subquery_to_join_with_outer_filter() -> Result<()> { Ok(()) } -