diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 565d10247f10e..afe2c6c11ac69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -43,7 +43,7 @@ private[sql] object InferSchema { } // perform schema inference on each row and merge afterwards - schemaData.mapPartitions { iter => + val rootType = schemaData.mapPartitions { iter => val factory = new JsonFactory() iter.map { row => try { @@ -55,8 +55,13 @@ private[sql] object InferSchema { StructType(Seq(StructField(columnNameOfCorruptRecords, StringType))) } } - }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) match { - case st: StructType => nullTypeToStringType(st) + }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) + + canonicalizeType(rootType) match { + case Some(st: StructType) => st + case _ => + // canonicalizeType erases all empty structs, including the only one we want to keep + StructType(Seq()) } } @@ -116,22 +121,35 @@ private[sql] object InferSchema { } } - private def nullTypeToStringType(struct: StructType): StructType = { - val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable, _) => - val newType = dataType match { - case NullType => StringType - case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) - case ArrayType(struct: StructType, containsNull) => - ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType => nullTypeToStringType(struct) - case other: DataType => other - } + /** + * Convert NullType to StringType and remove StructTypes with no fields + */ + private def canonicalizeType: DataType => Option[DataType] = { + case at@ArrayType(elementType, _) => + for { + canonicalType <- canonicalizeType(elementType) + } yield { + at.copy(canonicalType) + } - StructField(fieldName, newType, nullable) - } + case StructType(fields) => + val canonicalFields = for { + field <- fields + if field.name.nonEmpty + canonicalType <- canonicalizeType(field.dataType) + } yield { + field.copy(dataType = canonicalType) + } + + if (canonicalFields.nonEmpty) { + Some(StructType(canonicalFields)) + } else { + // per SPARK-8093: empty structs should be deleted + None + } - StructType(fields) + case NullType => Some(StringType) + case other => Some(other) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index fca24364fe6ec..b5615a7ed3632 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -1103,4 +1103,8 @@ class JsonSuite extends QueryTest with TestJsonData { } } + test("SPARK-8093 Erase empty structs") { + val emptySchema = InferSchema(emptyRecords, 1.0, "") + assert(StructType(Seq()) === emptySchema) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index b6a6a8dc6a63c..eb62066ac6430 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -189,5 +189,14 @@ trait TestJsonData { """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """]""" :: Nil) + def emptyRecords: RDD[String] = + ctx.sparkContext.parallelize( + """{""" :: + """""" :: + """{"a": {}}""" :: + """{"a": {"b": {}}}""" :: + """{"b": [{"c": {}}]}""" :: + """]""" :: Nil) + def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]()) }