diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 90cec5e72c1a7..a32a9405676f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -45,7 +45,7 @@ object DataSourceUtils { */ private def verifySchema(format: FileFormat, schema: StructType, isReadPath: Boolean): Unit = { schema.foreach { field => - if (!format.supportDataType(field.dataType, isReadPath)) { + if (!format.supportDataType(field.dataType)) { throw new AnalysisException( s"$format data source does not support ${field.dataType.catalogString} data type.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 2c162e23644ef..f0b49715c8da5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -156,7 +156,7 @@ trait FileFormat { * Returns whether this format supports the given [[DataType]] in read/write path. * By default all data types are supported. */ - def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = true + def supportDataType(dataType: DataType): Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index f4f139d180058..d08a54cc9b1f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -153,10 +153,10 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] - override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + override def supportDataType(dataType: DataType): Boolean = dataType match { case _: AtomicType => true - case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + case udt: UserDefinedType[_] => supportDataType(udt.sqlType) case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 40f55e7068010..d3f04145b83dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -140,17 +140,17 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat] - override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + override def supportDataType(dataType: DataType): Boolean = dataType match { case _: AtomicType => true - case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + case st: StructType => st.forall { f => supportDataType(f.dataType) } - case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + case ArrayType(elementType, _) => supportDataType(elementType) case MapType(keyType, valueType, _) => - supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + supportDataType(keyType) && supportDataType(valueType) - case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + case udt: UserDefinedType[_] => supportDataType(udt.sqlType) case _: NullType => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 14779cdba4178..2a764957be11b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -235,19 +235,17 @@ class OrcFileFormat } } - override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + override def supportDataType(dataType: DataType): Boolean = dataType match { case _: AtomicType => true - case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + case st: StructType => st.forall { f => supportDataType(f.dataType) } - case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + case ArrayType(elementType, _) => supportDataType(elementType) case MapType(keyType, valueType, _) => - supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + supportDataType(keyType) && supportDataType(valueType) - case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) - - case _: NullType => isReadPath + case udt: UserDefinedType[_] => supportDataType(udt.sqlType) case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index f04502d113acb..efa4f3f166d98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -453,17 +453,17 @@ class ParquetFileFormat } } - override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + override def supportDataType(dataType: DataType): Boolean = dataType match { case _: AtomicType => true - case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + case st: StructType => st.forall { f => supportDataType(f.dataType) } - case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + case ArrayType(elementType, _) => supportDataType(elementType) case MapType(keyType, valueType, _) => - supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + supportDataType(keyType) && supportDataType(valueType) - case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) + case udt: UserDefinedType[_] => supportDataType(udt.sqlType) case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 0607f7b3c0d4a..f8a24eb080294 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -139,7 +139,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } } - override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = + override def supportDataType(dataType: DataType): Boolean = dataType == StringType } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 54299e9808bf1..5e6705094e602 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -367,69 +367,43 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { - withTempDir { dir => - val tempDir = new File(dir, "files").getCanonicalPath - - Seq("orc").foreach { format => - // write path - var msg = intercept[AnalysisException] { - sql("select null").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - msg = intercept[AnalysisException] { - spark.udf.register("testType", () => new NullData()) - sql("select testType()").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - // read path - // We expect the types below should be passed for backward-compatibility - - // Null type - var schema = StructType(StructField("a", NullType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - - // UDT having null data - schema = StructType(StructField("a", new NullUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - } - - Seq("parquet", "csv").foreach { format => - // write path - var msg = intercept[AnalysisException] { - sql("select null").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - msg = intercept[AnalysisException] { - spark.udf.register("testType", () => new NullData()) - sql("select testType()").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - // read path - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", NullType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) + // TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well. + withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "orc") { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + Seq("parquet", "csv", "orc").foreach { format => + // write path + var msg = intercept[AnalysisException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + // read path + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(s"$format data source does not support null data type.")) + } } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 4e641e34c18d9..bfb0a95d4e707 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -181,19 +181,17 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable } } - override def supportDataType(dataType: DataType, isReadPath: Boolean): Boolean = dataType match { + override def supportDataType(dataType: DataType): Boolean = dataType match { case _: AtomicType => true - case st: StructType => st.forall { f => supportDataType(f.dataType, isReadPath) } + case st: StructType => st.forall { f => supportDataType(f.dataType) } - case ArrayType(elementType, _) => supportDataType(elementType, isReadPath) + case ArrayType(elementType, _) => supportDataType(elementType) case MapType(keyType, valueType, _) => - supportDataType(keyType, isReadPath) && supportDataType(valueType, isReadPath) + supportDataType(keyType) && supportDataType(valueType) - case udt: UserDefinedType[_] => supportDataType(udt.sqlType, isReadPath) - - case _: NullType => isReadPath + case udt: UserDefinedType[_] => supportDataType(udt.sqlType) case _ => false }