Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23922][SQL] Add arrays_overlap function #21028

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e5ebdad
[SPARK-23922][SQL] Add arrays_overlap function
mgaido91 Apr 10, 2018
682bc73
fix python style
mgaido91 Apr 10, 2018
876cd93
Merge branch 'master' of github.com:apache/spark into SPARK-23922
mgaido91 Apr 17, 2018
88e09b3
review comments
mgaido91 Apr 20, 2018
c895707
Merge branch 'master' of github.com:apache/spark into SPARK-23922
mgaido91 Apr 20, 2018
65b7d6d
introduce BinaryArrayExpressionWithImplicitCast
mgaido91 Apr 27, 2018
f9a1ecf
Merge branch 'master' of github.com:apache/spark into SPARK-23922
mgaido91 Apr 27, 2018
1dbcd0c
fix type check
mgaido91 Apr 27, 2018
076fc69
fix scalastyle
mgaido91 Apr 27, 2018
eafca0f
fix build error
mgaido91 Apr 27, 2018
5925104
fix
mgaido91 Apr 27, 2018
2a1121c
address comments
mgaido91 May 3, 2018
bf81e4a
use sets instead of nested loops
mgaido91 May 3, 2018
4a18ba8
address review comments
mgaido91 May 4, 2018
566946a
address review comments
mgaido91 May 4, 2018
710433e
add test case for null
mgaido91 May 4, 2018
3cf410a
Merge branch 'master' of github.com:apache/spark into SPARK-23922
mgaido91 May 7, 2018
9d086f9
address comments
mgaido91 May 7, 2018
964f7af
use findTightestCommonType for type inference
mgaido91 May 8, 2018
41ef6c6
support binary and complex data types
mgaido91 May 9, 2018
3dd724b
review comments
mgaido91 May 11, 2018
f7089f5
Merge branch 'master' of github.com:apache/spark into SPARK-23922
mgaido91 May 11, 2018
e36a5d7
fix compilation error
mgaido91 May 11, 2018
49d9372
address comments
mgaido91 May 14, 2018
227437b
address comment
mgaido91 May 15, 2018
2e9e024
fix null handling with complex types
mgaido91 May 16, 2018
92730a1
Merge branch 'master' of github.com:apache/spark into SPARK-23922
mgaido91 May 17, 2018
56c59ae
fix build
mgaido91 May 17, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,6 +1852,21 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))


@since(2.4)
def arrays_overlap(a1, a2):
"""
Collection function: returns true if the arrays contain any common non-null element; if not,
returns null if both the arrays are non-empty and any of them contains a null element; returns
false otherwise.

>>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y'])
>>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect()
[Row(overlap=True), Row(overlap=False)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2)))


@since(2.4)
def slice(x, start, length):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ object FunctionRegistry {
// collection functions
expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
expression[ArraysOverlap]("arrays_overlap"),
expression[ArrayJoin]("array_join"),
expression[ArrayPosition]("array_position"),
expression[ArraySort]("array_sort"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,53 @@ package org.apache.spark.sql.catalyst.expressions

import java.util.Comparator

import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}

/**
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
* casting.
*/
trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ImplicitCastInputTypes trait is able to work with any number of children. Would it be possible to implement this trait to behave in the same way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible indeed. Though, as far as I know there is no use case for a function with a different number of children, so I am not sure if it makes sense to generalize it. @cloud-fan @kiszk @ueshin WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @ueshin pointed out here, concat is also a use case that has a different number of children. Am I wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kiszk you are not wrong, but Concat is a very specific case, since it supports also Strings and Binarys, so it would anyway require a specific implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I would like to hear other opinions

with ImplicitCastInputTypes {

private val caseSensitive = SQLConf.get.caseSensitiveAnalysis

@transient protected lazy val elementType: DataType =
inputTypes.head.asInstanceOf[ArrayType].elementType

override def inputTypes: Seq[AbstractDataType] = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) =>
TypeCoercion.findTightestCommonType(e1, e2, caseSensitive) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should fail the build now

case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2))
case _ => Seq.empty
}
case _ => Seq.empty
}
}

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) =>
TypeCheckResult.TypeCheckSuccess
case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " +
s"been two ${ArrayType.simpleString}s with same element type, but it's " +
s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]")
}
}
}


/**
* Given an array or map, returns its size. Returns -1 if null.
*/
Expand Down Expand Up @@ -529,6 +567,239 @@ case class ArrayContains(left: Expression, right: Expression)
override def prettyName: String = "array_contains"
}

/**
* Checks if the two arrays contain at least one common element.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5));
true
""", since = "2.4.0")
// scalastyle:off line.size.limit
case class ArraysOverlap(left: Expression, right: Expression)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't you override prettyName to a value following the conventions?

override def prettyName: String = "arrays_overlap"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks!

extends BinaryArrayExpressionWithImplicitCast {

override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
if (RowOrdering.isOrderable(elementType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.")
}
case failure => failure
}

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(elementType)

@transient private lazy val elementTypeSupportEquals = elementType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
}

@transient private lazy val doEvaluation = if (elementTypeSupportEquals) {
fastEval _
} else {
bruteForceEval _
}

override def dataType: DataType = BooleanType

override def nullable: Boolean = {
left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull ||
right.dataType.asInstanceOf[ArrayType].containsNull
}

override def nullSafeEval(a1: Any, a2: Any): Any = {
doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData])
}

/**
* A fast implementation which puts all the elements from the smaller array in a set
* and then performs a lookup on it for each element of the bigger one.
* This eval mode works only for data types which implements properly the equals method.
*/
private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = {
var hasNull = false
val (bigger, smaller, biggerDt) = if (arr1.numElements() > arr2.numElements()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the biggerDt is not used

(arr1, arr2, left.dataType.asInstanceOf[ArrayType])
} else {
(arr2, arr1, right.dataType.asInstanceOf[ArrayType])
}
if (smaller.numElements() > 0) {
val smallestSet = new mutable.HashSet[Any]
smaller.foreach(elementType, (_, v) =>
if (v == null) {
hasNull = true
} else {
smallestSet += v
})
bigger.foreach(elementType, (_, v1) =>
if (v1 == null) {
hasNull = true
} else if (smallestSet.contains(v1)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't work with BinaryType(the data is byte[]). We may need to wrap values with ByteBuffer first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it was not working also with ArrayType, so I addressed the problem in a more general way which supports both these cases. Thanks.

return true
}
)
}
if (hasNull) {
null
} else {
false
}
}

/**
* A slower evaluation which performs a nested loop and supports all the data types.
*/
private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = {
var hasNull = false
if (arr1.numElements() > 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also compare the numElements here and check the array is empty for the smaller one? Otherwise the result is different if the arr1 is not empty and contains null and arr2 is empty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I added also some test cases for this

arr1.foreach(elementType, (_, v1) =>
if (v1 == null) {
hasNull = true
} else {
arr2.foreach(elementType, (_, v2) =>
if (v1 == null) {
hasNull = true
} else if (ordering.equiv(v1, v2)) {
return true
}
)
})
}
if (hasNull) {
null
} else {
false
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (a1, a2) => {
val smaller = ctx.freshName("smallerArray")
val bigger = ctx.freshName("biggerArray")
val comparisonCode = if (elementTypeSupportEquals) {
fastCodegen(ctx, ev, smaller, bigger)
} else {
bruteForceCodegen(ctx, ev, smaller, bigger)
}
s"""
|ArrayData $smaller;
|ArrayData $bigger;
|if ($a1.numElements() > $a2.numElements()) {
| $bigger = $a1;
| $smaller = $a2;
|} else {
| $smaller = $a1;
| $bigger = $a2;
|}
|if ($smaller.numElements() > 0) {
| $comparisonCode
|}
Copy link
Contributor

@cloud-fan cloud-fan May 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if smaller is empty, we should return false. However, the code here depends on the initial value of ev.value and ev.isNull, which, according to nullSafeCodeGen, depends on nullable.

Copy link
Contributor Author

@mgaido91 mgaido91 May 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but ev.isNull anyway is it initiated to false, unless one of the input is null. And in that case we don't even reach this point because we just return null.

ev.value is initiated always to the defaultValue which is false. So when we arrive here we are sure that they are both false.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see, sorry I misread the code in nullSafeCodeGen

Copy link
Contributor Author

@mgaido91 mgaido91 May 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np, thanks for checking it. Always better one check more than one less.

""".stripMargin
})
}

/**
* Code generation for a fast implementation which puts all the elements from the smaller array
* in a set and then performs a lookup on it for each element of the bigger one.
* It works only for data types which implements properly the equals method.
*/
private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = {
val i = ctx.freshName("i")
val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i)
val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
val javaElementClass = CodeGenerator.boxedType(elementType)
val javaSet = classOf[java.util.HashSet[_]].getName
val set = ctx.freshName("set")
val addToSetFromSmallerCode = nullSafeElementCodegen(
smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;")
val elementIsInSetCode = nullSafeElementCodegen(
bigger,
i,
s"""
|if ($set.contains($getFromBigger)) {
| ${ev.isNull} = false;
| ${ev.value} = true;
| break;
|}
""".stripMargin,
s"${ev.isNull} = true;")
s"""
|$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>();
|for (int $i = 0; $i < $smaller.numElements(); $i ++) {
| $addToSetFromSmallerCode
|}
|for (int $i = 0; $i < $bigger.numElements(); $i ++) {
| $elementIsInSetCode
|}
""".stripMargin
}

/**
* Code generation for a slower evaluation which performs a nested loop and supports all the data types.
*/
private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = {
val i = ctx.freshName("i")
val j = ctx.freshName("j")
val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j)
val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
val compareValues = nullSafeElementCodegen(
smaller,
j,
s"""
|if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) {
| ${ev.isNull} = false;
| ${ev.value} = true;
| break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we wanna break 2 loops here, that's why we generate 2 breaks. it might be cleaner to generate

for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) {
  for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) {
    ...
  }
}

|}
""".stripMargin,
s"${ev.isNull} = true;")
val isInSmaller = nullSafeElementCodegen(
bigger,
i,
s"""
|for (int $j = 0; $j < $smaller.numElements(); $j ++) {
| $compareValues
| if (${ev.value}) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this if?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was in the wrong place, thanks, nice catch!

| break;
| }
|}
""".stripMargin,
s"${ev.isNull} = true;")
s"""
|for (int $i = 0; $i < $bigger.numElements(); $i ++) {
|$isInSmaller
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 space

|}
""".stripMargin
}

def nullSafeElementCodegen(
arrayVar: String,
index: String,
code: String,
isNullCode: String): String = {
if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this depend on whether the input array arrayVar contains null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately we don't know which one we have here (the left or the rigth) as arrayVar, since we don't know which one is the smaller/bigger and this can change record to record. So we can skip the null check only if both them don't contain null.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see, makes sense!

s"""
|if ($arrayVar.isNullAt($index)) {
| $isNullCode
|} else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

| $code
|}
""".stripMargin
} else {
code
}
}

override def prettyName: String = "arrays_overlap"
}

/**
* Slices an array according to the requested start index and length
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,60 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}

test("ArraysOverlap") {
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType))
val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType))
val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType))
val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType))
val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType))

val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType))

checkEvaluation(ArraysOverlap(a0, a1), true)
checkEvaluation(ArraysOverlap(a0, a2), null)
checkEvaluation(ArraysOverlap(a1, a2), true)
checkEvaluation(ArraysOverlap(a1, a3), false)
checkEvaluation(ArraysOverlap(a0, emptyIntArray), false)
checkEvaluation(ArraysOverlap(a2, emptyIntArray), false)
checkEvaluation(ArraysOverlap(emptyIntArray, a2), false)

checkEvaluation(ArraysOverlap(a4, a5), true)
checkEvaluation(ArraysOverlap(a4, a6), null)
checkEvaluation(ArraysOverlap(a5, a6), false)

// null handling
checkEvaluation(ArraysOverlap(emptyIntArray, a2), false)
checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null)
checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null)
checkEvaluation(ArraysOverlap(
Literal.create(Seq(null), ArrayType(IntegerType)),
Literal.create(Seq(null), ArrayType(IntegerType))), null)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if arrays_overlap(array(), array(null))?
Seems like Presto returns false for the case. TestArrayOperators.java#L1041
Also can you add the test case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am returning null for it. This is interesting. I checked Presto's implementation and it returns false if any of the input arrays is empty. I am copying Presto's behavior but this is quite against what the docs say:

Returns null if there are no non-null elements in common but either array contains null.

I will add a sentence to clarify the behavior in our docs. Thanks for this nice catch!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a test case for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is covered by https://github.com/apache/spark/pull/21028/files#diff-d31eca9f1c4c33104dc2cb8950486910R163 for instance. Anyway, I am adding another on which is exactly this one.


// arrays of binaries
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)),
ArrayType(BinaryType))
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
ArrayType(BinaryType))
val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
ArrayType(BinaryType))

checkEvaluation(ArraysOverlap(b0, b1), true)
checkEvaluation(ArraysOverlap(b0, b2), false)

// arrays of complex data types
val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")),
ArrayType(ArrayType(StringType)))
val aa1 = Literal.create(Seq[Array[String]](Array[String]("e", "f"), Array[String]("a", "b")),
ArrayType(ArrayType(StringType)))
val aa2 = Literal.create(Seq[Array[String]](Array[String]("b", "a"), Array[String]("f", "g")),
ArrayType(ArrayType(StringType)))

checkEvaluation(ArraysOverlap(aa0, aa1), true)
checkEvaluation(ArraysOverlap(aa0, aa2), false)
}

test("Slice") {
val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType))
val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType))
Expand Down
Loading