diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 2e4f9873de280..30b7305a0af9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -438,7 +438,7 @@ abstract class RegExpExtractBase override def children: Seq[Expression] = subject :: regexp :: idx :: Nil protected def getLastMatcher(s: Any, p: Any): Matcher = { - if (!p.equals(lastRegex)) { + if (p != lastRegex) { // regex value changed lastRegex = p.asInstanceOf[UTF8String].clone() pattern = Pattern.compile(lastRegex.toString) @@ -472,10 +472,9 @@ abstract class RegExpExtractBase if the config is enabled, the `regexp` that can match "\abc" is "^\abc$". * idx - an integer expression that representing the group index. The regex maybe contains multiple groups. `idx` indicates which regex group to extract. The group index should - be non-negative. If `idx` is not specified, the default group index value is 1. The - `idx` parameter is the Java regex Matcher group() method index. See - docs/api/java/util/regex/Matcher.html for more information on the `idx` or Java regex - group() method. + be non-negative. The minimum value of `idx` is 0, which means matching the entire + regular expression. If `idx` is not specified, the default group index value is 1. The + `idx` parameter is the Java regex Matcher group() method index. """, examples = """ Examples: @@ -521,6 +520,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio } else { "" } + nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { s""" if (!$regexp.equals($termLastRegex)) { @@ -572,17 +572,16 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio if the config is enabled, the `regexp` that can match "\abc" is "^\abc$". * idx - an integer expression that representing the group index. The regex may contains multiple groups. `idx` indicates which regex group to extract. The group index should - be non-negative. If `idx` is not specified, the default group index value is 1. The - `idx` parameter is the Java regex Matcher group() method index. See - docs/api/java/util/regex/Matcher.html for more information on the `idx` or Java regex - group() method. + be non-negative. The minimum value of `idx` is 0, which means matching the entire + regular expression. If `idx` is not specified, the default group index value is 1. The + `idx` parameter is the Java regex Matcher group() method index. """, examples = """ Examples: > SELECT _FUNC_('100-200, 300-400', '(\\d+)-(\\d+)', 1); ["100","300"] """, - since = "3.0.0") + since = "3.1.0") case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expression) extends RegExpExtractBase { def this(s: Expression, r: Expression) = this(s, r, Literal(1)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b9abb85fbdce1..653fe5bf7c9b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2495,7 +2495,7 @@ object functions { * will be thrown. * * @group string_funcs - * @since 3.0.0 + * @since 3.1.0 */ def regexp_extract_all(e: Column, exp: String, groupIdx: Int): Column = withExpr { RegExpExtractAll(e.expr, lit(exp).expr, lit(groupIdx).expr) diff --git a/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql index 573f6edc40897..1f74d7ef4d1be 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql @@ -16,8 +16,10 @@ SELECT regexp_extract_all('1a 2b 14m', '\\d+'); SELECT regexp_extract_all('1a 2b 14m', '\\d+', 0); SELECT regexp_extract_all('1a 2b 14m', '\\d+', 1); SELECT regexp_extract_all('1a 2b 14m', '\\d+', 2); +SELECT regexp_extract_all('1a 2b 14m', '\\d+', -1); SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)'); SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 0); SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 1); SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 2); SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 3); +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', -1); diff --git a/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out index f0596856dab96..0bfd5118864ef 100644 --- a/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 20 +-- Number of queries: 22 -- !query @@ -79,7 +79,6 @@ a -- !query - SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 3) -- !query schema struct<> @@ -132,6 +131,15 @@ java.lang.IllegalArgumentException Regex group count is 0, but the specified group index is 2 +-- !query +SELECT regexp_extract_all('1a 2b 14m', '\\d+', -1) +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +The specified group index cannot be less than zero + + -- !query SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)') -- !query schema @@ -171,3 +179,12 @@ struct<> -- !query output java.lang.IllegalArgumentException Regex group count is 2, but the specified group index is 3 + + +-- !query +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', -1) +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +The specified group index cannot be less than zero diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index a67685c461517..8d5166b5398cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -167,19 +167,6 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { Row(Seq("100", "300"), Seq("200", "400")) :: Row(Seq("101", "301"), Seq("201", "401")) :: Row(Seq("102", "302"), Seq("202", "402")) :: Nil) - - // for testing the mutable state of the expression in code gen. - // This is a hack way to enable the codegen, thus the codegen is enable by default, - // it will still use the interpretProjection if projection followed by a LocalRelation, - // hence we add a filter operator. - // See the optimizer rule `ConvertToLocalRelation` - checkAnswer( - df.filter("isnotnull(a)").selectExpr( - "regexp_extract_all(a, b, 0)", - "regexp_extract_all(a, b, 1)"), - Row(Seq("100-200", "300-400"), Seq("100", "300")) :: - Row(Seq("101-201", "301-401"), Seq("101", "301")) :: - Row(Seq("102", "202", "302", "402"), Seq("102", "202", "302", "402")) :: Nil) } test("non-matching optional group") {