diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3989df5d29467..7e73667e4b85f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -354,6 +354,7 @@ object FunctionRegistry { expression[StringLocate]("position", true), expression[FormatString]("printf", true), expression[RegExpExtract]("regexp_extract"), + expression[RegExpExtractAll]("regexp_extract_all"), expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReplace]("replace"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 1af1636e1df75..8eb7f463e049c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import java.util.regex.{MatchResult, Pattern} +import java.util.regex.{Matcher, MatchResult, Pattern} + +import scala.collection.mutable.ArrayBuffer import org.apache.commons.text.StringEscapeUtils @@ -410,7 +412,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio } } -object RegExpExtract { +object RegExpExtractBase { def checkGroupIndex(groupCount: Int, groupIndex: Int): Unit = { if (groupIndex < 0) { throw new IllegalArgumentException("The specified group index cannot be less than zero") @@ -421,20 +423,58 @@ object RegExpExtract { } } +abstract class RegExpExtractBase + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { + def subject: Expression + def regexp: Expression + def idx: Expression + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) + override def children: Seq[Expression] = subject :: regexp :: idx :: Nil + + protected def getLastMatcher(s: Any, p: Any): Matcher = { + if (p != lastRegex) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String].clone() + pattern = Pattern.compile(lastRegex.toString) + } + pattern.matcher(s.toString) + } +} + /** * Extract a specific(idx) group identified by a Java regex. * * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ @ExpressionDescription( - usage = "_FUNC_(str, regexp[, idx]) - Extracts a group that matches `regexp`.", + usage = """ + _FUNC_(str, regexp[, idx]) - Extract the first string in the `str` that match the `regexp` + expression and corresponding to the regex group index. + """, arguments = """ Arguments: * str - a string expression. - * regexp - a string representing a regular expression. - The regex string should be a Java regular expression. - * idx - an integer expression that representing the group index. The group index should be - non-negative. If `idx` is not specified, the default group index value is 1. + * regexp - a string representing a regular expression. The regex string should be a + Java regular expression. + + Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL + parser. For example, to match "\abc", a regular expression for `regexp` can be + "^\\abc$". + + There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to + fallback to the Spark 1.6 behavior regarding string literal parsing. For example, + if the config is enabled, the `regexp` that can match "\abc" is "^\abc$". + * idx - an integer expression that representing the group index. The regex maybe contains + multiple groups. `idx` indicates which regex group to extract. The group index should + be non-negative. The minimum value of `idx` is 0, which means matching the entire + regular expression. If `idx` is not specified, the default group index value is 1. The + `idx` parameter is the Java regex Matcher group() method index. """, examples = """ Examples: @@ -443,27 +483,17 @@ object RegExpExtract { """, since = "1.5.0") case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) - extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { + extends RegExpExtractBase { def this(s: Expression, r: Expression) = this(s, r, Literal(1)) - // last regex in string, we will update the pattern iff regexp value changed. - @transient private var lastRegex: UTF8String = _ - // last regex pattern, we cache it for performance concern - @transient private var pattern: Pattern = _ - override def nullSafeEval(s: Any, p: Any, r: Any): Any = { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String].clone() - pattern = Pattern.compile(lastRegex.toString) - } - val m = pattern.matcher(s.toString) + val m = getLastMatcher(s, p) if (m.find) { val mr: MatchResult = m.toMatchResult val index = r.asInstanceOf[Int] - RegExpExtract.checkGroupIndex(mr.groupCount, index) + RegExpExtractBase.checkGroupIndex(mr.groupCount, index) val group = mr.group(index) - if (group == null) { // Pattern matched, but not optional group + if (group == null) { // Pattern matched, but it's an optional group UTF8String.EMPTY_UTF8 } else { UTF8String.fromString(group) @@ -474,13 +504,11 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio } override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) - override def children: Seq[Expression] = subject :: regexp :: idx :: Nil override def prettyName: String = "regexp_extract" override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val classNamePattern = classOf[Pattern].getCanonicalName - val classNameRegExpExtract = classOf[RegExpExtract].getCanonicalName + val classNameRegExpExtractBase = classOf[RegExpExtractBase].getCanonicalName val matcher = ctx.freshName("matcher") val matchResult = ctx.freshName("matchResult") @@ -504,7 +532,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio $termPattern.matcher($subject.toString()); if ($matcher.find()) { java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); - $classNameRegExpExtract.checkGroupIndex($matchResult.groupCount(), $idx); + $classNameRegExpExtractBase.checkGroupIndex($matchResult.groupCount(), $idx); if ($matchResult.group($idx) == null) { ${ev.value} = UTF8String.EMPTY_UTF8; } else { @@ -518,3 +546,105 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio }) } } + +/** + * Extract all specific(idx) groups identified by a Java regex. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +@ExpressionDescription( + usage = """ + _FUNC_(str, regexp[, idx]) - Extract all strings in the `str` that match the `regexp` + expression and corresponding to the regex group index. + """, + arguments = """ + Arguments: + * str - a string expression. + * regexp - a string representing a regular expression. The regex string should be a + Java regular expression. + + Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL + parser. For example, to match "\abc", a regular expression for `regexp` can be + "^\\abc$". + + There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to + fallback to the Spark 1.6 behavior regarding string literal parsing. For example, + if the config is enabled, the `regexp` that can match "\abc" is "^\abc$". + * idx - an integer expression that representing the group index. The regex may contains + multiple groups. `idx` indicates which regex group to extract. The group index should + be non-negative. The minimum value of `idx` is 0, which means matching the entire + regular expression. If `idx` is not specified, the default group index value is 1. The + `idx` parameter is the Java regex Matcher group() method index. + """, + examples = """ + Examples: + > SELECT _FUNC_('100-200, 300-400', '(\\d+)-(\\d+)', 1); + ["100","300"] + """, + since = "3.1.0") +case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expression) + extends RegExpExtractBase { + def this(s: Expression, r: Expression) = this(s, r, Literal(1)) + + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + val m = getLastMatcher(s, p) + val matchResults = new ArrayBuffer[UTF8String]() + while(m.find) { + val mr: MatchResult = m.toMatchResult + val index = r.asInstanceOf[Int] + RegExpExtractBase.checkGroupIndex(mr.groupCount, index) + val group = mr.group(index) + if (group == null) { // Pattern matched, but it's an optional group + matchResults += UTF8String.EMPTY_UTF8 + } else { + matchResults += UTF8String.fromString(group) + } + } + + new GenericArrayData(matchResults.toArray.asInstanceOf[Array[Any]]) + } + + override def dataType: DataType = ArrayType(StringType) + override def prettyName: String = "regexp_extract_all" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val classNamePattern = classOf[Pattern].getCanonicalName + val classNameRegExpExtractBase = classOf[RegExpExtractBase].getCanonicalName + val arrayClass = classOf[GenericArrayData].getName + val matcher = ctx.freshName("matcher") + val matchResult = ctx.freshName("matchResult") + val matchResults = ctx.freshName("matchResults") + + val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex") + val termPattern = ctx.addMutableState(classNamePattern, "pattern") + + val setEvNotNull = if (nullable) { + s"${ev.isNull} = false;" + } else { + "" + } + nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { + s""" + | if (!$regexp.equals($termLastRegex)) { + | // regex value changed + | $termLastRegex = $regexp.clone(); + | $termPattern = $classNamePattern.compile($termLastRegex.toString()); + | } + | java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString()); + | java.util.ArrayList $matchResults = new java.util.ArrayList(); + | while ($matcher.find()) { + | java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); + | $classNameRegExpExtractBase.checkGroupIndex($matchResult.groupCount(), $idx); + | if ($matchResult.group($idx) == null) { + | $matchResults.add(UTF8String.EMPTY_UTF8); + | } else { + | $matchResults.add(UTF8String.fromString($matchResult.group($idx))); + | } + | } + | ${ev.value} = + | new $arrayClass($matchResults.toArray(new UTF8String[$matchResults.size()])); + | $setEvNotNull + """ + }) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index cab81f85fda06..205dc10efc8a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -322,6 +322,56 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { RegExpExtract(Literal("\"quote"), Literal("\"quote"), Literal(1)) :: Nil) } + test("RegexExtractAll") { + val row1 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", 0) + val row2 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", 1) + val row3 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", 2) + val row4 = create_row("100-200,300-400,500-600", "(\\d+).*", 1) + val row5 = create_row("100-200,300-400,500-600", "([a-z])", 1) + val row6 = create_row(null, "([a-z])", 1) + val row7 = create_row("100-200,300-400,500-600", null, 1) + val row8 = create_row("100-200,300-400,500-600", "([a-z])", null) + + val s = 's.string.at(0) + val p = 'p.string.at(1) + val r = 'r.int.at(2) + + val expr = RegExpExtractAll(s, p, r) + checkEvaluation(expr, Seq("100-200", "300-400", "500-600"), row1) + checkEvaluation(expr, Seq("100", "300", "500"), row2) + checkEvaluation(expr, Seq("200", "400", "600"), row3) + checkEvaluation(expr, Seq("100"), row4) + checkEvaluation(expr, Seq(), row5) + checkEvaluation(expr, null, row6) + checkEvaluation(expr, null, row7) + checkEvaluation(expr, null, row8) + + val expr1 = new RegExpExtractAll(s, p) + checkEvaluation(expr1, Seq("100", "300", "500"), row2) + + val nonNullExpr = RegExpExtractAll(Literal("100-200,300-400,500-600"), + Literal("(\\d+)-(\\d+)"), Literal(1)) + checkEvaluation(nonNullExpr, Seq("100", "300", "500"), row2) + + // invalid group index + val row9 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", 3) + val row10 = create_row("100-200,300-400,500-600", "(\\d+).*", 2) + val row11 = create_row("100-200,300-400,500-600", "\\d+", 1) + val row12 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", -1) + val row13 = create_row("100-200,300-400,500-600", "\\d+", -1) + + checkExceptionInExpression[IllegalArgumentException]( + expr, row9, "Regex group count is 2, but the specified group index is 3") + checkExceptionInExpression[IllegalArgumentException]( + expr, row10, "Regex group count is 1, but the specified group index is 2") + checkExceptionInExpression[IllegalArgumentException]( + expr, row11, "Regex group count is 0, but the specified group index is 1") + checkExceptionInExpression[IllegalArgumentException]( + expr, row12, "The specified group index cannot be less than zero") + checkExceptionInExpression[IllegalArgumentException]( + expr, row13, "The specified group index cannot be less than zero") + } + test("SPLIT") { val s1 = 'a.string.at(0) val s2 = 'b.string.at(1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index fe0057c3d588b..653fe5bf7c9b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2478,6 +2478,8 @@ object functions { /** * Extract a specific group matched by a Java regex, from the specified string column. * If the regex did not match, or the specified group did not match, an empty string is returned. + * if the specified group index exceeds the group count of regex, an IllegalArgumentException + * will be thrown. * * @group string_funcs * @since 1.5.0 @@ -2486,6 +2488,19 @@ object functions { RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr) } + /** + * Extract all specific groups matched by a Java regex, from the specified string column. + * If the regex did not match, or the specified group did not match, return an empty array. + * if the specified group index exceeds the group count of regex, an IllegalArgumentException + * will be thrown. + * + * @group string_funcs + * @since 3.1.0 + */ + def regexp_extract_all(e: Column, exp: String, groupIdx: Int): Column = withExpr { + RegExpExtractAll(e.expr, lit(exp).expr, lit(groupIdx).expr) + } + /** * Replace all substrings of the specified string value that match regexp with rep. * diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index edf2ede9e5a44..a212d8ce40642 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -214,6 +214,7 @@ | org.apache.spark.sql.catalyst.expressions.Randn | randn | SELECT randn() | struct | | org.apache.spark.sql.catalyst.expressions.Rank | rank | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.RegExpExtract | regexp_extract | SELECT regexp_extract('100-200', '(\\d+)-(\\d+)', 1) | struct | +| org.apache.spark.sql.catalyst.expressions.RegExpExtractAll | regexp_extract_all | SELECT regexp_extract_all('100-200, 300-400', '(\\d+)-(\\d+)', 1) | struct> | | org.apache.spark.sql.catalyst.expressions.RegExpReplace | regexp_replace | SELECT regexp_replace('100-200', '(\\d+)', 'num') | struct | | org.apache.spark.sql.catalyst.expressions.Remainder | % | SELECT 2 % 1.8 | struct<(CAST(CAST(2 AS DECIMAL(1,0)) AS DECIMAL(2,1)) % CAST(1.8 AS DECIMAL(2,1))):decimal(2,1)> | | org.apache.spark.sql.catalyst.expressions.Remainder | mod | SELECT 2 % 1.8 | struct<(CAST(CAST(2 AS DECIMAL(1,0)) AS DECIMAL(2,1)) % CAST(1.8 AS DECIMAL(2,1))):decimal(2,1)> | diff --git a/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql index 8a531be30d896..7128dee0a00d7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql @@ -4,8 +4,30 @@ SELECT regexp_extract('1a 2b 14m', '\\d+', 0); SELECT regexp_extract('1a 2b 14m', '\\d+', 1); SELECT regexp_extract('1a 2b 14m', '\\d+', 2); SELECT regexp_extract('1a 2b 14m', '\\d+', -1); +SELECT regexp_extract('1a 2b 14m', '(\\d+)?', 1); +SELECT regexp_extract('a b m', '(\\d+)?', 1); SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)'); SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 0); SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 1); SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 2); +SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 3); SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', -1); +SELECT regexp_extract('1a 2b 14m', '(\\d+)?([a-z]+)', 1); +SELECT regexp_extract('a b m', '(\\d+)?([a-z]+)', 1); + +-- regexp_extract_all +SELECT regexp_extract_all('1a 2b 14m', '\\d+'); +SELECT regexp_extract_all('1a 2b 14m', '\\d+', 0); +SELECT regexp_extract_all('1a 2b 14m', '\\d+', 1); +SELECT regexp_extract_all('1a 2b 14m', '\\d+', 2); +SELECT regexp_extract_all('1a 2b 14m', '\\d+', -1); +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)?', 1); +SELECT regexp_extract_all('a 2b 14m', '(\\d+)?', 1); +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)'); +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 0); +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 1); +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 2); +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 3); +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', -1); +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)?([a-z]+)', 1); +SELECT regexp_extract_all('a 2b 14m', '(\\d+)?([a-z]+)', 1); diff --git a/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out index 64aa6053d8d70..2eef926f63e37 100644 --- a/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 30 -- !query @@ -46,6 +46,22 @@ java.lang.IllegalArgumentException The specified group index cannot be less than zero +-- !query +SELECT regexp_extract('1a 2b 14m', '(\\d+)?', 1) +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT regexp_extract('a b m', '(\\d+)?', 1) +-- !query schema +struct +-- !query output + + + -- !query SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)') -- !query schema @@ -78,10 +94,161 @@ struct a +-- !query +SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 3) +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Regex group count is 2, but the specified group index is 3 + + -- !query SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', -1) -- !query schema struct<> -- !query output java.lang.IllegalArgumentException -The specified group index cannot be less than zero \ No newline at end of file +The specified group index cannot be less than zero + + +-- !query +SELECT regexp_extract('1a 2b 14m', '(\\d+)?([a-z]+)', 1) +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT regexp_extract('a b m', '(\\d+)?([a-z]+)', 1) +-- !query schema +struct +-- !query output + + + +-- !query +SELECT regexp_extract_all('1a 2b 14m', '\\d+') +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Regex group count is 0, but the specified group index is 1 + + +-- !query +SELECT regexp_extract_all('1a 2b 14m', '\\d+', 0) +-- !query schema +struct> +-- !query output +["1","2","14"] + + +-- !query +SELECT regexp_extract_all('1a 2b 14m', '\\d+', 1) +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Regex group count is 0, but the specified group index is 1 + + +-- !query +SELECT regexp_extract_all('1a 2b 14m', '\\d+', 2) +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Regex group count is 0, but the specified group index is 2 + + +-- !query +SELECT regexp_extract_all('1a 2b 14m', '\\d+', -1) +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +The specified group index cannot be less than zero + + +-- !query +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)?', 1) +-- !query schema +struct> +-- !query output +["1","","","2","","","14","",""] + + +-- !query +SELECT regexp_extract_all('a 2b 14m', '(\\d+)?', 1) +-- !query schema +struct> +-- !query output +["","","2","","","14","",""] + + +-- !query +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)') +-- !query schema +struct> +-- !query output +["1","2","14"] + + +-- !query +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 0) +-- !query schema +struct> +-- !query output +["1a","2b","14m"] + + +-- !query +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 1) +-- !query schema +struct> +-- !query output +["1","2","14"] + + +-- !query +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 2) +-- !query schema +struct> +-- !query output +["a","b","m"] + + +-- !query +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 3) +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Regex group count is 2, but the specified group index is 3 + + +-- !query +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', -1) +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +The specified group index cannot be less than zero + + +-- !query +SELECT regexp_extract_all('1a 2b 14m', '(\\d+)?([a-z]+)', 1) +-- !query schema +struct> +-- !query output +["1","2","14"] + + +-- !query +SELECT regexp_extract_all('a 2b 14m', '(\\d+)?([a-z]+)', 1) +-- !query schema +struct> +-- !query output +["","2","14"] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index f904b53fe47eb..8d5166b5398cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -154,8 +154,25 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil) } + test("string regex_extract_all") { + val df = Seq( + ("100-200,300-400", "(\\d+)-(\\d+)"), + ("101-201,301-401", "(\\d+)-(\\d+)"), + ("102-202,302-402", "(\\d+)")).toDF("a", "b") + + checkAnswer( + df.select( + regexp_extract_all($"a", "(\\d+)-(\\d+)", 1), + regexp_extract_all($"a", "(\\d+)-(\\d+)", 2)), + Row(Seq("100", "300"), Seq("200", "400")) :: + Row(Seq("101", "301"), Seq("201", "401")) :: + Row(Seq("102", "302"), Seq("202", "402")) :: Nil) + } + test("non-matching optional group") { val df = Seq(Tuple1("aaaac")).toDF("s") + + // regexp_extract checkAnswer( df.select(regexp_extract($"s", "(foo)", 1)), Row("") @@ -164,6 +181,16 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { df.select(regexp_extract($"s", "(a+)(b)?(c)", 2)), Row("") ) + + // regexp_extract_all + checkAnswer( + df.select(regexp_extract_all($"s", "(foo)", 1)), + Row(Seq()) + ) + checkAnswer( + df.select(regexp_extract_all($"s", "(a+)(b)?(c)", 2)), + Row(Seq("")) + ) } test("string ascii function") {