diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2d6b9f01e6525..a294d70119d0b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4763,17 +4763,6 @@ def test_vectorized_udf_invalid_length(self): 'Result vector from pandas_udf was not the required length'): df.select(raise_exception(col('id'))).collect() - def test_vectorized_udf_mix_udf(self): - from pyspark.sql.functions import pandas_udf, udf, col - df = self.spark.range(10) - row_by_row_udf = udf(lambda x: x, LongType()) - pd_udf = pandas_udf(lambda x: x, LongType()) - with QuietTest(self.sc): - with self.assertRaisesRegexp( - Exception, - 'Can not mix vectorized and non-vectorized UDFs'): - df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect() - def test_vectorized_udf_chained(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10) @@ -5060,6 +5049,166 @@ def test_type_annotation(self): df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) self.assertEqual(df.first()[0], 0) + def test_mixed_udf(self): + import pandas as pd + from pyspark.sql.functions import col, udf, pandas_udf + + df = self.spark.range(0, 1).toDF('v') + + # Test mixture of multiple UDFs and Pandas UDFs. + + @udf('int') + def f1(x): + assert type(x) == int + return x + 1 + + @pandas_udf('int') + def f2(x): + assert type(x) == pd.Series + return x + 10 + + @udf('int') + def f3(x): + assert type(x) == int + return x + 100 + + @pandas_udf('int') + def f4(x): + assert type(x) == pd.Series + return x + 1000 + + # Test single expression with chained UDFs + df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v']))) + df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) + df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v']))))) + df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v'])))) + df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v'])))) + + expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11) + expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111) + expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111) + expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011) + expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101) + + self.assertEquals(expected_chained_1.collect(), df_chained_1.collect()) + self.assertEquals(expected_chained_2.collect(), df_chained_2.collect()) + self.assertEquals(expected_chained_3.collect(), df_chained_3.collect()) + self.assertEquals(expected_chained_4.collect(), df_chained_4.collect()) + self.assertEquals(expected_chained_5.collect(), df_chained_5.collect()) + + # Test multiple mixed UDF expressions in a single projection + df_multi_1 = df \ + .withColumn('f1', f1(col('v'))) \ + .withColumn('f2', f2(col('v'))) \ + .withColumn('f3', f3(col('v'))) \ + .withColumn('f4', f4(col('v'))) \ + .withColumn('f2_f1', f2(col('f1'))) \ + .withColumn('f3_f1', f3(col('f1'))) \ + .withColumn('f4_f1', f4(col('f1'))) \ + .withColumn('f3_f2', f3(col('f2'))) \ + .withColumn('f4_f2', f4(col('f2'))) \ + .withColumn('f4_f3', f4(col('f3'))) \ + .withColumn('f3_f2_f1', f3(col('f2_f1'))) \ + .withColumn('f4_f2_f1', f4(col('f2_f1'))) \ + .withColumn('f4_f3_f1', f4(col('f3_f1'))) \ + .withColumn('f4_f3_f2', f4(col('f3_f2'))) \ + .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1'))) + + # Test mixed udfs in a single expression + df_multi_2 = df \ + .withColumn('f1', f1(col('v'))) \ + .withColumn('f2', f2(col('v'))) \ + .withColumn('f3', f3(col('v'))) \ + .withColumn('f4', f4(col('v'))) \ + .withColumn('f2_f1', f2(f1(col('v')))) \ + .withColumn('f3_f1', f3(f1(col('v')))) \ + .withColumn('f4_f1', f4(f1(col('v')))) \ + .withColumn('f3_f2', f3(f2(col('v')))) \ + .withColumn('f4_f2', f4(f2(col('v')))) \ + .withColumn('f4_f3', f4(f3(col('v')))) \ + .withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \ + .withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \ + .withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \ + .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \ + .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v')))))) + + expected = df \ + .withColumn('f1', df['v'] + 1) \ + .withColumn('f2', df['v'] + 10) \ + .withColumn('f3', df['v'] + 100) \ + .withColumn('f4', df['v'] + 1000) \ + .withColumn('f2_f1', df['v'] + 11) \ + .withColumn('f3_f1', df['v'] + 101) \ + .withColumn('f4_f1', df['v'] + 1001) \ + .withColumn('f3_f2', df['v'] + 110) \ + .withColumn('f4_f2', df['v'] + 1010) \ + .withColumn('f4_f3', df['v'] + 1100) \ + .withColumn('f3_f2_f1', df['v'] + 111) \ + .withColumn('f4_f2_f1', df['v'] + 1011) \ + .withColumn('f4_f3_f1', df['v'] + 1101) \ + .withColumn('f4_f3_f2', df['v'] + 1110) \ + .withColumn('f4_f3_f2_f1', df['v'] + 1111) + + self.assertEquals(expected.collect(), df_multi_1.collect()) + self.assertEquals(expected.collect(), df_multi_2.collect()) + + def test_mixed_udf_and_sql(self): + import pandas as pd + from pyspark.sql import Column + from pyspark.sql.functions import udf, pandas_udf + + df = self.spark.range(0, 1).toDF('v') + + # Test mixture of UDFs, Pandas UDFs and SQL expression. + + @udf('int') + def f1(x): + assert type(x) == int + return x + 1 + + def f2(x): + assert type(x) == Column + return x + 10 + + @pandas_udf('int') + def f3(x): + assert type(x) == pd.Series + return x + 100 + + df1 = df.withColumn('f1', f1(df['v'])) \ + .withColumn('f2', f2(df['v'])) \ + .withColumn('f3', f3(df['v'])) \ + .withColumn('f1_f2', f1(f2(df['v']))) \ + .withColumn('f1_f3', f1(f3(df['v']))) \ + .withColumn('f2_f1', f2(f1(df['v']))) \ + .withColumn('f2_f3', f2(f3(df['v']))) \ + .withColumn('f3_f1', f3(f1(df['v']))) \ + .withColumn('f3_f2', f3(f2(df['v']))) \ + .withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \ + .withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \ + .withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \ + .withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \ + .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \ + .withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) + + expected = df.withColumn('f1', df['v'] + 1) \ + .withColumn('f2', df['v'] + 10) \ + .withColumn('f3', df['v'] + 100) \ + .withColumn('f1_f2', df['v'] + 11) \ + .withColumn('f1_f3', df['v'] + 101) \ + .withColumn('f2_f1', df['v'] + 11) \ + .withColumn('f2_f3', df['v'] + 110) \ + .withColumn('f3_f1', df['v'] + 101) \ + .withColumn('f3_f2', df['v'] + 110) \ + .withColumn('f1_f2_f3', df['v'] + 111) \ + .withColumn('f1_f3_f2', df['v'] + 111) \ + .withColumn('f2_f1_f3', df['v'] + 111) \ + .withColumn('f2_f3_f1', df['v'] + 111) \ + .withColumn('f3_f1_f2', df['v'] + 111) \ + .withColumn('f3_f2_f1', df['v'] + 111) + + self.assertEquals(expected.collect(), df1.collect()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, @@ -5487,6 +5636,21 @@ def dummy_pandas_udf(df): F.col('temp0.key') == F.col('temp1.key')) self.assertEquals(res.count(), 5) + def test_mixed_scalar_udfs_followed_by_grouby_apply(self): + import pandas as pd + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType + + df = self.spark.range(0, 10).toDF('v1') + df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ + .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1'])) + + result = df.groupby() \ + .apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]), + 'sum int', + PandasUDFType.GROUPED_MAP)) + + self.assertEquals(result.collect()[0]['sum'], 165) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 1e096100f7f43..cb75874be32ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} @@ -94,28 +95,44 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { - private def hasPythonUDF(e: Expression): Boolean = { + private type EvalType = Int + private type EvalTypeChecker = EvalType => Boolean + + private def hasScalarPythonUDF(e: Expression): Boolean = { e.find(PythonUDF.isScalarPythonUDF).isDefined } private def canEvaluateInPython(e: PythonUDF): Boolean = { e.children match { // single PythonUDF child could be chained and evaluated in Python - case Seq(u: PythonUDF) => canEvaluateInPython(u) + case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u) // Python UDF can't be evaluated directly in JVM - case children => !children.exists(hasPythonUDF) + case children => !children.exists(hasScalarPythonUDF) } } - private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { - case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) - case e => e.children.flatMap(collectEvaluatableUDF) + private def collectEvaluableUDFsFromExpressions(expressions: Seq[Expression]): Seq[PythonUDF] = { + // Eval type checker is set once when we find the first evaluable UDF and its value + // shouldn't change later. + // Used to check if subsequent UDFs are of the same type as the first UDF. (since we can only + // extract UDFs of the same eval type) + var evalTypeChecker: Option[EvalTypeChecker] = None + + def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match { + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && evalTypeChecker.isEmpty => + evalTypeChecker = Some((otherEvalType: EvalType) => otherEvalType == udf.evalType) + Seq(udf) + case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && evalTypeChecker.get(udf.evalType) => + Seq(udf) + case e => e.children.flatMap(collectEvaluableUDFs) + } + + expressions.flatMap(collectEvaluableUDFs) } def apply(plan: SparkPlan): SparkPlan = plan transformUp { - // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker - // Therefore we don't need to extract the UDFs - case plan: FlatMapGroupsInPandasExec => plan case plan: SparkPlan => extract(plan) } @@ -123,7 +140,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { * Extract all the PythonUDFs from the current operator and evaluate them before the operator. */ private def extract(plan: SparkPlan): SparkPlan = { - val udfs = plan.expressions.flatMap(collectEvaluatableUDF) + val udfs = collectEvaluableUDFsFromExpressions(plan.expressions) // ignore the PythonUDF that come from second/third aggregate, which is not used .filter(udf => udf.references.subsetOf(plan.inputSet)) if (udfs.isEmpty) { @@ -167,7 +184,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) case _ => - throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs") + throw new AnalysisException( + "Expected either Scalar Pandas UDFs or Batched UDFs but got both") } attributeMap ++= validUdfs.zip(resultAttrs) @@ -205,7 +223,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { case filter: FilterExec => val (candidates, nonDeterministic) = splitConjunctivePredicates(filter.condition).partition(_.deterministic) - val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) + val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_)) if (pushDown.nonEmpty) { val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index d456c931f5275..2cc55ff88b983 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -115,3 +115,10 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( dataType = BooleanType, pythonEvalType = PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) + +class MyDummyScalarPandasUDF extends UserDefinedPythonFunction( + name = "dummyScalarPandasUDF", + func = new DummyUDF, + dataType = BooleanType, + pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF, + udfDeterministic = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala new file mode 100644 index 0000000000000..76b609d111acd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.test.SharedSQLContext + +class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.newProductEncoder + import testImplicits.localSeqToDatasetHolder + + val batchedPythonUDF = new MyDummyPythonUDF + val scalarPandasUDF = new MyDummyScalarPandasUDF + + private def collectBatchExec(plan: SparkPlan): Seq[BatchEvalPythonExec] = plan.collect { + case b: BatchEvalPythonExec => b + } + + private def collectArrowExec(plan: SparkPlan): Seq[ArrowEvalPythonExec] = plan.collect { + case b: ArrowEvalPythonExec => b + } + + test("Chained Batched Python UDFs should be combined to a single physical node") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c", batchedPythonUDF(col("a"))) + .withColumn("d", batchedPythonUDF(col("c"))) + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 1) + } + + test("Chained Scalar Pandas UDFs should be combined to a single physical node") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c", scalarPandasUDF(col("a"))) + .withColumn("d", scalarPandasUDF(col("c"))) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(arrowEvalNodes.size == 1) + } + + test("Mixed Batched Python UDFs and Pandas UDF should be separate physical node") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c", batchedPythonUDF(col("a"))) + .withColumn("d", scalarPandasUDF(col("b"))) + + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 1) + assert(arrowEvalNodes.size == 1) + } + + test("Independent Batched Python UDFs and Scalar Pandas UDFs should be combined separately") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c1", batchedPythonUDF(col("a"))) + .withColumn("c2", batchedPythonUDF(col("c1"))) + .withColumn("d1", scalarPandasUDF(col("a"))) + .withColumn("d2", scalarPandasUDF(col("d1"))) + + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 1) + assert(arrowEvalNodes.size == 1) + } + + test("Dependent Batched Python UDFs and Scalar Pandas UDFs should not be combined") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = df.withColumn("c1", batchedPythonUDF(col("a"))) + .withColumn("d1", scalarPandasUDF(col("c1"))) + .withColumn("c2", batchedPythonUDF(col("d1"))) + .withColumn("d2", scalarPandasUDF(col("c2"))) + + val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan) + val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan) + assert(pythonEvalNodes.size == 2) + assert(arrowEvalNodes.size == 2) + } +} +