Skip to content

Commit

Permalink
[SPARK-24583] Wrong schema type in InsertIntoDataSourceCommand
Browse files Browse the repository at this point in the history
  • Loading branch information
maryannxue committed Jun 18, 2018
1 parent 8f225e0 commit f099a67
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ case class InsertIntoDataSourceCommand(
val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
val data = Dataset.ofRows(sparkSession, query)
// Apply the schema of the existing table to the new data.
val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
val df = sparkSession.internalCreateDataFrame(
data.queryExecution.toRdd, logicalRelation.schema.asNullable)
relation.insert(df, overwrite)

// Re-cache all cached plans(including this relation itself, if it's cached) that refer to this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,47 @@ package org.apache.spark.sql.sources
import java.io.File

import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class SimpleInsertSource extends SchemaRelationProvider {
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: StructType): BaseRelation = {
SimpleInsert(schema)(sqlContext.sparkSession)
}
}

case class SimpleInsert(userSpecifiedSchema: StructType)(@transient val sparkSession: SparkSession)
extends BaseRelation with InsertableRelation {

override def sqlContext: SQLContext = sparkSession.sqlContext

override def schema: StructType = userSpecifiedSchema

override def insert(input: DataFrame, overwrite: Boolean): Unit = {
input.foreach { row =>
schema.fields.zipWithIndex.filter(!_._1.nullable).foreach { field =>
if (row.get(field._2) == null) {
throw new NotNullableViolationException(field._1.name)
}
}
}
}
}

class NotNullableViolationException(val message: String)
extends Exception(message) with Serializable {
override def getMessage: String = s"Value for column '$message' cannot be null."
}

class InsertSuite extends DataSourceTest with SharedSQLContext {
import testImplicits._

Expand Down Expand Up @@ -520,4 +555,48 @@ class InsertSuite extends DataSourceTest with SharedSQLContext {
}
}
}

test("SPARK-24583 Wrong schema type in InsertIntoDataSourceCommand") {
withTable("test_table") {
val schema = new StructType()
.add("i", IntegerType, false)
.add("s", StringType, false)
val newTable = CatalogTable(
identifier = TableIdentifier("test_table", None),
tableType = CatalogTableType.EXTERNAL,
storage = CatalogStorageFormat(
locationUri = None,
inputFormat = None,
outputFormat = None,
serde = None,
compressed = false,
properties = Map.empty),
schema = schema,
provider = Some("org.apache.spark.sql.sources.SimpleInsertSource"))

spark.sessionState.catalog.createTable(newTable, false)

def verifyException(e: Exception, column: String): Unit = {
var ex = e.getCause
while (ex != null &&
!ex.isInstanceOf[NotNullableViolationException]) {
ex = ex.getCause
}
if (ex == null) {
fail(s"Expected a NotNullableViolationException but got '${e.getMessage}'.")
}
assert(ex.getMessage.contains(s"Value for column '$column' cannot be null."))
}

sql("INSERT INTO TABLE test_table SELECT 1, 'a'")
verifyException(
intercept[SparkException] {
sql("INSERT INTO TABLE test_table SELECT null, 'b'")
}, "i")
verifyException(
intercept[SparkException] {
sql("INSERT INTO TABLE test_table SELECT 2, null")
}, "s")
}
}
}

0 comments on commit f099a67

Please sign in to comment.