Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support non-tuple expression for in-subquery to join #4826

Merged
merged 14 commits into from
Jan 15, 2023
275 changes: 275 additions & 0 deletions datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2868,3 +2868,278 @@ async fn test_cross_join_to_groupby_with_different_key_ordering() -> Result<()>

Ok(())
}

#[tokio::test]
async fn subquery_to_join_with_both_side_expr() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;

let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in (select t2.t2_id + 1 from t2)";

// assert logical plan
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan().unwrap();

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 the plan looks good

" LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
" Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);

let expected = vec![
"+-------+---------+--------+",
"| t1_id | t1_name | t1_int |",
"+-------+---------+--------+",
"| 11 | a | 1 |",
"| 33 | c | 3 |",
"| 44 | d | 4 |",
"+-------+---------+--------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);
Comment on lines +2900 to +2911
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we already use sqllogicaltest to support Data Driven Tests #4460.
I suggest add a file join.slt | subquery.slt to cover some case of join and subquery.

In UT and integration-test, we just focus on correctness of Plan, data-test is derived by sqllogicaltest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File a ticket #4870.


Ok(())
}

#[tokio::test]
async fn subquery_to_join_with_muti_filter() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;

let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in
(select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t2.t2_int > 0)";

// assert logical plan
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan().unwrap();

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N]",
" Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N]",
" Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]",
];

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);

let expected = vec![
"+-------+---------+--------+",
"| t1_id | t1_name | t1_int |",
"+-------+---------+--------+",
"| 11 | a | 1 |",
"| 33 | c | 3 |",
"+-------+---------+--------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}

#[tokio::test]
async fn three_projection_exprs_subquery_to_join() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a minor point, I am not sure what the extra coverage the multiple predicates in the where clause are adding

I wonder if we need all these tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more tests -- multiple predicates in outer and subquery where clause.

let ctx = create_join_context("t1_id", "t2_id", false)?;

let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in
(select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t1.t1_name != t2.t2_name and t2.t2_int > 0)";

// assert logical plan
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan().unwrap();

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int AND t1.t1_name != __correlated_sq_1.t2_name [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
" Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int, t2.t2_name [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
" Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
];

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);

let expected = vec![
"+-------+---------+--------+",
"| t1_id | t1_name | t1_int |",
"+-------+---------+--------+",
"| 11 | a | 1 |",
"| 33 | c | 3 |",
"+-------+---------+--------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}

#[tokio::test]
async fn in_subquery_to_join_with_correlated_outer_filter() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;

let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in
(select t2.t2_id + 1 from t2 where t1.t1_int > 0)";

// assert logical plan
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan().unwrap();

// The `t1.t1_int > UInt32(0)` should be pushdown by `filter push down rule`.
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
" Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case show that the special case t1.t1_int > 0 is not pushed down after optimizing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that something you would like to fix in this PR or is it something you would like to fix in a follow on?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean next pr 🤣.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a issue: #4413


let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);

let expected = vec![
"+-------+---------+--------+",
"| t1_id | t1_name | t1_int |",
"+-------+---------+--------+",
"| 11 | a | 1 |",
"| 33 | c | 3 |",
"| 44 | d | 4 |",
"+-------+---------+--------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}

#[tokio::test]
async fn in_subquery_to_join_with_outer_filter() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;

let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in
(select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t1.t1_name != t2.t2_name) and t1.t1_id > 0";

// assert logical plan
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan().unwrap();

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int AND t1.t1_name != __correlated_sq_1.t2_name [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Filter: t1.t1_id > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
" Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int, t2.t2_name [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
];

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);

let expected = vec![
"+-------+---------+--------+",
"| t1_id | t1_name | t1_int |",
"+-------+---------+--------+",
"| 11 | a | 1 |",
"| 33 | c | 3 |",
"+-------+---------+--------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}

#[tokio::test]
async fn two_in_subquery_to_join_with_outer_filter() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;

let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in
(select t2.t2_id + 1 from t2)
and t1.t1_int in(select t2.t2_int + 1 from t2)
and t1.t1_id > 0";

// assert logical plan
let msg = format!("Creating logical plan for '{sql}'");
let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg);
let plan = dataframe.into_optimized_plan().unwrap();

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftSemi Join: CAST(t1.t1_int AS Int64) = __correlated_sq_2.CAST(t2_int AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Filter: t1.t1_id > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
" Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
" SubqueryAlias: __correlated_sq_2 [CAST(t2_int AS Int64) + Int64(1):Int64;N]",
" Projection: CAST(t2.t2_int AS Int64) + Int64(1) AS CAST(t2_int AS Int64) + Int64(1) [CAST(t2_int AS Int64) + Int64(1):Int64;N]",
" TableScan: t2 projection=[t2_int] [t2_int:UInt32;N]",
];

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);

let expected = vec![
"+-------+---------+--------+",
"| t1_id | t1_name | t1_int |",
"+-------+---------+--------+",
"| 44 | d | 4 |",
"+-------+---------+--------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}
13 changes: 7 additions & 6 deletions datafusion/core/tests/sql/subqueries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,13 @@ where o_orderstatus in (
let dataframe = ctx.sql(sql).await.unwrap();
let plan = dataframe.into_optimized_plan().unwrap();
let actual = format!("{}", plan.display_indent());
let expected = r#"Projection: orders.o_orderkey
LeftSemi Join: orders.o_orderstatus = __correlated_sq_1.l_linestatus, orders.o_orderkey = __correlated_sq_1.l_orderkey
TableScan: orders projection=[o_orderkey, o_orderstatus]
SubqueryAlias: __correlated_sq_1
Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey AS l_orderkey
TableScan: lineitem projection=[l_orderkey, l_linestatus]"#;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is just whitespace, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not only whitespace, the alias except the first expression are removed in the projection.

before:

SubqueryAlias: __correlated_sq_1
      Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey AS l_orderkey

after:

 \n    SubqueryAlias: __correlated_sq_1\
    \n      Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey\

The first expression of the projection is the right side of in-equijoin, it may be a non-column expression, so always give it an alias, but others do not(they are alway column).

let expected = "Projection: orders.o_orderkey\
\n LeftSemi Join: orders.o_orderstatus = __correlated_sq_1.l_linestatus, orders.o_orderkey = __correlated_sq_1.l_orderkey\
\n TableScan: orders projection=[o_orderkey, o_orderstatus]\
\n SubqueryAlias: __correlated_sq_1\
\n Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey\
\n TableScan: lineitem projection=[l_orderkey, l_linestatus]";
assert_eq!(actual, expected);

// assert data
Expand Down
5 changes: 4 additions & 1 deletion datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,10 @@ pub fn can_hash(data_type: &DataType) -> bool {
}

/// Check whether all columns are from the schema.
fn check_all_column_from_schema(columns: &HashSet<Column>, schema: DFSchemaRef) -> bool {
pub fn check_all_column_from_schema(
columns: &HashSet<Column>,
schema: DFSchemaRef,
) -> bool {
columns
.iter()
.all(|column| schema.index_of_column(column).is_ok())
Expand Down
Loading