Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24884][SQL] Support regexp function regexp_extract_all #27507

Closed
wants to merge 17 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ object FunctionRegistry {
expression[StringLocate]("position", true),
expression[FormatString]("printf", true),
expression[RegExpExtract]("regexp_extract"),
expression[RegExpExtractAll]("regexp_extract_all"),
expression[RegExpReplace]("regexp_replace"),
expression[StringRepeat]("repeat"),
expression[StringReplace]("replace"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst.expressions

import java.util.Locale
import java.util.regex.{MatchResult, Pattern}
import java.util.regex.{Matcher, MatchResult, Pattern}

import scala.collection.mutable.ArrayBuffer

import org.apache.commons.text.StringEscapeUtils

Expand Down Expand Up @@ -410,7 +412,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
}
}

object RegExpExtract {
object RegExpExtractBase {
def checkGroupIndex(groupCount: Int, groupIndex: Int): Unit = {
if (groupIndex < 0) {
throw new IllegalArgumentException("The specified group index cannot be less than zero")
Expand All @@ -421,20 +423,58 @@ object RegExpExtract {
}
}

abstract class RegExpExtractBase
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
def subject: Expression
def regexp: Expression
def idx: Expression

// last regex in string, we will update the pattern iff regexp value changed.
@transient private var lastRegex: UTF8String = _
// last regex pattern, we cache it for performance concern
@transient private var pattern: Pattern = _

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
override def children: Seq[Expression] = subject :: regexp :: idx :: Nil

protected def getLastMatcher(s: Any, p: Any): Matcher = {
if (p != lastRegex) {
// regex value changed
lastRegex = p.asInstanceOf[UTF8String].clone()
pattern = Pattern.compile(lastRegex.toString)
}
pattern.matcher(s.toString)
}
}

/**
* Extract a specific(idx) group identified by a Java regex.
*
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
*/
@ExpressionDescription(
usage = "_FUNC_(str, regexp[, idx]) - Extracts a group that matches `regexp`.",
usage = """
_FUNC_(str, regexp[, idx]) - Extract the first string in the `str` that match the `regexp`
expression and corresponding to the regex group index.
""",
arguments = """
Arguments:
* str - a string expression.
* regexp - a string representing a regular expression.
The regex string should be a Java regular expression.
* idx - an integer expression that representing the group index. The group index should be
non-negative. If `idx` is not specified, the default group index value is 1.
* regexp - a string representing a regular expression. The regex string should be a
Java regular expression.

Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL
parser. For example, to match "\abc", a regular expression for `regexp` can be
"^\\abc$".

There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to
fallback to the Spark 1.6 behavior regarding string literal parsing. For example,
if the config is enabled, the `regexp` that can match "\abc" is "^\abc$".
* idx - an integer expression that representing the group index. The regex maybe contains
multiple groups. `idx` indicates which regex group to extract. The group index should
be non-negative. The minimum value of `idx` is 0, which means matching the entire
regular expression. If `idx` is not specified, the default group index value is 1. The
`idx` parameter is the Java regex Matcher group() method index.
""",
examples = """
Examples:
Expand All @@ -443,27 +483,17 @@ object RegExpExtract {
""",
since = "1.5.0")
case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression)
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
extends RegExpExtractBase {
def this(s: Expression, r: Expression) = this(s, r, Literal(1))

// last regex in string, we will update the pattern iff regexp value changed.
@transient private var lastRegex: UTF8String = _
// last regex pattern, we cache it for performance concern
@transient private var pattern: Pattern = _

override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
if (!p.equals(lastRegex)) {
// regex value changed
lastRegex = p.asInstanceOf[UTF8String].clone()
pattern = Pattern.compile(lastRegex.toString)
}
val m = pattern.matcher(s.toString)
val m = getLastMatcher(s, p)
if (m.find) {
val mr: MatchResult = m.toMatchResult
val index = r.asInstanceOf[Int]
RegExpExtract.checkGroupIndex(mr.groupCount, index)
RegExpExtractBase.checkGroupIndex(mr.groupCount, index)
val group = mr.group(index)
if (group == null) { // Pattern matched, but not optional group
if (group == null) { // Pattern matched, but it's an optional group
UTF8String.EMPTY_UTF8
} else {
UTF8String.fromString(group)
Expand All @@ -474,13 +504,11 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
}

override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
override def children: Seq[Expression] = subject :: regexp :: idx :: Nil
override def prettyName: String = "regexp_extract"

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val classNamePattern = classOf[Pattern].getCanonicalName
val classNameRegExpExtract = classOf[RegExpExtract].getCanonicalName
val classNameRegExpExtractBase = classOf[RegExpExtractBase].getCanonicalName
val matcher = ctx.freshName("matcher")
val matchResult = ctx.freshName("matchResult")

Expand All @@ -504,7 +532,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
$termPattern.matcher($subject.toString());
if ($matcher.find()) {
java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
$classNameRegExpExtract.checkGroupIndex($matchResult.groupCount(), $idx);
$classNameRegExpExtractBase.checkGroupIndex($matchResult.groupCount(), $idx);
if ($matchResult.group($idx) == null) {
${ev.value} = UTF8String.EMPTY_UTF8;
} else {
Expand All @@ -518,3 +546,105 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
})
}
}

/**
* Extract all specific(idx) groups identified by a Java regex.
*
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
*/
@ExpressionDescription(
usage = """
_FUNC_(str, regexp[, idx]) - Extract all strings in the `str` that match the `regexp`
expression and corresponding to the regex group index.
""",
arguments = """
Arguments:
* str - a string expression.
* regexp - a string representing a regular expression. The regex string should be a
Java regular expression.

Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL
parser. For example, to match "\abc", a regular expression for `regexp` can be
"^\\abc$".

There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to
fallback to the Spark 1.6 behavior regarding string literal parsing. For example,
if the config is enabled, the `regexp` that can match "\abc" is "^\abc$".
* idx - an integer expression that representing the group index. The regex may contains
multiple groups. `idx` indicates which regex group to extract. The group index should
be non-negative. The minimum value of `idx` is 0, which means matching the entire
regular expression. If `idx` is not specified, the default group index value is 1. The
`idx` parameter is the Java regex Matcher group() method index.
""",
examples = """
Examples:
> SELECT _FUNC_('100-200, 300-400', '(\\d+)-(\\d+)', 1);
["100","300"]
""",
since = "3.1.0")
case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expression)
extends RegExpExtractBase {
def this(s: Expression, r: Expression) = this(s, r, Literal(1))

override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
val m = getLastMatcher(s, p)
val matchResults = new ArrayBuffer[UTF8String]()
while(m.find) {
val mr: MatchResult = m.toMatchResult
val index = r.asInstanceOf[Int]
RegExpExtractBase.checkGroupIndex(mr.groupCount, index)
val group = mr.group(index)
if (group == null) { // Pattern matched, but it's an optional group
matchResults += UTF8String.EMPTY_UTF8
} else {
matchResults += UTF8String.fromString(group)
}
}

new GenericArrayData(matchResults.toArray.asInstanceOf[Array[Any]])
}

override def dataType: DataType = ArrayType(StringType)
override def prettyName: String = "regexp_extract_all"

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val classNamePattern = classOf[Pattern].getCanonicalName
val classNameRegExpExtractBase = classOf[RegExpExtractBase].getCanonicalName
val arrayClass = classOf[GenericArrayData].getName
val matcher = ctx.freshName("matcher")
val matchResult = ctx.freshName("matchResult")
val matchResults = ctx.freshName("matchResults")

val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex")
val termPattern = ctx.addMutableState(classNamePattern, "pattern")

val setEvNotNull = if (nullable) {
s"${ev.isNull} = false;"
} else {
""
}
nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
s"""
| if (!$regexp.equals($termLastRegex)) {
| // regex value changed
| $termLastRegex = $regexp.clone();
| $termPattern = $classNamePattern.compile($termLastRegex.toString());
| }
| java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString());
| java.util.ArrayList $matchResults = new java.util.ArrayList<UTF8String>();
| while ($matcher.find()) {
| java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
| $classNameRegExpExtractBase.checkGroupIndex($matchResult.groupCount(), $idx);
| if ($matchResult.group($idx) == null) {
| $matchResults.add(UTF8String.EMPTY_UTF8);
| } else {
| $matchResults.add(UTF8String.fromString($matchResult.group($idx)));
| }
| }
| ${ev.value} =
| new $arrayClass($matchResults.toArray(new UTF8String[$matchResults.size()]));
| $setEvNotNull
"""
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,56 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
RegExpExtract(Literal("\"quote"), Literal("\"quote"), Literal(1)) :: Nil)
}

test("RegexExtractAll") {
val row1 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", 0)
val row2 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", 1)
val row3 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", 2)
val row4 = create_row("100-200,300-400,500-600", "(\\d+).*", 1)
val row5 = create_row("100-200,300-400,500-600", "([a-z])", 1)
val row6 = create_row(null, "([a-z])", 1)
val row7 = create_row("100-200,300-400,500-600", null, 1)
val row8 = create_row("100-200,300-400,500-600", "([a-z])", null)

val s = 's.string.at(0)
val p = 'p.string.at(1)
val r = 'r.int.at(2)

val expr = RegExpExtractAll(s, p, r)
checkEvaluation(expr, Seq("100-200", "300-400", "500-600"), row1)
checkEvaluation(expr, Seq("100", "300", "500"), row2)
checkEvaluation(expr, Seq("200", "400", "600"), row3)
checkEvaluation(expr, Seq("100"), row4)
checkEvaluation(expr, Seq(), row5)
checkEvaluation(expr, null, row6)
checkEvaluation(expr, null, row7)
checkEvaluation(expr, null, row8)

val expr1 = new RegExpExtractAll(s, p)
checkEvaluation(expr1, Seq("100", "300", "500"), row2)

val nonNullExpr = RegExpExtractAll(Literal("100-200,300-400,500-600"),
Literal("(\\d+)-(\\d+)"), Literal(1))
checkEvaluation(nonNullExpr, Seq("100", "300", "500"), row2)

// invalid group index
val row9 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", 3)
val row10 = create_row("100-200,300-400,500-600", "(\\d+).*", 2)
val row11 = create_row("100-200,300-400,500-600", "\\d+", 1)
val row12 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", -1)
val row13 = create_row("100-200,300-400,500-600", "\\d+", -1)

checkExceptionInExpression[IllegalArgumentException](
expr, row9, "Regex group count is 2, but the specified group index is 3")
checkExceptionInExpression[IllegalArgumentException](
expr, row10, "Regex group count is 1, but the specified group index is 2")
checkExceptionInExpression[IllegalArgumentException](
expr, row11, "Regex group count is 0, but the specified group index is 1")
checkExceptionInExpression[IllegalArgumentException](
expr, row12, "The specified group index cannot be less than zero")
checkExceptionInExpression[IllegalArgumentException](
expr, row13, "The specified group index cannot be less than zero")
}

test("SPLIT") {
val s1 = 'a.string.at(0)
val s2 = 'b.string.at(1)
Expand Down
15 changes: 15 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2478,6 +2478,8 @@ object functions {
/**
* Extract a specific group matched by a Java regex, from the specified string column.
* If the regex did not match, or the specified group did not match, an empty string is returned.
* if the specified group index exceeds the group count of regex, an IllegalArgumentException
* will be thrown.
*
* @group string_funcs
* @since 1.5.0
Expand All @@ -2486,6 +2488,19 @@ object functions {
RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr)
}

/**
* Extract all specific groups matched by a Java regex, from the specified string column.
* If the regex did not match, or the specified group did not match, return an empty array.
* if the specified group index exceeds the group count of regex, an IllegalArgumentException
* will be thrown.
*
* @group string_funcs
* @since 3.1.0
*/
def regexp_extract_all(e: Column, exp: String, groupIdx: Int): Column = withExpr {
RegExpExtractAll(e.expr, lit(exp).expr, lit(groupIdx).expr)
}

/**
* Replace all substrings of the specified string value that match regexp with rep.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@
| org.apache.spark.sql.catalyst.expressions.Randn | randn | SELECT randn() | struct<randn():double> |
| org.apache.spark.sql.catalyst.expressions.Rank | rank | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.RegExpExtract | regexp_extract | SELECT regexp_extract('100-200', '(\\d+)-(\\d+)', 1) | struct<regexp_extract(100-200, (\d+)-(\d+), 1):string> |
| org.apache.spark.sql.catalyst.expressions.RegExpExtractAll | regexp_extract_all | SELECT regexp_extract_all('100-200, 300-400', '(\\d+)-(\\d+)', 1) | struct<regexp_extract_all(100-200, 300-400, (\d+)-(\d+), 1):array<string>> |
| org.apache.spark.sql.catalyst.expressions.RegExpReplace | regexp_replace | SELECT regexp_replace('100-200', '(\\d+)', 'num') | struct<regexp_replace(100-200, (\d+), num):string> |
| org.apache.spark.sql.catalyst.expressions.Remainder | % | SELECT 2 % 1.8 | struct<(CAST(CAST(2 AS DECIMAL(1,0)) AS DECIMAL(2,1)) % CAST(1.8 AS DECIMAL(2,1))):decimal(2,1)> |
| org.apache.spark.sql.catalyst.expressions.Remainder | mod | SELECT 2 % 1.8 | struct<(CAST(CAST(2 AS DECIMAL(1,0)) AS DECIMAL(2,1)) % CAST(1.8 AS DECIMAL(2,1))):decimal(2,1)> |
Expand Down
22 changes: 22 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,30 @@ SELECT regexp_extract('1a 2b 14m', '\\d+', 0);
SELECT regexp_extract('1a 2b 14m', '\\d+', 1);
SELECT regexp_extract('1a 2b 14m', '\\d+', 2);
SELECT regexp_extract('1a 2b 14m', '\\d+', -1);
SELECT regexp_extract('1a 2b 14m', '(\\d+)?', 1);
SELECT regexp_extract('a b m', '(\\d+)?', 1);
SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)');
SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 0);
SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 1);
SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 2);
SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 3);
SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', -1);
SELECT regexp_extract('1a 2b 14m', '(\\d+)?([a-z]+)', 1);
SELECT regexp_extract('a b m', '(\\d+)?([a-z]+)', 1);

-- regexp_extract_all
SELECT regexp_extract_all('1a 2b 14m', '\\d+');
SELECT regexp_extract_all('1a 2b 14m', '\\d+', 0);
SELECT regexp_extract_all('1a 2b 14m', '\\d+', 1);
SELECT regexp_extract_all('1a 2b 14m', '\\d+', 2);
SELECT regexp_extract_all('1a 2b 14m', '\\d+', -1);
SELECT regexp_extract_all('1a 2b 14m', '(\\d+)?', 1);
SELECT regexp_extract_all('a 2b 14m', '(\\d+)?', 1);
SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)');
SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 0);
SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 1);
SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 2);
SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 3);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we test optional group here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', -1);
SELECT regexp_extract_all('1a 2b 14m', '(\\d+)?([a-z]+)', 1);
SELECT regexp_extract_all('a 2b 14m', '(\\d+)?([a-z]+)', 1);
Loading