diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index f20475df150b..c81e0afa1827 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -63,7 +63,7 @@ use datafusion_expr::{ expr_rewriter::FunctionRewrite, logical_plan::{DdlStatement, Statement}, planner::ExprPlanner, - Expr, UserDefinedLogicalNode, WindowUDF, + Expr, WindowUDF, }; // backwards compatibility @@ -1682,27 +1682,7 @@ pub enum RegisterFunction { #[derive(Debug)] pub struct EmptySerializerRegistry; -impl SerializerRegistry for EmptySerializerRegistry { - fn serialize_logical_plan( - &self, - node: &dyn UserDefinedLogicalNode, - ) -> Result> { - not_impl_err!( - "Serializing user defined logical plan node `{}` is not supported", - node.name() - ) - } - - fn deserialize_logical_plan( - &self, - name: &str, - _bytes: &[u8], - ) -> Result> { - not_impl_err!( - "Deserializing user defined logical plan node `{name}` is not supported" - ) - } -} +impl SerializerRegistry for EmptySerializerRegistry {} /// Describes which SQL statements can be run. /// diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 4eb49710bcf8..588181b14421 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -19,7 +19,7 @@ use crate::expr_rewriter::FunctionRewrite; use crate::planner::ExprPlanner; -use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; +use crate::{AggregateUDF, ScalarUDF, TableSource, UserDefinedLogicalNode, WindowUDF}; use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result}; use std::collections::HashSet; use std::fmt::Debug; @@ -123,22 +123,52 @@ pub trait FunctionRegistry { } } -/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. +/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode] +/// and custom table providers for which the name alone is meaningless in the target +/// execution context, e.g. UDTFs, manually registered tables etc. pub trait SerializerRegistry: Debug + Send + Sync { /// Serialize this node to a byte array. This serialization should not include /// input plans. fn serialize_logical_plan( &self, node: &dyn UserDefinedLogicalNode, - ) -> Result>; + ) -> Result> { + not_impl_err!( + "Serializing user defined logical plan node `{}` is not supported", + node.name() + ) + } /// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from /// bytes. fn deserialize_logical_plan( &self, name: &str, - bytes: &[u8], - ) -> Result>; + _bytes: &[u8], + ) -> Result> { + not_impl_err!( + "Deserializing user defined logical plan node `{name}` is not supported" + ) + } + + /// Serialized table definition for UDTFs or manually registered table providers that can't be + /// marshaled by reference. Should return some benign error for regular tables that can be + /// found/restored by name in the destination execution context. + fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result> { + not_impl_err!("No custom table support") + } + + /// Deserialize the custom table with the given name. + /// Note: more often than not, the name can't be used as a discriminator if multiple different + /// `TableSource` and/or `TableProvider` implementations are expected (this is particularly true + /// for UDTFs in DataFusion, which are always registered under the same name: `tmp_table`). + fn deserialize_custom_table( + &self, + name: &str, + _bytes: &[u8], + ) -> Result> { + not_impl_err!("Deserializing custom table `{name}` is not supported") + } } /// A [`FunctionRegistry`] that uses in memory [`HashMap`]s diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 9623f12c88dd..9618bc4a59fc 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -30,7 +30,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension, - LogicalPlan, Operator, Projection, SortExpr, Subquery, TryCast, Values, + LogicalPlan, Operator, Projection, SortExpr, Subquery, TableScan, TryCast, Values, }; use substrait::proto::aggregate_rel::Grouping; use substrait::proto::expression as substrait_expression; @@ -86,6 +86,7 @@ use substrait::proto::expression::{ SingularOrList, SwitchExpression, WindowFunction, }; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; +use substrait::proto::read_rel::ExtensionTable; use substrait::proto::rel_common::{Emit, EmitKind}; use substrait::proto::set_rel::SetOp; use substrait::proto::{ @@ -457,6 +458,22 @@ pub trait SubstraitConsumer: Send + Sync + Sized { user_defined_literal.type_reference ) } + + fn consume_extension_table( + &self, + extension_table: &ExtensionTable, + _schema: &DFSchema, + _projection: &Option, + ) -> Result { + if let Some(ext_detail) = extension_table.detail.as_ref() { + substrait_err!( + "Missing handler for extension table: {}", + &ext_detail.type_url + ) + } else { + substrait_err!("Unexpected empty detail in ExtensionTable") + } + } } /// Convert Substrait Rel to DataFusion DataFrame @@ -578,6 +595,32 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> { let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; Ok(LogicalPlan::Extension(Extension { node: plan })) } + + fn consume_extension_table( + &self, + extension_table: &ExtensionTable, + schema: &DFSchema, + projection: &Option, + ) -> Result { + if let Some(ext_detail) = &extension_table.detail { + let source = self + .state + .serializer_registry() + .deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)?; + let table_name = ext_detail + .type_url + .rsplit_once('/') + .map(|(_, name)| name) + .unwrap_or(&ext_detail.type_url); + let table_scan = TableScan::try_new(table_name, source, None, vec![], None)?; + let plan = LogicalPlan::TableScan(table_scan); + ensure_schema_compatibility(plan.schema(), schema.clone())?; + let schema = apply_masking(schema.clone(), projection)?; + apply_projection(plan, schema) + } else { + substrait_err!("Unexpected empty detail in ExtensionTable") + } + } } // Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which @@ -1467,8 +1510,11 @@ pub async fn from_read_rel( ) .await } - _ => { - not_impl_err!("Unsupported ReadType: {:?}", read.read_type) + Some(ReadType::ExtensionTable(ext)) => { + consumer.consume_extension_table(ext, &substrait_schema, &read.projection) + } + None => { + substrait_err!("Unexpected empty read_type") } } } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index e501ddf5c698..e5cbcd4dbe66 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -22,11 +22,7 @@ use std::sync::Arc; use substrait::proto::expression_reference::ExprType; use datafusion::arrow::datatypes::{Field, IntervalUnit}; -use datafusion::logical_expr::{ - Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit, - Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan, - TryCast, Union, Values, Window, WindowFrameUnits, -}; +use datafusion::logical_expr::{Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit, Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan, TableSource, TryCast, Union, Values, Window, WindowFrameUnits}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -56,7 +52,7 @@ use datafusion::logical_expr::expr::{ }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; -use pbjson_types::Any as ProtoAny; +use pbjson_types::{Any as ProtoAny, Any}; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; use substrait::proto::expression::cast::FailureBehavior; use substrait::proto::expression::field_reference::{RootReference, RootType}; @@ -69,7 +65,7 @@ use substrait::proto::expression::literal::{ use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; use substrait::proto::expression::ScalarFunction; -use substrait::proto::read_rel::VirtualTable; +use substrait::proto::read_rel::{ExtensionTable, VirtualTable}; use substrait::proto::rel_common::EmitKind; use substrait::proto::rel_common::EmitKind::Emit; use substrait::proto::{ @@ -366,6 +362,13 @@ pub trait SubstraitProducer: Send + Sync + Sized { ) -> Result { from_in_subquery(self, in_subquery, schema) } + + fn handle_extension_table( + &mut self, + _table: &dyn TableSource, + ) -> Result { + not_impl_err!("Not implemented") + } } struct DefaultSubstraitProducer<'a> { @@ -425,6 +428,16 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> { rel_type: Some(rel_type), })) } + + fn handle_extension_table(&mut self, table: &dyn TableSource) -> Result { + let bytes = self.serializer_registry.serialize_custom_table(table)?; + Ok(ExtensionTable { + detail: Some(Any { + type_url: "/substrait.ExtensionTable".into(), + value: bytes.into(), + }) + }) + } } /// Convert DataFusion LogicalPlan to Substrait Plan @@ -539,7 +552,7 @@ pub fn to_substrait_rel( } pub fn from_table_scan( - _producer: &mut impl SubstraitProducer, + producer: &mut impl SubstraitProducer, scan: &TableScan, ) -> Result> { let projection = scan.projection.as_ref().map(|p| { @@ -559,6 +572,18 @@ pub fn from_table_scan( let table_schema = scan.source.schema().to_dfschema_ref()?; let base_schema = to_substrait_named_struct(&table_schema)?; + let table = if let Ok(ext_table) = producer + .handle_extension_table(scan.source.as_ref()) + { + ReadType::ExtensionTable(ext_table) + } else { + ReadType::NamedTable(NamedTable { + names: scan.table_name.to_vec(), + advanced_extension: None, + }) + }; + + Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, @@ -567,10 +592,7 @@ pub fn from_table_scan( best_effort_filter: None, projection, advanced_extension: None, - read_type: Some(ReadType::NamedTable(NamedTable { - names: scan.table_name.to_vec(), - advanced_extension: None, - })), + read_type: Some(table), }))), })) } @@ -2532,8 +2554,8 @@ mod test { use super::*; use crate::logical_plan::consumer::{ from_substrait_extended_expr, from_substrait_literal_without_names, - from_substrait_named_struct, from_substrait_type_without_names, - DefaultSubstraitConsumer, + from_substrait_named_struct, from_substrait_plan, + from_substrait_type_without_names, DefaultSubstraitConsumer, }; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::arrow::array::{ @@ -2541,8 +2563,12 @@ mod test { }; use datafusion::arrow::datatypes::{Field, Fields, Schema}; use datafusion::common::scalar::ScalarStructBuilder; - use datafusion::common::DFSchema; + use datafusion::common::{assert_contains, DFSchema}; + use datafusion::datasource::empty::EmptyTable; + use datafusion::datasource::{DefaultTableSource, TableProvider}; use datafusion::execution::{SessionState, SessionStateBuilder}; + use datafusion::logical_expr::registry::SerializerRegistry; + use datafusion::logical_expr::TableSource; use datafusion::prelude::SessionContext; use std::sync::OnceLock; @@ -2879,4 +2905,110 @@ mod test { assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); } + + #[tokio::test] + async fn round_trip_extension_table() { + const TABLE_NAME: &str = "custom_table"; + const SERIALIZED: &[u8] = "table definition".as_bytes(); + + fn custom_table() -> Arc { + Arc::new(EmptyTable::new(Arc::new(Schema::new([ + Arc::new(Field::new("id", DataType::Int32, false)), + Arc::new(Field::new("name", DataType::Utf8, false)), + ])))) + } + + #[derive(Debug)] + struct Registry; + impl SerializerRegistry for Registry { + fn serialize_custom_table(&self, table: &dyn TableSource) -> Result> { + if table.schema() == custom_table().schema() { + Ok(SERIALIZED.to_vec()) + } else { + Err(DataFusionError::Internal("Not our table".into())) + } + } + fn deserialize_custom_table( + &self, + name: &str, + bytes: &[u8], + ) -> Result> { + if name == TABLE_NAME && bytes == SERIALIZED { + Ok(Arc::new(DefaultTableSource::new(custom_table()))) + } else { + panic!("Unexpected extension table: {name}"); + } + } + } + + async fn round_trip_logical_plans( + local: &SessionContext, + remote: &SessionContext, + ) -> Result<()> { + local.register_table(TABLE_NAME, custom_table())?; + remote.table_provider(TABLE_NAME).await.expect_err( + "The remote context is not supposed to know about custom_table", + ); + let initial_plan = local + .sql(&format!("select id from {TABLE_NAME}")) + .await? + .logical_plan() + .clone(); + + // write substrait locally + let substrait = to_substrait_plan(&initial_plan, &local.state())?; + + // read substrait remotely + // since we know there's no `custom_table` registered in the remote context, this will only succeed + // if our table got encoded as an ExtensionTable and is now decoded back to a table source. + let restored = from_substrait_plan(&remote.state(), &substrait).await?; + assert_contains!( + // confirm that the Substrait plan contains our custom_table as an ExtensionTable + serde_json::to_string(substrait.as_ref()).unwrap(), + format!(r#""extensionTable":{{"detail":{{"typeUrl":"{TABLE_NAME}","#) + ); + remote // make sure the restored plan is fully working in the remote context + .execute_logical_plan(restored.clone()) + .await? + .collect() + .await + .expect("Restored plan cannot be executed remotely"); + assert_eq!( + // check that the restored plan is functionally equivalent (and almost identical) to the initial one + initial_plan.to_string(), + restored.to_string().replace( + // substrait will add an explicit full-schema projection if the original table had none + &format!("TableScan: {TABLE_NAME} projection=[id, name]"), + &format!("TableScan: {TABLE_NAME}"), + ) + ); + Ok(()) + } + + // take 1 + let failed_attempt = + round_trip_logical_plans(&SessionContext::new(), &SessionContext::new()) + .await + .expect_err( + "The round trip should fail in the absence of a SerializerRegistry", + ); + assert_contains!( + failed_attempt.message(), + format!("No table named '{TABLE_NAME}'") + ); + + // take 2 + fn proper_context() -> SessionContext { + SessionContext::new_with_state( + SessionStateBuilder::new() + // This will transport our custom_table as a Substrait ExtensionTable + .with_serializer_registry(Arc::new(Registry)) + .build(), + ) + } + + round_trip_logical_plans(&proper_context(), &proper_context()) + .await + .expect("Local plan could not be restored remotely"); + } }