diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index e06f947ad5e7..d6e4490cec4c 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -347,6 +347,22 @@ impl DFSchema { .collect() } + /// Find all fields indices having the given qualifier + pub fn fields_indices_with_qualified( + &self, + qualifier: &TableReference, + ) -> Vec { + self.fields + .iter() + .enumerate() + .filter_map(|(idx, field)| { + field + .qualifier() + .and_then(|q| q.eq(qualifier).then_some(idx)) + }) + .collect() + } + /// Find all fields match the given name pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> { self.fields diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index b9b45a6c7470..830cd7a07e46 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -64,7 +64,8 @@ impl TableProviderFactory for StreamTableFactory { .with_encoding(encoding) .with_order(cmd.order_exprs.clone()) .with_header(cmd.has_header) - .with_batch_size(state.config().batch_size()); + .with_batch_size(state.config().batch_size()) + .with_constraints(cmd.constraints.clone()); Ok(Arc::new(StreamTable(Arc::new(config)))) } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index abdd7f5f57f6..09f4842c9e64 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -32,6 +32,7 @@ use crate::{ use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::utils::get_at_indices; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, @@ -425,18 +426,18 @@ pub fn expand_qualified_wildcard( wildcard_options: Option<&WildcardAdditionalOptions>, ) -> Result> { let qualifier = TableReference::from(qualifier); - let qualified_fields: Vec = schema - .fields_with_qualified(&qualifier) - .into_iter() - .cloned() - .collect(); + let qualified_indices = schema.fields_indices_with_qualified(&qualifier); + let projected_func_dependencies = schema + .functional_dependencies() + .project_functional_dependencies(&qualified_indices, qualified_indices.len()); + let qualified_fields = get_at_indices(schema.fields(), &qualified_indices)?; if qualified_fields.is_empty() { return plan_err!("Invalid qualifier {qualifier}"); } let qualified_schema = DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())? // We can use the functional dependencies as is, since it only stores indices: - .with_functional_dependencies(schema.functional_dependencies().clone())?; + .with_functional_dependencies(projected_func_dependencies)?; let excluded_columns = if let Some(WildcardAdditionalOptions { opt_exclude, opt_except, diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index c74c4ac0f821..921de96252f0 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -43,7 +43,7 @@ use datafusion_execution::TaskContext; use datafusion_expr::Accumulator; use datafusion_physical_expr::{ aggregate::is_order_sensitive, - equivalence::collapse_lex_req, + equivalence::{collapse_lex_req, ProjectionMapping}, expressions::{Column, Max, Min, UnKnownColumn}, physical_exprs_contains, reverse_order_bys, AggregateExpr, EquivalenceProperties, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, @@ -59,7 +59,6 @@ mod topk; mod topk_stream; pub use datafusion_expr::AggregateFunction; -use datafusion_physical_expr::equivalence::ProjectionMapping; pub use datafusion_physical_expr::expressions::create_aggregate_expr; /// Hash aggregate modes @@ -464,7 +463,7 @@ impl AggregateExec { pub fn try_new( mode: AggregateMode, group_by: PhysicalGroupBy, - mut aggr_expr: Vec>, + aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, @@ -482,6 +481,37 @@ impl AggregateExec { group_by.expr.len(), )); let original_schema = Arc::new(original_schema); + AggregateExec::try_new_with_schema( + mode, + group_by, + aggr_expr, + filter_expr, + input, + input_schema, + schema, + original_schema, + ) + } + + /// Create a new hash aggregate execution plan with the given schema. + /// This constructor isn't part of the public API, it is used internally + /// by Datafusion to enforce schema consistency during when re-creating + /// `AggregateExec`s inside optimization rules. Schema field names of an + /// `AggregateExec` depends on the names of aggregate expressions. Since + /// a rule may re-write aggregate expressions (e.g. reverse them) during + /// initialization, field names may change inadvertently if one re-creates + /// the schema in such cases. + #[allow(clippy::too_many_arguments)] + fn try_new_with_schema( + mode: AggregateMode, + group_by: PhysicalGroupBy, + mut aggr_expr: Vec>, + filter_expr: Vec>>, + input: Arc, + input_schema: SchemaRef, + schema: SchemaRef, + original_schema: SchemaRef, + ) -> Result { // Reset ordering requirement to `None` if aggregator is not order-sensitive let mut order_by_expr = aggr_expr .iter() @@ -858,13 +888,15 @@ impl ExecutionPlan for AggregateExec { self: Arc, children: Vec>, ) -> Result> { - let mut me = AggregateExec::try_new( + let mut me = AggregateExec::try_new_with_schema( self.mode, self.group_by.clone(), self.aggr_expr.clone(), self.filter_expr.clone(), children[0].clone(), self.input_schema.clone(), + self.schema.clone(), + self.original_schema.clone(), )?; me.limit = self.limit; Ok(Arc::new(me)) @@ -2162,4 +2194,56 @@ mod tests { assert_eq!(res, common_requirement); Ok(()) } + + #[test] + fn test_agg_exec_same_schema() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, true), + Field::new("b", DataType::Float32, true), + ])); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let sort_expr = vec![PhysicalSortExpr { + expr: col_b.clone(), + options: option_desc, + }]; + let sort_expr_reverse = reverse_order_bys(&sort_expr); + let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); + + let aggregates: Vec> = vec![ + Arc::new(FirstValue::new( + col_b.clone(), + "FIRST_VALUE(b)".to_string(), + DataType::Float64, + sort_expr_reverse.clone(), + vec![DataType::Float64], + )), + Arc::new(LastValue::new( + col_b.clone(), + "LAST_VALUE(b)".to_string(), + DataType::Float64, + sort_expr.clone(), + vec![DataType::Float64], + )), + ]; + let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups, + aggregates.clone(), + vec![None, None], + blocking_exec.clone(), + schema, + )?); + let new_agg = aggregate_exec + .clone() + .with_new_children(vec![blocking_exec])?; + assert_eq!(new_agg.schema(), aggregate_exec.schema()); + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/groupby.slt b/datafusion/sqllogictest/test_files/groupby.slt index 44d30ba0b34c..f1b6a57287b5 100644 --- a/datafusion/sqllogictest/test_files/groupby.slt +++ b/datafusion/sqllogictest/test_files/groupby.slt @@ -4280,3 +4280,15 @@ LIMIT 5 2 0 0 3 0 0 4 0 1 + + +query ITIPTR rowsort +SELECT r.* +FROM sales_global_with_pk as l, sales_global_with_pk as r +LIMIT 5 +---- +0 GRC 0 2022-01-01T06:00:00 EUR 30 +1 FRA 1 2022-01-01T08:00:00 EUR 50 +1 FRA 3 2022-01-02T12:00:00 EUR 200 +1 TUR 2 2022-01-01T11:30:00 TRY 75 +1 TUR 4 2022-01-03T10:00:00 TRY 100