From 21259574100955fa646891af8f38763cd58321c8 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 14 Sep 2016 23:10:20 +0800 Subject: [PATCH] [SPARK-17409][SQL] Do Not Optimize Query in CTAS More Than Once ### What changes were proposed in this pull request? As explained in https://github.com/apache/spark/pull/14797: >Some analyzer rules have assumptions on logical plans, optimizer may break these assumption, we should not pass an optimized query plan into QueryExecution (will be analyzed again), otherwise we may some weird bugs. For example, we have a rule for decimal calculation to promote the precision before binary operations, use PromotePrecision as placeholder to indicate that this rule should not apply twice. But a Optimizer rule will remove this placeholder, that break the assumption, then the rule applied twice, cause wrong result. We should not optimize the query in CTAS more than once. For example, ```Scala spark.range(99, 101).createOrReplaceTempView("tab1") val sqlStmt = "SELECT id, cast(id as long) * cast('1.0' as decimal(38, 18)) as num FROM tab1" sql(s"CREATE TABLE tab2 USING PARQUET AS $sqlStmt") checkAnswer(spark.table("tab2"), sql(sqlStmt)) ``` Before this PR, the results do not match ``` == Results == !== Correct Answer - 2 == == Spark Answer - 2 == ![100,100.000000000000000000] [100,null] [99,99.000000000000000000] [99,99.000000000000000000] ``` After this PR, the results match. ``` +---+----------------------+ |id |num | +---+----------------------+ |99 |99.000000000000000000 | |100|100.000000000000000000| +---+----------------------+ ``` In this PR, we do not treat the `query` in CTAS as a child. Thus, the `query` will not be optimized when optimizing CTAS statement. However, we still need to analyze it for normalizing and verifying the CTAS in the Analyzer. Thus, we do it in the analyzer rule `PreprocessDDL`, because so far only this rule needs the analyzed plan of the `query`. ### How was this patch tested? Added a test Author: gatorsmile Closes #15048 from gatorsmile/ctasOptimized. --- .../sql/catalyst/plans/logical/Command.scala | 7 +++++- .../analysis/UnsupportedOperationsSuite.scala | 5 +--- .../sql/execution/command/SetCommand.scala | 2 -- .../spark/sql/execution/command/cache.scala | 7 ------ .../sql/execution/command/commands.scala | 4 +--- .../sql/execution/command/databases.scala | 2 -- .../spark/sql/execution/command/ddl.scala | 6 ----- .../spark/sql/execution/datasources/ddl.scala | 12 +++++----- .../sql/execution/datasources/rules.scala | 24 ++++++++++++++----- .../spark/sql/internal/SessionState.scala | 2 +- .../sources/CreateTableAsSelectSuite.scala | 12 ++++++++++ .../spark/sql/hive/HiveSessionState.scala | 2 +- .../sql/hive/execution/HiveExplainSuite.scala | 6 ++--- 13 files changed, 49 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index 75a5b10d9ed04..64f57835c8898 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.expressions.Attribute + /** * A logical node that represents a non-query command to be executed by the system. For example, * commands can be used by parsers to represent DDL operations. Commands, unlike queries, are * eagerly executed. */ -trait Command +trait Command extends LeafNode { + final override def children: Seq[LogicalPlan] = Seq.empty + override def output: Seq[Attribute] = Seq.empty +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 6df47acaba85b..ff1bb126f463d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -31,10 +31,7 @@ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.IntegerType /** A dummy command for testing unsupported operations. */ -case class DummyCommand() extends LogicalPlan with Command { - override def output: Seq[Attribute] = Nil - override def children: Seq[LogicalPlan] = Nil -} +case class DummyCommand() extends Command class UnsupportedOperationsSuite extends SparkFunSuite { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index b0e2d03af070d..af6def52d07d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -129,6 +129,4 @@ case object ResetCommand extends RunnableCommand with Logging { sparkSession.sessionState.conf.clear() Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 697e2ff21159b..c31f4dc9aba4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -47,8 +46,6 @@ case class CacheTableCommand( Seq.empty[Row] } - - override def output: Seq[Attribute] = Seq.empty } @@ -58,8 +55,6 @@ case class UncacheTableCommand(tableIdent: TableIdentifier) extends RunnableComm sparkSession.catalog.uncacheTable(tableIdent.quotedString) Seq.empty[Row] } - - override def output: Seq[Attribute] = Seq.empty } /** @@ -71,6 +66,4 @@ case object ClearCacheCommand extends RunnableCommand { sparkSession.catalog.clearCache() Seq.empty[Row] } - - override def output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 424a962b5eb1c..698c625d617fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -35,9 +35,7 @@ import org.apache.spark.sql.types._ * A logical command that is executed for its side-effects. `RunnableCommand`s are * wrapped in `ExecutedCommand` during execution. */ -trait RunnableCommand extends LogicalPlan with logical.Command { - override def output: Seq[Attribute] = Seq.empty - final override def children: Seq[LogicalPlan] = Seq.empty +trait RunnableCommand extends logical.Command { def run(sparkSession: SparkSession): Seq[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala index 597ec27ce6698..e5a6a5f60b8a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala @@ -59,6 +59,4 @@ case class SetDatabaseCommand(databaseName: String) extends RunnableCommand { sparkSession.sessionState.catalog.setCurrentDatabase(databaseName) Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index bc1c4f85e3315..dcda2f8d1c52a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -70,8 +70,6 @@ case class CreateDatabaseCommand( ifNotExists) Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } @@ -101,8 +99,6 @@ case class DropDatabaseCommand( sparkSession.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade) Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } /** @@ -126,8 +122,6 @@ case class AlterDatabasePropertiesCommand( Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 1b1e2123b7c47..fa95af2648cf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.types._ -case class CreateTable(tableDesc: CatalogTable, mode: SaveMode, query: Option[LogicalPlan]) - extends LogicalPlan with Command { +case class CreateTable( + tableDesc: CatalogTable, + mode: SaveMode, + query: Option[LogicalPlan]) extends Command { assert(tableDesc.provider.isDefined, "The table to be created must have a provider.") if (query.isEmpty) { @@ -35,9 +37,7 @@ case class CreateTable(tableDesc: CatalogTable, mode: SaveMode, query: Option[Lo "create table without data insertion can only use ErrorIfExists or Ignore as SaveMode.") } - override def output: Seq[Attribute] = Seq.empty[Attribute] - - override def children: Seq[LogicalPlan] = query.toSeq + override def innerChildren: Seq[QueryPlan[_]] = query.toSeq } case class CreateTempViewUsing( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index fbf4063ff63b8..bd6eb6e0535ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -66,9 +66,10 @@ class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] { } /** - * Preprocess some DDL plans, e.g. [[CreateTable]], to do some normalization and checking. + * Analyze [[CreateTable]] and do some normalization and checking. + * For CREATE TABLE AS SELECT, the SELECT query is also analyzed. */ -case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] { +case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // When we CREATE TABLE without specifying the table schema, we should fail the query if @@ -95,9 +96,19 @@ case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] { // * can't use all table columns as partition columns. // * partition columns' type must be AtomicType. // * sort columns' type must be orderable. - case c @ CreateTable(tableDesc, mode, query) if c.childrenResolved => - val schema = if (query.isDefined) query.get.schema else tableDesc.schema - val columnNames = if (conf.caseSensitiveAnalysis) { + case c @ CreateTable(tableDesc, mode, query) => + val analyzedQuery = query.map { q => + // Analyze the query in CTAS and then we can do the normalization and checking. + val qe = sparkSession.sessionState.executePlan(q) + qe.assertAnalyzed() + qe.analyzed + } + val schema = if (analyzedQuery.isDefined) { + analyzedQuery.get.schema + } else { + tableDesc.schema + } + val columnNames = if (sparkSession.sessionState.conf.caseSensitiveAnalysis) { schema.map(_.name) } else { schema.map(_.name.toLowerCase) @@ -106,7 +117,7 @@ case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] { val partitionColsChecked = checkPartitionColumns(schema, tableDesc) val bucketColsChecked = checkBucketColumns(schema, partitionColsChecked) - c.copy(tableDesc = bucketColsChecked) + c.copy(tableDesc = bucketColsChecked, query = analyzedQuery) } private def checkPartitionColumns(schema: StructType, tableDesc: CatalogTable): CatalogTable = { @@ -176,6 +187,7 @@ case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] { colName: String, colType: String): String = { val tableCols = schema.map(_.name) + val conf = sparkSession.sessionState.conf tableCols.find(conf.resolver(_, colName)).getOrElse { failAnalysis(s"$colType column $colName is not defined in table $tableIdent, " + s"defined table columns are: ${tableCols.mkString(", ")}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 8fdbd0f2c6dab..c899773b6b36f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -111,7 +111,7 @@ private[sql] class SessionState(sparkSession: SparkSession) { lazy val analyzer: Analyzer = { new Analyzer(catalog, conf) { override val extendedResolutionRules = - PreprocessDDL(conf) :: + AnalyzeCreateTable(sparkSession) :: PreprocessTableInsertion(conf) :: new FindDataSourceTable(sparkSession) :: DataSourceAnalysis(conf) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 729c9fdda543e..344d4aa6cfea4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -236,4 +236,16 @@ class CreateTableAsSelectSuite assert(e.contains("Expected positive number of buckets, but got `0`")) } } + + test("CTAS of decimal calculation") { + withTable("tab2") { + withTempView("tab1") { + spark.range(99, 101).createOrReplaceTempView("tab1") + val sqlStmt = + "SELECT id, cast(id as long) * cast('1.0' as decimal(38, 18)) as num FROM tab1" + sql(s"CREATE TABLE tab2 USING PARQUET AS $sqlStmt") + checkAnswer(spark.table("tab2"), sql(sqlStmt)) + } + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 15e1255653f88..eb10c11382e83 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -60,7 +60,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) override val extendedResolutionRules = catalog.ParquetConversions :: catalog.OrcConversions :: - PreprocessDDL(conf) :: + AnalyzeCreateTable(sparkSession) :: PreprocessTableInsertion(conf) :: DataSourceAnalysis(conf) :: (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 98afd99a203ac..f9751e3d5f2eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -77,7 +77,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "src") } - test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") { + test("SPARK-17409: The EXPLAIN output of CTAS only shows the analyzed plan") { withTempView("jt") { val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) spark.read.json(rdd).createOrReplaceTempView("jt") @@ -98,8 +98,8 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } val physicalIndex = outputs.indexOf("== Physical Plan ==") - assert(!outputs.substring(physicalIndex).contains("Subquery"), - "Physical Plan should not contain Subquery since it's eliminated by optimizer") + assert(outputs.substring(physicalIndex).contains("Subquery"), + "Physical Plan should contain SubqueryAlias since the query should not be optimized") } }