From ec8fd44594cada9cb0189f56ddf586ec48175ce0 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Tue, 26 Dec 2023 01:01:10 +0300 Subject: [PATCH] [MINOR]: Add new test for filter pushdown into cross join (#8648) * Initial commit * Minor changes * Simplifications * Update UDF example * Address review --------- Co-authored-by: Mehmet Ozan Kabak --- .../optimizer/src/eliminate_cross_join.rs | 1 + datafusion/optimizer/src/push_down_filter.rs | 12 +++- datafusion/sqllogictest/src/test_context.rs | 61 ++++++++++++++----- datafusion/sqllogictest/test_files/joins.slt | 22 +++++++ 4 files changed, 78 insertions(+), 18 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 7c866950a622..d9e96a9f2543 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -45,6 +45,7 @@ impl EliminateCrossJoin { /// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' /// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) /// or (a.x = b.y and b.xx = 200 and a.z=c.z);' +/// 'select ... from a, b where a.x > b.y' /// For above queries, the join predicate is available in filters and they are moved to /// join nodes appropriately /// This fix helps to improve the performance of TPCH Q19. issue#78 diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 4eed39a08941..9d277d18d2f7 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -965,11 +965,11 @@ impl PushDownFilter { } } -/// Convert cross join to join by pushing down filter predicate to the join condition +/// Converts the given cross join to an inner join with an empty equality +/// predicate and an empty filter condition. fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { let CrossJoin { left, right, .. } = cross_join; let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; - // predicate is given Ok(Join { left, right, @@ -982,7 +982,8 @@ fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { }) } -/// Converts the inner join with empty equality predicate and empty filter condition to the cross join +/// Converts the given inner join with an empty equality predicate and an +/// empty filter condition to a cross join. fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result { if let LogicalPlan::Join(join) = &plan { // Can be converted back to cross join @@ -991,6 +992,11 @@ fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result .cross_join(join.right.as_ref().clone())? .build(); } + } else if let LogicalPlan::Filter(filter) = &plan { + let new_input = + convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?; + return Filter::try_new(filter.predicate.clone(), Arc::new(new_input)) + .map(LogicalPlan::Filter); } Ok(plan) } diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 941dcb69d2f4..a5ce7ccb9fe0 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -15,31 +15,33 @@ // specific language governing permissions and limitations // under the License. -use async_trait::async_trait; +use std::collections::HashMap; +use std::fs::File; +use std::io::Write; +use std::path::Path; +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, + StringArray, TimestampNanosecondArray, +}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow::record_batch::RecordBatch; use datafusion::execution::context::SessionState; -use datafusion::logical_expr::Expr; +use datafusion::logical_expr::{create_udf, Expr, ScalarUDF, Volatility}; +use datafusion::physical_expr::functions::make_scalar_function; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionConfig; use datafusion::{ - arrow::{ - array::{ - BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, - StringArray, TimestampNanosecondArray, - }, - datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}, - record_batch::RecordBatch, - }, catalog::{schema::MemorySchemaProvider, CatalogProvider, MemoryCatalogProvider}, datasource::{MemTable, TableProvider, TableType}, prelude::{CsvReadOptions, SessionContext}, }; +use datafusion_common::cast::as_float64_array; use datafusion_common::DataFusionError; + +use async_trait::async_trait; use log::info; -use std::collections::HashMap; -use std::fs::File; -use std::io::Write; -use std::path::Path; -use std::sync::Arc; use tempfile::TempDir; /// Context for running tests @@ -102,6 +104,8 @@ impl TestContext { } "joins.slt" => { info!("Registering partition table tables"); + let example_udf = create_example_udf(); + test_ctx.ctx.register_udf(example_udf); register_partition_table(&mut test_ctx).await; } "metadata.slt" => { @@ -348,3 +352,30 @@ pub async fn register_metadata_tables(ctx: &SessionContext) { ctx.register_batch("table_with_metadata", batch).unwrap(); } + +/// Create a UDF function named "example". See the `sample_udf.rs` example +/// file for an explanation of the API. +fn create_example_udf() -> ScalarUDF { + let adder = make_scalar_function(|args: &[ArrayRef]| { + let lhs = as_float64_array(&args[0]).expect("cast failed"); + let rhs = as_float64_array(&args[1]).expect("cast failed"); + let array = lhs + .iter() + .zip(rhs.iter()) + .map(|(lhs, rhs)| match (lhs, rhs) { + (Some(lhs), Some(rhs)) => Some(lhs + rhs), + _ => None, + }) + .collect::(); + Ok(Arc::new(array) as ArrayRef) + }); + create_udf( + "example", + // Expects two f64 values: + vec![DataType::Float64, DataType::Float64], + // Returns an f64 value: + Arc::new(DataType::Float64), + Volatility::Immutable, + adder, + ) +} diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index eee213811f44..9a349f600091 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -3483,6 +3483,28 @@ NestedLoopJoinExec: join_type=Inner, filter=a@0 > a@1 ----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true --CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +# Currently datafusion cannot pushdown filter conditions with scalar UDF into +# cross join. +query TT +EXPLAIN SELECT * +FROM annotated_data as t1, annotated_data as t2 +WHERE EXAMPLE(t1.a, t2.a) > 3 +---- +logical_plan +Filter: example(CAST(t1.a AS Float64), CAST(t2.a AS Float64)) > Float64(3) +--CrossJoin: +----SubqueryAlias: t1 +------TableScan: annotated_data projection=[a0, a, b, c, d] +----SubqueryAlias: t2 +------TableScan: annotated_data projection=[a0, a, b, c, d] +physical_plan +CoalesceBatchesExec: target_batch_size=2 +--FilterExec: example(CAST(a@1 AS Float64), CAST(a@6 AS Float64)) > 3 +----CrossJoinExec +------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + #### # Config teardown ####