diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 4644e15febef..787698c009de 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1587,8 +1587,9 @@ mod tests { use datafusion_common::{Constraint, Constraints}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - cast, count_distinct, create_udf, expr, lit, sum, BuiltInWindowFunction, - ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition, + array_agg, cast, count_distinct, create_udf, expr, lit, sum, + BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, + WindowFunctionDefinition, }; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; @@ -2044,6 +2045,24 @@ mod tests { Ok(()) } + // Test issue: https://github.com/apache/datafusion/issues/10346 + #[tokio::test] + async fn test_select_over_aggregate_schema() -> Result<()> { + let df = test_table() + .await? + .with_column("c", col("c1"))? + .aggregate(vec![], vec![array_agg(col("c")).alias("c")])? + .select(vec![col("c")])?; + + assert_eq!(df.schema().fields().len(), 1); + let field = df.schema().field(0); + // There are two columns named 'c', one from the input of the aggregate and the other from the output. + // Select should return the column from the output of the aggregate, which is a list. + assert!(matches!(field.data_type(), DataType::List(_))); + + Ok(()) + } + #[tokio::test] async fn test_distinct() -> Result<()> { let t = test_table().await?; diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 64fe98c23b08..4282952a1efc 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -23,7 +23,6 @@ use std::sync::Arc; use crate::expr::{Alias, Sort, WindowFunction}; use crate::expr_rewriter::strip_outer_reference; -use crate::logical_plan::Aggregate; use crate::signature::{Signature, TypeSignature}; use crate::{ and, BinaryExpr, Cast, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, @@ -725,53 +724,14 @@ pub fn from_plan( plan.with_new_exprs(expr.to_vec(), inputs.to_vec()) } -/// Find all columns referenced from an aggregate query -fn agg_cols(agg: &Aggregate) -> Vec { - agg.aggr_expr - .iter() - .chain(&agg.group_expr) - .flat_map(find_columns_referenced_by_expr) - .collect() -} - -fn exprlist_to_fields_aggregate( - exprs: &[Expr], - agg: &Aggregate, -) -> Result, Arc)>> { - let agg_cols = agg_cols(agg); - let mut fields = vec![]; - for expr in exprs { - match expr { - Expr::Column(c) if agg_cols.iter().any(|x| x == c) => { - // resolve against schema of input to aggregate - fields.push(expr.to_field(agg.input.schema())?); - } - _ => fields.push(expr.to_field(&agg.schema)?), - } - } - Ok(fields) -} - /// Create field meta-data from an expression, for use in a result set schema pub fn exprlist_to_fields( exprs: &[Expr], plan: &LogicalPlan, ) -> Result, Arc)>> { - // when dealing with aggregate plans we cannot simply look in the aggregate output schema - // because it will contain columns representing complex expressions (such a column named - // `GROUPING(person.state)` so in order to resolve `person.state` in this case we need to - // look at the input to the aggregate instead. - let fields = match plan { - LogicalPlan::Aggregate(agg) => Some(exprlist_to_fields_aggregate(exprs, agg)), - _ => None, - }; - if let Some(fields) = fields { - fields - } else { - // look for exact match in plan's output schema - let input_schema = &plan.schema(); - exprs.iter().map(|e| e.to_field(input_schema)).collect() - } + // look for exact match in plan's output schema + let input_schema = &plan.schema(); + exprs.iter().map(|e| e.to_field(input_schema)).collect() } /// Convert an expression into Column expression if it's already provided as input plan.