Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: Add functional dependency check and aggregate try_new schema #8584

Merged
merged 3 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,22 @@ impl DFSchema {
.collect()
}

/// Find all fields indices having the given qualifier
pub fn fields_indices_with_qualified(
&self,
qualifier: &TableReference,
) -> Vec<usize> {
self.fields
.iter()
.enumerate()
.filter_map(|(idx, field)| {
field
.qualifier()
.and_then(|q| q.eq(qualifier).then_some(idx))
})
.collect()
}

/// Find all fields match the given name
pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> {
self.fields
Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/src/datasource/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ impl TableProviderFactory for StreamTableFactory {
.with_encoding(encoding)
.with_order(cmd.order_exprs.clone())
.with_header(cmd.has_header)
.with_batch_size(state.config().batch_size());
.with_batch_size(state.config().batch_size())
.with_constraints(cmd.constraints.clone());

Ok(Arc::new(StreamTable(Arc::new(config))))
}
Expand Down
13 changes: 7 additions & 6 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::{

use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_common::utils::get_at_indices;
use datafusion_common::{
internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef,
DataFusionError, Result, ScalarValue, TableReference,
Expand Down Expand Up @@ -425,18 +426,18 @@ pub fn expand_qualified_wildcard(
wildcard_options: Option<&WildcardAdditionalOptions>,
) -> Result<Vec<Expr>> {
let qualifier = TableReference::from(qualifier);
let qualified_fields: Vec<DFField> = schema
.fields_with_qualified(&qualifier)
.into_iter()
.cloned()
.collect();
let qualified_indices = schema.fields_indices_with_qualified(&qualifier);
let projected_func_dependencies = schema
.functional_dependencies()
.project_functional_dependencies(&qualified_indices, qualified_indices.len());
let qualified_fields = get_at_indices(schema.fields(), &qualified_indices)?;
if qualified_fields.is_empty() {
return plan_err!("Invalid qualifier {qualifier}");
}
let qualified_schema =
DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())?
// We can use the functional dependencies as is, since it only stores indices:
.with_functional_dependencies(schema.functional_dependencies().clone())?;
.with_functional_dependencies(projected_func_dependencies)?;
let excluded_columns = if let Some(WildcardAdditionalOptions {
opt_exclude,
opt_except,
Expand Down
92 changes: 88 additions & 4 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use datafusion_execution::TaskContext;
use datafusion_expr::Accumulator;
use datafusion_physical_expr::{
aggregate::is_order_sensitive,
equivalence::collapse_lex_req,
equivalence::{collapse_lex_req, ProjectionMapping},
expressions::{Column, Max, Min, UnKnownColumn},
physical_exprs_contains, reverse_order_bys, AggregateExpr, EquivalenceProperties,
LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement,
Expand All @@ -59,7 +59,6 @@ mod topk;
mod topk_stream;

pub use datafusion_expr::AggregateFunction;
use datafusion_physical_expr::equivalence::ProjectionMapping;
pub use datafusion_physical_expr::expressions::create_aggregate_expr;

/// Hash aggregate modes
Expand Down Expand Up @@ -464,7 +463,7 @@ impl AggregateExec {
pub fn try_new(
mode: AggregateMode,
group_by: PhysicalGroupBy,
mut aggr_expr: Vec<Arc<dyn AggregateExpr>>,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
Expand All @@ -482,6 +481,37 @@ impl AggregateExec {
group_by.expr.len(),
));
let original_schema = Arc::new(original_schema);
AggregateExec::try_new_with_schema(
mode,
group_by,
aggr_expr,
filter_expr,
input,
input_schema,
schema,
original_schema,
)
}

/// Create a new hash aggregate execution plan with the given schema.
/// This constructor isn't part of the public API, it is used internally
/// by Datafusion to enforce schema consistency during when re-creating
/// `AggregateExec`s inside optimization rules. Schema field names of an
/// `AggregateExec` depends on the names of aggregate expressions. Since
/// a rule may re-write aggregate expressions (e.g. reverse them) during
/// initialization, field names may change inadvertently if one re-creates
/// the schema in such cases.
#[allow(clippy::too_many_arguments)]
fn try_new_with_schema(
mode: AggregateMode,
group_by: PhysicalGroupBy,
mut aggr_expr: Vec<Arc<dyn AggregateExpr>>,
filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
schema: SchemaRef,
original_schema: SchemaRef,
) -> Result<Self> {
// Reset ordering requirement to `None` if aggregator is not order-sensitive
let mut order_by_expr = aggr_expr
.iter()
Expand Down Expand Up @@ -858,13 +888,15 @@ impl ExecutionPlan for AggregateExec {
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut me = AggregateExec::try_new(
let mut me = AggregateExec::try_new_with_schema(
self.mode,
self.group_by.clone(),
self.aggr_expr.clone(),
self.filter_expr.clone(),
children[0].clone(),
self.input_schema.clone(),
self.schema.clone(),
self.original_schema.clone(),
)?;
me.limit = self.limit;
Ok(Arc::new(me))
Expand Down Expand Up @@ -2162,4 +2194,56 @@ mod tests {
assert_eq!(res, common_requirement);
Ok(())
}

#[test]
fn test_agg_exec_same_schema() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any way that this can be tested from SQL? Or does it require a programatic test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I tried to produce this error from a SQL test. However couldn't reproduce it. The reason may be related to the some rules in physical optimizer has schema_check flag false. But I am not sure, I couldn't come up with a reproducer from the SQL query.

let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, true),
Field::new("b", DataType::Float32, true),
]));

let col_a = col("a", &schema)?;
let col_b = col("b", &schema)?;
let option_desc = SortOptions {
descending: true,
nulls_first: true,
};
let sort_expr = vec![PhysicalSortExpr {
expr: col_b.clone(),
options: option_desc,
}];
let sort_expr_reverse = reverse_order_bys(&sort_expr);
let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]);

let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![
Arc::new(FirstValue::new(
col_b.clone(),
"FIRST_VALUE(b)".to_string(),
DataType::Float64,
sort_expr_reverse.clone(),
vec![DataType::Float64],
)),
Arc::new(LastValue::new(
col_b.clone(),
"LAST_VALUE(b)".to_string(),
DataType::Float64,
sort_expr.clone(),
vec![DataType::Float64],
)),
];
let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups,
aggregates.clone(),
vec![None, None],
blocking_exec.clone(),
schema,
)?);
let new_agg = aggregate_exec
.clone()
.with_new_children(vec![blocking_exec])?;
assert_eq!(new_agg.schema(), aggregate_exec.schema());
Ok(())
}
}
12 changes: 12 additions & 0 deletions datafusion/sqllogictest/test_files/groupby.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4280,3 +4280,15 @@ LIMIT 5
2 0 0
3 0 0
4 0 1


query ITIPTR rowsort
SELECT r.*
FROM sales_global_with_pk as l, sales_global_with_pk as r
LIMIT 5
----
0 GRC 0 2022-01-01T06:00:00 EUR 30
1 FRA 1 2022-01-01T08:00:00 EUR 50
1 FRA 3 2022-01-02T12:00:00 EUR 200
1 TUR 2 2022-01-01T11:30:00 TRY 75
1 TUR 4 2022-01-03T10:00:00 TRY 100