diff --git a/rust/datafusion/src/physical_plan/union.rs b/rust/datafusion/src/physical_plan/union.rs index cbab728a8428b..0f3da9c10d2a6 100644 --- a/rust/datafusion/src/physical_plan/union.rs +++ b/rust/datafusion/src/physical_plan/union.rs @@ -60,15 +60,31 @@ impl ExecutionPlan for UnionExec { /// Output of the union is the combination of all output partitions of the inputs fn output_partitioning(&self) -> Partitioning { - // Sums all the output partitions - let num_partitions = self - .inputs + let intial: Option = None; + self.inputs .iter() - .map(|plan| plan.output_partitioning().partition_count()) - .sum(); - // TODO: this loses partitioning info in case of same partitioning scheme (for example `Partitioning::Hash`) - // https://issues.apache.org/jira/browse/ARROW-11991 - Partitioning::UnknownPartitioning(num_partitions) + .fold(intial, |acc, input| { + match (acc, input.output_partitioning()) { + (None, partition) => Some(partition), + ( + Some(Partitioning::Hash(mut vector_acc, size_acc)), + Partitioning::Hash(vector, size), + ) => { + vector_acc.append(&mut vector.clone()); + Some(Partitioning::Hash(vector_acc, size_acc + size)) + } + ( + Some(Partitioning::RoundRobinBatch(size_acc)), + Partitioning::RoundRobinBatch(size), + ) => Some(Partitioning::RoundRobinBatch(size_acc + size)), + (Some(partition_acc), partition) => { + Some(Partitioning::UnknownPartitioning( + partition_acc.partition_count() + partition.partition_count(), + )) + } + } + }) + .unwrap() } fn with_new_children( @@ -99,17 +115,17 @@ impl ExecutionPlan for UnionExec { #[cfg(test)] mod tests { use super::*; + use crate::physical_plan::expressions::Column; use crate::physical_plan::{ collect, csv::{CsvExec, CsvReadOptions}, + repartition::RepartitionExec, }; use crate::test; use arrow::record_batch::RecordBatch; - #[tokio::test] - async fn test_union_partitions() -> Result<()> { + fn get_csv_exec() -> Result<(CsvExec, CsvExec)> { let schema = test::aggr_test_schema(); - // Create csv's with different partitioning let path = test::create_partitioned_csv("aggregate_test_100.csv", 4)?; let path2 = test::create_partitioned_csv("aggregate_test_100.csv", 5)?; @@ -129,15 +145,105 @@ mod tests { 1024, None, )?; + Ok((csv, csv2)) + } + + #[tokio::test] + async fn test_union_partitions_unknown() -> Result<()> { + let (csv, csv2) = get_csv_exec()?; let union_exec = Arc::new(UnionExec::new(vec![Arc::new(csv), Arc::new(csv2)])); // Should have 9 partitions and 9 output batches - assert_eq!(union_exec.output_partitioning().partition_count(), 9); + assert!(matches!( + union_exec.output_partitioning(), + Partitioning::UnknownPartitioning(9) + )); + + let result: Vec = collect(union_exec).await?; + assert_eq!(result.len(), 9); + + Ok(()) + } + + #[tokio::test] + async fn test_union_partitions_hash() -> Result<()> { + let (csv, csv2) = get_csv_exec()?; + let repartition = RepartitionExec::try_new( + Arc::new(csv), + Partitioning::Hash(vec![Arc::new(Column::new("c1"))], 5), + )?; + let repartition2 = RepartitionExec::try_new( + Arc::new(csv2), + Partitioning::Hash(vec![Arc::new(Column::new("c2"))], 5), + )?; + + let union_exec = Arc::new(UnionExec::new(vec![ + Arc::new(repartition), + Arc::new(repartition2), + ])); + + // should be hash, have 10 partitions and 45 output batches + assert!(matches!( + union_exec.output_partitioning(), + Partitioning::Hash(_, 10) + )); + + let result: Vec = collect(union_exec).await?; + assert_eq!(result.len(), 45); + + Ok(()) + } + + #[tokio::test] + async fn test_union_partitions_round_robin() -> Result<()> { + let (csv, csv2) = get_csv_exec()?; + let repartition = + RepartitionExec::try_new(Arc::new(csv), Partitioning::RoundRobinBatch(4))?; + let repartition2 = + RepartitionExec::try_new(Arc::new(csv2), Partitioning::RoundRobinBatch(6))?; + + let union_exec = Arc::new(UnionExec::new(vec![ + Arc::new(repartition), + Arc::new(repartition2), + ])); + + // should be hash, have 10 partitions and 9 output batches + assert!(matches!( + union_exec.output_partitioning(), + Partitioning::RoundRobinBatch(10) + )); let result: Vec = collect(union_exec).await?; assert_eq!(result.len(), 9); Ok(()) } + + #[tokio::test] + async fn test_union_partitions_mix() -> Result<()> { + let (csv, csv2) = get_csv_exec()?; + let repartition = RepartitionExec::try_new( + Arc::new(csv), + Partitioning::Hash(vec![Arc::new(Column::new("c1"))], 5), + )?; + let repartition2 = + RepartitionExec::try_new(Arc::new(csv2), Partitioning::RoundRobinBatch(6))?; + + let union_exec = Arc::new(UnionExec::new(vec![ + Arc::new(repartition), + Arc::new(repartition2), + ])); + + // should be hash, have 11 partitions and 25 output batches + assert!(matches!( + union_exec.output_partitioning(), + Partitioning::UnknownPartitioning(11) + )); + + let result: Vec = collect(union_exec).await?; + assert_eq!(result.len(), 25); + + Ok(()) + } }