Skip to content

Commit

Permalink
[SPARK-23922][SQL] Add arrays_overlap function
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

The PR adds the function `arrays_overlap`. This function returns `true` if the input arrays contain a non-null common element; if not, it returns `null` if any of the arrays contains a `null` element, `false` otherwise.

## How was this patch tested?

added UTs

Author: Marco Gaido <[email protected]>

Closes #21028 from mgaido91/SPARK-23922.
  • Loading branch information
mgaido91 authored and cloud-fan committed May 17, 2018
1 parent 6ec0582 commit 69350aa
Show file tree
Hide file tree
Showing 6 changed files with 388 additions and 1 deletion.
15 changes: 15 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,21 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))


@since(2.4)
def arrays_overlap(a1, a2):
"""
Collection function: returns true if the arrays contain any common non-null element; if not,
returns null if both the arrays are non-empty and any of them contains a null element; returns
false otherwise.
>>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y'])
>>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect()
[Row(overlap=True), Row(overlap=False)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2)))


@since(2.4)
def slice(x, start, length):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ object FunctionRegistry {
// collection functions
expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
expression[ArraysOverlap]("arrays_overlap"),
expression[ArrayJoin]("array_join"),
expression[ArrayPosition]("array_position"),
expression[ArraySort]("array_sort"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,51 @@ package org.apache.spark.sql.catalyst.expressions

import java.util.Comparator

import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}

/**
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
* casting.
*/
trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression
with ImplicitCastInputTypes {

@transient protected lazy val elementType: DataType =
inputTypes.head.asInstanceOf[ArrayType].elementType

override def inputTypes: Seq[AbstractDataType] = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) =>
TypeCoercion.findTightestCommonType(e1, e2) match {
case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2))
case _ => Seq.empty
}
case _ => Seq.empty
}
}

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) =>
TypeCheckResult.TypeCheckSuccess
case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " +
s"been two ${ArrayType.simpleString}s with same element type, but it's " +
s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]")
}
}
}


/**
* Given an array or map, returns its size. Returns -1 if null.
*/
Expand Down Expand Up @@ -529,6 +565,235 @@ case class ArrayContains(left: Expression, right: Expression)
override def prettyName: String = "array_contains"
}

/**
* Checks if the two arrays contain at least one common element.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5));
true
""", since = "2.4.0")
// scalastyle:off line.size.limit
case class ArraysOverlap(left: Expression, right: Expression)
extends BinaryArrayExpressionWithImplicitCast {

override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
if (RowOrdering.isOrderable(elementType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.")
}
case failure => failure
}

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(elementType)

@transient private lazy val elementTypeSupportEquals = elementType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
}

@transient private lazy val doEvaluation = if (elementTypeSupportEquals) {
fastEval _
} else {
bruteForceEval _
}

override def dataType: DataType = BooleanType

override def nullable: Boolean = {
left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull ||
right.dataType.asInstanceOf[ArrayType].containsNull
}

override def nullSafeEval(a1: Any, a2: Any): Any = {
doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData])
}

/**
* A fast implementation which puts all the elements from the smaller array in a set
* and then performs a lookup on it for each element of the bigger one.
* This eval mode works only for data types which implements properly the equals method.
*/
private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = {
var hasNull = false
val (bigger, smaller) = if (arr1.numElements() > arr2.numElements()) {
(arr1, arr2)
} else {
(arr2, arr1)
}
if (smaller.numElements() > 0) {
val smallestSet = new mutable.HashSet[Any]
smaller.foreach(elementType, (_, v) =>
if (v == null) {
hasNull = true
} else {
smallestSet += v
})
bigger.foreach(elementType, (_, v1) =>
if (v1 == null) {
hasNull = true
} else if (smallestSet.contains(v1)) {
return true
}
)
}
if (hasNull) {
null
} else {
false
}
}

/**
* A slower evaluation which performs a nested loop and supports all the data types.
*/
private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = {
var hasNull = false
if (arr1.numElements() > 0 && arr2.numElements() > 0) {
arr1.foreach(elementType, (_, v1) =>
if (v1 == null) {
hasNull = true
} else {
arr2.foreach(elementType, (_, v2) =>
if (v2 == null) {
hasNull = true
} else if (ordering.equiv(v1, v2)) {
return true
}
)
})
}
if (hasNull) {
null
} else {
false
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (a1, a2) => {
val smaller = ctx.freshName("smallerArray")
val bigger = ctx.freshName("biggerArray")
val comparisonCode = if (elementTypeSupportEquals) {
fastCodegen(ctx, ev, smaller, bigger)
} else {
bruteForceCodegen(ctx, ev, smaller, bigger)
}
s"""
|ArrayData $smaller;
|ArrayData $bigger;
|if ($a1.numElements() > $a2.numElements()) {
| $bigger = $a1;
| $smaller = $a2;
|} else {
| $smaller = $a1;
| $bigger = $a2;
|}
|if ($smaller.numElements() > 0) {
| $comparisonCode
|}
""".stripMargin
})
}

/**
* Code generation for a fast implementation which puts all the elements from the smaller array
* in a set and then performs a lookup on it for each element of the bigger one.
* It works only for data types which implements properly the equals method.
*/
private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = {
val i = ctx.freshName("i")
val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i)
val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
val javaElementClass = CodeGenerator.boxedType(elementType)
val javaSet = classOf[java.util.HashSet[_]].getName
val set = ctx.freshName("set")
val addToSetFromSmallerCode = nullSafeElementCodegen(
smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;")
val elementIsInSetCode = nullSafeElementCodegen(
bigger,
i,
s"""
|if ($set.contains($getFromBigger)) {
| ${ev.isNull} = false;
| ${ev.value} = true;
| break;
|}
""".stripMargin,
s"${ev.isNull} = true;")
s"""
|$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>();
|for (int $i = 0; $i < $smaller.numElements(); $i ++) {
| $addToSetFromSmallerCode
|}
|for (int $i = 0; $i < $bigger.numElements(); $i ++) {
| $elementIsInSetCode
|}
""".stripMargin
}

/**
* Code generation for a slower evaluation which performs a nested loop and supports all the data types.
*/
private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = {
val i = ctx.freshName("i")
val j = ctx.freshName("j")
val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j)
val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
val compareValues = nullSafeElementCodegen(
smaller,
j,
s"""
|if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) {
| ${ev.isNull} = false;
| ${ev.value} = true;
|}
""".stripMargin,
s"${ev.isNull} = true;")
val isInSmaller = nullSafeElementCodegen(
bigger,
i,
s"""
|for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) {
| $compareValues
|}
""".stripMargin,
s"${ev.isNull} = true;")
s"""
|for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) {
| $isInSmaller
|}
""".stripMargin
}

def nullSafeElementCodegen(
arrayVar: String,
index: String,
code: String,
isNullCode: String): String = {
if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) {
s"""
|if ($arrayVar.isNullAt($index)) {
| $isNullCode
|} else {
| $code
|}
""".stripMargin
} else {
code
}
}

override def prettyName: String = "arrays_overlap"
}

/**
* Slices an array according to the requested start index and length
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,72 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}

test("ArraysOverlap") {
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType))
val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType))
val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType))
val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType))
val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType))

val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType))

checkEvaluation(ArraysOverlap(a0, a1), true)
checkEvaluation(ArraysOverlap(a0, a2), null)
checkEvaluation(ArraysOverlap(a1, a2), true)
checkEvaluation(ArraysOverlap(a1, a3), false)
checkEvaluation(ArraysOverlap(a0, emptyIntArray), false)
checkEvaluation(ArraysOverlap(a2, emptyIntArray), false)
checkEvaluation(ArraysOverlap(emptyIntArray, a2), false)

checkEvaluation(ArraysOverlap(a4, a5), true)
checkEvaluation(ArraysOverlap(a4, a6), null)
checkEvaluation(ArraysOverlap(a5, a6), false)

// null handling
checkEvaluation(ArraysOverlap(emptyIntArray, a2), false)
checkEvaluation(ArraysOverlap(
emptyIntArray, Literal.create(Seq(null), ArrayType(IntegerType))), false)
checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null)
checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null)
checkEvaluation(ArraysOverlap(
Literal.create(Seq(null), ArrayType(IntegerType)),
Literal.create(Seq(null), ArrayType(IntegerType))), null)

// arrays of binaries
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)),
ArrayType(BinaryType))
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
ArrayType(BinaryType))
val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
ArrayType(BinaryType))

checkEvaluation(ArraysOverlap(b0, b1), true)
checkEvaluation(ArraysOverlap(b0, b2), false)

// arrays of complex data types
val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")),
ArrayType(ArrayType(StringType)))
val aa1 = Literal.create(Seq[Array[String]](Array[String]("e", "f"), Array[String]("a", "b")),
ArrayType(ArrayType(StringType)))
val aa2 = Literal.create(Seq[Array[String]](Array[String]("b", "a"), Array[String]("f", "g")),
ArrayType(ArrayType(StringType)))

checkEvaluation(ArraysOverlap(aa0, aa1), true)
checkEvaluation(ArraysOverlap(aa0, aa2), false)

// null handling with complex datatypes
val emptyBinaryArray = Literal.create(Seq.empty[Array[Byte]], ArrayType(BinaryType))
val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType))
checkEvaluation(ArraysOverlap(emptyBinaryArray, b0), false)
checkEvaluation(ArraysOverlap(b0, emptyBinaryArray), false)
checkEvaluation(ArraysOverlap(emptyBinaryArray, arrayWithBinaryNull), false)
checkEvaluation(ArraysOverlap(arrayWithBinaryNull, emptyBinaryArray), false)
checkEvaluation(ArraysOverlap(arrayWithBinaryNull, b0), null)
checkEvaluation(ArraysOverlap(b0, arrayWithBinaryNull), null)
}

test("Slice") {
val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType))
val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType))
Expand Down
Loading

0 comments on commit 69350aa

Please sign in to comment.