From e5ebdad41645c0058f1cd2788f6cc1d4158ff2e9 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 10 Apr 2018 15:49:53 +0200 Subject: [PATCH] [SPARK-23922][SQL] Add arrays_overlap function --- python/pyspark/sql/functions.py | 14 +++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 109 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 25 ++++ .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 17 +++ 6 files changed, 176 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1b192680f0795..88bed09c563fe 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1846,6 +1846,20 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def arrays_overlap(a1, a2): + """ + Collection function: returns true if the arrays contain any common non-null element; if not, returns + null if any of the arrays contains a null element and false otherwise. + + >>> df = spark.createDataFrame([(["a", "b", "c"], ["c", "d", "e"]), (["a"], ["b", "c"])], ['x', 'y']) + >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() + [Row(overlap=True), Row(overlap=False)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 747016beb06e7..e672b9f7063c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -402,6 +402,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArraysOverlap]("arrays_overlap"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[MapKeys]("map_keys"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 91188da8b0bd3..c3e78935386f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -287,3 +287,112 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Checks if the two arrays contain at least one common element. + */ +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least an element present also in a2.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5)); + true + """, since = "2.4.0") +case class ArraysOverlap(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + private lazy val elementType = inputTypes.head.asInstanceOf[ArrayType].elementType + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = left.dataType match { + case la: ArrayType if la.sameType(right.dataType) => + Seq(la, la) + case _ => Seq.empty + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (!left.dataType.isInstanceOf[ArrayType] || !right.dataType.isInstanceOf[ArrayType] || + !left.dataType.sameType(right.dataType)) { + TypeCheckResult.TypeCheckFailure("Arguments must be arrays with the same element type.") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull || + right.dataType.asInstanceOf[ArrayType].containsNull + } + + override def nullSafeEval(a1: Any, a2: Any): Any = { + var hasNull = false + val arr1 = a1.asInstanceOf[ArrayData] + val arr2 = a2.asInstanceOf[ArrayData] + if (arr1.numElements() > 0) { + arr1.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else { + arr2.foreach(elementType, (_, v2) => + if (v2 == null) { + hasNull = true + } else if (v1 == v2) { + return true + } + ) + } + ) + } else { + arr2.foreach(elementType, (_, v) => + if (v == null) { + return null + } + ) + } + if (hasNull) { + null + } else { + false + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (a1, a2) => { + val i1 = ctx.freshName("i") + val i2 = ctx.freshName("i") + val getValue1 = CodeGenerator.getValue(a1, elementType, i1) + val getValue2 = CodeGenerator.getValue(a2, elementType, i2) + s""" + |if ($a1.numElements() > 0) { + | for (int $i1 = 0; $i1 < $a1.numElements(); $i1 ++) { + | if ($a1.isNullAt($i1)) { + | ${ev.isNull} = true; + | } else { + | for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) { + | if ($a2.isNullAt($i2)) { + | ${ev.isNull} = true; + | } else if (${ctx.genEqual(elementType, getValue1, getValue2)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + | } + | } + | if (${ev.value}) { + | break; + | } + | } + | } + |} else { + | for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) { + | if ($a2.isNullAt($i2)) { + | ${ev.isNull} = true; + | break; + | } + | } + |} + |""".stripMargin + }) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a27..2e93d6f2533f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,4 +105,29 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("ArraysOverlap") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType)) + val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType)) + val a4 = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + + val a5 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a6 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType)) + val a7 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType)) + + checkEvaluation(ArraysOverlap(a0, a1), true) + checkEvaluation(ArraysOverlap(a0, a2), null) + checkEvaluation(ArraysOverlap(a1, a2), true) + checkEvaluation(ArraysOverlap(a1, a3), false) + checkEvaluation(ArraysOverlap(a0, a4), false) + checkEvaluation(ArraysOverlap(a2, a4), null) + checkEvaluation(ArraysOverlap(a4, a2), null) + + checkEvaluation(ArraysOverlap(a5, a6), true) + checkEvaluation(ArraysOverlap(a5, a7), null) + checkEvaluation(ArraysOverlap(a6, a7), false) + + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c658f25ced053..3c580ba6f6ceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3046,6 +3046,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and + * any of the arrays contains a `null`, it returns `null`. It returns `false` otherwise. + * @group collection_funcs + * @since 2.4.0 + */ + def arrays_overlap(a1: Column, a2: Column): Column = withExpr { + ArraysOverlap(a1.expr, a2.expr) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e475984f458..4fcf8681e4bd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,23 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("arrays_overlap function") { + val df = Seq( + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))), + (Seq.empty[Option[Int]], Seq[Option[Int]](Some(-1), None)), + (Seq[Option[Int]](Some(3), Some(2)), Seq[Option[Int]](Some(1), Some(2))) + ).toDF("a", "b") + + val answer = Seq(Row(false), Row(null), Row(true)) + + checkAnswer(df.select(arrays_overlap(df("a"), df("b"))), answer) + checkAnswer(df.selectExpr("arrays_overlap(a, b)"), answer) + + intercept[AnalysisException] { + df.selectExpr("arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {