diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 2aca10f1bfbc7..9ad6f30c40a88 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -591,6 +591,8 @@ primaryExpression (OVER windowSpec)? #functionCall | qualifiedName '(' trimOption=(BOTH | LEADING | TRAILING) argument+=expression FROM argument+=expression ')' #functionCall + | IDENTIFIER '->' expression #lambda + | '(' IDENTIFIER (',' IDENTIFIER)+ ')' '->' expression #lambda | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 76dc86710909e..7f235ac560299 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -180,6 +180,8 @@ class Analyzer( ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: + ResolveHigherOrderFunctions(catalog) :: + ResolveLambdaVariables(conf) :: ResolveTimeZone(conf) :: ResolveRandomSeed :: TypeCoercion.typeCoercionRules(conf) ++ @@ -878,6 +880,7 @@ class Analyzer( } private def resolve(e: Expression, q: LogicalPlan): Expression = e match { + case f: LambdaFunction if !f.bound => f case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b8b311219ca8d..f7517486e5411 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -440,6 +440,7 @@ object FunctionRegistry { expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), expression[ArrayDistinct]("array_distinct"), + expression[ArrayTransform]("transform"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala new file mode 100644 index 0000000000000..063ca0fc3252d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType + +/** + * Resolve a higher order functions from the catalog. This is different from regular function + * resolution because lambda functions can only be resolved after the function has been resolved; + * so we need to resolve higher order function when all children are either resolved or a lambda + * function. + */ +case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case q: LogicalPlan => + q.transformExpressions { + case u @ UnresolvedFunction(fn, children, false) + if hasLambdaAndResolvedArguments(children) => + withPosition(u) { + catalog.lookupFunction(fn, children) match { + case func: HigherOrderFunction => func + case other => other.failAnalysis( + "A lambda function should only be used in a higher order function. However, " + + s"its class is ${other.getClass.getCanonicalName}, which is not a " + + s"higher order function.") + } + } + } + } + + /** + * Check if the arguments of a function are either resolved or a lambda function. + */ + private def hasLambdaAndResolvedArguments(expressions: Seq[Expression]): Boolean = { + val (lambdas, others) = expressions.partition(_.isInstanceOf[LambdaFunction]) + lambdas.nonEmpty && others.forall(_.resolved) + } +} + +/** + * Resolve the lambda variables exposed by a higher order functions. + * + * This rule works in two steps: + * [1]. Bind the anonymous variables exposed by the higher order function to the lambda function's + * arguments; this creates named and typed lambda variables. The argument names are checked + * for duplicates and the number of arguments are checked during this step. + * [2]. Resolve the used lambda variables used in the lambda function's function expression tree. + * Note that we allow the use of variables from outside the current lambda, this can either + * be a lambda function defined in an outer scope, or a attribute in produced by the plan's + * child. If names are duplicate, the name defined in the most inner scope is used. + */ +case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] { + + type LambdaVariableMap = Map[String, NamedExpression] + + private val canonicalizer = { + if (!conf.caseSensitiveAnalysis) { + s: String => s.toLowerCase + } else { + s: String => s + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.resolveOperators { + case q: LogicalPlan => + q.mapExpressions(resolve(_, Map.empty)) + } + } + + /** + * Create a bound lambda function by binding the arguments of a lambda function to the given + * partial arguments (dataType and nullability only). If the expression happens to be an already + * bound lambda function then we assume it has been bound to the correct arguments and do + * nothing. This function will produce a lambda function with hidden arguments when it is passed + * an arbitrary expression. + */ + private def createLambda( + e: Expression, + partialArguments: Seq[(DataType, Boolean)]): LambdaFunction = e match { + case f: LambdaFunction if f.bound => f + + case LambdaFunction(function, names, _) => + if (names.size != partialArguments.size) { + e.failAnalysis( + s"The number of lambda function arguments '${names.size}' does not " + + "match the number of arguments expected by the higher order function " + + s"'${partialArguments.size}'.") + } + + if (names.map(a => canonicalizer(a.name)).distinct.size < names.size) { + e.failAnalysis( + "Lambda function arguments should not have names that are semantically the same.") + } + + val arguments = partialArguments.zip(names).map { + case ((dataType, nullable), ne) => + NamedLambdaVariable(ne.name, dataType, nullable) + } + LambdaFunction(function, arguments) + + case _ => + // This expression does not consume any of the lambda's arguments (it is independent). We do + // create a lambda function with default parameters because this is expected by the higher + // order function. Note that we hide the lambda variables produced by this function in order + // to prevent accidental naming collisions. + val arguments = partialArguments.zipWithIndex.map { + case ((dataType, nullable), i) => + NamedLambdaVariable(s"col$i", dataType, nullable) + } + LambdaFunction(e, arguments, hidden = true) + } + + /** + * Resolve lambda variables in the expression subtree, using the passed lambda variable registry. + */ + private def resolve(e: Expression, parentLambdaMap: LambdaVariableMap): Expression = e match { + case _ if e.resolved => e + + case h: HigherOrderFunction if h.inputResolved => + h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap)) + + case l: LambdaFunction if !l.bound => + // Do not resolve an unbound lambda function. If we see such a lambda function this means + // that either the higher order function has yet to be resolved, or that we are seeing + // dangling lambda function. + l + + case l: LambdaFunction if !l.hidden => + val lambdaMap = l.arguments.map(v => canonicalizer(v.name) -> v).toMap + l.mapChildren(resolve(_, parentLambdaMap ++ lambdaMap)) + + case u @ UnresolvedAttribute(name +: nestedFields) => + parentLambdaMap.get(canonicalizer(name)) match { + case Some(lambda) => + nestedFields.foldLeft(lambda: Expression) { (expr, fieldName) => + ExtractValue(expr, Literal(fieldName), conf.resolver) + } + case None => u + } + + case _ => + e.mapChildren(resolve(_, parentLambdaMap)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala new file mode 100644 index 0000000000000..c5c3482afa134 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.concurrent.atomic.AtomicReference + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.types._ + +/** + * A named lambda variable. + */ +case class NamedLambdaVariable( + name: String, + dataType: DataType, + nullable: Boolean, + value: AtomicReference[Any] = new AtomicReference(), + exprId: ExprId = NamedExpression.newExprId) + extends LeafExpression + with NamedExpression + with CodegenFallback { + + override def qualifier: Option[String] = None + + override def newInstance(): NamedExpression = + copy(value = new AtomicReference(), exprId = NamedExpression.newExprId) + + override def toAttribute: Attribute = { + AttributeReference(name, dataType, nullable, Metadata.empty)(exprId, None) + } + + override def eval(input: InternalRow): Any = value.get + + override def toString: String = s"lambda $name#${exprId.id}$typeSuffix" + + override def simpleString: String = s"lambda $name#${exprId.id}: ${dataType.simpleString}" +} + +/** + * A lambda function and its arguments. A lambda function can be hidden when a user wants to + * process an completely independent expression in a [[HigherOrderFunction]], the lambda function + * and its variables are then only used for internal bookkeeping within the higher order function. + */ +case class LambdaFunction( + function: Expression, + arguments: Seq[NamedExpression], + hidden: Boolean = false) + extends Expression with CodegenFallback { + + override def children: Seq[Expression] = function +: arguments + override def dataType: DataType = function.dataType + override def nullable: Boolean = function.nullable + + lazy val bound: Boolean = arguments.forall(_.resolved) + + override def eval(input: InternalRow): Any = function.eval(input) +} + +/** + * A higher order function takes one or more (lambda) functions and applies these to some objects. + * The function produces a number of variables which can be consumed by some lambda function. + */ +trait HigherOrderFunction extends Expression { + + override def children: Seq[Expression] = inputs ++ functions + + /** + * Inputs to the higher ordered function. + */ + def inputs: Seq[Expression] + + /** + * All inputs have been resolved. This means that the types and nullabilty of (most of) the + * lambda function arguments is known, and that we can start binding the lambda functions. + */ + lazy val inputResolved: Boolean = inputs.forall(_.resolved) + + /** + * Functions applied by the higher order function. + */ + def functions: Seq[Expression] + + /** + * All inputs must be resolved and all functions must be resolved lambda functions. + */ + override lazy val resolved: Boolean = inputResolved && functions.forall { + case l: LambdaFunction => l.resolved + case _ => false + } + + /** + * Bind the lambda functions to the [[HigherOrderFunction]] using the given bind function. The + * bind function takes the potential lambda and it's (partial) arguments and converts this into + * a bound lambda function. + */ + def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): HigherOrderFunction + + @transient lazy val functionsForEval: Seq[Expression] = functions.map { + case LambdaFunction(function, arguments, hidden) => + val argumentMap = arguments.map { arg => arg.exprId -> arg }.toMap + function.transformUp { + case variable: NamedLambdaVariable if argumentMap.contains(variable.exprId) => + argumentMap(variable.exprId) + } + } +} + +trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { + + def input: Expression + + override def inputs: Seq[Expression] = input :: Nil + + def function: Expression + + override def functions: Seq[Expression] = function :: Nil + + def expectingFunctionType: AbstractDataType = AnyDataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType) + + @transient lazy val functionForEval: Expression = functionsForEval.head +} + +/** + * Transform elements in an array using the transform function. This is similar to + * a `map` in functional programming. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Transforms elements in an array using the function.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), x -> x + 1); + array(2, 3, 4) + > SELECT _FUNC_(array(1, 2, 3), (x, i) -> x + i); + array(1, 3, 5) + """, + since = "2.4.0") +case class ArrayTransform( + input: Expression, + function: Expression) + extends ArrayBasedHigherOrderFunction with CodegenFallback { + + override def nullable: Boolean = input.nullable + + override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = { + val (elementType, containsNull) = input.dataType match { + case ArrayType(elementType, containsNull) => (elementType, containsNull) + case _ => + val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType + (elementType, containsNull) + } + function match { + case LambdaFunction(_, arguments, _) if arguments.size == 2 => + copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil)) + case _ => + copy(function = f(function, (elementType, containsNull) :: Nil)) + } + } + + @transient lazy val (elementVar, indexVar) = { + val LambdaFunction(_, (elementVar: NamedLambdaVariable) +: tail, _) = function + val indexVar = if (tail.nonEmpty) { + Some(tail.head.asInstanceOf[NamedLambdaVariable]) + } else { + None + } + (elementVar, indexVar) + } + + override def eval(input: InternalRow): Any = { + val arr = this.input.eval(input).asInstanceOf[ArrayData] + if (arr == null) { + null + } else { + val f = functionForEval + val result = new GenericArrayData(new Array[Any](arr.numElements)) + var i = 0 + while (i < arr.numElements) { + elementVar.value.set(arr.get(i, elementVar.dataType)) + if (indexVar.isDefined) { + indexVar.get.value.set(i) + } + result.update(i, f.eval(input)) + i += 1 + } + result + } + } + + override def prettyName: String = "transform" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8a8db6df37094..0ceeb53e1d7a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1312,6 +1312,16 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } } + /** + * Create an [[LambdaFunction]]. + */ + override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { + val arguments = ctx.IDENTIFIER().asScala.map { name => + UnresolvedAttribute.quoted(name.getText) + } + LambdaFunction(expression(ctx.expression), arguments) + } + /** * Create a reference to a window frame, i.e. [[WindowSpecReference]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala new file mode 100644 index 0000000000000..c4171c75ecd03 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.{ArrayType, IntegerType} + +/** + * Test suite for [[ResolveLambdaVariables]]. + */ +class ResolveLambdaVariablesSuite extends PlanTest { + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + object Analyzer extends RuleExecutor[LogicalPlan] { + val batches = Batch("Resolution", FixedPoint(4), ResolveLambdaVariables(conf)) :: Nil + } + + private val key = 'key.int + private val values1 = 'values1.array(IntegerType) + private val values2 = 'values2.array(ArrayType(ArrayType(IntegerType))) + private val data = LocalRelation(Seq(key, values1, values2)) + private val lvInt = NamedLambdaVariable("x", IntegerType, nullable = true) + private val lvHiddenInt = NamedLambdaVariable("col0", IntegerType, nullable = true) + private val lvArray = NamedLambdaVariable("x", ArrayType(IntegerType), nullable = true) + + private def plan(e: Expression): LogicalPlan = data.select(e.as("res")) + + private def checkExpression(e1: Expression, e2: Expression): Unit = { + comparePlans(Analyzer.execute(plan(e1)), plan(e2)) + } + + test("resolution - no op") { + checkExpression(key, key) + } + + test("resolution - simple") { + val in = ArrayTransform(values1, LambdaFunction('x.attr + 1, 'x.attr :: Nil)) + val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil)) + checkExpression(in, out) + } + + test("resolution - nested") { + val in = ArrayTransform(values2, LambdaFunction( + ArrayTransform('x.attr, LambdaFunction('x.attr + 1, 'x.attr :: Nil)), 'x.attr :: Nil)) + val out = ArrayTransform(values2, LambdaFunction( + ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil)) + checkExpression(in, out) + } + + test("resolution - hidden") { + val in = ArrayTransform(values1, key) + val out = ArrayTransform(values1, LambdaFunction(key, lvHiddenInt :: Nil, hidden = true)) + checkExpression(in, out) + } + + test("fail - name collisions") { + val p = plan(ArrayTransform(values1, + LambdaFunction('x.attr + 'X.attr, 'x.attr :: 'X.attr :: Nil))) + val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage + assert(msg.contains("arguments should not have names that are semantically the same")) + } + + test("fail - lambda arguments") { + val p = plan(ArrayTransform(values1, + LambdaFunction('x.attr + 'y.attr + 'z.attr, 'x.attr :: 'y.attr :: 'z.attr :: Nil))) + val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage + assert(msg.contains("does not match the number of arguments expected")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala new file mode 100644 index 0000000000000..e987ea5b8a4d1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private def createLambda( + dt: DataType, + nullable: Boolean, + f: Expression => Expression): Expression = { + val lv = NamedLambdaVariable("arg", dt, nullable) + val function = f(lv) + LambdaFunction(function, Seq(lv)) + } + + private def createLambda( + dt1: DataType, + nullable1: Boolean, + dt2: DataType, + nullable2: Boolean, + f: (Expression, Expression) => Expression): Expression = { + val lv1 = NamedLambdaVariable("arg1", dt1, nullable1) + val lv2 = NamedLambdaVariable("arg2", dt2, nullable2) + val function = f(lv1, lv2) + LambdaFunction(function, Seq(lv1, lv2)) + } + + def transform(expr: Expression, f: Expression => Expression): Expression = { + val at = expr.dataType.asInstanceOf[ArrayType] + ArrayTransform(expr, createLambda(at.elementType, at.containsNull, f)) + } + + def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val at = expr.dataType.asInstanceOf[ArrayType] + ArrayTransform(expr, createLambda(at.elementType, at.containsNull, IntegerType, false, f)) + } + + test("ArrayTransform") { + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val plusOne: Expression => Expression = x => x + 1 + val plusIndex: (Expression, Expression) => Expression = (x, i) => x + i + + checkEvaluation(transform(ai0, plusOne), Seq(2, 3, 4)) + checkEvaluation(transform(ai0, plusIndex), Seq(1, 3, 5)) + checkEvaluation(transform(transform(ai0, plusIndex), plusOne), Seq(2, 4, 6)) + checkEvaluation(transform(ai1, plusOne), Seq(2, null, 4)) + checkEvaluation(transform(ai1, plusIndex), Seq(1, null, 5)) + checkEvaluation(transform(transform(ai1, plusIndex), plusOne), Seq(2, null, 6)) + checkEvaluation(transform(ain, plusOne), null) + + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val repeatTwice: Expression => Expression = x => Concat(Seq(x, x)) + val repeatIndexTimes: (Expression, Expression) => Expression = (x, i) => StringRepeat(x, i) + + checkEvaluation(transform(as0, repeatTwice), Seq("aa", "bb", "cc")) + checkEvaluation(transform(as0, repeatIndexTimes), Seq("", "b", "cc")) + checkEvaluation(transform(transform(as0, repeatIndexTimes), repeatTwice), + Seq("", "bb", "cccc")) + checkEvaluation(transform(as1, repeatTwice), Seq("aa", null, "cc")) + checkEvaluation(transform(as1, repeatIndexTimes), Seq("", null, "cc")) + checkEvaluation(transform(transform(as1, repeatIndexTimes), repeatTwice), + Seq("", null, "cccc")) + checkEvaluation(transform(asn, repeatTwice), null) + + val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation(transform(aai, array => Cast(transform(array, plusOne), StringType)), + Seq("[2, 3, 4]", null, "[5, 6]")) + checkEvaluation(transform(aai, array => Cast(transform(array, plusIndex), StringType)), + Seq("[1, 3, 5]", null, "[4, 6]")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index b4d422d8506fc..c37b9f148cf48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -234,6 +234,11 @@ class ExpressionParserSuite extends PlanTest { intercept("foo(a x)", "extraneous input 'x'") } + test("lambda functions") { + assertEqual("x -> x + 1", LambdaFunction('x + 1, Seq('x.attr))) + assertEqual("(x, y) -> x + y", LambdaFunction('x + 'y, Seq('x.attr, 'y.attr))) + } + test("window function expressions") { val func = 'foo.function(star()) def windowed( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 6241d5cbb1d25..139785719fec7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -60,6 +60,8 @@ trait PlanTestBase extends PredicateHelper { self: Suite => Alias(a.child, a.name)(exprId = ExprId(0)) case ae: AggregateExpression => ae.copy(resultId = ExprId(0)) + case lv: NamedLambdaVariable => + lv.copy(value = null, exprId = ExprId(0)) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql new file mode 100644 index 0000000000000..8e928a41f08e0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -0,0 +1,26 @@ +create or replace temporary view nested as values + (1, array(32, 97), array(array(12, 99), array(123, 42), array(1))), + (2, array(77, -76), array(array(6, 96, 65), array(-1, -2))), + (3, array(12), array(array(17))) + as t(x, ys, zs); + +-- Only allow lambda's in higher order functions. +select upper(x -> x) as v; + +-- Identity transform an array +select transform(zs, z -> z) as v from nested; + +-- Transform an array +select transform(ys, y -> y * y) as v from nested; + +-- Transform an array with index +select transform(ys, (y, i) -> y + i) as v from nested; + +-- Transform an array with reference +select transform(zs, z -> concat(ys, z)) as v from nested; + +-- Transform an array to an array of 0's +select transform(ys, 0) as v from nested; + +-- Transform a null array +select transform(cast(null as array), x -> x + 1) as v; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out new file mode 100644 index 0000000000000..ca2c3c35333cc --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -0,0 +1,81 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +create or replace temporary view nested as values + (1, array(32, 97), array(array(12, 99), array(123, 42), array(1))), + (2, array(77, -76), array(array(6, 96, 65), array(-1, -2))), + (3, array(12), array(array(17))) + as t(x, ys, zs) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select upper(x -> x) as v +-- !query 1 schema +struct<> +-- !query 1 output +org.apache.spark.sql.AnalysisException +A lambda function should only be used in a higher order function. However, its class is org.apache.spark.sql.catalyst.expressions.Upper, which is not a higher order function.; line 1 pos 7 + + +-- !query 2 +select transform(zs, z -> z) as v from nested +-- !query 2 schema +struct>> +-- !query 2 output +[[12,99],[123,42],[1]] +[[17]] +[[6,96,65],[-1,-2]] + + +-- !query 3 +select transform(ys, y -> y * y) as v from nested +-- !query 3 schema +struct> +-- !query 3 output +[1024,9409] +[144] +[5929,5776] + + +-- !query 4 +select transform(ys, (y, i) -> y + i) as v from nested +-- !query 4 schema +struct> +-- !query 4 output +[12] +[32,98] +[77,-75] + + +-- !query 5 +select transform(zs, z -> concat(ys, z)) as v from nested +-- !query 5 schema +struct>> +-- !query 5 output +[[12,17]] +[[32,97,12,99],[32,97,123,42],[32,97,1]] +[[77,-76,6,96,65],[77,-76,-1,-2]] + + +-- !query 6 +select transform(ys, 0) as v from nested +-- !query 6 schema +struct> +-- !query 6 output +[0,0] +[0,0] +[0] + + +-- !query 7 +select transform(cast(null as array), x -> x + 1) as v +-- !query 7 schema +struct> +-- !query 7 output +NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e550b142c738d..923482024b033 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1647,6 +1647,159 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(result10.first.schema(0).dataType === expectedType10) } + test("transform function - array for primitive type not containing null") { + val df = Seq( + Seq(1, 9, 8, 7), + Seq(5, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { + checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), + Seq( + Row(Seq(2, 10, 9, 8)), + Row(Seq(6, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), + Seq( + Row(Seq(1, 10, 10, 10)), + Row(Seq(5, 9, 11, 10, 6)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeNotContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeNotContainsNull() + } + + test("transform function - array for primitive type containing null") { + val df = Seq[Seq[Integer]]( + Seq(1, 9, 8, null, 7), + Seq(5, null, 8, 9, 7, 2), + Seq.empty, + null + ).toDF("i") + + def testArrayOfPrimitiveTypeContainsNull(): Unit = { + checkAnswer(df.selectExpr("transform(i, x -> x + 1)"), + Seq( + Row(Seq(2, 10, 9, null, 8)), + Row(Seq(6, null, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(i, (x, i) -> x + i)"), + Seq( + Row(Seq(1, 10, 10, null, 11)), + Row(Seq(5, null, 10, 12, 11, 7)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArrayOfPrimitiveTypeContainsNull() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testArrayOfPrimitiveTypeContainsNull() + } + + test("transform function - array for non-primitive type") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("s") + + def testNonPrimitiveType(): Unit = { + checkAnswer(df.selectExpr("transform(s, x -> concat(x, x))"), + Seq( + Row(Seq("cc", "aa", "bb")), + Row(Seq("bb", null, "cc", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(s, (x, i) -> concat(x, i))"), + Seq( + Row(Seq("c0", "a1", "b2")), + Row(Seq("b0", null, "c2", null)), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testNonPrimitiveType() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testNonPrimitiveType() + } + + test("transform function - special cases") { + val df = Seq( + Seq("c", "a", "b"), + Seq("b", null, "c", null), + Seq.empty, + null + ).toDF("arg") + + def testSpecialCases(): Unit = { + checkAnswer(df.selectExpr("transform(arg, arg -> arg)"), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", null, "c", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(arg, arg)"), + Seq( + Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), + Row(Seq( + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null))), + Row(Seq.empty), + Row(null))) + checkAnswer(df.selectExpr("transform(arg, x -> concat(arg, array(x)))"), + Seq( + Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), + Row(Seq( + Seq("b", null, "c", null, "b"), + Seq("b", null, "c", null, null), + Seq("b", null, "c", null, "c"), + Seq("b", null, "c", null, null))), + Row(Seq.empty), + Row(null))) + } + + // Test with local relation, the Project will be evaluated without codegen + testSpecialCases() + // Test with cached relation, the Project will be evaluated with codegen + df.cache() + testSpecialCases() + } + + test("transform function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), 1), + (Seq("b", null, "c", null), 2), + (Seq.empty, 3), + (null, 4) + ).toDF("s", "i") + + val ex1 = intercept[AnalysisException] { + df.selectExpr("transform(s, (x, y, z) -> x + y + z)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex2 = intercept[AnalysisException] { + df.selectExpr("transform(i, x -> x)") + } + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {