diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala index a01008bf62f..1bfc74ee1bc 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaUtils.scala @@ -286,8 +286,10 @@ def normalizeColumnNamesInDataType( // When schema evolution adds a new column during MERGE, it can be represented with // a NullType in the schema of the data written by the MERGE. sourceDataType - case (_: IntegralType, _: IntegralType) => - // The integral types can be cast to each other later on. + case (_: AtomicType, _: AtomicType) => + // Some atomic types (e.g. integral types) can be cast to each other later on. For now, + // it's enough to know that there are no nested fields inside the atomic types that might + // require normalization. sourceDataType case _ => if (DeltaUtils.isTesting) { diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala index c6939677d71..5bd6a73c924 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala @@ -1566,61 +1566,66 @@ class SchemaUtilsSuite extends QueryTest ///////////////////////////////// // normalizeColumnNamesInDataType ///////////////////////////////// - - private def checkNormalizedColumnNamesInDataType( + private def runNormalizeColumnNamesInDataType( sourceDataType: DataType, - tableDataType: DataType, - expectedDataType: DataType): Unit = { - assert(normalizeColumnNamesInDataType( + tableDataType: DataType): DataType = { + normalizeColumnNamesInDataType( deltaLog = null, sourceDataType, tableDataType, sourceParentFields = Seq.empty, - tableSchema = new StructType()) == expectedDataType) + tableSchema = new StructType()) } - test("normalize column names in data type - atomic types") { + test("normalize column names in data type - top-level atomic types") { val source = new StructType() .add("a", IntegerType) .add("b", StringType) + .add("c", LongType) + .add("d", DateType) val table = new StructType() .add("B", StringType) - .add("A", IntegerType) + .add("A", LongType) // LongType != IntegerType + .add("D", DecimalType(10, 0)) // DecimalType != DateType + .add("C", StringType) // StringType != LongType val expected = new StructType() .add("A", IntegerType) .add("B", StringType) - checkNormalizedColumnNamesInDataType(source, table, expected) + .add("C", LongType) + .add("D", DateType) + assert(runNormalizeColumnNamesInDataType(source, table) == expected) } - test("normalize column names in data type - incompatible atomic types") { - val source = new StructType() + test("normalize column names in data type - incompatible top-level types") { + val schema1a = new StructType() .add("a", IntegerType) .add("b", StringType) - val table = new StructType() + val schema1b = new StructType() .add("B", StringType) - .add("A", StringType) // StringType != IntegerType - val exception = intercept[AssertionError] { - normalizeColumnNamesInDataType( - deltaLog = null, - source, - table, - sourceParentFields = Seq.empty, - tableSchema = new StructType()) + .add("A", new StructType()) // StructType != IntegerType + intercept[AssertionError] { + runNormalizeColumnNamesInDataType(schema1a, schema1b) + } + intercept[AssertionError] { + runNormalizeColumnNamesInDataType(schema1b, schema1a) } - assert(exception.getMessage.contains("Types without nesting should match")) - } - test("normalize column names in data type - different integral types") { - val source = new StructType() - .add("a", IntegerType) - .add("b", StringType) - val table = new StructType() - .add("B", StringType) - .add("A", LongType) // LongType != IntegerType - val expected = new StructType() - .add("A", IntegerType) - .add("B", StringType) - checkNormalizedColumnNamesInDataType(source, table, expected) + val schema2a = new StructType() + .add("x", StringType) + .add("y", new StructType() + .add("z", IntegerType) + ) + val schema2b = new StructType() + .add("x", StringType) + .add("Y", new StructType() + .add("z", ArrayType(IntegerType)) // ArrayType != IntegerType + ) + intercept[AssertionError] { + runNormalizeColumnNamesInDataType(schema2a, schema2b) + } + intercept[AssertionError] { + runNormalizeColumnNamesInDataType(schema2b, schema2a) + } } test("normalize column names in data type - nested structs") { @@ -1669,10 +1674,10 @@ class SchemaUtilsSuite extends QueryTest .add("D1", IntegerType) .add("D2", LongType) ) - checkNormalizedColumnNamesInDataType(source, table, expected) + assert(runNormalizeColumnNamesInDataType(source, table) == expected) } - test("normalize column names in data type - incompatible types in a struct") { + test("normalize column names in data type - different atomic types in a map") { val source = new StructType() .add("a", new StructType() .add("b", new StructType() @@ -1685,8 +1690,33 @@ class SchemaUtilsSuite extends QueryTest .add("A", new StructType() .add("B", new StructType() .add("C", MapType(StringType, IntegerType)))) - assertThrows[AssertionError] { - checkNormalizedColumnNamesInDataType(source, table, expected) + assert(runNormalizeColumnNamesInDataType(source, table) == expected) + } + + test("normalize column names in data type - incompatible nested types") { + val schema1 = new StructType() + .add("a", new StructType() + .add("b", new StructType() + .add("c", IntegerType))) + val schema2 = new StructType() + .add("A", new StructType() + .add("B", new StructType() + .add("C", ArrayType(IntegerType)))) + val schema3 = new StructType() + .add("A", new StructType() + .add("b", new StructType() + .add("C", new StructType()))) + val schemas = Seq(schema1, schema2, schema3) + + for (left <- schemas; right <- schemas) { + if (left == right) { + // Make sure there's no error when the schemas are the same. + assert(runNormalizeColumnNamesInDataType(left, right) == left) + } else { + intercept[AssertionError] { + runNormalizeColumnNamesInDataType(left, right) + } + } } } @@ -1713,7 +1743,7 @@ class SchemaUtilsSuite extends QueryTest ArrayType(new StructType() .add("Aa", IntegerType) .add("Bb", StringType))) - checkNormalizedColumnNamesInDataType(source, table, expected) + assert(runNormalizeColumnNamesInDataType(source, table) == expected) } test("normalize column names in data type - missing column") { @@ -1770,14 +1800,14 @@ class SchemaUtilsSuite extends QueryTest nullable = true, comment = "comment for b3"), nullable = false, comment = "comment for a2" ) - checkNormalizedColumnNamesInDataType(source, table, expected) + assert(runNormalizeColumnNamesInDataType(source, table) == expected) } test("normalize column names in data type - empty source struct") { val source = new StructType() val table = new StructType().add("a", IntegerType) val expected = new StructType() - checkNormalizedColumnNamesInDataType(source, table, expected) + assert(runNormalizeColumnNamesInDataType(source, table) == expected) } ////////////////////////////