From 54de169323b1b95ce35c45d852c43ff437dd24f1 Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Tue, 28 May 2024 09:59:53 -0700 Subject: [PATCH] [SPARK-48273][SQL] Fix late rewrite of PlanWithUnresolvedIdentifier ### What changes were proposed in this pull request? `PlanWithUnresolvedIdentifier` is rewritten later in analysis which causes rules like `SubstituteUnresolvedOrdinals` to miss the new plan. This causes following queries to fail: ``` create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); -- cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); -- create table identifier('t2') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1; ``` Fix this by explicitly applying rules after plan rewrite. ### Why are the changes needed? To fix the described bug. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the mentioned problematic queries. ### How was this patch tested? Updated existing `identifier-clause.sql` golden file. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46580 from nikolamand-db/SPARK-48273. Authored-by: Nikola Mandic Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 9 ++- .../analysis/ResolveIdentifierClause.scala | 11 +++- .../sql/catalyst/rules/RuleExecutor.scala | 2 +- .../identifier-clause.sql.out | 59 +++++++++++++++++++ .../sql-tests/inputs/identifier-clause.sql | 9 +++ .../results/identifier-clause.sql.out | 56 ++++++++++++++++++ 6 files changed, 139 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 55b6f1af7fd8b..a233161713c3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -254,7 +254,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor TypeCoercion.typeCoercionRules } - override def batches: Seq[Batch] = Seq( + private def earlyBatches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, new SubstituteExecuteImmediate(catalogManager), // This rule optimizes `UpdateFields` expression chains so looks more like optimization rule. @@ -274,7 +274,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Batch("Simple Sanity Check", Once, LookupFunctions), Batch("Keep Legacy Outputs", Once, - KeepLegacyOutputs), + KeepLegacyOutputs) + ) + + override def batches: Seq[Batch] = earlyBatches ++ Seq( Batch("Resolution", fixedPoint, new ResolveCatalogs(catalogManager) :: ResolveInsertInto :: @@ -319,7 +322,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveTimeZone :: ResolveRandomSeed :: ResolveBinaryArithmetic :: - ResolveIdentifierClause :: + new ResolveIdentifierClause(earlyBatches) :: ResolveUnion :: ResolveRowLevelCommandAssignments :: RewriteDeleteFromTable :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala index ced7123dfcc14..f04b7799e35ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveIdentifierClause.scala @@ -20,19 +20,24 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER import org.apache.spark.sql.types.StringType /** * Resolves the identifier expressions and builds the original plans/expressions. */ -object ResolveIdentifierClause extends Rule[LogicalPlan] with AliasHelper with EvalHelper { +class ResolveIdentifierClause(earlyBatches: Seq[RuleExecutor[LogicalPlan]#Batch]) + extends Rule[LogicalPlan] with AliasHelper with EvalHelper { + + private val executor = new RuleExecutor[LogicalPlan] { + override def batches: Seq[Batch] = earlyBatches.asInstanceOf[Seq[Batch]] + } override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsAnyPattern(UNRESOLVED_IDENTIFIER)) { case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved => - p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr)) + executor.execute(p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr))) case other => other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) { case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 0aa01e4f5c517..c8b3f224a3129 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -147,7 +147,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { override val maxIterationsSetting: String = null) extends Strategy /** A batch of rules. */ - protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) + protected[catalyst] case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) /** Defines a sequence of rule batches, to be overridden by the implementation. */ protected def batches: Seq[Batch] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out index 7389c7be87af7..f799c19a3bb8e 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out @@ -926,6 +926,65 @@ org.apache.spark.sql.catalyst.parser.ParseException } +-- !query +create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query analysis +CreateViewCommand `v1`, (select my_col from (values (1), (2), (1) as (my_col)) group by 1), false, false, LocalTempView, UNSUPPORTED, true + +- Aggregate [my_col#x], [my_col#x] + +- SubqueryAlias __auto_generated_subquery_name + +- SubqueryAlias as + +- LocalRelation [my_col#x] + + +-- !query +cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query analysis +CacheTableAsSelect t1, (select my_col from (values (1), (2), (1) as (my_col)) group by 1), false, true + +- Aggregate [my_col#x], [my_col#x] + +- SubqueryAlias __auto_generated_subquery_name + +- SubqueryAlias as + +- LocalRelation [my_col#x] + + +-- !query +create table identifier('t2') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query analysis +CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`t2`, ErrorIfExists, [my_col] + +- Aggregate [my_col#x], [my_col#x] + +- SubqueryAlias __auto_generated_subquery_name + +- SubqueryAlias as + +- LocalRelation [my_col#x] + + +-- !query +insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1 +-- !query analysis +InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t2, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/t2], Append, `spark_catalog`.`default`.`t2`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t2), [my_col] ++- Aggregate [my_col#x], [my_col#x] + +- SubqueryAlias __auto_generated_subquery_name + +- SubqueryAlias as + +- LocalRelation [my_col#x] + + +-- !query +drop view v1 +-- !query analysis +DropTempViewCommand v1 + + +-- !query +drop table t1 +-- !query analysis +DropTempViewCommand t1 + + +-- !query +drop table t2 +-- !query analysis +DropTable false, false ++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t2 + + -- !query SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1) -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql b/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql index fd53f44d3c33c..978b82c331feb 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql @@ -132,6 +132,15 @@ CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.a DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg'); CREATE TEMPORARY VIEW IDENTIFIER('default.v')(c1) AS VALUES(1); +-- SPARK-48273: Aggregation operation in statements using identifier clause for table name +create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); +cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); +create table identifier('t2') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1); +insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1; +drop view v1; +drop table t1; +drop table t2; + -- Not supported SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1); SELECT T1.c1 FROM VALUES(1) AS T1(c1) JOIN VALUES(1) AS T2(c1) USING (IDENTIFIER('c1')); diff --git a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out index 9dfc6a66b0782..68aa5956a91c1 100644 --- a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out @@ -1059,6 +1059,62 @@ org.apache.spark.sql.catalyst.parser.ParseException } +-- !query +create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query schema +struct<> +-- !query output + + + +-- !query +cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query schema +struct<> +-- !query output + + + +-- !query +create table identifier('t2') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1) +-- !query schema +struct<> +-- !query output + + + +-- !query +insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1 +-- !query schema +struct<> +-- !query output + + + +-- !query +drop view v1 +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table t1 +-- !query schema +struct<> +-- !query output + + + +-- !query +drop table t2 +-- !query schema +struct<> +-- !query output + + + -- !query SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1) -- !query schema