diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index a80c4b94d999..632ef8d287d2 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -26,7 +26,6 @@ use std::sync::Arc; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use async_trait::async_trait; -use parking_lot::RwLock; use crate::datasource::{TableProvider, TableType}; use crate::error::{DataFusionError, Result}; @@ -41,7 +40,7 @@ use crate::physical_plan::{repartition::RepartitionExec, Partitioning}; #[derive(Debug)] pub struct MemTable { schema: SchemaRef, - batches: Arc>>>, + batches: Vec>, } impl MemTable { @@ -54,7 +53,7 @@ impl MemTable { { Ok(Self { schema, - batches: Arc::new(RwLock::new(partitions)), + batches: partitions, }) } else { Err(DataFusionError::Plan( @@ -118,11 +117,6 @@ impl MemTable { } MemTable::try_new(schema.clone(), data) } - - /// Get record batches in MemTable - pub fn get_batches(&self) -> Arc>>> { - self.batches.clone() - } } #[async_trait] @@ -146,9 +140,8 @@ impl TableProvider for MemTable { _filters: &[Expr], _limit: Option, ) -> Result> { - let batches = self.batches.read(); Ok(Arc::new(MemoryExec::try_new( - &(*batches).clone(), + &self.batches.clone(), self.schema(), projection.cloned(), )?)) diff --git a/datafusion/core/tests/sqllogictests/src/error.rs b/datafusion/core/tests/sqllogictests/src/error.rs index 5324e8f88550..0b073870df8a 100644 --- a/datafusion/core/tests/sqllogictests/src/error.rs +++ b/datafusion/core/tests/sqllogictests/src/error.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow::error::ArrowError; use datafusion_common::DataFusionError; use sqllogictest::TestError; use sqlparser::parser::ParserError; @@ -32,6 +33,8 @@ pub enum DFSqlLogicTestError { DataFusion(DataFusionError), /// Error returned when SQL is syntactically incorrect. Sql(ParserError), + /// Error from arrow-rs + Arrow(ArrowError), } impl From for DFSqlLogicTestError { @@ -52,6 +55,12 @@ impl From for DFSqlLogicTestError { } } +impl From for DFSqlLogicTestError { + fn from(value: ArrowError) -> Self { + DFSqlLogicTestError::Arrow(value) + } +} + impl Display for DFSqlLogicTestError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -64,6 +73,7 @@ impl Display for DFSqlLogicTestError { write!(f, "DataFusion error: {}", error) } DFSqlLogicTestError::Sql(error) => write!(f, "SQL Parser error: {}", error), + DFSqlLogicTestError::Arrow(error) => write!(f, "Arrow error: {}", error), } } } diff --git a/datafusion/core/tests/sqllogictests/src/insert/mod.rs b/datafusion/core/tests/sqllogictests/src/insert/mod.rs index 025015f5f0e8..a8f24a051601 100644 --- a/datafusion/core/tests/sqllogictests/src/insert/mod.rs +++ b/datafusion/core/tests/sqllogictests/src/insert/mod.rs @@ -19,6 +19,7 @@ mod util; use crate::error::Result; use crate::insert::util::LogicTestContextProvider; +use arrow::record_batch::RecordBatch; use datafusion::datasource::MemTable; use datafusion::prelude::SessionContext; use datafusion_common::{DFSchema, DataFusionError}; @@ -26,6 +27,7 @@ use datafusion_expr::Expr as DFExpr; use datafusion_sql::planner::SqlToRel; use sqlparser::ast::{Expr, SetExpr, Statement as SQLStatement}; use std::collections::HashMap; +use std::sync::Arc; pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result { // First, use sqlparser to get table name and insert values @@ -52,19 +54,14 @@ pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result< _ => unreachable!(), } - // Second, get table by table name - // Here we assume table must be in memory table. - let table_provider = ctx.table_provider(table_name.as_str())?; - let table_batches = table_provider - .as_any() - .downcast_ref::() - .unwrap() - .get_batches(); + // Second, get batches in table and destroy the old table + let mut origin_batches = ctx.table(table_name.as_str())?.collect().await?; + let schema = ctx.table_provider(table_name.as_str())?.schema(); + ctx.deregister_table(table_name.as_str())?; // Third, transfer insert values to `RecordBatch` // Attention: schema info can be ignored. (insert values don't contain schema info) let sql_to_rel = SqlToRel::new(&LogicTestContextProvider {}); - let mut insert_batches = Vec::with_capacity(insert_values.len()); for row in insert_values.into_iter() { let logical_exprs = row .into_iter() @@ -74,12 +71,17 @@ pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result< .collect::, DataFusionError>>()?; // Directly use `select` to get `RecordBatch` let dataframe = ctx.read_empty()?; - insert_batches.push(dataframe.select(logical_exprs)?.collect().await?) + origin_batches.extend(dataframe.select(logical_exprs)?.collect().await?) } - // Final, append the `RecordBatch` to memtable's batches - let mut table_batches = table_batches.write(); - table_batches.extend(insert_batches); + // Replace new batches schema to old schema + for batch in origin_batches.iter_mut() { + *batch = RecordBatch::try_new(schema.clone(), batch.columns().to_vec())?; + } + + // Final, create new memtable with same schema. + let new_provider = MemTable::try_new(schema, vec![origin_batches])?; + ctx.register_table(table_name.as_str(), Arc::new(new_provider))?; Ok("".to_string()) }