From cc073f9c9cde79e42f3501f48cbabfcea10e7ceb Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 6 Apr 2020 05:43:29 +0000 Subject: [PATCH] [SPARK-31343][SQL][TESTS] Check codegen does not fail on expressions with escape chars in string parameters ### What changes were proposed in this pull request? In the PR, I propose to add tests to check that code generation doesn't fail if expressions string argument contains escape chars. The PR adds similar tests added by https://github.com/apache/spark/pull/20182 for `from_utc_timestamp` / `to_utc_timestamp`. ### Why are the changes needed? To prevent regressions in the future. ### Does this PR introduce any user-facing change? No ### How was this patch tested? By running the affected tests Closes #28115 from MaxGekk/tests-arg-escape. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- .../CallMethodViaReflectionSuite.scala | 6 +++ .../expressions/DateExpressionsSuite.scala | 17 +++++++ .../expressions/JsonExpressionsSuite.scala | 23 ++++++++++ .../expressions/RegexpExpressionsSuite.scala | 13 ++++++ .../expressions/StringExpressionsSuite.scala | 44 +++++++++++++++++++ 5 files changed, 103 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala index 88d4d460751b6..d8f3ad24246a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala @@ -21,6 +21,7 @@ import java.sql.Timestamp import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.types.{IntegerType, StringType} /** A static class for testing purpose. */ @@ -101,6 +102,11 @@ class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelp checkEvaluation(createExpr(staticClassName, "method4", 4, "four"), "m4four") } + test("escaping of class and method names") { + GenerateUnsafeProjection.generate( + CallMethodViaReflection(Seq(Literal("\"quote"), Literal("\"quote"), Literal(null))) :: Nil) + } + private def createExpr(className: String, methodName: String, args: Any*) = { CallMethodViaReflection( Literal.create(className, StringType) +: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 024276355edbc..d9b508a563a8b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -265,6 +265,10 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal("H"), JST_OPT), "0") checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), JST_OPT), "22") + // Test escaping of format + GenerateUnsafeProjection.generate( + DateFormatClass(Literal(ts), Literal("\"quote"), JST_OPT) :: Nil) + // SPARK-28072 The codegen path should work checkEvaluation( expression = DateFormatClass( @@ -602,6 +606,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(NextDay(Literal.create(null, DateType), Literal("xx")), null) checkEvaluation( NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null) + // Test escaping of dayOfWeek + GenerateUnsafeProjection.generate( + NextDay(Literal(Date.valueOf("2015-07-23")), Literal("\"quote")) :: Nil) } test("TruncDate") { @@ -625,6 +632,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testTrunc(date, null, null) testTrunc(null, "MON", null) testTrunc(null, null, null) + // Test escaping of format + GenerateUnsafeProjection.generate(TruncDate(Literal(0, DateType), Literal("\"quote")) :: Nil) testTrunc(Date.valueOf("2000-03-08"), "decade", Date.valueOf("2000-01-01")) testTrunc(Date.valueOf("2000-03-08"), "century", Date.valueOf("1901-01-01")) @@ -751,6 +760,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + // Test escaping of format + GenerateUnsafeProjection.generate(FromUnixTime(Literal(0L), Literal("\"quote")) :: Nil) } test("unix_timestamp") { @@ -818,6 +829,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + // Test escaping of format + GenerateUnsafeProjection.generate( + UnixTimestamp(Literal("2015-07-24"), Literal("\"quote")) :: Nil) } test("to_unix_timestamp") { @@ -893,6 +907,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + // Test escaping of format + GenerateUnsafeProjection.generate( + ToUnixTimestamp(Literal("2015-07-24"), Literal("\"quote")) :: Nil) } test("datediff") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 032e0ac61884b..90c4d8f789660 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{PST, UTC, UTC_OPT} @@ -49,6 +50,10 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with in Jackson's JsonFactory.createParser(byte[]) due to RFC-4627 encoding detection */ val badJson = "\u0000\u0000\u0000A\u0001AAA" + test("get_json_object escaping") { + GenerateUnsafeProjection.generate(GetJsonObject(Literal("\"quote"), Literal("\"quote")) :: Nil) + } + test("$.store.bicycle") { checkEvaluation( GetJsonObject(Literal(json), Literal("$.store.bicycle")), @@ -265,6 +270,11 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with assert(jt.eval(null).toSeq.head === expected) } + test("json_tuple escaping") { + GenerateUnsafeProjection.generate( + JsonTuple(Literal("\"quote") :: Literal("\"quote") :: Nil) :: Nil) + } + test("json_tuple - hive key 1") { checkJsonTuple( JsonTuple( @@ -396,6 +406,12 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with InternalRow(UTF8String.fromString("1"), null, UTF8String.fromString("1"))) } + test("from_json escaping") { + val schema = StructType(StructField("\"quote", IntegerType) :: Nil) + GenerateUnsafeProjection.generate( + JsonToStructs(schema, Map.empty, Literal("\"quote"), UTC_OPT) :: Nil) + } + test("from_json") { val jsonData = """{"a": 1}""" val schema = StructType(StructField("a", IntegerType) :: Nil) @@ -549,6 +565,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with ) } + test("to_json escaping") { + val schema = StructType(StructField("\"quote", IntegerType) :: Nil) + val struct = Literal.create(create_row(1), schema) + GenerateUnsafeProjection.generate( + StructsToJson(Map.empty, struct, UTC_OPT) :: Nil) + } + test("to_json - struct") { val schema = StructType(StructField("a", IntegerType) :: Nil) val struct = Literal.create(create_row(1), schema) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 712d2bc4c4736..ad9492a8d3ab1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.types.StringType /** @@ -255,6 +256,10 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val nonNullExpr = RegExpReplace(Literal("100-200"), Literal("(\\d+)"), Literal("num")) checkEvaluation(nonNullExpr, "num-num", row1) + + // Test escaping of arguments + GenerateUnsafeProjection.generate( + RegExpReplace(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil) } test("SPARK-22570: RegExpReplace should not create a lot of global variables") { @@ -305,6 +310,10 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { expr, row9, "Regex group count is 1, but the specified group index is 2") checkExceptionInExpression[IllegalArgumentException]( expr, row10, "Regex group count is 0, but the specified group index is 1") + + // Test escaping of arguments + GenerateUnsafeProjection.generate( + RegExpExtract(Literal("\"quote"), Literal("\"quote"), Literal(1)) :: Nil) } test("SPLIT") { @@ -327,6 +336,10 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { StringSplit(s1, s2, -1), Seq("aa", "bb", "cc"), row1) checkEvaluation(StringSplit(s1, s2, -1), null, row2) checkEvaluation(StringSplit(s1, s2, -1), null, row3) + + // Test escaping of arguments + GenerateUnsafeProjection.generate( + StringSplit(Literal("\"quote"), Literal("\"quote"), Literal(-1)) :: Nil) } test("SPARK-30759: cache initialization for literal patterns") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 4308f98d6969a..f18364d844ce1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.types._ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -155,6 +156,11 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(c1 endsWith "b", false, row) checkEvaluation(c2 endsWith "b", null, row) checkEvaluation(c1 endsWith Literal.create(null, StringType), null, row) + + // Test escaping of arguments + GenerateUnsafeProjection.generate(Contains(Literal("\"quote"), Literal("\"quote")) :: Nil) + GenerateUnsafeProjection.generate(EndsWith(Literal("\"quote"), Literal("\"quote")) :: Nil) + GenerateUnsafeProjection.generate(StartsWith(Literal("\"quote"), Literal("\"quote")) :: Nil) } test("Substring") { @@ -352,6 +358,10 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Decode(b, Literal("utf-8")), null, create_row(null)) checkEvaluation(Decode(Literal.create(null, BinaryType), Literal("utf-8")), null) checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null)) + + // Test escaping of charset + GenerateUnsafeProjection.generate(Encode(a, Literal("\"quote")) :: Nil) + GenerateUnsafeProjection.generate(Decode(b, Literal("\"quote")) :: Nil) } test("initcap unit test") { @@ -379,6 +389,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Levenshtein(Literal("千世"), Literal("fog")), 3) checkEvaluation(Levenshtein(Literal("世界千世"), Literal("大a界b")), 4) // scalastyle:on + + // Test escaping of arguments: + GenerateUnsafeProjection.generate(Levenshtein(Literal("\"quotea"), Literal("\"quoteb")) :: Nil) } test("soundex unit test") { @@ -560,6 +573,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringTrim(Literal("a"), Literal.create(null, StringType)), null) checkEvaluation(StringTrim(Literal.create(null, StringType), Literal("a")), null) + // Test escaping of arguments + GenerateUnsafeProjection.generate(StringTrim(Literal("\"quote"), Literal("\"quote")) :: Nil) + checkEvaluation(StringTrim(Literal("yxTomxx"), Literal("xyz")), "Tom") checkEvaluation(StringTrim(Literal("xxxbarxxx"), Literal("x")), "bar") } @@ -587,6 +603,10 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringTrimLeft(Literal.create(null, StringType), Literal("a")), null) checkEvaluation(StringTrimLeft(Literal("a"), Literal.create(null, StringType)), null) + // Test escaping of arguments + GenerateUnsafeProjection.generate( + StringTrimLeft(Literal("\"quote"), Literal("\"quote")) :: Nil) + checkEvaluation(StringTrimLeft(Literal("zzzytest"), Literal("xyz")), "test") checkEvaluation(StringTrimLeft(Literal("zzzytestxyz"), Literal("xyz")), "testxyz") checkEvaluation(StringTrimLeft(Literal("xyxXxyLAST WORD"), Literal("xy")), "XxyLAST WORD") @@ -616,6 +636,10 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringTrimRight(Literal("a"), Literal.create(null, StringType)), null) checkEvaluation(StringTrimRight(Literal.create(null, StringType), Literal("a")), null) + // Test escaping of arguments + GenerateUnsafeProjection.generate( + StringTrimRight(Literal("\"quote"), Literal("\"quote")) :: Nil) + checkEvaluation(StringTrimRight(Literal("testxxzx"), Literal("xyz")), "test") checkEvaluation(StringTrimRight(Literal("xyztestxxzx"), Literal("xyz")), "xyztest") checkEvaluation(StringTrimRight(Literal("TURNERyxXxy"), Literal("xy")), "TURNERyxX") @@ -632,6 +656,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { FormatString(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") checkEvaluation( FormatString(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") + + // Test escaping of arguments + GenerateUnsafeProjection.generate(FormatString(Literal("\"quote"), Literal("\"quote")) :: Nil) } test("SPARK-22603: FormatString should not generate codes beyond 64KB") { @@ -662,6 +689,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringInstr(s1, s2), 1, create_row("花花世界", "花")) checkEvaluation(StringInstr(s1, s2), 0, create_row("花花世界", "小")) // scalastyle:on + + // Test escaping of arguments + GenerateUnsafeProjection.generate(StringInstr(Literal("\"quote"), Literal("\"quote")) :: Nil) } test("LOCATE") { @@ -718,6 +748,10 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringRPad(s1, s2, s3), null, row3) checkEvaluation(StringRPad(s1, s2, s3), null, row4) checkEvaluation(StringRPad(s1, s2, s3), null, row5) + + // Test escaping of arguments + GenerateUnsafeProjection.generate(StringLPad(Literal("\"quote"), s2, Literal("\"quote")) :: Nil) + GenerateUnsafeProjection.generate(StringRPad(Literal("\"quote"), s2, Literal("\"quote")) :: Nil) checkEvaluation(StringRPad(Literal("hi"), Literal(5)), "hi ") checkEvaluation(StringRPad(Literal("hi"), Literal(1)), "h") } @@ -732,6 +766,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringRepeat(Literal("hi"), Literal(-1)), "", row1) checkEvaluation(StringRepeat(s1, s2), "hihi", row1) checkEvaluation(StringRepeat(s1, s2), null, row2) + + // Test escaping of arguments + GenerateUnsafeProjection.generate(StringRepeat(Literal("\"quote"), Literal(2)) :: Nil) } test("REVERSE") { @@ -897,6 +934,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ParseUrl(Seq(Literal("1"), Literal(2))).checkInputDataTypes().isFailure) assert(ParseUrl(Seq(Literal(1), Literal("2"))).checkInputDataTypes().isFailure) assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal(3))).checkInputDataTypes().isFailure) + + // Test escaping of arguments + GenerateUnsafeProjection.generate(ParseUrl(Seq(Literal("\"quote"), Literal("\"quote"))) :: Nil) } test("Sentences") { @@ -919,5 +959,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { answer) checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "XXX", "YYY"), answer) + + // Test escaping of arguments + GenerateUnsafeProjection.generate( + Sentences(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil) } }