Skip to content

Commit

Permalink
Update style, tests, comments, and add migration guide
Browse files Browse the repository at this point in the history
  • Loading branch information
Kimahriman committed Jun 12, 2021
1 parent 0b0990a commit 7bca531
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 172 deletions.
2 changes: 2 additions & 0 deletions docs/sql-migration-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ license: |
- In Spark 3.2, `FloatType` is mapped to `FLOAT` in MySQL. Prior to this, it used to be mapped to `REAL`, which is by default a synonym to `DOUBLE PRECISION` in MySQL.

- In Spark 3.2, the query executions triggered by `DataFrameWriter` are always named `command` when being sent to `QueryExecutionListener`. In Spark 3.1 and earlier, the name is one of `save`, `insertInto`, `saveAsTable`, `create`, `append`, `overwrite`, `overwritePartitions`, `replace`.

- In Spark 3.2, `Dataset.unionByName` with `allowMissingColumns` set to true will add missing nested fields to the end of structs. In Spark 3.1, nested struct fields are sorted alphabetically.

## Upgrading from Spark SQL 3.0 to 3.1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,19 @@ import org.apache.spark.sql.util.SchemaUtils
*/
object ResolveUnion extends Rule[LogicalPlan] {
/**
* Adds missing fields recursively into given `col` expression, based on the target `StructType`.
* This is called by `compareAndAddFields` when we find two struct columns with same name but
* different nested fields. This method will find out the missing nested fields from `col` to
* `target` struct and add these missing nested fields. Currently we don't support finding out
* missing nested fields of struct nested in array or struct nested in map.
* Adds missing fields recursively into given `col` expression, based on the expected struct
* fields from merging the two schemas. This is called by `compareAndAddFields` when we find two
* struct columns with same name but different nested fields. This method will recursively
* return a new struct with all of the expected fields, adding null values when `col` doesn't
* already contain them. Currently we don't support merging structs nested inside of arrays
* or maps.
*/
private def addFields(col: Expression, expectedFields: Seq[StructField]): Expression = {
assert(col.dataType.isInstanceOf[StructType], "Only support StructType.")

val resolver = conf.resolver
val colType = col.dataType.asInstanceOf[StructType]
val newStructFields = expectedFields.flatMap(expectedField => {
val newStructFields = expectedFields.flatMap { expectedField =>
val currentField = colType.fields.find(f => resolver(f.name, expectedField.name))

val newExpression = (currentField, expectedField.dataType) match {
Expand All @@ -60,11 +61,11 @@ object ResolveUnion extends Rule[LogicalPlan] {
}
case (Some(cf), _) =>
ExtractValue(col, Literal(cf.name), resolver)
case (_, expectedType) =>
case (None, expectedType) =>
Literal(null, expectedType)
}
Literal(expectedField.name) :: newExpression :: Nil
})
}
CreateNamedStruct(newStructFields)
}

Expand Down Expand Up @@ -99,11 +100,7 @@ object ResolveUnion extends Rule[LogicalPlan] {
// like that. We will sort columns in the struct expression to make sure two sides of
// union have consistent schema.
aliased += foundAttr
val targetType = try {
target.merge(source, conf.resolver)
} catch { case e: Throwable =>
throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError(target, source, e)
}
val targetType = target.merge(source, conf.resolver)
Alias(addFields(foundAttr, targetType.fields.toSeq), foundAttr.name)()
case _ =>
// We don't need/try to add missing fields if:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
* 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be
* thrown.
*/
private[sql] def merge(that: StructType, resolver: Resolver = _ == _): StructType =
private[sql] def merge(that: StructType, resolver: Resolver): StructType =
StructType.merge(resolver)(this, that).asInstanceOf[StructType]

override private[spark] def asNullable: StructType = {
Expand Down Expand Up @@ -630,39 +630,4 @@ object StructType extends AbstractDataType {
fields.foreach(s => map.put(s.name, s))
map
}

/**
* Returns a `StructType` that contains missing fields recursively from `source` to `target`.
* Note that this doesn't support looking into array type and map type recursively.
*/
def findMissingFields(
source: StructType,
target: StructType,
resolver: Resolver): Option[StructType] = {
def bothStructType(dt1: DataType, dt2: DataType): Boolean =
dt1.isInstanceOf[StructType] && dt2.isInstanceOf[StructType]

val newFields = mutable.ArrayBuffer.empty[StructField]

target.fields.foreach { field =>
val found = source.fields.find(f => resolver(field.name, f.name))
if (found.isEmpty) {
// Found a missing field in `source`.
newFields += field
} else if (bothStructType(found.get.dataType, field.dataType) &&
!found.get.dataType.sameType(field.dataType)) {
// Found a field with same name, but different data type.
findMissingFields(found.get.dataType.asInstanceOf[StructType],
field.dataType.asInstanceOf[StructType], resolver).map { missingType =>
newFields += found.get.copy(dataType = missingType)
}
}
}

if (newFields.isEmpty) {
None
} else {
Some(StructType(newFields.toSeq))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.fasterxml.jackson.core.JsonParseException
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataTypeTestUtils.dayTimeIntervalTypes

class DataTypeSuite extends SparkFunSuite {
Expand Down Expand Up @@ -153,7 +154,7 @@ class DataTypeSuite extends SparkFunSuite {
StructField("b", LongType) :: Nil)

val message = intercept[SparkException] {
left.merge(right)
left.merge(right, SQLConf.get.resolver)
}.getMessage
assert(message.equals("Failed to merge fields 'b' and 'b'. " +
"Failed to merge incompatible data types float and bigint"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,95 +150,36 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper {
assert(fromDDL(interval).toDDL === interval)
}

test("find missing (nested) fields") {
val schema = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>")
test("SPARK-35290: Struct merging case insensitive") {
val schema1 = StructType.fromDDL("a1 INT, a2 STRING, nested STRUCT<b1: INT, b2: STRING>")
val schema2 = StructType.fromDDL("A2 STRING, a3 DOUBLE, nested STRUCT<B2: STRING, b3: DOUBLE>")
val resolver = SQLConf.get.resolver

val source1 = StructType.fromDDL("c1 INT")
val missing1 = StructType.fromDDL("c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>")
assert(StructType.findMissingFields(source1, schema, resolver)
.exists(_.sameType(missing1)))
assert(schema1.merge(schema2, resolver).sameType(StructType.fromDDL(
"a1 INT, a2 STRING, nested STRUCT<b1: INT, b2: STRING, b3: DOUBLE>, a3 DOUBLE"
)))

val source2 = StructType.fromDDL("c1 INT, c3 STRING")
val missing2 = StructType.fromDDL("c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>")
assert(StructType.findMissingFields(source2, schema, resolver)
.exists(_.sameType(missing2)))

val source3 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT>")
val missing3 = StructType.fromDDL("c2 STRUCT<c4: STRUCT<c5: INT, c6: INT>>")
assert(StructType.findMissingFields(source3, schema, resolver)
.exists(_.sameType(missing3)))

val source4 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c6: INT>>")
val missing4 = StructType.fromDDL("c2 STRUCT<c4: STRUCT<c5: INT>>")
assert(StructType.findMissingFields(source4, schema, resolver)
.exists(_.sameType(missing4)))
assert(schema2.merge(schema1, resolver).sameType(StructType.fromDDL(
"a2 STRING, a3 DOUBLE, nested STRUCT<b2: STRING, b3: DOUBLE, b1: INT>, a1 INT"
)))
}

test("find missing (nested) fields: array and map") {
val resolver = SQLConf.get.resolver

val schemaWithArray = StructType.fromDDL("c1 INT, c2 ARRAY<STRUCT<c3: INT, c4: LONG>>")
val source5 = StructType.fromDDL("c1 INT")
val missing5 = StructType.fromDDL("c2 ARRAY<STRUCT<c3: INT, c4: LONG>>")
assert(
StructType.findMissingFields(source5, schemaWithArray, resolver)
.exists(_.sameType(missing5)))

val schemaWithMap1 = StructType.fromDDL(
"c1 INT, c2 MAP<STRUCT<c3: INT, c4: LONG>, STRING>, c3 LONG")
val source6 = StructType.fromDDL("c1 INT, c3 LONG")
val missing6 = StructType.fromDDL("c2 MAP<STRUCT<c3: INT, c4: LONG>, STRING>")
assert(
StructType.findMissingFields(source6, schemaWithMap1, resolver)
.exists(_.sameType(missing6)))

val schemaWithMap2 = StructType.fromDDL(
"c1 INT, c2 MAP<STRING, STRUCT<c3: INT, c4: LONG>>, c3 STRING")
val source7 = StructType.fromDDL("c1 INT, c3 STRING")
val missing7 = StructType.fromDDL("c2 MAP<STRING, STRUCT<c3: INT, c4: LONG>>")
assert(
StructType.findMissingFields(source7, schemaWithMap2, resolver)
.exists(_.sameType(missing7)))

// Unsupported: nested struct in array, map
val source8 = StructType.fromDDL("c1 INT, c2 ARRAY<STRUCT<c3: INT>>")
// `findMissingFields` doesn't support looking into nested struct in array type.
assert(StructType.findMissingFields(source8, schemaWithArray, resolver).isEmpty)

val source9 = StructType.fromDDL("c1 INT, c2 MAP<STRUCT<c3: INT>, STRING>, c3 LONG")
// `findMissingFields` doesn't support looking into nested struct in map type.
assert(StructType.findMissingFields(source9, schemaWithMap1, resolver).isEmpty)

val source10 = StructType.fromDDL("c1 INT, c2 MAP<STRING, STRUCT<c3: INT>>, c3 STRING")
// `findMissingFields` doesn't support looking into nested struct in map type.
assert(StructType.findMissingFields(source10, schemaWithMap2, resolver).isEmpty)
}

test("find missing (nested) fields: case sensitive cases") {
test("SPARK-35290: Struct merging case sensitive") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
val schema = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT, C4: STRUCT<C5: INT, c6: INT>>")
val schema1 = StructType.fromDDL("a1 INT, a2 STRING, nested STRUCT<b1: INT, b2: STRING>")
val schema2 = StructType.fromDDL(
"A2 STRING, a3 DOUBLE, nested STRUCT<B2: STRING, b3: DOUBLE>")
val resolver = SQLConf.get.resolver

val source1 = StructType.fromDDL("c1 INT, C2 LONG")
val missing1 = StructType.fromDDL("c2 STRUCT<c3: INT, C4: STRUCT<C5: INT, c6: INT>>")
assert(StructType.findMissingFields(source1, schema, resolver)
.exists(_.sameType(missing1)))

val source2 = StructType.fromDDL("c2 LONG")
val missing2 = StructType.fromDDL("c1 INT")
assert(StructType.findMissingFields(source2, schema, resolver)
.exists(_.sameType(missing2)))

val source3 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT, C4: STRUCT<c5: INT>>")
val missing3 = StructType.fromDDL("c2 STRUCT<C4: STRUCT<C5: INT, c6: INT>>")
assert(StructType.findMissingFields(source3, schema, resolver)
.exists(_.sameType(missing3)))

val source4 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT, C4: STRUCT<C5: Int>>")
val missing4 = StructType.fromDDL("c2 STRUCT<C4: STRUCT<c6: INT>>")
assert(StructType.findMissingFields(source4, schema, resolver)
.exists(_.sameType(missing4)))
assert(schema1.merge(schema2, resolver).sameType(StructType.fromDDL(
"a1 INT, a2 STRING, nested STRUCT<b1: INT, b2: STRING, B2: STRING, b3: DOUBLE>, " +
"A2 STRING, a3 DOUBLE"
)))

assert(schema2.merge(schema1, resolver).sameType(StructType.fromDDL(
"A2 STRING, a3 DOUBLE, nested STRUCT<B2: STRING, b3: DOUBLE, b1: INT, b2: STRING>, " +
"a1 INT, a2 STRING"
)))
}
}

Expand Down
6 changes: 2 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2080,10 +2080,8 @@ class Dataset[T] private[sql](
* }}}
*
* Note that `allowMissingColumns` supports nested column in struct types. Missing nested columns
* of struct columns with same name will also be filled with null values. This currently does not
* support nested columns in array and map types. Note that if there is any missing nested columns
* to be filled, in order to make consistent schema between two sides of union, the nested fields
* of structs will be sorted after merging schema.
* of struct columns with the same name will also be filled with null values and added to the end
* of struct. This currently does not support nested columns in array and map types.
*
* @group typedrel
* @since 3.1.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -736,37 +736,6 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession {
}
}

test("SPARK-32376: Make unionByName null-filling behavior work with struct columns" +
" - case-insensitive cases") {
val df1 = Seq((0, UnionClass1d(0, 1L, UnionClass2(1, "2")))).toDF("id", "a")
val df2 = Seq((1, UnionClass1c(1, 2L, UnionClass4(2, 3L)))).toDF("id", "a")

var unionDf = df1.unionByName(df2, true)
assert(unionDf.schema.toDDL ==
"`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " +
"`Nested`: STRUCT<`a`: INT, `c`: STRING, `b`: BIGINT>>")
checkAnswer(unionDf,
Row(0, Row(0, 1, Row(1, "2", null))) ::
Row(1, Row(1, 2, Row(2, null, 3L))) :: Nil)

unionDf = df2.unionByName(df1, true)
assert(unionDf.schema.toDDL ==
"`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " +
"`nested`: STRUCT<`A`: INT, `b`: BIGINT, `c`: STRING>>")
checkAnswer(unionDf,
Row(1, Row(1, 2, Row(2, 3L, null))) ::
Row(0, Row(0, 1, Row(1, null, "2"))) :: Nil)

val df3 = Seq((2, UnionClass1b(2, 3L, UnionClass3(4, 5L)))).toDF("id", "a")
unionDf = df2.unionByName(df3, true)
assert(unionDf.schema.toDDL ==
"`id` INT,`a` STRUCT<`a`: INT, `b`: BIGINT, " +
"`nested`: STRUCT<`A`: INT, `b`: BIGINT>>")
checkAnswer(unionDf,
Row(1, Row(1, 2, Row(2, 3L))) ::
Row(2, Row(2, 3, Row(4, 5L))) :: Nil)
}

test("SPARK-32376: Make unionByName null-filling behavior work with struct columns - edge case") {
val nestedStructType1 = StructType(Seq(
StructField("b", StringType)))
Expand All @@ -777,11 +746,11 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession {
StructField("a", StringType)))
val nestedStructValues2 = Row("b", "a")

val df1: DataFrame = spark.createDataFrame(
val df1 = spark.createDataFrame(
sparkContext.parallelize(Row(nestedStructValues1) :: Nil),
StructType(Seq(StructField("topLevelCol", nestedStructType1))))

val df2: DataFrame = spark.createDataFrame(
val df2 = spark.createDataFrame(
sparkContext.parallelize(Row(nestedStructValues2) :: Nil),
StructType(Seq(StructField("topLevelCol", nestedStructType2))))

Expand Down Expand Up @@ -809,11 +778,11 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession {
))
val nestedStructValues2 = Row(Row("aa"), Row("bb"))

val df1: DataFrame = spark.createDataFrame(
val df1 = spark.createDataFrame(
sparkContext.parallelize(Row(nestedStructValues1) :: Nil),
StructType(Seq(StructField("topLevelCol", nestedStructType1))))

val df2: DataFrame = spark.createDataFrame(
val df2 = spark.createDataFrame(
sparkContext.parallelize(Row(nestedStructValues2) :: Nil),
StructType(Seq(StructField("topLevelCol", nestedStructType2))))

Expand Down Expand Up @@ -853,7 +822,7 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession {
depthCounter -= 1
}

val df: DataFrame = spark.createDataFrame(
val df = spark.createDataFrame(
sparkContext.parallelize(Row(struct) :: Nil),
StructType(Seq(StructField("nested0Col0", structType))))

Expand Down Expand Up @@ -958,7 +927,6 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession {
case class UnionClass1a(a: Int, b: Long, nested: UnionClass2)
case class UnionClass1b(a: Int, b: Long, nested: UnionClass3)
case class UnionClass1c(a: Int, b: Long, nested: UnionClass4)
case class UnionClass1d(a: Int, b: Long, Nested: UnionClass2)

case class UnionClass2(a: Int, c: String)
case class UnionClass3(a: Int, b: Long)
Expand Down

0 comments on commit 7bca531

Please sign in to comment.