From d5a8c934fa818e52cc021ad250cb1044e29ff9df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Metehan=20Y=C4=B1ld=C4=B1r=C4=B1m?= <100111937+metesynnada@users.noreply.github.com> Date: Fri, 28 Apr 2023 11:26:38 +0300 Subject: [PATCH] MemoryExec INSERT INTO refactor to use ExecutionPlan (#6049) * MemoryExec insert into refactor * Merge leftovers * Set target partition * Comment and formatting improvements * Comments on state. * Letfover comments * After merge corrections * Correction after merge --------- Co-authored-by: Mehmet Ozan Kabak --- datafusion/core/src/datasource/datasource.rs | 4 +- datafusion/core/src/datasource/memory.rs | 283 ++++---- datafusion/core/src/execution/context.rs | 21 +- datafusion/core/src/physical_plan/memory.rs | 646 ++++++++++++++++++- datafusion/core/src/physical_plan/planner.rs | 19 +- 5 files changed, 800 insertions(+), 173 deletions(-) diff --git a/datafusion/core/src/datasource/datasource.rs b/datafusion/core/src/datasource/datasource.rs index 8db075a30a79..4560b3820cd9 100644 --- a/datafusion/core/src/datasource/datasource.rs +++ b/datafusion/core/src/datasource/datasource.rs @@ -102,8 +102,8 @@ pub trait TableProvider: Sync + Send { async fn insert_into( &self, _state: &SessionState, - _input: &LogicalPlan, - ) -> Result<()> { + _input: Arc, + ) -> Result> { let msg = "Insertion not implemented for this table".to_owned(); Err(DataFusionError::NotImplemented(msg)) } diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index ca083aebe33a..f41f8cb1bd48 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -24,20 +24,22 @@ use std::sync::Arc; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use async_trait::async_trait; -use datafusion_expr::LogicalPlan; use tokio::sync::RwLock; use crate::datasource::{TableProvider, TableType}; use crate::error::{DataFusionError, Result}; use crate::execution::context::SessionState; use crate::logical_expr::Expr; -use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::common; use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::memory::MemoryExec; +use crate::physical_plan::memory::MemoryWriteExec; use crate::physical_plan::ExecutionPlan; -use crate::physical_plan::{collect_partitioned, common}; use crate::physical_plan::{repartition::RepartitionExec, Partitioning}; +/// Type alias for partition data +pub type PartitionData = Arc>>; + /// In-memory data source for presenting a `Vec` as a /// data source that can be queried by DataFusion. This allows data to /// be pre-loaded into memory and then repeatedly queried without @@ -45,7 +47,7 @@ use crate::physical_plan::{repartition::RepartitionExec, Partitioning}; #[derive(Debug)] pub struct MemTable { schema: SchemaRef, - batches: Arc>>>, + pub(crate) batches: Vec, } impl MemTable { @@ -58,7 +60,10 @@ impl MemTable { { Ok(Self { schema, - batches: Arc::new(RwLock::new(partitions)), + batches: partitions + .into_iter() + .map(|e| Arc::new(RwLock::new(e))) + .collect::>(), }) } else { Err(DataFusionError::Plan( @@ -147,71 +152,62 @@ impl TableProvider for MemTable { _filters: &[Expr], _limit: Option, ) -> Result> { - let batches = &self.batches.read().await; - Ok(Arc::new(MemoryExec::try_new( - batches, + let mut partitions = vec![]; + for arc_inner_vec in self.batches.iter() { + let inner_vec = arc_inner_vec.read().await; + partitions.push(inner_vec.clone()) + } + Ok(Arc::new(MemoryExec::try_new_owned_data( + partitions, self.schema(), projection.cloned(), )?)) } - /// Inserts the execution results of a given [LogicalPlan] into this [MemTable]. - /// The `LogicalPlan` must have the same schema as this `MemTable`. + /// Inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`]. + /// The [`ExecutionPlan`] must have the same schema as this [`MemTable`]. /// /// # Arguments /// - /// * `state` - The [SessionState] containing the context for executing the plan. - /// * `input` - The [LogicalPlan] to execute and insert. + /// * `state` - The [`SessionState`] containing the context for executing the plan. + /// * `input` - The [`ExecutionPlan`] to execute and insert. /// /// # Returns /// /// * A `Result` indicating success or failure. - async fn insert_into(&self, state: &SessionState, input: &LogicalPlan) -> Result<()> { + async fn insert_into( + &self, + _state: &SessionState, + input: Arc, + ) -> Result> { // Create a physical plan from the logical plan. - let plan = state.create_physical_plan(input).await?; - // Check that the schema of the plan matches the schema of this table. - if !plan.schema().eq(&self.schema) { + if !input.schema().eq(&self.schema) { return Err(DataFusionError::Plan( "Inserting query must have the same schema with the table.".to_string(), )); } - // Get the number of partitions in the plan and the table. - let plan_partition_count = plan.output_partitioning().partition_count(); - let table_partition_count = self.batches.read().await.len(); + if self.batches.is_empty() { + return Err(DataFusionError::Plan( + "The table must have partitions.".to_string(), + )); + } - // Adjust the plan as necessary to match the number of partitions in the table. - let plan: Arc = if plan_partition_count - == table_partition_count - || table_partition_count == 0 - { - plan - } else if table_partition_count == 1 { - // If the table has only one partition, coalesce the partitions in the plan. - Arc::new(CoalescePartitionsExec::new(plan)) - } else { - // Otherwise, repartition the plan using a round-robin partitioning scheme. + let input = if self.batches.len() > 1 { Arc::new(RepartitionExec::try_new( - plan, - Partitioning::RoundRobinBatch(table_partition_count), + input, + Partitioning::RoundRobinBatch(self.batches.len()), )?) - }; - - let results = collect_partitioned(plan, state.task_ctx()).await?; - - // Write the results into the table. - let mut all_batches = self.batches.write().await; - - if all_batches.is_empty() { - *all_batches = results } else { - for (batches, result) in all_batches.iter_mut().zip(results.into_iter()) { - batches.extend(result); - } - } + input + }; - Ok(()) + Ok(Arc::new(MemoryWriteExec::try_new( + input, + self.batches.clone(), + self.schema.clone(), + )?)) } } @@ -220,6 +216,7 @@ mod tests { use super::*; use crate::datasource::provider_as_source; use crate::from_slice::FromSlice; + use crate::physical_plan::collect; use crate::prelude::SessionContext; use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; @@ -455,21 +452,48 @@ mod tests { Ok(()) } - fn create_mem_table_scan( + async fn experiment( schema: SchemaRef, - data: Vec>, - ) -> Result> { - // Convert the table into a provider so that it can be used in a query - let provider = provider_as_source(Arc::new(MemTable::try_new(schema, data)?)); - // Create a table scan logical plan to read from the table - Ok(Arc::new( - LogicalPlanBuilder::scan("source", provider, None)?.build()?, - )) - } - - fn create_initial_ctx() -> Result<(SessionContext, SchemaRef, RecordBatch)> { + initial_data: Vec>, + inserted_data: Vec>, + ) -> Result>> { // Create a new session context let session_ctx = SessionContext::new(); + // Create and register the initial table with the provided schema and data + let initial_table = Arc::new(MemTable::try_new(schema.clone(), initial_data)?); + session_ctx.register_table("t", initial_table.clone())?; + // Create and register the source table with the provided schema and inserted data + let source_table = Arc::new(MemTable::try_new(schema.clone(), inserted_data)?); + session_ctx.register_table("source", source_table.clone())?; + // Convert the source table into a provider so that it can be used in a query + let source = provider_as_source(source_table); + // Create a table scan logical plan to read from the source table + let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; + // Create an insert plan to insert the source data into the initial table + let insert_into_table = + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema)?.build()?; + // Create a physical plan from the insert plan + let plan = session_ctx + .state() + .create_physical_plan(&insert_into_table) + .await?; + + // Execute the physical plan and collect the results + let res = collect(plan, session_ctx.task_ctx()).await?; + // Ensure the result is empty after the insert operation + assert!(res.is_empty()); + // Read the data from the initial table and store it in a vector of partitions + let mut partitions = vec![]; + for partition in initial_table.batches.iter() { + let part = partition.read().await.clone(); + partitions.push(part); + } + Ok(partitions) + } + + // Test inserting a single batch of data into a single partition + #[tokio::test] + async fn test_insert_into_single_partition() -> Result<()> { // Create a new schema with one field called "a" of type Int32 let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); @@ -478,111 +502,84 @@ mod tests { schema.clone(), vec![Arc::new(Int32Array::from_slice([1, 2, 3]))], )?; - Ok((session_ctx, schema, batch)) - } - - #[tokio::test] - async fn test_insert_into_single_partition() -> Result<()> { - let (session_ctx, schema, batch) = create_initial_ctx()?; - let initial_table = Arc::new(MemTable::try_new( - schema.clone(), - vec![vec![batch.clone()]], - )?); - // Create a table scan logical plan to read from the table - let single_partition_table_scan = - create_mem_table_scan(schema.clone(), vec![vec![batch.clone()]])?; - // Insert the data from the provider into the table - initial_table - .insert_into(&session_ctx.state(), &single_partition_table_scan) - .await?; + // Run the experiment and obtain the resulting data in the table + let resulting_data_in_table = + experiment(schema, vec![vec![batch.clone()]], vec![vec![batch.clone()]]) + .await?; // Ensure that the table now contains two batches of data in the same partition - assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); - - // Create a new provider with 2 partitions - let multi_partition_table_scan = create_mem_table_scan( - schema.clone(), - vec![vec![batch.clone()], vec![batch]], - )?; - - // Insert the data from the provider into the table. We expect coalescing partitions. - initial_table - .insert_into(&session_ctx.state(), &multi_partition_table_scan) - .await?; - // Ensure that the table now contains 4 batches of data with only 1 partition - assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4); - assert_eq!(initial_table.batches.read().await.len(), 1); + assert_eq!(resulting_data_in_table[0].len(), 2); Ok(()) } + // Test inserting multiple batches of data into a single partition #[tokio::test] - async fn test_insert_into_multiple_partition() -> Result<()> { - let (session_ctx, schema, batch) = create_initial_ctx()?; - // create a memory table with two partitions, each having one batch with the same data - let initial_table = Arc::new(MemTable::try_new( - schema.clone(), - vec![vec![batch.clone()], vec![batch.clone()]], - )?); + async fn test_insert_into_single_partition_with_multi_partition() -> Result<()> { + // Create a new schema with one field called "a" of type Int32 + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - // scan a data source provider from a memory table with a single partition - let single_partition_table_scan = create_mem_table_scan( + // Create a new batch of data to insert into the table + let batch = RecordBatch::try_new( schema.clone(), - vec![vec![batch.clone(), batch.clone()]], + vec![Arc::new(Int32Array::from_slice([1, 2, 3]))], )?; - - // insert the data from the 1 partition data source provider into the initial table - initial_table - .insert_into(&session_ctx.state(), &single_partition_table_scan) - .await?; - - // We expect round robin repartition here, each partition gets 1 batch. - assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); - assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 2); - - // scan a data source provider from a memory table with 2 partition - let multi_partition_table_scan = create_mem_table_scan( - schema.clone(), + // Run the experiment and obtain the resulting data in the table + let resulting_data_in_table = experiment( + schema, + vec![vec![batch.clone()]], vec![vec![batch.clone()], vec![batch]], - )?; - // We expect one-to-one partition mapping. - initial_table - .insert_into(&session_ctx.state(), &multi_partition_table_scan) - .await?; - // Ensure that the table now contains 3 batches of data with 2 partitions. - assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 3); - assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 3); + ) + .await?; + // Ensure that the table now contains three batches of data in the same partition + assert_eq!(resulting_data_in_table[0].len(), 3); Ok(()) } + // Test inserting multiple batches of data into multiple partitions #[tokio::test] - async fn test_insert_into_empty_table() -> Result<()> { - let (session_ctx, schema, batch) = create_initial_ctx()?; - // create empty memory table - let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![])?); + async fn test_insert_into_multi_partition_with_multi_partition() -> Result<()> { + // Create a new schema with one field called "a" of type Int32 + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - // scan a data source provider from a memory table with a single partition - let single_partition_table_scan = create_mem_table_scan( + // Create a new batch of data to insert into the table + let batch = RecordBatch::try_new( schema.clone(), - vec![vec![batch.clone(), batch.clone()]], + vec![Arc::new(Int32Array::from_slice([1, 2, 3]))], )?; + // Run the experiment and obtain the resulting data in the table + let resulting_data_in_table = experiment( + schema, + vec![vec![batch.clone()], vec![batch.clone()]], + vec![ + vec![batch.clone(), batch.clone()], + vec![batch.clone(), batch], + ], + ) + .await?; + // Ensure that each partition in the table now contains three batches of data + assert_eq!(resulting_data_in_table[0].len(), 3); + assert_eq!(resulting_data_in_table[1].len(), 3); + Ok(()) + } - // insert the data from the 1 partition data source provider into the initial table - initial_table - .insert_into(&session_ctx.state(), &single_partition_table_scan) - .await?; - - assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); + #[tokio::test] + async fn test_insert_from_empty_table() -> Result<()> { + // Create a new schema with one field called "a" of type Int32 + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); - // scan a data source provider from a memory table with 2 partition - let single_partition_table_scan = create_mem_table_scan( + // Create a new batch of data to insert into the table + let batch = RecordBatch::try_new( schema.clone(), - vec![vec![batch.clone()], vec![batch]], + vec![Arc::new(Int32Array::from_slice([1, 2, 3]))], )?; - // We expect coalesce partitions here. - initial_table - .insert_into(&session_ctx.state(), &single_partition_table_scan) - .await?; - // Ensure that the table now contains 3 batches of data with 2 partitions. - assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4); + // Run the experiment and obtain the resulting data in the table + let resulting_data_in_table = experiment( + schema, + vec![vec![batch.clone(), batch.clone()]], + vec![vec![]], + ) + .await?; + // Ensure that the table now contains two batches of data in the same partition + assert_eq!(resulting_data_in_table[0].len(), 2); Ok(()) } } diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index bb6d58fb908e..dce8a1c42408 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -33,7 +33,7 @@ use crate::{ }; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, - DescribeTable, DmlStatement, StringifiedPlan, WriteOp, + DescribeTable, StringifiedPlan, }; pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; @@ -369,23 +369,6 @@ impl SessionContext { /// Execute the [`LogicalPlan`], return a [`DataFrame`] pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result { match plan { - LogicalPlan::Dml(DmlStatement { - table_name, - op: WriteOp::Insert, - input, - .. - }) => { - if self.table_exist(&table_name)? { - let name = table_name.table(); - let provider = self.table_provider(name).await?; - provider.insert_into(&self.state(), &input).await?; - } else { - return Err(DataFusionError::Execution(format!( - "Table '{table_name}' does not exist" - ))); - } - self.return_empty_dataframe() - } LogicalPlan::Ddl(ddl) => match ddl { DdlStatement::CreateExternalTable(cmd) => { self.create_external_table(&cmd).await @@ -1475,7 +1458,7 @@ impl SessionState { .resolve(&catalog.default_catalog, &catalog.default_schema) } - fn schema_for_ref<'a>( + pub(crate) fn schema_for_ref<'a>( &'a self, table_ref: impl Into>, ) -> Result> { diff --git a/datafusion/core/src/physical_plan/memory.rs b/datafusion/core/src/physical_plan/memory.rs index f0cd48fa4f9d..12a37c65c8c8 100644 --- a/datafusion/core/src/physical_plan/memory.rs +++ b/datafusion/core/src/physical_plan/memory.rs @@ -17,11 +17,6 @@ //! Execution plan for reading in-memory batches of data -use core::fmt; -use std::any::Any; -use std::sync::Arc; -use std::task::{Context, Poll}; - use super::expressions::PhysicalSortExpr; use super::{ common, project_schema, DisplayFormatType, ExecutionPlan, Partitioning, @@ -30,10 +25,20 @@ use super::{ use crate::error::Result; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use core::fmt; +use futures::FutureExt; +use futures::StreamExt; +use std::any::Any; +use std::sync::Arc; +use std::task::{Context, Poll}; +use crate::datasource::memory::PartitionData; use crate::execution::context::TaskContext; +use crate::physical_plan::Distribution; use datafusion_common::DataFusionError; -use futures::Stream; +use futures::{ready, Stream}; +use std::mem; +use tokio::sync::{OwnedRwLockWriteGuard, RwLock}; /// Execution plan for reading in-memory batches of data pub struct MemoryExec { @@ -150,6 +155,23 @@ impl MemoryExec { }) } + /// Create a new execution plan for reading in-memory record batches + /// The provided `schema` should not have the projection applied. + pub fn try_new_owned_data( + partitions: Vec>, + schema: SchemaRef, + projection: Option>, + ) -> Result { + let projected_schema = project_schema(&schema, projection.as_ref())?; + Ok(Self { + partitions, + schema, + projected_schema, + projection, + sort_information: None, + }) + } + /// Set sort information pub fn with_sort_information( mut self, @@ -223,15 +245,365 @@ impl RecordBatchStream for MemoryStream { } } +/// Execution plan for writing record batches to an in-memory table. +pub struct MemoryWriteExec { + /// Input plan that produces the record batches to be written. + input: Arc, + /// Reference to the MemTable's partition data. + batches: Vec, + /// Schema describing the structure of the data. + schema: SchemaRef, +} + +impl fmt::Debug for MemoryWriteExec { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "schema: {:?}", self.schema) + } +} + +impl ExecutionPlan for MemoryWriteExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + /// Get the schema for this execution plan + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning( + self.input.output_partitioning().partition_count(), + ) + } + + fn benefits_from_input_partitioning(&self) -> bool { + false + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.input.output_ordering() + } + + fn required_input_distribution(&self) -> Vec { + // If the partition count of the MemTable is one, we want to require SinglePartition + // since it would induce better plans in plan optimizer. + if self.batches.len() == 1 { + vec![Distribution::SinglePartition] + } else { + vec![Distribution::UnspecifiedDistribution] + } + } + + fn maintains_input_order(&self) -> Vec { + // In theory, if MemTable partition count equals the input plans output partition count, + // the Execution plan can preserve the order inside the partitions. + vec![self.batches.len() == self.input.output_partitioning().partition_count()] + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(MemoryWriteExec::try_new( + children[0].clone(), + self.batches.clone(), + self.schema.clone(), + )?)) + } + + /// Execute the plan and return a stream of record batches for the specified partition. + /// Depending on the number of input partitions and MemTable partitions, it will choose + /// either a less lock acquiring or a locked implementation. + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let batch_count = self.batches.len(); + let data = self.input.execute(partition, context)?; + if batch_count >= self.input.output_partitioning().partition_count() { + // If the number of input partitions matches the number of MemTable partitions, + // use a lightweight implementation that doesn't utilize as many locks. + let table_partition = self.batches[partition].clone(); + Ok(Box::pin(MemorySinkOneToOneStream::try_new( + table_partition, + data, + self.schema.clone(), + )?)) + } else { + // Otherwise, use the locked implementation. + let table_partition = self.batches[partition % batch_count].clone(); + Ok(Box::pin(MemorySinkStream::try_new( + table_partition, + data, + self.schema.clone(), + )?)) + } + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!( + f, + "MemoryWriteExec: partitions={}, input_partition={}", + self.batches.len(), + self.input.output_partitioning().partition_count() + ) + } + } + } + + fn statistics(&self) -> Statistics { + Statistics::default() + } +} + +impl MemoryWriteExec { + /// Create a new execution plan for reading in-memory record batches + /// The provided `schema` should not have the projection applied. + pub fn try_new( + plan: Arc, + batches: Vec>>>, + schema: SchemaRef, + ) -> Result { + Ok(Self { + input: plan, + batches, + schema, + }) + } +} + +/// This object encodes the different states of the [`MemorySinkStream`] when +/// processing record batches. +enum MemorySinkStreamState { + /// The stream is pulling data from the input. + Pull, + /// The stream is writing data to the table partition. + Write { maybe_batch: Option }, +} + +/// A stream that saves record batches in memory-backed storage. +/// Can work even when multiple input partitions map to the same table +/// partition, achieves buffer exclusivity by locking before writing. +struct MemorySinkStream { + /// Stream of record batches to be inserted into the memory table. + data: SendableRecordBatchStream, + /// Memory table partition that stores the record batches. + table_partition: PartitionData, + /// Schema representing the structure of the data. + schema: SchemaRef, + /// State of the iterator when processing multiple polls. + state: MemorySinkStreamState, +} + +impl MemorySinkStream { + /// Create a new `MemorySinkStream` with the provided parameters. + pub fn try_new( + table_partition: PartitionData, + data: SendableRecordBatchStream, + schema: SchemaRef, + ) -> Result { + Ok(Self { + table_partition, + data, + schema, + state: MemorySinkStreamState::Pull, + }) + } + + /// Implementation of the `poll_next` method. Continuously polls the record + /// batch stream, switching between the Pull and Write states. In case of + /// an error, returns the error immediately. + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + match &mut self.state { + MemorySinkStreamState::Pull => { + // Pull data from the input stream. + if let Some(result) = ready!(self.data.as_mut().poll_next(cx)) { + match result { + Ok(batch) => { + // Switch to the Write state with the received batch. + self.state = MemorySinkStreamState::Write { + maybe_batch: Some(batch), + } + } + Err(e) => return Poll::Ready(Some(Err(e))), // Return the error immediately. + } + } else { + return Poll::Ready(None); // If the input stream is exhausted, return None. + } + } + MemorySinkStreamState::Write { maybe_batch } => { + // Acquire a write lock on the table partition. + let mut partition = + ready!(self.table_partition.write().boxed().poll_unpin(cx)); + if let Some(b) = mem::take(maybe_batch) { + partition.push(b); // Insert the batch into the table partition. + } + self.state = MemorySinkStreamState::Pull; // Switch back to the Pull state. + } + } + } + } +} + +impl Stream for MemorySinkStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.poll_next_impl(cx) + } +} + +impl RecordBatchStream for MemorySinkStream { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +/// This object encodes the different states of the [`MemorySinkOneToOneStream`] +/// when processing record batches. +enum MemorySinkOneToOneStreamState { + /// The `Acquire` variant represents the state where the [`MemorySinkOneToOneStream`] + /// is waiting to acquire the write lock on the shared partition to store the record batches. + Acquire, + + /// The `Pull` variant represents the state where the [`MemorySinkOneToOneStream`] has + /// acquired the write lock on the shared partition and can pull record batches from + /// the input stream to store in the partition. + Pull { + /// The `partition` field contains an [`OwnedRwLockWriteGuard`] which wraps the + /// shared partition, providing exclusive write access to the underlying `Vec`. + partition: OwnedRwLockWriteGuard>, + }, +} + +/// A stream that saves record batches in memory-backed storage. +/// Assumes that every table partition has at most one corresponding input +/// partition, so it locks the table partition only once. +struct MemorySinkOneToOneStream { + /// Stream of record batches to be inserted into the memory table. + data: SendableRecordBatchStream, + /// Memory table partition that stores the record batches. + table_partition: PartitionData, + /// Schema representing the structure of the data. + schema: SchemaRef, + /// State of the iterator when processing multiple polls. + state: MemorySinkOneToOneStreamState, +} + +impl MemorySinkOneToOneStream { + /// Create a new `MemorySinkOneToOneStream` with the provided parameters. + pub fn try_new( + table_partition: Arc>>, + data: SendableRecordBatchStream, + schema: SchemaRef, + ) -> Result { + Ok(Self { + table_partition, + data, + schema, + state: MemorySinkOneToOneStreamState::Acquire, + }) + } + + /// Implementation of the `poll_next` method. Continuously polls the record + /// batch stream and pushes batches to their corresponding table partition, + /// which are lock-acquired only once. In case of an error, returns the + /// error immediately. + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + match &mut self.state { + MemorySinkOneToOneStreamState::Acquire => { + // Acquire a write lock on the table partition. + self.state = MemorySinkOneToOneStreamState::Pull { + partition: ready!(self + .table_partition + .clone() + .write_owned() + .boxed() + .poll_unpin(cx)), + }; + } + MemorySinkOneToOneStreamState::Pull { partition } => { + // Iterate over the batches in the input data stream. + while let Some(result) = ready!(self.data.poll_next_unpin(cx)) { + match result { + Ok(batch) => { + partition.push(batch); + } // Insert the batch into the table partition. + Err(e) => return Poll::Ready(Some(Err(e))), // Return the error immediately. + } + } + // If the input stream is exhausted, return None to indicate the end of the stream. + return Poll::Ready(None); + } + } + } + } +} + +impl Stream for MemorySinkOneToOneStream { + type Item = Result; + + /// Poll the stream for the next item. + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.poll_next_impl(cx) + } +} + +impl RecordBatchStream for MemorySinkOneToOneStream { + /// Get the schema of the record batches in the stream. + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + #[cfg(test)] mod tests { use super::*; + use crate::datasource::streaming::PartitionStream; + use crate::datasource::{MemTable, TableProvider}; use crate::from_slice::FromSlice; + use crate::physical_plan::stream::RecordBatchStreamAdapter; + use crate::physical_plan::streaming::StreamingTableExec; use crate::physical_plan::ColumnStatistics; - use crate::prelude::SessionContext; + use crate::physical_plan::{collect, displayable, SendableRecordBatchStream}; + use crate::prelude::{CsvReadOptions, SessionContext}; + use crate::test_util; use arrow::array::Int32Array; - use arrow::datatypes::{DataType, Field, Schema}; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use arrow::record_batch::RecordBatch; + use datafusion_common::Result; + use datafusion_execution::config::SessionConfig; + use datafusion_execution::TaskContext; use futures::StreamExt; + use std::sync::Arc; fn mock_data() -> Result<(SchemaRef, RecordBatch)> { let schema = Arc::new(Schema::new(vec![ @@ -340,4 +712,262 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_insert_into() -> Result<()> { + // Create session context + let config = SessionConfig::new().with_target_partitions(8); + let ctx = SessionContext::with_config(config); + let testdata = test_util::arrow_test_data(); + let schema = test_util::aggr_test_schema(); + ctx.register_csv( + "aggregate_test_100", + &format!("{testdata}/csv/aggregate_test_100.csv"), + CsvReadOptions::new().schema(&schema), + ) + .await?; + ctx.sql( + "CREATE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL)", + ) + .await?; + + let sql = "INSERT INTO table_without_values SELECT + SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING), + COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) + FROM aggregate_test_100 + ORDER by c1 + "; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "MemoryWriteExec: partitions=1, input_partition=1", + " ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2]", + " SortPreservingMergeExec: [c1@2 ASC NULLS LAST]", + " ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(UInt8(1)), c1@0 as c1]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted]", + " SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + Ok(()) + } + + #[tokio::test] + async fn test_insert_into_as_select_multi_partitioned() -> Result<()> { + // Create session context + let config = SessionConfig::new().with_target_partitions(8); + let ctx = SessionContext::with_config(config); + let testdata = test_util::arrow_test_data(); + let schema = test_util::aggr_test_schema(); + ctx.register_csv( + "aggregate_test_100", + &format!("{testdata}/csv/aggregate_test_100.csv"), + CsvReadOptions::new().schema(&schema), + ) + .await?; + ctx.sql( + "CREATE TABLE table_without_values(field1 BIGINT NULL, field2 BIGINT NULL)", + ) + .await?; + + let sql = "INSERT INTO table_without_values SELECT + SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, + COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 + FROM aggregate_test_100"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "MemoryWriteExec: partitions=1, input_partition=1", + " CoalescePartitionsExec", + " ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted]", + " SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + Ok(()) + } + + // TODO: The generated plan is suboptimal since SortExec is in global state. + #[tokio::test] + async fn test_insert_into_as_select_single_partition() -> Result<()> { + // Create session context + let config = SessionConfig::new().with_target_partitions(8); + let ctx = SessionContext::with_config(config); + let testdata = test_util::arrow_test_data(); + let schema = test_util::aggr_test_schema(); + ctx.register_csv( + "aggregate_test_100", + &format!("{testdata}/csv/aggregate_test_100.csv"), + CsvReadOptions::new().schema(&schema), + ) + .await?; + ctx.sql("CREATE TABLE table_without_values AS SELECT + SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, + COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 + FROM aggregate_test_100") + .await?; + + let sql = "INSERT INTO table_without_values SELECT + SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, + COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 + FROM aggregate_test_100 + ORDER BY c1"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "MemoryWriteExec: partitions=8, input_partition=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " ProjectionExec: expr=[a1@0 as a1, a2@1 as a2]", + " SortPreservingMergeExec: [c1@2 ASC NULLS LAST]", + " ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as a1, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as a2, c1@0 as c1]", + " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted]", + " SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=8192", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8), input_partitions=8", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + Ok(()) + } + + // DummyPartition is a simple implementation of the PartitionStream trait. + // It produces a stream of record batches with a fixed schema and the same content. + struct DummyPartition { + schema: SchemaRef, + batch: RecordBatch, + num_batches: usize, + } + + impl PartitionStream for DummyPartition { + // Return a reference to the schema of this partition. + fn schema(&self) -> &SchemaRef { + &self.schema + } + + // Execute the partition stream, producing a stream of record batches. + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + let batches = itertools::repeat_n(self.batch.clone(), self.num_batches); + Box::pin(RecordBatchStreamAdapter::new( + self.schema.clone(), + futures::stream::iter(batches).map(Ok), + )) + } + } + + // Test the less-lock mode by inserting a large number of batches into a table. + #[tokio::test] + async fn test_one_to_one_mode() -> Result<()> { + let num_batches = 10000; + // Create a new session context + let session_ctx = SessionContext::new(); + // Create a new schema with one field called "a" of type Int32 + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a new batch of data to insert into the table + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_slice([1, 2, 3]))], + )?; + let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?); + + let single_partition = Arc::new(DummyPartition { + schema: schema.clone(), + batch, + num_batches, + }); + let input = Arc::new(StreamingTableExec::try_new( + schema.clone(), + vec![single_partition], + None, + false, + )?); + let plan = initial_table + .insert_into(&session_ctx.state(), input) + .await?; + let res = collect(plan, session_ctx.task_ctx()).await?; + assert!(res.is_empty()); + // Ensure that the table now contains two batches of data in the same partition + assert_eq!(initial_table.batches[0].read().await.len(), num_batches); + Ok(()) + } + + // Test the locked mode by inserting a large number of batches into a table. It tests + // where the table partition count is not equal to the input's output partition count. + #[tokio::test] + async fn test_locked_mode() -> Result<()> { + let num_batches = 10000; + // Create a new session context + let session_ctx = SessionContext::new(); + // Create a new schema with one field called "a" of type Int32 + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a new batch of data to insert into the table + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from_slice([1, 2, 3]))], + )?; + let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![]])?); + + let single_partition = Arc::new(DummyPartition { + schema: schema.clone(), + batch, + num_batches, + }); + let input = Arc::new(StreamingTableExec::try_new( + schema.clone(), + vec![ + single_partition.clone(), + single_partition.clone(), + single_partition, + ], + None, + false, + )?); + let plan = initial_table + .insert_into(&session_ctx.state(), input) + .await?; + let res = collect(plan, session_ctx.task_ctx()).await?; + assert!(res.is_empty()); + // Ensure that the table now contains two batches of data in the same partition + assert_eq!(initial_table.batches[0].read().await.len(), num_batches * 3); + Ok(()) + } } diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 782dcf13352c..7f68d5a39ad2 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -67,7 +67,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; -use datafusion_expr::{logical_plan, StringifiedPlan}; +use datafusion_expr::{logical_plan, DmlStatement, StringifiedPlan, WriteOp}; use datafusion_expr::{WindowFrame, WindowFrameBound}; use datafusion_optimizer::utils::unalias; use datafusion_physical_expr::expressions::Literal; @@ -489,6 +489,23 @@ impl DefaultPhysicalPlanner { let unaliased: Vec = filters.into_iter().map(unalias).collect(); source.scan(session_state, projection.as_ref(), &unaliased, *fetch).await } + LogicalPlan::Dml(DmlStatement { + table_name, + op: WriteOp::Insert, + input, + .. + }) => { + let name = table_name.table(); + let schema = session_state.schema_for_ref(table_name)?; + if let Some(provider) = schema.table(name).await { + let input_exec = self.create_initial_plan(input, session_state).await?; + provider.insert_into(session_state, input_exec).await + } else { + return Err(DataFusionError::Execution(format!( + "Table '{table_name}' does not exist" + ))); + } + } LogicalPlan::Values(Values { values, schema,