diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala index a4ce5fb120340..7597cb1d9087d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkParserUtils.scala @@ -16,8 +16,7 @@ */ package org.apache.spark.sql.catalyst.util -import java.lang.{Long => JLong} -import java.nio.CharBuffer +import java.lang.{Long => JLong, StringBuilder => JStringBuilder} import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.misc.Interval @@ -26,16 +25,10 @@ import org.antlr.v4.runtime.tree.TerminalNode import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} trait SparkParserUtils { - val U16_CHAR_PATTERN = """\\u([a-fA-F0-9]{4})(?s).*""".r - val U32_CHAR_PATTERN = """\\U([a-fA-F0-9]{8})(?s).*""".r - val OCTAL_CHAR_PATTERN = """\\([01][0-7]{2})(?s).*""".r - val ESCAPED_CHAR_PATTERN = """\\((?s).)(?s).*""".r /** Unescape backslash-escaped string enclosed by quotes. */ def unescapeSQLString(b: String): String = { - val sb = new StringBuilder(b.length()) - - def appendEscapedChar(n: Char): Unit = { + def appendEscapedChar(n: Char, sb: JStringBuilder): Unit = { n match { case '0' => sb.append('\u0000') case 'b' => sb.append('\b') @@ -50,22 +43,64 @@ trait SparkParserUtils { } } - if (b.startsWith("r") || b.startsWith("R")) { + def allCharsAreHex(s: String, start: Int, length: Int): Boolean = { + val end = start + length + var i = start + while (i < end) { + val c = s.charAt(i) + val cIsHex = (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F') + if (!cIsHex) { + return false + } + i += 1 + } + true + } + + def isThreeDigitOctalEscape(s: String, start: Int): Boolean = { + val firstChar = s.charAt(start) + val secondChar = s.charAt(start + 1) + val thirdChar = s.charAt(start + 2) + (firstChar == '0' || firstChar == '1') && + (secondChar >= '0' && secondChar <= '7') && + (thirdChar >= '0' && thirdChar <= '7') + } + + val isRawString = { + val firstChar = b.charAt(0) + firstChar == 'r' || firstChar == 'R' + } + + if (isRawString) { + // Skip the 'r' or 'R' and the first and last quotations enclosing the string literal. b.substring(2, b.length - 1) + } else if (b.indexOf('\\') == -1) { + // Fast path for the common case where the string has no escaped characters, + // in which case we just skip the first and last quotations enclosing the string literal. + b.substring(1, b.length - 1) } else { + val sb = new JStringBuilder(b.length()) // Skip the first and last quotations enclosing the string literal. - val charBuffer = CharBuffer.wrap(b, 1, b.length - 1) - - while (charBuffer.remaining() > 0) { - charBuffer match { - case U16_CHAR_PATTERN(cp) => + var i = 1 + val length = b.length - 1 + while (i < length) { + val c = b.charAt(i) + if (c != '\\' || i + 1 == length) { + // Either a regular character or a backslash at the end of the string: + sb.append(c) + i += 1 + } else { + // A backslash followed by at least one character: + i += 1 + val cAfterBackslash = b.charAt(i) + if (cAfterBackslash == 'u' && i + 1 + 4 <= length && allCharsAreHex(b, i + 1, 4)) { // \u0000 style 16-bit unicode character literals. - sb.append(Integer.parseInt(cp, 16).toChar) - charBuffer.position(charBuffer.position() + 6) - case U32_CHAR_PATTERN(cp) => + sb.append(Integer.parseInt(b, i + 1, i + 1 + 4, 16).toChar) + i += 1 + 4 + } else if (cAfterBackslash == 'U' && i + 1 + 8 <= length && allCharsAreHex(b, i + 1, 8)) { // \U00000000 style 32-bit unicode character literals. // Use Long to treat codePoint as unsigned in the range of 32-bit. - val codePoint = JLong.parseLong(cp, 16) + val codePoint = JLong.parseLong(b, i + 1, i + 1 + 8, 16) if (codePoint < 0x10000) { sb.append((codePoint & 0xFFFF).toChar) } else { @@ -74,21 +109,18 @@ trait SparkParserUtils { sb.append(highSurrogate.toChar) sb.append(lowSurrogate.toChar) } - charBuffer.position(charBuffer.position() + 10) - case OCTAL_CHAR_PATTERN(cp) => + i += 1 + 8 + } else if (i + 3 <= length && isThreeDigitOctalEscape(b, i)) { // \000 style character literals. - sb.append(Integer.parseInt(cp, 8).toChar) - charBuffer.position(charBuffer.position() + 4) - case ESCAPED_CHAR_PATTERN(c) => - // escaped character literals. - appendEscapedChar(c.charAt(0)) - charBuffer.position(charBuffer.position() + 2) - case _ => - // non-escaped character literals. - sb.append(charBuffer.get()) + sb.append(Integer.parseInt(b, i, i + 3, 8).toChar) + i += 3 + } else { + appendEscapedChar(cAfterBackslash, sb) + i += 1 + } } } - sb.toString() + sb.toString } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala index d9f3067d30e51..218304db3d591 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala @@ -131,6 +131,18 @@ class ParserUtilsSuite extends SparkFunSuite { |cd\ef"""".stripMargin) == """ab |cdef""".stripMargin) + + // String with an invalid '\' as the last character. + assert(unescapeSQLString(""""abc\"""") == "abc\\") + + // Strings containing invalid Unicode escapes with non-hex characters. + assert(unescapeSQLString("\"abc\\uXXXXa\"") == "abcuXXXXa") + assert(unescapeSQLString("\"abc\\uxxxxa\"") == "abcuxxxxa") + assert(unescapeSQLString("\"abc\\UXXXXXXXXa\"") == "abcUXXXXXXXXa") + assert(unescapeSQLString("\"abc\\Uxxxxxxxxa\"") == "abcUxxxxxxxxa") + // Guard against off-by-one errors in the "all chars are hex" routine: + assert(unescapeSQLString("\"abc\\uAAAXa\"") == "abcuAAAXa") + // scalastyle:on nonascii }