From e5ebdad41645c0058f1cd2788f6cc1d4158ff2e9 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 10 Apr 2018 15:49:53 +0200 Subject: [PATCH 01/22] [SPARK-23922][SQL] Add arrays_overlap function --- python/pyspark/sql/functions.py | 14 +++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 109 ++++++++++++++++++ .../CollectionExpressionsSuite.scala | 25 ++++ .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 17 +++ 6 files changed, 176 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1b192680f0795..88bed09c563fe 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1846,6 +1846,20 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def arrays_overlap(a1, a2): + """ + Collection function: returns true if the arrays contain any common non-null element; if not, returns + null if any of the arrays contains a null element and false otherwise. + + >>> df = spark.createDataFrame([(["a", "b", "c"], ["c", "d", "e"]), (["a"], ["b", "c"])], ['x', 'y']) + >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() + [Row(overlap=True), Row(overlap=False)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2))) + + @since(1.4) def explode(col): """Returns a new row for each element in the given array or map. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 747016beb06e7..e672b9f7063c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -402,6 +402,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArraysOverlap]("arrays_overlap"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[MapKeys]("map_keys"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 91188da8b0bd3..c3e78935386f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -287,3 +287,112 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Checks if the two arrays contain at least one common element. + */ +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least an element present also in a2.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5)); + true + """, since = "2.4.0") +case class ArraysOverlap(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + private lazy val elementType = inputTypes.head.asInstanceOf[ArrayType].elementType + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = left.dataType match { + case la: ArrayType if la.sameType(right.dataType) => + Seq(la, la) + case _ => Seq.empty + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (!left.dataType.isInstanceOf[ArrayType] || !right.dataType.isInstanceOf[ArrayType] || + !left.dataType.sameType(right.dataType)) { + TypeCheckResult.TypeCheckFailure("Arguments must be arrays with the same element type.") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull || + right.dataType.asInstanceOf[ArrayType].containsNull + } + + override def nullSafeEval(a1: Any, a2: Any): Any = { + var hasNull = false + val arr1 = a1.asInstanceOf[ArrayData] + val arr2 = a2.asInstanceOf[ArrayData] + if (arr1.numElements() > 0) { + arr1.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else { + arr2.foreach(elementType, (_, v2) => + if (v2 == null) { + hasNull = true + } else if (v1 == v2) { + return true + } + ) + } + ) + } else { + arr2.foreach(elementType, (_, v) => + if (v == null) { + return null + } + ) + } + if (hasNull) { + null + } else { + false + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (a1, a2) => { + val i1 = ctx.freshName("i") + val i2 = ctx.freshName("i") + val getValue1 = CodeGenerator.getValue(a1, elementType, i1) + val getValue2 = CodeGenerator.getValue(a2, elementType, i2) + s""" + |if ($a1.numElements() > 0) { + | for (int $i1 = 0; $i1 < $a1.numElements(); $i1 ++) { + | if ($a1.isNullAt($i1)) { + | ${ev.isNull} = true; + | } else { + | for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) { + | if ($a2.isNullAt($i2)) { + | ${ev.isNull} = true; + | } else if (${ctx.genEqual(elementType, getValue1, getValue2)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + | } + | } + | if (${ev.value}) { + | break; + | } + | } + | } + |} else { + | for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) { + | if ($a2.isNullAt($i2)) { + | ${ev.isNull} = true; + | break; + | } + | } + |} + |""".stripMargin + }) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a27..2e93d6f2533f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,4 +105,29 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("ArraysOverlap") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType)) + val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType)) + val a4 = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + + val a5 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a6 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType)) + val a7 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType)) + + checkEvaluation(ArraysOverlap(a0, a1), true) + checkEvaluation(ArraysOverlap(a0, a2), null) + checkEvaluation(ArraysOverlap(a1, a2), true) + checkEvaluation(ArraysOverlap(a1, a3), false) + checkEvaluation(ArraysOverlap(a0, a4), false) + checkEvaluation(ArraysOverlap(a2, a4), null) + checkEvaluation(ArraysOverlap(a4, a2), null) + + checkEvaluation(ArraysOverlap(a5, a6), true) + checkEvaluation(ArraysOverlap(a5, a7), null) + checkEvaluation(ArraysOverlap(a6, a7), false) + + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c658f25ced053..3c580ba6f6ceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3046,6 +3046,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and + * any of the arrays contains a `null`, it returns `null`. It returns `false` otherwise. + * @group collection_funcs + * @since 2.4.0 + */ + def arrays_overlap(a1: Column, a2: Column): Column = withExpr { + ArraysOverlap(a1.expr, a2.expr) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e475984f458..4fcf8681e4bd0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,23 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("arrays_overlap function") { + val df = Seq( + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))), + (Seq.empty[Option[Int]], Seq[Option[Int]](Some(-1), None)), + (Seq[Option[Int]](Some(3), Some(2)), Seq[Option[Int]](Some(1), Some(2))) + ).toDF("a", "b") + + val answer = Seq(Row(false), Row(null), Row(true)) + + checkAnswer(df.select(arrays_overlap(df("a"), df("b"))), answer) + checkAnswer(df.selectExpr("arrays_overlap(a, b)"), answer) + + intercept[AnalysisException] { + df.selectExpr("arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))") + } + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 682bc7327ce4e4442ae5d1ddfe662f6b5dc99593 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 10 Apr 2018 17:19:47 +0200 Subject: [PATCH 02/22] fix python style --- python/pyspark/sql/functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 88bed09c563fe..3757afbd033d9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1849,10 +1849,10 @@ def array_contains(col, value): @since(2.4) def arrays_overlap(a1, a2): """ - Collection function: returns true if the arrays contain any common non-null element; if not, returns - null if any of the arrays contains a null element and false otherwise. + Collection function: returns true if the arrays contain any common non-null element; if not, + returns null if any of the arrays contains a null element and false otherwise. - >>> df = spark.createDataFrame([(["a", "b", "c"], ["c", "d", "e"]), (["a"], ["b", "c"])], ['x', 'y']) + >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y']) >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() [Row(overlap=True), Row(overlap=False)] """ From 88e09b3e9ebfc6b162a1d7403592c34c5fb1f71f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 20 Apr 2018 17:38:08 +0200 Subject: [PATCH 03/22] review comments --- .../expressions/collectionOperations.scala | 80 +++++++++++++------ .../CollectionExpressionsSuite.scala | 7 ++ 2 files changed, 61 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a43376c0a66d6..48135fceb6987 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -291,13 +291,15 @@ case class ArrayContains(left: Expression, right: Expression) /** * Checks if the two arrays contain at least one common element. */ +// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least an element present also in a2.", + usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least an element present also in a2. If the arrays have no common element and either of them contains a null element null is returned, false otherwise.", examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5)); true """, since = "2.4.0") +// scalastyle:off line.size.limit case class ArraysOverlap(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -307,7 +309,7 @@ case class ArraysOverlap(left: Expression, right: Expression) override def inputTypes: Seq[AbstractDataType] = left.dataType match { case la: ArrayType if la.sameType(right.dataType) => - Seq(la, la) + Seq(la, right.dataType) case _ => Seq.empty } @@ -363,37 +365,63 @@ case class ArraysOverlap(left: Expression, right: Expression) val i2 = ctx.freshName("i") val getValue1 = CodeGenerator.getValue(a1, elementType, i1) val getValue2 = CodeGenerator.getValue(a2, elementType, i2) + val leftEmptyCode = if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |else { + | for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) { + | if ($a2.isNullAt($i2)) { + | ${ev.isNull} = true; + | break; + | } + | } + |} + """.stripMargin + } else { + "" + } s""" |if ($a1.numElements() > 0) { | for (int $i1 = 0; $i1 < $a1.numElements(); $i1 ++) { - | if ($a1.isNullAt($i1)) { - | ${ev.isNull} = true; - | } else { - | for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) { - | if ($a2.isNullAt($i2)) { - | ${ev.isNull} = true; - | } else if (${ctx.genEqual(elementType, getValue1, getValue2)}) { - | ${ev.isNull} = false; - | ${ev.value} = true; - | break; - | } - | } - | if (${ev.value}) { - | break; - | } - | } + | ${nullSafeElementCodegen(left.dataType.asInstanceOf[ArrayType], a1, i1, + s""" + |for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) { + | ${nullSafeElementCodegen(right.dataType.asInstanceOf[ArrayType], a2, i2, + s""" + |if (${ctx.genEqual(elementType, getValue1, getValue2)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + |} + """.stripMargin, s"${ev.isNull} = true;")} + |} + |if (${ev.value}) { + | break; + |} + """.stripMargin, s"${ev.isNull} = true;")} | } - |} else { - | for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) { - | if ($a2.isNullAt($i2)) { - | ${ev.isNull} = true; - | break; - | } - | } - |} + |} $leftEmptyCode |""".stripMargin }) } + + def nullSafeElementCodegen( + arrayType: ArrayType, + arrayVar: String, + index: String, + code: String, + isNullCode: String): String = { + if (arrayType.containsNull) { + s""" + |if ($arrayVar.isNullAt($index)) { + | $isNullCode + |} else { + | $code + |} + """.stripMargin + } else { + code + } + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index b159f56204909..97a5a50c38752 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -128,6 +128,13 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArraysOverlap(a5, a6), true) checkEvaluation(ArraysOverlap(a5, a7), null) checkEvaluation(ArraysOverlap(a6, a7), false) + + // null handling + checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null) + checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null) + checkEvaluation(ArraysOverlap( + Literal.create(Seq(null), ArrayType(IntegerType)), + Literal.create(Seq(null), ArrayType(IntegerType))), null) } test("Array Min") { From 65b7d6d4ce998b026e580f48c1a7467415c9c261 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 27 Apr 2018 16:36:12 +0200 Subject: [PATCH 04/22] introduce BinaryArrayExpressionWithImplicitCast --- .../expressions/collectionOperations.scala | 47 +++++++++++-------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2fd0bdb7e21d8..7b3188321148a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ @@ -27,6 +27,32 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +/** + * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit + * casting. + */ +trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression + with ImplicitCastInputTypes { + + protected lazy val elementType: DataType = inputTypes.head.asInstanceOf[ArrayType].elementType + + override def inputTypes: Seq[AbstractDataType] = { + TypeCoercion.findWiderTypeForTwo(left.dataType, right.dataType) match { + case Some(arrayType) => Seq(arrayType, arrayType) + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + TypeCoercion.findWiderTypeForTwo(left.dataType, right.dataType) match { + case Some(ArrayType(_, _)) => TypeCheckResult.TypeCheckSuccess + case None => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " + + s"been two ${ArrayType.simpleString}s with same element type, but it's " + + s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]") + } + } +} + + /** * Given an array or map, returns its size. Returns -1 if null. */ @@ -391,27 +417,10 @@ case class ArrayContains(left: Expression, right: Expression) """, since = "2.4.0") // scalastyle:off line.size.limit case class ArraysOverlap(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - private lazy val elementType = inputTypes.head.asInstanceOf[ArrayType].elementType + extends BinaryArrayExpressionWithImplicitCast { override def dataType: DataType = BooleanType - override def inputTypes: Seq[AbstractDataType] = left.dataType match { - case la: ArrayType if la.sameType(right.dataType) => - Seq(la, right.dataType) - case _ => Seq.empty - } - - override def checkInputDataTypes(): TypeCheckResult = { - if (!left.dataType.isInstanceOf[ArrayType] || !right.dataType.isInstanceOf[ArrayType] || - !left.dataType.sameType(right.dataType)) { - TypeCheckResult.TypeCheckFailure("Arguments must be arrays with the same element type.") - } else { - TypeCheckResult.TypeCheckSuccess - } - } - override def nullable: Boolean = { left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull || right.dataType.asInstanceOf[ArrayType].containsNull From 1dbcd0c68171ee5375e54a320b4314741a135fbd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 27 Apr 2018 17:31:46 +0200 Subject: [PATCH 05/22] fix type check --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 1 + .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5d3f836c6cbe6..00e5ddc8c4594 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -39,6 +39,7 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression override def inputTypes: Seq[AbstractDataType] = { TypeCoercion.findWiderTypeForTwo(left.dataType, right.dataType) match { case Some(arrayType) => Seq(arrayType, arrayType) + case None => Seq.empty } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b4a8ce9aa80b8..833f7f2ee65bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -425,8 +425,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.select(arrays_overlap(df("a"), df("b"))), answer) checkAnswer(df.selectExpr("arrays_overlap(a, b)"), answer) + checkAnswer(sql("select arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))"), Row(false)) + intercept[AnalysisException] { - df.selectExpr("arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))") + sql("select arrays_overlap(array(array(1)), array('a'))") } } From 076fc698d4054b757e5afb14d1d6bfc190c2c6f7 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 27 Apr 2018 17:52:12 +0200 Subject: [PATCH 06/22] fix scalastyle --- .../sql/catalyst/expressions/CollectionExpressionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index e7fbb51d91af6..e170f9d7d27da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -136,7 +136,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Literal.create(Seq(null), ArrayType(IntegerType)), Literal.create(Seq(null), ArrayType(IntegerType))), null) } - + test("ArrayJoin") { def testArrays( arrays: Seq[Expression], From eafca0f1cc969f66c569215508d4c687b5527305 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 27 Apr 2018 18:33:41 +0200 Subject: [PATCH 07/22] fix build error --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 00e5ddc8c4594..4661454348418 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -39,7 +39,7 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression override def inputTypes: Seq[AbstractDataType] = { TypeCoercion.findWiderTypeForTwo(left.dataType, right.dataType) match { case Some(arrayType) => Seq(arrayType, arrayType) - case None => Seq.empty + case _ => Seq.empty } } From 592510461622cd8eccd6f93af2e1fdbc0521fb98 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 27 Apr 2018 19:00:53 +0200 Subject: [PATCH 08/22] fix --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4661454348418..664e83569342e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -46,7 +46,7 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression override def checkInputDataTypes(): TypeCheckResult = { TypeCoercion.findWiderTypeForTwo(left.dataType, right.dataType) match { case Some(ArrayType(_, _)) => TypeCheckResult.TypeCheckSuccess - case None => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " + + case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " + s"been two ${ArrayType.simpleString}s with same element type, but it's " + s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]") } From 2a1121c8699b5c1fb2d1566338bd11d0bb70caa2 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 3 May 2018 12:17:47 +0200 Subject: [PATCH 09/22] address comments --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 664e83569342e..26d200e9dd3f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -445,7 +445,7 @@ case class ArraysOverlap(left: Expression, right: Expression) ) } ) - } else { + } else if (right.dataType.asInstanceOf[ArrayType].containsNull) { arr2.foreach(elementType, (_, v) => if (v == null) { return null @@ -522,6 +522,8 @@ case class ArraysOverlap(left: Expression, right: Expression) code } } + + override def prettyName: String = "arrays_overlap" } /** From bf81e4a6c842f5196d501f0891956bfdaab4fbfb Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 3 May 2018 15:37:46 +0200 Subject: [PATCH 10/22] use sets instead of nested loops --- .../expressions/collectionOperations.scala | 57 ++++++++++--------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 26d200e9dd3f0..7663a0dc8728a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -432,25 +432,19 @@ case class ArraysOverlap(left: Expression, right: Expression) val arr1 = a1.asInstanceOf[ArrayData] val arr2 = a2.asInstanceOf[ArrayData] if (arr1.numElements() > 0) { + val set2 = arr2.array.toSet arr1.foreach(elementType, (_, v1) => if (v1 == null) { hasNull = true - } else { - arr2.foreach(elementType, (_, v2) => - if (v2 == null) { - hasNull = true - } else if (v1 == v2) { - return true - } - ) - } - ) - } else if (right.dataType.asInstanceOf[ArrayType].containsNull) { - arr2.foreach(elementType, (_, v) => - if (v == null) { - return null + } else if (set2.contains(v1)) { + return true } ) + if (!hasNull && containsNull(arr2, right.dataType.asInstanceOf[ArrayType])) { + hasNull = true + } + } else if (containsNull(arr2, right.dataType.asInstanceOf[ArrayType])) { + hasNull = true } if (hasNull) { null @@ -459,6 +453,17 @@ case class ArraysOverlap(left: Expression, right: Expression) } } + def containsNull(arr: ArrayData, dt: ArrayType): Boolean = { + if (dt.containsNull) { + arr.foreach(elementType, (_, v) => + if (v == null) { + return true + } + ) + } + false + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (a1, a2) => { val i1 = ctx.freshName("i") @@ -479,25 +484,25 @@ case class ArraysOverlap(left: Expression, right: Expression) } else { "" } + val javaElementClass = CodeGenerator.boxedType(elementType) + val javaSet = classOf[java.util.HashSet[_]].getName + val set2 = ctx.freshName("set") s""" |if ($a1.numElements() > 0) { + | $javaSet<$javaElementClass> $set2 = new $javaSet<$javaElementClass>(); + | for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) { + | ${nullSafeElementCodegen(right.dataType.asInstanceOf[ArrayType], a2, i2, + s"$set2.add($getValue2);", s"${ev.isNull} = true;")} + | } | for (int $i1 = 0; $i1 < $a1.numElements(); $i1 ++) { | ${nullSafeElementCodegen(left.dataType.asInstanceOf[ArrayType], a1, i1, s""" - |for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) { - | ${nullSafeElementCodegen(right.dataType.asInstanceOf[ArrayType], a2, i2, - s""" - |if (${ctx.genEqual(elementType, getValue1, getValue2)}) { - | ${ev.isNull} = false; - | ${ev.value} = true; - | break; - |} - """.stripMargin, s"${ev.isNull} = true;")} - |} - |if (${ev.value}) { + |if ($set2.contains($getValue1)) { + | ${ev.isNull} = false; + | ${ev.value} = true; | break; |} - """.stripMargin, s"${ev.isNull} = true;")} + |""".stripMargin, s"${ev.isNull} = true;")} | } |} $leftEmptyCode |""".stripMargin From 4a18ba89c4489de5efe34eb31108665474f70f23 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 4 May 2018 11:15:19 +0200 Subject: [PATCH 11/22] address review comments --- .../expressions/collectionOperations.scala | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7663a0dc8728a..c34c5a2f33dc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -44,8 +46,9 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression } override def checkInputDataTypes(): TypeCheckResult = { - TypeCoercion.findWiderTypeForTwo(left.dataType, right.dataType) match { - case Some(ArrayType(_, _)) => TypeCheckResult.TypeCheckSuccess + (left.dataType, right.dataType) match { + case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) => + TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " + s"been two ${ArrayType.simpleString}s with same element type, but it's " + s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]") @@ -432,7 +435,13 @@ case class ArraysOverlap(left: Expression, right: Expression) val arr1 = a1.asInstanceOf[ArrayData] val arr2 = a2.asInstanceOf[ArrayData] if (arr1.numElements() > 0) { - val set2 = arr2.array.toSet + val set2 = new mutable.HashSet[Any] + arr2.foreach(elementType, (_, v) => + if (v == null) { + hasNull = true + } else { + set2 += v + }) arr1.foreach(elementType, (_, v1) => if (v1 == null) { hasNull = true @@ -440,9 +449,6 @@ case class ArraysOverlap(left: Expression, right: Expression) return true } ) - if (!hasNull && containsNull(arr2, right.dataType.asInstanceOf[ArrayType])) { - hasNull = true - } } else if (containsNull(arr2, right.dataType.asInstanceOf[ArrayType])) { hasNull = true } From 566946a8bf9d4ca295fabc305b204bb0dd8752f0 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 4 May 2018 13:53:09 +0200 Subject: [PATCH 12/22] address review comments --- .../expressions/collectionOperations.scala | 72 ++++++++++++------- 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c34c5a2f33dc7..b994d398b5f38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -434,22 +434,27 @@ case class ArraysOverlap(left: Expression, right: Expression) var hasNull = false val arr1 = a1.asInstanceOf[ArrayData] val arr2 = a2.asInstanceOf[ArrayData] - if (arr1.numElements() > 0) { - val set2 = new mutable.HashSet[Any] - arr2.foreach(elementType, (_, v) => + val (biggestArr, smallestArr) = if (arr1.numElements() > arr2.numElements()) { + (arr1, arr2) + } else { + (arr2, arr1) + } + if (smallestArr.numElements() > 0) { + val smallestSet = new mutable.HashSet[Any] + smallestArr.foreach(elementType, (_, v) => if (v == null) { hasNull = true } else { - set2 += v + smallestSet += v }) - arr1.foreach(elementType, (_, v1) => + biggestArr.foreach(elementType, (_, v1) => if (v1 == null) { hasNull = true - } else if (set2.contains(v1)) { + } else if (smallestSet.contains(v1)) { return true } ) - } else if (containsNull(arr2, right.dataType.asInstanceOf[ArrayType])) { + } else if (containsNull(biggestArr, right.dataType.asInstanceOf[ArrayType])) { hasNull = true } if (hasNull) { @@ -472,15 +477,16 @@ case class ArraysOverlap(left: Expression, right: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (a1, a2) => { - val i1 = ctx.freshName("i") - val i2 = ctx.freshName("i") - val getValue1 = CodeGenerator.getValue(a1, elementType, i1) - val getValue2 = CodeGenerator.getValue(a2, elementType, i2) + val i = ctx.freshName("i") + val smallestArray = ctx.freshName("smallestArray") + val biggestArray = ctx.freshName("biggestArray") + val getFromSmallest = CodeGenerator.getValue(smallestArray, elementType, i) + val getFromBiggest = CodeGenerator.getValue(biggestArray, elementType, i) val leftEmptyCode = if (right.dataType.asInstanceOf[ArrayType].containsNull) { s""" |else { - | for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) { - | if ($a2.isNullAt($i2)) { + | for (int $i = 0; $i < $biggestArray.numElements(); $i ++) { + | if ($biggestArray.isNullAt($i)) { | ${ev.isNull} = true; | break; | } @@ -493,22 +499,36 @@ case class ArraysOverlap(left: Expression, right: Expression) val javaElementClass = CodeGenerator.boxedType(elementType) val javaSet = classOf[java.util.HashSet[_]].getName val set2 = ctx.freshName("set") + val addToSetFromSmallestCode = nullSafeElementCodegen(right.dataType.asInstanceOf[ArrayType], + smallestArray, i, s"$set2.add($getFromSmallest);", s"${ev.isNull} = true;") + val elementIsInSetCode = nullSafeElementCodegen(left.dataType.asInstanceOf[ArrayType], + biggestArray, + i, + s""" + |if ($set2.contains($getFromBiggest)) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + |} + |""".stripMargin, + s"${ev.isNull} = true;") s""" - |if ($a1.numElements() > 0) { + |ArrayData $smallestArray; + |ArrayData $biggestArray; + |if ($a1.numElements() > $a2.numElements()) { + | $biggestArray = $a1; + | $smallestArray = $a2; + |} else { + | $smallestArray = $a1; + | $biggestArray = $a2; + |} + |if ($smallestArray.numElements() > 0) { | $javaSet<$javaElementClass> $set2 = new $javaSet<$javaElementClass>(); - | for (int $i2 = 0; $i2 < $a2.numElements(); $i2 ++) { - | ${nullSafeElementCodegen(right.dataType.asInstanceOf[ArrayType], a2, i2, - s"$set2.add($getValue2);", s"${ev.isNull} = true;")} + | for (int $i = 0; $i < $smallestArray.numElements(); $i ++) { + | $addToSetFromSmallestCode | } - | for (int $i1 = 0; $i1 < $a1.numElements(); $i1 ++) { - | ${nullSafeElementCodegen(left.dataType.asInstanceOf[ArrayType], a1, i1, - s""" - |if ($set2.contains($getValue1)) { - | ${ev.isNull} = false; - | ${ev.value} = true; - | break; - |} - |""".stripMargin, s"${ev.isNull} = true;")} + | for (int $i = 0; $i < $biggestArray.numElements(); $i ++) { + | $elementIsInSetCode | } |} $leftEmptyCode |""".stripMargin From 710433ea3a69db4d6286aff676a57e475e499a5d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 4 May 2018 20:46:48 +0200 Subject: [PATCH 13/22] add test case for null --- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 833f7f2ee65bd..e4a55366eac32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -430,6 +430,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { intercept[AnalysisException] { sql("select arrays_overlap(array(array(1)), array('a'))") } + + intercept[AnalysisException] { + sql("select arrays_overlap(null, null)") + } } test("array_join function") { From 9d086f992215f479576f4337b072a11698c95ac2 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 7 May 2018 15:36:14 +0200 Subject: [PATCH 14/22] address comments --- .../expressions/collectionOperations.scala | 80 ++++++++++--------- .../org/apache/spark/sql/functions.scala | 2 +- 2 files changed, 42 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3e2c06406fed6..5f70bd8e76328 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -585,27 +585,27 @@ case class ArraysOverlap(left: Expression, right: Expression) var hasNull = false val arr1 = a1.asInstanceOf[ArrayData] val arr2 = a2.asInstanceOf[ArrayData] - val (biggestArr, smallestArr) = if (arr1.numElements() > arr2.numElements()) { - (arr1, arr2) + val (bigger, smaller, biggerDt) = if (arr1.numElements() > arr2.numElements()) { + (arr1, arr2, left.dataType.asInstanceOf[ArrayType]) } else { - (arr2, arr1) + (arr2, arr1, right.dataType.asInstanceOf[ArrayType]) } - if (smallestArr.numElements() > 0) { + if (smaller.numElements() > 0) { val smallestSet = new mutable.HashSet[Any] - smallestArr.foreach(elementType, (_, v) => + smaller.foreach(elementType, (_, v) => if (v == null) { hasNull = true } else { smallestSet += v }) - biggestArr.foreach(elementType, (_, v1) => + bigger.foreach(elementType, (_, v1) => if (v1 == null) { hasNull = true } else if (smallestSet.contains(v1)) { return true } ) - } else if (containsNull(biggestArr, right.dataType.asInstanceOf[ArrayType])) { + } else if (containsNull(bigger, biggerDt)) { hasNull = true } if (hasNull) { @@ -617,27 +617,30 @@ case class ArraysOverlap(left: Expression, right: Expression) def containsNull(arr: ArrayData, dt: ArrayType): Boolean = { if (dt.containsNull) { - arr.foreach(elementType, (_, v) => - if (v == null) { - return true - } - ) + var i = 0 + var hasNull = false + while (i < arr.numElements && !hasNull) { + hasNull = arr.isNullAt(i) + i += 1 + } + hasNull + } else { + false } - false } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (a1, a2) => { val i = ctx.freshName("i") - val smallestArray = ctx.freshName("smallestArray") - val biggestArray = ctx.freshName("biggestArray") - val getFromSmallest = CodeGenerator.getValue(smallestArray, elementType, i) - val getFromBiggest = CodeGenerator.getValue(biggestArray, elementType, i) - val leftEmptyCode = if (right.dataType.asInstanceOf[ArrayType].containsNull) { + val smaller = ctx.freshName("smallerArray") + val bigger = ctx.freshName("biggerArray") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val smallerEmptyCode = if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) { s""" |else { - | for (int $i = 0; $i < $biggestArray.numElements(); $i ++) { - | if ($biggestArray.isNullAt($i)) { + | for (int $i = 0; $i < $bigger.numElements(); $i ++) { + | if ($bigger.isNullAt($i)) { | ${ev.isNull} = true; | break; | } @@ -650,13 +653,13 @@ case class ArraysOverlap(left: Expression, right: Expression) val javaElementClass = CodeGenerator.boxedType(elementType) val javaSet = classOf[java.util.HashSet[_]].getName val set2 = ctx.freshName("set") - val addToSetFromSmallestCode = nullSafeElementCodegen(right.dataType.asInstanceOf[ArrayType], - smallestArray, i, s"$set2.add($getFromSmallest);", s"${ev.isNull} = true;") - val elementIsInSetCode = nullSafeElementCodegen(left.dataType.asInstanceOf[ArrayType], - biggestArray, + val addToSetFromSmallerCode = nullSafeElementCodegen( + smaller, i, s"$set2.add($getFromSmaller);", s"${ev.isNull} = true;") + val elementIsInSetCode = nullSafeElementCodegen( + bigger, i, s""" - |if ($set2.contains($getFromBiggest)) { + |if ($set2.contains($getFromBigger)) { | ${ev.isNull} = false; | ${ev.value} = true; | break; @@ -664,42 +667,41 @@ case class ArraysOverlap(left: Expression, right: Expression) |""".stripMargin, s"${ev.isNull} = true;") s""" - |ArrayData $smallestArray; - |ArrayData $biggestArray; + |ArrayData $smaller; + |ArrayData $bigger; |if ($a1.numElements() > $a2.numElements()) { - | $biggestArray = $a1; - | $smallestArray = $a2; + | $bigger = $a1; + | $smaller = $a2; |} else { - | $smallestArray = $a1; - | $biggestArray = $a2; + | $smaller = $a1; + | $bigger = $a2; |} - |if ($smallestArray.numElements() > 0) { + |if ($smaller.numElements() > 0) { | $javaSet<$javaElementClass> $set2 = new $javaSet<$javaElementClass>(); - | for (int $i = 0; $i < $smallestArray.numElements(); $i ++) { - | $addToSetFromSmallestCode + | for (int $i = 0; $i < $smaller.numElements(); $i ++) { + | $addToSetFromSmallerCode | } - | for (int $i = 0; $i < $biggestArray.numElements(); $i ++) { + | for (int $i = 0; $i < $bigger.numElements(); $i ++) { | $elementIsInSetCode | } - |} $leftEmptyCode + |} $smallerEmptyCode |""".stripMargin }) } def nullSafeElementCodegen( - arrayType: ArrayType, arrayVar: String, index: String, code: String, isNullCode: String): String = { - if (arrayType.containsNull) { + if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) { s""" |if ($arrayVar.isNullAt($index)) { | $isNullCode |} else { | $code |} - """.stripMargin + |""".stripMargin } else { code } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f772949e6201d..5696a95eb970b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3047,7 +3047,7 @@ object functions { */ def arrays_overlap(a1: Column, a2: Column): Column = withExpr { ArraysOverlap(a1.expr, a2.expr) - } + } /** * Returns an array containing all the elements in `x` from index `start` (or starting from the From 964f7af012fca13e86289d11131927d37b74ed41 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 8 May 2018 13:58:44 +0200 Subject: [PATCH 15/22] use findTightestCommonType for type inference --- .../catalyst/expressions/collectionOperations.scala | 11 ++++++++--- .../apache/spark/sql/DataFrameFunctionsSuite.scala | 6 ++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5f70bd8e76328..600d14a20ce56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -36,11 +36,16 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression with ImplicitCastInputTypes { - protected lazy val elementType: DataType = inputTypes.head.asInstanceOf[ArrayType].elementType + @transient protected lazy val elementType: DataType = + inputTypes.head.asInstanceOf[ArrayType].elementType override def inputTypes: Seq[AbstractDataType] = { - TypeCoercion.findWiderTypeForTwo(left.dataType, right.dataType) match { - case Some(arrayType) => Seq(arrayType, arrayType) + (left.dataType, right.dataType) match { + case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2)) + case _ => Seq.empty + } case _ => Seq.empty } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 983738315556f..1490e6a4b5c97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -454,10 +454,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df.select(arrays_overlap(df("a"), df("b"))), answer) checkAnswer(df.selectExpr("arrays_overlap(a, b)"), answer) - checkAnswer(sql("select arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))"), Row(false)) + checkAnswer( + Seq((Seq(1, 2, 3), Seq(2.0, 2.5))).toDF("a", "b").selectExpr("arrays_overlap(a, b)"), + Row(true)) intercept[AnalysisException] { - sql("select arrays_overlap(array(array(1)), array('a'))") + sql("select arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))") } intercept[AnalysisException] { From 41ef6c6716587ef8a35055f57944ec05264cc16d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 9 May 2018 19:01:35 +0200 Subject: [PATCH 16/22] support binary and complex data types --- .../expressions/collectionOperations.scala | 169 +++++++++++++++--- .../CollectionExpressionsSuite.scala | 22 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 4 + 3 files changed, 168 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 600d14a20ce56..4e3df98c39328 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -579,6 +579,31 @@ case class ArrayContains(left: Expression, right: Expression) case class ArraysOverlap(left: Expression, right: Expression) extends BinaryArrayExpressionWithImplicitCast { + override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + if (RowOrdering.isOrderable(elementType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.") + } + case failure => failure + } + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + @transient private lazy val doEvaluation = if (elementTypeSupportEquals) { + fastEval _ + } else { + bruteForceEval _ + } + override def dataType: DataType = BooleanType override def nullable: Boolean = { @@ -587,9 +612,16 @@ case class ArraysOverlap(left: Expression, right: Expression) } override def nullSafeEval(a1: Any, a2: Any): Any = { + doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData]) + } + + /** + * A fast implementation which puts all the elements from the smaller array in a set + * and then performs a lookup on it for each element of the bigger one. + * This eval mode works only for data types which implements properly the equals method. + */ + private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = { var hasNull = false - val arr1 = a1.asInstanceOf[ArrayData] - val arr2 = a2.asInstanceOf[ArrayData] val (bigger, smaller, biggerDt) = if (arr1.numElements() > arr2.numElements()) { (arr1, arr2, left.dataType.asInstanceOf[ArrayType]) } else { @@ -620,6 +652,34 @@ case class ArraysOverlap(left: Expression, right: Expression) } } + /** + * A slower evaluation which performs a nested loop and supports all the data types. + */ + private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = { + var hasNull = false + if (arr1.numElements() > 0) { + arr1.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else { + arr2.foreach(elementType, (_, v2) => + if (v1 == null) { + hasNull = true + } else if (ordering.equiv(v1, v2)) { + return true + } + ) + }) + } else if (containsNull(arr2, right.dataType.asInstanceOf[ArrayType])) { + hasNull = true + } + if (hasNull) { + null + } else { + false + } + } + def containsNull(arr: ArrayData, dt: ArrayType): Boolean = { if (dt.containsNull) { var i = 0 @@ -639,8 +699,6 @@ case class ArraysOverlap(left: Expression, right: Expression) val i = ctx.freshName("i") val smaller = ctx.freshName("smallerArray") val bigger = ctx.freshName("biggerArray") - val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i) - val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) val smallerEmptyCode = if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) { s""" |else { @@ -655,22 +713,11 @@ case class ArraysOverlap(left: Expression, right: Expression) } else { "" } - val javaElementClass = CodeGenerator.boxedType(elementType) - val javaSet = classOf[java.util.HashSet[_]].getName - val set2 = ctx.freshName("set") - val addToSetFromSmallerCode = nullSafeElementCodegen( - smaller, i, s"$set2.add($getFromSmaller);", s"${ev.isNull} = true;") - val elementIsInSetCode = nullSafeElementCodegen( - bigger, - i, - s""" - |if ($set2.contains($getFromBigger)) { - | ${ev.isNull} = false; - | ${ev.value} = true; - | break; - |} - |""".stripMargin, - s"${ev.isNull} = true;") + val comparisonCode = if (elementTypeSupportEquals) { + fastCodegen(ctx, ev, smaller, bigger) + } else { + bruteForceCodegen(ctx, ev, smaller, bigger) + } s""" |ArrayData $smaller; |ArrayData $bigger; @@ -682,18 +729,86 @@ case class ArraysOverlap(left: Expression, right: Expression) | $bigger = $a2; |} |if ($smaller.numElements() > 0) { - | $javaSet<$javaElementClass> $set2 = new $javaSet<$javaElementClass>(); - | for (int $i = 0; $i < $smaller.numElements(); $i ++) { - | $addToSetFromSmallerCode - | } - | for (int $i = 0; $i < $bigger.numElements(); $i ++) { - | $elementIsInSetCode - | } + | $comparisonCode |} $smallerEmptyCode |""".stripMargin }) } + /** + * Code generation for a fast implementation which puts all the elements from the smaller array + * in a set and then performs a lookup on it for each element of the bigger one. + * It works only for data types which implements properly the equals method. + */ + private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { + val i = ctx.freshName("i") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val javaElementClass = CodeGenerator.boxedType(elementType) + val javaSet = classOf[java.util.HashSet[_]].getName + val set2 = ctx.freshName("set") + val addToSetFromSmallerCode = nullSafeElementCodegen( + smaller, i, s"$set2.add($getFromSmaller);", s"${ev.isNull} = true;") + val elementIsInSetCode = nullSafeElementCodegen( + bigger, + i, + s""" + |if ($set2.contains($getFromBigger)) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + |} + |""".stripMargin, + s"${ev.isNull} = true;") + s""" + |$javaSet<$javaElementClass> $set2 = new $javaSet<$javaElementClass>(); + |for (int $i = 0; $i < $smaller.numElements(); $i ++) { + | $addToSetFromSmallerCode + |} + |for (int $i = 0; $i < $bigger.numElements(); $i ++) { + | $elementIsInSetCode + |} + """.stripMargin + } + + /** + * Code generation for a slower evaluation which performs a nested loop and supports all the data types. + */ + private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val compareValues = nullSafeElementCodegen( + smaller, + j, + s""" + |if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + |} + """.stripMargin, + s"${ev.isNull} = true;") + val isInSmaller = nullSafeElementCodegen( + bigger, + i, + s""" + |for (int $j = 0; $j < $smaller.numElements(); $j ++) { + | $compareValues + | if (${ev.value}) { + | break; + | } + |} + """.stripMargin, + s"${ev.isNull} = true;") + s""" + |for (int $i = 0; $i < $bigger.numElements(); $i ++) { + |$isInSmaller + |} + """.stripMargin + } + def nullSafeElementCodegen( arrayVar: String, index: String, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 6df3a624b55df..6cb6e7f546364 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -165,6 +165,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArraysOverlap( Literal.create(Seq(null), ArrayType(IntegerType)), Literal.create(Seq(null), ArrayType(IntegerType))), null) + + // arrays of binaries + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + + checkEvaluation(ArraysOverlap(b0, b1), true) + checkEvaluation(ArraysOverlap(b0, b2), false) + + // arrays of complex data types + val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")), + ArrayType(ArrayType(StringType))) + val aa1 = Literal.create(Seq[Array[String]](Array[String]("e", "f"), Array[String]("a", "b")), + ArrayType(ArrayType(StringType))) + val aa2 = Literal.create(Seq[Array[String]](Array[String]("b", "a"), Array[String]("f", "g")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(ArraysOverlap(aa0, aa1), true) + checkEvaluation(ArraysOverlap(aa0, aa2), false) } test("Slice") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 1490e6a4b5c97..5d78793a8f271 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -465,6 +465,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { intercept[AnalysisException] { sql("select arrays_overlap(null, null)") } + + intercept[AnalysisException] { + sql("select arrays_overlap(map(1, 2), map(3, 4))") + } } test("slice function") { From 3dd724b27329461eb1d1ff37c86be62fdc8f7b3d Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 11 May 2018 15:17:53 +0200 Subject: [PATCH 17/22] review comments --- python/pyspark/sql/functions.py | 3 +- .../expressions/collectionOperations.scala | 59 ++++--------------- .../CollectionExpressionsSuite.scala | 21 +++---- .../org/apache/spark/sql/functions.scala | 5 +- .../spark/sql/DataFrameFunctionsSuite.scala | 2 +- 5 files changed, 30 insertions(+), 60 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 067daae15e89c..d7449580cfcfd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1838,7 +1838,8 @@ def array_contains(col, value): def arrays_overlap(a1, a2): """ Collection function: returns true if the arrays contain any common non-null element; if not, - returns null if any of the arrays contains a null element and false otherwise. + returns null if both the arrays are non-empty and any of them contains a null element; returns + false otherwise. >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y']) >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4e3df98c39328..9300296a71e53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -569,7 +569,7 @@ case class ArrayContains(left: Expression, right: Expression) */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least an element present also in a2. If the arrays have no common element and either of them contains a null element null is returned, false otherwise.", + usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.", examples = """ Examples: > SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5)); @@ -599,10 +599,10 @@ case class ArraysOverlap(left: Expression, right: Expression) } @transient private lazy val doEvaluation = if (elementTypeSupportEquals) { - fastEval _ - } else { - bruteForceEval _ - } + fastEval _ + } else { + bruteForceEval _ + } override def dataType: DataType = BooleanType @@ -642,8 +642,6 @@ case class ArraysOverlap(left: Expression, right: Expression) return true } ) - } else if (containsNull(bigger, biggerDt)) { - hasNull = true } if (hasNull) { null @@ -670,8 +668,6 @@ case class ArraysOverlap(left: Expression, right: Expression) } ) }) - } else if (containsNull(arr2, right.dataType.asInstanceOf[ArrayType])) { - hasNull = true } if (hasNull) { null @@ -680,39 +676,10 @@ case class ArraysOverlap(left: Expression, right: Expression) } } - def containsNull(arr: ArrayData, dt: ArrayType): Boolean = { - if (dt.containsNull) { - var i = 0 - var hasNull = false - while (i < arr.numElements && !hasNull) { - hasNull = arr.isNullAt(i) - i += 1 - } - hasNull - } else { - false - } - } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (a1, a2) => { - val i = ctx.freshName("i") val smaller = ctx.freshName("smallerArray") val bigger = ctx.freshName("biggerArray") - val smallerEmptyCode = if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) { - s""" - |else { - | for (int $i = 0; $i < $bigger.numElements(); $i ++) { - | if ($bigger.isNullAt($i)) { - | ${ev.isNull} = true; - | break; - | } - | } - |} - """.stripMargin - } else { - "" - } val comparisonCode = if (elementTypeSupportEquals) { fastCodegen(ctx, ev, smaller, bigger) } else { @@ -730,8 +697,8 @@ case class ArraysOverlap(left: Expression, right: Expression) |} |if ($smaller.numElements() > 0) { | $comparisonCode - |} $smallerEmptyCode - |""".stripMargin + |} + """.stripMargin }) } @@ -746,22 +713,22 @@ case class ArraysOverlap(left: Expression, right: Expression) val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) val javaElementClass = CodeGenerator.boxedType(elementType) val javaSet = classOf[java.util.HashSet[_]].getName - val set2 = ctx.freshName("set") + val set = ctx.freshName("set") val addToSetFromSmallerCode = nullSafeElementCodegen( - smaller, i, s"$set2.add($getFromSmaller);", s"${ev.isNull} = true;") + smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;") val elementIsInSetCode = nullSafeElementCodegen( bigger, i, s""" - |if ($set2.contains($getFromBigger)) { + |if ($set.contains($getFromBigger)) { | ${ev.isNull} = false; | ${ev.value} = true; | break; |} - |""".stripMargin, + """.stripMargin, s"${ev.isNull} = true;") s""" - |$javaSet<$javaElementClass> $set2 = new $javaSet<$javaElementClass>(); + |$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>(); |for (int $i = 0; $i < $smaller.numElements(); $i ++) { | $addToSetFromSmallerCode |} @@ -821,7 +788,7 @@ case class ArraysOverlap(left: Expression, right: Expression) |} else { | $code |} - |""".stripMargin + """.stripMargin } else { code } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 6cb6e7f546364..14f15225d34f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -141,25 +141,26 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType)) val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType)) val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType)) - val a4 = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType)) + val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType)) - val a5 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) - val a6 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType)) - val a7 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType)) + val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) checkEvaluation(ArraysOverlap(a0, a1), true) checkEvaluation(ArraysOverlap(a0, a2), null) checkEvaluation(ArraysOverlap(a1, a2), true) checkEvaluation(ArraysOverlap(a1, a3), false) - checkEvaluation(ArraysOverlap(a0, a4), false) - checkEvaluation(ArraysOverlap(a2, a4), null) - checkEvaluation(ArraysOverlap(a4, a2), null) + checkEvaluation(ArraysOverlap(a0, emptyIntArray), false) + checkEvaluation(ArraysOverlap(a2, emptyIntArray), false) + checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) - checkEvaluation(ArraysOverlap(a5, a6), true) - checkEvaluation(ArraysOverlap(a5, a7), null) - checkEvaluation(ArraysOverlap(a6, a7), false) + checkEvaluation(ArraysOverlap(a4, a5), true) + checkEvaluation(ArraysOverlap(a4, a6), null) + checkEvaluation(ArraysOverlap(a5, a6), false) // null handling + checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null) checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null) checkEvaluation(ArraysOverlap( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5696a95eb970b..f06ffd14f6515 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3040,8 +3040,9 @@ object functions { } /** - * Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and - * any of the arrays contains a `null`, it returns `null`. It returns `false` otherwise. + * Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and both + * the arrays are non-empty and any of them contains a `null`, it returns `null`. It returns + * `false` otherwise. * @group collection_funcs * @since 2.4.0 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5d78793a8f271..b8e560d5c1296 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -445,7 +445,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("arrays_overlap function") { val df = Seq( (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))), - (Seq.empty[Option[Int]], Seq[Option[Int]](Some(-1), None)), + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), None)), (Seq[Option[Int]](Some(3), Some(2)), Seq[Option[Int]](Some(1), Some(2))) ).toDF("a", "b") From e36a5d7a1e80743838ca00cd9047a721f784cdac Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 11 May 2018 18:30:14 +0200 Subject: [PATCH 18/22] fix compilation error --- .../sql/catalyst/expressions/collectionOperations.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 9300296a71e53..9908cb901c3e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -36,13 +37,15 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression with ImplicitCastInputTypes { + private val caseSensitive = SQLConf.get.caseSensitiveAnalysis + @transient protected lazy val elementType: DataType = inputTypes.head.asInstanceOf[ArrayType].elementType override def inputTypes: Seq[AbstractDataType] = { (left.dataType, right.dataType) match { case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) => - TypeCoercion.findTightestCommonType(e1, e2) match { + TypeCoercion.findTightestCommonType(e1, e2, caseSensitive) match { case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2)) case _ => Seq.empty } From 49d937293d57ea561d03b6e2ee8cc947e789a68c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 14 May 2018 15:24:28 +0200 Subject: [PATCH 19/22] address comments --- .../expressions/collectionOperations.scala | 14 +++++++------- .../expressions/CollectionExpressionsSuite.scala | 2 ++ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 9908cb901c3e8..9d8b4fea858e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -625,10 +625,10 @@ case class ArraysOverlap(left: Expression, right: Expression) */ private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = { var hasNull = false - val (bigger, smaller, biggerDt) = if (arr1.numElements() > arr2.numElements()) { - (arr1, arr2, left.dataType.asInstanceOf[ArrayType]) + val (bigger, smaller) = if (arr1.numElements() > arr2.numElements()) { + (arr1, arr2) } else { - (arr2, arr1, right.dataType.asInstanceOf[ArrayType]) + (arr2, arr1) } if (smaller.numElements() > 0) { val smallestSet = new mutable.HashSet[Any] @@ -766,15 +766,15 @@ case class ArraysOverlap(left: Expression, right: Expression) s""" |for (int $j = 0; $j < $smaller.numElements(); $j ++) { | $compareValues - | if (${ev.value}) { - | break; - | } |} """.stripMargin, s"${ev.isNull} = true;") s""" |for (int $i = 0; $i < $bigger.numElements(); $i ++) { - |$isInSmaller + | $isInSmaller + | if (${ev.value}) { + | break; + | } |} """.stripMargin } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 14f15225d34f9..f216f9a37b5f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -161,6 +161,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper // null handling checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) + checkEvaluation(ArraysOverlap( + emptyIntArray, Literal.create(Seq(null), ArrayType(IntegerType))), false) checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null) checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null) checkEvaluation(ArraysOverlap( From 227437bb81a4baf7cf985dd00e86bb57baf5df0b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 15 May 2018 10:21:18 +0200 Subject: [PATCH 20/22] address comment --- .../sql/catalyst/expressions/collectionOperations.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 9d8b4fea858e8..a9dfc7bd8ab48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -756,7 +756,6 @@ case class ArraysOverlap(left: Expression, right: Expression) |if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) { | ${ev.isNull} = false; | ${ev.value} = true; - | break; |} """.stripMargin, s"${ev.isNull} = true;") @@ -764,17 +763,14 @@ case class ArraysOverlap(left: Expression, right: Expression) bigger, i, s""" - |for (int $j = 0; $j < $smaller.numElements(); $j ++) { + |for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) { | $compareValues |} """.stripMargin, s"${ev.isNull} = true;") s""" - |for (int $i = 0; $i < $bigger.numElements(); $i ++) { + |for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) { | $isInSmaller - | if (${ev.value}) { - | break; - | } |} """.stripMargin } From 2e9e0249a5a7a8bdcb453a523a3fa63f320e8f0f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 16 May 2018 12:58:58 +0200 Subject: [PATCH 21/22] fix null handling with complex types --- .../catalyst/expressions/collectionOperations.scala | 4 ++-- .../expressions/CollectionExpressionsSuite.scala | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a9dfc7bd8ab48..50914920b8a88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -658,13 +658,13 @@ case class ArraysOverlap(left: Expression, right: Expression) */ private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = { var hasNull = false - if (arr1.numElements() > 0) { + if (arr1.numElements() > 0 && arr2.numElements() > 0) { arr1.foreach(elementType, (_, v1) => if (v1 == null) { hasNull = true } else { arr2.foreach(elementType, (_, v2) => - if (v1 == null) { + if (v2 == null) { hasNull = true } else if (ordering.equiv(v1, v2)) { return true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index f216f9a37b5f4..32cb9fd19e1e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -190,6 +190,16 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArraysOverlap(aa0, aa1), true) checkEvaluation(ArraysOverlap(aa0, aa2), false) + + // null handling with complex datatypes + val emptyBinaryArray = Literal.create(Seq.empty[Array[Byte]], ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + checkEvaluation(ArraysOverlap(emptyBinaryArray, b0), false) + checkEvaluation(ArraysOverlap(b0, emptyBinaryArray), false) + checkEvaluation(ArraysOverlap(emptyBinaryArray, arrayWithBinaryNull), false) + checkEvaluation(ArraysOverlap(arrayWithBinaryNull, emptyBinaryArray), false) + checkEvaluation(ArraysOverlap(arrayWithBinaryNull, b0), null) + checkEvaluation(ArraysOverlap(b0, arrayWithBinaryNull), null) } test("Slice") { From 56c59ae0ce985fbfa2b4ab900ae23e7237cf2ceb Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 17 May 2018 10:22:53 +0200 Subject: [PATCH 22/22] fix build --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index bac28f6cc08cb..c82db839438ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -37,15 +37,13 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression with ImplicitCastInputTypes { - private val caseSensitive = SQLConf.get.caseSensitiveAnalysis - @transient protected lazy val elementType: DataType = inputTypes.head.asInstanceOf[ArrayType].elementType override def inputTypes: Seq[AbstractDataType] = { (left.dataType, right.dataType) match { case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) => - TypeCoercion.findTightestCommonType(e1, e2, caseSensitive) match { + TypeCoercion.findTightestCommonType(e1, e2) match { case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2)) case _ => Seq.empty }