From 9b8521e53e56a53b44c02366a99f8a8ee1307bbf Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 9 Aug 2018 14:41:59 -0700 Subject: [PATCH] [SPARK-25068][SQL] Add exists function. ## What changes were proposed in this pull request? This pr adds `exists` function which tests whether a predicate holds for one or more elements in the array. ```sql > SELECT exists(array(1, 2, 3), x -> x % 2 == 0); true ``` ## How was this patch tested? Added tests. Closes #22052 from ueshin/issues/SPARK-25068/exists. Authored-by: Takuya UESHIN Signed-off-by: Xiao Li --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 47 +++++++++ .../HigherOrderFunctionsSuite.scala | 37 +++++++ .../inputs/higher-order-functions.sql | 6 ++ .../results/higher-order-functions.sql.out | 18 ++++ .../spark/sql/DataFrameFunctionsSuite.scala | 96 +++++++++++++++++++ 6 files changed, 205 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 390debd865eed..15543c909a271 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -444,6 +444,7 @@ object FunctionRegistry { expression[ArrayTransform]("transform"), expression[MapFilter]("map_filter"), expression[ArrayFilter]("filter"), + expression[ArrayExists]("exists"), expression[ArrayAggregate]("aggregate"), CreateStruct.registryEntry, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index d20673359129b..7f8203ab92213 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -356,6 +356,53 @@ case class ArrayFilter( override def prettyName: String = "filter" } +/** + * Tests whether a predicate holds for one or more elements in the array. + */ +@ExpressionDescription(usage = + "_FUNC_(expr, pred) - Tests whether a predicate holds for one or more elements in the array.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 0); + true + """, + since = "2.4.0") +case class ArrayExists( + input: Expression, + function: Expression) + extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + + override def nullable: Boolean = input.nullable + + override def dataType: DataType = BooleanType + + override def expectingFunctionType: AbstractDataType = BooleanType + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayExists = { + val elem = HigherOrderFunction.arrayArgumentType(input.dataType) + copy(function = f(function, elem :: Nil)) + } + + @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function + + override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { + val arr = value.asInstanceOf[ArrayData] + val f = functionForEval + var exists = false + var i = 0 + while (i < arr.numElements && !exists) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (f.eval(inputRow).asInstanceOf[Boolean]) { + exists = true + } + i += 1 + } + exists + } + + override def prettyName: String = "exists" +} + /** * Applies a binary operator to a start value and all elements in the array. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index f7e84b8757916..bc7d04c77fa9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -202,6 +202,43 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(Seq(1, 3), null, Seq(5))) } + test("ArrayExists") { + def exists(expr: Expression, f: Expression => Expression): Expression = { + val at = expr.dataType.asInstanceOf[ArrayType] + ArrayExists(expr, createLambda(at.elementType, at.containsNull, f)) + } + + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val isEven: Expression => Expression = x => x % 2 === 0 + val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1 + + checkEvaluation(exists(ai0, isEven), true) + checkEvaluation(exists(ai0, isNullOrOdd), true) + checkEvaluation(exists(ai1, isEven), false) + checkEvaluation(exists(ai1, isNullOrOdd), true) + checkEvaluation(exists(ain, isEven), null) + checkEvaluation(exists(ain, isNullOrOdd), null) + + val as0 = + Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq(null, "b", "c"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val startsWithA: Expression => Expression = x => x.startsWith("a") + + checkEvaluation(exists(as0, startsWithA), true) + checkEvaluation(exists(as1, startsWithA), false) + checkEvaluation(exists(asn, startsWithA), null) + + val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation(transform(aai, ix => exists(ix, isNullOrOdd)), + Seq(true, null, true)) + } + test("ArrayAggregate") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 136396d9553db..ce1d0daa4d397 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -45,3 +45,9 @@ select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as -- Aggregate a null array select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) as v; + +-- Check for element existence +select exists(ys, y -> y > 30) as v from nested; + +-- Check for element existence in a null array +select exists(cast(null as array), y -> y > 30) as v; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index e6f62f2e1bb67..e18abce3b617b 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -145,3 +145,21 @@ select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) a struct -- !query 14 output NULL + + +-- !query 15 +select exists(ys, y -> y > 30) as v from nested +-- !query 15 schema +struct +-- !query 15 output +false +true +true + + +-- !query 16 +select exists(cast(null as array), y -> y > 30) as v +-- !query 16 schema +struct +-- !query 16 output +NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 24091f2128049..2c4238e69ad7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1996,6 +1996,102 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) } + test("exists function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 9, 7), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("exists function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(5, null, null, 9, 7, null), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("exists function - array for non-primitive type") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("exists(s, x -> x is null)"), + Seq( + Row(false), + Row(true), + Row(false), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("exists function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("exists(s, (x, y) -> x + y)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("exists(i, x -> x)") + } + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + + val ex3 = intercept[AnalysisException] { + df.selectExpr("exists(s, x -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + } + test("aggregate function - array for primitive type not containing null") { val df = Seq( Seq(1, 9, 8, 7),