From 796f5f5a2a78d32e59884c632d8315440e83d92b Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Tue, 18 Jul 2023 15:58:23 +0800 Subject: [PATCH 1/6] fix: incorrect simplification of case expr (#7006) * fix: incorrect simplification of case expr * Use pattern to match against None --- .../tests/sqllogictests/test_files/scalar.slt | 12 +++++++ .../simplify_expressions/expr_simplifier.rs | 33 ++++++++++--------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/scalar.slt b/datafusion/core/tests/sqllogictests/test_files/scalar.slt index 2fe99fab192e..8c5c399c390e 100644 --- a/datafusion/core/tests/sqllogictests/test_files/scalar.slt +++ b/datafusion/core/tests/sqllogictests/test_files/scalar.slt @@ -1090,6 +1090,18 @@ FROM t1 999 999 +# issue: https://github.com/apache/arrow-datafusion/issues/7004 +query B +select case c1 + when 'foo' then TRUE + when 'bar' then FALSE +end from t1 +---- +NULL +NULL +NULL +NULL + statement ok drop table t1 diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index d88459bcf48a..447585e70aa1 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -34,8 +34,8 @@ use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter} use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{InList, InSubquery, ScalarFunction}; use datafusion_expr::{ - and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, ColumnarValue, Expr, Like, - Volatility, + and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, + Like, Volatility, }; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; @@ -1069,17 +1069,20 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // // Note: the rationale for this rewrite is that the expr can then be further // simplified using the existing rules for AND/OR - Expr::Case(case) - if !case.when_then_expr.is_empty() - && case.when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number - && info.is_boolean_type(&case.when_then_expr[0].1)? => + Expr::Case(Case { + expr: None, + when_then_expr, + else_expr, + }) if !when_then_expr.is_empty() + && when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number + && info.is_boolean_type(&when_then_expr[0].1)? => { // The disjunction of all the when predicates encountered so far let mut filter_expr = lit(false); // The disjunction of all the cases let mut out_expr = lit(false); - for (when, then) in case.when_then_expr { + for (when, then) in when_then_expr { let case_expr = when .as_ref() .clone() @@ -1090,7 +1093,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { filter_expr = filter_expr.or(*when); } - if let Some(else_expr) = case.else_expr { + if let Some(else_expr) = else_expr { let case_expr = filter_expr.not().and(*else_expr); out_expr = out_expr.or(case_expr); } @@ -2819,9 +2822,9 @@ mod tests { #[test] fn simplify_expr_case_when_then_else() { - // CASE WHERE c2 != false THEN "ok" == "not_ok" ELSE c2 == true + // CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true // --> - // CASE WHERE c2 THEN false ELSE c2 + // CASE WHEN c2 THEN false ELSE c2 // --> // false assert_eq!( @@ -2836,9 +2839,9 @@ mod tests { col("c2").not().and(col("c2")) // #1716 ); - // CASE WHERE c2 != false THEN "ok" == "ok" ELSE c2 + // CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2 // --> - // CASE WHERE c2 THEN true ELSE c2 + // CASE WHEN c2 THEN true ELSE c2 // --> // c2 // @@ -2856,7 +2859,7 @@ mod tests { col("c2").or(col("c2").not().and(col("c2"))) // #1716 ); - // CASE WHERE ISNULL(c2) THEN true ELSE c2 + // CASE WHEN ISNULL(c2) THEN true ELSE c2 // --> // ISNULL(c2) OR c2 // @@ -2873,7 +2876,7 @@ mod tests { .or(col("c2").is_not_null().and(col("c2"))) ); - // CASE WHERE c1 then true WHERE c2 then false ELSE true + // CASE WHEN c1 then true WHEN c2 then false ELSE true // --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE) // --> c1 OR (NOT(c1) AND NOT(c2)) // --> c1 OR NOT(c2) @@ -2892,7 +2895,7 @@ mod tests { col("c1").or(col("c1").not().and(col("c2").not())) ); - // CASE WHERE c1 then true WHERE c2 then true ELSE false + // CASE WHEN c1 then true WHEN c2 then true ELSE false // --> c1 OR (NOT(c1) AND c2 AND TRUE) OR (NOT(c1 OR c2) AND FALSE) // --> c1 OR (NOT(c1) AND c2) // --> c1 OR c2 From 63c2943d9c05c9a12e00a30d52e02cd2112938bb Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 18 Jul 2023 03:59:24 -0400 Subject: [PATCH 2/6] Minor: Add String/Binary aggregate tests (#6962) --- .../sqllogictests/test_files/aggregate.slt | 115 ++++++++++++++++++ 1 file changed, 115 insertions(+) diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt index 72b9e8400b61..c8b8960fd971 100644 --- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt +++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt @@ -1889,6 +1889,121 @@ drop table t_source; statement ok drop table t; + +# aggregates on strings +statement ok +create table t_source +as values + ('Foo', 1), + ('Bar', 2), + (null, 2), + ('Baz', 1); + +statement ok +create table t as +select + arrow_cast(column1, 'Utf8') as utf8, + arrow_cast(column1, 'LargeUtf8') as largeutf8, + column2 as tag +from t_source; + +# No groupy +query TTITTI +SELECT + min(utf8), + max(utf8), + count(utf8), + min(largeutf8), + max(largeutf8), + count(largeutf8) +FROM t +---- +Bar Foo 3 Bar Foo 3 + + +# with groupby +query TTITTI +SELECT + min(utf8), + max(utf8), + count(utf8), + min(largeutf8), + max(largeutf8), + count(largeutf8) +FROM t +GROUP BY tag +ORDER BY tag +---- +Baz Foo 2 Baz Foo 2 +Bar Bar 1 Bar Bar 1 + + +statement ok +drop table t_source; + +statement ok +drop table t; + + +# aggregates on binary +statement ok +create table t_source +as values + ('Foo', 1), + ('Bar', 2), + (null, 2), + ('Baz', 1); + +statement ok +create table t as +select + arrow_cast(column1, 'Binary') as binary, + arrow_cast(column1, 'LargeBinary') as largebinary, + column2 as tag +from t_source; + +# No groupy +query error DataFusion error: Internal error: Min/Max accumulator not implemented for type Binary\. This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker +SELECT + min(binary), + max(binary), + count(binary), + min(largebinary), + max(largebinary), + count(largebinary) +FROM t + + +# with groupby +query error DataFusion error: External error: Internal error: Min/Max accumulator not implemented for type Binary\. This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker +SELECT + min(binary), + max(binary), + count(binary), + min(largebinary), + max(largebinary), + count(largebinary) +FROM t +GROUP BY tag +ORDER BY tag + + + +statement ok +drop table t_source; + +statement ok +drop table t; + + + + +statement error DataFusion error: Execution error: Table 't_source' doesn't exist\. +drop table t_source; + +statement error DataFusion error: Execution error: Table 't' doesn't exist\. +drop table t; + query I select median(a) from (select 1 as a where 1=0); ---- From 4d93b6a3802151865b68967bdc4c7d7ef425b49a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Metehan=20Y=C4=B1ld=C4=B1r=C4=B1m?= <100111937+metesynnada@users.noreply.github.com> Date: Tue, 18 Jul 2023 12:46:35 +0300 Subject: [PATCH 3/6] [MINOR] Supporting repartition joins conf in SHJ (#6998) * Initial * Update mod.rs * Add fmt_as and refactor tests --- .../src/physical_optimizer/pipeline_fixer.rs | 5 ++- .../core/src/physical_plan/joins/mod.rs | 9 ++++ .../joins/symmetric_hash_join.rs | 41 ++++++++++++------- 3 files changed, 40 insertions(+), 15 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index caae7743450d..7db3e99c3920 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -27,7 +27,9 @@ use crate::error::Result; use crate::physical_optimizer::join_selection::swap_hash_join; use crate::physical_optimizer::pipeline_checker::PipelineStatePropagator; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::joins::{HashJoinExec, PartitionMode, SymmetricHashJoinExec}; +use crate::physical_plan::joins::{ + HashJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, +}; use crate::physical_plan::ExecutionPlan; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::DataFusionError; @@ -101,6 +103,7 @@ fn hash_join_convert_symmetric_subrule( hash_join.filter().cloned(), hash_join.join_type(), hash_join.null_equals_null(), + StreamJoinPartitionMode::Partitioned, ) .map(|exec| { input.plan = Arc::new(exec) as _; diff --git a/datafusion/core/src/physical_plan/joins/mod.rs b/datafusion/core/src/physical_plan/joins/mod.rs index 0a1bc147b80c..fd805fa2018a 100644 --- a/datafusion/core/src/physical_plan/joins/mod.rs +++ b/datafusion/core/src/physical_plan/joins/mod.rs @@ -42,3 +42,12 @@ pub enum PartitionMode { /// It will also consider swapping the left and right inputs for the Join Auto, } + +/// Partitioning mode to use for symmetric hash join +#[derive(Hash, Clone, Copy, Debug, PartialEq, Eq)] +pub enum StreamJoinPartitionMode { + /// Left/right children are partitioned using the left and right keys + Partitioned, + /// Both sides will collected into one partition + SinglePartition, +} diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index dfd38a20c087..10c9ae2c08e2 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -57,6 +57,7 @@ use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval, IntervalB use crate::physical_plan::common::SharedMemoryReservation; use crate::physical_plan::joins::hash_join_utils::convert_sort_expr_with_filter_schema; +use crate::physical_plan::joins::StreamJoinPartitionMode; use crate::physical_plan::DisplayAs; use crate::physical_plan::{ expressions::Column, @@ -192,6 +193,8 @@ pub struct SymmetricHashJoinExec { column_indices: Vec, /// If null_equals_null is true, null == null else null != null pub(crate) null_equals_null: bool, + /// Partition Mode + mode: StreamJoinPartitionMode, } struct IntervalCalculatorInnerState { @@ -280,6 +283,7 @@ impl SymmetricHashJoinExec { filter: Option, join_type: &JoinType, null_equals_null: bool, + mode: StreamJoinPartitionMode, ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); @@ -324,6 +328,7 @@ impl SymmetricHashJoinExec { metrics: ExecutionPlanMetricsSet::new(), column_indices, null_equals_null, + mode, }) } @@ -402,8 +407,8 @@ impl DisplayAs for SymmetricHashJoinExec { .join(", "); write!( f, - "SymmetricHashJoinExec: join_type={:?}, on=[{}]{}", - self.join_type, on, display_filter + "SymmetricHashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}", + self.mode, self.join_type, on, display_filter ) } } @@ -428,16 +433,22 @@ impl ExecutionPlan for SymmetricHashJoinExec { } fn required_input_distribution(&self) -> Vec { - let (left_expr, right_expr) = self - .on - .iter() - .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) - .unzip(); - // TODO: This will change when we extend collected executions. - vec![ - Distribution::HashPartitioned(left_expr), - Distribution::HashPartitioned(right_expr), - ] + match self.mode { + StreamJoinPartitionMode::Partitioned => { + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) + .unzip(); + vec![ + Distribution::HashPartitioned(left_expr), + Distribution::HashPartitioned(right_expr), + ] + } + StreamJoinPartitionMode::SinglePartition => { + vec![Distribution::SinglePartition, Distribution::SinglePartition] + } + } } fn output_partitioning(&self) -> Partitioning { @@ -482,6 +493,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { self.filter.clone(), &self.join_type, self.null_equals_null, + self.mode, )?)) } @@ -1818,6 +1830,7 @@ mod tests { filter, join_type, null_equals_null, + StreamJoinPartitionMode::Partitioned, )?; let mut batches = vec![]; @@ -2636,7 +2649,7 @@ mod tests { let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); let expected = { [ - "SymmetricHashJoinExec: join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", + "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1", // " CsvExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], has_header=false", @@ -2689,7 +2702,7 @@ mod tests { let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); let expected = { [ - "SymmetricHashJoinExec: join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", + "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", " CoalesceBatchesExec: target_batch_size=8192", " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=1", // " CsvExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], has_header=false", From 58ad10c07a3ae7abc19b3b0899f19f1f069030c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Metehan=20Y=C4=B1ld=C4=B1r=C4=B1m?= <100111937+metesynnada@users.noreply.github.com> Date: Tue, 18 Jul 2023 23:56:16 +0300 Subject: [PATCH 4/6] [MINOR] Code refactor on hash join utils (#6999) * Code refactor * Remove mode * Update test_utils.rs --- .../physical_plan/joins/hash_join_utils.rs | 332 +++- .../core/src/physical_plan/joins/mod.rs | 3 + .../joins/symmetric_hash_join.rs | 1498 ++++------------- .../src/physical_plan/joins/test_utils.rs | 513 ++++++ 4 files changed, 1199 insertions(+), 1147 deletions(-) create mode 100644 datafusion/core/src/physical_plan/joins/test_utils.rs diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs index 1b9cbd543d73..37790e6bb8a6 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -22,17 +22,25 @@ use std::collections::HashMap; use std::sync::Arc; use std::{fmt, usize}; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{ArrowNativeType, SchemaRef}; +use arrow::compute::concat_batches; +use arrow_array::builder::BooleanBufferBuilder; +use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::intervals::Interval; +use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval, IntervalBound}; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use hashbrown::raw::RawTable; +use hashbrown::HashSet; +use parking_lot::Mutex; use smallvec::SmallVec; +use std::fmt::{Debug, Formatter}; use crate::physical_plan::joins::utils::{JoinFilter, JoinSide}; +use crate::physical_plan::ExecutionPlan; use datafusion_common::Result; // Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. @@ -280,6 +288,105 @@ fn convert_filter_columns( }) } +#[derive(Default)] +pub struct IntervalCalculatorInnerState { + /// Expression graph for interval calculations + graph: Option, + sorted_exprs: Vec>, + calculated: bool, +} + +impl Debug for IntervalCalculatorInnerState { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "Exprs({:?})", self.sorted_exprs) + } +} + +pub fn build_filter_expression_graph( + interval_state: &Arc>, + left: &Arc, + right: &Arc, + filter: &JoinFilter, +) -> Result<( + Option, + Option, + Option, +)> { + // Lock the mutex of the interval state: + let mut filter_state = interval_state.lock(); + // If this is the first partition to be invoked, then we need to initialize our state + // (the expression graph for pruning, sorted filter expressions etc.) + if !filter_state.calculated { + // Interval calculations require each column to exhibit monotonicity + // independently. However, a `PhysicalSortExpr` object defines a + // lexicographical ordering, so we can only use their first elements. + // when deducing column monotonicities. + // TODO: Extend the `PhysicalSortExpr` mechanism to express independent + // (i.e. simultaneous) ordering properties of columns. + + // Build sorted filter expressions for the left and right join side: + let join_sides = [JoinSide::Left, JoinSide::Right]; + let children = [left, right]; + for (join_side, child) in join_sides.iter().zip(children.iter()) { + let sorted_expr = child + .output_ordering() + .and_then(|orders| { + build_filter_input_order( + *join_side, + filter, + &child.schema(), + &orders[0], + ) + .transpose() + }) + .transpose()?; + + filter_state.sorted_exprs.push(sorted_expr); + } + + // Collect available sorted filter expressions: + let sorted_exprs_size = filter_state.sorted_exprs.len(); + let mut sorted_exprs = filter_state + .sorted_exprs + .iter_mut() + .flatten() + .collect::>(); + + // Create the expression graph if we can create sorted filter expressions for both children: + filter_state.graph = if sorted_exprs.len() == sorted_exprs_size { + let mut graph = ExprIntervalGraph::try_new(filter.expression().clone())?; + + // Gather filter expressions: + let filter_exprs = sorted_exprs + .iter() + .map(|sorted_expr| sorted_expr.filter_expr().clone()) + .collect::>(); + + // Gather node indices of converted filter expressions in `SortedFilterExpr`s + // using the filter columns vector: + let child_node_indices = graph.gather_node_indices(&filter_exprs); + + // Update SortedFilterExpr instances with the corresponding node indices: + for (sorted_expr, (_, index)) in + sorted_exprs.iter_mut().zip(child_node_indices.iter()) + { + sorted_expr.set_node_index(*index); + } + + Some(graph) + } else { + None + }; + filter_state.calculated = true; + } + // Return the sorted filter expressions for both sides along with the expression graph: + Ok(( + filter_state.sorted_exprs[0].clone(), + filter_state.sorted_exprs[1].clone(), + filter_state.graph.as_ref().cloned(), + )) +} + /// The [SortedFilterExpr] object represents a sorted filter expression. It /// contains the following information: The origin expression, the filter /// expression, an interval encapsulating expression bounds, and a stable @@ -341,6 +448,227 @@ impl SortedFilterExpr { } } +/// Calculate the filter expression intervals. +/// +/// This function updates the `interval` field of each `SortedFilterExpr` based +/// on the first or the last value of the expression in `build_input_buffer` +/// and `probe_batch`. +/// +/// # Arguments +/// +/// * `build_input_buffer` - The [RecordBatch] on the build side of the join. +/// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update. +/// * `probe_batch` - The `RecordBatch` on the probe side of the join. +/// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update. +/// +/// ### Note +/// ```text +/// +/// Interval arithmetic is used to calculate viable join ranges for build-side +/// pruning. This is done by first creating an interval for join filter values in +/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on the +/// ordering (descending/ascending) of the filter expression. Here, FV denotes the +/// first value on the build side. This range is then compared with the probe side +/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering +/// (ascending/descending) of the probe side. Here, LV denotes the last value on +/// the probe side. +/// +/// As a concrete example, consider the following query: +/// +/// SELECT * FROM left_table, right_table +/// WHERE +/// left_key = right_key AND +/// a > b - 3 AND +/// a < b + 10 +/// +/// where columns "a" and "b" come from tables "left_table" and "right_table", +/// respectively. When a new `RecordBatch` arrives at the right side, the +/// condition a > b - 3 will possibly indicate a prunable range for the left +/// side. Conversely, when a new `RecordBatch` arrives at the left side, the +/// condition a < b + 10 will possibly indicate prunability for the right side. +/// Let’s inspect what happens when a new RecordBatch` arrives at the right +/// side (i.e. when the left side is the build side): +/// +/// Build Probe +/// +-------+ +-------+ +/// | a | z | | b | y | +/// |+--|--+| |+--|--+| +/// | 1 | 2 | | 4 | 3 | +/// |+--|--+| |+--|--+| +/// | 3 | 1 | | 4 | 3 | +/// |+--|--+| |+--|--+| +/// | 5 | 7 | | 6 | 1 | +/// |+--|--+| |+--|--+| +/// | 7 | 1 | | 6 | 3 | +/// +-------+ +-------+ +/// +/// In this case, the interval representing viable (i.e. joinable) values for +/// column "a" is [1, ∞], and the interval representing possible future values +/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate +/// intervals for the whole filter expression and propagate join constraint by +/// traversing the expression graph. +/// ``` +pub fn calculate_filter_expr_intervals( + build_input_buffer: &RecordBatch, + build_sorted_filter_expr: &mut SortedFilterExpr, + probe_batch: &RecordBatch, + probe_sorted_filter_expr: &mut SortedFilterExpr, +) -> Result<()> { + // If either build or probe side has no data, return early: + if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { + return Ok(()); + } + // Calculate the interval for the build side filter expression (if present): + update_filter_expr_interval( + &build_input_buffer.slice(0, 1), + build_sorted_filter_expr, + )?; + // Calculate the interval for the probe side filter expression (if present): + update_filter_expr_interval( + &probe_batch.slice(probe_batch.num_rows() - 1, 1), + probe_sorted_filter_expr, + ) +} + +/// This is a subroutine of the function [`calculate_filter_expr_intervals`]. +/// It constructs the current interval using the given `batch` and updates +/// the filter expression (i.e. `sorted_expr`) with this interval. +pub fn update_filter_expr_interval( + batch: &RecordBatch, + sorted_expr: &mut SortedFilterExpr, +) -> Result<()> { + // Evaluate the filter expression and convert the result to an array: + let array = sorted_expr + .origin_sorted_expr() + .expr + .evaluate(batch)? + .into_array(1); + // Convert the array to a ScalarValue: + let value = ScalarValue::try_from_array(&array, 0)?; + // Create a ScalarValue representing positive or negative infinity for the same data type: + let unbounded = IntervalBound::make_unbounded(value.get_datatype())?; + // Update the interval with lower and upper bounds based on the sort option: + let interval = if sorted_expr.origin_sorted_expr().options.descending { + Interval::new(unbounded, IntervalBound::new(value, false)) + } else { + Interval::new(IntervalBound::new(value, false), unbounded) + }; + // Set the calculated interval for the sorted filter expression: + sorted_expr.set_interval(interval); + Ok(()) +} + +/// Get the anti join indices from the visited hash set. +/// +/// This method returns the indices from the original input that were not present in the visited hash set. +/// +/// # Arguments +/// +/// * `prune_length` - The length of the pruned record batch. +/// * `deleted_offset` - The offset to the indices. +/// * `visited_rows` - The hash set of visited indices. +/// +/// # Returns +/// +/// A `PrimitiveArray` of the anti join indices. +pub fn get_pruning_anti_indices( + prune_length: usize, + deleted_offset: usize, + visited_rows: &HashSet, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = BooleanBufferBuilder::new(prune_length); + bitmap.append_n(prune_length, false); + // mark the indices as true if they are present in the visited hash set + for v in 0..prune_length { + let row = v + deleted_offset; + bitmap.set_bit(v, visited_rows.contains(&row)); + } + // get the anti index + (0..prune_length) + .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) + .collect() +} + +/// This method creates a boolean buffer from the visited rows hash set +/// and the indices of the pruned record batch slice. +/// +/// It gets the indices from the original input that were present in the visited hash set. +/// +/// # Arguments +/// +/// * `prune_length` - The length of the pruned record batch. +/// * `deleted_offset` - The offset to the indices. +/// * `visited_rows` - The hash set of visited indices. +/// +/// # Returns +/// +/// A [PrimitiveArray] of the specified type T, containing the semi indices. +pub fn get_pruning_semi_indices( + prune_length: usize, + deleted_offset: usize, + visited_rows: &HashSet, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = BooleanBufferBuilder::new(prune_length); + bitmap.append_n(prune_length, false); + // mark the indices as true if they are present in the visited hash set + (0..prune_length).for_each(|v| { + let row = &(v + deleted_offset); + bitmap.set_bit(v, visited_rows.contains(row)); + }); + // get the semi index + (0..prune_length) + .filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) + .collect::>() +} + +pub fn combine_two_batches( + output_schema: &SchemaRef, + left_batch: Option, + right_batch: Option, +) -> Result> { + match (left_batch, right_batch) { + (Some(batch), None) | (None, Some(batch)) => { + // If only one of the batches are present, return it: + Ok(Some(batch)) + } + (Some(left_batch), Some(right_batch)) => { + // If both batches are present, concatenate them: + concat_batches(output_schema, &[left_batch, right_batch]) + .map_err(DataFusionError::ArrowError) + .map(Some) + } + (None, None) => { + // If neither is present, return an empty batch: + Ok(None) + } + } +} + +/// Records the visited indices from the input `PrimitiveArray` of type `T` into the given hash set `visited`. +/// This function will insert the indices (offset by `offset`) into the `visited` hash set. +/// +/// # Arguments +/// +/// * `visited` - A hash set to store the visited indices. +/// * `offset` - An offset to the indices in the `PrimitiveArray`. +/// * `indices` - The input `PrimitiveArray` of type `T` which stores the indices to be recorded. +/// +pub fn record_visited_indices( + visited: &mut HashSet, + offset: usize, + indices: &PrimitiveArray, +) { + for i in indices.values() { + visited.insert(i.as_usize() + offset); + } +} + #[cfg(test)] pub mod tests { use super::*; diff --git a/datafusion/core/src/physical_plan/joins/mod.rs b/datafusion/core/src/physical_plan/joins/mod.rs index fd805fa2018a..19f10d06e1ef 100644 --- a/datafusion/core/src/physical_plan/joins/mod.rs +++ b/datafusion/core/src/physical_plan/joins/mod.rs @@ -31,6 +31,9 @@ mod sort_merge_join; mod symmetric_hash_join; pub mod utils; +#[cfg(test)] +pub mod test_utils; + #[derive(Clone, Copy, Debug, PartialEq, Eq)] /// Partitioning mode to use for hash join pub enum PartitionMode { diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 10c9ae2c08e2..1818e4b91c1b 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -27,19 +27,16 @@ use std::collections::{HashMap, VecDeque}; use std::fmt; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::sync::Arc; use std::task::Poll; use std::vec; use std::{any::Any, usize}; use ahash::RandomState; -use arrow::array::{ - ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray, - PrimitiveBuilder, -}; +use arrow::array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, PrimitiveBuilder}; use arrow::compute::concat_batches; -use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; +use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow_array::builder::{UInt32BufferBuilder, UInt64BufferBuilder}; use arrow_array::{UInt32Array, UInt64Array}; @@ -51,19 +48,22 @@ use hashbrown::HashSet; use parking_lot::Mutex; use smallvec::smallvec; -use datafusion_common::{utils::bisect, ScalarValue}; use datafusion_execution::memory_pool::MemoryConsumer; -use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval, IntervalBound}; +use datafusion_physical_expr::intervals::ExprIntervalGraph; use crate::physical_plan::common::SharedMemoryReservation; -use crate::physical_plan::joins::hash_join_utils::convert_sort_expr_with_filter_schema; +use crate::physical_plan::joins::hash_join_utils::{ + build_filter_expression_graph, calculate_filter_expr_intervals, combine_two_batches, + convert_sort_expr_with_filter_schema, get_pruning_anti_indices, + get_pruning_semi_indices, record_visited_indices, IntervalCalculatorInnerState, +}; use crate::physical_plan::joins::StreamJoinPartitionMode; use crate::physical_plan::DisplayAs; use crate::physical_plan::{ expressions::Column, expressions::PhysicalSortExpr, joins::{ - hash_join_utils::{build_filter_input_order, SortedFilterExpr}, + hash_join_utils::SortedFilterExpr, utils::{ build_batch_from_indices, build_join_schema, check_join_is_valid, combine_join_equivalence_properties, partitioned_join_output_partitioning, @@ -74,6 +74,7 @@ use crate::physical_plan::{ DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use datafusion_common::utils::bisect; use datafusion_common::JoinType; use datafusion_common::{DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -197,19 +198,6 @@ pub struct SymmetricHashJoinExec { mode: StreamJoinPartitionMode, } -struct IntervalCalculatorInnerState { - /// Expression graph for interval calculations - graph: Option, - sorted_exprs: Vec>, - calculated: bool, -} - -impl Debug for IntervalCalculatorInnerState { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "Exprs({:?})", self.sorted_exprs) - } -} - #[derive(Debug)] struct SymmetricHashJoinSideMetrics { /// Number of batches consumed by this operator @@ -306,11 +294,7 @@ impl SymmetricHashJoinExec { let random_state = RandomState::with_seeds(0, 0, 0, 0); let filter_state = if filter.is_some() { - let inner_state = IntervalCalculatorInnerState { - graph: None, - sorted_exprs: vec![], - calculated: false, - }; + let inner_state = IntervalCalculatorInnerState::default(); Some(Arc::new(Mutex::new(inner_state))) } else { None @@ -523,83 +507,12 @@ impl ExecutionPlan for SymmetricHashJoinExec { // for both sides, and build an expression graph if one is not already built. let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match (&self.filter_state, &self.filter) { - (Some(interval_state), Some(filter)) => { - // Lock the mutex of the interval state: - let mut filter_state = interval_state.lock(); - // If this is the first partition to be invoked, then we need to initialize our state - // (the expression graph for pruning, sorted filter expressions etc.) - if !filter_state.calculated { - // Interval calculations require each column to exhibit monotonicity - // independently. However, a `PhysicalSortExpr` object defines a - // lexicographical ordering, so we can only use their first elements. - // when deducing column monotonicities. - // TODO: Extend the `PhysicalSortExpr` mechanism to express independent - // (i.e. simultaneous) ordering properties of columns. - - // Build sorted filter expressions for the left and right join side: - let join_sides = [JoinSide::Left, JoinSide::Right]; - let children = [&self.left, &self.right]; - for (join_side, child) in join_sides.iter().zip(children.iter()) { - let sorted_expr = child - .output_ordering() - .and_then(|orders| { - build_filter_input_order( - *join_side, - filter, - &child.schema(), - &orders[0], - ) - .transpose() - }) - .transpose()?; - - filter_state.sorted_exprs.push(sorted_expr); - } - - // Collect available sorted filter expressions: - let sorted_exprs_size = filter_state.sorted_exprs.len(); - let mut sorted_exprs = filter_state - .sorted_exprs - .iter_mut() - .flatten() - .collect::>(); - - // Create the expression graph if we can create sorted filter expressions for both children: - filter_state.graph = if sorted_exprs.len() == sorted_exprs_size { - let mut graph = - ExprIntervalGraph::try_new(filter.expression().clone())?; - - // Gather filter expressions: - let filter_exprs = sorted_exprs - .iter() - .map(|sorted_expr| sorted_expr.filter_expr().clone()) - .collect::>(); - - // Gather node indices of converted filter expressions in `SortedFilterExpr`s - // using the filter columns vector: - let child_node_indices = - graph.gather_node_indices(&filter_exprs); - - // Update SortedFilterExpr instances with the corresponding node indices: - for (sorted_expr, (_, index)) in - sorted_exprs.iter_mut().zip(child_node_indices.iter()) - { - sorted_expr.set_node_index(*index); - } - - Some(graph) - } else { - None - }; - filter_state.calculated = true; - } - // Return the sorted filter expressions for both sides along with the expression graph: - ( - filter_state.sorted_exprs[0].clone(), - filter_state.sorted_exprs[1].clone(), - filter_state.graph.as_ref().cloned(), - ) - } + (Some(interval_state), Some(filter)) => build_filter_expression_graph( + interval_state, + &self.left, + &self.right, + filter, + )?, // If `filter_state` or `filter` is not present, then return None for all three values: (_, _) => (None, None, None), }; @@ -742,116 +655,6 @@ fn prune_hash_values( Ok(()) } -/// Calculate the filter expression intervals. -/// -/// This function updates the `interval` field of each `SortedFilterExpr` based -/// on the first or the last value of the expression in `build_input_buffer` -/// and `probe_batch`. -/// -/// # Arguments -/// -/// * `build_input_buffer` - The [RecordBatch] on the build side of the join. -/// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update. -/// * `probe_batch` - The `RecordBatch` on the probe side of the join. -/// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update. -/// -/// ### Note -/// ```text -/// -/// Interval arithmetic is used to calculate viable join ranges for build-side -/// pruning. This is done by first creating an interval for join filter values in -/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on the -/// ordering (descending/ascending) of the filter expression. Here, FV denotes the -/// first value on the build side. This range is then compared with the probe side -/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering -/// (ascending/descending) of the probe side. Here, LV denotes the last value on -/// the probe side. -/// -/// As a concrete example, consider the following query: -/// -/// SELECT * FROM left_table, right_table -/// WHERE -/// left_key = right_key AND -/// a > b - 3 AND -/// a < b + 10 -/// -/// where columns "a" and "b" come from tables "left_table" and "right_table", -/// respectively. When a new `RecordBatch` arrives at the right side, the -/// condition a > b - 3 will possibly indicate a prunable range for the left -/// side. Conversely, when a new `RecordBatch` arrives at the left side, the -/// condition a < b + 10 will possibly indicate prunability for the right side. -/// Let’s inspect what happens when a new RecordBatch` arrives at the right -/// side (i.e. when the left side is the build side): -/// -/// Build Probe -/// +-------+ +-------+ -/// | a | z | | b | y | -/// |+--|--+| |+--|--+| -/// | 1 | 2 | | 4 | 3 | -/// |+--|--+| |+--|--+| -/// | 3 | 1 | | 4 | 3 | -/// |+--|--+| |+--|--+| -/// | 5 | 7 | | 6 | 1 | -/// |+--|--+| |+--|--+| -/// | 7 | 1 | | 6 | 3 | -/// +-------+ +-------+ -/// -/// In this case, the interval representing viable (i.e. joinable) values for -/// column "a" is [1, ∞], and the interval representing possible future values -/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate -/// intervals for the whole filter expression and propagate join constraint by -/// traversing the expression graph. -/// ``` -fn calculate_filter_expr_intervals( - build_input_buffer: &RecordBatch, - build_sorted_filter_expr: &mut SortedFilterExpr, - probe_batch: &RecordBatch, - probe_sorted_filter_expr: &mut SortedFilterExpr, -) -> Result<()> { - // If either build or probe side has no data, return early: - if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { - return Ok(()); - } - // Calculate the interval for the build side filter expression (if present): - update_filter_expr_interval( - &build_input_buffer.slice(0, 1), - build_sorted_filter_expr, - )?; - // Calculate the interval for the probe side filter expression (if present): - update_filter_expr_interval( - &probe_batch.slice(probe_batch.num_rows() - 1, 1), - probe_sorted_filter_expr, - ) -} - -/// This is a subroutine of the function [`calculate_filter_expr_intervals`]. -/// It constructs the current interval using the given `batch` and updates -/// the filter expression (i.e. `sorted_expr`) with this interval. -fn update_filter_expr_interval( - batch: &RecordBatch, - sorted_expr: &mut SortedFilterExpr, -) -> Result<()> { - // Evaluate the filter expression and convert the result to an array: - let array = sorted_expr - .origin_sorted_expr() - .expr - .evaluate(batch)? - .into_array(1); - // Convert the array to a ScalarValue: - let value = ScalarValue::try_from_array(&array, 0)?; - // Create a ScalarValue representing positive or negative infinity for the same data type: - let unbounded = IntervalBound::make_unbounded(value.get_datatype())?; - // Update the interval with lower and upper bounds based on the sort option: - let interval = if sorted_expr.origin_sorted_expr().options.descending { - Interval::new(unbounded, IntervalBound::new(value, false)) - } else { - Interval::new(IntervalBound::new(value, false), unbounded) - }; - // Set the calculated interval for the sorted filter expression: - sorted_expr.set_interval(interval); - Ok(()) -} - /// Determine the pruning length for `buffer`. /// /// This function evaluates the build side filter expression, converts the @@ -919,93 +722,6 @@ fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> } } -/// Get the anti join indices from the visited hash set. -/// -/// This method returns the indices from the original input that were not present in the visited hash set. -/// -/// # Arguments -/// -/// * `prune_length` - The length of the pruned record batch. -/// * `deleted_offset` - The offset to the indices. -/// * `visited_rows` - The hash set of visited indices. -/// -/// # Returns -/// -/// A `PrimitiveArray` of the anti join indices. -fn get_anti_indices( - prune_length: usize, - deleted_offset: usize, - visited_rows: &HashSet, -) -> PrimitiveArray -where - NativeAdapter: From<::Native>, -{ - let mut bitmap = BooleanBufferBuilder::new(prune_length); - bitmap.append_n(prune_length, false); - // mark the indices as true if they are present in the visited hash set - for v in 0..prune_length { - let row = v + deleted_offset; - bitmap.set_bit(v, visited_rows.contains(&row)); - } - // get the anti index - (0..prune_length) - .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) - .collect() -} - -/// This method creates a boolean buffer from the visited rows hash set -/// and the indices of the pruned record batch slice. -/// -/// It gets the indices from the original input that were present in the visited hash set. -/// -/// # Arguments -/// -/// * `prune_length` - The length of the pruned record batch. -/// * `deleted_offset` - The offset to the indices. -/// * `visited_rows` - The hash set of visited indices. -/// -/// # Returns -/// -/// A [PrimitiveArray] of the specified type T, containing the semi indices. -fn get_semi_indices( - prune_length: usize, - deleted_offset: usize, - visited_rows: &HashSet, -) -> PrimitiveArray -where - NativeAdapter: From<::Native>, -{ - let mut bitmap = BooleanBufferBuilder::new(prune_length); - bitmap.append_n(prune_length, false); - // mark the indices as true if they are present in the visited hash set - (0..prune_length).for_each(|v| { - let row = &(v + deleted_offset); - bitmap.set_bit(v, visited_rows.contains(row)); - }); - // get the semi index - (0..prune_length) - .filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) - .collect::>() -} -/// Records the visited indices from the input `PrimitiveArray` of type `T` into the given hash set `visited`. -/// This function will insert the indices (offset by `offset`) into the `visited` hash set. -/// -/// # Arguments -/// -/// * `visited` - A hash set to store the visited indices. -/// * `offset` - An offset to the indices in the `PrimitiveArray`. -/// * `indices` - The input `PrimitiveArray` of type `T` which stores the indices to be recorded. -/// -fn record_visited_indices( - visited: &mut HashSet, - offset: usize, - indices: &PrimitiveArray, -) { - for i in indices.values() { - visited.insert(i.as_usize() + offset); - } -} - /// Calculate indices by join type. /// /// This method returns a tuple of two arrays: build and probe indices. @@ -1040,7 +756,7 @@ where | (JoinSide::Right, JoinType::Right | JoinType::RightAnti) | (_, JoinType::Full) => { let build_unmatched_indices = - get_anti_indices(prune_length, deleted_offset, visited_rows); + get_pruning_anti_indices(prune_length, deleted_offset, visited_rows); let mut builder = PrimitiveBuilder::::with_capacity(build_unmatched_indices.len()); builder.append_nulls(build_unmatched_indices.len()); @@ -1050,7 +766,7 @@ where // In the case of `LeftSemi` or `RightSemi` join, get the semi indices (JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => { let build_unmatched_indices = - get_semi_indices(prune_length, deleted_offset, visited_rows); + get_pruning_semi_indices(prune_length, deleted_offset, visited_rows); let mut builder = PrimitiveBuilder::::with_capacity(build_unmatched_indices.len()); builder.append_nulls(build_unmatched_indices.len()); @@ -1063,25 +779,301 @@ where Ok(result) } -struct OneSideHashJoiner { +/// This function produces unmatched record results based on the build side, +/// join type and other parameters. +/// +/// The method uses first `prune_length` rows from the build side input buffer +/// to produce results. +/// +/// # Arguments +/// +/// * `output_schema` - The schema of the final output record batch. +/// * `prune_length` - The length of the determined prune length. +/// * `probe_schema` - The schema of the probe [RecordBatch]. +/// * `join_type` - The type of join to be performed. +/// * `column_indices` - Indices of columns that are being joined. +/// +/// # Returns +/// +/// * `Option` - The final output record batch if required, otherwise [None]. +pub(crate) fn build_side_determined_results( + build_hash_joiner: &OneSideHashJoiner, + output_schema: &SchemaRef, + prune_length: usize, + probe_schema: SchemaRef, + join_type: JoinType, + column_indices: &[ColumnIndex], +) -> Result> { + // Check if we need to produce a result in the final output: + if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { + // Calculate the indices for build and probe sides based on join type and build side: + let (build_indices, probe_indices) = calculate_indices_by_join_type( + build_hash_joiner.build_side, + prune_length, + &build_hash_joiner.visited_rows, + build_hash_joiner.deleted_offset, + join_type, + )?; + + // Create an empty probe record batch: + let empty_probe_batch = RecordBatch::new_empty(probe_schema); + // Build the final result from the indices of build and probe sides: + build_batch_from_indices( + output_schema.as_ref(), + &build_hash_joiner.input_buffer, + &empty_probe_batch, + &build_indices, + &probe_indices, + column_indices, + build_hash_joiner.build_side, + ) + .map(|batch| (batch.num_rows() > 0).then_some(batch)) + } else { + // If we don't need to produce a result, return None + Ok(None) + } +} + +/// Gets build and probe indices which satisfy the on condition (including +/// the equality condition and the join filter) in the join. +#[allow(clippy::too_many_arguments)] +pub fn build_join_indices( + probe_batch: &RecordBatch, + build_hashmap: &SymmetricJoinHashMap, + build_input_buffer: &RecordBatch, + on_build: &[Column], + on_probe: &[Column], + filter: Option<&JoinFilter>, + random_state: &RandomState, + null_equals_null: bool, + hashes_buffer: &mut Vec, + offset: Option, + build_side: JoinSide, +) -> Result<(UInt64Array, UInt32Array)> { + // Get the indices that satisfy the equality condition, like `left.a1 = right.a2` + let (build_indices, probe_indices) = build_equal_condition_join_indices( + build_hashmap, + build_input_buffer, + probe_batch, + on_build, + on_probe, + random_state, + null_equals_null, + hashes_buffer, + offset, + )?; + if let Some(filter) = filter { + // Filter the indices which satisfy the non-equal join condition, like `left.b1 = 10` + apply_join_filter_to_indices( + build_input_buffer, + probe_batch, + build_indices, + probe_indices, + filter, + build_side, + ) + } else { + Ok((build_indices, probe_indices)) + } +} + +// Returns build/probe indices satisfying the equality condition. +// On LEFT.b1 = RIGHT.b2 +// LEFT Table: +// a1 b1 c1 +// 1 1 10 +// 3 3 30 +// 5 5 50 +// 7 7 70 +// 9 8 90 +// 11 8 110 +// 13 10 130 +// RIGHT Table: +// a2 b2 c2 +// 2 2 20 +// 4 4 40 +// 6 6 60 +// 8 8 80 +// 10 10 100 +// 12 10 120 +// The result is +// "+----+----+-----+----+----+-----+", +// "| a1 | b1 | c1 | a2 | b2 | c2 |", +// "+----+----+-----+----+----+-----+", +// "| 11 | 8 | 110 | 8 | 8 | 80 |", +// "| 13 | 10 | 130 | 10 | 10 | 100 |", +// "| 13 | 10 | 130 | 12 | 10 | 120 |", +// "| 9 | 8 | 90 | 8 | 8 | 80 |", +// "+----+----+-----+----+----+-----+" +// And the result of build and probe indices are: +// Build indices: 5, 6, 6, 4 +// Probe indices: 3, 4, 5, 3 +#[allow(clippy::too_many_arguments)] +pub fn build_equal_condition_join_indices( + build_hashmap: &SymmetricJoinHashMap, + build_input_buffer: &RecordBatch, + probe_batch: &RecordBatch, + build_on: &[Column], + probe_on: &[Column], + random_state: &RandomState, + null_equals_null: bool, + hashes_buffer: &mut Vec, + offset: Option, +) -> Result<(UInt64Array, UInt32Array)> { + let keys_values = probe_on + .iter() + .map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))) + .collect::>>()?; + let build_join_values = build_on + .iter() + .map(|c| { + Ok(c.evaluate(build_input_buffer)? + .into_array(build_input_buffer.num_rows())) + }) + .collect::>>()?; + hashes_buffer.clear(); + hashes_buffer.resize(probe_batch.num_rows(), 0); + let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; + // Using a buffer builder to avoid slower normal builder + let mut build_indices = UInt64BufferBuilder::new(0); + let mut probe_indices = UInt32BufferBuilder::new(0); + let offset_value = offset.unwrap_or(0); + // Visit all of the probe rows + for (row, hash_value) in hash_values.iter().enumerate() { + // Get the hash and find it in the build index + // For every item on the build and probe we check if it matches + // This possibly contains rows with hash collisions, + // So we have to check here whether rows are equal or not + if let Some((_, indices)) = build_hashmap + .0 + .get(*hash_value, |(hash, _)| *hash_value == *hash) + { + for &i in indices { + // Check hash collisions + let offset_build_index = i as usize - offset_value; + // Check hash collisions + if equal_rows( + offset_build_index, + row, + &build_join_values, + &keys_values, + null_equals_null, + )? { + build_indices.append(offset_build_index as u64); + probe_indices.append(row as u32); + } + } + } + } + + Ok(( + PrimitiveArray::new(build_indices.finish().into(), None), + PrimitiveArray::new(probe_indices.finish().into(), None), + )) +} + +/// This method performs a join between the build side input buffer and the probe side batch. +/// +/// # Arguments +/// +/// * `build_hash_joiner` - Build side hash joiner +/// * `probe_hash_joiner` - Probe side hash joiner +/// * `schema` - A reference to the schema of the output record batch. +/// * `join_type` - The type of join to be performed. +/// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join. +/// * `filter` - An optional filter on the join condition. +/// * `probe_batch` - The second record batch to be joined. +/// * `column_indices` - An array of columns to be selected for the result of the join. +/// * `random_state` - The random state for the join. +/// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. +/// +/// # Returns +/// +/// A [Result] containing an optional record batch if the join type is not one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`. +/// If the join type is one of the above four, the function will return [None]. +#[allow(clippy::too_many_arguments)] +pub(crate) fn join_with_probe_batch( + build_hash_joiner: &mut OneSideHashJoiner, + probe_hash_joiner: &mut OneSideHashJoiner, + schema: &SchemaRef, + join_type: JoinType, + filter: Option<&JoinFilter>, + probe_batch: &RecordBatch, + column_indices: &[ColumnIndex], + random_state: &RandomState, + null_equals_null: bool, +) -> Result> { + if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { + return Ok(None); + } + let (build_indices, probe_indices) = build_join_indices( + probe_batch, + &build_hash_joiner.hashmap, + &build_hash_joiner.input_buffer, + &build_hash_joiner.on, + &probe_hash_joiner.on, + filter, + random_state, + null_equals_null, + &mut build_hash_joiner.hashes_buffer, + Some(build_hash_joiner.deleted_offset), + build_hash_joiner.build_side, + )?; + if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { + record_visited_indices( + &mut build_hash_joiner.visited_rows, + build_hash_joiner.deleted_offset, + &build_indices, + ); + } + if need_to_produce_result_in_final(build_hash_joiner.build_side.negate(), join_type) { + record_visited_indices( + &mut probe_hash_joiner.visited_rows, + probe_hash_joiner.offset, + &probe_indices, + ); + } + if matches!( + join_type, + JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::RightSemi + ) { + Ok(None) + } else { + build_batch_from_indices( + schema, + &build_hash_joiner.input_buffer, + probe_batch, + &build_indices, + &probe_indices, + column_indices, + build_hash_joiner.build_side, + ) + .map(|batch| (batch.num_rows() > 0).then_some(batch)) + } +} + +pub struct OneSideHashJoiner { /// Build side build_side: JoinSide, /// Input record batch buffer - input_buffer: RecordBatch, + pub input_buffer: RecordBatch, /// Columns from the side - on: Vec, + pub(crate) on: Vec, /// Hashmap - hashmap: SymmetricJoinHashMap, + pub(crate) hashmap: SymmetricJoinHashMap, /// To optimize hash deleting in case of pruning, we hold them in memory row_hash_values: VecDeque, /// Reuse the hashes buffer - hashes_buffer: Vec, + pub(crate) hashes_buffer: Vec, /// Matched rows - visited_rows: HashSet, + pub(crate) visited_rows: HashSet, /// Offset - offset: usize, + pub(crate) offset: usize, /// Deleted offset - deleted_offset: usize, + pub(crate) deleted_offset: usize, } impl OneSideHashJoiner { @@ -1156,7 +1148,7 @@ impl OneSideHashJoiner { /// # Returns /// /// Returns a [Result] encapsulating any intermediate errors. - fn update_internal_state( + pub(crate) fn update_internal_state( &mut self, batch: &RecordBatch, random_state: &RandomState, @@ -1180,280 +1172,6 @@ impl OneSideHashJoiner { Ok(()) } - /// Gets build and probe indices which satisfy the on condition (including - /// the equality condition and the join filter) in the join. - #[allow(clippy::too_many_arguments)] - pub fn build_join_indices( - probe_batch: &RecordBatch, - build_hashmap: &SymmetricJoinHashMap, - build_input_buffer: &RecordBatch, - on_build: &[Column], - on_probe: &[Column], - filter: Option<&JoinFilter>, - random_state: &RandomState, - null_equals_null: bool, - hashes_buffer: &mut Vec, - offset: Option, - build_side: JoinSide, - ) -> Result<(UInt64Array, UInt32Array)> { - // Get the indices that satisfy the equality condition, like `left.a1 = right.a2` - let (build_indices, probe_indices) = Self::build_equal_condition_join_indices( - build_hashmap, - build_input_buffer, - probe_batch, - on_build, - on_probe, - random_state, - null_equals_null, - hashes_buffer, - offset, - )?; - if let Some(filter) = filter { - // Filter the indices which satisfy the non-equal join condition, like `left.b1 = 10` - apply_join_filter_to_indices( - build_input_buffer, - probe_batch, - build_indices, - probe_indices, - filter, - build_side, - ) - } else { - Ok((build_indices, probe_indices)) - } - } - - // Returns build/probe indices satisfying the equality condition. - // On LEFT.b1 = RIGHT.b2 - // LEFT Table: - // a1 b1 c1 - // 1 1 10 - // 3 3 30 - // 5 5 50 - // 7 7 70 - // 9 8 90 - // 11 8 110 - // 13 10 130 - // RIGHT Table: - // a2 b2 c2 - // 2 2 20 - // 4 4 40 - // 6 6 60 - // 8 8 80 - // 10 10 100 - // 12 10 120 - // The result is - // "+----+----+-----+----+----+-----+", - // "| a1 | b1 | c1 | a2 | b2 | c2 |", - // "+----+----+-----+----+----+-----+", - // "| 11 | 8 | 110 | 8 | 8 | 80 |", - // "| 13 | 10 | 130 | 10 | 10 | 100 |", - // "| 13 | 10 | 130 | 12 | 10 | 120 |", - // "| 9 | 8 | 90 | 8 | 8 | 80 |", - // "+----+----+-----+----+----+-----+" - // And the result of build and probe indices are: - // Build indices: 5, 6, 6, 4 - // Probe indices: 3, 4, 5, 3 - #[allow(clippy::too_many_arguments)] - pub fn build_equal_condition_join_indices( - build_hashmap: &SymmetricJoinHashMap, - build_input_buffer: &RecordBatch, - probe_batch: &RecordBatch, - build_on: &[Column], - probe_on: &[Column], - random_state: &RandomState, - null_equals_null: bool, - hashes_buffer: &mut Vec, - offset: Option, - ) -> Result<(UInt64Array, UInt32Array)> { - let keys_values = probe_on - .iter() - .map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))) - .collect::>>()?; - let build_join_values = build_on - .iter() - .map(|c| { - Ok(c.evaluate(build_input_buffer)? - .into_array(build_input_buffer.num_rows())) - }) - .collect::>>()?; - hashes_buffer.clear(); - hashes_buffer.resize(probe_batch.num_rows(), 0); - let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - // Using a buffer builder to avoid slower normal builder - let mut build_indices = UInt64BufferBuilder::new(0); - let mut probe_indices = UInt32BufferBuilder::new(0); - let offset_value = offset.unwrap_or(0); - // Visit all of the probe rows - for (row, hash_value) in hash_values.iter().enumerate() { - // Get the hash and find it in the build index - // For every item on the build and probe we check if it matches - // This possibly contains rows with hash collisions, - // So we have to check here whether rows are equal or not - if let Some((_, indices)) = build_hashmap - .0 - .get(*hash_value, |(hash, _)| *hash_value == *hash) - { - for &i in indices { - // Check hash collisions - let offset_build_index = i as usize - offset_value; - // Check hash collisions - if equal_rows( - offset_build_index, - row, - &build_join_values, - &keys_values, - null_equals_null, - )? { - build_indices.append(offset_build_index as u64); - probe_indices.append(row as u32); - } - } - } - } - - Ok(( - PrimitiveArray::new(build_indices.finish().into(), None), - PrimitiveArray::new(probe_indices.finish().into(), None), - )) - } - - /// This method performs a join between the build side input buffer and the probe side batch. - /// - /// # Arguments - /// - /// * `schema` - A reference to the schema of the output record batch. - /// * `join_type` - The type of join to be performed. - /// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join. - /// * `filter` - An optional filter on the join condition. - /// * `probe_batch` - The second record batch to be joined. - /// * `probe_visited` - A hash set to store the visited indices from the probe batch. - /// * `probe_offset` - The offset of the probe side for visited indices calculations. - /// * `column_indices` - An array of columns to be selected for the result of the join. - /// * `random_state` - The random state for the join. - /// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. - /// - /// # Returns - /// - /// A [Result] containing an optional record batch if the join type is not one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`. - /// If the join type is one of the above four, the function will return [None]. - #[allow(clippy::too_many_arguments)] - fn join_with_probe_batch( - &mut self, - schema: &SchemaRef, - join_type: JoinType, - on_probe: &[Column], - filter: Option<&JoinFilter>, - probe_batch: &RecordBatch, - probe_visited: &mut HashSet, - probe_offset: usize, - column_indices: &[ColumnIndex], - random_state: &RandomState, - null_equals_null: bool, - ) -> Result> { - if self.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { - return Ok(None); - } - let (build_indices, probe_indices) = Self::build_join_indices( - probe_batch, - &self.hashmap, - &self.input_buffer, - &self.on, - on_probe, - filter, - random_state, - null_equals_null, - &mut self.hashes_buffer, - Some(self.deleted_offset), - self.build_side, - )?; - if need_to_produce_result_in_final(self.build_side, join_type) { - record_visited_indices( - &mut self.visited_rows, - self.deleted_offset, - &build_indices, - ); - } - if need_to_produce_result_in_final(self.build_side.negate(), join_type) { - record_visited_indices(probe_visited, probe_offset, &probe_indices); - } - if matches!( - join_type, - JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::LeftSemi - | JoinType::RightSemi - ) { - Ok(None) - } else { - build_batch_from_indices( - schema, - &self.input_buffer, - probe_batch, - &build_indices, - &probe_indices, - column_indices, - self.build_side, - ) - .map(|batch| (batch.num_rows() > 0).then_some(batch)) - } - } - - /// This function produces unmatched record results based on the build side, - /// join type and other parameters. - /// - /// The method uses first `prune_length` rows from the build side input buffer - /// to produce results. - /// - /// # Arguments - /// - /// * `output_schema` - The schema of the final output record batch. - /// * `prune_length` - The length of the determined prune length. - /// * `probe_schema` - The schema of the probe [RecordBatch]. - /// * `join_type` - The type of join to be performed. - /// * `column_indices` - Indices of columns that are being joined. - /// - /// # Returns - /// - /// * `Option` - The final output record batch if required, otherwise [None]. - fn build_side_determined_results( - &self, - output_schema: &SchemaRef, - prune_length: usize, - probe_schema: SchemaRef, - join_type: JoinType, - column_indices: &[ColumnIndex], - ) -> Result> { - // Check if we need to produce a result in the final output: - if need_to_produce_result_in_final(self.build_side, join_type) { - // Calculate the indices for build and probe sides based on join type and build side: - let (build_indices, probe_indices) = calculate_indices_by_join_type( - self.build_side, - prune_length, - &self.visited_rows, - self.deleted_offset, - join_type, - )?; - - // Create an empty probe record batch: - let empty_probe_batch = RecordBatch::new_empty(probe_schema); - // Build the final result from the indices of build and probe sides: - build_batch_from_indices( - output_schema.as_ref(), - &self.input_buffer, - &empty_probe_batch, - &build_indices, - &probe_indices, - column_indices, - self.build_side, - ) - .map(|batch| (batch.num_rows() > 0).then_some(batch)) - } else { - // If we don't need to produce a result, return None - Ok(None) - } - } - /// Prunes the internal buffer. /// /// Argument `probe_batch` is used to update the intervals of the sorted @@ -1475,7 +1193,7 @@ impl OneSideHashJoiner { /// /// If there are rows to prune, returns the pruned build side record batch wrapped in an `Ok` variant. /// Otherwise, returns `Ok(None)`. - fn calculate_prune_length_with_probe_batch( + pub(crate) fn calculate_prune_length_with_probe_batch( &mut self, build_side_sorted_filter_expr: &mut SortedFilterExpr, probe_side_sorted_filter_expr: &mut SortedFilterExpr, @@ -1508,22 +1226,7 @@ impl OneSideHashJoiner { determine_prune_length(&self.input_buffer, build_side_sorted_filter_expr) } - fn prune_internal_state_and_build_anti_result( - &mut self, - prune_length: usize, - schema: &SchemaRef, - probe_batch: &RecordBatch, - join_type: JoinType, - column_indices: &[ColumnIndex], - ) -> Result> { - // Compute the result and perform pruning if there are rows to prune: - let result = self.build_side_determined_results( - schema, - prune_length, - probe_batch.schema(), - join_type, - column_indices, - ); + pub(crate) fn prune_internal_state(&mut self, prune_length: usize) -> Result<()> { // Prune the hash values: prune_hash_values( prune_length, @@ -1541,30 +1244,7 @@ impl OneSideHashJoiner { .slice(prune_length, self.input_buffer.num_rows() - prune_length); // Increment the deleted offset: self.deleted_offset += prune_length; - result - } -} - -fn combine_two_batches( - output_schema: &SchemaRef, - left_batch: Option, - right_batch: Option, -) -> Result> { - match (left_batch, right_batch) { - (Some(batch), None) | (None, Some(batch)) => { - // If only one of the batches are present, return it: - Ok(Some(batch)) - } - (Some(left_batch), Some(right_batch)) => { - // If both batches are present, concatenate them: - concat_batches(output_schema, &[left_batch, right_batch]) - .map_err(DataFusionError::ArrowError) - .map(Some) - } - (None, None) => { - // If neither is present, return an empty batch: - Ok(None) - } + Ok(()) } } @@ -1634,14 +1314,13 @@ impl SymmetricHashJoinStream { probe_hash_joiner .update_internal_state(&probe_batch, &self.random_state)?; // Join the two sides: - let equal_result = build_hash_joiner.join_with_probe_batch( + let equal_result = join_with_probe_batch( + build_hash_joiner, + probe_hash_joiner, &self.schema, self.join_type, - &probe_hash_joiner.on, self.filter.as_ref(), &probe_batch, - &mut probe_hash_joiner.visited_rows, - probe_hash_joiner.offset, &self.column_indices, &self.random_state, self.null_equals_null, @@ -1673,13 +1352,16 @@ impl SymmetricHashJoinStream { )?; if prune_length > 0 { - build_hash_joiner.prune_internal_state_and_build_anti_result( - prune_length, + let res = build_side_determined_results( + build_hash_joiner, &self.schema, - &probe_batch, + prune_length, + probe_batch.schema(), self.join_type, &self.column_indices, - )? + )?; + build_hash_joiner.prune_internal_state(prune_length)?; + res } else { None } @@ -1708,7 +1390,8 @@ impl SymmetricHashJoinStream { } self.final_result = true; // Get the left side results: - let left_result = self.left.build_side_determined_results( + let left_result = build_side_determined_results( + &self.left, &self.schema, self.left.input_buffer.num_rows(), self.right.input_buffer.schema(), @@ -1716,7 +1399,8 @@ impl SymmetricHashJoinStream { &self.column_indices, )?; // Get the right side results: - let right_result = self.right.build_side_determined_results( + let right_result = build_side_determined_results( + &self.right, &self.schema, self.right.input_buffer.num_rows(), self.left.input_buffer.schema(), @@ -1746,509 +1430,34 @@ impl SymmetricHashJoinStream { mod tests { use std::fs::File; - use arrow::array::{ArrayRef, Float64Array, IntervalDayTimeArray}; - use arrow::array::{Int32Array, TimestampMillisecondArray}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; - use arrow::util::pretty::pretty_format_batches; use rstest::*; use tempfile::TempDir; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{binary, col, Column}; - use datafusion_physical_expr::intervals::test_utils::{ - gen_conjunctive_numerical_expr, gen_conjunctive_temporal_expr, - }; - use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_expr::intervals::test_utils::gen_conjunctive_numerical_expr; - use crate::physical_plan::joins::{ - hash_join_utils::tests::complicated_filter, HashJoinExec, PartitionMode, - }; - use crate::physical_plan::{ - common, displayable, memory::MemoryExec, repartition::RepartitionExec, - }; + use crate::physical_plan::displayable; + use crate::physical_plan::joins::hash_join_utils::tests::complicated_filter; use crate::prelude::{CsvReadOptions, SessionConfig, SessionContext}; use crate::test_util::register_unbounded_file_with_ordering; - use super::*; - - const TABLE_SIZE: i32 = 100; - - fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) { - // compare - let first_formatted = pretty_format_batches(collected_1).unwrap().to_string(); - let second_formatted = pretty_format_batches(collected_2).unwrap().to_string(); - - let mut first_formatted_sorted: Vec<&str> = - first_formatted.trim().lines().collect(); - first_formatted_sorted.sort_unstable(); - - let mut second_formatted_sorted: Vec<&str> = - second_formatted.trim().lines().collect(); - second_formatted_sorted.sort_unstable(); - - for (i, (first_line, second_line)) in first_formatted_sorted - .iter() - .zip(&second_formatted_sorted) - .enumerate() - { - assert_eq!((i, first_line), (i, second_line)); - } - } - - async fn partitioned_sym_join_with_filter( - left: Arc, - right: Arc, - on: JoinOn, - filter: Option, - join_type: &JoinType, - null_equals_null: bool, - context: Arc, - ) -> Result> { - let partition_count = 4; - - let left_expr = on - .iter() - .map(|(l, _)| Arc::new(l.clone()) as _) - .collect::>(); - - let right_expr = on - .iter() - .map(|(_, r)| Arc::new(r.clone()) as _) - .collect::>(); - - let join = SymmetricHashJoinExec::try_new( - Arc::new(RepartitionExec::try_new( - left, - Partitioning::Hash(left_expr, partition_count), - )?), - Arc::new(RepartitionExec::try_new( - right, - Partitioning::Hash(right_expr, partition_count), - )?), - on, - filter, - join_type, - null_equals_null, - StreamJoinPartitionMode::Partitioned, - )?; - - let mut batches = vec![]; - for i in 0..partition_count { - let stream = join.execute(i, context.clone())?; - let more_batches = common::collect(stream).await?; - batches.extend( - more_batches - .into_iter() - .filter(|b| b.num_rows() > 0) - .collect::>(), - ); - } - - Ok(batches) - } - - async fn partitioned_hash_join_with_filter( - left: Arc, - right: Arc, - on: JoinOn, - filter: Option, - join_type: &JoinType, - null_equals_null: bool, - context: Arc, - ) -> Result> { - let partition_count = 4; - - let (left_expr, right_expr) = on - .iter() - .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) - .unzip(); - - let join = HashJoinExec::try_new( - Arc::new(RepartitionExec::try_new( - left, - Partitioning::Hash(left_expr, partition_count), - )?), - Arc::new(RepartitionExec::try_new( - right, - Partitioning::Hash(right_expr, partition_count), - )?), - on, - filter, - join_type, - PartitionMode::Partitioned, - null_equals_null, - )?; - - let mut batches = vec![]; - for i in 0..partition_count { - let stream = join.execute(i, context.clone())?; - let more_batches = common::collect(stream).await?; - batches.extend( - more_batches - .into_iter() - .filter(|b| b.num_rows() > 0) - .collect::>(), - ); - } - - Ok(batches) - } - - pub fn split_record_batches( - batch: &RecordBatch, - batch_size: usize, - ) -> Result> { - let row_num = batch.num_rows(); - let number_of_batch = row_num / batch_size; - let mut sizes = vec![batch_size; number_of_batch]; - sizes.push(row_num - (batch_size * number_of_batch)); - let mut result = vec![]; - for (i, size) in sizes.iter().enumerate() { - result.push(batch.slice(i * batch_size, *size)); - } - Ok(result) - } - - // It creates join filters for different type of fields for testing. - macro_rules! join_expr_tests { - ($func_name:ident, $type:ty, $SCALAR:ident) => { - fn $func_name( - expr_id: usize, - left_col: Arc, - right_col: Arc, - ) -> Arc { - match expr_id { - // left_col + 1 > right_col + 5 AND left_col + 3 < right_col + 10 - 0 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Plus, - Operator::Plus, - Operator::Plus, - Operator::Plus, - ), - ScalarValue::$SCALAR(Some(1 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), - ScalarValue::$SCALAR(Some(3 as $type)), - ScalarValue::$SCALAR(Some(10 as $type)), - (Operator::Gt, Operator::Lt), - ), - // left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10 - 1 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Minus, - Operator::Plus, - Operator::Plus, - Operator::Plus, - ), - ScalarValue::$SCALAR(Some(1 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), - ScalarValue::$SCALAR(Some(3 as $type)), - ScalarValue::$SCALAR(Some(10 as $type)), - (Operator::Gt, Operator::Lt), - ), - // left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10 - 2 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Minus, - Operator::Plus, - Operator::Minus, - Operator::Plus, - ), - ScalarValue::$SCALAR(Some(1 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), - ScalarValue::$SCALAR(Some(3 as $type)), - ScalarValue::$SCALAR(Some(10 as $type)), - (Operator::Gt, Operator::Lt), - ), - // left_col - 10 > right_col - 5 AND left_col - 3 < right_col + 10 - 3 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Minus, - Operator::Minus, - Operator::Minus, - Operator::Plus, - ), - ScalarValue::$SCALAR(Some(10 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), - ScalarValue::$SCALAR(Some(3 as $type)), - ScalarValue::$SCALAR(Some(10 as $type)), - (Operator::Gt, Operator::Lt), - ), - // left_col - 10 > right_col - 5 AND left_col - 30 < right_col - 3 - 4 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Minus, - Operator::Minus, - Operator::Minus, - Operator::Minus, - ), - ScalarValue::$SCALAR(Some(10 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), - ScalarValue::$SCALAR(Some(30 as $type)), - ScalarValue::$SCALAR(Some(3 as $type)), - (Operator::Gt, Operator::Lt), - ), - // left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col - 3 - 5 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Minus, - Operator::Plus, - Operator::Plus, - Operator::Minus, - ), - ScalarValue::$SCALAR(Some(2 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), - ScalarValue::$SCALAR(Some(7 as $type)), - ScalarValue::$SCALAR(Some(3 as $type)), - (Operator::GtEq, Operator::LtEq), - ), - // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39 - 6 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Plus, - Operator::Minus, - Operator::Plus, - Operator::Plus, - ), - ScalarValue::$SCALAR(Some(28 as $type)), - ScalarValue::$SCALAR(Some(11 as $type)), - ScalarValue::$SCALAR(Some(21 as $type)), - ScalarValue::$SCALAR(Some(39 as $type)), - (Operator::Gt, Operator::LtEq), - ), - // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39 - 7 => gen_conjunctive_numerical_expr( - left_col, - right_col, - ( - Operator::Plus, - Operator::Minus, - Operator::Minus, - Operator::Plus, - ), - ScalarValue::$SCALAR(Some(28 as $type)), - ScalarValue::$SCALAR(Some(11 as $type)), - ScalarValue::$SCALAR(Some(21 as $type)), - ScalarValue::$SCALAR(Some(39 as $type)), - (Operator::GtEq, Operator::Lt), - ), - _ => panic!("No case"), - } - } - }; - } - - join_expr_tests!(join_expr_tests_fixture_i32, i32, Int32); - join_expr_tests!(join_expr_tests_fixture_f64, f64, Float64); - - use rand::rngs::StdRng; - use rand::{Rng, SeedableRng}; + use crate::physical_plan::joins::test_utils::{ + build_sides_record_batches, compare_batches, create_memory_table, + join_expr_tests_fixture_f64, join_expr_tests_fixture_i32, + join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter, + partitioned_sym_join_with_filter, + }; + use datafusion_common::ScalarValue; use std::iter::Iterator; - struct AscendingRandomFloatIterator { - prev: f64, - max: f64, - rng: StdRng, - } - - impl AscendingRandomFloatIterator { - fn new(min: f64, max: f64) -> Self { - let mut rng = StdRng::seed_from_u64(42); - let initial = rng.gen_range(min..max); - AscendingRandomFloatIterator { - prev: initial, - max, - rng, - } - } - } - - impl Iterator for AscendingRandomFloatIterator { - type Item = f64; - - fn next(&mut self) -> Option { - let value = self.rng.gen_range(self.prev..self.max); - self.prev = value; - Some(value) - } - } - - fn join_expr_tests_fixture_temporal( - expr_id: usize, - left_col: Arc, - right_col: Arc, - schema: &Schema, - ) -> Result> { - match expr_id { - // constructs ((left_col - INTERVAL '100ms') > (right_col - INTERVAL '200ms')) AND ((left_col - INTERVAL '450ms') < (right_col - INTERVAL '300ms')) - 0 => gen_conjunctive_temporal_expr( - left_col, - right_col, - Operator::Minus, - Operator::Minus, - Operator::Minus, - Operator::Minus, - ScalarValue::new_interval_dt(0, 100), // 100 ms - ScalarValue::new_interval_dt(0, 200), // 200 ms - ScalarValue::new_interval_dt(0, 450), // 450 ms - ScalarValue::new_interval_dt(0, 300), // 300 ms - schema, - ), - // constructs ((left_col - TIMESTAMP '2023-01-01:12.00.03') > (right_col - TIMESTAMP '2023-01-01:12.00.01')) AND ((left_col - TIMESTAMP '2023-01-01:12.00.00') < (right_col - TIMESTAMP '2023-01-01:12.00.02')) - 1 => gen_conjunctive_temporal_expr( - left_col, - right_col, - Operator::Minus, - Operator::Minus, - Operator::Minus, - Operator::Minus, - ScalarValue::TimestampMillisecond(Some(1672574403000), None), // 2023-01-01:12.00.03 - ScalarValue::TimestampMillisecond(Some(1672574401000), None), // 2023-01-01:12.00.01 - ScalarValue::TimestampMillisecond(Some(1672574400000), None), // 2023-01-01:12.00.00 - ScalarValue::TimestampMillisecond(Some(1672574402000), None), // 2023-01-01:12.00.02 - schema, - ), - _ => unreachable!(), - } - } - fn build_sides_record_batches( - table_size: i32, - key_cardinality: (i32, i32), - ) -> Result<(RecordBatch, RecordBatch)> { - let null_ratio: f64 = 0.4; - let initial_range = 0..table_size; - let index = (table_size as f64 * null_ratio).round() as i32; - let rest_of = index..table_size; - let ordered: ArrayRef = Arc::new(Int32Array::from_iter( - initial_range.clone().collect::>(), - )); - let ordered_des = Arc::new(Int32Array::from_iter( - initial_range.clone().rev().collect::>(), - )); - let cardinality = Arc::new(Int32Array::from_iter( - initial_range.clone().map(|x| x % 4).collect::>(), - )); - let cardinality_key_left = Arc::new(Int32Array::from_iter( - initial_range - .clone() - .map(|x| x % key_cardinality.0) - .collect::>(), - )); - let cardinality_key_right = Arc::new(Int32Array::from_iter( - initial_range - .clone() - .map(|x| x % key_cardinality.1) - .collect::>(), - )); - let ordered_asc_null_first = Arc::new(Int32Array::from_iter({ - std::iter::repeat(None) - .take(index as usize) - .chain(rest_of.clone().map(Some)) - .collect::>>() - })); - let ordered_asc_null_last = Arc::new(Int32Array::from_iter({ - rest_of - .clone() - .map(Some) - .chain(std::iter::repeat(None).take(index as usize)) - .collect::>>() - })); - - let ordered_desc_null_first = Arc::new(Int32Array::from_iter({ - std::iter::repeat(None) - .take(index as usize) - .chain(rest_of.rev().map(Some)) - .collect::>>() - })); - - let time = Arc::new(TimestampMillisecondArray::from( - initial_range - .clone() - .map(|x| x as i64 + 1672531200000) // x + 2023-01-01:00.00.00 - .collect::>(), - )); - let interval_time: ArrayRef = Arc::new(IntervalDayTimeArray::from( - initial_range - .map(|x| x as i64 * 100) // x * 100ms - .collect::>(), - )); - - let float_asc = Arc::new(Float64Array::from_iter_values( - AscendingRandomFloatIterator::new(0., table_size as f64) - .take(table_size as usize), - )); - - let left = RecordBatch::try_from_iter(vec![ - ("la1", ordered.clone()), - ("lb1", cardinality.clone()), - ("lc1", cardinality_key_left), - ("lt1", time.clone()), - ("la2", ordered.clone()), - ("la1_des", ordered_des.clone()), - ("l_asc_null_first", ordered_asc_null_first.clone()), - ("l_asc_null_last", ordered_asc_null_last.clone()), - ("l_desc_null_first", ordered_desc_null_first.clone()), - ("li1", interval_time.clone()), - ("l_float", float_asc.clone()), - ])?; - let right = RecordBatch::try_from_iter(vec![ - ("ra1", ordered.clone()), - ("rb1", cardinality), - ("rc1", cardinality_key_right), - ("rt1", time), - ("ra2", ordered), - ("ra1_des", ordered_des), - ("r_asc_null_first", ordered_asc_null_first), - ("r_asc_null_last", ordered_asc_null_last), - ("r_desc_null_first", ordered_desc_null_first), - ("ri1", interval_time), - ("r_float", float_asc), - ])?; - Ok((left, right)) - } + use super::*; - fn create_memory_table( - left_batch: RecordBatch, - right_batch: RecordBatch, - left_sorted: Option>, - right_sorted: Option>, - batch_size: usize, - ) -> Result<(Arc, Arc)> { - let mut left = MemoryExec::try_new( - &[split_record_batches(&left_batch, batch_size)?], - left_batch.schema(), - None, - )?; - if let Some(sorted) = left_sorted { - left = left.with_sort_information(sorted); - } - let mut right = MemoryExec::try_new( - &[split_record_batches(&right_batch, batch_size)?], - right_batch.schema(), - None, - )?; - if let Some(sorted) = right_sorted { - right = right.with_sort_information(sorted); - } - Ok((Arc::new(left), Arc::new(right))) - } + const TABLE_SIZE: i32 = 100; - async fn experiment( + pub async fn experiment( left: Arc, right: Arc, filter: Option, @@ -2289,10 +1498,10 @@ mod tests { )] join_type: JoinType, #[values( - (4, 5), - (11, 21), - (31, 71), - (99, 12), + (4, 5), + (11, 21), + (31, 71), + (99, 12), )] cardinality: (i32, i32), ) -> Result<()> { @@ -2370,10 +1579,10 @@ mod tests { )] join_type: JoinType, #[values( - (4, 5), - (11, 21), - (31, 71), - (99, 12), + (4, 5), + (11, 21), + (31, 71), + (99, 12), )] cardinality: (i32, i32), #[values(0, 1, 2, 3, 4, 5, 6, 7)] case_expr: usize, @@ -2536,10 +1745,10 @@ mod tests { )] join_type: JoinType, #[values( - (4, 5), - (11, 21), - (31, 71), - (99, 12), + (4, 5), + (11, 21), + (31, 71), + (99, 12), )] cardinality: (i32, i32), #[values(0, 1, 2, 3, 4, 5, 6)] case_expr: usize, @@ -3125,14 +2334,13 @@ mod tests { initial_right_batch.num_rows() ); - left_side_joiner.join_with_probe_batch( + join_with_probe_batch( + &mut left_side_joiner, + &mut right_side_joiner, &join_schema, join_type, - &right_side_joiner.on, Some(&filter), &initial_right_batch, - &mut right_side_joiner.visited_rows, - right_side_joiner.offset, &join_column_indices, &random_state, false, @@ -3155,9 +2363,9 @@ mod tests { )] join_type: JoinType, #[values( - (4, 5), - (99, 12), - )] + (4, 5), + (99, 12), + )] cardinality: (i32, i32), #[values(0, 1)] case_expr: usize, ) -> Result<()> { diff --git a/datafusion/core/src/physical_plan/joins/test_utils.rs b/datafusion/core/src/physical_plan/joins/test_utils.rs new file mode 100644 index 000000000000..e786fb5eb5df --- /dev/null +++ b/datafusion/core/src/physical_plan/joins/test_utils.rs @@ -0,0 +1,513 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This file has test utils for hash joins + +use crate::physical_plan::joins::utils::{JoinFilter, JoinOn}; +use crate::physical_plan::joins::{ + HashJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, +}; +use crate::physical_plan::memory::MemoryExec; +use crate::physical_plan::repartition::RepartitionExec; +use crate::physical_plan::{common, ExecutionPlan, Partitioning}; +use arrow::util::pretty::pretty_format_batches; +use arrow_array::{ + ArrayRef, Float64Array, Int32Array, IntervalDayTimeArray, RecordBatch, + TimestampMillisecondArray, +}; +use arrow_schema::Schema; +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_execution::TaskContext; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::intervals::test_utils::{ + gen_conjunctive_numerical_expr, gen_conjunctive_temporal_expr, +}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::sync::Arc; +use std::usize; + +pub fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) { + // compare + let first_formatted = pretty_format_batches(collected_1).unwrap().to_string(); + let second_formatted = pretty_format_batches(collected_2).unwrap().to_string(); + + let mut first_formatted_sorted: Vec<&str> = first_formatted.trim().lines().collect(); + first_formatted_sorted.sort_unstable(); + + let mut second_formatted_sorted: Vec<&str> = + second_formatted.trim().lines().collect(); + second_formatted_sorted.sort_unstable(); + + for (i, (first_line, second_line)) in first_formatted_sorted + .iter() + .zip(&second_formatted_sorted) + .enumerate() + { + assert_eq!((i, first_line), (i, second_line)); + } +} + +pub async fn partitioned_sym_join_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: Option, + join_type: &JoinType, + null_equals_null: bool, + context: Arc, +) -> Result> { + let partition_count = 4; + + let left_expr = on + .iter() + .map(|(l, _)| Arc::new(l.clone()) as _) + .collect::>(); + + let right_expr = on + .iter() + .map(|(_, r)| Arc::new(r.clone()) as _) + .collect::>(); + + let join = SymmetricHashJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + filter, + join_type, + null_equals_null, + StreamJoinPartitionMode::Partitioned, + )?; + + let mut batches = vec![]; + for i in 0..partition_count { + let stream = join.execute(i, context.clone())?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + + Ok(batches) +} + +pub async fn partitioned_hash_join_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: Option, + join_type: &JoinType, + null_equals_null: bool, + context: Arc, +) -> Result> { + let partition_count = 4; + let (left_expr, right_expr) = on + .iter() + .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) + .unzip(); + + let join = Arc::new(HashJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + filter, + join_type, + PartitionMode::Partitioned, + null_equals_null, + )?); + + let mut batches = vec![]; + for i in 0..partition_count { + let stream = join.execute(i, context.clone())?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + + Ok(batches) +} + +pub fn split_record_batches( + batch: &RecordBatch, + batch_size: usize, +) -> Result> { + let row_num = batch.num_rows(); + let number_of_batch = row_num / batch_size; + let mut sizes = vec![batch_size; number_of_batch]; + sizes.push(row_num - (batch_size * number_of_batch)); + let mut result = vec![]; + for (i, size) in sizes.iter().enumerate() { + result.push(batch.slice(i * batch_size, *size)); + } + Ok(result) +} + +struct AscendingRandomFloatIterator { + prev: f64, + max: f64, + rng: StdRng, +} + +impl AscendingRandomFloatIterator { + fn new(min: f64, max: f64) -> Self { + let mut rng = StdRng::seed_from_u64(42); + let initial = rng.gen_range(min..max); + AscendingRandomFloatIterator { + prev: initial, + max, + rng, + } + } +} + +impl Iterator for AscendingRandomFloatIterator { + type Item = f64; + + fn next(&mut self) -> Option { + let value = self.rng.gen_range(self.prev..self.max); + self.prev = value; + Some(value) + } +} + +pub fn join_expr_tests_fixture_temporal( + expr_id: usize, + left_col: Arc, + right_col: Arc, + schema: &Schema, +) -> Result> { + match expr_id { + // constructs ((left_col - INTERVAL '100ms') > (right_col - INTERVAL '200ms')) AND ((left_col - INTERVAL '450ms') < (right_col - INTERVAL '300ms')) + 0 => gen_conjunctive_temporal_expr( + left_col, + right_col, + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + ScalarValue::new_interval_dt(0, 100), // 100 ms + ScalarValue::new_interval_dt(0, 200), // 200 ms + ScalarValue::new_interval_dt(0, 450), // 450 ms + ScalarValue::new_interval_dt(0, 300), // 300 ms + schema, + ), + // constructs ((left_col - TIMESTAMP '2023-01-01:12.00.03') > (right_col - TIMESTAMP '2023-01-01:12.00.01')) AND ((left_col - TIMESTAMP '2023-01-01:12.00.00') < (right_col - TIMESTAMP '2023-01-01:12.00.02')) + 1 => gen_conjunctive_temporal_expr( + left_col, + right_col, + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + ScalarValue::TimestampMillisecond(Some(1672574403000), None), // 2023-01-01:12.00.03 + ScalarValue::TimestampMillisecond(Some(1672574401000), None), // 2023-01-01:12.00.01 + ScalarValue::TimestampMillisecond(Some(1672574400000), None), // 2023-01-01:12.00.00 + ScalarValue::TimestampMillisecond(Some(1672574402000), None), // 2023-01-01:12.00.02 + schema, + ), + _ => unreachable!(), + } +} + +// It creates join filters for different type of fields for testing. +macro_rules! join_expr_tests { + ($func_name:ident, $type:ty, $SCALAR:ident) => { + pub fn $func_name( + expr_id: usize, + left_col: Arc, + right_col: Arc, + ) -> Arc { + match expr_id { + // left_col + 1 > right_col + 5 AND left_col + 3 < right_col + 10 + 0 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Plus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(1 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(10 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10 + 1 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(1 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(10 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10 + 2 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Plus, + Operator::Minus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(1 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(10 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 10 > right_col - 5 AND left_col - 3 < right_col + 10 + 3 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(10 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(10 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 10 > right_col - 5 AND left_col - 30 < right_col - 3 + 4 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + ), + ScalarValue::$SCALAR(Some(10 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(30 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + (Operator::Gt, Operator::Lt), + ), + // left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col - 3 + 5 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Minus, + Operator::Plus, + Operator::Plus, + Operator::Minus, + ), + ScalarValue::$SCALAR(Some(2 as $type)), + ScalarValue::$SCALAR(Some(5 as $type)), + ScalarValue::$SCALAR(Some(7 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + (Operator::GtEq, Operator::LtEq), + ), + // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39 + 6 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Plus, + Operator::Minus, + Operator::Plus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(28 as $type)), + ScalarValue::$SCALAR(Some(11 as $type)), + ScalarValue::$SCALAR(Some(21 as $type)), + ScalarValue::$SCALAR(Some(39 as $type)), + (Operator::Gt, Operator::LtEq), + ), + // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col + 39 + 7 => gen_conjunctive_numerical_expr( + left_col, + right_col, + ( + Operator::Plus, + Operator::Minus, + Operator::Minus, + Operator::Plus, + ), + ScalarValue::$SCALAR(Some(28 as $type)), + ScalarValue::$SCALAR(Some(11 as $type)), + ScalarValue::$SCALAR(Some(21 as $type)), + ScalarValue::$SCALAR(Some(39 as $type)), + (Operator::GtEq, Operator::Lt), + ), + _ => panic!("No case"), + } + } + }; +} + +join_expr_tests!(join_expr_tests_fixture_i32, i32, Int32); +join_expr_tests!(join_expr_tests_fixture_f64, f64, Float64); + +pub fn build_sides_record_batches( + table_size: i32, + key_cardinality: (i32, i32), +) -> Result<(RecordBatch, RecordBatch)> { + let null_ratio: f64 = 0.4; + let initial_range = 0..table_size; + let index = (table_size as f64 * null_ratio).round() as i32; + let rest_of = index..table_size; + let ordered: ArrayRef = Arc::new(Int32Array::from_iter( + initial_range.clone().collect::>(), + )); + let ordered_des = Arc::new(Int32Array::from_iter( + initial_range.clone().rev().collect::>(), + )); + let cardinality = Arc::new(Int32Array::from_iter( + initial_range.clone().map(|x| x % 4).collect::>(), + )); + let cardinality_key_left = Arc::new(Int32Array::from_iter( + initial_range + .clone() + .map(|x| x % key_cardinality.0) + .collect::>(), + )); + let cardinality_key_right = Arc::new(Int32Array::from_iter( + initial_range + .clone() + .map(|x| x % key_cardinality.1) + .collect::>(), + )); + let ordered_asc_null_first = Arc::new(Int32Array::from_iter({ + std::iter::repeat(None) + .take(index as usize) + .chain(rest_of.clone().map(Some)) + .collect::>>() + })); + let ordered_asc_null_last = Arc::new(Int32Array::from_iter({ + rest_of + .clone() + .map(Some) + .chain(std::iter::repeat(None).take(index as usize)) + .collect::>>() + })); + + let ordered_desc_null_first = Arc::new(Int32Array::from_iter({ + std::iter::repeat(None) + .take(index as usize) + .chain(rest_of.rev().map(Some)) + .collect::>>() + })); + + let time = Arc::new(TimestampMillisecondArray::from( + initial_range + .clone() + .map(|x| x as i64 + 1672531200000) // x + 2023-01-01:00.00.00 + .collect::>(), + )); + let interval_time: ArrayRef = Arc::new(IntervalDayTimeArray::from( + initial_range + .map(|x| x as i64 * 100) // x * 100ms + .collect::>(), + )); + + let float_asc = Arc::new(Float64Array::from_iter_values( + AscendingRandomFloatIterator::new(0., table_size as f64) + .take(table_size as usize), + )); + + let left = RecordBatch::try_from_iter(vec![ + ("la1", ordered.clone()), + ("lb1", cardinality.clone()), + ("lc1", cardinality_key_left), + ("lt1", time.clone()), + ("la2", ordered.clone()), + ("la1_des", ordered_des.clone()), + ("l_asc_null_first", ordered_asc_null_first.clone()), + ("l_asc_null_last", ordered_asc_null_last.clone()), + ("l_desc_null_first", ordered_desc_null_first.clone()), + ("li1", interval_time.clone()), + ("l_float", float_asc.clone()), + ])?; + let right = RecordBatch::try_from_iter(vec![ + ("ra1", ordered.clone()), + ("rb1", cardinality), + ("rc1", cardinality_key_right), + ("rt1", time), + ("ra2", ordered), + ("ra1_des", ordered_des), + ("r_asc_null_first", ordered_asc_null_first), + ("r_asc_null_last", ordered_asc_null_last), + ("r_desc_null_first", ordered_desc_null_first), + ("ri1", interval_time), + ("r_float", float_asc), + ])?; + Ok((left, right)) +} + +pub fn create_memory_table( + left_batch: RecordBatch, + right_batch: RecordBatch, + left_sorted: Option>, + right_sorted: Option>, + batch_size: usize, +) -> Result<(Arc, Arc)> { + let mut left = MemoryExec::try_new( + &[split_record_batches(&left_batch, batch_size)?], + left_batch.schema(), + None, + )?; + if let Some(sorted) = left_sorted { + left = left.with_sort_information(sorted); + } + let mut right = MemoryExec::try_new( + &[split_record_batches(&right_batch, batch_size)?], + right_batch.schema(), + None, + )?; + if let Some(sorted) = right_sorted { + right = right.with_sort_information(sorted); + } + Ok((Arc::new(left), Arc::new(right))) +} From 0624378160b1d19de12a29b1374dea1930c9faaa Mon Sep 17 00:00:00 2001 From: Igor Izvekov Date: Wed, 19 Jul 2023 00:13:45 +0300 Subject: [PATCH 5/6] feat: array functions treat an array as an element (#6986) --- .../tests/sqllogictests/test_files/array.slt | 196 +++++++++++++----- .../physical-expr/src/array_expressions.rs | 59 +++++- 2 files changed, 205 insertions(+), 50 deletions(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index f0f50ccc9340..1e9b32414bba 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -55,6 +55,13 @@ AS VALUES (make_array(make_array(15, 16),make_array(NULL, 18)), make_array(16.6, 17.7, 18.8), NULL) ; +statement ok +CREATE TABLE nested_arrays +AS VALUES + (make_array(make_array(1, 2, 3), make_array(2, 9, 1), make_array(7, 8, 9), make_array(1, 2, 3), make_array(1, 7, 4), make_array(4, 5, 6)), make_array(7, 8, 9), 2, make_array([[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]])), + (make_array(make_array(4, 5, 6), make_array(10, 11, 12), make_array(4, 9, 8), make_array(7, 8, 9), make_array(10, 11, 12), make_array(1, 8, 7)), make_array(10, 11, 12), 3, make_array([[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]])) +; + statement ok CREATE TABLE arrays_values AS VALUES @@ -100,6 +107,13 @@ NULL [13.3, 14.4, 15.5] [a, m, e, t] [[11, 12], [13, 14]] NULL [,] [[15, 16], [, 18]] [16.6, 17.7, 18.8] NULL +# nested_arrays table +query ??I? +select column1, column2, column3, column4 from nested_arrays; +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [7, 8, 9] 2 [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]] +[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [10, 11, 12] 3 [[[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]]] + # values table query IIIRT select a, b, c, d, e from values; @@ -292,7 +306,13 @@ select array_append(make_array(1, 2, 3), 4), array_append(make_array(1.0, 2.0, 3 ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] -# array_append with columns +# array_append scalar function #4 (element is list) +query ??? +select array_append(make_array([1], [2], [3]), make_array(4)), array_append(make_array([1.0], [2.0], [3.0]), make_array(4.0)), array_append(make_array(['h'], ['e'], ['l'], ['l']), make_array('o')); +---- +[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + +# array_append with columns #1 query ? select array_append(column1, column2) from arrays_values; ---- @@ -305,7 +325,14 @@ select array_append(column1, column2) from arrays_values; [51, 52, , 54, 55, 56, 57, 58, 59, 60, 55] [61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66] -# array_append with columns and scalars +# array_append with columns #2 (element is list) +query ? +select array_append(column1, column2) from nested_arrays; +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [7, 8, 9]] +[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [10, 11, 12]] + +# array_append with columns and scalars #1 query ?? select array_append(column2, 100.1), array_append(column3, '.') from arrays; ---- @@ -317,6 +344,13 @@ select array_append(column2, 100.1), array_append(column3, '.') from arrays; [100.1] [,, .] [16.6, 17.7, 18.8, 100.1] [.] +# array_append with columns and scalars #2 +query ?? +select array_append(column1, make_array(1, 11, 111)), array_append(make_array(make_array(1, 2, 3), make_array(11, 12, 13)), column2) from nested_arrays; +---- +[[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [7, 8, 9]] +[[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7], [1, 11, 111]] [[1, 2, 3], [11, 12, 13], [10, 11, 12]] + ## array_prepend # array_prepend scalar function #1 @@ -337,7 +371,13 @@ select array_prepend(1, make_array(2, 3, 4)), array_prepend(1.0, make_array(2.0, ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] -# array_prepend with columns +# array_prepend scalar function #4 (element is list) +query ??? +select array_prepend(make_array(1), make_array(make_array(2), make_array(3), make_array(4))), array_prepend(make_array(1.0), make_array([2.0], [3.0], [4.0])), array_prepend(make_array('h'), make_array(['e'], ['l'], ['l'], ['o'])); +---- +[[1], [2], [3], [4]] [[1.0], [2.0], [3.0], [4.0]] [[h], [e], [l], [l], [o]] + +# array_prepend with columns #1 query ? select array_prepend(column2, column1) from arrays_values; ---- @@ -350,7 +390,14 @@ select array_prepend(column2, column1) from arrays_values; [55, 51, 52, , 54, 55, 56, 57, 58, 59, 60] [66, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70] -# array_prepend with columns and scalars +# array_prepend with columns #2 (element is list) +query ? +select array_prepend(column2, column1) from nested_arrays; +---- +[[7, 8, 9], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] +[[10, 11, 12], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] + +# array_prepend with columns and scalars #1 query ?? select array_prepend(100.1, column2), array_prepend('.', column3) from arrays; ---- @@ -362,6 +409,13 @@ select array_prepend(100.1, column2), array_prepend('.', column3) from arrays; [100.1] [., ,] [100.1, 16.6, 17.7, 18.8] [.] +# array_prepend with columns and scalars #2 (element is list) +query ?? +select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, make_array(make_array(1, 2, 3), make_array(11, 12, 13))) from nested_arrays; +---- +[[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] +[[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] + ## array_fill # array_fill scalar function #1 @@ -473,19 +527,6 @@ select array_concat(make_array(column2), make_array(column3)) from arrays_values # array_concat column-wise #4 query ? -select array_concat(column1, column2) from arrays_values; ----- -[, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1] -[11, 12, 13, 14, 15, 16, 17, 18, , 20, 12] -[21, 22, 23, , 25, 26, 27, 28, 29, 30, 23] -[31, 32, 33, 34, 35, , 37, 38, 39, 40, 34] -[44] -[41, 42, 43, 44, 45, 46, 47, 48, 49, 50, ] -[51, 52, , 54, 55, 56, 57, 58, 59, 60, 55] -[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66] - -# array_concat column-wise #5 -query ? select array_concat(make_array(column2), make_array(0)) from arrays_values; ---- [1, 0] @@ -497,7 +538,7 @@ select array_concat(make_array(column2), make_array(0)) from arrays_values; [55, 0] [66, 0] -# array_concat column-wise #6 +# array_concat column-wise #5 query ??? select array_concat(column1, column1), array_concat(column2, column2), array_concat(column3, column3) from arrays; ---- @@ -509,7 +550,7 @@ NULL [13.3, 14.4, 15.5, 13.3, 14.4, 15.5] [a, m, e, t, a, m, e, t] [[11, 12], [13, 14], [11, 12], [13, 14]] NULL [,, ,] [[15, 16], [, 18], [15, 16], [, 18]] [16.6, 17.7, 18.8, 16.6, 17.7, 18.8] NULL -# array_concat column-wise #7 +# array_concat column-wise #6 query ?? select array_concat(column1, make_array(make_array(1, 2), make_array(3, 4))), array_concat(column2, make_array(1.1, 2.2, 3.3)) from arrays; ---- @@ -521,7 +562,7 @@ select array_concat(column1, make_array(make_array(1, 2), make_array(3, 4))), ar [[11, 12], [13, 14], [1, 2], [3, 4]] [1.1, 2.2, 3.3] [[15, 16], [, 18], [1, 2], [3, 4]] [16.6, 17.7, 18.8, 1.1, 2.2, 3.3] -# array_concat column-wise #8 +# array_concat column-wise #7 query ? select array_concat(column3, make_array('.', '.', '.')) from arrays; ---- @@ -543,7 +584,7 @@ select array_concat(column3, make_array('.', '.', '.')) from arrays; # [11, 12] NULL NULL NULL # NULL NULL NULL NULL -# array_concat column-wise #9 (1D + 1D) +# array_concat column-wise #8 (1D + 1D) query ? select array_concat(column1, column2) from arrays_values_v2; ---- @@ -554,28 +595,36 @@ select array_concat(column1, column2) from arrays_values_v2; [11, 12] NULL -# TODO: Concat columns with different dimensions fails -# array_concat column-wise #10 (1D + 2D) -# query error DataFusion error: Arrow error: Invalid argument error: column types must match schema types, expected List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) but found List\(Field \{ name: "item", data_type: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) at column index 0 -# select array_concat(make_array(column3), column4) from arrays_values_v2; +# array_concat column-wise #9 (2D + 1D) +query ? +select array_concat(column4, make_array(column3)) from arrays_values_v2; +---- +[[30, 40, 50], [12]] +[[, , 60], [13]] +[[70, , ], [14]] +[[]] +[[]] +[[]] + +# array_concat column-wise #10 (3D + 2D + 1D) +query ? +select array_concat(column4, column1, column2) from nested_arrays; +---- +[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]], [[1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]], [[7, 8, 9]]] +[[[11, 12, 13], [14, 15, 16]], [[17, 18, 19], [20, 21, 22]], [[4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]], [[10, 11, 12]]] -# array_concat column-wise #11 (1D + Integers) +# array_concat column-wise #11 (2D + 1D) query ? -select array_concat(column2, column3) from arrays_values_v2; +select array_concat(column4, column1) from arrays_values_v2; ---- -[4, 5, , 12] -[7, , 8, 13] -[14] -[, 21, ] +[[30, 40, 50], [, 2, 3]] +[[, , 60], ] +[[70, , ], [9, , 10]] +[[, 1]] +[[11, 12]] [] -[] - -# TODO: Panic at 'range end index 3 out of range for slice of length 2' -# array_concat column-wise #12 (2D + 1D) -# query -# select array_concat(column4, column1) from arrays_values_v2; -# array_concat column-wise #13 (1D + 1D + 1D) +# array_concat column-wise #12 (1D + 1D + 1D) query ? select array_concat(make_array(column3), column1, column2) from arrays_values_v2; ---- @@ -594,13 +643,25 @@ select array_position(['h', 'e', 'l', 'l', 'o'], 'l'), array_position([1, 2, 3, ---- 3 5 1 -# array_position scalar function #2 +# array_position scalar function #2 (with optional argument) query III select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2); ---- 4 5 2 -# array_position with columns +# array_position scalar function #3 (element is list) +query II +select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]); +---- +2 2 + +# array_position scalar function #4 (element in list; with optional argument) +query II +select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6], 3), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4], 3); +---- +4 3 + +# array_position with columns #1 query II select array_position(column1, column2), array_position(column1, column2, column3) from arrays_values_without_nulls; ---- @@ -609,24 +670,44 @@ select array_position(column1, column2), array_position(column1, column2, column 3 3 4 4 -# array_position with columns and scalars +# array_position with columns #2 (element is list) query II -select array_position(column1, 3), array_position(column1, 3, 5) from arrays_values_without_nulls; +select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays; ---- -3 NULL -NULL NULL -NULL NULL -NULL NULL +3 3 +2 5 + +# array_position with columns and scalars #1 +query III +select array_position(make_array(1, 2, 3, 4, 5), column2), array_position(column1, 3), array_position(column1, 3, 5) from arrays_values_without_nulls; +---- +1 3 NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL + +# array_position with columns and scalars #2 (element is list) +query III +select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from nested_arrays; +---- +NULL 6 4 +NULL 1 NULL ## array_positions -# array_positions scalar function +# array_positions scalar function #1 query ??? select array_positions(['h', 'e', 'l', 'l', 'o'], 'l'), array_positions([1, 2, 3, 4, 5], 5), array_positions([1, 1, 1], 1); ---- [3, 4] [5] [1, 2, 3] -# array_positions with columns +# array_positions scalar function #2 +query ? +select array_positions(make_array([1, 2, 3], [2, 1, 3], [1, 5, 6], [2, 1, 3], [4, 5, 6]), [2, 1, 3]); +---- +[2, 4] + +# array_positions with columns #1 query ? select array_positions(column1, column2) from arrays_values_without_nulls; ---- @@ -635,7 +716,14 @@ select array_positions(column1, column2) from arrays_values_without_nulls; [3] [4] -# array_positions with columns and scalars +# array_positions with columns #2 (element is list) +query ? +select array_positions(column1, column2) from nested_arrays; +---- +[3] +[2, 5] + +# array_positions with columns and scalars #1 query ?? select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], column2) from arrays_values_without_nulls; ---- @@ -644,6 +732,13 @@ select array_positions(column1, 4), array_positions(array[1, 2, 23, 13, 33, 45], [] [3] [] [] +# array_positions with columns and scalars #2 (element is list) +query ?? +select array_positions(column1, make_array(4, 5, 6)), array_positions(make_array([1, 2, 3], [11, 12, 13], [4, 5, 6]), column2) from nested_arrays; +---- +[6] [] +[1] [] + ## array_replace # array_replace scalar function @@ -1053,6 +1148,9 @@ select make_array(f0) from fixed_size_list_array statement ok drop table values; +statement ok +drop table nested_arrays; + statement ok drop table arrays; diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 104d49e1c876..b16432b50531 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -410,6 +410,7 @@ pub fn array_append(args: &[ArrayRef]) -> Result { let element = &args[1]; let res = match (arr.value_type(), element.data_type()) { + (DataType::List(_), DataType::List(_)) => concat_internal(args)?, (DataType::Utf8, DataType::Utf8) => append!(arr, element, StringArray), (DataType::LargeUtf8, DataType::LargeUtf8) => append!(arr, element, LargeStringArray), (DataType::Boolean, DataType::Boolean) => append!(arr, element, BooleanArray), @@ -499,6 +500,7 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result { let arr = as_list_array(&args[1])?; let res = match (arr.value_type(), element.data_type()) { + (DataType::List(_), DataType::List(_)) => concat_internal(args)?, (DataType::Utf8, DataType::Utf8) => prepend!(arr, element, StringArray), (DataType::LargeUtf8, DataType::LargeUtf8) => prepend!(arr, element, LargeStringArray), (DataType::Boolean, DataType::Boolean) => prepend!(arr, element, BooleanArray), @@ -543,7 +545,18 @@ fn align_array_dimensions(args: Vec) -> Result> { let mut aligned_array = array.clone(); for _ in 0..(max_ndim - ndim) { let data_type = aligned_array.as_ref().data_type().clone(); - aligned_array = array_array(&[aligned_array], data_type)?; + let offsets: Vec = + (0..downcast_arg!(aligned_array, ListArray).offsets().len()) + .map(|i| i as i32) + .collect(); + let field = Arc::new(Field::new("item", data_type, true)); + + aligned_array = Arc::new(ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(aligned_array.clone()), + None, + )?) } Ok(aligned_array) } else { @@ -761,6 +774,7 @@ pub fn array_position(args: &[ArrayRef]) -> Result { let res = match arr.data_type() { DataType::List(field) => match field.data_type() { + DataType::List(_) => position!(arr, element, index, ListArray), DataType::Utf8 => position!(arr, element, index, StringArray), DataType::LargeUtf8 => position!(arr, element, index, LargeStringArray), DataType::Boolean => position!(arr, element, index, BooleanArray), @@ -846,6 +860,7 @@ pub fn array_positions(args: &[ArrayRef]) -> Result { let res = match arr.data_type() { DataType::List(field) => match field.data_type() { + DataType::List(_) => positions!(arr, element, ListArray), DataType::Utf8 => positions!(arr, element, StringArray), DataType::LargeUtf8 => positions!(arr, element, LargeStringArray), DataType::Boolean => positions!(arr, element, BooleanArray), @@ -1617,6 +1632,48 @@ mod tests { ); } + #[test] + fn test_nested_array_concat() { + // array_concat([1, 2, 3, 4], [1, 2, 3, 4]) = [1, 2, 3, 4, 1, 2, 3, 4] + let list_array = return_array().into_array(1); + let arr = array_concat(&[list_array.clone(), list_array.clone()]) + .expect("failed to initialize function array_concat"); + let result = + as_list_array(&arr).expect("failed to initialize function array_concat"); + + assert_eq!( + &[1, 2, 3, 4, 1, 2, 3, 4], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_concat([[1, 2, 3, 4], [5, 6, 7, 8]], [1, 2, 3, 4]) = [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4]] + let list_nested_array = return_nested_array().into_array(1); + let list_array = return_array().into_array(1); + let arr = array_concat(&[list_nested_array, list_array]) + .expect("failed to initialize function array_concat"); + let result = + as_list_array(&arr).expect("failed to initialize function array_concat"); + + assert_eq!( + &[1, 2, 3, 4], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(2) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + } + #[test] fn test_array_fill() { // array_fill(4, [5]) = [4, 4, 4, 4, 4] From 07ffebbfeffb72842de45f3183a7a7cb153f07c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Metehan=20Y=C4=B1ld=C4=B1r=C4=B1m?= <100111937+metesynnada@users.noreply.github.com> Date: Wed, 19 Jul 2023 09:52:42 +0300 Subject: [PATCH 6/6] Code replacement (#7009) --- .../physical_optimizer/sort_enforcement.rs | 230 +----------------- .../core/src/physical_optimizer/test_utils.rs | 226 +++++++++++++++++ 2 files changed, 234 insertions(+), 222 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs b/datafusion/core/src/physical_optimizer/sort_enforcement.rs index cdd3b39f0d70..c9da83d86b34 100644 --- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs @@ -972,22 +972,15 @@ pub(crate) fn unbounded_output(plan: &Arc) -> bool { #[cfg(test)] mod tests { use super::*; - use crate::datasource::listing::PartitionedFile; - use crate::datasource::object_store::ObjectStoreUrl; - use crate::datasource::physical_plan::{FileScanConfig, ParquetExec}; use crate::physical_optimizer::dist_enforcement::EnforceDistribution; - use crate::physical_plan::aggregates::PhysicalGroupBy; - use crate::physical_plan::aggregates::{AggregateExec, AggregateMode}; - use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; - use crate::physical_plan::filter::FilterExec; - use crate::physical_plan::joins::utils::{JoinFilter, JoinOn}; - use crate::physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; - use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; - use crate::physical_plan::memory::MemoryExec; + use crate::physical_optimizer::test_utils::{ + aggregate_exec, bounded_window_exec, coalesce_batches_exec, + coalesce_partitions_exec, filter_exec, get_plan_string, global_limit_exec, + hash_join_exec, limit_exec, local_limit_exec, memory_exec, parquet_exec, + parquet_exec_sorted, repartition_exec, sort_exec, sort_expr, sort_expr_options, + sort_merge_join_exec, sort_preserving_merge_exec, union_exec, + }; use crate::physical_plan::repartition::RepartitionExec; - use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; - use crate::physical_plan::union::UnionExec; - use crate::physical_plan::windows::create_window_expr; use crate::physical_plan::windows::PartitionSearchMode::{ Linear, PartiallySorted, Sorted, }; @@ -996,9 +989,8 @@ mod tests { use crate::test::csv_exec_sorted; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::{Result, Statistics}; + use datafusion_common::Result; use datafusion_expr::JoinType; - use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::expressions::{col, NotExpr}; use datafusion_physical_expr::PhysicalSortExpr; @@ -1030,13 +1022,6 @@ mod tests { Ok(schema) } - // Util function to get string representation of a physical plan - fn get_plan_string(plan: &Arc) -> Vec { - let formatted = displayable(plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - actual.iter().map(|elem| elem.to_string()).collect() - } - #[tokio::test] async fn test_is_column_aligned_nullable() -> Result<()> { let schema = create_test_schema()?; @@ -2828,203 +2813,4 @@ mod tests { assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) } - - /// make PhysicalSortExpr with default options - fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr { - sort_expr_options(name, schema, SortOptions::default()) - } - - /// PhysicalSortExpr with specified options - fn sort_expr_options( - name: &str, - schema: &Schema, - options: SortOptions, - ) -> PhysicalSortExpr { - PhysicalSortExpr { - expr: col(name, schema).unwrap(), - options, - } - } - - fn memory_exec(schema: &SchemaRef) -> Arc { - Arc::new(MemoryExec::try_new(&[], schema.clone(), None).unwrap()) - } - - fn sort_exec( - sort_exprs: impl IntoIterator, - input: Arc, - ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortExec::new(sort_exprs, input)) - } - - fn hash_join_exec( - left: Arc, - right: Arc, - on: JoinOn, - filter: Option, - join_type: &JoinType, - ) -> Result> { - Ok(Arc::new(HashJoinExec::try_new( - left, - right, - on, - filter, - join_type, - PartitionMode::Partitioned, - true, - )?)) - } - - fn sort_preserving_merge_exec( - sort_exprs: impl IntoIterator, - input: Arc, - ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) - } - - fn filter_exec( - predicate: Arc, - input: Arc, - ) -> Arc { - Arc::new(FilterExec::try_new(predicate, input).unwrap()) - } - - fn bounded_window_exec( - col_name: &str, - sort_exprs: impl IntoIterator, - input: Arc, - ) -> Arc { - let sort_exprs: Vec<_> = sort_exprs.into_iter().collect(); - let schema = input.schema(); - - Arc::new( - BoundedWindowAggExec::try_new( - vec![create_window_expr( - &WindowFunction::AggregateFunction(AggregateFunction::Count), - "count".to_owned(), - &[col(col_name, &schema).unwrap()], - &[], - &sort_exprs, - Arc::new(WindowFrame::new(true)), - schema.as_ref(), - ) - .unwrap()], - input.clone(), - input.schema(), - vec![], - Sorted, - ) - .unwrap(), - ) - } - - /// Create a non sorted parquet exec - fn parquet_exec(schema: &SchemaRef) -> Arc { - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - infinite_source: false, - }, - None, - None, - )) - } - - // Created a sorted parquet exec - fn parquet_exec_sorted( - schema: &SchemaRef, - sort_exprs: impl IntoIterator, - ) -> Arc { - let sort_exprs = sort_exprs.into_iter().collect(); - - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::default(), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![sort_exprs], - infinite_source: false, - }, - None, - None, - )) - } - - fn union_exec(input: Vec>) -> Arc { - Arc::new(UnionExec::new(input)) - } - - fn limit_exec(input: Arc) -> Arc { - global_limit_exec(local_limit_exec(input)) - } - - fn local_limit_exec(input: Arc) -> Arc { - Arc::new(LocalLimitExec::new(input, 100)) - } - - fn global_limit_exec(input: Arc) -> Arc { - Arc::new(GlobalLimitExec::new(input, 0, Some(100))) - } - - fn repartition_exec(input: Arc) -> Arc { - Arc::new( - RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)).unwrap(), - ) - } - - fn coalesce_partitions_exec(input: Arc) -> Arc { - Arc::new(CoalescePartitionsExec::new(input)) - } - - fn aggregate_exec(input: Arc) -> Arc { - let schema = input.schema(); - Arc::new( - AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::default(), - vec![], - vec![], - vec![], - input, - schema, - ) - .unwrap(), - ) - } - - fn sort_merge_join_exec( - left: Arc, - right: Arc, - join_on: &JoinOn, - join_type: &JoinType, - ) -> Arc { - Arc::new( - SortMergeJoinExec::try_new( - left, - right, - join_on.clone(), - *join_type, - vec![SortOptions::default(); join_on.len()], - false, - ) - .unwrap(), - ) - } - - fn coalesce_batches_exec(input: Arc) -> Arc { - Arc::new(CoalesceBatchesExec::new(input, 128)) - } } diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 8689b016b01c..0386bcd3d054 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -17,9 +17,31 @@ //! Collection of testing utility functions that are leveraged by the query optimizer rules +use crate::datasource::listing::PartitionedFile; +use crate::datasource::physical_plan::{FileScanConfig, ParquetExec}; use crate::error::Result; +use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; +use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::filter::FilterExec; +use crate::physical_plan::joins::utils::{JoinFilter, JoinOn}; +use crate::physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; +use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use crate::physical_plan::memory::MemoryExec; +use crate::physical_plan::repartition::RepartitionExec; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use crate::physical_plan::union::UnionExec; +use crate::physical_plan::windows::create_window_expr; +use crate::physical_plan::{displayable, ExecutionPlan, Partitioning}; use crate::prelude::{CsvReadOptions, SessionContext}; +use arrow_schema::{Schema, SchemaRef, SortOptions}; use async_trait::async_trait; +use datafusion_common::{JoinType, Statistics}; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use std::sync::Arc; async fn register_current_csv( @@ -130,3 +152,207 @@ impl QueryCase { Ok(()) } } + +pub fn sort_merge_join_exec( + left: Arc, + right: Arc, + join_on: &JoinOn, + join_type: &JoinType, +) -> Arc { + Arc::new( + SortMergeJoinExec::try_new( + left, + right, + join_on.clone(), + *join_type, + vec![SortOptions::default(); join_on.len()], + false, + ) + .unwrap(), + ) +} + +/// make PhysicalSortExpr with default options +pub fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr { + sort_expr_options(name, schema, SortOptions::default()) +} + +/// PhysicalSortExpr with specified options +pub fn sort_expr_options( + name: &str, + schema: &Schema, + options: SortOptions, +) -> PhysicalSortExpr { + PhysicalSortExpr { + expr: col(name, schema).unwrap(), + options, + } +} + +pub fn coalesce_partitions_exec(input: Arc) -> Arc { + Arc::new(CoalescePartitionsExec::new(input)) +} + +pub(crate) fn memory_exec(schema: &SchemaRef) -> Arc { + Arc::new(MemoryExec::try_new(&[], schema.clone(), None).unwrap()) +} + +pub fn hash_join_exec( + left: Arc, + right: Arc, + on: JoinOn, + filter: Option, + join_type: &JoinType, +) -> Result> { + Ok(Arc::new(HashJoinExec::try_new( + left, + right, + on, + filter, + join_type, + PartitionMode::Partitioned, + true, + )?)) +} + +pub fn bounded_window_exec( + col_name: &str, + sort_exprs: impl IntoIterator, + input: Arc, +) -> Arc { + let sort_exprs: Vec<_> = sort_exprs.into_iter().collect(); + let schema = input.schema(); + + Arc::new( + crate::physical_plan::windows::BoundedWindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Count), + "count".to_owned(), + &[col(col_name, &schema).unwrap()], + &[], + &sort_exprs, + Arc::new(WindowFrame::new(true)), + schema.as_ref(), + ) + .unwrap()], + input.clone(), + input.schema(), + vec![], + crate::physical_plan::windows::PartitionSearchMode::Sorted, + ) + .unwrap(), + ) +} + +pub fn filter_exec( + predicate: Arc, + input: Arc, +) -> Arc { + Arc::new(FilterExec::try_new(predicate, input).unwrap()) +} + +// Util function to get string representation of a physical plan +pub fn get_plan_string(plan: &Arc) -> Vec { + let formatted = displayable(plan.as_ref()).indent(true).to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + actual.iter().map(|elem| elem.to_string()).collect() +} + +pub fn sort_preserving_merge_exec( + sort_exprs: impl IntoIterator, + input: Arc, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) +} + +/// Create a non sorted parquet exec +pub fn parquet_exec(schema: &SchemaRef) -> Arc { + Arc::new(ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::default(), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + infinite_source: false, + }, + None, + None, + )) +} + +// Created a sorted parquet exec +pub fn parquet_exec_sorted( + schema: &SchemaRef, + sort_exprs: impl IntoIterator, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + + Arc::new(ParquetExec::new( + FileScanConfig { + object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), + file_schema: schema.clone(), + file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], + statistics: Statistics::default(), + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![sort_exprs], + infinite_source: false, + }, + None, + None, + )) +} + +pub fn union_exec(input: Vec>) -> Arc { + Arc::new(UnionExec::new(input)) +} + +pub fn limit_exec(input: Arc) -> Arc { + global_limit_exec(local_limit_exec(input)) +} + +pub fn local_limit_exec(input: Arc) -> Arc { + Arc::new(LocalLimitExec::new(input, 100)) +} + +pub fn global_limit_exec(input: Arc) -> Arc { + Arc::new(GlobalLimitExec::new(input, 0, Some(100))) +} + +pub fn repartition_exec(input: Arc) -> Arc { + Arc::new(RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)).unwrap()) +} + +pub fn aggregate_exec(input: Arc) -> Arc { + let schema = input.schema(); + Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![], + vec![], + vec![], + input, + schema, + ) + .unwrap(), + ) +} + +pub fn coalesce_batches_exec(input: Arc) -> Arc { + Arc::new(CoalesceBatchesExec::new(input, 128)) +} + +pub fn sort_exec( + sort_exprs: impl IntoIterator, + input: Arc, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortExec::new(sort_exprs, input)) +}