-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23922][SQL] Add arrays_overlap function #21028
Changes from 23 commits
e5ebdad
682bc73
876cd93
88e09b3
c895707
65b7d6d
f9a1ecf
1dbcd0c
076fc69
eafca0f
5925104
2a1121c
bf81e4a
4a18ba8
566946a
710433e
3cf410a
9d086f9
964f7af
41ef6c6
3dd724b
f7089f5
e36a5d7
49d9372
227437b
2e9e024
92730a1
56c59ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,15 +18,53 @@ package org.apache.spark.sql.catalyst.expressions | |
|
||
import java.util.Comparator | ||
|
||
import scala.collection.mutable | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} | ||
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder | ||
import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.unsafe.array.ByteArrayMethods | ||
import org.apache.spark.unsafe.types.{ByteArray, UTF8String} | ||
|
||
/** | ||
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit | ||
* casting. | ||
*/ | ||
trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression | ||
with ImplicitCastInputTypes { | ||
|
||
private val caseSensitive = SQLConf.get.caseSensitiveAnalysis | ||
|
||
@transient protected lazy val elementType: DataType = | ||
inputTypes.head.asInstanceOf[ArrayType].elementType | ||
|
||
override def inputTypes: Seq[AbstractDataType] = { | ||
(left.dataType, right.dataType) match { | ||
case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) => | ||
TypeCoercion.findTightestCommonType(e1, e2, caseSensitive) match { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should fail the build now |
||
case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2)) | ||
case _ => Seq.empty | ||
} | ||
case _ => Seq.empty | ||
} | ||
} | ||
|
||
override def checkInputDataTypes(): TypeCheckResult = { | ||
(left.dataType, right.dataType) match { | ||
case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) => | ||
TypeCheckResult.TypeCheckSuccess | ||
case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " + | ||
s"been two ${ArrayType.simpleString}s with same element type, but it's " + | ||
s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]") | ||
} | ||
} | ||
} | ||
|
||
|
||
/** | ||
* Given an array or map, returns its size. Returns -1 if null. | ||
*/ | ||
|
@@ -529,6 +567,239 @@ case class ArrayContains(left: Expression, right: Expression) | |
override def prettyName: String = "array_contains" | ||
} | ||
|
||
/** | ||
* Checks if the two arrays contain at least one common element. | ||
*/ | ||
// scalastyle:off line.size.limit | ||
@ExpressionDescription( | ||
usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.", | ||
examples = """ | ||
Examples: | ||
> SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5)); | ||
true | ||
""", since = "2.4.0") | ||
// scalastyle:off line.size.limit | ||
case class ArraysOverlap(left: Expression, right: Expression) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't you override
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, thanks! |
||
extends BinaryArrayExpressionWithImplicitCast { | ||
|
||
override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { | ||
case TypeCheckResult.TypeCheckSuccess => | ||
if (RowOrdering.isOrderable(elementType)) { | ||
TypeCheckResult.TypeCheckSuccess | ||
} else { | ||
TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.") | ||
} | ||
case failure => failure | ||
} | ||
|
||
@transient private lazy val ordering: Ordering[Any] = | ||
TypeUtils.getInterpretedOrdering(elementType) | ||
|
||
@transient private lazy val elementTypeSupportEquals = elementType match { | ||
case BinaryType => false | ||
case _: AtomicType => true | ||
case _ => false | ||
} | ||
|
||
@transient private lazy val doEvaluation = if (elementTypeSupportEquals) { | ||
fastEval _ | ||
} else { | ||
bruteForceEval _ | ||
} | ||
|
||
override def dataType: DataType = BooleanType | ||
|
||
override def nullable: Boolean = { | ||
left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull || | ||
right.dataType.asInstanceOf[ArrayType].containsNull | ||
} | ||
|
||
override def nullSafeEval(a1: Any, a2: Any): Any = { | ||
doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData]) | ||
} | ||
|
||
/** | ||
* A fast implementation which puts all the elements from the smaller array in a set | ||
* and then performs a lookup on it for each element of the bigger one. | ||
* This eval mode works only for data types which implements properly the equals method. | ||
*/ | ||
private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = { | ||
var hasNull = false | ||
val (bigger, smaller, biggerDt) = if (arr1.numElements() > arr2.numElements()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the |
||
(arr1, arr2, left.dataType.asInstanceOf[ArrayType]) | ||
} else { | ||
(arr2, arr1, right.dataType.asInstanceOf[ArrayType]) | ||
} | ||
if (smaller.numElements() > 0) { | ||
val smallestSet = new mutable.HashSet[Any] | ||
smaller.foreach(elementType, (_, v) => | ||
if (v == null) { | ||
hasNull = true | ||
} else { | ||
smallestSet += v | ||
}) | ||
bigger.foreach(elementType, (_, v1) => | ||
if (v1 == null) { | ||
hasNull = true | ||
} else if (smallestSet.contains(v1)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this doesn't work with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually it was not working also with |
||
return true | ||
} | ||
) | ||
} | ||
if (hasNull) { | ||
null | ||
} else { | ||
false | ||
} | ||
} | ||
|
||
/** | ||
* A slower evaluation which performs a nested loop and supports all the data types. | ||
*/ | ||
private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = { | ||
var hasNull = false | ||
if (arr1.numElements() > 0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also compare the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks, I added also some test cases for this |
||
arr1.foreach(elementType, (_, v1) => | ||
if (v1 == null) { | ||
hasNull = true | ||
} else { | ||
arr2.foreach(elementType, (_, v2) => | ||
if (v1 == null) { | ||
hasNull = true | ||
} else if (ordering.equiv(v1, v2)) { | ||
return true | ||
} | ||
) | ||
}) | ||
} | ||
if (hasNull) { | ||
null | ||
} else { | ||
false | ||
} | ||
} | ||
|
||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
nullSafeCodeGen(ctx, ev, (a1, a2) => { | ||
val smaller = ctx.freshName("smallerArray") | ||
val bigger = ctx.freshName("biggerArray") | ||
val comparisonCode = if (elementTypeSupportEquals) { | ||
fastCodegen(ctx, ev, smaller, bigger) | ||
} else { | ||
bruteForceCodegen(ctx, ev, smaller, bigger) | ||
} | ||
s""" | ||
|ArrayData $smaller; | ||
|ArrayData $bigger; | ||
|if ($a1.numElements() > $a2.numElements()) { | ||
| $bigger = $a1; | ||
| $smaller = $a2; | ||
|} else { | ||
| $smaller = $a1; | ||
| $bigger = $a2; | ||
|} | ||
|if ($smaller.numElements() > 0) { | ||
| $comparisonCode | ||
|} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah i see, sorry I misread the code in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. np, thanks for checking it. Always better one check more than one less. |
||
""".stripMargin | ||
}) | ||
} | ||
|
||
/** | ||
* Code generation for a fast implementation which puts all the elements from the smaller array | ||
* in a set and then performs a lookup on it for each element of the bigger one. | ||
* It works only for data types which implements properly the equals method. | ||
*/ | ||
private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { | ||
val i = ctx.freshName("i") | ||
val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i) | ||
val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) | ||
val javaElementClass = CodeGenerator.boxedType(elementType) | ||
val javaSet = classOf[java.util.HashSet[_]].getName | ||
val set = ctx.freshName("set") | ||
val addToSetFromSmallerCode = nullSafeElementCodegen( | ||
smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;") | ||
val elementIsInSetCode = nullSafeElementCodegen( | ||
bigger, | ||
i, | ||
s""" | ||
|if ($set.contains($getFromBigger)) { | ||
| ${ev.isNull} = false; | ||
| ${ev.value} = true; | ||
| break; | ||
|} | ||
""".stripMargin, | ||
s"${ev.isNull} = true;") | ||
s""" | ||
|$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>(); | ||
|for (int $i = 0; $i < $smaller.numElements(); $i ++) { | ||
| $addToSetFromSmallerCode | ||
|} | ||
|for (int $i = 0; $i < $bigger.numElements(); $i ++) { | ||
| $elementIsInSetCode | ||
|} | ||
""".stripMargin | ||
} | ||
|
||
/** | ||
* Code generation for a slower evaluation which performs a nested loop and supports all the data types. | ||
*/ | ||
private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { | ||
val i = ctx.freshName("i") | ||
val j = ctx.freshName("j") | ||
val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j) | ||
val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) | ||
val compareValues = nullSafeElementCodegen( | ||
smaller, | ||
j, | ||
s""" | ||
|if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) { | ||
| ${ev.isNull} = false; | ||
| ${ev.value} = true; | ||
| break; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so we wanna break 2 loops here, that's why we generate 2
|
||
|} | ||
""".stripMargin, | ||
s"${ev.isNull} = true;") | ||
val isInSmaller = nullSafeElementCodegen( | ||
bigger, | ||
i, | ||
s""" | ||
|for (int $j = 0; $j < $smaller.numElements(); $j ++) { | ||
| $compareValues | ||
| if (${ev.value}) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need this if? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it was in the wrong place, thanks, nice catch! |
||
| break; | ||
| } | ||
|} | ||
""".stripMargin, | ||
s"${ev.isNull} = true;") | ||
s""" | ||
|for (int $i = 0; $i < $bigger.numElements(); $i ++) { | ||
|$isInSmaller | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2 space |
||
|} | ||
""".stripMargin | ||
} | ||
|
||
def nullSafeElementCodegen( | ||
arrayVar: String, | ||
index: String, | ||
code: String, | ||
isNullCode: String): String = { | ||
if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't this depend on whether the input array There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unfortunately we don't know which one we have here (the left or the rigth) as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i see, makes sense! |
||
s""" | ||
|if ($arrayVar.isNullAt($index)) { | ||
| $isNullCode | ||
|} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. |
||
| $code | ||
|} | ||
""".stripMargin | ||
} else { | ||
code | ||
} | ||
} | ||
|
||
override def prettyName: String = "arrays_overlap" | ||
} | ||
|
||
/** | ||
* Slices an array according to the requested start index and length | ||
*/ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -136,6 +136,60 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) | ||
} | ||
|
||
test("ArraysOverlap") { | ||
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) | ||
val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType)) | ||
val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType)) | ||
val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType)) | ||
val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) | ||
val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType)) | ||
val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType)) | ||
|
||
val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) | ||
|
||
checkEvaluation(ArraysOverlap(a0, a1), true) | ||
checkEvaluation(ArraysOverlap(a0, a2), null) | ||
checkEvaluation(ArraysOverlap(a1, a2), true) | ||
checkEvaluation(ArraysOverlap(a1, a3), false) | ||
checkEvaluation(ArraysOverlap(a0, emptyIntArray), false) | ||
checkEvaluation(ArraysOverlap(a2, emptyIntArray), false) | ||
checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) | ||
|
||
checkEvaluation(ArraysOverlap(a4, a5), true) | ||
checkEvaluation(ArraysOverlap(a4, a6), null) | ||
checkEvaluation(ArraysOverlap(a5, a6), false) | ||
|
||
// null handling | ||
checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) | ||
checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null) | ||
checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null) | ||
checkEvaluation(ArraysOverlap( | ||
Literal.create(Seq(null), ArrayType(IntegerType)), | ||
Literal.create(Seq(null), ArrayType(IntegerType))), null) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am returning
I will add a sentence to clarify the behavior in our docs. Thanks for this nice catch! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we have a test case for it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This case is covered by https://github.com/apache/spark/pull/21028/files#diff-d31eca9f1c4c33104dc2cb8950486910R163 for instance. Anyway, I am adding another on which is exactly this one. |
||
|
||
// arrays of binaries | ||
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)), | ||
ArrayType(BinaryType)) | ||
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), | ||
ArrayType(BinaryType)) | ||
val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), | ||
ArrayType(BinaryType)) | ||
|
||
checkEvaluation(ArraysOverlap(b0, b1), true) | ||
checkEvaluation(ArraysOverlap(b0, b2), false) | ||
|
||
// arrays of complex data types | ||
val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")), | ||
ArrayType(ArrayType(StringType))) | ||
val aa1 = Literal.create(Seq[Array[String]](Array[String]("e", "f"), Array[String]("a", "b")), | ||
ArrayType(ArrayType(StringType))) | ||
val aa2 = Literal.create(Seq[Array[String]](Array[String]("b", "a"), Array[String]("f", "g")), | ||
ArrayType(ArrayType(StringType))) | ||
|
||
checkEvaluation(ArraysOverlap(aa0, aa1), true) | ||
checkEvaluation(ArraysOverlap(aa0, aa2), false) | ||
} | ||
|
||
test("Slice") { | ||
val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) | ||
val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
ImplicitCastInputTypes
trait is able to work with any number of children. Would it be possible to implement this trait to behave in the same way?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's possible indeed. Though, as far as I know there is no use case for a function with a different number of children, so I am not sure if it makes sense to generalize it. @cloud-fan @kiszk @ueshin WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As @ueshin pointed out here,
concat
is also a use case that has a different number of children. Am I wrong?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kiszk you are not wrong, but
Concat
is a very specific case, since it supports alsoString
s andBinary
s, so it would anyway require a specific implementation.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, I would like to hear other opinions