From ae89d00c52ffeab937456aa444bce299f834d1c2 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 11 May 2018 18:21:32 +0200 Subject: [PATCH] support non-ascii chars --- .../expressions/MaskExpressionsUtils.java | 2 +- .../expressions/maskExpressions.scala | 184 ++++++++++++------ .../expressions/MaskExpressionsSuite.scala | 10 + 3 files changed, 138 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java index b88b12df4531a..02f3a3f0c5343 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/MaskExpressionsUtils.java @@ -73,7 +73,7 @@ public static int transformChar( */ public static int getReplacementChar(String rep, int def) { if (rep != null && rep.length() > 0) { - return rep.charAt(0); + return rep.codePointAt(0); } return def; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala index 65eb9e449cf14..4ce396ea9d4a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -38,18 +38,25 @@ trait MaskLike { protected val maskUtilsClassName: String = classOf[MaskExpressionsUtils].getName - def maskAndAppendToStringBuilderCode( + def inputStringLengthCode(inputString: String, length: String): String = { + s"${CodeGenerator.JAVA_INT} $length = $inputString.codePointCount(0, $inputString.length());" + } + + def appendMaskedToStringBuilderCode( ctx: CodegenContext, sb: String, inputString: String, - start: String, - end: String): String = { + offset: String, + numChars: String): String = { val i = ctx.freshName("i") + val codePoint = ctx.freshName("codePoint") s""" - |for (${CodeGenerator.JAVA_INT} $i = $start; $i < $end; $i ++) { - | $sb.appendCodePoint($maskUtilsClassName.transformChar($inputString.charAt($i), + |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { + | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); + | $sb.appendCodePoint($maskUtilsClassName.transformChar($codePoint, | $upperReplacement, $lowerReplacement, | $digitReplacement, $defaultMaskedOther)); + | $offset += Character.charCount($codePoint); |} """.stripMargin } @@ -58,15 +65,51 @@ trait MaskLike { ctx: CodegenContext, sb: String, inputString: String, - start: String, - end: String): String = { + offset: String, + numChars: String): String = { val i = ctx.freshName("i") + val codePoint = ctx.freshName("codePoint") s""" - |for (${CodeGenerator.JAVA_INT} $i = $start; $i < $end; $i ++) { - | $sb.appendCodePoint($inputString.charAt($i)); + |for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) { + | ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset); + | $sb.appendCodePoint($codePoint); + | $offset += Character.charCount($codePoint); |} """.stripMargin } + + def appendMaskedToStringBuffer( + sb: StringBuffer, + inputString: String, + startOffset: Int, + numChars: Int): Int = { + var offset = startOffset + (1 to numChars) foreach { _ => + val codePoint = inputString.codePointAt(offset) + sb.appendCodePoint(transformChar( + codePoint, + upperReplacement, + lowerReplacement, + digitReplacement, + defaultMaskedOther)) + offset += Character.charCount(codePoint) + } + offset + } + + def appendUnchangedToStringBuffer( + sb: StringBuffer, + inputString: String, + startOffset: Int, + numChars: Int): Int = { + var offset = startOffset + (1 to numChars) foreach { _ => + val codePoint = inputString.codePointAt(offset) + sb.appendCodePoint(codePoint) + offset += Character.charCount(codePoint) + } + offset + } } trait MaskLikeWithN extends MaskLike { @@ -128,21 +171,27 @@ case class Mask(child: Expression, upper: String, lower: String, digit: String) this(child, extractReplacement(upper), extractReplacement(lower), extractReplacement(digit)) override def nullSafeEval(input: Any): Any = { - val res = input.asInstanceOf[UTF8String].toString.map(transformChar( - _, upperReplacement, lowerReplacement, digitReplacement, defaultMaskedOther).toChar) - UTF8String.fromString(res) + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val sb = new StringBuffer(length) + appendMaskedToStringBuffer(sb, str, 0, length) + UTF8String.fromString(sb.toString) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (input: String) => { val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") val inputString = ctx.freshName("inputString") s""" |String $inputString = $input.toString(); - |StringBuilder $sb = new StringBuilder($inputString.length()); - |${maskAndAppendToStringBuilderCode(ctx, sb, inputString, "0", s"$inputString.length()")} + |${inputStringLengthCode(inputString, length)} + |StringBuilder $sb = new StringBuilder($length); + |${CodeGenerator.JAVA_INT} $offset = 0; + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, length)} |${ev.value} = UTF8String.fromString($sb.toString()); - |""".stripMargin + """.stripMargin }) } @@ -197,26 +246,31 @@ case class MaskFirstN( extractReplacement(digit)) override def nullSafeEval(input: Any): Any = { - val inputString = input.asInstanceOf[UTF8String].toString - val (firstN, others) = inputString.splitAt(charCount) - val transformed = firstN.map(transformChar( - _, upperReplacement, lowerReplacement, digitReplacement, defaultMaskedOther).toChar) - UTF8String.fromString(transformed + others) + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val endOfMask = if (charCount > length) length else charCount + val sb = new StringBuffer(length) + val offset = appendMaskedToStringBuffer(sb, str, 0, endOfMask) + appendUnchangedToStringBuffer(sb, str, offset, length - endOfMask) + UTF8String.fromString(sb.toString) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (input: String) => { val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") val inputString = ctx.freshName("inputString") val endOfMask = ctx.freshName("endOfMask") s""" |String $inputString = $input.toString(); - |${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $inputString.length() ? - | $inputString.length() : $charCount; - |StringBuilder $sb = new StringBuilder($inputString.length()); - |${maskAndAppendToStringBuilderCode(ctx, sb, inputString, "0", endOfMask)} + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $length ? $length : $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} |${appendUnchangedToStringBuilderCode( - ctx, sb, inputString, endOfMask, s"$inputString.length()")} + ctx, sb, inputString, offset, s"$length - $endOfMask")} |${ev.value} = UTF8String.fromString($sb.toString()); |""".stripMargin }) @@ -275,26 +329,32 @@ case class MaskLastN( extractReplacement(digit)) override def nullSafeEval(input: Any): Any = { - val inputString = input.asInstanceOf[UTF8String].toString - val (others, lastN) = inputString.splitAt(inputString.length - charCount) - val transformed = lastN.map(transformChar( - _, upperReplacement, lowerReplacement, digitReplacement, defaultMaskedOther).toChar) - UTF8String.fromString(others + transformed) + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val startOfMask = if (charCount >= length) 0 else length - charCount + val sb = new StringBuffer(length) + val offset = appendUnchangedToStringBuffer(sb, str, 0, startOfMask) + appendMaskedToStringBuffer(sb, str, offset, length - startOfMask) + UTF8String.fromString(sb.toString) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (input: String) => { val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") val inputString = ctx.freshName("inputString") val startOfMask = ctx.freshName("startOfMask") s""" |String $inputString = $input.toString(); - |${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $inputString.length() ? - | 0 : $inputString.length() - $charCount; - |StringBuilder $sb = new StringBuilder($inputString.length()); - |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, "0", startOfMask)} - |${maskAndAppendToStringBuilderCode( - ctx, sb, inputString, startOfMask, s"$inputString.length()")} + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $length ? + | 0 : $length - $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} + |${appendMaskedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $startOfMask")} |${ev.value} = UTF8String.fromString($sb.toString()); |""".stripMargin }) @@ -353,26 +413,31 @@ case class MaskShowFirstN( extractReplacement(digit)) override def nullSafeEval(input: Any): Any = { - val inputString = input.asInstanceOf[UTF8String].toString - val (firstN, others) = inputString.splitAt(charCount) - val transformed = others.map(transformChar( - _, upperReplacement, lowerReplacement, digitReplacement, defaultMaskedOther).toChar) - UTF8String.fromString(firstN + transformed) + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val startOfMask = if (charCount > length) length else charCount + val sb = new StringBuffer(length) + val offset = appendUnchangedToStringBuffer(sb, str, 0, startOfMask) + appendMaskedToStringBuffer(sb, str, offset, length - startOfMask) + UTF8String.fromString(sb.toString) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (input: String) => { val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") val inputString = ctx.freshName("inputString") val startOfMask = ctx.freshName("startOfMask") s""" |String $inputString = $input.toString(); - |${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $inputString.length() ? - | $inputString.length() : $charCount; - |StringBuilder $sb = new StringBuilder($inputString.length()); - |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, "0", startOfMask)} - |${maskAndAppendToStringBuilderCode( - ctx, sb, inputString, startOfMask, s"$inputString.length()")} + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $length ? $length : $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)} + |${appendMaskedToStringBuilderCode( + ctx, sb, inputString, offset, s"$length - $startOfMask")} |${ev.value} = UTF8String.fromString($sb.toString()); |""".stripMargin }) @@ -431,26 +496,31 @@ case class MaskShowLastN( extractReplacement(digit)) override def nullSafeEval(input: Any): Any = { - val inputString = input.asInstanceOf[UTF8String].toString - val (others, lastN) = inputString.splitAt(inputString.length - charCount) - val transformed = others.map(transformChar( - _, upperReplacement, lowerReplacement, digitReplacement, defaultMaskedOther).toChar) - UTF8String.fromString(transformed + lastN) + val str = input.asInstanceOf[UTF8String].toString + val length = str.codePointCount(0, str.length()) + val endOfMask = if (charCount >= length) 0 else length - charCount + val sb = new StringBuffer(length) + val offset = appendMaskedToStringBuffer(sb, str, 0, endOfMask) + appendUnchangedToStringBuffer(sb, str, offset, length - endOfMask) + UTF8String.fromString(sb.toString) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (input: String) => { val sb = ctx.freshName("sb") + val length = ctx.freshName("length") + val offset = ctx.freshName("offset") val inputString = ctx.freshName("inputString") val endOfMask = ctx.freshName("endOfMask") s""" |String $inputString = $input.toString(); - |${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $inputString.length() ? - | 0 : $inputString.length() - $charCount; - |StringBuilder $sb = new StringBuilder($inputString.length()); - |${maskAndAppendToStringBuilderCode(ctx, sb, inputString, "0", endOfMask)} + |${inputStringLengthCode(inputString, length)} + |${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $length ? 0 : $length - $charCount; + |${CodeGenerator.JAVA_INT} $offset = 0; + |StringBuilder $sb = new StringBuilder($length); + |${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)} |${appendUnchangedToStringBuilderCode( - ctx, sb, inputString, endOfMask, s"$inputString.length()")} + ctx, sb, inputString, offset, s"$length - $endOfMask")} |${ev.value} = UTF8String.fromString($sb.toString()); |""".stripMargin }) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala index 088c89d8335df..45ed081af35ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MaskExpressionsSuite.scala @@ -46,6 +46,8 @@ class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "", "", ""), "xxxx-XXXX-nnnn-nnnn") // scalastyle:off nonascii checkEvaluation(Mask(Literal("Ul9U"), "\u2200", null, null), "\u2200xn\u2200") + checkEvaluation(new Mask(Literal("Hello World, こんにちは, 𠀋"), Literal("あ"), Literal("𡈽")), + "あ𡈽𡈽𡈽𡈽 あ𡈽𡈽𡈽𡈽, こんにちは, 𠀋") // scalastyle:on nonascii intercept[AnalysisException] { checkEvaluation(new Mask(Literal(""), Literal(1)), "") @@ -89,6 +91,8 @@ class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "abcd-EFGH-8765-4321") // scalastyle:off nonascii checkEvaluation(MaskFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") + checkEvaluation(new MaskFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Xxxxo World") // scalastyle:on nonascii } @@ -131,6 +135,8 @@ class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "abcd-EFGH-8765-4321") // scalastyle:off nonascii checkEvaluation(MaskLastN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") + checkEvaluation(new MaskLastN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Hxxxx Xxxxx") // scalastyle:on nonascii } @@ -172,6 +178,8 @@ class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "xxxx-XXXX-nnnn-nnnn") // scalastyle:off nonascii checkEvaluation(MaskShowFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200") + checkEvaluation(new MaskShowFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Hellx Xxxxx") // scalastyle:on nonascii } @@ -212,6 +220,8 @@ class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { "xxxx-XXXX-nnnn-nnnn") // scalastyle:off nonascii checkEvaluation(MaskShowLastN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U") + checkEvaluation(new MaskShowLastN(Literal("あ, 𠀋, Hello World"), Literal(10)), + "あ, 𠀋, Xello World") // scalastyle:on nonascii }