Skip to content

Commit

Permalink
support non-ascii chars
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed May 11, 2018
1 parent 3edd243 commit ae89d00
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public static int transformChar(
*/
public static int getReplacementChar(String rep, int def) {
if (rep != null && rep.length() > 0) {
return rep.charAt(0);
return rep.codePointAt(0);
}
return def;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,25 @@ trait MaskLike {

protected val maskUtilsClassName: String = classOf[MaskExpressionsUtils].getName

def maskAndAppendToStringBuilderCode(
def inputStringLengthCode(inputString: String, length: String): String = {
s"${CodeGenerator.JAVA_INT} $length = $inputString.codePointCount(0, $inputString.length());"
}

def appendMaskedToStringBuilderCode(
ctx: CodegenContext,
sb: String,
inputString: String,
start: String,
end: String): String = {
offset: String,
numChars: String): String = {
val i = ctx.freshName("i")
val codePoint = ctx.freshName("codePoint")
s"""
|for (${CodeGenerator.JAVA_INT} $i = $start; $i < $end; $i ++) {
| $sb.appendCodePoint($maskUtilsClassName.transformChar($inputString.charAt($i),
|for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) {
| ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset);
| $sb.appendCodePoint($maskUtilsClassName.transformChar($codePoint,
| $upperReplacement, $lowerReplacement,
| $digitReplacement, $defaultMaskedOther));
| $offset += Character.charCount($codePoint);
|}
""".stripMargin
}
Expand All @@ -58,15 +65,51 @@ trait MaskLike {
ctx: CodegenContext,
sb: String,
inputString: String,
start: String,
end: String): String = {
offset: String,
numChars: String): String = {
val i = ctx.freshName("i")
val codePoint = ctx.freshName("codePoint")
s"""
|for (${CodeGenerator.JAVA_INT} $i = $start; $i < $end; $i ++) {
| $sb.appendCodePoint($inputString.charAt($i));
|for (${CodeGenerator.JAVA_INT} $i = 0; $i < $numChars; $i++) {
| ${CodeGenerator.JAVA_INT} $codePoint = $inputString.codePointAt($offset);
| $sb.appendCodePoint($codePoint);
| $offset += Character.charCount($codePoint);
|}
""".stripMargin
}

def appendMaskedToStringBuffer(
sb: StringBuffer,
inputString: String,
startOffset: Int,
numChars: Int): Int = {
var offset = startOffset
(1 to numChars) foreach { _ =>
val codePoint = inputString.codePointAt(offset)
sb.appendCodePoint(transformChar(
codePoint,
upperReplacement,
lowerReplacement,
digitReplacement,
defaultMaskedOther))
offset += Character.charCount(codePoint)
}
offset
}

def appendUnchangedToStringBuffer(
sb: StringBuffer,
inputString: String,
startOffset: Int,
numChars: Int): Int = {
var offset = startOffset
(1 to numChars) foreach { _ =>
val codePoint = inputString.codePointAt(offset)
sb.appendCodePoint(codePoint)
offset += Character.charCount(codePoint)
}
offset
}
}

trait MaskLikeWithN extends MaskLike {
Expand Down Expand Up @@ -128,21 +171,27 @@ case class Mask(child: Expression, upper: String, lower: String, digit: String)
this(child, extractReplacement(upper), extractReplacement(lower), extractReplacement(digit))

override def nullSafeEval(input: Any): Any = {
val res = input.asInstanceOf[UTF8String].toString.map(transformChar(
_, upperReplacement, lowerReplacement, digitReplacement, defaultMaskedOther).toChar)
UTF8String.fromString(res)
val str = input.asInstanceOf[UTF8String].toString
val length = str.codePointCount(0, str.length())
val sb = new StringBuffer(length)
appendMaskedToStringBuffer(sb, str, 0, length)
UTF8String.fromString(sb.toString)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (input: String) => {
val sb = ctx.freshName("sb")
val length = ctx.freshName("length")
val offset = ctx.freshName("offset")
val inputString = ctx.freshName("inputString")
s"""
|String $inputString = $input.toString();
|StringBuilder $sb = new StringBuilder($inputString.length());
|${maskAndAppendToStringBuilderCode(ctx, sb, inputString, "0", s"$inputString.length()")}
|${inputStringLengthCode(inputString, length)}
|StringBuilder $sb = new StringBuilder($length);
|${CodeGenerator.JAVA_INT} $offset = 0;
|${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, length)}
|${ev.value} = UTF8String.fromString($sb.toString());
|""".stripMargin
""".stripMargin
})
}

Expand Down Expand Up @@ -197,26 +246,31 @@ case class MaskFirstN(
extractReplacement(digit))

override def nullSafeEval(input: Any): Any = {
val inputString = input.asInstanceOf[UTF8String].toString
val (firstN, others) = inputString.splitAt(charCount)
val transformed = firstN.map(transformChar(
_, upperReplacement, lowerReplacement, digitReplacement, defaultMaskedOther).toChar)
UTF8String.fromString(transformed + others)
val str = input.asInstanceOf[UTF8String].toString
val length = str.codePointCount(0, str.length())
val endOfMask = if (charCount > length) length else charCount
val sb = new StringBuffer(length)
val offset = appendMaskedToStringBuffer(sb, str, 0, endOfMask)
appendUnchangedToStringBuffer(sb, str, offset, length - endOfMask)
UTF8String.fromString(sb.toString)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (input: String) => {
val sb = ctx.freshName("sb")
val length = ctx.freshName("length")
val offset = ctx.freshName("offset")
val inputString = ctx.freshName("inputString")
val endOfMask = ctx.freshName("endOfMask")
s"""
|String $inputString = $input.toString();
|${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $inputString.length() ?
| $inputString.length() : $charCount;
|StringBuilder $sb = new StringBuilder($inputString.length());
|${maskAndAppendToStringBuilderCode(ctx, sb, inputString, "0", endOfMask)}
|${inputStringLengthCode(inputString, length)}
|${CodeGenerator.JAVA_INT} $endOfMask = $charCount > $length ? $length : $charCount;
|${CodeGenerator.JAVA_INT} $offset = 0;
|StringBuilder $sb = new StringBuilder($length);
|${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)}
|${appendUnchangedToStringBuilderCode(
ctx, sb, inputString, endOfMask, s"$inputString.length()")}
ctx, sb, inputString, offset, s"$length - $endOfMask")}
|${ev.value} = UTF8String.fromString($sb.toString());
|""".stripMargin
})
Expand Down Expand Up @@ -275,26 +329,32 @@ case class MaskLastN(
extractReplacement(digit))

override def nullSafeEval(input: Any): Any = {
val inputString = input.asInstanceOf[UTF8String].toString
val (others, lastN) = inputString.splitAt(inputString.length - charCount)
val transformed = lastN.map(transformChar(
_, upperReplacement, lowerReplacement, digitReplacement, defaultMaskedOther).toChar)
UTF8String.fromString(others + transformed)
val str = input.asInstanceOf[UTF8String].toString
val length = str.codePointCount(0, str.length())
val startOfMask = if (charCount >= length) 0 else length - charCount
val sb = new StringBuffer(length)
val offset = appendUnchangedToStringBuffer(sb, str, 0, startOfMask)
appendMaskedToStringBuffer(sb, str, offset, length - startOfMask)
UTF8String.fromString(sb.toString)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (input: String) => {
val sb = ctx.freshName("sb")
val length = ctx.freshName("length")
val offset = ctx.freshName("offset")
val inputString = ctx.freshName("inputString")
val startOfMask = ctx.freshName("startOfMask")
s"""
|String $inputString = $input.toString();
|${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $inputString.length() ?
| 0 : $inputString.length() - $charCount;
|StringBuilder $sb = new StringBuilder($inputString.length());
|${appendUnchangedToStringBuilderCode(ctx, sb, inputString, "0", startOfMask)}
|${maskAndAppendToStringBuilderCode(
ctx, sb, inputString, startOfMask, s"$inputString.length()")}
|${inputStringLengthCode(inputString, length)}
|${CodeGenerator.JAVA_INT} $startOfMask = $charCount >= $length ?
| 0 : $length - $charCount;
|${CodeGenerator.JAVA_INT} $offset = 0;
|StringBuilder $sb = new StringBuilder($length);
|${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)}
|${appendMaskedToStringBuilderCode(
ctx, sb, inputString, offset, s"$length - $startOfMask")}
|${ev.value} = UTF8String.fromString($sb.toString());
|""".stripMargin
})
Expand Down Expand Up @@ -353,26 +413,31 @@ case class MaskShowFirstN(
extractReplacement(digit))

override def nullSafeEval(input: Any): Any = {
val inputString = input.asInstanceOf[UTF8String].toString
val (firstN, others) = inputString.splitAt(charCount)
val transformed = others.map(transformChar(
_, upperReplacement, lowerReplacement, digitReplacement, defaultMaskedOther).toChar)
UTF8String.fromString(firstN + transformed)
val str = input.asInstanceOf[UTF8String].toString
val length = str.codePointCount(0, str.length())
val startOfMask = if (charCount > length) length else charCount
val sb = new StringBuffer(length)
val offset = appendUnchangedToStringBuffer(sb, str, 0, startOfMask)
appendMaskedToStringBuffer(sb, str, offset, length - startOfMask)
UTF8String.fromString(sb.toString)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (input: String) => {
val sb = ctx.freshName("sb")
val length = ctx.freshName("length")
val offset = ctx.freshName("offset")
val inputString = ctx.freshName("inputString")
val startOfMask = ctx.freshName("startOfMask")
s"""
|String $inputString = $input.toString();
|${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $inputString.length() ?
| $inputString.length() : $charCount;
|StringBuilder $sb = new StringBuilder($inputString.length());
|${appendUnchangedToStringBuilderCode(ctx, sb, inputString, "0", startOfMask)}
|${maskAndAppendToStringBuilderCode(
ctx, sb, inputString, startOfMask, s"$inputString.length()")}
|${inputStringLengthCode(inputString, length)}
|${CodeGenerator.JAVA_INT} $startOfMask = $charCount > $length ? $length : $charCount;
|${CodeGenerator.JAVA_INT} $offset = 0;
|StringBuilder $sb = new StringBuilder($length);
|${appendUnchangedToStringBuilderCode(ctx, sb, inputString, offset, startOfMask)}
|${appendMaskedToStringBuilderCode(
ctx, sb, inputString, offset, s"$length - $startOfMask")}
|${ev.value} = UTF8String.fromString($sb.toString());
|""".stripMargin
})
Expand Down Expand Up @@ -431,26 +496,31 @@ case class MaskShowLastN(
extractReplacement(digit))

override def nullSafeEval(input: Any): Any = {
val inputString = input.asInstanceOf[UTF8String].toString
val (others, lastN) = inputString.splitAt(inputString.length - charCount)
val transformed = others.map(transformChar(
_, upperReplacement, lowerReplacement, digitReplacement, defaultMaskedOther).toChar)
UTF8String.fromString(transformed + lastN)
val str = input.asInstanceOf[UTF8String].toString
val length = str.codePointCount(0, str.length())
val endOfMask = if (charCount >= length) 0 else length - charCount
val sb = new StringBuffer(length)
val offset = appendMaskedToStringBuffer(sb, str, 0, endOfMask)
appendUnchangedToStringBuffer(sb, str, offset, length - endOfMask)
UTF8String.fromString(sb.toString)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (input: String) => {
val sb = ctx.freshName("sb")
val length = ctx.freshName("length")
val offset = ctx.freshName("offset")
val inputString = ctx.freshName("inputString")
val endOfMask = ctx.freshName("endOfMask")
s"""
|String $inputString = $input.toString();
|${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $inputString.length() ?
| 0 : $inputString.length() - $charCount;
|StringBuilder $sb = new StringBuilder($inputString.length());
|${maskAndAppendToStringBuilderCode(ctx, sb, inputString, "0", endOfMask)}
|${inputStringLengthCode(inputString, length)}
|${CodeGenerator.JAVA_INT} $endOfMask = $charCount >= $length ? 0 : $length - $charCount;
|${CodeGenerator.JAVA_INT} $offset = 0;
|StringBuilder $sb = new StringBuilder($length);
|${appendMaskedToStringBuilderCode(ctx, sb, inputString, offset, endOfMask)}
|${appendUnchangedToStringBuilderCode(
ctx, sb, inputString, endOfMask, s"$inputString.length()")}
ctx, sb, inputString, offset, s"$length - $endOfMask")}
|${ev.value} = UTF8String.fromString($sb.toString());
|""".stripMargin
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Mask(Literal("abcd-EFGH-8765-4321"), "", "", ""), "xxxx-XXXX-nnnn-nnnn")
// scalastyle:off nonascii
checkEvaluation(Mask(Literal("Ul9U"), "\u2200", null, null), "\u2200xn\u2200")
checkEvaluation(new Mask(Literal("Hello World, こんにちは, 𠀋"), Literal(""), Literal("𡈽")),
"あ𡈽𡈽𡈽𡈽 あ𡈽𡈽𡈽𡈽, こんにちは, 𠀋")
// scalastyle:on nonascii
intercept[AnalysisException] {
checkEvaluation(new Mask(Literal(""), Literal(1)), "")
Expand Down Expand Up @@ -89,6 +91,8 @@ class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
"abcd-EFGH-8765-4321")
// scalastyle:off nonascii
checkEvaluation(MaskFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U")
checkEvaluation(new MaskFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)),
"あ, 𠀋, Xxxxo World")
// scalastyle:on nonascii
}

Expand Down Expand Up @@ -131,6 +135,8 @@ class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
"abcd-EFGH-8765-4321")
// scalastyle:off nonascii
checkEvaluation(MaskLastN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200")
checkEvaluation(new MaskLastN(Literal("あ, 𠀋, Hello World"), Literal(10)),
"あ, 𠀋, Hxxxx Xxxxx")
// scalastyle:on nonascii
}

Expand Down Expand Up @@ -172,6 +178,8 @@ class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
"xxxx-XXXX-nnnn-nnnn")
// scalastyle:off nonascii
checkEvaluation(MaskShowFirstN(Literal("Ul9U"), 2, "\u2200", null, null), "Uln\u2200")
checkEvaluation(new MaskShowFirstN(Literal("あ, 𠀋, Hello World"), Literal(10)),
"あ, 𠀋, Hellx Xxxxx")
// scalastyle:on nonascii
}

Expand Down Expand Up @@ -212,6 +220,8 @@ class MaskExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
"xxxx-XXXX-nnnn-nnnn")
// scalastyle:off nonascii
checkEvaluation(MaskShowLastN(Literal("Ul9U"), 2, "\u2200", null, null), "\u2200x9U")
checkEvaluation(new MaskShowLastN(Literal("あ, 𠀋, Hello World"), Literal(10)),
"あ, 𠀋, Xello World")
// scalastyle:on nonascii
}

Expand Down

0 comments on commit ae89d00

Please sign in to comment.