Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-2781][SQL] Check resolution of LogicalPlans in Analyzer. #1706

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
// TODO: pass this in as a parameter.
val fixedPoint = FixedPoint(100)

val batches: Seq[Batch] = Seq(
/**
* Override to provide additional rules for the "Resolution" batch.
*/
val extendedRules: Seq[Rule[LogicalPlan]] = Nil

lazy val batches: Seq[Batch] = Seq(
Batch("MultiInstanceRelations", Once,
NewRelationInstances),
Batch("CaseInsensitiveAttributeReferences", Once,
Expand All @@ -54,23 +59,31 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
StarExpansion ::
ResolveFunctions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
typeCoercionRules :_*),
UnresolvedHavingClauseAttributes ::
typeCoercionRules ++
extendedRules : _*),
Batch("Check Analysis", Once,
CheckResolution),
Batch("AnalysisOperators", fixedPoint,
EliminateAnalysisOperators)
)

/**
* Makes sure all attributes have been resolved.
* Makes sure all attributes and logical plans have been resolved.
*/
object CheckResolution extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transform {
case p if p.expressions.exists(!_.resolved) =>
throw new TreeNodeException(p,
s"Unresolved attributes: ${p.expressions.filterNot(_.resolved).mkString(",")}")
case p if !p.resolved && p.childrenResolved =>
throw new TreeNodeException(p, "Unresolved plan found")
} match {
// As a backstop, use the root node to check that the entire plan tree is resolved.
case p if !p.resolved =>
throw new TreeNodeException(p, "Unresolved plan in tree")
case p => p
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,10 @@ trait HiveTypeCoercion {
// If the data type is not boolean and is being cast boolean, turn it into a comparison
// with the numeric value, i.e. x != 0. This will coerce the type into numeric type.
case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0)))
// Stringify boolean if casting to StringType.
// TODO Ensure true/false string letter casing is consistent with Hive in all cases.
case Cast(e, StringType) if e.dataType == BooleanType =>
If(e, Literal("true"), Literal("false"))
// Turn true into 1, and false into 0 if casting boolean into other types.
case Cast(e, dataType) if e.dataType == BooleanType =>
Cast(If(e, Literal(1), Literal(0)), dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {

/**
* Returns true if this expression and all its children have been resolved to a specific schema
* and false if it is still contains any unresolved placeholders. Implementations of LogicalPlan
* and false if it still contains any unresolved placeholders. Implementations of LogicalPlan
* can override this (e.g.
* [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]]
* should return `false`).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,17 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
val e = intercept[TreeNodeException[_]] {
caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation))
}
assert(e.getMessage().toLowerCase.contains("unresolved"))
assert(e.getMessage().toLowerCase.contains("unresolved attribute"))
}

test("throw errors for unresolved plans during analysis") {
case class UnresolvedTestPlan() extends LeafNode {
override lazy val resolved = false
override def output = Nil
}
val e = intercept[TreeNodeException[_]] {
caseSensitiveAnalyze(UnresolvedTestPlan())
}
assert(e.getMessage().toLowerCase.contains("unresolved plan"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.analysis

import org.scalatest.FunSuite

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.types._

class HiveTypeCoercionSuite extends FunSuite {
Expand Down Expand Up @@ -84,4 +86,17 @@ class HiveTypeCoercionSuite extends FunSuite {
widenTest(StringType, MapType(IntegerType, StringType, true), None)
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
}

test("boolean casts") {
val booleanCasts = new HiveTypeCoercion { }.BooleanCasts
def ruleTest(initial: Expression, transformed: Expression) {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) ==
Project(Seq(Alias(transformed, "a")()), testRelation))
}
// Remove superflous boolean -> boolean casts.
ruleTest(Cast(Literal(true), BooleanType), Literal(true))
// Stringify boolean when casting to string.
ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false")))
}
}
45 changes: 41 additions & 4 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test._
import org.scalatest.BeforeAndAfterAll
Expand Down Expand Up @@ -477,18 +478,48 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
(3, null)))
}

test("EXCEPT") {
test("UNION") {
checkAnswer(
sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"),
(1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") ::
(4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"),
(1, "a") :: (2, "b") :: (3, "c") :: (4, "d") :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"),
(1, "a") :: (1, "a") :: (2, "b") :: (2, "b") :: (3, "c") :: (3, "c") ::
(4, "d") :: (4, "d") :: Nil)
}

test("UNION with column mismatches") {
// Column name mismatches are allowed.
checkAnswer(
sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"),
(1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") ::
(4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil)
// Column type mismatches are not allowed, forcing a type coercion.
checkAnswer(
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData "),
sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"),
("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Tuple1(_)))
// Column type mismatches where a coercion is not possible, in this case between integer
// and array types, trigger a TreeNodeException.
intercept[TreeNodeException[_]] {
sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect()
}
}

test("EXCEPT") {
checkAnswer(
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"),
(1, "a") ::
(2, "b") ::
(3, "c") ::
(4, "d") :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData "), Nil)
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil)
checkAnswer(
sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData "), Nil)
sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil)
}

test("INTERSECT") {
Expand Down Expand Up @@ -634,6 +665,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"),
Seq()
)
}

test("cast boolean to string") {
// TODO Ensure true/false string letter casing is consistent with Hive in all cases.
checkAnswer(
sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"),
("true", "false") :: Nil)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
/* An analyzer that uses the Hive metastore. */
@transient
override protected[sql] lazy val analyzer =
new Analyzer(catalog, functionRegistry, caseSensitive = false)
new Analyzer(catalog, functionRegistry, caseSensitive = false) {
override val extendedRules =
catalog.CreateTables ::
catalog.PreInsertionCasts ::
ExtractPythonUdfs ::
Nil
}

/**
* Runs the specified SQL query using Hive.
Expand Down Expand Up @@ -353,9 +359,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {

/** Extends QueryExecution with hive specific features. */
protected[sql] abstract class QueryExecution extends super.QueryExecution {
// TODO: Create mixin for the analyzer instead of overriding things here.
override lazy val optimizedPlan =
optimizer(ExtractPythonUdfs(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))))

override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,17 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
*/
object CreateTables extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Wait until children are resolved.
case p: LogicalPlan if !p.childrenResolved => p

case InsertIntoCreatedTable(db, tableName, child) =>
val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase)

createTable(databaseName, tblName, child.output)

InsertIntoTable(
EliminateAnalysisOperators(
lookupRelation(Some(databaseName), tblName, None)),
lookupRelation(Some(databaseName), tblName, None),
Map.empty,
child,
overwrite = false)
Expand All @@ -130,15 +132,17 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
*/
object PreInsertionCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
// Wait until children are resolved
// Wait until children are resolved.
case p: LogicalPlan if !p.childrenResolved => p

case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) =>
case p @ InsertIntoTable(
LowerCaseSchema(table: MetastoreRelation), _, child, _) =>
castChildOutput(p, table, child)

case p @ logical.InsertIntoTable(
InMemoryRelation(_, _, _,
HiveTableScan(_, table, _)), _, child, _) =>
LowerCaseSchema(
InMemoryRelation(_, _, _,
HiveTableScan(_, table, _))), _, child, _) =>
castChildOutput(p, table, child)
}

Expand Down