From f099a678e1f6dc0bfb591206402397d492e7720d Mon Sep 17 00:00:00 2001 From: Maryann Xue Date: Mon, 18 Jun 2018 13:47:53 -0700 Subject: [PATCH] [SPARK-24583] Wrong schema type in InsertIntoDataSourceCommand --- .../InsertIntoDataSourceCommand.scala | 3 +- .../spark/sql/sources/InsertSuite.scala | 81 ++++++++++++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index a813829d50cb1..3cfdfcc1812f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -39,7 +39,8 @@ case class InsertIntoDataSourceCommand( val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = Dataset.ofRows(sparkSession, query) // Apply the schema of the existing table to the new data. - val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) + val df = sparkSession.internalCreateDataFrame( + data.queryExecution.toRdd, logicalRelation.schema.asNullable) relation.insert(df, overwrite) // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index fef01c860db6e..e130e03113b17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -20,12 +20,47 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +class SimpleInsertSource extends SchemaRelationProvider { + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + SimpleInsert(schema)(sqlContext.sparkSession) + } +} + +case class SimpleInsert(userSpecifiedSchema: StructType)(@transient val sparkSession: SparkSession) + extends BaseRelation with InsertableRelation { + + override def sqlContext: SQLContext = sparkSession.sqlContext + + override def schema: StructType = userSpecifiedSchema + + override def insert(input: DataFrame, overwrite: Boolean): Unit = { + input.foreach { row => + schema.fields.zipWithIndex.filter(!_._1.nullable).foreach { field => + if (row.get(field._2) == null) { + throw new NotNullableViolationException(field._1.name) + } + } + } + } +} + +class NotNullableViolationException(val message: String) + extends Exception(message) with Serializable { + override def getMessage: String = s"Value for column '$message' cannot be null." +} + class InsertSuite extends DataSourceTest with SharedSQLContext { import testImplicits._ @@ -520,4 +555,48 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } } } + + test("SPARK-24583 Wrong schema type in InsertIntoDataSourceCommand") { + withTable("test_table") { + val schema = new StructType() + .add("i", IntegerType, false) + .add("s", StringType, false) + val newTable = CatalogTable( + identifier = TableIdentifier("test_table", None), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + compressed = false, + properties = Map.empty), + schema = schema, + provider = Some("org.apache.spark.sql.sources.SimpleInsertSource")) + + spark.sessionState.catalog.createTable(newTable, false) + + def verifyException(e: Exception, column: String): Unit = { + var ex = e.getCause + while (ex != null && + !ex.isInstanceOf[NotNullableViolationException]) { + ex = ex.getCause + } + if (ex == null) { + fail(s"Expected a NotNullableViolationException but got '${e.getMessage}'.") + } + assert(ex.getMessage.contains(s"Value for column '$column' cannot be null.")) + } + + sql("INSERT INTO TABLE test_table SELECT 1, 'a'") + verifyException( + intercept[SparkException] { + sql("INSERT INTO TABLE test_table SELECT null, 'b'") + }, "i") + verifyException( + intercept[SparkException] { + sql("INSERT INTO TABLE test_table SELECT 2, null") + }, "s") + } + } }