Skip to content

Commit

Permalink
[SPARK-47352][SQL] Fix Upper, Lower, InitCap collation awareness
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Add support for Locale aware expressions.

### Why are the changes needed?
This is needed as some future collations might use different Locales then default.

### Does this PR introduce _any_ user-facing change?
Yes, we follow ICU implementations for collations that are non native.

### How was this patch tested?
Tests for Upper, Lower and InitCap already exist.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#46104 from mihailom-db/SPARK-47352.

Authored-by: Mihailo Milosevic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
mihailom-db authored and cloud-fan committed Apr 23, 2024
1 parent 61ac342 commit b9f2270
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
*/
package org.apache.spark.sql.catalyst.util;

import com.ibm.icu.lang.UCharacter;
import com.ibm.icu.text.BreakIterator;
import com.ibm.icu.text.StringSearch;
import com.ibm.icu.util.ULocale;

import org.apache.spark.unsafe.types.UTF8String;

Expand Down Expand Up @@ -144,6 +147,93 @@ public static boolean execICU(final UTF8String l, final UTF8String r,
}
}

public static class Upper {
public static UTF8String exec(final UTF8String v, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) {
return execUTF8(v);
} else {
return execICU(v, collationId);
}
}
public static String genCode(final String v, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.Upper.exec";
if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) {
return String.format(expr + "UTF8(%s)", v);
} else {
return String.format(expr + "ICU(%s, %d)", v, collationId);
}
}
public static UTF8String execUTF8(final UTF8String v) {
return v.toUpperCase();
}
public static UTF8String execICU(final UTF8String v, final int collationId) {
return UTF8String.fromString(CollationAwareUTF8String.toUpperCase(v.toString(), collationId));
}
}

public static class Lower {
public static UTF8String exec(final UTF8String v, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) {
return execUTF8(v);
} else {
return execICU(v, collationId);
}
}
public static String genCode(final String v, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.Lower.exec";
if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) {
return String.format(expr + "UTF8(%s)", v);
} else {
return String.format(expr + "ICU(%s, %d)", v, collationId);
}
}
public static UTF8String execUTF8(final UTF8String v) {
return v.toLowerCase();
}
public static UTF8String execICU(final UTF8String v, final int collationId) {
return UTF8String.fromString(CollationAwareUTF8String.toLowerCase(v.toString(), collationId));
}
}

public static class InitCap {
public static UTF8String exec(final UTF8String v, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) {
return execUTF8(v);
} else {
return execICU(v, collationId);
}
}

public static String genCode(final String v, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.InitCap.exec";
if (collation.supportsBinaryEquality || collation.supportsLowercaseEquality) {
return String.format(expr + "UTF8(%s)", v);
} else {
return String.format(expr + "ICU(%s, %d)", v, collationId);
}
}

public static UTF8String execUTF8(final UTF8String v) {
return v.toLowerCase().toTitleCase();
}

public static UTF8String execICU(final UTF8String v, final int collationId) {
return UTF8String.fromString(
CollationAwareUTF8String.toTitleCase(
CollationAwareUTF8String.toLowerCase(
v.toString(),
collationId
),
collationId));
}
}

public static class FindInSet {
public static int exec(final UTF8String word, final UTF8String set, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
Expand Down Expand Up @@ -234,6 +324,24 @@ public static int execICU(final UTF8String string, final UTF8String substring,

private static class CollationAwareUTF8String {

private static String toUpperCase(final String target, final int collationId) {
ULocale locale = CollationFactory.fetchCollation(collationId)
.collator.getLocale(ULocale.ACTUAL_LOCALE);
return UCharacter.toUpperCase(locale, target);
}

private static String toLowerCase(final String target, final int collationId) {
ULocale locale = CollationFactory.fetchCollation(collationId)
.collator.getLocale(ULocale.ACTUAL_LOCALE);
return UCharacter.toLowerCase(locale, target);
}

private static String toTitleCase(final String target, final int collationId) {
ULocale locale = CollationFactory.fetchCollation(collationId)
.collator.getLocale(ULocale.ACTUAL_LOCALE);
return UCharacter.toTitleCase(locale, target, BreakIterator.getWordInstance(locale));
}

private static int findInSet(final UTF8String match, final UTF8String set, int collationId) {
if (match.contains(UTF8String.fromString(","))) {
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,157 @@ public void testEndsWith() throws SparkException {
assertEndsWith("The i̇o", "İo", "UNICODE_CI", true);
}


private void assertUpper(String target, String collationName, String expected)
throws SparkException {
UTF8String target_utf8 = UTF8String.fromString(target);
UTF8String expected_utf8 = UTF8String.fromString(expected);
int collationId = CollationFactory.collationNameToId(collationName);
assertEquals(expected_utf8, CollationSupport.Upper.exec(target_utf8, collationId));
}

@Test
public void testUpper() throws SparkException {
// Edge cases
assertUpper("", "UTF8_BINARY", "");
assertUpper("", "UTF8_BINARY_LCASE", "");
assertUpper("", "UNICODE", "");
assertUpper("", "UNICODE_CI", "");
// Basic tests
assertUpper("abcde", "UTF8_BINARY", "ABCDE");
assertUpper("abcde", "UTF8_BINARY_LCASE", "ABCDE");
assertUpper("abcde", "UNICODE", "ABCDE");
assertUpper("abcde", "UNICODE_CI", "ABCDE");
// Uppercase present
assertUpper("AbCdE", "UTF8_BINARY", "ABCDE");
assertUpper("aBcDe", "UTF8_BINARY", "ABCDE");
assertUpper("AbCdE", "UTF8_BINARY_LCASE", "ABCDE");
assertUpper("aBcDe", "UTF8_BINARY_LCASE", "ABCDE");
assertUpper("AbCdE", "UNICODE", "ABCDE");
assertUpper("aBcDe", "UNICODE", "ABCDE");
assertUpper("AbCdE", "UNICODE_CI", "ABCDE");
assertUpper("aBcDe", "UNICODE_CI", "ABCDE");
// Accent letters
assertUpper("aBćDe","UTF8_BINARY", "ABĆDE");
assertUpper("aBćDe","UTF8_BINARY_LCASE", "ABĆDE");
assertUpper("aBćDe","UNICODE", "ABĆDE");
assertUpper("aBćDe","UNICODE_CI", "ABĆDE");
// Variable byte length characters
assertUpper("ab世De", "UTF8_BINARY", "AB世DE");
assertUpper("äbćδe", "UTF8_BINARY", "ÄBĆΔE");
assertUpper("ab世De", "UTF8_BINARY_LCASE", "AB世DE");
assertUpper("äbćδe", "UTF8_BINARY_LCASE", "ÄBĆΔE");
assertUpper("ab世De", "UNICODE", "AB世DE");
assertUpper("äbćδe", "UNICODE", "ÄBĆΔE");
assertUpper("ab世De", "UNICODE_CI", "AB世DE");
assertUpper("äbćδe", "UNICODE_CI", "ÄBĆΔE");
// Case-variable character length
assertUpper("i̇o", "UTF8_BINARY","İO");
assertUpper("i̇o", "UTF8_BINARY_LCASE","İO");
assertUpper("i̇o", "UNICODE","İO");
assertUpper("i̇o", "UNICODE_CI","İO");
}

private void assertLower(String target, String collationName, String expected)
throws SparkException {
UTF8String target_utf8 = UTF8String.fromString(target);
UTF8String expected_utf8 = UTF8String.fromString(expected);
int collationId = CollationFactory.collationNameToId(collationName);
assertEquals(expected_utf8, CollationSupport.Lower.exec(target_utf8, collationId));
}

@Test
public void testLower() throws SparkException {
// Edge cases
assertLower("", "UTF8_BINARY", "");
assertLower("", "UTF8_BINARY_LCASE", "");
assertLower("", "UNICODE", "");
assertLower("", "UNICODE_CI", "");
// Basic tests
assertLower("ABCDE", "UTF8_BINARY", "abcde");
assertLower("ABCDE", "UTF8_BINARY_LCASE", "abcde");
assertLower("ABCDE", "UNICODE", "abcde");
assertLower("ABCDE", "UNICODE_CI", "abcde");
// Uppercase present
assertLower("AbCdE", "UTF8_BINARY", "abcde");
assertLower("aBcDe", "UTF8_BINARY", "abcde");
assertLower("AbCdE", "UTF8_BINARY_LCASE", "abcde");
assertLower("aBcDe", "UTF8_BINARY_LCASE", "abcde");
assertLower("AbCdE", "UNICODE", "abcde");
assertLower("aBcDe", "UNICODE", "abcde");
assertLower("AbCdE", "UNICODE_CI", "abcde");
assertLower("aBcDe", "UNICODE_CI", "abcde");
// Accent letters
assertLower("AbĆdE","UTF8_BINARY", "abćde");
assertLower("AbĆdE","UTF8_BINARY_LCASE", "abćde");
assertLower("AbĆdE","UNICODE", "abćde");
assertLower("AbĆdE","UNICODE_CI", "abćde");
// Variable byte length characters
assertLower("aB世De", "UTF8_BINARY", "ab世de");
assertLower("ÄBĆΔE", "UTF8_BINARY", "äbćδe");
assertLower("aB世De", "UTF8_BINARY_LCASE", "ab世de");
assertLower("ÄBĆΔE", "UTF8_BINARY_LCASE", "äbćδe");
assertLower("aB世De", "UNICODE", "ab世de");
assertLower("ÄBĆΔE", "UNICODE", "äbćδe");
assertLower("aB世De", "UNICODE_CI", "ab世de");
assertLower("ÄBĆΔE", "UNICODE_CI", "äbćδe");
// Case-variable character length
assertLower("İo", "UTF8_BINARY","i̇o");
assertLower("İo", "UTF8_BINARY_LCASE","i̇o");
assertLower("İo", "UNICODE","i̇o");
assertLower("İo", "UNICODE_CI","i̇o");
}

private void assertInitCap(String target, String collationName, String expected)
throws SparkException {
UTF8String target_utf8 = UTF8String.fromString(target);
UTF8String expected_utf8 = UTF8String.fromString(expected);
int collationId = CollationFactory.collationNameToId(collationName);
assertEquals(expected_utf8, CollationSupport.InitCap.exec(target_utf8, collationId));
}

@Test
public void testInitCap() throws SparkException {
// Edge cases
assertInitCap("", "UTF8_BINARY", "");
assertInitCap("", "UTF8_BINARY_LCASE", "");
assertInitCap("", "UNICODE", "");
assertInitCap("", "UNICODE_CI", "");
// Basic tests
assertInitCap("ABCDE", "UTF8_BINARY", "Abcde");
assertInitCap("ABCDE", "UTF8_BINARY_LCASE", "Abcde");
assertInitCap("ABCDE", "UNICODE", "Abcde");
assertInitCap("ABCDE", "UNICODE_CI", "Abcde");
// Uppercase present
assertInitCap("AbCdE", "UTF8_BINARY", "Abcde");
assertInitCap("aBcDe", "UTF8_BINARY", "Abcde");
assertInitCap("AbCdE", "UTF8_BINARY_LCASE", "Abcde");
assertInitCap("aBcDe", "UTF8_BINARY_LCASE", "Abcde");
assertInitCap("AbCdE", "UNICODE", "Abcde");
assertInitCap("aBcDe", "UNICODE", "Abcde");
assertInitCap("AbCdE", "UNICODE_CI", "Abcde");
assertInitCap("aBcDe", "UNICODE_CI", "Abcde");
// Accent letters
assertInitCap("AbĆdE", "UTF8_BINARY", "Abćde");
assertInitCap("AbĆdE", "UTF8_BINARY_LCASE", "Abćde");
assertInitCap("AbĆdE", "UNICODE", "Abćde");
assertInitCap("AbĆdE", "UNICODE_CI", "Abćde");
// Variable byte length characters
assertInitCap("aB 世 De", "UTF8_BINARY", "Ab 世 De");
assertInitCap("ÄBĆΔE", "UTF8_BINARY", "Äbćδe");
assertInitCap("aB 世 De", "UTF8_BINARY_LCASE", "Ab 世 De");
assertInitCap("ÄBĆΔE", "UTF8_BINARY_LCASE", "Äbćδe");
assertInitCap("aB 世 De", "UNICODE", "Ab 世 De");
assertInitCap("ÄBĆΔE", "UNICODE", "Äbćδe");
assertInitCap("aB 世 de", "UNICODE_CI", "Ab 世 De");
assertInitCap("ÄBĆΔE", "UNICODE_CI", "Äbćδe");
// Case-variable character length
assertInitCap("İo", "UTF8_BINARY", "İo");
assertInitCap("İo", "UTF8_BINARY_LCASE", "İo");
assertInitCap("İo", "UNICODE", "İo");
assertInitCap("İo", "UNICODE_CI", "İo");
}

private void assertStringInstr(String string, String substring, String collationName,
Integer expected) throws SparkException {
UTF8String str = UTF8String.fromString(string);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,14 +453,14 @@ trait String2StringExpression extends ImplicitCastInputTypes {
case class Upper(child: Expression)
extends UnaryExpression with String2StringExpression with NullIntolerant {

// scalastyle:off caselocale
override def convert(v: UTF8String): UTF8String = v.toUpperCase
// scalastyle:on caselocale
final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId

override def convert(v: UTF8String): UTF8String = CollationSupport.Upper.exec(v, collationId)

final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"($c).toUpperCase()")
defineCodeGen(ctx, ev, c => CollationSupport.Upper.genCode(c, collationId))
}

override protected def withNewChildInternal(newChild: Expression): Upper = copy(child = newChild)
Expand All @@ -481,14 +481,14 @@ case class Upper(child: Expression)
case class Lower(child: Expression)
extends UnaryExpression with String2StringExpression with NullIntolerant {

// scalastyle:off caselocale
override def convert(v: UTF8String): UTF8String = v.toLowerCase
// scalastyle:on caselocale
final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId

override def convert(v: UTF8String): UTF8String = CollationSupport.Lower.exec(v, collationId)

final override val nodePatterns: Seq[TreePattern] = Seq(UPPER_OR_LOWER)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"($c).toLowerCase()")
defineCodeGen(ctx, ev, c => CollationSupport.Lower.genCode(c, collationId))
}

override def prettyName: String =
Expand Down Expand Up @@ -1824,16 +1824,16 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
case class InitCap(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {

final lazy val collationId: Int = child.dataType.asInstanceOf[StringType].collationId

override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)
override def dataType: DataType = child.dataType

override def nullSafeEval(string: Any): Any = {
// scalastyle:off caselocale
string.asInstanceOf[UTF8String].toLowerCase.toTitleCase
// scalastyle:on caselocale
CollationSupport.InitCap.exec(string.asInstanceOf[UTF8String], collationId)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()")
defineCodeGen(ctx, ev, str => CollationSupport.InitCap.genCode(str, collationId))
}

override protected def withNewChildInternal(newChild: Expression): InitCap =
Expand Down

0 comments on commit b9f2270

Please sign in to comment.