Skip to content

Commit

Permalink
Address comments; Use mutable state in collectEvaluableUDFs
Browse files Browse the repository at this point in the history
  • Loading branch information
icexelloss committed Jul 25, 2018
1 parent 4c9c007 commit 83635da
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 43 deletions.
43 changes: 30 additions & 13 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5055,7 +5055,7 @@ def test_mixed_udf(self):

df = self.spark.range(0, 1).toDF('v')

# Test mixture of multiple UDFs and Pandas UDFs
# Test mixture of multiple UDFs and Pandas UDFs.

@udf('int')
def f1(x):
Expand All @@ -5077,8 +5077,27 @@ def f4(x):
assert type(x) == pd.Series
return x + 1000

# Test mixed udfs in a single projection
df1 = df \
# Test single expression with chained UDFs
df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v'])))
df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v'])))))
df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v']))))
df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v']))))

expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11)
expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111)
expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111)
expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011)
expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101)

self.assertEquals(expected_chained_1.collect(), df_chained_1.collect())
self.assertEquals(expected_chained_2.collect(), df_chained_2.collect())
self.assertEquals(expected_chained_3.collect(), df_chained_3.collect())
self.assertEquals(expected_chained_4.collect(), df_chained_4.collect())
self.assertEquals(expected_chained_5.collect(), df_chained_5.collect())

# Test multiple mixed UDF expressions in a single projection
df_multi_1 = df \
.withColumn('f1', f1(col('v'))) \
.withColumn('f2', f2(col('v'))) \
.withColumn('f3', f3(col('v'))) \
Expand All @@ -5096,7 +5115,7 @@ def f4(x):
.withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1')))

# Test mixed udfs in a single expression
df2 = df \
df_multi_2 = df \
.withColumn('f1', f1(col('v'))) \
.withColumn('f2', f2(col('v'))) \
.withColumn('f3', f3(col('v'))) \
Expand All @@ -5113,8 +5132,7 @@ def f4(x):
.withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \
.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v'))))))

# expected result
df3 = df \
expected = df \
.withColumn('f1', df['v'] + 1) \
.withColumn('f2', df['v'] + 10) \
.withColumn('f3', df['v'] + 100) \
Expand All @@ -5131,8 +5149,8 @@ def f4(x):
.withColumn('f4_f3_f2', df['v'] + 1110) \
.withColumn('f4_f3_f2_f1', df['v'] + 1111)

self.assertEquals(df3.collect(), df1.collect())
self.assertEquals(df3.collect(), df2.collect())
self.assertEquals(expected.collect(), df_multi_1.collect())
self.assertEquals(expected.collect(), df_multi_2.collect())

def test_mixed_udf_and_sql(self):
import pandas as pd
Expand All @@ -5148,6 +5166,7 @@ def f1(x):
return x + 1

def f2(x):
assert type(x) == pyspark.sql.Column
return x + 10

@pandas_udf('int')
Expand All @@ -5171,8 +5190,7 @@ def f3(x):
.withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \
.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))

# expected result
df2 = df.withColumn('f1', df['v'] + 1) \
expected = df.withColumn('f1', df['v'] + 1) \
.withColumn('f2', df['v'] + 10) \
.withColumn('f3', df['v'] + 100) \
.withColumn('f1_f2', df['v'] + 11) \
Expand All @@ -5188,7 +5206,7 @@ def f3(x):
.withColumn('f3_f1_f2', df['v'] + 111) \
.withColumn('f3_f2_f1', df['v'] + 111)

self.assertEquals(df2.collect(), df1.collect())
self.assertEquals(expected.collect(), df1.collect())


@unittest.skipIf(
Expand Down Expand Up @@ -5618,9 +5636,8 @@ def dummy_pandas_udf(df):
self.assertEquals(res.count(), 5)

def test_mixed_scalar_udfs_followed_by_grouby_apply(self):
# Test Pandas UDF and scalar Python UDF followed by groupby apply
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
import pandas as pd
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType

df = self.spark.range(0, 10).toDF('v1')
df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,47 +95,81 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
*/
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {

private def hasScalarPythonUDF(e: Expression): Boolean = {
e.find(PythonUDF.isScalarPythonUDF).isDefined
}
private case class LazyEvalType(var evalType: Int = -1) {

private def canEvaluateInPython(e: PythonUDF, evalType: Int): Boolean = {
if (e.evalType != evalType) {
false
} else {
e.children match {
// single PythonUDF child could be chained and evaluated in Python
case Seq(u: PythonUDF) => canEvaluateInPython(u, evalType)
// Python UDF can't be evaluated directly in JVM
case children => !children.exists(hasScalarPythonUDF)
def isSet: Boolean = evalType >= 0

def set(evalType: Int): Unit = {
if (isSet) {
throw new IllegalStateException("Eval type has already been set")
} else {
this.evalType = evalType
}
}

def get(): Int = {
if (!isSet) {
throw new IllegalStateException("Eval type is not set")
} else {
evalType
}
}
}

private def collectEvaluableUDF(expr: Expression, evalType: Int): Seq[PythonUDF] = expr match {
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf, evalType) =>
Seq(udf)
case e => e.children.flatMap(collectEvaluableUDF(_, evalType))
private def hasScalarPythonUDF(e: Expression): Boolean = {
e.find(PythonUDF.isScalarPythonUDF).isDefined
}

/**
* Collect evaluable UDFs from the current node.
* Check whether a PythonUDF expression can be evaluated in Python.
*
* This function collects Python UDFs or Scalar Python UDFs from expressions of the input node,
* and returns a list of UDFs of the same eval type.
* If the lazy eval type is not set, this method checks for either Batched Python UDF and Scalar
* Pandas UDF. If the lazy eval type is set, this method checks for the expression of the
* specified eval type.
*
* If expressions contain both UDFs eval types, this function will only return Python UDFs.
* This method will also set the lazy eval type to be the type of the first evaluable expression,
* i.e., if lazy eval type is not set and we find a evaluable Python UDF expression, lazy eval
* type will be set to the eval type of the expression.
*
* The caller should call this function multiple times until all evaluable UDFs are collected.
*/
private def collectEvaluableUDFs(plan: SparkPlan): Seq[PythonUDF] = {
val pythonUDFs =
plan.expressions.flatMap(collectEvaluableUDF(_, PythonEvalType.SQL_BATCHED_UDF))

if (pythonUDFs.isEmpty) {
plan.expressions.flatMap(collectEvaluableUDF(_, PythonEvalType.SQL_SCALAR_PANDAS_UDF))
private def canEvaluateInPython(e: PythonUDF, lazyEvalType: LazyEvalType): Boolean = {
if (!lazyEvalType.isSet) {
e.children match {
// single PythonUDF child could be chained and evaluated in Python if eval type is the same
case Seq(u: PythonUDF) =>
// Need to recheck the eval type because lazy eval type will be set if child Python UDF is
// evaluable
canEvaluateInPython(u, lazyEvalType) && lazyEvalType.get == e.evalType
// Python UDF can't be evaluated directly in JVM
case children => if (!children.exists(hasScalarPythonUDF)) {
// We found the first evaluable expression, set lazy eval type to its eval type.
lazyEvalType.set(e.evalType)
true
} else {
false
}
}
} else {
pythonUDFs
if (e.evalType != lazyEvalType.get) {
false
} else {
e.children match {
case Seq(u: PythonUDF) => canEvaluateInPython(u, lazyEvalType)
case children => !children.exists(hasScalarPythonUDF)
}
}
}
}

private def collectEvaluableUDFs(
expr: Expression,
evalType: LazyEvalType
): Seq[PythonUDF] = {
expr match {
case udf: PythonUDF if
PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf, evalType) =>
Seq(udf)
case e => e.children.flatMap(collectEvaluableUDFs(_, evalType))
}
}

Expand All @@ -147,7 +181,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
*/
private def extract(plan: SparkPlan): SparkPlan = {
val udfs = collectEvaluableUDFs(plan)
val lazyEvalType = new LazyEvalType
val udfs = plan.expressions.flatMap(collectEvaluableUDFs(_, lazyEvalType))
// ignore the PythonUDF that come from second/third aggregate, which is not used
.filter(udf => udf.references.subsetOf(plan.inputSet))
if (udfs.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,3 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext {
}
}


0 comments on commit 83635da

Please sign in to comment.