From f9542d008402f8cef96d5ec347583c7c1d30d840 Mon Sep 17 00:00:00 2001 From: Nikola Mandic Date: Mon, 3 Jun 2024 13:00:34 -0700 Subject: [PATCH] [SPARK-48413][SQL] ALTER COLUMN with collation ### What changes were proposed in this pull request? Add support for changing collation of a column with `ALTER COLUMN` command. Use existing support for `ALTER COLUMN` with type to enable changing collations of column. Syntax example: ``` ALTER TABLE t1 ALTER COLUMN col TYPE STRING COLLATE UTF8_BINARY_LCASE ``` ### Why are the changes needed? Enable changing collation on column. ### Does this PR introduce _any_ user-facing change? Yes, it adds support for changing collation of column. ### How was this patch tested? Added tests to `DDLSuite` and `DataTypeSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46734 from nikolamand-db/SPARK-48413. Authored-by: Nikola Mandic Signed-off-by: Wenchen Fan --- .../resources/error/error-conditions.json | 6 + .../org/apache/spark/sql/types/DataType.scala | 35 ++++++ .../sql/errors/QueryCompilationErrors.scala | 9 ++ .../spark/sql/types/DataTypeSuite.scala | 109 ++++++++++++++++++ .../spark/sql/execution/command/ddl.scala | 50 +++++--- .../sql/execution/command/DDLSuite.scala | 94 +++++++++++++++ 6 files changed, 290 insertions(+), 13 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 69965e58fb79c..5bab14e3eebf7 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -119,6 +119,12 @@ ], "sqlState" : "42KDE" }, + "CANNOT_ALTER_COLLATION_BUCKET_COLUMN" : { + "message" : [ + "ALTER TABLE (ALTER|CHANGE) COLUMN cannot change collation of type/subtypes of bucket columns, but found the bucket column in the table ." + ], + "sqlState" : "428FR" + }, "CANNOT_ALTER_PARTITION_COLUMN" : { "message" : [ "ALTER TABLE (ALTER|CHANGE) COLUMN is not supported for partition columns, but found the partition column in the table ." diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index ea90aa2ca397b..12c7905f62d1a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -408,6 +408,41 @@ object DataType { } } + /** + * Check if `from` is equal to `to` type except for collations, which are checked to be + * compatible so that data of type `from` can be interpreted as of type `to`. + */ + private[sql] def equalsIgnoreCompatibleCollation( + from: DataType, + to: DataType): Boolean = { + (from, to) match { + // String types with possibly different collations are compatible. + case (_: StringType, _: StringType) => true + + case (ArrayType(fromElement, fromContainsNull), ArrayType(toElement, toContainsNull)) => + (fromContainsNull == toContainsNull) && + equalsIgnoreCompatibleCollation(fromElement, toElement) + + case (MapType(fromKey, fromValue, fromContainsNull), + MapType(toKey, toValue, toContainsNull)) => + fromContainsNull == toContainsNull && + // Map keys cannot change collation. + fromKey == toKey && + equalsIgnoreCompatibleCollation(fromValue, toValue) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { case (fromField, toField) => + fromField.name == toField.name && + fromField.nullable == toField.nullable && + fromField.metadata == toField.metadata && + equalsIgnoreCompatibleCollation(fromField.dataType, toField.dataType) + } + + case (fromDataType, toDataType) => fromDataType == toDataType + } + } + /** * Returns true if the two data types share the same "shape", i.e. the types * are the same, but the field names don't need to be the same. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index f84abe270853b..7b9eb2020a5f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -2687,6 +2687,15 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat Map("tableName" -> toSQLId(tableName), "columnName" -> toSQLId(columnName)) ) } + + def cannotAlterCollationBucketColumn(tableName: String, columnName: String): Throwable = { + new AnalysisException( + errorClass = "CANNOT_ALTER_COLLATION_BUCKET_COLUMN", + messageParameters = + Map("tableName" -> toSQLId(tableName), "columnName" -> toSQLId(columnName)) + ) + } + def cannotFindColumnError(name: String, fieldNames: Array[String]): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1246", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index c06d90f6e9522..cfda19ed67a03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -689,6 +689,115 @@ class DataTypeSuite extends SparkFunSuite { false, caseSensitive = true) + def checkEqualsIgnoreCompatibleCollation( + from: DataType, + to: DataType, + expected: Boolean): Unit = { + val testName = s"equalsIgnoreCompatibleCollation: (from: $from, to: $to)" + + test(testName) { + assert(DataType.equalsIgnoreCompatibleCollation(from, to) === expected) + } + } + + // Simple types. + checkEqualsIgnoreCompatibleCollation(IntegerType, IntegerType, expected = true) + checkEqualsIgnoreCompatibleCollation(BooleanType, BooleanType, expected = true) + checkEqualsIgnoreCompatibleCollation(StringType, StringType, expected = true) + checkEqualsIgnoreCompatibleCollation(IntegerType, BooleanType, expected = false) + checkEqualsIgnoreCompatibleCollation(BooleanType, IntegerType, expected = false) + checkEqualsIgnoreCompatibleCollation(StringType, BooleanType, expected = false) + checkEqualsIgnoreCompatibleCollation(BooleanType, StringType, expected = false) + checkEqualsIgnoreCompatibleCollation(StringType, IntegerType, expected = false) + checkEqualsIgnoreCompatibleCollation(IntegerType, StringType, expected = false) + // Collated `StringType`. + checkEqualsIgnoreCompatibleCollation(StringType, StringType("UTF8_BINARY_LCASE"), + expected = true) + checkEqualsIgnoreCompatibleCollation( + StringType("UTF8_BINARY"), StringType("UTF8_BINARY_LCASE"), expected = true) + // Complex types. + checkEqualsIgnoreCompatibleCollation( + ArrayType(StringType), + ArrayType(StringType("UTF8_BINARY_LCASE")), + expected = true + ) + checkEqualsIgnoreCompatibleCollation( + ArrayType(StringType), + ArrayType(ArrayType(StringType("UTF8_BINARY_LCASE"))), + expected = false + ) + checkEqualsIgnoreCompatibleCollation( + ArrayType(ArrayType(StringType)), + ArrayType(ArrayType(StringType("UTF8_BINARY_LCASE"))), + expected = true + ) + checkEqualsIgnoreCompatibleCollation( + MapType(StringType, StringType), + MapType(StringType, StringType("UTF8_BINARY_LCASE")), + expected = true + ) + checkEqualsIgnoreCompatibleCollation( + MapType(StringType("UTF8_BINARY_LCASE"), StringType), + MapType(StringType, StringType), + expected = false + ) + checkEqualsIgnoreCompatibleCollation( + MapType(StringType("UTF8_BINARY_LCASE"), ArrayType(StringType)), + MapType(StringType("UTF8_BINARY_LCASE"), ArrayType(StringType("UTF8_BINARY_LCASE"))), + expected = true + ) + checkEqualsIgnoreCompatibleCollation( + MapType(ArrayType(StringType), IntegerType), + MapType(ArrayType(StringType("UTF8_BINARY_LCASE")), IntegerType), + expected = false + ) + checkEqualsIgnoreCompatibleCollation( + MapType(ArrayType(StringType("UTF8_BINARY_LCASE")), IntegerType), + MapType(ArrayType(StringType("UTF8_BINARY_LCASE")), IntegerType), + expected = true + ) + checkEqualsIgnoreCompatibleCollation( + StructType(StructField("a", StringType) :: Nil), + StructType(StructField("a", StringType("UTF8_BINARY_LCASE")) :: Nil), + expected = true + ) + checkEqualsIgnoreCompatibleCollation( + StructType(StructField("a", ArrayType(StringType)) :: Nil), + StructType(StructField("a", ArrayType(StringType("UTF8_BINARY_LCASE"))) :: Nil), + expected = true + ) + checkEqualsIgnoreCompatibleCollation( + StructType(StructField("a", MapType(StringType, IntegerType)) :: Nil), + StructType(StructField("a", MapType(StringType("UTF8_BINARY_LCASE"), IntegerType)) :: Nil), + expected = false + ) + checkEqualsIgnoreCompatibleCollation( + StructType(StructField("a", StringType) :: Nil), + StructType(StructField("b", StringType("UTF8_BINARY_LCASE")) :: Nil), + expected = false + ) + // Null compatibility checks. + checkEqualsIgnoreCompatibleCollation( + ArrayType(StringType, containsNull = true), + ArrayType(StringType, containsNull = false), + expected = false + ) + checkEqualsIgnoreCompatibleCollation( + ArrayType(StringType, containsNull = true), + ArrayType(StringType("UTF8_BINARY_LCASE"), containsNull = false), + expected = false + ) + checkEqualsIgnoreCompatibleCollation( + MapType(StringType, StringType, valueContainsNull = true), + MapType(StringType, StringType, valueContainsNull = false), + expected = false + ) + checkEqualsIgnoreCompatibleCollation( + StructType(StructField("a", StringType) :: Nil), + StructType(StructField("a", StringType, nullable = false) :: Nil), + expected = false + ) + test("SPARK-25031: MapType should produce current formatted string for complex types") { val keyType: DataType = StructType(Seq( StructField("a", DataTypes.IntegerType), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index cdeb4716e1265..6f402188910e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -356,8 +356,8 @@ case class AlterTableUnsetPropertiesCommand( /** - * A command to change the column for a table, only support changing the comment of a non-partition - * column for now. + * A command to change the column for a table, only support changing the comment or collation of + * the data type or nested types (recursively) of a non-partition column for now. * * The syntax of using this command in SQL is: * {{{ @@ -387,32 +387,45 @@ case class AlterTableChangeColumnCommand( } // Find the origin column from dataSchema by column name. val originColumn = findColumnByName(table.dataSchema, columnName, resolver) - // Throw an AnalysisException if the column name/dataType is changed. - if (!columnEqual(originColumn, newColumn, resolver)) { + val validType = canEvolveType(originColumn, newColumn) + // Throw an AnalysisException on attempt to change collation of bucket column. + if (validType && originColumn.dataType != newColumn.dataType) { + val isBucketColumn = table.bucketSpec match { + case Some(bucketSpec) => bucketSpec.bucketColumnNames.exists(resolver(columnName, _)) + case _ => false + } + if (isBucketColumn) { + throw QueryCompilationErrors.cannotAlterCollationBucketColumn( + table.qualifiedName, columnName) + } + } + // Throw an AnalysisException if the column name is changed or we cannot evolve the data type. + // Only changes in collation of column data type or its nested types (recursively) are allowed. + if (!validType || !namesEqual(originColumn, newColumn, resolver)) { throw QueryCompilationErrors.alterTableChangeColumnNotSupportedForColumnTypeError( toSQLId(table.identifier.nameParts), originColumn, newColumn, this.origin) } val newDataSchema = table.dataSchema.fields.map { field => if (field.name == originColumn.name) { - // Create a new column from the origin column with the new comment. - val withNewComment: StructField = - addComment(field, newColumn.getComment()) + // Create a new column from the origin column with the new type and new comment. + val withNewTypeAndComment: StructField = + addComment(withNewType(field, newColumn.dataType), newColumn.getComment()) // Create a new column from the origin column with the new current default value. if (newColumn.getCurrentDefaultValue().isDefined) { if (newColumn.getCurrentDefaultValue().get.nonEmpty) { val result: StructField = - addCurrentDefaultValue(withNewComment, newColumn.getCurrentDefaultValue()) + addCurrentDefaultValue(withNewTypeAndComment, newColumn.getCurrentDefaultValue()) // Check that the proposed default value parses and analyzes correctly, and that the // type of the resulting expression is equivalent or coercible to the destination column // type. ResolveDefaultColumns.analyze(result, "ALTER TABLE ALTER COLUMN") result } else { - withNewComment.clearCurrentDefaultValue() + withNewTypeAndComment.clearCurrentDefaultValue() } } else { - withNewComment + withNewTypeAndComment } } else { field @@ -432,6 +445,10 @@ case class AlterTableChangeColumnCommand( }.getOrElse(throw QueryCompilationErrors.cannotFindColumnError(name, schema.fieldNames)) } + // Change the dataType of the column. + private def withNewType(column: StructField, dataType: DataType): StructField = + column.copy(dataType = dataType) + // Add the comment to a column, if comment is empty, return the original column. private def addComment(column: StructField, comment: Option[String]): StructField = comment.map(column.withComment).getOrElse(column) @@ -442,10 +459,17 @@ case class AlterTableChangeColumnCommand( value.map(column.withCurrentDefaultValue).getOrElse(column) // Compare a [[StructField]] to another, return true if they have the same column - // name(by resolver) and dataType. - private def columnEqual( + // name(by resolver). + private def namesEqual( field: StructField, other: StructField, resolver: Resolver): Boolean = { - resolver(field.name, other.name) && field.dataType == other.dataType + resolver(field.name, other.name) + } + + // Compare dataType of [[StructField]] to another, return true if it is valid to evolve the type + // when altering column. Only changes in collation of data type or its nested types (recursively) + // are allowed. + private def canEvolveType(from: StructField, to: StructField): Boolean = { + DataType.equalsIgnoreCompatibleCollation(from.dataType, to.dataType) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index e8af606d797e3..b4eeffab8d855 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2437,6 +2437,100 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { ) } } + + test("Change column collation") { + withTable("t1", "t2", "t3", "t4") { + // Plain `StringType`. + sql("CREATE TABLE t1(col STRING) USING parquet") + sql("INSERT INTO t1 VALUES ('a')") + checkAnswer(sql("SELECT COLLATION(col) FROM t1"), Row("UTF8_BINARY")) + sql("ALTER TABLE t1 ALTER COLUMN col TYPE STRING COLLATE UTF8_BINARY_LCASE") + checkAnswer(sql("SELECT COLLATION(col) FROM t1"), Row("UTF8_BINARY_LCASE")) + + // Invalid "ALTER COLUMN" to Integer. + val alterInt = "ALTER TABLE t1 ALTER COLUMN col TYPE INTEGER" + checkError( + exception = intercept[AnalysisException] { + sql(alterInt) + }, + errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + parameters = Map( + "originType" -> "\"STRING COLLATE UTF8_BINARY_LCASE\"", + "originName" -> "`col`", + "table" -> "`spark_catalog`.`default`.`t1`", + "newType" -> "\"INT\"", + "newName" -> "`col`" + ), + context = ExpectedContext(fragment = alterInt, start = 0, stop = alterInt.length - 1) + ) + + // `ArrayType` with collation. + sql("CREATE TABLE t2(col ARRAY) USING parquet") + sql("INSERT INTO t2 VALUES (ARRAY('a'))") + checkAnswer(sql("SELECT COLLATION(col[0]) FROM t2"), Row("UTF8_BINARY")) + sql("ALTER TABLE t2 ALTER COLUMN col TYPE ARRAY") + checkAnswer(sql("SELECT COLLATION(col[0]) FROM t2"), Row("UTF8_BINARY_LCASE")) + + // `MapType` with collation. + sql("CREATE TABLE t3(col MAP) USING parquet") + sql("INSERT INTO t3 VALUES (MAP('k', 'v'))") + checkAnswer(sql("SELECT COLLATION(col['k']) FROM t3"), Row("UTF8_BINARY")) + sql( + """ + |ALTER TABLE t3 ALTER COLUMN col TYPE + |MAP""".stripMargin) + checkAnswer(sql("SELECT COLLATION(col['k']) FROM t3"), Row("UTF8_BINARY_LCASE")) + + // Invalid change of map key collation. + val alterMap = + "ALTER TABLE t3 ALTER COLUMN col TYPE " + + "MAP" + checkError( + exception = intercept[AnalysisException] { + sql(alterMap) + }, + errorClass = "NOT_SUPPORTED_CHANGE_COLUMN", + parameters = Map( + "originType" -> "\"MAP\"", + "originName" -> "`col`", + "table" -> "`spark_catalog`.`default`.`t3`", + "newType" -> "\"MAP\"", + "newName" -> "`col`" + ), + context = ExpectedContext(fragment = alterMap, start = 0, stop = alterMap.length - 1) + ) + + // `StructType` with collation. + sql("CREATE TABLE t4(col STRUCT) USING parquet") + sql("INSERT INTO t4 VALUES (NAMED_STRUCT('a', 'value'))") + checkAnswer(sql("SELECT COLLATION(col.a) FROM t4"), Row("UTF8_BINARY")) + sql("ALTER TABLE t4 ALTER COLUMN col TYPE STRUCT") + checkAnswer(sql("SELECT COLLATION(col.a) FROM t4"), Row("UTF8_BINARY_LCASE")) + } + } + + test("Invalid collation change on partition and bucket columns") { + withTable("t1", "t2") { + sql("CREATE TABLE t1(col STRING, i INTEGER) USING parquet PARTITIONED BY (col)") + checkError( + exception = intercept[AnalysisException] { + sql("ALTER TABLE t1 ALTER COLUMN col TYPE STRING COLLATE UTF8_BINARY_LCASE") + }, + errorClass = "CANNOT_ALTER_PARTITION_COLUMN", + sqlState = "428FR", + parameters = Map("tableName" -> "`spark_catalog`.`default`.`t1`", "columnName" -> "`col`") + ) + sql("CREATE TABLE t2(col STRING) USING parquet CLUSTERED BY (col) INTO 1 BUCKETS") + checkError( + exception = intercept[AnalysisException] { + sql("ALTER TABLE t2 ALTER COLUMN col TYPE STRING COLLATE UTF8_BINARY_LCASE") + }, + errorClass = "CANNOT_ALTER_COLLATION_BUCKET_COLUMN", + sqlState = "428FR", + parameters = Map("tableName" -> "`spark_catalog`.`default`.`t2`", "columnName" -> "`col`") + ) + } + } } object FakeLocalFsFileSystem {