Skip to content

Commit

Permalink
[SPARK-17892][SQL][2.0] Do Not Optimize Query in CTAS More Than Once #…
Browse files Browse the repository at this point in the history
…15048

### What changes were proposed in this pull request?
This PR is to backport #15048 and #15459.

However, in 2.0, we do not have a unified logical node `CreateTable` and the analyzer rule `PreWriteCheck` is also different. To minimize the code changes, this PR adds a new rule `AnalyzeCreateTableAsSelect`. Please treat it as a new PR to review. Thanks!

As explained in #14797:
>Some analyzer rules have assumptions on logical plans, optimizer may break these assumption, we should not pass an optimized query plan into QueryExecution (will be analyzed again), otherwise we may some weird bugs.
For example, we have a rule for decimal calculation to promote the precision before binary operations, use PromotePrecision as placeholder to indicate that this rule should not apply twice. But a Optimizer rule will remove this placeholder, that break the assumption, then the rule applied twice, cause wrong result.

We should not optimize the query in CTAS more than once. For example,
```Scala
spark.range(99, 101).createOrReplaceTempView("tab1")
val sqlStmt = "SELECT id, cast(id as long) * cast('1.0' as decimal(38, 18)) as num FROM tab1"
sql(s"CREATE TABLE tab2 USING PARQUET AS $sqlStmt")
checkAnswer(spark.table("tab2"), sql(sqlStmt))
```
Before this PR, the results do not match
```
== Results ==
!== Correct Answer - 2 ==       == Spark Answer - 2 ==
![100,100.000000000000000000]   [100,null]
 [99,99.000000000000000000]     [99,99.000000000000000000]
```
After this PR, the results match.
```
+---+----------------------+
|id |num                   |
+---+----------------------+
|99 |99.000000000000000000 |
|100|100.000000000000000000|
+---+----------------------+
```

In this PR, we do not treat the `query` in CTAS as a child. Thus, the `query` will not be optimized when optimizing CTAS statement. However, we still need to analyze it for normalizing and verifying the CTAS in the Analyzer. Thus, we do it in the analyzer rule `PreprocessDDL`, because so far only this rule needs the analyzed plan of the `query`.

### How was this patch tested?

Author: gatorsmile <[email protected]>

Closes #15502 from gatorsmile/ctasOptimize2.0.
  • Loading branch information
gatorsmile authored and cloud-fan committed Oct 17, 2016
1 parent ca66f52 commit d1a0211
Show file tree
Hide file tree
Showing 17 changed files with 102 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.Attribute

/**
* A logical node that represents a non-query command to be executed by the system. For example,
* commands can be used by parsers to represent DDL operations. Commands, unlike queries, are
* eagerly executed.
*/
trait Command
trait Command extends LeafNode {
final override def children: Seq[LogicalPlan] = Seq.empty
override def output: Seq[Attribute] = Seq.empty
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.IntegerType

/** A dummy command for testing unsupported operations. */
case class DummyCommand() extends LogicalPlan with Command {
override def output: Seq[Attribute] = Nil
override def children: Seq[LogicalPlan] = Nil
}
case class DummyCommand() extends Command

class UnsupportedOperationsSuite extends SparkFunSuite {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
c.bucketSpec,
c.mode,
c.options,
c.child)
c.query)
ExecutedCommandExec(cmd) :: Nil

case c: CreateTempViewUsing =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,4 @@ case object ResetCommand extends RunnableCommand with Logging {
sparkSession.sessionState.conf.clear()
Seq.empty[Row]
}

override val output: Seq[Attribute] = Seq.empty
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ case class CacheTableCommand(

Seq.empty[Row]
}

override def output: Seq[Attribute] = Seq.empty
}


Expand All @@ -58,8 +56,6 @@ case class UncacheTableCommand(tableIdent: TableIdentifier) extends RunnableComm
sparkSession.catalog.uncacheTable(tableIdent.quotedString)
Seq.empty[Row]
}

override def output: Seq[Attribute] = Seq.empty
}

/**
Expand All @@ -71,6 +67,4 @@ case object ClearCacheCommand extends RunnableCommand {
sparkSession.catalog.clearCache()
Seq.empty[Row]
}

override def output: Seq[Attribute] = Seq.empty
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ import org.apache.spark.sql.types._
* A logical command that is executed for its side-effects. `RunnableCommand`s are
* wrapped in `ExecutedCommand` during execution.
*/
trait RunnableCommand extends LogicalPlan with logical.Command {
override def output: Seq[Attribute] = Seq.empty
final override def children: Seq[LogicalPlan] = Seq.empty
trait RunnableCommand extends logical.Command {
def run(sparkSession: SparkSession): Seq[Row]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,4 @@ case class SetDatabaseCommand(databaseName: String) extends RunnableCommand {
sparkSession.sessionState.catalog.setCurrentDatabase(databaseName)
Seq.empty[Row]
}

override val output: Seq[Attribute] = Seq.empty
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ case class CreateDatabaseCommand(
ifNotExists)
Seq.empty[Row]
}

override val output: Seq[Attribute] = Seq.empty
}


Expand Down Expand Up @@ -103,8 +101,6 @@ case class DropDatabaseCommand(
sparkSession.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade)
Seq.empty[Row]
}

override val output: Seq[Attribute] = Seq.empty
}

/**
Expand All @@ -128,8 +124,6 @@ case class AlterDatabasePropertiesCommand(

Seq.empty[Row]
}

override val output: Seq[Attribute] = Seq.empty
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._
import org.apache.spark.sql.execution.datasources.PartitioningUtils
Expand All @@ -43,18 +44,18 @@ import org.apache.spark.util.Utils

case class CreateHiveTableAsSelectLogicalPlan(
tableDesc: CatalogTable,
child: LogicalPlan,
allowExisting: Boolean) extends UnaryNode with Command {
query: LogicalPlan,
allowExisting: Boolean) extends Command {

override def output: Seq[Attribute] = Seq.empty[Attribute]
override def innerChildren: Seq[QueryPlan[_]] = Seq(query)

override lazy val resolved: Boolean =
tableDesc.identifier.database.isDefined &&
tableDesc.schema.nonEmpty &&
tableDesc.storage.serde.isDefined &&
tableDesc.storage.inputFormat.isDefined &&
tableDesc.storage.outputFormat.isDefined &&
childrenResolved
query.resolved
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
Expand All @@ -41,17 +41,10 @@ case class CreateTableUsing(
partitionColumns: Array[String],
bucketSpec: Option[BucketSpec],
allowExisting: Boolean,
managedIfNoPath: Boolean) extends LogicalPlan with logical.Command {

override def output: Seq[Attribute] = Seq.empty
override def children: Seq[LogicalPlan] = Seq.empty
}
managedIfNoPath: Boolean) extends logical.Command

/**
* A node used to support CTAS statements and saveAsTable for the data source API.
* This node is a [[logical.UnaryNode]] instead of a [[logical.Command]] because we want the
* analyzer can analyze the logical plan that will be used to populate the table.
* So, [[PreWriteCheck]] can detect cases that are not allowed.
*/
case class CreateTableUsingAsSelect(
tableIdent: TableIdentifier,
Expand All @@ -60,8 +53,10 @@ case class CreateTableUsingAsSelect(
bucketSpec: Option[BucketSpec],
mode: SaveMode,
options: Map[String, String],
child: LogicalPlan) extends logical.UnaryNode {
override def output: Seq[Attribute] = Seq.empty[Attribute]
query: LogicalPlan) extends logical.Command {

override def innerChildren: Seq[QueryPlan[_]] = Seq(query)
override lazy val resolved: Boolean = query.resolved
}

case class CreateTempViewUsing(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrd
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.CreateHiveTableAsSelectLogicalPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation}

Expand Down Expand Up @@ -61,6 +62,25 @@ class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] {
}
}

/**
* Analyze the query in CREATE TABLE AS SELECT (CTAS). After analysis, [[PreWriteCheck]] also
* can detect the cases that are not allowed.
*/
case class AnalyzeCreateTableAsSelect(sparkSession: SparkSession) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case c: CreateTableUsingAsSelect if !c.query.resolved =>
c.copy(query = analyzeQuery(c.query))
case c: CreateHiveTableAsSelectLogicalPlan if !c.query.resolved =>
c.copy(query = analyzeQuery(c.query))
}

private def analyzeQuery(query: LogicalPlan): LogicalPlan = {
val qe = sparkSession.sessionState.executePlan(query)
qe.assertAnalyzed()
qe.analyzed
}
}

/**
* Preprocess the [[InsertIntoTable]] plan. Throws exception if the number of columns mismatch, or
* specified partition columns are different from the existing partition columns in the target
Expand Down Expand Up @@ -216,7 +236,7 @@ case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog)
// (the relation is a BaseRelation).
case l @ LogicalRelation(dest: BaseRelation, _, _) =>
// Get all input data source relations of the query.
val srcRelations = c.child.collect {
val srcRelations = c.query.collect {
case LogicalRelation(src: BaseRelation, _, _) => src
}
if (srcRelations.contains(dest)) {
Expand All @@ -233,12 +253,12 @@ case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog)
}

PartitioningUtils.validatePartitionColumn(
c.child.schema, c.partitionColumns, conf.caseSensitiveAnalysis)
c.query.schema, c.partitionColumns, conf.caseSensitiveAnalysis)

for {
spec <- c.bucketSpec
sortColumnName <- spec.sortColumnNames
sortColumn <- c.child.schema.find(_.name == sortColumnName)
sortColumn <- c.query.schema.find(_.name == sortColumnName)
} {
if (!RowOrdering.isOrderable(sortColumn.dataType)) {
failAnalysis(s"Cannot use ${sortColumn.dataType.simpleString} for sorting column.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.AnalyzeTableCommand
import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, PreprocessTableInsertion, ResolveDataSource}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager}
import org.apache.spark.sql.util.ExecutionListenerManager

Expand Down Expand Up @@ -111,6 +111,7 @@ private[sql] class SessionState(sparkSession: SparkSession) {
lazy val analyzer: Analyzer = {
new Analyzer(catalog, conf) {
override val extendedResolutionRules =
AnalyzeCreateTableAsSelect(sparkSession) ::
PreprocessTableInsertion(conf) ::
new FindDataSourceTable(sparkSession) ::
DataSourceAnalysis(conf) ::
Expand Down
20 changes: 20 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ import scala.util.Random
import org.scalatest.Matchers._

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1565,4 +1567,22 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val d = sampleDf.withColumn("c", monotonically_increasing_id).select($"c").collect
assert(d.size == d.distinct.size)
}

test("SPARK-17409: Do Not Optimize Query in CTAS (Data source tables) More Than Once") {
withTable("bar") {
withTempView("foo") {
withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") {
sql("select 0 as id").createOrReplaceTempView("foo")
val df = sql("select * from foo group by id")
// If we optimize the query in CTAS more than once, the following saveAsTable will fail
// with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])`
df.write.mode("overwrite").saveAsTable("bar")
checkAnswer(spark.table("bar"), Row(0) :: Nil)
val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar"))
assert(tableMetadata.properties(DATASOURCE_PROVIDER) == "json",
"the expected table is a data source table using json")
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,16 @@ class CreateTableAsSelectSuite
Some(BucketSpec(5, Seq("a"), Seq("b"))))
}
}

test("SPARK-17409: CTAS of decimal calculation") {
withTable("tab2") {
withTempView("tab1") {
spark.range(99, 101).createOrReplaceTempView("tab1")
val sqlStmt =
"SELECT id, cast(id as long) * cast('1.0' as decimal(38, 18)) as num FROM tab1"
sql(s"CREATE TABLE tab2 USING PARQUET AS $sqlStmt")
checkAnswer(spark.table("tab2"), sql(sqlStmt))
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
case p: LogicalPlan if !p.childrenResolved => p
case p: LogicalPlan if p.resolved => p

case p @ CreateHiveTableAsSelectLogicalPlan(table, child, allowExisting) =>
case p @ CreateHiveTableAsSelectLogicalPlan(table, query, allowExisting) =>
val desc = if (table.storage.serde.isEmpty) {
// add default serde
table.withNewStorage(
Expand All @@ -462,7 +462,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log

execution.CreateHiveTableAsSelectCommand(
desc.copy(identifier = TableIdentifier(tblName, Some(dbName))),
child,
query,
allowExisting)
}
}
Expand Down Expand Up @@ -510,7 +510,7 @@ private[hive] case class InsertIntoHiveTable(
child: LogicalPlan,
overwrite: Boolean,
ifNotExists: Boolean)
extends LogicalPlan with Command {
extends LogicalPlan {

override def children: Seq[LogicalPlan] = child :: Nil
override def output: Seq[Attribute] = Seq.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession)
override lazy val analyzer: Analyzer = {
new Analyzer(catalog, conf) {
override val extendedResolutionRules =
AnalyzeCreateTableAsSelect(sparkSession) ::
catalog.ParquetConversions ::
catalog.OrcConversions ::
catalog.CreateTables ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

package org.apache.spark.sql.hive

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils

class MetastoreRelationSuite extends SparkFunSuite {
class MetastoreRelationSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
test("makeCopy and toJSON should work") {
val table = CatalogTable(
identifier = TableIdentifier("test", Some("db")),
Expand All @@ -35,4 +38,19 @@ class MetastoreRelationSuite extends SparkFunSuite {
// No exception should be thrown
relation.toJSON
}

test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") {
withTable("bar") {
withTempView("foo") {
sql("select 0 as id").createOrReplaceTempView("foo")
// If we optimize the query in CTAS more than once, the following saveAsTable will fail
// with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])`
sql("CREATE TABLE bar AS SELECT * FROM foo group by id")
checkAnswer(spark.table("bar"), Row(0) :: Nil)
val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar"))
assert(!DDLUtils.isDatasourceTable(tableMetadata),
"the expected table is a Hive serde table")
}
}
}
}

0 comments on commit d1a0211

Please sign in to comment.