Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48282][SQL] Alter string search logic for UTF8_BINARY_LCASE collation (StringReplace, FindInSet) #46682

Closed
wants to merge 15 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,19 @@ private static int compareLowerCaseSlow(final UTF8String left, final UTF8String
return lowerCaseCodePoints(left).binaryCompare(lowerCaseCodePoints(right));
}

/*
* Performs string replacement for ICU collations by searching for instances of the search
* string in the src string, with respect to the specified collation, and then replacing
* them with the replace string. The method returns a new UTF8String with all instances of the
* search string replaced using the replace string. Similar to UTF8String.findInSet behaviour
* used for UTF8_BINARY collation, the method returns src string if the search string is empty.
*
* @param src the string to be searched in
* @param search the string to be searched for
* @param replace the string to be used as replacement
* @param collationId the collation ID to use for string search
* @return the position of the first occurrence of `match` in `set`
*/
public static UTF8String replace(final UTF8String src, final UTF8String search,
final UTF8String replace, final int collationId) {
// This collation aware implementation is based on existing implementation on UTF8String
Expand Down Expand Up @@ -286,49 +299,47 @@ public static UTF8String replace(final UTF8String src, final UTF8String search,
return buf.build();
}

/*
* Performs string replacement for UTF8_LCASE collation by searching for instances of the search
* string in the src string, with respect to lowercased string versions, and then replacing
* them with the replace string. The method returns a new UTF8String with all instances of the
* search string replaced using the replace string. Similar to UTF8String.findInSet behaviour
* used for UTF8_BINARY collation, the method returns src string if the search string is empty.
*
* @param src the string to be searched in
* @param search the string to be searched for
* @param replace the string to be used as replacement
* @param collationId the collation ID to use for string search
* @return the position of the first occurrence of `match` in `set`
*/
public static UTF8String lowercaseReplace(final UTF8String src, final UTF8String search,
final UTF8String replace) {
if (src.numBytes() == 0 || search.numBytes() == 0) {
return src;
}
UTF8String lowercaseString = src.toLowerCase();

// TODO(SPARK-48725): Use lowerCaseCodePoints instead of UTF8String.toLowerCase.
UTF8String lowercaseSearch = search.toLowerCase();

int start = 0;
int end = lowercaseString.indexOf(lowercaseSearch, 0);
int end = lowercaseFind(src, lowercaseSearch, start);
if (end == -1) {
// Search string was not found, so string is unchanged.
return src;
}

// Initialize byte positions
int c = 0;
int byteStart = 0; // position in byte
int byteEnd = 0; // position in byte
while (byteEnd < src.numBytes() && c < end) {
byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
c += 1;
}

// At least one match was found. Estimate space needed for result.
// The 16x multiplier here is chosen to match commons-lang3's implementation.
int increase = Math.max(0, replace.numBytes() - search.numBytes()) * 16;
final UTF8StringBuilder buf = new UTF8StringBuilder(src.numBytes() + increase);
while (end != -1) {
buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart, byteEnd - byteStart);
buf.append(src.substring(start, end));
buf.append(replace);
// Update character positions
start = end + lowercaseSearch.numChars();
end = lowercaseString.indexOf(lowercaseSearch, start);
// Update byte positions
byteStart = byteEnd + search.numBytes();
while (byteEnd < src.numBytes() && c < end) {
byteEnd += UTF8String.numBytesForFirstByte(src.getByte(byteEnd));
c += 1;
}
start = end + lowercaseMatchLengthFrom(src, lowercaseSearch, end);
end = lowercaseFind(src, lowercaseSearch, start);
}
buf.appendBytes(src.getBaseObject(), src.getBaseOffset() + byteStart,
src.numBytes() - byteStart);
buf.append(src.substring(start, src.numChars()));
return buf.build();
}

Expand Down Expand Up @@ -479,34 +490,40 @@ public static UTF8String toTitleCase(final UTF8String target, final int collatio
BreakIterator.getWordInstance(locale)));
}

/*
* Returns the position of the first occurrence of the match string in the set string,
* counting ASCII commas as delimiters. The match string is compared in a collation-aware manner,
* with respect to the specified collation ID. Similar to UTF8String.findInSet behaviour used
* for UTF8_BINARY collation, the method returns 0 if the match string contains no commas.
*
* @param match the string to be searched for
* @param set the string to be searched in
* @param collationId the collation ID to use for string comparison
* @return the position of the first occurrence of `match` in `set`
*/
public static int findInSet(final UTF8String match, final UTF8String set, int collationId) {
// If the "word" string contains a comma, FindInSet should return 0.
if (match.contains(UTF8String.fromString(","))) {
return 0;
}

// TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid`
String setString = set.toString();
StringSearch stringSearch = CollationFactory.getStringSearch(setString, match.toString(),
collationId);

int wordStart = 0;
while ((wordStart = stringSearch.next()) != StringSearch.DONE) {
boolean isValidStart = wordStart == 0 || setString.charAt(wordStart - 1) == ',';
boolean isValidEnd = wordStart + stringSearch.getMatchLength() == setString.length()
|| setString.charAt(wordStart + stringSearch.getMatchLength()) == ',';

if (isValidStart && isValidEnd) {
int pos = 0;
for (int i = 0; i < setString.length() && i < wordStart; i++) {
if (setString.charAt(i) == ',') {
pos++;
}
// Otherwise, search for commas in "set" and compare each substring with "word".
int byteIndex = 0, charIndex = 0, wordCount = 1, lastComma = -1;
while (byteIndex < set.numBytes()) {
byte nextByte = set.getByte(byteIndex);
if (nextByte == (byte) ',') {
if (set.substring(lastComma + 1, charIndex).semanticEquals(match, collationId)) {
return wordCount;
}

return pos + 1;
lastComma = charIndex;
++wordCount;
}
byteIndex += UTF8String.numBytesForFirstByte(nextByte);
++charIndex;
}

if (set.substring(lastComma + 1, set.numBytes()).semanticEquals(match, collationId)) {
return wordCount;
}
// If no match is found, return 0.
return 0;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,31 +318,24 @@ public static int exec(final UTF8String word, final UTF8String set, final int co
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(word, set);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(word, set);
} else {
return execICU(word, set, collationId);
return execCollationAware(word, set, collationId);
}
}
public static String genCode(final String word, final String set, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.FindInSet.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s)", word, set);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s, %s)", word, set);
} else {
return String.format(expr + "ICU(%s, %s, %d)", word, set, collationId);
return String.format(expr + "execCollationAware(%s, %s, %d)", word, set, collationId);
}
}
public static int execBinary(final UTF8String word, final UTF8String set) {
return set.findInSet(word);
}
public static int execLowercase(final UTF8String word, final UTF8String set) {
return set.toLowerCase().findInSet(word.toLowerCase());
}
public static int execICU(final UTF8String word, final UTF8String set,
final int collationId) {
public static int execCollationAware(final UTF8String word, final UTF8String set,
final int collationId) {
return CollationAwareUTF8String.findInSet(word, set, collationId);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -875,47 +875,105 @@ public void testStringInstr() throws SparkException {
assertStringInstr("aİoi̇oxx", "XX", "UTF8_LCASE", 7);
}

private void assertFindInSet(String word, String set, String collationName,
Integer expected) throws SparkException {
private void assertFindInSet(String word, UTF8String set, String collationName,
Integer expected) throws SparkException {
UTF8String w = UTF8String.fromString(word);
UTF8String s = UTF8String.fromString(set);
int collationId = CollationFactory.collationNameToId(collationName);
assertEquals(expected, CollationSupport.FindInSet.exec(w, s, collationId));
assertEquals(expected, CollationSupport.FindInSet.exec(w, set, collationId));
}

@Test
public void testFindInSet() throws SparkException {
assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("abc", "abc,b,ab,c,def", "UTF8_BINARY", 1);
assertFindInSet("def", "abc,b,ab,c,def", "UTF8_BINARY", 5);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("", "abc,b,ab,c,def", "UTF8_BINARY", 0);
assertFindInSet("a", "abc,b,ab,c,def", "UTF8_LCASE", 0);
assertFindInSet("c", "abc,b,ab,c,def", "UTF8_LCASE", 4);
assertFindInSet("AB", "abc,b,ab,c,def", "UTF8_LCASE", 3);
assertFindInSet("AbC", "abc,b,ab,c,def", "UTF8_LCASE", 1);
assertFindInSet("abcd", "abc,b,ab,c,def", "UTF8_LCASE", 0);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UTF8_LCASE", 0);
assertFindInSet("XX", "xx", "UTF8_LCASE", 1);
assertFindInSet("", "abc,b,ab,c,def", "UTF8_LCASE", 0);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UTF8_LCASE", 4);
assertFindInSet("a", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("ab", "abc,b,ab,c,def", "UNICODE", 3);
assertFindInSet("Ab", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("d,ef", "abc,b,ab,c,def", "UNICODE", 0);
assertFindInSet("xx", "xx", "UNICODE", 1);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE", 0);
assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE", 5);
assertFindInSet("a", "abc,b,ab,c,def", "UNICODE_CI", 0);
assertFindInSet("C", "abc,b,ab,c,def", "UNICODE_CI", 4);
assertFindInSet("DeF", "abc,b,ab,c,dEf", "UNICODE_CI", 5);
assertFindInSet("DEFG", "abc,b,ab,c,def", "UNICODE_CI", 0);
assertFindInSet("XX", "xx", "UNICODE_CI", 1);
assertFindInSet("界x", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 4);
assertFindInSet("界x", "test,大千,界Xx,世,界X,大,千,世界", "UNICODE_CI", 5);
assertFindInSet("大", "test,大千,世,界X,大,千,世界", "UNICODE_CI", 5);
assertFindInSet("i̇o", "ab,İo,12", "UNICODE_CI", 2);
assertFindInSet("İo", "ab,i̇o,12", "UNICODE_CI", 2);
assertFindInSet("AB", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0);
assertFindInSet("abc", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 1);
assertFindInSet("def", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 5);
assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0);
assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_BINARY", 0);
assertFindInSet("", UTF8String.fromString(",abc,b,ab,c,def"), "UTF8_BINARY", 1);
assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UTF8_BINARY", 6);
assertFindInSet("", UTF8String.fromString("abc"), "UTF8_BINARY", 0);
assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0);
assertFindInSet("c", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 4);
assertFindInSet("AB", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 3);
assertFindInSet("AbC", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 1);
assertFindInSet("abcd", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0);
assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0);
assertFindInSet("XX", UTF8String.fromString("xx"), "UTF8_LCASE", 1);
assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def"), "UTF8_LCASE", 0);
assertFindInSet("", UTF8String.fromString(",abc,b,ab,c,def"), "UTF8_LCASE", 1);
assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UTF8_LCASE", 6);
assertFindInSet("", UTF8String.fromString("abc"), "UTF8_LCASE", 0);
assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UTF8_LCASE", 4);
assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0);
assertFindInSet("ab", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 3);
assertFindInSet("Ab", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0);
assertFindInSet("d,ef", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE", 0);
assertFindInSet("", UTF8String.fromString(",abc,b,ab,c,def"), "UNICODE", 1);
assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UNICODE", 6);
assertFindInSet("", UTF8String.fromString("abc"), "UNICODE", 0);
assertFindInSet("xx", UTF8String.fromString("xx"), "UNICODE", 1);
assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE", 0);
assertFindInSet("大", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE", 5);
assertFindInSet("a", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 0);
assertFindInSet("C", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 4);
assertFindInSet("DeF", UTF8String.fromString("abc,b,ab,c,dEf"), "UNICODE_CI", 5);
assertFindInSet("DEFG", UTF8String.fromString("abc,b,ab,c,def"), "UNICODE_CI", 0);
assertFindInSet("", UTF8String.fromString(",abc,b,ab,c,def"), "UNICODE_CI", 1);
assertFindInSet("", UTF8String.fromString("abc,b,ab,c,def,"), "UNICODE_CI", 6);
assertFindInSet("", UTF8String.fromString("abc"), "UNICODE_CI", 0);
assertFindInSet("XX", UTF8String.fromString("xx"), "UNICODE_CI", 1);
assertFindInSet("界x", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE_CI", 4);
assertFindInSet("界x", UTF8String.fromString("test,大千,界Xx,世,界X,大,千,世界"), "UNICODE_CI", 5);
assertFindInSet("大", UTF8String.fromString("test,大千,世,界X,大,千,世界"), "UNICODE_CI", 5);
assertFindInSet("i̇", UTF8String.fromString("İ"), "UNICODE_CI", 1);
assertFindInSet("i", UTF8String.fromString("İ"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("i̇"), "UNICODE_CI", 1);
assertFindInSet("i", UTF8String.fromString("i̇"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("İ,"), "UNICODE_CI", 1);
assertFindInSet("i", UTF8String.fromString("İ,"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("i̇,"), "UNICODE_CI", 1);
assertFindInSet("i", UTF8String.fromString("i̇,"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,İ"), "UNICODE_CI", 2);
assertFindInSet("i", UTF8String.fromString("ab,İ"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,i̇"), "UNICODE_CI", 2);
assertFindInSet("i", UTF8String.fromString("ab,i̇"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,İ,12"), "UNICODE_CI", 2);
assertFindInSet("i", UTF8String.fromString("ab,İ,12"), "UNICODE_CI", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,i̇,12"), "UNICODE_CI", 2);
assertFindInSet("i", UTF8String.fromString("ab,i̇,12"), "UNICODE_CI", 0);
assertFindInSet("i̇o", UTF8String.fromString("ab,İo,12"), "UNICODE_CI", 2);
assertFindInSet("İo", UTF8String.fromString("ab,i̇o,12"), "UNICODE_CI", 2);
assertFindInSet("i̇", UTF8String.fromString("İ"), "UTF8_LCASE", 1);
assertFindInSet("i", UTF8String.fromString("İ"), "UTF8_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("i̇"), "UTF8_LCASE", 1);
assertFindInSet("i", UTF8String.fromString("i̇"), "UTF8_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("İ,"), "UTF8_LCASE", 1);
assertFindInSet("i", UTF8String.fromString("İ,"), "UTF8_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("i̇,"), "UTF8_LCASE", 1);
assertFindInSet("i", UTF8String.fromString("i̇,"), "UTF8_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,İ"), "UTF8_LCASE", 2);
assertFindInSet("i", UTF8String.fromString("ab,İ"), "UTF8_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,i̇"), "UTF8_LCASE", 2);
assertFindInSet("i", UTF8String.fromString("ab,i̇"), "UTF8_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,İ,12"), "UTF8_LCASE", 2);
assertFindInSet("i", UTF8String.fromString("ab,İ,12"), "UTF8_LCASE", 0);
assertFindInSet("i̇", UTF8String.fromString("ab,i̇,12"), "UTF8_LCASE", 2);
assertFindInSet("i", UTF8String.fromString("ab,i̇,12"), "UTF8_LCASE", 0);
assertFindInSet("i̇o", UTF8String.fromString("ab,İo,12"), "UTF8_LCASE", 2);
assertFindInSet("İo", UTF8String.fromString("ab,i̇o,12"), "UTF8_LCASE", 2);
// Invalid UTF8 strings
assertFindInSet("C", UTF8String.fromBytes(
new byte[] { 0x41, (byte) 0xC2, 0x2C, 0x42, 0x2C, 0x43, 0x2C, 0x43, 0x2C, 0x56 }),
"UTF8_BINARY", 3);
assertFindInSet("c", UTF8String.fromBytes(
new byte[] { 0x41, (byte) 0xC2, 0x2C, 0x42, 0x2C, 0x43, 0x2C, 0x43, 0x2C, 0x56 }),
"UTF8_LCASE", 2);
assertFindInSet("C", UTF8String.fromBytes(
new byte[] { 0x41, (byte) 0xC2, 0x2C, 0x42, 0x2C, 0x43, 0x2C, 0x43, 0x2C, 0x56 }),
"UNICODE", 2);
assertFindInSet("c", UTF8String.fromBytes(
new byte[] { 0x41, (byte) 0xC2, 0x2C, 0x42, 0x2C, 0x43, 0x2C, 0x43, 0x2C, 0x56 }),
"UNICODE_CI", 2);
}

private void assertReplace(String source, String search, String replace, String collationName,
Expand Down Expand Up @@ -952,8 +1010,23 @@ public void testReplace() throws SparkException {
assertReplace("replace", "", "123", "UNICODE_CI", "replace");
assertReplace("aBc世abc", "b", "12", "UNICODE_CI", "a12c世a12c");
assertReplace("a世Bcdabcd", "bC", "", "UNICODE_CI", "a世dad");
assertReplace("abi̇12", "i", "X", "UNICODE_CI", "abi̇12");
assertReplace("abi̇12", "\u0307", "X", "UNICODE_CI", "abi̇12");
assertReplace("abi̇12", "İ", "X", "UNICODE_CI", "abX12");
assertReplace("abİ12", "i", "X", "UNICODE_CI", "abİ12");
assertReplace("İi̇İi̇İi̇", "i̇", "x", "UNICODE_CI", "xxxxxx");
assertReplace("İi̇İi̇İi̇", "i", "x", "UNICODE_CI", "İi̇İi̇İi̇");
assertReplace("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx");
assertReplace("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy");
assertReplace("abi̇12", "i", "X", "UTF8_LCASE", "abX\u030712"); // != UNICODE_CI
assertReplace("abi̇12", "\u0307", "X", "UTF8_LCASE", "abiX12"); // != UNICODE_CI
assertReplace("abi̇12", "İ", "X", "UTF8_LCASE", "abX12");
assertReplace("abİ12", "i", "X", "UTF8_LCASE", "abİ12");
assertReplace("İi̇İi̇İi̇", "i̇", "x", "UTF8_LCASE", "xxxxxx");
assertReplace("İi̇İi̇İi̇", "i", "x", "UTF8_LCASE",
"İx\u0307İx\u0307İx\u0307"); // != UNICODE_CI
assertReplace("abİo12i̇o", "i̇o", "xx", "UTF8_LCASE", "abxx12xx");
assertReplace("abi̇o12i̇o", "İo", "yy", "UTF8_LCASE", "abyy12yy");
}

private void assertLocate(String substring, String string, Integer start, String collationName,
Expand Down