Skip to content

Commit

Permalink
[SPARK-48686][SQL] Improve performance of ParserUtils.unescapeSQLString
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR implements multiple performance optimizations for `ParserUtils.unescapeSQLString`:

1. Don't use regex: following apache#31362, the existing code uses regexes for parsing escaped character patterns. However, in the worst case (the expected common case of "no escaping needed") it will perform four regex match attempts per input character, resulting in significant garbage creation because the matchers aren't reused.
2. Skip the StringBuilder allocation for raw strings and for strings that don't need any unescaping.
3. Minor: use Java StringBuilder instead of the Scala version: this removes a layer of indirection and may benefit JIT (we've seen positive results in some scenarios from this type of switch).

### Why are the changes needed?

unescapeSQLString showed up as a CPU and allocation hotspot in certain testing scenarios. See this flamegraph for an illustration of the relative costs of repeated regex matching in the old code:

![image](https://github.com/apache/spark/assets/50748/e045d9da-da0f-493c-a634-188acaeab1a9)

The new code is almost arbitrarily faster (e.g. can show ~arbitrary relative speedups, depending on the choice of input) for strings that don't require unescaping. For strings that _do_ need escaping, I tested extreme cases where _every_ character needs escaping: in these cases I see ~10-20x speedups (depending on the type of escaping). The new code should be faster in every scenario.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Correctness is covered by existing unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#47062 from JoshRosen/unescapeSQLString-optimizations.

Authored-by: Josh Rosen <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
JoshRosen authored and HyukjinKwon committed Jun 25, 2024
1 parent b49479b commit 51f1103
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
*/
package org.apache.spark.sql.catalyst.util

import java.lang.{Long => JLong}
import java.nio.CharBuffer
import java.lang.{Long => JLong, StringBuilder => JStringBuilder}

import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
Expand All @@ -26,16 +25,10 @@ import org.antlr.v4.runtime.tree.TerminalNode
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}

trait SparkParserUtils {
val U16_CHAR_PATTERN = """\\u([a-fA-F0-9]{4})(?s).*""".r
val U32_CHAR_PATTERN = """\\U([a-fA-F0-9]{8})(?s).*""".r
val OCTAL_CHAR_PATTERN = """\\([01][0-7]{2})(?s).*""".r
val ESCAPED_CHAR_PATTERN = """\\((?s).)(?s).*""".r

/** Unescape backslash-escaped string enclosed by quotes. */
def unescapeSQLString(b: String): String = {
val sb = new StringBuilder(b.length())

def appendEscapedChar(n: Char): Unit = {
def appendEscapedChar(n: Char, sb: JStringBuilder): Unit = {
n match {
case '0' => sb.append('\u0000')
case 'b' => sb.append('\b')
Expand All @@ -50,22 +43,64 @@ trait SparkParserUtils {
}
}

if (b.startsWith("r") || b.startsWith("R")) {
def allCharsAreHex(s: String, start: Int, length: Int): Boolean = {
val end = start + length
var i = start
while (i < end) {
val c = s.charAt(i)
val cIsHex = (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')
if (!cIsHex) {
return false
}
i += 1
}
true
}

def isThreeDigitOctalEscape(s: String, start: Int): Boolean = {
val firstChar = s.charAt(start)
val secondChar = s.charAt(start + 1)
val thirdChar = s.charAt(start + 2)
(firstChar == '0' || firstChar == '1') &&
(secondChar >= '0' && secondChar <= '7') &&
(thirdChar >= '0' && thirdChar <= '7')
}

val isRawString = {
val firstChar = b.charAt(0)
firstChar == 'r' || firstChar == 'R'
}

if (isRawString) {
// Skip the 'r' or 'R' and the first and last quotations enclosing the string literal.
b.substring(2, b.length - 1)
} else if (b.indexOf('\\') == -1) {
// Fast path for the common case where the string has no escaped characters,
// in which case we just skip the first and last quotations enclosing the string literal.
b.substring(1, b.length - 1)
} else {
val sb = new JStringBuilder(b.length())
// Skip the first and last quotations enclosing the string literal.
val charBuffer = CharBuffer.wrap(b, 1, b.length - 1)

while (charBuffer.remaining() > 0) {
charBuffer match {
case U16_CHAR_PATTERN(cp) =>
var i = 1
val length = b.length - 1
while (i < length) {
val c = b.charAt(i)
if (c != '\\' || i + 1 == length) {
// Either a regular character or a backslash at the end of the string:
sb.append(c)
i += 1
} else {
// A backslash followed by at least one character:
i += 1
val cAfterBackslash = b.charAt(i)
if (cAfterBackslash == 'u' && i + 1 + 4 <= length && allCharsAreHex(b, i + 1, 4)) {
// \u0000 style 16-bit unicode character literals.
sb.append(Integer.parseInt(cp, 16).toChar)
charBuffer.position(charBuffer.position() + 6)
case U32_CHAR_PATTERN(cp) =>
sb.append(Integer.parseInt(b, i + 1, i + 1 + 4, 16).toChar)
i += 1 + 4
} else if (cAfterBackslash == 'U' && i + 1 + 8 <= length && allCharsAreHex(b, i + 1, 8)) {
// \U00000000 style 32-bit unicode character literals.
// Use Long to treat codePoint as unsigned in the range of 32-bit.
val codePoint = JLong.parseLong(cp, 16)
val codePoint = JLong.parseLong(b, i + 1, i + 1 + 8, 16)
if (codePoint < 0x10000) {
sb.append((codePoint & 0xFFFF).toChar)
} else {
Expand All @@ -74,21 +109,18 @@ trait SparkParserUtils {
sb.append(highSurrogate.toChar)
sb.append(lowSurrogate.toChar)
}
charBuffer.position(charBuffer.position() + 10)
case OCTAL_CHAR_PATTERN(cp) =>
i += 1 + 8
} else if (i + 3 <= length && isThreeDigitOctalEscape(b, i)) {
// \000 style character literals.
sb.append(Integer.parseInt(cp, 8).toChar)
charBuffer.position(charBuffer.position() + 4)
case ESCAPED_CHAR_PATTERN(c) =>
// escaped character literals.
appendEscapedChar(c.charAt(0))
charBuffer.position(charBuffer.position() + 2)
case _ =>
// non-escaped character literals.
sb.append(charBuffer.get())
sb.append(Integer.parseInt(b, i, i + 3, 8).toChar)
i += 3
} else {
appendEscapedChar(cAfterBackslash, sb)
i += 1
}
}
}
sb.toString()
sb.toString
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ class ParserUtilsSuite extends SparkFunSuite {
|cd\ef"""".stripMargin) ==
"""ab
|cdef""".stripMargin)

// String with an invalid '\' as the last character.
assert(unescapeSQLString(""""abc\"""") == "abc\\")

// Strings containing invalid Unicode escapes with non-hex characters.
assert(unescapeSQLString("\"abc\\uXXXXa\"") == "abcuXXXXa")
assert(unescapeSQLString("\"abc\\uxxxxa\"") == "abcuxxxxa")
assert(unescapeSQLString("\"abc\\UXXXXXXXXa\"") == "abcUXXXXXXXXa")
assert(unescapeSQLString("\"abc\\Uxxxxxxxxa\"") == "abcUxxxxxxxxa")
// Guard against off-by-one errors in the "all chars are hex" routine:
assert(unescapeSQLString("\"abc\\uAAAXa\"") == "abcuAAAXa")

// scalastyle:on nonascii
}

Expand Down

0 comments on commit 51f1103

Please sign in to comment.