Skip to content

Commit

Permalink
[SPARK-24624][SQL][PYTHON] Support mixture of Python UDF and Scalar P…
Browse files Browse the repository at this point in the history
…andas UDF

## What changes were proposed in this pull request?

This PR add supports for using mixed Python UDF and Scalar Pandas UDF, in the following two cases:

(1)
```
from pyspark.sql.functions import udf, pandas_udf

udf('int')
def f1(x):
    return x + 1

pandas_udf('int')
def f2(x):
    return x + 1

df = spark.range(0, 1).toDF('v') \
    .withColumn('foo', f1(col('v'))) \
    .withColumn('bar', f2(col('v')))

```

QueryPlan:
```
>>> df.explain(True)
== Parsed Logical Plan ==
'Project [v#2L, foo#5, f2('v) AS bar#9]
+- AnalysisBarrier
      +- Project [v#2L, f1(v#2L) AS foo#5]
         +- Project [id#0L AS v#2L]
            +- Range (0, 1, step=1, splits=Some(4))

== Analyzed Logical Plan ==
v: bigint, foo: int, bar: int
Project [v#2L, foo#5, f2(v#2L) AS bar#9]
+- Project [v#2L, f1(v#2L) AS foo#5]
   +- Project [id#0L AS v#2L]
      +- Range (0, 1, step=1, splits=Some(4))

== Optimized Logical Plan ==
Project [id#0L AS v#2L, f1(id#0L) AS foo#5, f2(id#0L) AS bar#9]
+- Range (0, 1, step=1, splits=Some(4))

== Physical Plan ==
*(2) Project [id#0L AS v#2L, pythonUDF0#13 AS foo#5, pythonUDF0#14 AS bar#9]
+- ArrowEvalPython [f2(id#0L)], [id#0L, pythonUDF0#13, pythonUDF0#14]
   +- BatchEvalPython [f1(id#0L)], [id#0L, pythonUDF0#13]
      +- *(1) Range (0, 1, step=1, splits=4)
```

(2)
```
from pyspark.sql.functions import udf, pandas_udf
udf('int')
def f1(x):
    return x + 1

pandas_udf('int')
def f2(x):
    return x + 1

df = spark.range(0, 1).toDF('v')
df = df.withColumn('foo', f2(f1(df['v'])))
```

QueryPlan:
```
>>> df.explain(True)
== Parsed Logical Plan ==
Project [v#21L, f2(f1(v#21L)) AS foo#46]
+- AnalysisBarrier
      +- Project [v#21L, f1(f2(v#21L)) AS foo#39]
         +- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#32]
            +- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#25]
               +- Project [id#19L AS v#21L]
                  +- Range (0, 1, step=1, splits=Some(4))

== Analyzed Logical Plan ==
v: bigint, foo: int
Project [v#21L, f2(f1(v#21L)) AS foo#46]
+- Project [v#21L, f1(f2(v#21L)) AS foo#39]
   +- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#32]
      +- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#25]
         +- Project [id#19L AS v#21L]
            +- Range (0, 1, step=1, splits=Some(4))

== Optimized Logical Plan ==
Project [id#19L AS v#21L, f2(f1(id#19L)) AS foo#46]
+- Range (0, 1, step=1, splits=Some(4))

== Physical Plan ==
*(2) Project [id#19L AS v#21L, pythonUDF0#50 AS foo#46]
+- ArrowEvalPython [f2(pythonUDF0#49)], [id#19L, pythonUDF0#49, pythonUDF0#50]
   +- BatchEvalPython [f1(id#19L)], [id#19L, pythonUDF0#49]
      +- *(1) Range (0, 1, step=1, splits=4)
```

## How was this patch tested?

New tests are added to BatchEvalPythonExecSuite and ScalarPandasUDFTests

Author: Li Jin <[email protected]>

Closes apache#21650 from icexelloss/SPARK-24624-mix-udf.
  • Loading branch information
icexelloss authored and HyukjinKwon committed Jul 28, 2018
1 parent 6424b14 commit e875209
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 23 deletions.
186 changes: 175 additions & 11 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4763,17 +4763,6 @@ def test_vectorized_udf_invalid_length(self):
'Result vector from pandas_udf was not the required length'):
df.select(raise_exception(col('id'))).collect()

def test_vectorized_udf_mix_udf(self):
from pyspark.sql.functions import pandas_udf, udf, col
df = self.spark.range(10)
row_by_row_udf = udf(lambda x: x, LongType())
pd_udf = pandas_udf(lambda x: x, LongType())
with QuietTest(self.sc):
with self.assertRaisesRegexp(
Exception,
'Can not mix vectorized and non-vectorized UDFs'):
df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect()

def test_vectorized_udf_chained(self):
from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10)
Expand Down Expand Up @@ -5060,6 +5049,166 @@ def test_type_annotation(self):
df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id'))
self.assertEqual(df.first()[0], 0)

def test_mixed_udf(self):
import pandas as pd
from pyspark.sql.functions import col, udf, pandas_udf

df = self.spark.range(0, 1).toDF('v')

# Test mixture of multiple UDFs and Pandas UDFs.

@udf('int')
def f1(x):
assert type(x) == int
return x + 1

@pandas_udf('int')
def f2(x):
assert type(x) == pd.Series
return x + 10

@udf('int')
def f3(x):
assert type(x) == int
return x + 100

@pandas_udf('int')
def f4(x):
assert type(x) == pd.Series
return x + 1000

# Test single expression with chained UDFs
df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v'])))
df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v'])))))
df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v']))))
df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v']))))

expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11)
expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111)
expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111)
expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011)
expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101)

self.assertEquals(expected_chained_1.collect(), df_chained_1.collect())
self.assertEquals(expected_chained_2.collect(), df_chained_2.collect())
self.assertEquals(expected_chained_3.collect(), df_chained_3.collect())
self.assertEquals(expected_chained_4.collect(), df_chained_4.collect())
self.assertEquals(expected_chained_5.collect(), df_chained_5.collect())

# Test multiple mixed UDF expressions in a single projection
df_multi_1 = df \
.withColumn('f1', f1(col('v'))) \
.withColumn('f2', f2(col('v'))) \
.withColumn('f3', f3(col('v'))) \
.withColumn('f4', f4(col('v'))) \
.withColumn('f2_f1', f2(col('f1'))) \
.withColumn('f3_f1', f3(col('f1'))) \
.withColumn('f4_f1', f4(col('f1'))) \
.withColumn('f3_f2', f3(col('f2'))) \
.withColumn('f4_f2', f4(col('f2'))) \
.withColumn('f4_f3', f4(col('f3'))) \
.withColumn('f3_f2_f1', f3(col('f2_f1'))) \
.withColumn('f4_f2_f1', f4(col('f2_f1'))) \
.withColumn('f4_f3_f1', f4(col('f3_f1'))) \
.withColumn('f4_f3_f2', f4(col('f3_f2'))) \
.withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1')))

# Test mixed udfs in a single expression
df_multi_2 = df \
.withColumn('f1', f1(col('v'))) \
.withColumn('f2', f2(col('v'))) \
.withColumn('f3', f3(col('v'))) \
.withColumn('f4', f4(col('v'))) \
.withColumn('f2_f1', f2(f1(col('v')))) \
.withColumn('f3_f1', f3(f1(col('v')))) \
.withColumn('f4_f1', f4(f1(col('v')))) \
.withColumn('f3_f2', f3(f2(col('v')))) \
.withColumn('f4_f2', f4(f2(col('v')))) \
.withColumn('f4_f3', f4(f3(col('v')))) \
.withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \
.withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \
.withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \
.withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \
.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v'))))))

expected = df \
.withColumn('f1', df['v'] + 1) \
.withColumn('f2', df['v'] + 10) \
.withColumn('f3', df['v'] + 100) \
.withColumn('f4', df['v'] + 1000) \
.withColumn('f2_f1', df['v'] + 11) \
.withColumn('f3_f1', df['v'] + 101) \
.withColumn('f4_f1', df['v'] + 1001) \
.withColumn('f3_f2', df['v'] + 110) \
.withColumn('f4_f2', df['v'] + 1010) \
.withColumn('f4_f3', df['v'] + 1100) \
.withColumn('f3_f2_f1', df['v'] + 111) \
.withColumn('f4_f2_f1', df['v'] + 1011) \
.withColumn('f4_f3_f1', df['v'] + 1101) \
.withColumn('f4_f3_f2', df['v'] + 1110) \
.withColumn('f4_f3_f2_f1', df['v'] + 1111)

self.assertEquals(expected.collect(), df_multi_1.collect())
self.assertEquals(expected.collect(), df_multi_2.collect())

def test_mixed_udf_and_sql(self):
import pandas as pd
from pyspark.sql import Column
from pyspark.sql.functions import udf, pandas_udf

df = self.spark.range(0, 1).toDF('v')

# Test mixture of UDFs, Pandas UDFs and SQL expression.

@udf('int')
def f1(x):
assert type(x) == int
return x + 1

def f2(x):
assert type(x) == Column
return x + 10

@pandas_udf('int')
def f3(x):
assert type(x) == pd.Series
return x + 100

df1 = df.withColumn('f1', f1(df['v'])) \
.withColumn('f2', f2(df['v'])) \
.withColumn('f3', f3(df['v'])) \
.withColumn('f1_f2', f1(f2(df['v']))) \
.withColumn('f1_f3', f1(f3(df['v']))) \
.withColumn('f2_f1', f2(f1(df['v']))) \
.withColumn('f2_f3', f2(f3(df['v']))) \
.withColumn('f3_f1', f3(f1(df['v']))) \
.withColumn('f3_f2', f3(f2(df['v']))) \
.withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \
.withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \
.withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \
.withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \
.withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \
.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))

expected = df.withColumn('f1', df['v'] + 1) \
.withColumn('f2', df['v'] + 10) \
.withColumn('f3', df['v'] + 100) \
.withColumn('f1_f2', df['v'] + 11) \
.withColumn('f1_f3', df['v'] + 101) \
.withColumn('f2_f1', df['v'] + 11) \
.withColumn('f2_f3', df['v'] + 110) \
.withColumn('f3_f1', df['v'] + 101) \
.withColumn('f3_f2', df['v'] + 110) \
.withColumn('f1_f2_f3', df['v'] + 111) \
.withColumn('f1_f3_f2', df['v'] + 111) \
.withColumn('f2_f1_f3', df['v'] + 111) \
.withColumn('f2_f3_f1', df['v'] + 111) \
.withColumn('f3_f1_f2', df['v'] + 111) \
.withColumn('f3_f2_f1', df['v'] + 111)

self.assertEquals(expected.collect(), df1.collect())


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down Expand Up @@ -5487,6 +5636,21 @@ def dummy_pandas_udf(df):
F.col('temp0.key') == F.col('temp1.key'))
self.assertEquals(res.count(), 5)

def test_mixed_scalar_udfs_followed_by_grouby_apply(self):
import pandas as pd
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType

df = self.spark.range(0, 10).toDF('v1')
df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
.withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))

result = df.groupby() \
.apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]),
'sum int',
PandasUDFType.GROUPED_MAP))

self.assertEquals(result.collect()[0]['sum'], 165)


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
Expand Down Expand Up @@ -94,36 +95,52 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
*/
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {

private def hasPythonUDF(e: Expression): Boolean = {
private type EvalType = Int
private type EvalTypeChecker = EvalType => Boolean

private def hasScalarPythonUDF(e: Expression): Boolean = {
e.find(PythonUDF.isScalarPythonUDF).isDefined
}

private def canEvaluateInPython(e: PythonUDF): Boolean = {
e.children match {
// single PythonUDF child could be chained and evaluated in Python
case Seq(u: PythonUDF) => canEvaluateInPython(u)
case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u)
// Python UDF can't be evaluated directly in JVM
case children => !children.exists(hasPythonUDF)
case children => !children.exists(hasScalarPythonUDF)
}
}

private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match {
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf)
case e => e.children.flatMap(collectEvaluatableUDF)
private def collectEvaluableUDFsFromExpressions(expressions: Seq[Expression]): Seq[PythonUDF] = {
// Eval type checker is set once when we find the first evaluable UDF and its value
// shouldn't change later.
// Used to check if subsequent UDFs are of the same type as the first UDF. (since we can only
// extract UDFs of the same eval type)
var evalTypeChecker: Option[EvalTypeChecker] = None

def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match {
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
&& evalTypeChecker.isEmpty =>
evalTypeChecker = Some((otherEvalType: EvalType) => otherEvalType == udf.evalType)
Seq(udf)
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
&& evalTypeChecker.get(udf.evalType) =>
Seq(udf)
case e => e.children.flatMap(collectEvaluableUDFs)
}

expressions.flatMap(collectEvaluableUDFs)
}

def apply(plan: SparkPlan): SparkPlan = plan transformUp {
// AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker
// Therefore we don't need to extract the UDFs
case plan: FlatMapGroupsInPandasExec => plan
case plan: SparkPlan => extract(plan)
}

/**
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
*/
private def extract(plan: SparkPlan): SparkPlan = {
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
val udfs = collectEvaluableUDFsFromExpressions(plan.expressions)
// ignore the PythonUDF that come from second/third aggregate, which is not used
.filter(udf => udf.references.subsetOf(plan.inputSet))
if (udfs.isEmpty) {
Expand Down Expand Up @@ -167,7 +184,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
case _ =>
throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs")
throw new AnalysisException(
"Expected either Scalar Pandas UDFs or Batched UDFs but got both")
}

attributeMap ++= validUdfs.zip(resultAttrs)
Expand Down Expand Up @@ -205,7 +223,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
case filter: FilterExec =>
val (candidates, nonDeterministic) =
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
val (pushDown, rest) = candidates.partition(!hasPythonUDF(_))
val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_))
if (pushDown.nonEmpty) {
val newChild = FilterExec(pushDown.reduceLeft(And), filter.child)
FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,10 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction(
dataType = BooleanType,
pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
udfDeterministic = true)

class MyDummyScalarPandasUDF extends UserDefinedPythonFunction(
name = "dummyScalarPandasUDF",
func = new DummyUDF,
dataType = BooleanType,
pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF,
udfDeterministic = true)
Loading

0 comments on commit e875209

Please sign in to comment.