From 83635da0e0ff033e6c1d9aa750fba596c348c262 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Wed, 25 Jul 2018 14:27:24 +0000 Subject: [PATCH] Address comments; Use mutable state in collectEvaluableUDFs --- python/pyspark/sql/tests.py | 43 ++++++--- .../execution/python/ExtractPythonUDFs.scala | 93 +++++++++++++------ .../python/ExtractPythonUDFsSuite.scala | 1 - 3 files changed, 94 insertions(+), 43 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b994c06f668c9..9da3b4cfa0c5b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5055,7 +5055,7 @@ def test_mixed_udf(self): df = self.spark.range(0, 1).toDF('v') - # Test mixture of multiple UDFs and Pandas UDFs + # Test mixture of multiple UDFs and Pandas UDFs. @udf('int') def f1(x): @@ -5077,8 +5077,27 @@ def f4(x): assert type(x) == pd.Series return x + 1000 - # Test mixed udfs in a single projection - df1 = df \ + # Test single expression with chained UDFs + df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v']))) + df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) + df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v']))))) + df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v'])))) + df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v'])))) + + expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11) + expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111) + expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111) + expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011) + expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101) + + self.assertEquals(expected_chained_1.collect(), df_chained_1.collect()) + self.assertEquals(expected_chained_2.collect(), df_chained_2.collect()) + self.assertEquals(expected_chained_3.collect(), df_chained_3.collect()) + self.assertEquals(expected_chained_4.collect(), df_chained_4.collect()) + self.assertEquals(expected_chained_5.collect(), df_chained_5.collect()) + + # Test multiple mixed UDF expressions in a single projection + df_multi_1 = df \ .withColumn('f1', f1(col('v'))) \ .withColumn('f2', f2(col('v'))) \ .withColumn('f3', f3(col('v'))) \ @@ -5096,7 +5115,7 @@ def f4(x): .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1'))) # Test mixed udfs in a single expression - df2 = df \ + df_multi_2 = df \ .withColumn('f1', f1(col('v'))) \ .withColumn('f2', f2(col('v'))) \ .withColumn('f3', f3(col('v'))) \ @@ -5113,8 +5132,7 @@ def f4(x): .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \ .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v')))))) - # expected result - df3 = df \ + expected = df \ .withColumn('f1', df['v'] + 1) \ .withColumn('f2', df['v'] + 10) \ .withColumn('f3', df['v'] + 100) \ @@ -5131,8 +5149,8 @@ def f4(x): .withColumn('f4_f3_f2', df['v'] + 1110) \ .withColumn('f4_f3_f2_f1', df['v'] + 1111) - self.assertEquals(df3.collect(), df1.collect()) - self.assertEquals(df3.collect(), df2.collect()) + self.assertEquals(expected.collect(), df_multi_1.collect()) + self.assertEquals(expected.collect(), df_multi_2.collect()) def test_mixed_udf_and_sql(self): import pandas as pd @@ -5148,6 +5166,7 @@ def f1(x): return x + 1 def f2(x): + assert type(x) == pyspark.sql.Column return x + 10 @pandas_udf('int') @@ -5171,8 +5190,7 @@ def f3(x): .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \ .withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) - # expected result - df2 = df.withColumn('f1', df['v'] + 1) \ + expected = df.withColumn('f1', df['v'] + 1) \ .withColumn('f2', df['v'] + 10) \ .withColumn('f3', df['v'] + 100) \ .withColumn('f1_f2', df['v'] + 11) \ @@ -5188,7 +5206,7 @@ def f3(x): .withColumn('f3_f1_f2', df['v'] + 111) \ .withColumn('f3_f2_f1', df['v'] + 111) - self.assertEquals(df2.collect(), df1.collect()) + self.assertEquals(expected.collect(), df1.collect()) @unittest.skipIf( @@ -5618,9 +5636,8 @@ def dummy_pandas_udf(df): self.assertEquals(res.count(), 5) def test_mixed_scalar_udfs_followed_by_grouby_apply(self): - # Test Pandas UDF and scalar Python UDF followed by groupby apply - from pyspark.sql.functions import udf, pandas_udf, PandasUDFType import pandas as pd + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType df = self.spark.range(0, 10).toDF('v1') df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 90818a95ac766..b4bb89578b8b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -95,47 +95,81 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { */ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { - private def hasScalarPythonUDF(e: Expression): Boolean = { - e.find(PythonUDF.isScalarPythonUDF).isDefined - } + private case class LazyEvalType(var evalType: Int = -1) { - private def canEvaluateInPython(e: PythonUDF, evalType: Int): Boolean = { - if (e.evalType != evalType) { - false - } else { - e.children match { - // single PythonUDF child could be chained and evaluated in Python - case Seq(u: PythonUDF) => canEvaluateInPython(u, evalType) - // Python UDF can't be evaluated directly in JVM - case children => !children.exists(hasScalarPythonUDF) + def isSet: Boolean = evalType >= 0 + + def set(evalType: Int): Unit = { + if (isSet) { + throw new IllegalStateException("Eval type has already been set") + } else { + this.evalType = evalType + } + } + + def get(): Int = { + if (!isSet) { + throw new IllegalStateException("Eval type is not set") + } else { + evalType } } } - private def collectEvaluableUDF(expr: Expression, evalType: Int): Seq[PythonUDF] = expr match { - case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf, evalType) => - Seq(udf) - case e => e.children.flatMap(collectEvaluableUDF(_, evalType)) + private def hasScalarPythonUDF(e: Expression): Boolean = { + e.find(PythonUDF.isScalarPythonUDF).isDefined } /** - * Collect evaluable UDFs from the current node. + * Check whether a PythonUDF expression can be evaluated in Python. * - * This function collects Python UDFs or Scalar Python UDFs from expressions of the input node, - * and returns a list of UDFs of the same eval type. + * If the lazy eval type is not set, this method checks for either Batched Python UDF and Scalar + * Pandas UDF. If the lazy eval type is set, this method checks for the expression of the + * specified eval type. * - * If expressions contain both UDFs eval types, this function will only return Python UDFs. + * This method will also set the lazy eval type to be the type of the first evaluable expression, + * i.e., if lazy eval type is not set and we find a evaluable Python UDF expression, lazy eval + * type will be set to the eval type of the expression. * - * The caller should call this function multiple times until all evaluable UDFs are collected. */ - private def collectEvaluableUDFs(plan: SparkPlan): Seq[PythonUDF] = { - val pythonUDFs = - plan.expressions.flatMap(collectEvaluableUDF(_, PythonEvalType.SQL_BATCHED_UDF)) - - if (pythonUDFs.isEmpty) { - plan.expressions.flatMap(collectEvaluableUDF(_, PythonEvalType.SQL_SCALAR_PANDAS_UDF)) + private def canEvaluateInPython(e: PythonUDF, lazyEvalType: LazyEvalType): Boolean = { + if (!lazyEvalType.isSet) { + e.children match { + // single PythonUDF child could be chained and evaluated in Python if eval type is the same + case Seq(u: PythonUDF) => + // Need to recheck the eval type because lazy eval type will be set if child Python UDF is + // evaluable + canEvaluateInPython(u, lazyEvalType) && lazyEvalType.get == e.evalType + // Python UDF can't be evaluated directly in JVM + case children => if (!children.exists(hasScalarPythonUDF)) { + // We found the first evaluable expression, set lazy eval type to its eval type. + lazyEvalType.set(e.evalType) + true + } else { + false + } + } } else { - pythonUDFs + if (e.evalType != lazyEvalType.get) { + false + } else { + e.children match { + case Seq(u: PythonUDF) => canEvaluateInPython(u, lazyEvalType) + case children => !children.exists(hasScalarPythonUDF) + } + } + } + } + + private def collectEvaluableUDFs( + expr: Expression, + evalType: LazyEvalType + ): Seq[PythonUDF] = { + expr match { + case udf: PythonUDF if + PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf, evalType) => + Seq(udf) + case e => e.children.flatMap(collectEvaluableUDFs(_, evalType)) } } @@ -147,7 +181,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { * Extract all the PythonUDFs from the current operator and evaluate them before the operator. */ private def extract(plan: SparkPlan): SparkPlan = { - val udfs = collectEvaluableUDFs(plan) + val lazyEvalType = new LazyEvalType + val udfs = plan.expressions.flatMap(collectEvaluableUDFs(_, lazyEvalType)) // ignore the PythonUDF that come from second/third aggregate, which is not used .filter(udf => udf.references.subsetOf(plan.inputSet)) if (udfs.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala index 2cb2e27c7deb0..76b609d111acd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala @@ -90,4 +90,3 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext { } } -