Skip to content

Commit

Permalink
[SPARK-17409][SQL] Do Not Optimize Query in CTAS More Than Once
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
As explained in #14797:
>Some analyzer rules have assumptions on logical plans, optimizer may break these assumption, we should not pass an optimized query plan into QueryExecution (will be analyzed again), otherwise we may some weird bugs.
For example, we have a rule for decimal calculation to promote the precision before binary operations, use PromotePrecision as placeholder to indicate that this rule should not apply twice. But a Optimizer rule will remove this placeholder, that break the assumption, then the rule applied twice, cause wrong result.

We should not optimize the query in CTAS more than once. For example,
```Scala
spark.range(99, 101).createOrReplaceTempView("tab1")
val sqlStmt = "SELECT id, cast(id as long) * cast('1.0' as decimal(38, 18)) as num FROM tab1"
sql(s"CREATE TABLE tab2 USING PARQUET AS $sqlStmt")
checkAnswer(spark.table("tab2"), sql(sqlStmt))
```
Before this PR, the results do not match
```
== Results ==
!== Correct Answer - 2 ==       == Spark Answer - 2 ==
![100,100.000000000000000000]   [100,null]
 [99,99.000000000000000000]     [99,99.000000000000000000]
```
After this PR, the results match.
```
+---+----------------------+
|id |num                   |
+---+----------------------+
|99 |99.000000000000000000 |
|100|100.000000000000000000|
+---+----------------------+
```

In this PR, we do not treat the `query` in CTAS as a child. Thus, the `query` will not be optimized when optimizing CTAS statement. However, we still need to analyze it for normalizing and verifying the CTAS in the Analyzer. Thus, we do it in the analyzer rule `PreprocessDDL`, because so far only this rule needs the analyzed plan of the `query`.

### How was this patch tested?
Added a test

Author: gatorsmile <[email protected]>

Closes #15048 from gatorsmile/ctasOptimized.
  • Loading branch information
gatorsmile authored and cloud-fan committed Sep 14, 2016
1 parent dc0a4c9 commit 52738d4
Show file tree
Hide file tree
Showing 13 changed files with 49 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.Attribute

/**
* A logical node that represents a non-query command to be executed by the system. For example,
* commands can be used by parsers to represent DDL operations. Commands, unlike queries, are
* eagerly executed.
*/
trait Command
trait Command extends LeafNode {
final override def children: Seq[LogicalPlan] = Seq.empty
override def output: Seq[Attribute] = Seq.empty
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.IntegerType

/** A dummy command for testing unsupported operations. */
case class DummyCommand() extends LogicalPlan with Command {
override def output: Seq[Attribute] = Nil
override def children: Seq[LogicalPlan] = Nil
}
case class DummyCommand() extends Command

class UnsupportedOperationsSuite extends SparkFunSuite {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,4 @@ case object ResetCommand extends RunnableCommand with Logging {
sparkSession.sessionState.conf.clear()
Seq.empty[Row]
}

override val output: Seq[Attribute] = Seq.empty
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command

import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

Expand Down Expand Up @@ -47,8 +46,6 @@ case class CacheTableCommand(

Seq.empty[Row]
}

override def output: Seq[Attribute] = Seq.empty
}


Expand All @@ -58,8 +55,6 @@ case class UncacheTableCommand(tableIdent: TableIdentifier) extends RunnableComm
sparkSession.catalog.uncacheTable(tableIdent.quotedString)
Seq.empty[Row]
}

override def output: Seq[Attribute] = Seq.empty
}

/**
Expand All @@ -71,6 +66,4 @@ case object ClearCacheCommand extends RunnableCommand {
sparkSession.catalog.clearCache()
Seq.empty[Row]
}

override def output: Seq[Attribute] = Seq.empty
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ import org.apache.spark.sql.types._
* A logical command that is executed for its side-effects. `RunnableCommand`s are
* wrapped in `ExecutedCommand` during execution.
*/
trait RunnableCommand extends LogicalPlan with logical.Command {
override def output: Seq[Attribute] = Seq.empty
final override def children: Seq[LogicalPlan] = Seq.empty
trait RunnableCommand extends logical.Command {
def run(sparkSession: SparkSession): Seq[Row]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,4 @@ case class SetDatabaseCommand(databaseName: String) extends RunnableCommand {
sparkSession.sessionState.catalog.setCurrentDatabase(databaseName)
Seq.empty[Row]
}

override val output: Seq[Attribute] = Seq.empty
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ case class CreateDatabaseCommand(
ifNotExists)
Seq.empty[Row]
}

override val output: Seq[Attribute] = Seq.empty
}


Expand Down Expand Up @@ -101,8 +99,6 @@ case class DropDatabaseCommand(
sparkSession.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade)
Seq.empty[Row]
}

override val output: Seq[Attribute] = Seq.empty
}

/**
Expand All @@ -126,8 +122,6 @@ case class AlterDatabasePropertiesCommand(

Seq.empty[Row]
}

override val output: Seq[Attribute] = Seq.empty
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.types._

case class CreateTable(tableDesc: CatalogTable, mode: SaveMode, query: Option[LogicalPlan])
extends LogicalPlan with Command {
case class CreateTable(
tableDesc: CatalogTable,
mode: SaveMode,
query: Option[LogicalPlan]) extends Command {
assert(tableDesc.provider.isDefined, "The table to be created must have a provider.")

if (query.isEmpty) {
Expand All @@ -35,9 +37,7 @@ case class CreateTable(tableDesc: CatalogTable, mode: SaveMode, query: Option[Lo
"create table without data insertion can only use ErrorIfExists or Ignore as SaveMode.")
}

override def output: Seq[Attribute] = Seq.empty[Attribute]

override def children: Seq[LogicalPlan] = query.toSeq
override def innerChildren: Seq[QueryPlan[_]] = query.toSeq
}

case class CreateTempViewUsing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] {
}

/**
* Preprocess some DDL plans, e.g. [[CreateTable]], to do some normalization and checking.
* Analyze [[CreateTable]] and do some normalization and checking.
* For CREATE TABLE AS SELECT, the SELECT query is also analyzed.
*/
case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] {
case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// When we CREATE TABLE without specifying the table schema, we should fail the query if
Expand All @@ -95,9 +96,19 @@ case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] {
// * can't use all table columns as partition columns.
// * partition columns' type must be AtomicType.
// * sort columns' type must be orderable.
case c @ CreateTable(tableDesc, mode, query) if c.childrenResolved =>
val schema = if (query.isDefined) query.get.schema else tableDesc.schema
val columnNames = if (conf.caseSensitiveAnalysis) {
case c @ CreateTable(tableDesc, mode, query) =>
val analyzedQuery = query.map { q =>
// Analyze the query in CTAS and then we can do the normalization and checking.
val qe = sparkSession.sessionState.executePlan(q)
qe.assertAnalyzed()
qe.analyzed
}
val schema = if (analyzedQuery.isDefined) {
analyzedQuery.get.schema
} else {
tableDesc.schema
}
val columnNames = if (sparkSession.sessionState.conf.caseSensitiveAnalysis) {
schema.map(_.name)
} else {
schema.map(_.name.toLowerCase)
Expand All @@ -106,7 +117,7 @@ case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] {

val partitionColsChecked = checkPartitionColumns(schema, tableDesc)
val bucketColsChecked = checkBucketColumns(schema, partitionColsChecked)
c.copy(tableDesc = bucketColsChecked)
c.copy(tableDesc = bucketColsChecked, query = analyzedQuery)
}

private def checkPartitionColumns(schema: StructType, tableDesc: CatalogTable): CatalogTable = {
Expand Down Expand Up @@ -176,6 +187,7 @@ case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] {
colName: String,
colType: String): String = {
val tableCols = schema.map(_.name)
val conf = sparkSession.sessionState.conf
tableCols.find(conf.resolver(_, colName)).getOrElse {
failAnalysis(s"$colType column $colName is not defined in table $tableIdent, " +
s"defined table columns are: ${tableCols.mkString(", ")}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ private[sql] class SessionState(sparkSession: SparkSession) {
lazy val analyzer: Analyzer = {
new Analyzer(catalog, conf) {
override val extendedResolutionRules =
PreprocessDDL(conf) ::
AnalyzeCreateTable(sparkSession) ::
PreprocessTableInsertion(conf) ::
new FindDataSourceTable(sparkSession) ::
DataSourceAnalysis(conf) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,16 @@ class CreateTableAsSelectSuite
assert(e.contains("Expected positive number of buckets, but got `0`"))
}
}

test("CTAS of decimal calculation") {
withTable("tab2") {
withTempView("tab1") {
spark.range(99, 101).createOrReplaceTempView("tab1")
val sqlStmt =
"SELECT id, cast(id as long) * cast('1.0' as decimal(38, 18)) as num FROM tab1"
sql(s"CREATE TABLE tab2 USING PARQUET AS $sqlStmt")
checkAnswer(spark.table("tab2"), sql(sqlStmt))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession)
override val extendedResolutionRules =
catalog.ParquetConversions ::
catalog.OrcConversions ::
PreprocessDDL(conf) ::
AnalyzeCreateTable(sparkSession) ::
PreprocessTableInsertion(conf) ::
DataSourceAnalysis(conf) ::
(if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
"src")
}

test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") {
test("SPARK-17409: The EXPLAIN output of CTAS only shows the analyzed plan") {
withTempView("jt") {
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""))
spark.read.json(rdd).createOrReplaceTempView("jt")
Expand All @@ -98,8 +98,8 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
}

val physicalIndex = outputs.indexOf("== Physical Plan ==")
assert(!outputs.substring(physicalIndex).contains("Subquery"),
"Physical Plan should not contain Subquery since it's eliminated by optimizer")
assert(outputs.substring(physicalIndex).contains("Subquery"),
"Physical Plan should contain SubqueryAlias since the query should not be optimized")
}
}

Expand Down

0 comments on commit 52738d4

Please sign in to comment.