From a71a76a996a32a0f068370940ebe475ec237b4ff Mon Sep 17 00:00:00 2001 From: Eduard Karacharov <13005055+korowa@users.noreply.github.com> Date: Mon, 18 Dec 2023 11:53:26 +0200 Subject: [PATCH] refactor: `HashJoinStream` state machine (#8538) * hash join state machine * StreamJoinStateResult to StatefulStreamResult * doc comments & naming & fmt * suggestions from code review Co-authored-by: Andrew Lamb * more review comments addressed * post-merge fixes --------- Co-authored-by: Andrew Lamb --- .../physical-plan/src/joins/hash_join.rs | 431 ++++++++++++------ .../src/joins/stream_join_utils.rs | 127 ++---- .../src/joins/symmetric_hash_join.rs | 25 +- datafusion/physical-plan/src/joins/utils.rs | 83 ++++ 4 files changed, 420 insertions(+), 246 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 4846d0a5e046..13ac06ee301c 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -28,7 +28,6 @@ use crate::joins::utils::{ calculate_join_output_ordering, get_final_indices_from_bit_map, need_produce_result_in_final, JoinHashMap, JoinHashMapType, }; -use crate::DisplayAs; use crate::{ coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec, @@ -38,12 +37,13 @@ use crate::{ joins::utils::{ adjust_right_output_partitioning, build_join_schema, check_join_is_valid, estimate_join_statistics, partitioned_join_output_partitioning, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, StatefulStreamResult, }, metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use crate::{handle_state, DisplayAs}; use super::{ utils::{OnceAsync, OnceFut}, @@ -618,15 +618,14 @@ impl ExecutionPlan for HashJoinExec { on_right, filter: self.filter.clone(), join_type: self.join_type, - left_fut, - visited_left_side: None, right: right_stream, column_indices: self.column_indices.clone(), random_state: self.random_state.clone(), join_metrics, null_equals_null: self.null_equals_null, - is_exhausted: false, reservation, + state: HashJoinStreamState::WaitBuildSide, + build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), })) } @@ -789,6 +788,104 @@ where Ok(()) } +/// Represents build-side of hash join. +enum BuildSide { + /// Indicates that build-side not collected yet + Initial(BuildSideInitialState), + /// Indicates that build-side data has been collected + Ready(BuildSideReadyState), +} + +/// Container for BuildSide::Initial related data +struct BuildSideInitialState { + /// Future for building hash table from build-side input + left_fut: OnceFut, +} + +/// Container for BuildSide::Ready related data +struct BuildSideReadyState { + /// Collected build-side data + left_data: Arc, + /// Which build-side rows have been matched while creating output. + /// For some OUTER joins, we need to know which rows have not been matched + /// to produce the correct output. + visited_left_side: BooleanBufferBuilder, +} + +impl BuildSide { + /// Tries to extract BuildSideInitialState from BuildSide enum. + /// Returns an error if state is not Initial. + fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> { + match self { + BuildSide::Initial(state) => Ok(state), + _ => internal_err!("Expected build side in initial state"), + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready(&self) -> Result<&BuildSideReadyState> { + match self { + BuildSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> { + match self { + BuildSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } +} + +/// Represents state of HashJoinStream +/// +/// Expected state transitions performed by HashJoinStream are: +/// +/// ```text +/// +/// WaitBuildSide +/// │ +/// ▼ +/// ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed +/// │ │ +/// │ ▼ +/// └─ ProcessProbeBatch +/// +/// ``` +enum HashJoinStreamState { + /// Initial state for HashJoinStream indicating that build-side data not collected yet + WaitBuildSide, + /// Indicates that build-side has been collected, and stream is ready for fetching probe-side + FetchProbeBatch, + /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed + ProcessProbeBatch(ProcessProbeBatchState), + /// Indicates that probe-side has been fully processed + ExhaustedProbeSide, + /// Indicates that HashJoinStream execution is completed + Completed, +} + +/// Container for HashJoinStreamState::ProcessProbeBatch related data +struct ProcessProbeBatchState { + /// Current probe-side batch + batch: RecordBatch, +} + +impl HashJoinStreamState { + /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum. + /// Returns an error if state is not ProcessProbeBatchState. + fn try_as_process_probe_batch(&self) -> Result<&ProcessProbeBatchState> { + match self { + HashJoinStreamState::ProcessProbeBatch(state) => Ok(state), + _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"), + } + } +} + /// [`Stream`] for [`HashJoinExec`] that does the actual join. /// /// This stream: @@ -808,20 +905,10 @@ struct HashJoinStream { filter: Option, /// type of the join (left, right, semi, etc) join_type: JoinType, - /// future which builds hash table from left side - left_fut: OnceFut, - /// Which left (probe) side rows have been matches while creating output. - /// For some OUTER joins, we need to know which rows have not been matched - /// to produce the correct output. - visited_left_side: Option, /// right (probe) input right: SendableRecordBatchStream, /// Random state used for hashing initialization random_state: RandomState, - /// The join output is complete. For outer joins, this is used to - /// distinguish when the input stream is exhausted and when any unmatched - /// rows are output. - is_exhausted: bool, /// Metrics join_metrics: BuildProbeJoinMetrics, /// Information of index and left / right placement of columns @@ -830,6 +917,10 @@ struct HashJoinStream { null_equals_null: bool, /// Memory reservation reservation: MemoryReservation, + /// State of the stream + state: HashJoinStreamState, + /// Build side + build_side: BuildSide, } impl RecordBatchStream for HashJoinStream { @@ -1069,19 +1160,44 @@ impl HashJoinStream { &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>> { + loop { + return match self.state { + HashJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + HashJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + HashJoinStreamState::ProcessProbeBatch(_) => { + handle_state!(self.process_probe_batch()) + } + HashJoinStreamState::ExhaustedProbeSide => { + handle_state!(self.process_unmatched_build_batch()) + } + HashJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + /// Collects build-side data by polling `OnceFut` future from initialized build-side + /// + /// Updates build-side to `Ready`, and state to `FetchProbeSide` + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); // build hash table from left (build) side, if not yet done - let left_data = match ready!(self.left_fut.get(cx)) { - Ok(left_data) => left_data, - Err(e) => return Poll::Ready(Some(Err(e))), - }; + let left_data = ready!(self + .build_side + .try_as_initial_mut()? + .left_fut + .get_shared(cx))?; build_timer.done(); // Reserving memory for visited_left_side bitmap in case it hasn't been initialized yet // and join_type requires to store it - if self.visited_left_side.is_none() - && need_produce_result_in_final(self.join_type) - { + if need_produce_result_in_final(self.join_type) { // TODO: Replace `ceil` wrapper with stable `div_cell` after // https://github.com/rust-lang/rust/issues/88581 let visited_bitmap_size = bit_util::ceil(left_data.num_rows(), 8); @@ -1089,124 +1205,167 @@ impl HashJoinStream { self.join_metrics.build_mem_used.add(visited_bitmap_size); } - let visited_left_side = self.visited_left_side.get_or_insert_with(|| { + let visited_left_side = if need_produce_result_in_final(self.join_type) { let num_rows = left_data.num_rows(); - if need_produce_result_in_final(self.join_type) { - // Some join types need to track which row has be matched or unmatched: - // `left semi` join: need to use the bitmap to produce the matched row in the left side - // `left` join: need to use the bitmap to produce the unmatched row in the left side with null - // `left anti` join: need to use the bitmap to produce the unmatched row in the left side - // `full` join: need to use the bitmap to produce the unmatched row in the left side with null - let mut buffer = BooleanBufferBuilder::new(num_rows); - buffer.append_n(num_rows, false); - buffer - } else { - BooleanBufferBuilder::new(0) - } + // Some join types need to track which row has be matched or unmatched: + // `left semi` join: need to use the bitmap to produce the matched row in the left side + // `left` join: need to use the bitmap to produce the unmatched row in the left side with null + // `left anti` join: need to use the bitmap to produce the unmatched row in the left side + // `full` join: need to use the bitmap to produce the unmatched row in the left side with null + let mut buffer = BooleanBufferBuilder::new(num_rows); + buffer.append_n(num_rows, false); + buffer + } else { + BooleanBufferBuilder::new(0) + }; + + self.state = HashJoinStreamState::FetchProbeBatch; + self.build_side = BuildSide::Ready(BuildSideReadyState { + left_data, + visited_left_side, }); + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Fetches next batch from probe-side + /// + /// If non-empty batch has been fetched, updates state to `ProcessProbeBatchState`, + /// otherwise updates state to `ExhaustedProbeSide` + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.right.poll_next_unpin(cx)) { + None => { + self.state = HashJoinStreamState::ExhaustedProbeSide; + } + Some(Ok(batch)) => { + self.state = + HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { + batch, + }); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Joins current probe batch with build-side data and produces batch with matched output + /// + /// Updates state to `FetchProbeBatch` + fn process_probe_batch( + &mut self, + ) -> Result>> { + let state = self.state.try_as_process_probe_batch()?; + let build_side = self.build_side.try_as_ready_mut()?; + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(state.batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + let mut hashes_buffer = vec![]; - // get next right (probe) input batch - self.right - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - // one right batch in the join loop - Some(Ok(batch)) => { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - // get the matched two indices for the on condition - let left_right_indices = build_equal_condition_join_indices( - left_data.hash_map(), - left_data.batch(), - &batch, - &self.on_left, - &self.on_right, - &self.random_state, - self.null_equals_null, - &mut hashes_buffer, - self.filter.as_ref(), - JoinSide::Left, - None, - ); - - let result = match left_right_indices { - Ok((left_side, right_side)) => { - // set the left bitmap - // and only left, full, left semi, left anti need the left bitmap - if need_produce_result_in_final(self.join_type) { - left_side.iter().flatten().for_each(|x| { - visited_left_side.set_bit(x as usize, true); - }); - } - - // adjust the two side indices base on the join type - let (left_side, right_side) = adjust_indices_by_join_type( - left_side, - right_side, - batch.num_rows(), - self.join_type, - ); - - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - Some(result) - } - Err(err) => Some(exec_err!( - "Fail to build join indices in HashJoinExec, error:{err}" - )), - }; - timer.done(); - result - } - None => { - let timer = self.join_metrics.join_time.timer(); - if need_produce_result_in_final(self.join_type) && !self.is_exhausted - { - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = get_final_indices_from_bit_map( - visited_left_side, - self.join_type, - ); - let empty_right_batch = - RecordBatch::new_empty(self.right.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - - if let Ok(ref batch) = result { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - timer.done(); - self.is_exhausted = true; - Some(result) - } else { - // end of the join loop - None - } + // get the matched two indices for the on condition + let left_right_indices = build_equal_condition_join_indices( + build_side.left_data.hash_map(), + build_side.left_data.batch(), + &state.batch, + &self.on_left, + &self.on_right, + &self.random_state, + self.null_equals_null, + &mut hashes_buffer, + self.filter.as_ref(), + JoinSide::Left, + None, + ); + + let result = match left_right_indices { + Ok((left_side, right_side)) => { + // set the left bitmap + // and only left, full, left semi, left anti need the left bitmap + if need_produce_result_in_final(self.join_type) { + left_side.iter().flatten().for_each(|x| { + build_side.visited_left_side.set_bit(x as usize, true); + }); } - Some(err) => Some(err), - }) + + // adjust the two side indices base on the join type + let (left_side, right_side) = adjust_indices_by_join_type( + left_side, + right_side, + state.batch.num_rows(), + self.join_type, + ); + + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &state.batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(state.batch.num_rows()); + result + } + Err(err) => { + exec_err!("Fail to build join indices in HashJoinExec, error:{err}") + } + }; + timer.done(); + + self.state = HashJoinStreamState::FetchProbeBatch; + + Ok(StatefulStreamResult::Ready(Some(result?))) + } + + /// Processes unmatched build-side rows for certain join types and produces output batch + /// + /// Updates state to `Completed` + fn process_unmatched_build_batch( + &mut self, + ) -> Result>> { + let timer = self.join_metrics.join_time.timer(); + + if !need_produce_result_in_final(self.join_type) { + self.state = HashJoinStreamState::Completed; + + return Ok(StatefulStreamResult::Continue); + } + + let build_side = self.build_side.try_as_ready()?; + + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_bit_map(&build_side.visited_left_side, self.join_type); + let empty_right_batch = RecordBatch::new_empty(self.right.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + + if let Ok(ref batch) = result { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + timer.done(); + + self.state = HashJoinStreamState::Completed; + + Ok(StatefulStreamResult::Ready(Some(result?))) } } diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 2f74bd1c4bb2..64a976a1e39f 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -23,9 +23,9 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::usize; -use crate::joins::utils::{JoinFilter, JoinHashMapType}; +use crate::joins::utils::{JoinFilter, JoinHashMapType, StatefulStreamResult}; use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; -use crate::{handle_async_state, metrics}; +use crate::{handle_async_state, handle_state, metrics}; use arrow::compute::concat_batches; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; @@ -624,73 +624,6 @@ pub fn record_visited_indices( } } -/// The `handle_state` macro is designed to process the result of a state-changing -/// operation, typically encountered in implementations of `EagerJoinStream`. It -/// operates on a `StreamJoinStateResult` by matching its variants and executing -/// corresponding actions. This macro is used to streamline code that deals with -/// state transitions, reducing boilerplate and improving readability. -/// -/// # Cases -/// -/// - `Ok(StreamJoinStateResult::Continue)`: Continues the loop, indicating the -/// stream join operation should proceed to the next step. -/// - `Ok(StreamJoinStateResult::Ready(result))`: Returns a `Poll::Ready` with the -/// result, either yielding a value or indicating the stream is awaiting more -/// data. -/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue -/// during the stream join operation. -/// -/// # Arguments -/// -/// * `$match_case`: An expression that evaluates to a `Result>`. -#[macro_export] -macro_rules! handle_state { - ($match_case:expr) => { - match $match_case { - Ok(StreamJoinStateResult::Continue) => continue, - Ok(StreamJoinStateResult::Ready(result)) => { - Poll::Ready(Ok(result).transpose()) - } - Err(e) => Poll::Ready(Some(Err(e))), - } - }; -} - -/// The `handle_async_state` macro adapts the `handle_state` macro for use in -/// asynchronous operations, particularly when dealing with `Poll` results within -/// async traits like `EagerJoinStream`. It polls the asynchronous state-changing -/// function using `poll_unpin` and then passes the result to `handle_state` for -/// further processing. -/// -/// # Arguments -/// -/// * `$state_func`: An async function or future that returns a -/// `Result>`. -/// * `$cx`: The context to be passed for polling, usually of type `&mut Context`. -/// -#[macro_export] -macro_rules! handle_async_state { - ($state_func:expr, $cx:expr) => { - $crate::handle_state!(ready!($state_func.poll_unpin($cx))) - }; -} - -/// Represents the result of a stateful operation on `EagerJoinStream`. -/// -/// This enumueration indicates whether the state produced a result that is -/// ready for use (`Ready`) or if the operation requires continuation (`Continue`). -/// -/// Variants: -/// - `Ready(T)`: Indicates that the operation is complete with a result of type `T`. -/// - `Continue`: Indicates that the operation is not yet complete and requires further -/// processing or more data. When this variant is returned, it typically means that the -/// current invocation of the state did not produce a final result, and the operation -/// should be invoked again later with more data and possibly with a different state. -pub enum StreamJoinStateResult { - Ready(T), - Continue, -} - /// Represents the various states of an eager join stream operation. /// /// This enum is used to track the current state of streaming during a join @@ -819,14 +752,14 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after pulling the batch. + /// * `Result>>` - The state result after pulling the batch. async fn fetch_next_from_right_stream( &mut self, - ) -> Result>> { + ) -> Result>> { match self.right_stream().next().await { Some(Ok(batch)) => { if batch.num_rows() == 0 { - return Ok(StreamJoinStateResult::Continue); + return Ok(StatefulStreamResult::Continue); } self.set_state(EagerJoinStreamState::PullLeft); @@ -835,7 +768,7 @@ pub trait EagerJoinStream { Some(Err(e)) => Err(e), None => { self.set_state(EagerJoinStreamState::RightExhausted); - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } } } @@ -848,14 +781,14 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after pulling the batch. + /// * `Result>>` - The state result after pulling the batch. async fn fetch_next_from_left_stream( &mut self, - ) -> Result>> { + ) -> Result>> { match self.left_stream().next().await { Some(Ok(batch)) => { if batch.num_rows() == 0 { - return Ok(StreamJoinStateResult::Continue); + return Ok(StatefulStreamResult::Continue); } self.set_state(EagerJoinStreamState::PullRight); self.process_batch_from_left(batch) @@ -863,7 +796,7 @@ pub trait EagerJoinStream { Some(Err(e)) => Err(e), None => { self.set_state(EagerJoinStreamState::LeftExhausted); - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } } } @@ -877,14 +810,14 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after checking the exhaustion state. + /// * `Result>>` - The state result after checking the exhaustion state. async fn handle_right_stream_end( &mut self, - ) -> Result>> { + ) -> Result>> { match self.left_stream().next().await { Some(Ok(batch)) => { if batch.num_rows() == 0 { - return Ok(StreamJoinStateResult::Continue); + return Ok(StatefulStreamResult::Continue); } self.process_batch_after_right_end(batch) } @@ -893,7 +826,7 @@ pub trait EagerJoinStream { self.set_state(EagerJoinStreamState::BothExhausted { final_result: false, }); - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } } } @@ -907,14 +840,14 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after checking the exhaustion state. + /// * `Result>>` - The state result after checking the exhaustion state. async fn handle_left_stream_end( &mut self, - ) -> Result>> { + ) -> Result>> { match self.right_stream().next().await { Some(Ok(batch)) => { if batch.num_rows() == 0 { - return Ok(StreamJoinStateResult::Continue); + return Ok(StatefulStreamResult::Continue); } self.process_batch_after_left_end(batch) } @@ -923,7 +856,7 @@ pub trait EagerJoinStream { self.set_state(EagerJoinStreamState::BothExhausted { final_result: false, }); - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } } } @@ -936,10 +869,10 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after both streams are exhausted. + /// * `Result>>` - The state result after both streams are exhausted. fn prepare_for_final_results_after_exhaustion( &mut self, - ) -> Result>> { + ) -> Result>> { self.set_state(EagerJoinStreamState::BothExhausted { final_result: true }); self.process_batches_before_finalization() } @@ -952,11 +885,11 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after processing the batch. + /// * `Result>>` - The state result after processing the batch. fn process_batch_from_right( &mut self, batch: RecordBatch, - ) -> Result>>; + ) -> Result>>; /// Handles a pulled batch from the left stream. /// @@ -966,11 +899,11 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after processing the batch. + /// * `Result>>` - The state result after processing the batch. fn process_batch_from_left( &mut self, batch: RecordBatch, - ) -> Result>>; + ) -> Result>>; /// Handles the situation when only the left stream is exhausted. /// @@ -980,11 +913,11 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after the left stream is exhausted. + /// * `Result>>` - The state result after the left stream is exhausted. fn process_batch_after_left_end( &mut self, right_batch: RecordBatch, - ) -> Result>>; + ) -> Result>>; /// Handles the situation when only the right stream is exhausted. /// @@ -994,20 +927,20 @@ pub trait EagerJoinStream { /// /// # Returns /// - /// * `Result>>` - The state result after the right stream is exhausted. + /// * `Result>>` - The state result after the right stream is exhausted. fn process_batch_after_right_end( &mut self, left_batch: RecordBatch, - ) -> Result>>; + ) -> Result>>; /// Handles the final state after both streams are exhausted. /// /// # Returns /// - /// * `Result>>` - The final state result after processing. + /// * `Result>>` - The final state result after processing. fn process_batches_before_finalization( &mut self, - ) -> Result>>; + ) -> Result>>; /// Provides mutable access to the right stream. /// diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 00a7f23ebae7..b9101b57c3e5 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -38,12 +38,11 @@ use crate::joins::stream_join_utils::{ convert_sort_expr_with_filter_schema, get_pruning_anti_indices, get_pruning_semi_indices, record_visited_indices, EagerJoinStream, EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics, - StreamJoinStateResult, }; use crate::joins::utils::{ build_batch_from_indices, build_join_schema, check_join_is_valid, partitioned_join_output_partitioning, prepare_sorted_exprs, ColumnIndex, JoinFilter, - JoinOn, + JoinOn, StatefulStreamResult, }; use crate::{ expressions::{Column, PhysicalSortExpr}, @@ -956,13 +955,13 @@ impl EagerJoinStream for SymmetricHashJoinStream { fn process_batch_from_right( &mut self, batch: RecordBatch, - ) -> Result>> { + ) -> Result>> { self.perform_join_for_given_side(batch, JoinSide::Right) .map(|maybe_batch| { if maybe_batch.is_some() { - StreamJoinStateResult::Ready(maybe_batch) + StatefulStreamResult::Ready(maybe_batch) } else { - StreamJoinStateResult::Continue + StatefulStreamResult::Continue } }) } @@ -970,13 +969,13 @@ impl EagerJoinStream for SymmetricHashJoinStream { fn process_batch_from_left( &mut self, batch: RecordBatch, - ) -> Result>> { + ) -> Result>> { self.perform_join_for_given_side(batch, JoinSide::Left) .map(|maybe_batch| { if maybe_batch.is_some() { - StreamJoinStateResult::Ready(maybe_batch) + StatefulStreamResult::Ready(maybe_batch) } else { - StreamJoinStateResult::Continue + StatefulStreamResult::Continue } }) } @@ -984,20 +983,20 @@ impl EagerJoinStream for SymmetricHashJoinStream { fn process_batch_after_left_end( &mut self, right_batch: RecordBatch, - ) -> Result>> { + ) -> Result>> { self.process_batch_from_right(right_batch) } fn process_batch_after_right_end( &mut self, left_batch: RecordBatch, - ) -> Result>> { + ) -> Result>> { self.process_batch_from_left(left_batch) } fn process_batches_before_finalization( &mut self, - ) -> Result>> { + ) -> Result>> { // Get the left side results: let left_result = build_side_determined_results( &self.left, @@ -1025,9 +1024,9 @@ impl EagerJoinStream for SymmetricHashJoinStream { // Update the metrics: self.metrics.output_batches.add(1); self.metrics.output_rows.add(batch.num_rows()); - return Ok(StreamJoinStateResult::Ready(result)); + return Ok(StatefulStreamResult::Ready(result)); } - Ok(StreamJoinStateResult::Continue) + Ok(StatefulStreamResult::Continue) } fn right_stream(&mut self) -> &mut SendableRecordBatchStream { diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 5e01ca227cf5..eae65ce9c26b 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -849,6 +849,22 @@ impl OnceFut { ), } } + + /// Get shared reference to the result of the computation if it is ready, without consuming it + pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) -> Poll>> { + if let OnceFutState::Pending(fut) = &mut self.state { + let r = ready!(fut.poll_unpin(cx)); + self.state = OnceFutState::Ready(r); + } + + match &self.state { + OnceFutState::Pending(_) => unreachable!(), + OnceFutState::Ready(r) => Poll::Ready( + r.clone() + .map_err(|e| DataFusionError::External(Box::new(e))), + ), + } + } } /// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and @@ -1277,6 +1293,73 @@ pub fn prepare_sorted_exprs( Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph)) } +/// The `handle_state` macro is designed to process the result of a state-changing +/// operation, encountered e.g. in implementations of `EagerJoinStream`. It +/// operates on a `StatefulStreamResult` by matching its variants and executing +/// corresponding actions. This macro is used to streamline code that deals with +/// state transitions, reducing boilerplate and improving readability. +/// +/// # Cases +/// +/// - `Ok(StatefulStreamResult::Continue)`: Continues the loop, indicating the +/// stream join operation should proceed to the next step. +/// - `Ok(StatefulStreamResult::Ready(result))`: Returns a `Poll::Ready` with the +/// result, either yielding a value or indicating the stream is awaiting more +/// data. +/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue +/// during the stream join operation. +/// +/// # Arguments +/// +/// * `$match_case`: An expression that evaluates to a `Result>`. +#[macro_export] +macro_rules! handle_state { + ($match_case:expr) => { + match $match_case { + Ok(StatefulStreamResult::Continue) => continue, + Ok(StatefulStreamResult::Ready(result)) => { + Poll::Ready(Ok(result).transpose()) + } + Err(e) => Poll::Ready(Some(Err(e))), + } + }; +} + +/// The `handle_async_state` macro adapts the `handle_state` macro for use in +/// asynchronous operations, particularly when dealing with `Poll` results within +/// async traits like `EagerJoinStream`. It polls the asynchronous state-changing +/// function using `poll_unpin` and then passes the result to `handle_state` for +/// further processing. +/// +/// # Arguments +/// +/// * `$state_func`: An async function or future that returns a +/// `Result>`. +/// * `$cx`: The context to be passed for polling, usually of type `&mut Context`. +/// +#[macro_export] +macro_rules! handle_async_state { + ($state_func:expr, $cx:expr) => { + $crate::handle_state!(ready!($state_func.poll_unpin($cx))) + }; +} + +/// Represents the result of an operation on stateful join stream. +/// +/// This enumueration indicates whether the state produced a result that is +/// ready for use (`Ready`) or if the operation requires continuation (`Continue`). +/// +/// Variants: +/// - `Ready(T)`: Indicates that the operation is complete with a result of type `T`. +/// - `Continue`: Indicates that the operation is not yet complete and requires further +/// processing or more data. When this variant is returned, it typically means that the +/// current invocation of the state did not produce a final result, and the operation +/// should be invoked again later with more data and possibly with a different state. +pub enum StatefulStreamResult { + Ready(T), + Continue, +} + #[cfg(test)] mod tests { use std::pin::Pin;