Skip to content

Commit

Permalink
[SPARK-23922][SQL] Add arrays_overlap function
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Apr 10, 2018
1 parent 6498884 commit e5ebdad
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 0 deletions.
14 changes: 14 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1846,6 +1846,20 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))


@since(2.4)
def arrays_overlap(a1, a2):
"""
Collection function: returns true if the arrays contain any common non-null element; if not, returns
null if any of the arrays contains a null element and false otherwise.
>>> df = spark.createDataFrame([(["a", "b", "c"], ["c", "d", "e"]), (["a"], ["b", "c"])], ['x', 'y'])
>>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect()
[Row(overlap=True), Row(overlap=False)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2)))


@since(1.4)
def explode(col):
"""Returns a new row for each element in the given array or map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ object FunctionRegistry {
// collection functions
expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
expression[ArraysOverlap]("arrays_overlap"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[MapKeys]("map_keys"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,112 @@ case class ArrayContains(left: Expression, right: Expression)

override def prettyName: String = "array_contains"
}


/**
* Checks if the two arrays contain at least one common element.
*/
@ExpressionDescription(
usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least an element present also in a2.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5));
true
""", since = "2.4.0")
case class ArraysOverlap(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

private lazy val elementType = inputTypes.head.asInstanceOf[ArrayType].elementType

override def dataType: DataType = BooleanType

override def inputTypes: Seq[AbstractDataType] = left.dataType match {
case la: ArrayType if la.sameType(right.dataType) =>
Seq(la, la)
case _ => Seq.empty
}

override def checkInputDataTypes(): TypeCheckResult = {
if (!left.dataType.isInstanceOf[ArrayType] || !right.dataType.isInstanceOf[ArrayType] ||
!left.dataType.sameType(right.dataType)) {
TypeCheckResult.TypeCheckFailure("Arguments must be arrays with the same element type.")
} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def nullable: Boolean = {
left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull ||
right.dataType.asInstanceOf[ArrayType].containsNull
}

override def nullSafeEval(a1: Any, a2: Any): Any = {
var hasNull = false
val arr1 = a1.asInstanceOf[ArrayData]
val arr2 = a2.asInstanceOf[ArrayData]
if (arr1.numElements() > 0) {
arr1.foreach(elementType, (_, v1) =>
if (v1 == null) {
hasNull = true
} else {
arr2.foreach(elementType, (_, v2) =>
if (v2 == null) {
hasNull = true
} else if (v1 == v2) {
return true
}
)
}
)
} else {
arr2.foreach(elementType, (_, v) =>
if (v == null) {
return null
}
)
}
if (hasNull) {
null
} else {
false
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (a1, a2) => {
val i1 = ctx.freshName("i")
val i2 = ctx.freshName("i")
val getValue1 = CodeGenerator.getValue(a1, elementType, i1)
val getValue2 = CodeGenerator.getValue(a2, elementType, i2)
s"""
|if ($a1.numElements() > 0) {
| for (int $i1 = 0; $i1 < $a1.numElements(); $i1 ++) {
| if ($a1.isNullAt($i1)) {
| ${ev.isNull} = true;
| } else {
| for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) {
| if ($a2.isNullAt($i2)) {
| ${ev.isNull} = true;
| } else if (${ctx.genEqual(elementType, getValue1, getValue2)}) {
| ${ev.isNull} = false;
| ${ev.value} = true;
| break;
| }
| }
| if (${ev.value}) {
| break;
| }
| }
| }
|} else {
| for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) {
| if ($a2.isNullAt($i2)) {
| ${ev.isNull} = true;
| break;
| }
| }
|}
|""".stripMargin
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,29 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal("")), null)
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}

test("ArraysOverlap") {
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType))
val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType))
val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType))
val a4 = Literal.create(Seq.empty[Int], ArrayType(IntegerType))

val a5 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
val a6 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType))
val a7 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType))

checkEvaluation(ArraysOverlap(a0, a1), true)
checkEvaluation(ArraysOverlap(a0, a2), null)
checkEvaluation(ArraysOverlap(a1, a2), true)
checkEvaluation(ArraysOverlap(a1, a3), false)
checkEvaluation(ArraysOverlap(a0, a4), false)
checkEvaluation(ArraysOverlap(a2, a4), null)
checkEvaluation(ArraysOverlap(a4, a2), null)

checkEvaluation(ArraysOverlap(a5, a6), true)
checkEvaluation(ArraysOverlap(a5, a7), null)
checkEvaluation(ArraysOverlap(a6, a7), false)

}
}
10 changes: 10 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3046,6 +3046,16 @@ object functions {
ArrayContains(column.expr, Literal(value))
}

/**
* Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and
* any of the arrays contains a `null`, it returns `null`. It returns `false` otherwise.
* @group collection_funcs
* @since 2.4.0
*/
def arrays_overlap(a1: Column, a2: Column): Column = withExpr {
ArraysOverlap(a1.expr, a2.expr)
}

/**
* Creates a new row for each element in the given array or map column.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,23 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("arrays_overlap function") {
val df = Seq(
(Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))),
(Seq.empty[Option[Int]], Seq[Option[Int]](Some(-1), None)),
(Seq[Option[Int]](Some(3), Some(2)), Seq[Option[Int]](Some(1), Some(2)))
).toDF("a", "b")

val answer = Seq(Row(false), Row(null), Row(true))

checkAnswer(df.select(arrays_overlap(df("a"), df("b"))), answer)
checkAnswer(df.selectExpr("arrays_overlap(a, b)"), answer)

intercept[AnalysisException] {
df.selectExpr("arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))")
}
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down

0 comments on commit e5ebdad

Please sign in to comment.