Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24884][SQL] Support regexp function regexp_extract_all #27507

Closed
wants to merge 17 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ object FunctionRegistry {
expression[StringLocate]("position"),
expression[FormatString]("printf"),
expression[RegExpExtract]("regexp_extract"),
expression[RegExpExtractAll]("regexp_extract_all"),
expression[RegExpReplace]("regexp_replace"),
expression[StringRepeat]("repeat"),
expression[StringReplace]("replace"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst.expressions

import java.util.Locale
import java.util.regex.{MatchResult, Pattern}
import java.util.regex.{Matcher, MatchResult, Pattern}

import scala.collection.mutable.ArrayBuffer

import org.apache.commons.text.StringEscapeUtils

Expand Down Expand Up @@ -410,7 +412,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
}
}

object RegExpExtract {
object RegExpExtractBase {
def checkGroupIndex(groupCount: Int, groupIndex: Int): Unit = {
if (groupCount < groupIndex) {
throw new IllegalArgumentException(
Expand All @@ -419,39 +421,104 @@ object RegExpExtract {
}
}

abstract class RegExpExtractBase extends TernaryExpression with ImplicitCastInputTypes {
def subject: Expression
def regexp: Expression
def idx: Expression

// last regex in string, we will update the pattern iff regexp value changed.
@transient private var lastRegex: UTF8String = _
// last regex pattern, we cache it for performance concern
@transient private var pattern: Pattern = _

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
override def children: Seq[Expression] = subject :: regexp :: idx :: Nil

protected def getLastMatcher(s: Any, p: Any): Matcher = {
if (!p.equals(lastRegex)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is scala, we can just write p != lastRegex

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

// regex value changed
lastRegex = p.asInstanceOf[UTF8String].clone()
pattern = Pattern.compile(lastRegex.toString)
}
pattern.matcher(s.toString)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val classNamePattern = classOf[Pattern].getCanonicalName
val classNameRegExpExtractBase = classOf[RegExpExtractBase].getCanonicalName
val matcher = ctx.freshName("matcher")
val matchResult = ctx.freshName("matchResult")

val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex")
val termPattern = ctx.addMutableState(classNamePattern, "pattern")

val setEvNotNull = if (nullable) {
s"${ev.isNull} = false;"
} else {
""
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH I don't think there is much common code to share. Maybe we can have a
protected def setNotNullCode(ev: ExprCode) = ... but that's all.

How about we just let each sub-class implement doGenCode individually?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

doNullSafeCodeGen(
ctx,
ev,
classNamePattern,
classNameRegExpExtractBase,
matcher,
matchResult,
termLastRegex,
termPattern,
setEvNotNull)
}

def doNullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
classNamePattern: String,
classNameRegExpExtractBase: String,
matcher: String,
matchResult: String,
termLastRegex: String,
termPattern: String,
setEvNotNull: String): ExprCode
}

/**
* Extract a specific(idx) group identified by a Java regex.
*
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
*/
@ExpressionDescription(
usage = "_FUNC_(str, regexp[, idx]) - Extracts a group that matches `regexp`.",
arguments = """
Arguments:
* str - a string expression
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a string expression of the input string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

* regexp - a string expression. The regex string should be a Java regular expression.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a string expression of the regex string.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is Java regular expression?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. There just references the comment of RLIKE.


Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL
parser. For example, to match "\abc", a regular expression for `regexp` can be
"^\\abc$".

There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to
fallback to the Spark 1.6 behavior regarding string literal parsing. For example,
if the config is enabled, the `regexp` that can match "\abc" is "^\abc$".
* idx - a int expression. The regex maybe contains multiple groups. `idx` represents the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an int expression of the regex group index.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

index of regex group.
Copy link
Contributor

@cloud-fan cloud-fan Feb 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idx indicates which regex group to extract.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

""",
examples = """
Examples:
> SELECT _FUNC_('100-200', '(\\d+)-(\\d+)', 1);
100
""",
since = "1.5.0")
case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
extends RegExpExtractBase {
def this(s: Expression, r: Expression) = this(s, r, Literal(1))

// last regex in string, we will update the pattern iff regexp value changed.
@transient private var lastRegex: UTF8String = _
// last regex pattern, we cache it for performance concern
@transient private var pattern: Pattern = _

override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
if (!p.equals(lastRegex)) {
// regex value changed
lastRegex = p.asInstanceOf[UTF8String].clone()
pattern = Pattern.compile(lastRegex.toString)
}
val m = pattern.matcher(s.toString)
val m = getLastMatcher(s, p)
if (m.find) {
val mr: MatchResult = m.toMatchResult
val index = r.asInstanceOf[Int]
RegExpExtract.checkGroupIndex(mr.groupCount, index)
RegExpExtractBase.checkGroupIndex(mr.groupCount, index)
val group = mr.group(index)
if (group == null) { // Pattern matched, but not optional group
UTF8String.EMPTY_UTF8
Expand All @@ -464,25 +531,18 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
}

override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
override def children: Seq[Expression] = subject :: regexp :: idx :: Nil
override def prettyName: String = "regexp_extract"

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val classNamePattern = classOf[Pattern].getCanonicalName
val classNameRegExpExtract = classOf[RegExpExtract].getCanonicalName
val matcher = ctx.freshName("matcher")
val matchResult = ctx.freshName("matchResult")

val termLastRegex = ctx.addMutableState("UTF8String", "lastRegex")
val termPattern = ctx.addMutableState(classNamePattern, "pattern")

val setEvNotNull = if (nullable) {
s"${ev.isNull} = false;"
} else {
""
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep the blank line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

override def doNullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
classNamePattern: String,
classNameRegExpExtractBase: String,
matcher: String,
matchResult: String,
termLastRegex: String,
termPattern: String,
setEvNotNull: String): ExprCode = {
nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
s"""
if (!$regexp.equals($termLastRegex)) {
Expand All @@ -494,7 +554,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
$termPattern.matcher($subject.toString());
if ($matcher.find()) {
java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
$classNameRegExpExtract.checkGroupIndex($matchResult.groupCount(), $idx);
$classNameRegExpExtractBase.checkGroupIndex($matchResult.groupCount(), $idx);
if ($matchResult.group($idx) == null) {
${ev.value} = UTF8String.EMPTY_UTF8;
} else {
Expand All @@ -508,3 +568,96 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
})
}
}

/**
* Extract all specific(idx) group identified by a Java regex.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group -> groups

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

*
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
*/
@ExpressionDescription(
usage = "_FUNC_(str, regexp[, idx]) - Extracts all group that matches `regexp`.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we explain the semantic of idx?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we refine it a little bit? What the "group" means here if the idx is specified or not specified?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

arguments = """
Arguments:
* str - a string expression
* regexp - a string expression. The regex string should be a Java regular expression.

Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL
parser. For example, to match "\abc", a regular expression for `regexp` can be
"^\\abc$".

There is a SQL config 'spark.sql.parser.escapedStringLiterals' that can be used to
fallback to the Spark 1.6 behavior regarding string literal parsing. For example,
if the config is enabled, the `regexp` that can match "\abc" is "^\abc$".
* idx - a int expression. The regex maybe contains multiple groups. `idx` represents the
index of regex group.
""",
examples = """
Examples:
> SELECT _FUNC_('100-200, 300-400', '(\\d+)-(\\d+)', 1);
["100","300"]
""",
since = "3.0.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3.1.0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

case class RegExpExtractAll(subject: Expression, regexp: Expression, idx: Expression)
extends RegExpExtractBase {
def this(s: Expression, r: Expression) = this(s, r, Literal(1))

override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
val m = getLastMatcher(s, p)
val matchResults = new ArrayBuffer[UTF8String]()
val mr: MatchResult = m.toMatchResult
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we use this mr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove it.

while(m.find) {
val mr: MatchResult = m.toMatchResult
val index = r.asInstanceOf[Int]
RegExpExtractBase.checkGroupIndex(mr.groupCount, index)
val group = mr.group(index)
if (group == null) { // Pattern matched, but not optional group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but not optional group -> but it's an optional group?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just reference

if (group == null) { // Pattern matched, but not optional group

matchResults += UTF8String.EMPTY_UTF8
} else {
matchResults += UTF8String.fromString(group)
}
}

new GenericArrayData(matchResults.toArray.asInstanceOf[Array[Any]])
}

override def dataType: DataType = ArrayType(StringType)
override def prettyName: String = "regexp_extract_all"

override def doNullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
classNamePattern: String,
classNameRegExpExtractBase: String,
matcher: String,
matchResult: String,
termLastRegex: String,
termPattern: String,
setEvNotNull: String): ExprCode = {
val matchResults = ctx.freshName("matchResults")
val arrayClass = classOf[GenericArrayData].getName

nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
s"""
| if (!$regexp.equals($termLastRegex)) {
| // regex value changed
| $termLastRegex = $regexp.clone();
| $termPattern = $classNamePattern.compile($termLastRegex.toString());
| }
| java.util.regex.Matcher $matcher = $termPattern.matcher($subject.toString());
| java.util.ArrayList $matchResults = new java.util.ArrayList<UTF8String>();
| while ($matcher.find()) {
| java.util.regex.MatchResult $matchResult = $matcher.toMatchResult();
| $classNameRegExpExtractBase.checkGroupIndex($matchResult.groupCount(), $idx);
| if ($matchResult.group($idx) == null) {
| $matchResults.add(UTF8String.EMPTY_UTF8);
| } else {
| $matchResults.add(UTF8String.fromString($matchResult.group($idx)));
| }
| }
| ${ev.value} =
| new $arrayClass($matchResults.toArray(new UTF8String[$matchResults.size()]));
| $setEvNotNull
"""
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,48 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
expr, row10, "Regex group count is 0, but the specified group index is 1")
}

test("RegexExtractAll") {
val row1 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we test group 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

val row2 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", 2)
val row3 = create_row("100-200,300-400,500-600", "(\\d+).*", 1)
val row4 = create_row("100-200,300-400,500-600", "([a-z])", 1)
val row5 = create_row(null, "([a-z])", 1)
val row6 = create_row("100-200,300-400,500-600", null, 1)
val row7 = create_row("100-200,300-400,500-600", "([a-z])", null)

val s = 's.string.at(0)
val p = 'p.string.at(1)
val r = 'r.int.at(2)

val expr = RegExpExtractAll(s, p, r)
checkEvaluation(expr, Seq("100", "300", "500"), row1)
checkEvaluation(expr, Seq("200", "400", "600"), row2)
checkEvaluation(expr, Seq("100"), row3)
checkEvaluation(expr, Seq(), row4)
checkEvaluation(expr, null, row5)
checkEvaluation(expr, null, row6)
checkEvaluation(expr, null, row7)

val expr1 = new RegExpExtractAll(s, p)
checkEvaluation(expr1, Seq("100", "300", "500"), row1)

val nonNullExpr = RegExpExtractAll(Literal("100-200,300-400,500-600"),
Literal("(\\d+)-(\\d+)"), Literal(1))
checkEvaluation(nonNullExpr, Seq("100", "300", "500"), row1)

// invalid group index
val row8 = create_row("100-200,300-400,500-600", "(\\d+)-(\\d+)", 3)
val row9 = create_row("100-200,300-400,500-600", "(\\d+).*", 2)
val row10 = create_row("100-200,300-400,500-600", "\\d+", 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about negative group index?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

throw new IllegalArgumentException("The specified group index cannot be less than zero")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we test it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK


checkExceptionInExpression[IllegalArgumentException](
expr, row8, "Regex group count is 2, but the specified group index is 3")
checkExceptionInExpression[IllegalArgumentException](
expr, row9, "Regex group count is 1, but the specified group index is 2")
checkExceptionInExpression[IllegalArgumentException](
expr, row10, "Regex group count is 0, but the specified group index is 1")
}

test("SPLIT") {
val s1 = 'a.string.at(0)
val s2 = 'b.string.at(1)
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2383,6 +2383,17 @@ object functions {
RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr)
}

/**
* Extract all specific group matched by a Java regex, from the specified string column.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

groups

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

* If the regex did not match, or the specified group did not match, an empty array is returned.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior seems to be

  1. If the regex does not match, return an empty array
  2. if the specified group does not match, put an empty string to the result array.

Can we document the behavior in SQL expression? And can you verify this is the standard behavior in other databases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. should throw a IllegalArgumentException.
    [SPARK-30763][SQL] Fix java.lang.IndexOutOfBoundsException No group 1 for regexp_extract #27508
    the behavior of Hive is :
    FAILED: SemanticException [Error 10014]: Line 1:7 Wrong arguments ‘2’: org.apache.hadoop.hive.ql.metadata.HiveException: Unable to execute method public java.lang.String org.apache.hadoop.hive.ql.udf.UDFRegExpExtract.evaluate(java.lang.String,java.lang.String,java.lang.Integer) on object org.apache.hadoop.hive.ql.udf.UDFRegExpExtract@2cf5e0f0 of class org.apache.hadoop.hive.ql.udf.UDFRegExpExtract with arguments {x=a3&x=18abc&x=2&y=3&x=4:java.lang.String, x=([0-9]+)[a-z]:java.lang.String, 2:java.lang.Integer} of size 3

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's document the behavior clearly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

*
* @group string_funcs
* @since 3.0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3.1.0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

*/
def regexp_extract_all(e: Column, exp: String, groupIdx: Int): Column = withExpr {
RegExpExtractAll(e.expr, lit(exp).expr, lit(groupIdx).expr)
}

/**
* Replace all substrings of the specified string value that match regexp with rep.
*
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,15 @@ SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)');
SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 0);
SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 1);
SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 2);
SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 3);

-- regexp_extract_all
SELECT regexp_extract_all('1a 2b 14m', '\\d+');
SELECT regexp_extract_all('1a 2b 14m', '\\d+', 0);
SELECT regexp_extract_all('1a 2b 14m', '\\d+', 1);
SELECT regexp_extract_all('1a 2b 14m', '\\d+', 2);
SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)');
SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 0);
SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 1);
SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 2);
SELECT regexp_extract_all('1a 2b 14m', '(\\d+)([a-z]+)', 3);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we test optional group here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Loading