Skip to content

Commit

Permalink
[SPARK-48413][SQL] ALTER COLUMN with collation
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Add support for changing collation of a column with `ALTER COLUMN` command. Use existing support for `ALTER COLUMN` with type to enable changing collations of column. Syntax example:
```
ALTER TABLE t1 ALTER COLUMN col TYPE STRING COLLATE UTF8_BINARY_LCASE
```

### Why are the changes needed?

Enable changing collation on column.

### Does this PR introduce _any_ user-facing change?

Yes, it adds support for changing collation of column.

### How was this patch tested?

Added tests to `DDLSuite` and `DataTypeSuite`.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #46734 from nikolamand-db/SPARK-48413.

Authored-by: Nikola Mandic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
nikolamand-db authored and cloud-fan committed Jun 3, 2024
1 parent 5baaa61 commit f9542d0
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 13 deletions.
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@
],
"sqlState" : "42KDE"
},
"CANNOT_ALTER_COLLATION_BUCKET_COLUMN" : {
"message" : [
"ALTER TABLE (ALTER|CHANGE) COLUMN cannot change collation of type/subtypes of bucket columns, but found the bucket column <columnName> in the table <tableName>."
],
"sqlState" : "428FR"
},
"CANNOT_ALTER_PARTITION_COLUMN" : {
"message" : [
"ALTER TABLE (ALTER|CHANGE) COLUMN is not supported for partition columns, but found the partition column <columnName> in the table <tableName>."
Expand Down
35 changes: 35 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,41 @@ object DataType {
}
}

/**
* Check if `from` is equal to `to` type except for collations, which are checked to be
* compatible so that data of type `from` can be interpreted as of type `to`.
*/
private[sql] def equalsIgnoreCompatibleCollation(
from: DataType,
to: DataType): Boolean = {
(from, to) match {
// String types with possibly different collations are compatible.
case (_: StringType, _: StringType) => true

case (ArrayType(fromElement, fromContainsNull), ArrayType(toElement, toContainsNull)) =>
(fromContainsNull == toContainsNull) &&
equalsIgnoreCompatibleCollation(fromElement, toElement)

case (MapType(fromKey, fromValue, fromContainsNull),
MapType(toKey, toValue, toContainsNull)) =>
fromContainsNull == toContainsNull &&
// Map keys cannot change collation.
fromKey == toKey &&
equalsIgnoreCompatibleCollation(fromValue, toValue)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall { case (fromField, toField) =>
fromField.name == toField.name &&
fromField.nullable == toField.nullable &&
fromField.metadata == toField.metadata &&
equalsIgnoreCompatibleCollation(fromField.dataType, toField.dataType)
}

case (fromDataType, toDataType) => fromDataType == toDataType
}
}

/**
* Returns true if the two data types share the same "shape", i.e. the types
* are the same, but the field names don't need to be the same.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2687,6 +2687,15 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
Map("tableName" -> toSQLId(tableName), "columnName" -> toSQLId(columnName))
)
}

def cannotAlterCollationBucketColumn(tableName: String, columnName: String): Throwable = {
new AnalysisException(
errorClass = "CANNOT_ALTER_COLLATION_BUCKET_COLUMN",
messageParameters =
Map("tableName" -> toSQLId(tableName), "columnName" -> toSQLId(columnName))
)
}

def cannotFindColumnError(name: String, fieldNames: Array[String]): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1246",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,115 @@ class DataTypeSuite extends SparkFunSuite {
false,
caseSensitive = true)

def checkEqualsIgnoreCompatibleCollation(
from: DataType,
to: DataType,
expected: Boolean): Unit = {
val testName = s"equalsIgnoreCompatibleCollation: (from: $from, to: $to)"

test(testName) {
assert(DataType.equalsIgnoreCompatibleCollation(from, to) === expected)
}
}

// Simple types.
checkEqualsIgnoreCompatibleCollation(IntegerType, IntegerType, expected = true)
checkEqualsIgnoreCompatibleCollation(BooleanType, BooleanType, expected = true)
checkEqualsIgnoreCompatibleCollation(StringType, StringType, expected = true)
checkEqualsIgnoreCompatibleCollation(IntegerType, BooleanType, expected = false)
checkEqualsIgnoreCompatibleCollation(BooleanType, IntegerType, expected = false)
checkEqualsIgnoreCompatibleCollation(StringType, BooleanType, expected = false)
checkEqualsIgnoreCompatibleCollation(BooleanType, StringType, expected = false)
checkEqualsIgnoreCompatibleCollation(StringType, IntegerType, expected = false)
checkEqualsIgnoreCompatibleCollation(IntegerType, StringType, expected = false)
// Collated `StringType`.
checkEqualsIgnoreCompatibleCollation(StringType, StringType("UTF8_BINARY_LCASE"),
expected = true)
checkEqualsIgnoreCompatibleCollation(
StringType("UTF8_BINARY"), StringType("UTF8_BINARY_LCASE"), expected = true)
// Complex types.
checkEqualsIgnoreCompatibleCollation(
ArrayType(StringType),
ArrayType(StringType("UTF8_BINARY_LCASE")),
expected = true
)
checkEqualsIgnoreCompatibleCollation(
ArrayType(StringType),
ArrayType(ArrayType(StringType("UTF8_BINARY_LCASE"))),
expected = false
)
checkEqualsIgnoreCompatibleCollation(
ArrayType(ArrayType(StringType)),
ArrayType(ArrayType(StringType("UTF8_BINARY_LCASE"))),
expected = true
)
checkEqualsIgnoreCompatibleCollation(
MapType(StringType, StringType),
MapType(StringType, StringType("UTF8_BINARY_LCASE")),
expected = true
)
checkEqualsIgnoreCompatibleCollation(
MapType(StringType("UTF8_BINARY_LCASE"), StringType),
MapType(StringType, StringType),
expected = false
)
checkEqualsIgnoreCompatibleCollation(
MapType(StringType("UTF8_BINARY_LCASE"), ArrayType(StringType)),
MapType(StringType("UTF8_BINARY_LCASE"), ArrayType(StringType("UTF8_BINARY_LCASE"))),
expected = true
)
checkEqualsIgnoreCompatibleCollation(
MapType(ArrayType(StringType), IntegerType),
MapType(ArrayType(StringType("UTF8_BINARY_LCASE")), IntegerType),
expected = false
)
checkEqualsIgnoreCompatibleCollation(
MapType(ArrayType(StringType("UTF8_BINARY_LCASE")), IntegerType),
MapType(ArrayType(StringType("UTF8_BINARY_LCASE")), IntegerType),
expected = true
)
checkEqualsIgnoreCompatibleCollation(
StructType(StructField("a", StringType) :: Nil),
StructType(StructField("a", StringType("UTF8_BINARY_LCASE")) :: Nil),
expected = true
)
checkEqualsIgnoreCompatibleCollation(
StructType(StructField("a", ArrayType(StringType)) :: Nil),
StructType(StructField("a", ArrayType(StringType("UTF8_BINARY_LCASE"))) :: Nil),
expected = true
)
checkEqualsIgnoreCompatibleCollation(
StructType(StructField("a", MapType(StringType, IntegerType)) :: Nil),
StructType(StructField("a", MapType(StringType("UTF8_BINARY_LCASE"), IntegerType)) :: Nil),
expected = false
)
checkEqualsIgnoreCompatibleCollation(
StructType(StructField("a", StringType) :: Nil),
StructType(StructField("b", StringType("UTF8_BINARY_LCASE")) :: Nil),
expected = false
)
// Null compatibility checks.
checkEqualsIgnoreCompatibleCollation(
ArrayType(StringType, containsNull = true),
ArrayType(StringType, containsNull = false),
expected = false
)
checkEqualsIgnoreCompatibleCollation(
ArrayType(StringType, containsNull = true),
ArrayType(StringType("UTF8_BINARY_LCASE"), containsNull = false),
expected = false
)
checkEqualsIgnoreCompatibleCollation(
MapType(StringType, StringType, valueContainsNull = true),
MapType(StringType, StringType, valueContainsNull = false),
expected = false
)
checkEqualsIgnoreCompatibleCollation(
StructType(StructField("a", StringType) :: Nil),
StructType(StructField("a", StringType, nullable = false) :: Nil),
expected = false
)

test("SPARK-25031: MapType should produce current formatted string for complex types") {
val keyType: DataType = StructType(Seq(
StructField("a", DataTypes.IntegerType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ case class AlterTableUnsetPropertiesCommand(


/**
* A command to change the column for a table, only support changing the comment of a non-partition
* column for now.
* A command to change the column for a table, only support changing the comment or collation of
* the data type or nested types (recursively) of a non-partition column for now.
*
* The syntax of using this command in SQL is:
* {{{
Expand Down Expand Up @@ -387,32 +387,45 @@ case class AlterTableChangeColumnCommand(
}
// Find the origin column from dataSchema by column name.
val originColumn = findColumnByName(table.dataSchema, columnName, resolver)
// Throw an AnalysisException if the column name/dataType is changed.
if (!columnEqual(originColumn, newColumn, resolver)) {
val validType = canEvolveType(originColumn, newColumn)
// Throw an AnalysisException on attempt to change collation of bucket column.
if (validType && originColumn.dataType != newColumn.dataType) {
val isBucketColumn = table.bucketSpec match {
case Some(bucketSpec) => bucketSpec.bucketColumnNames.exists(resolver(columnName, _))
case _ => false
}
if (isBucketColumn) {
throw QueryCompilationErrors.cannotAlterCollationBucketColumn(
table.qualifiedName, columnName)
}
}
// Throw an AnalysisException if the column name is changed or we cannot evolve the data type.
// Only changes in collation of column data type or its nested types (recursively) are allowed.
if (!validType || !namesEqual(originColumn, newColumn, resolver)) {
throw QueryCompilationErrors.alterTableChangeColumnNotSupportedForColumnTypeError(
toSQLId(table.identifier.nameParts), originColumn, newColumn, this.origin)
}

val newDataSchema = table.dataSchema.fields.map { field =>
if (field.name == originColumn.name) {
// Create a new column from the origin column with the new comment.
val withNewComment: StructField =
addComment(field, newColumn.getComment())
// Create a new column from the origin column with the new type and new comment.
val withNewTypeAndComment: StructField =
addComment(withNewType(field, newColumn.dataType), newColumn.getComment())
// Create a new column from the origin column with the new current default value.
if (newColumn.getCurrentDefaultValue().isDefined) {
if (newColumn.getCurrentDefaultValue().get.nonEmpty) {
val result: StructField =
addCurrentDefaultValue(withNewComment, newColumn.getCurrentDefaultValue())
addCurrentDefaultValue(withNewTypeAndComment, newColumn.getCurrentDefaultValue())
// Check that the proposed default value parses and analyzes correctly, and that the
// type of the resulting expression is equivalent or coercible to the destination column
// type.
ResolveDefaultColumns.analyze(result, "ALTER TABLE ALTER COLUMN")
result
} else {
withNewComment.clearCurrentDefaultValue()
withNewTypeAndComment.clearCurrentDefaultValue()
}
} else {
withNewComment
withNewTypeAndComment
}
} else {
field
Expand All @@ -432,6 +445,10 @@ case class AlterTableChangeColumnCommand(
}.getOrElse(throw QueryCompilationErrors.cannotFindColumnError(name, schema.fieldNames))
}

// Change the dataType of the column.
private def withNewType(column: StructField, dataType: DataType): StructField =
column.copy(dataType = dataType)

// Add the comment to a column, if comment is empty, return the original column.
private def addComment(column: StructField, comment: Option[String]): StructField =
comment.map(column.withComment).getOrElse(column)
Expand All @@ -442,10 +459,17 @@ case class AlterTableChangeColumnCommand(
value.map(column.withCurrentDefaultValue).getOrElse(column)

// Compare a [[StructField]] to another, return true if they have the same column
// name(by resolver) and dataType.
private def columnEqual(
// name(by resolver).
private def namesEqual(
field: StructField, other: StructField, resolver: Resolver): Boolean = {
resolver(field.name, other.name) && field.dataType == other.dataType
resolver(field.name, other.name)
}

// Compare dataType of [[StructField]] to another, return true if it is valid to evolve the type
// when altering column. Only changes in collation of data type or its nested types (recursively)
// are allowed.
private def canEvolveType(from: StructField, to: StructField): Boolean = {
DataType.equalsIgnoreCompatibleCollation(from.dataType, to.dataType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2437,6 +2437,100 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase {
)
}
}

test("Change column collation") {
withTable("t1", "t2", "t3", "t4") {
// Plain `StringType`.
sql("CREATE TABLE t1(col STRING) USING parquet")
sql("INSERT INTO t1 VALUES ('a')")
checkAnswer(sql("SELECT COLLATION(col) FROM t1"), Row("UTF8_BINARY"))
sql("ALTER TABLE t1 ALTER COLUMN col TYPE STRING COLLATE UTF8_BINARY_LCASE")
checkAnswer(sql("SELECT COLLATION(col) FROM t1"), Row("UTF8_BINARY_LCASE"))

// Invalid "ALTER COLUMN" to Integer.
val alterInt = "ALTER TABLE t1 ALTER COLUMN col TYPE INTEGER"
checkError(
exception = intercept[AnalysisException] {
sql(alterInt)
},
errorClass = "NOT_SUPPORTED_CHANGE_COLUMN",
parameters = Map(
"originType" -> "\"STRING COLLATE UTF8_BINARY_LCASE\"",
"originName" -> "`col`",
"table" -> "`spark_catalog`.`default`.`t1`",
"newType" -> "\"INT\"",
"newName" -> "`col`"
),
context = ExpectedContext(fragment = alterInt, start = 0, stop = alterInt.length - 1)
)

// `ArrayType` with collation.
sql("CREATE TABLE t2(col ARRAY<STRING>) USING parquet")
sql("INSERT INTO t2 VALUES (ARRAY('a'))")
checkAnswer(sql("SELECT COLLATION(col[0]) FROM t2"), Row("UTF8_BINARY"))
sql("ALTER TABLE t2 ALTER COLUMN col TYPE ARRAY<STRING COLLATE UTF8_BINARY_LCASE>")
checkAnswer(sql("SELECT COLLATION(col[0]) FROM t2"), Row("UTF8_BINARY_LCASE"))

// `MapType` with collation.
sql("CREATE TABLE t3(col MAP<STRING, STRING>) USING parquet")
sql("INSERT INTO t3 VALUES (MAP('k', 'v'))")
checkAnswer(sql("SELECT COLLATION(col['k']) FROM t3"), Row("UTF8_BINARY"))
sql(
"""
|ALTER TABLE t3 ALTER COLUMN col TYPE
|MAP<STRING, STRING COLLATE UTF8_BINARY_LCASE>""".stripMargin)
checkAnswer(sql("SELECT COLLATION(col['k']) FROM t3"), Row("UTF8_BINARY_LCASE"))

// Invalid change of map key collation.
val alterMap =
"ALTER TABLE t3 ALTER COLUMN col TYPE " +
"MAP<STRING COLLATE UTF8_BINARY_LCASE, STRING>"
checkError(
exception = intercept[AnalysisException] {
sql(alterMap)
},
errorClass = "NOT_SUPPORTED_CHANGE_COLUMN",
parameters = Map(
"originType" -> "\"MAP<STRING, STRING COLLATE UTF8_BINARY_LCASE>\"",
"originName" -> "`col`",
"table" -> "`spark_catalog`.`default`.`t3`",
"newType" -> "\"MAP<STRING COLLATE UTF8_BINARY_LCASE, STRING>\"",
"newName" -> "`col`"
),
context = ExpectedContext(fragment = alterMap, start = 0, stop = alterMap.length - 1)
)

// `StructType` with collation.
sql("CREATE TABLE t4(col STRUCT<a:STRING>) USING parquet")
sql("INSERT INTO t4 VALUES (NAMED_STRUCT('a', 'value'))")
checkAnswer(sql("SELECT COLLATION(col.a) FROM t4"), Row("UTF8_BINARY"))
sql("ALTER TABLE t4 ALTER COLUMN col TYPE STRUCT<a:STRING COLLATE UTF8_BINARY_LCASE>")
checkAnswer(sql("SELECT COLLATION(col.a) FROM t4"), Row("UTF8_BINARY_LCASE"))
}
}

test("Invalid collation change on partition and bucket columns") {
withTable("t1", "t2") {
sql("CREATE TABLE t1(col STRING, i INTEGER) USING parquet PARTITIONED BY (col)")
checkError(
exception = intercept[AnalysisException] {
sql("ALTER TABLE t1 ALTER COLUMN col TYPE STRING COLLATE UTF8_BINARY_LCASE")
},
errorClass = "CANNOT_ALTER_PARTITION_COLUMN",
sqlState = "428FR",
parameters = Map("tableName" -> "`spark_catalog`.`default`.`t1`", "columnName" -> "`col`")
)
sql("CREATE TABLE t2(col STRING) USING parquet CLUSTERED BY (col) INTO 1 BUCKETS")
checkError(
exception = intercept[AnalysisException] {
sql("ALTER TABLE t2 ALTER COLUMN col TYPE STRING COLLATE UTF8_BINARY_LCASE")
},
errorClass = "CANNOT_ALTER_COLLATION_BUCKET_COLUMN",
sqlState = "428FR",
parameters = Map("tableName" -> "`spark_catalog`.`default`.`t2`", "columnName" -> "`col`")
)
}
}
}

object FakeLocalFsFileSystem {
Expand Down

0 comments on commit f9542d0

Please sign in to comment.