Skip to content

Commit

Permalink
fix: preserve qualifiers when rewriting expressions (#12341)
Browse files Browse the repository at this point in the history
* fix: preserve qualifiers in `NamePreserver`

* Add test

* Review feedback
  • Loading branch information
jonahgao authored Sep 6, 2024
1 parent 41b10ca commit a444528
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 66 deletions.
15 changes: 14 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ use sqlparser::ast::{
pub enum Expr {
/// An expression with a specific name.
Alias(Alias),
/// A named reference to a qualified filed in a schema.
/// A named reference to a qualified field in a schema.
Column(Column),
/// A named reference to a variable in a registry.
ScalarVariable(DataType, Vec<String>),
Expand Down Expand Up @@ -1115,6 +1115,19 @@ impl Expr {
SchemaDisplay(self)
}

/// Returns the qualifier and the schema name of this expression.
///
/// Used when the expression forms the output field of a certain plan.
/// The result is the field's qualifier and field name in the plan's
/// output schema. We can use this qualified name to reference the field.
pub fn qualified_name(&self) -> (Option<TableReference>, String) {
match self {
Expr::Column(Column { relation, name }) => (relation.clone(), name.clone()),
Expr::Alias(Alias { relation, name, .. }) => (relation.clone(), name.clone()),
_ => (None, self.schema_name().to_string()),
}
}

/// Returns a full and complete string representation of this expression.
#[deprecated(note = "use format! instead")]
pub fn canonical_name(&self) -> String {
Expand Down
85 changes: 49 additions & 36 deletions datafusion/expr/src/expr_rewriter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ use crate::logical_plan::Projection;
use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};

use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRewriter,
};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::TableReference;
use datafusion_common::{Column, DFSchema, Result};

Expand Down Expand Up @@ -279,22 +277,10 @@ pub fn unalias(expr: Expr) -> Expr {
}
}

/// Rewrites `expr` using `rewriter`, ensuring that the output has the
/// same name as `expr` prior to rewrite, adding an alias if necessary.
///
/// This is important when optimizing plans to ensure the output
/// schema of plan nodes don't change after optimization
pub fn rewrite_preserving_name<R>(expr: Expr, rewriter: &mut R) -> Result<Expr>
where
R: TreeNodeRewriter<Node = Expr>,
{
let original_name = expr.name_for_alias()?;
let expr = expr.rewrite(rewriter)?.data;
expr.alias_if_changed(original_name)
}

/// Handles ensuring the name of rewritten expressions is not changed.
///
/// This is important when optimizing plans to ensure the output
/// schema of plan nodes don't change after optimization.
/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
/// expression should be preserved: `3 as "1 + 2"`
///
Expand All @@ -303,9 +289,17 @@ pub struct NamePreserver {
use_alias: bool,
}

/// If the name of an expression is remembered, it will be preserved when
/// rewriting the expression
pub struct SavedName(Option<String>);
/// If the qualified name of an expression is remembered, it will be preserved
/// when rewriting the expression
pub enum SavedName {
/// Saved qualified name to be preserved
Saved {
relation: Option<TableReference>,
name: String,
},
/// Name is not preserved
None,
}

impl NamePreserver {
/// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
Expand All @@ -326,23 +320,30 @@ impl NamePreserver {

pub fn save(&self, expr: &Expr) -> Result<SavedName> {
let original_name = if self.use_alias {
Some(expr.name_for_alias()?)
let (relation, name) = expr.qualified_name();
SavedName::Saved { relation, name }
} else {
None
SavedName::None
};

Ok(SavedName(original_name))
Ok(original_name)
}
}

impl SavedName {
/// Ensures the name of the rewritten expression is preserved
/// Ensures the qualified name of the rewritten expression is preserved
pub fn restore(self, expr: Expr) -> Result<Expr> {
let Self(original_name) = self;
match original_name {
Some(name) => expr.alias_if_changed(name),
None => Ok(expr),
}
let expr = match self {
SavedName::Saved { relation, name } => {
let (new_relation, new_name) = expr.qualified_name();
if new_relation != relation || new_name != name {
expr.alias_qualified(relation, name)
} else {
expr
}
}
SavedName::None => expr,
};
Ok(expr)
}
}

Expand All @@ -353,6 +354,7 @@ mod test {
use super::*;
use crate::{col, lit, Cast};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::tree_node::TreeNodeRewriter;
use datafusion_common::ScalarValue;

#[derive(Default)]
Expand Down Expand Up @@ -511,10 +513,20 @@ mod test {

// change literal type from i32 to i64
test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64)));

// test preserve qualifier
test_rewrite(
Expr::Column(Column::new(Some("test"), "a")),
Expr::Column(Column::new_unqualified("test.a")),
);
test_rewrite(
Expr::Column(Column::new_unqualified("test.a")),
Expr::Column(Column::new(Some("test"), "a")),
);
}

/// rewrites `expr_from` to `rewrite_to` using
/// `rewrite_preserving_name` verifying the result is `expected_expr`
/// rewrites `expr_from` to `rewrite_to` while preserving the original qualified name
/// by using the `NamePreserver`
fn test_rewrite(expr_from: Expr, rewrite_to: Expr) {
struct TestRewriter {
rewrite_to: Expr,
Expand All @@ -531,11 +543,12 @@ mod test {
let mut rewriter = TestRewriter {
rewrite_to: rewrite_to.clone(),
};
let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap();

let original_name = expr_from.schema_name().to_string();
let new_name = expr.schema_name().to_string();
let saved_name = NamePreserver { use_alias: true }.save(&expr_from).unwrap();
let new_expr = expr_from.clone().rewrite(&mut rewriter).unwrap().data;
let new_expr = saved_name.restore(new_expr).unwrap();

let original_name = expr_from.qualified_name();
let new_name = new_expr.qualified_name();
assert_eq!(
original_name, new_name,
"mismatch rewriting expr_from: {expr_from} to {rewrite_to}"
Expand Down
35 changes: 6 additions & 29 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,35 +462,12 @@ impl ExprSchemable for Expr {
&self,
input_schema: &dyn ExprSchema,
) -> Result<(Option<TableReference>, Arc<Field>)> {
match self {
Expr::Column(c) => {
let (data_type, nullable) = self.data_type_and_nullable(input_schema)?;
Ok((
c.relation.clone(),
Field::new(&c.name, data_type, nullable)
.with_metadata(self.metadata(input_schema)?)
.into(),
))
}
Expr::Alias(Alias { relation, name, .. }) => {
let (data_type, nullable) = self.data_type_and_nullable(input_schema)?;
Ok((
relation.clone(),
Field::new(name, data_type, nullable)
.with_metadata(self.metadata(input_schema)?)
.into(),
))
}
_ => {
let (data_type, nullable) = self.data_type_and_nullable(input_schema)?;
Ok((
None,
Field::new(self.schema_name().to_string(), data_type, nullable)
.with_metadata(self.metadata(input_schema)?)
.into(),
))
}
}
let (relation, schema_name) = self.qualified_name();
let (data_type, nullable) = self.data_type_and_nullable(input_schema)?;
let field = Field::new(schema_name, data_type, nullable)
.with_metadata(self.metadata(input_schema)?)
.into();
Ok((relation, field))
}

/// Wraps this expression in a cast to a target [arrow::datatypes::DataType].
Expand Down
15 changes: 15 additions & 0 deletions datafusion/sqllogictest/test_files/select.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,9 @@ SELECT i + i FROM test WHERE i > 2;
----
6

statement ok
DROP TABLE test;

query error DataFusion error: Arrow error: Parser error: Error parsing timestamp from 'I AM NOT A TIMESTAMP': error parsing date
SELECT to_timestamp('I AM NOT A TIMESTAMP');

Expand Down Expand Up @@ -1741,3 +1744,15 @@ select a from t;

statement ok
set datafusion.optimizer.max_passes=3;

# Test issue: https://github.com/apache/datafusion/issues/12183
statement ok
CREATE TABLE test(a BIGINT) AS VALUES (1);

query I
SELECT "test.a" FROM (SELECT a AS "test.a" FROM test)
----
1

statement ok
DROP TABLE test;

0 comments on commit a444528

Please sign in to comment.