Skip to content

Commit

Permalink
[SPARK-31343][SQL][TESTS] Check codegen does not fail on expressions …
Browse files Browse the repository at this point in the history
…with escape chars in string parameters

### What changes were proposed in this pull request?
In the PR, I propose to add tests to check that code generation doesn't fail if expressions string argument contains escape chars. The PR adds similar tests added by apache#20182 for `from_utc_timestamp` / `to_utc_timestamp`.

### Why are the changes needed?
To prevent regressions in the future.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
By running the affected tests

Closes apache#28115 from MaxGekk/tests-arg-escape.

Authored-by: Maxim Gekk <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
MaxGekk authored and Seongjin Cho committed Apr 14, 2020
1 parent 98fcdf6 commit cc073f9
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.sql.Timestamp

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.types.{IntegerType, StringType}

/** A static class for testing purpose. */
Expand Down Expand Up @@ -101,6 +102,11 @@ class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelp
checkEvaluation(createExpr(staticClassName, "method4", 4, "four"), "m4four")
}

test("escaping of class and method names") {
GenerateUnsafeProjection.generate(
CallMethodViaReflection(Seq(Literal("\"quote"), Literal("\"quote"), Literal(null))) :: Nil)
}

private def createExpr(className: String, methodName: String, args: Any*) = {
CallMethodViaReflection(
Literal.create(className, StringType) +:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Literal("H"), JST_OPT), "0")
checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), JST_OPT), "22")

// Test escaping of format
GenerateUnsafeProjection.generate(
DateFormatClass(Literal(ts), Literal("\"quote"), JST_OPT) :: Nil)

// SPARK-28072 The codegen path should work
checkEvaluation(
expression = DateFormatClass(
Expand Down Expand Up @@ -602,6 +606,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(NextDay(Literal.create(null, DateType), Literal("xx")), null)
checkEvaluation(
NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null)
// Test escaping of dayOfWeek
GenerateUnsafeProjection.generate(
NextDay(Literal(Date.valueOf("2015-07-23")), Literal("\"quote")) :: Nil)
}

test("TruncDate") {
Expand All @@ -625,6 +632,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
testTrunc(date, null, null)
testTrunc(null, "MON", null)
testTrunc(null, null, null)
// Test escaping of format
GenerateUnsafeProjection.generate(TruncDate(Literal(0, DateType), Literal("\"quote")) :: Nil)

testTrunc(Date.valueOf("2000-03-08"), "decade", Date.valueOf("2000-01-01"))
testTrunc(Date.valueOf("2000-03-08"), "century", Date.valueOf("1901-01-01"))
Expand Down Expand Up @@ -751,6 +760,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
}
// Test escaping of format
GenerateUnsafeProjection.generate(FromUnixTime(Literal(0L), Literal("\"quote")) :: Nil)
}

test("unix_timestamp") {
Expand Down Expand Up @@ -818,6 +829,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
}
// Test escaping of format
GenerateUnsafeProjection.generate(
UnixTimestamp(Literal("2015-07-24"), Literal("\"quote")) :: Nil)
}

test("to_unix_timestamp") {
Expand Down Expand Up @@ -893,6 +907,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
}
// Test escaping of format
GenerateUnsafeProjection.generate(
ToUnixTimestamp(Literal("2015-07-24"), Literal("\"quote")) :: Nil)
}

test("datediff") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.scalatest.exceptions.TestFailedException
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{PST, UTC, UTC_OPT}
Expand All @@ -49,6 +50,10 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
in Jackson's JsonFactory.createParser(byte[]) due to RFC-4627 encoding detection */
val badJson = "\u0000\u0000\u0000A\u0001AAA"

test("get_json_object escaping") {
GenerateUnsafeProjection.generate(GetJsonObject(Literal("\"quote"), Literal("\"quote")) :: Nil)
}

test("$.store.bicycle") {
checkEvaluation(
GetJsonObject(Literal(json), Literal("$.store.bicycle")),
Expand Down Expand Up @@ -265,6 +270,11 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
assert(jt.eval(null).toSeq.head === expected)
}

test("json_tuple escaping") {
GenerateUnsafeProjection.generate(
JsonTuple(Literal("\"quote") :: Literal("\"quote") :: Nil) :: Nil)
}

test("json_tuple - hive key 1") {
checkJsonTuple(
JsonTuple(
Expand Down Expand Up @@ -396,6 +406,12 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
InternalRow(UTF8String.fromString("1"), null, UTF8String.fromString("1")))
}

test("from_json escaping") {
val schema = StructType(StructField("\"quote", IntegerType) :: Nil)
GenerateUnsafeProjection.generate(
JsonToStructs(schema, Map.empty, Literal("\"quote"), UTC_OPT) :: Nil)
}

test("from_json") {
val jsonData = """{"a": 1}"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
Expand Down Expand Up @@ -549,6 +565,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
)
}

test("to_json escaping") {
val schema = StructType(StructField("\"quote", IntegerType) :: Nil)
val struct = Literal.create(create_row(1), schema)
GenerateUnsafeProjection.generate(
StructsToJson(Map.empty, struct, UTC_OPT) :: Nil)
}

test("to_json - struct") {
val schema = StructType(StructField("a", IntegerType) :: Nil)
val struct = Literal.create(create_row(1), schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.types.StringType

/**
Expand Down Expand Up @@ -255,6 +256,10 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

val nonNullExpr = RegExpReplace(Literal("100-200"), Literal("(\\d+)"), Literal("num"))
checkEvaluation(nonNullExpr, "num-num", row1)

// Test escaping of arguments
GenerateUnsafeProjection.generate(
RegExpReplace(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil)
}

test("SPARK-22570: RegExpReplace should not create a lot of global variables") {
Expand Down Expand Up @@ -305,6 +310,10 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
expr, row9, "Regex group count is 1, but the specified group index is 2")
checkExceptionInExpression[IllegalArgumentException](
expr, row10, "Regex group count is 0, but the specified group index is 1")

// Test escaping of arguments
GenerateUnsafeProjection.generate(
RegExpExtract(Literal("\"quote"), Literal("\"quote"), Literal(1)) :: Nil)
}

test("SPLIT") {
Expand All @@ -327,6 +336,10 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
StringSplit(s1, s2, -1), Seq("aa", "bb", "cc"), row1)
checkEvaluation(StringSplit(s1, s2, -1), null, row2)
checkEvaluation(StringSplit(s1, s2, -1), null, row3)

// Test escaping of arguments
GenerateUnsafeProjection.generate(
StringSplit(Literal("\"quote"), Literal("\"quote"), Literal(-1)) :: Nil)
}

test("SPARK-30759: cache initialization for literal patterns") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.types._

class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -155,6 +156,11 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(c1 endsWith "b", false, row)
checkEvaluation(c2 endsWith "b", null, row)
checkEvaluation(c1 endsWith Literal.create(null, StringType), null, row)

// Test escaping of arguments
GenerateUnsafeProjection.generate(Contains(Literal("\"quote"), Literal("\"quote")) :: Nil)
GenerateUnsafeProjection.generate(EndsWith(Literal("\"quote"), Literal("\"quote")) :: Nil)
GenerateUnsafeProjection.generate(StartsWith(Literal("\"quote"), Literal("\"quote")) :: Nil)
}

test("Substring") {
Expand Down Expand Up @@ -352,6 +358,10 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Decode(b, Literal("utf-8")), null, create_row(null))
checkEvaluation(Decode(Literal.create(null, BinaryType), Literal("utf-8")), null)
checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null))

// Test escaping of charset
GenerateUnsafeProjection.generate(Encode(a, Literal("\"quote")) :: Nil)
GenerateUnsafeProjection.generate(Decode(b, Literal("\"quote")) :: Nil)
}

test("initcap unit test") {
Expand Down Expand Up @@ -379,6 +389,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Levenshtein(Literal("千世"), Literal("fog")), 3)
checkEvaluation(Levenshtein(Literal("世界千世"), Literal("大a界b")), 4)
// scalastyle:on

// Test escaping of arguments:
GenerateUnsafeProjection.generate(Levenshtein(Literal("\"quotea"), Literal("\"quoteb")) :: Nil)
}

test("soundex unit test") {
Expand Down Expand Up @@ -560,6 +573,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringTrim(Literal("a"), Literal.create(null, StringType)), null)
checkEvaluation(StringTrim(Literal.create(null, StringType), Literal("a")), null)

// Test escaping of arguments
GenerateUnsafeProjection.generate(StringTrim(Literal("\"quote"), Literal("\"quote")) :: Nil)

checkEvaluation(StringTrim(Literal("yxTomxx"), Literal("xyz")), "Tom")
checkEvaluation(StringTrim(Literal("xxxbarxxx"), Literal("x")), "bar")
}
Expand Down Expand Up @@ -587,6 +603,10 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringTrimLeft(Literal.create(null, StringType), Literal("a")), null)
checkEvaluation(StringTrimLeft(Literal("a"), Literal.create(null, StringType)), null)

// Test escaping of arguments
GenerateUnsafeProjection.generate(
StringTrimLeft(Literal("\"quote"), Literal("\"quote")) :: Nil)

checkEvaluation(StringTrimLeft(Literal("zzzytest"), Literal("xyz")), "test")
checkEvaluation(StringTrimLeft(Literal("zzzytestxyz"), Literal("xyz")), "testxyz")
checkEvaluation(StringTrimLeft(Literal("xyxXxyLAST WORD"), Literal("xy")), "XxyLAST WORD")
Expand Down Expand Up @@ -616,6 +636,10 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringTrimRight(Literal("a"), Literal.create(null, StringType)), null)
checkEvaluation(StringTrimRight(Literal.create(null, StringType), Literal("a")), null)

// Test escaping of arguments
GenerateUnsafeProjection.generate(
StringTrimRight(Literal("\"quote"), Literal("\"quote")) :: Nil)

checkEvaluation(StringTrimRight(Literal("testxxzx"), Literal("xyz")), "test")
checkEvaluation(StringTrimRight(Literal("xyztestxxzx"), Literal("xyz")), "xyztest")
checkEvaluation(StringTrimRight(Literal("TURNERyxXxy"), Literal("xy")), "TURNERyxX")
Expand All @@ -632,6 +656,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
FormatString(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
checkEvaluation(
FormatString(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")

// Test escaping of arguments
GenerateUnsafeProjection.generate(FormatString(Literal("\"quote"), Literal("\"quote")) :: Nil)
}

test("SPARK-22603: FormatString should not generate codes beyond 64KB") {
Expand Down Expand Up @@ -662,6 +689,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringInstr(s1, s2), 1, create_row("花花世界", ""))
checkEvaluation(StringInstr(s1, s2), 0, create_row("花花世界", ""))
// scalastyle:on

// Test escaping of arguments
GenerateUnsafeProjection.generate(StringInstr(Literal("\"quote"), Literal("\"quote")) :: Nil)
}

test("LOCATE") {
Expand Down Expand Up @@ -718,6 +748,10 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringRPad(s1, s2, s3), null, row3)
checkEvaluation(StringRPad(s1, s2, s3), null, row4)
checkEvaluation(StringRPad(s1, s2, s3), null, row5)

// Test escaping of arguments
GenerateUnsafeProjection.generate(StringLPad(Literal("\"quote"), s2, Literal("\"quote")) :: Nil)
GenerateUnsafeProjection.generate(StringRPad(Literal("\"quote"), s2, Literal("\"quote")) :: Nil)
checkEvaluation(StringRPad(Literal("hi"), Literal(5)), "hi ")
checkEvaluation(StringRPad(Literal("hi"), Literal(1)), "h")
}
Expand All @@ -732,6 +766,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringRepeat(Literal("hi"), Literal(-1)), "", row1)
checkEvaluation(StringRepeat(s1, s2), "hihi", row1)
checkEvaluation(StringRepeat(s1, s2), null, row2)

// Test escaping of arguments
GenerateUnsafeProjection.generate(StringRepeat(Literal("\"quote"), Literal(2)) :: Nil)
}

test("REVERSE") {
Expand Down Expand Up @@ -897,6 +934,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(ParseUrl(Seq(Literal("1"), Literal(2))).checkInputDataTypes().isFailure)
assert(ParseUrl(Seq(Literal(1), Literal("2"))).checkInputDataTypes().isFailure)
assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal(3))).checkInputDataTypes().isFailure)

// Test escaping of arguments
GenerateUnsafeProjection.generate(ParseUrl(Seq(Literal("\"quote"), Literal("\"quote"))) :: Nil)
}

test("Sentences") {
Expand All @@ -919,5 +959,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
answer)
checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "XXX", "YYY"),
answer)

// Test escaping of arguments
GenerateUnsafeProjection.generate(
Sentences(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil)
}
}

0 comments on commit cc073f9

Please sign in to comment.