-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-24624][SQL][PYTHON] Support mixture of Python UDF and Scalar Pandas UDF #21650
Changes from all commits
3c2fe9a
b3435b6
490dc09
3015257
cbf310e
78f2ebf
4c9c007
83635da
2bc906d
6b22fea
b25936d
8e995e8
f3a45a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4763,17 +4763,6 @@ def test_vectorized_udf_invalid_length(self): | |
'Result vector from pandas_udf was not the required length'): | ||
df.select(raise_exception(col('id'))).collect() | ||
|
||
def test_vectorized_udf_mix_udf(self): | ||
from pyspark.sql.functions import pandas_udf, udf, col | ||
df = self.spark.range(10) | ||
row_by_row_udf = udf(lambda x: x, LongType()) | ||
pd_udf = pandas_udf(lambda x: x, LongType()) | ||
with QuietTest(self.sc): | ||
with self.assertRaisesRegexp( | ||
Exception, | ||
'Can not mix vectorized and non-vectorized UDFs'): | ||
df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect() | ||
|
||
def test_vectorized_udf_chained(self): | ||
from pyspark.sql.functions import pandas_udf, col | ||
df = self.spark.range(10) | ||
|
@@ -5060,6 +5049,166 @@ def test_type_annotation(self): | |
df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) | ||
self.assertEqual(df.first()[0], 0) | ||
|
||
def test_mixed_udf(self): | ||
import pandas as pd | ||
from pyspark.sql.functions import col, udf, pandas_udf | ||
|
||
df = self.spark.range(0, 1).toDF('v') | ||
|
||
# Test mixture of multiple UDFs and Pandas UDFs. | ||
|
||
@udf('int') | ||
def f1(x): | ||
assert type(x) == int | ||
return x + 1 | ||
|
||
@pandas_udf('int') | ||
def f2(x): | ||
assert type(x) == pd.Series | ||
return x + 10 | ||
|
||
@udf('int') | ||
def f3(x): | ||
assert type(x) == int | ||
return x + 100 | ||
|
||
@pandas_udf('int') | ||
def f4(x): | ||
assert type(x) == pd.Series | ||
return x + 1000 | ||
|
||
# Test single expression with chained UDFs | ||
df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v']))) | ||
df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) | ||
df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v']))))) | ||
df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v'])))) | ||
df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v'])))) | ||
|
||
expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11) | ||
expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111) | ||
expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111) | ||
expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011) | ||
expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101) | ||
|
||
self.assertEquals(expected_chained_1.collect(), df_chained_1.collect()) | ||
self.assertEquals(expected_chained_2.collect(), df_chained_2.collect()) | ||
self.assertEquals(expected_chained_3.collect(), df_chained_3.collect()) | ||
self.assertEquals(expected_chained_4.collect(), df_chained_4.collect()) | ||
self.assertEquals(expected_chained_5.collect(), df_chained_5.collect()) | ||
|
||
# Test multiple mixed UDF expressions in a single projection | ||
df_multi_1 = df \ | ||
.withColumn('f1', f1(col('v'))) \ | ||
.withColumn('f2', f2(col('v'))) \ | ||
.withColumn('f3', f3(col('v'))) \ | ||
.withColumn('f4', f4(col('v'))) \ | ||
.withColumn('f2_f1', f2(col('f1'))) \ | ||
.withColumn('f3_f1', f3(col('f1'))) \ | ||
.withColumn('f4_f1', f4(col('f1'))) \ | ||
.withColumn('f3_f2', f3(col('f2'))) \ | ||
.withColumn('f4_f2', f4(col('f2'))) \ | ||
.withColumn('f4_f3', f4(col('f3'))) \ | ||
.withColumn('f3_f2_f1', f3(col('f2_f1'))) \ | ||
.withColumn('f4_f2_f1', f4(col('f2_f1'))) \ | ||
.withColumn('f4_f3_f1', f4(col('f3_f1'))) \ | ||
.withColumn('f4_f3_f2', f4(col('f3_f2'))) \ | ||
.withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1'))) | ||
|
||
# Test mixed udfs in a single expression | ||
df_multi_2 = df \ | ||
.withColumn('f1', f1(col('v'))) \ | ||
.withColumn('f2', f2(col('v'))) \ | ||
.withColumn('f3', f3(col('v'))) \ | ||
.withColumn('f4', f4(col('v'))) \ | ||
.withColumn('f2_f1', f2(f1(col('v')))) \ | ||
.withColumn('f3_f1', f3(f1(col('v')))) \ | ||
.withColumn('f4_f1', f4(f1(col('v')))) \ | ||
.withColumn('f3_f2', f3(f2(col('v')))) \ | ||
.withColumn('f4_f2', f4(f2(col('v')))) \ | ||
.withColumn('f4_f3', f4(f3(col('v')))) \ | ||
.withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \ | ||
.withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \ | ||
.withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \ | ||
.withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \ | ||
.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v')))))) | ||
|
||
expected = df \ | ||
.withColumn('f1', df['v'] + 1) \ | ||
.withColumn('f2', df['v'] + 10) \ | ||
.withColumn('f3', df['v'] + 100) \ | ||
.withColumn('f4', df['v'] + 1000) \ | ||
.withColumn('f2_f1', df['v'] + 11) \ | ||
.withColumn('f3_f1', df['v'] + 101) \ | ||
.withColumn('f4_f1', df['v'] + 1001) \ | ||
.withColumn('f3_f2', df['v'] + 110) \ | ||
.withColumn('f4_f2', df['v'] + 1010) \ | ||
.withColumn('f4_f3', df['v'] + 1100) \ | ||
.withColumn('f3_f2_f1', df['v'] + 111) \ | ||
.withColumn('f4_f2_f1', df['v'] + 1011) \ | ||
.withColumn('f4_f3_f1', df['v'] + 1101) \ | ||
.withColumn('f4_f3_f2', df['v'] + 1110) \ | ||
.withColumn('f4_f3_f2_f1', df['v'] + 1111) | ||
|
||
self.assertEquals(expected.collect(), df_multi_1.collect()) | ||
self.assertEquals(expected.collect(), df_multi_2.collect()) | ||
|
||
def test_mixed_udf_and_sql(self): | ||
import pandas as pd | ||
from pyspark.sql import Column | ||
from pyspark.sql.functions import udf, pandas_udf | ||
|
||
df = self.spark.range(0, 1).toDF('v') | ||
|
||
# Test mixture of UDFs, Pandas UDFs and SQL expression. | ||
|
||
@udf('int') | ||
def f1(x): | ||
assert type(x) == int | ||
return x + 1 | ||
|
||
def f2(x): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like this is neither There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the purpose is to test mixing udf, pandas_udf and sql expression. I will add comments to make it clearer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added comments in test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see why it looks confusing. Can we add an assert here too (check if it's a column)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
assert type(x) == Column | ||
return x + 10 | ||
|
||
@pandas_udf('int') | ||
def f3(x): | ||
assert type(x) == pd.Series | ||
return x + 100 | ||
|
||
df1 = df.withColumn('f1', f1(df['v'])) \ | ||
.withColumn('f2', f2(df['v'])) \ | ||
.withColumn('f3', f3(df['v'])) \ | ||
.withColumn('f1_f2', f1(f2(df['v']))) \ | ||
.withColumn('f1_f3', f1(f3(df['v']))) \ | ||
.withColumn('f2_f1', f2(f1(df['v']))) \ | ||
.withColumn('f2_f3', f2(f3(df['v']))) \ | ||
.withColumn('f3_f1', f3(f1(df['v']))) \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks combination between f1 and f3 duplicating few tests in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the way the test is written is that I am trying to test many combinations so there are some dup cases. Do you prefer that I remove these? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea.. I know it's still minor since the elapsed time will be virtually the same but recently the build / test time was an issue, and I wonder if there's better way then avoding duplicated tests for now.. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was discussed here #21845 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I don't think it's necessary (we are only likely to remove a few cases and like you said, the test time is virtually the same) and helps the readability of the tests (so it doesn't look like some test cases are missed). But if that's the preferred practice I can remove duplicate cases in the next commit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am okay to leave it too here since it's clear they are virtually the same but let's remove duplicated tests or orthogonal tests next time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotcha. I will keep that in mind next time. |
||
.withColumn('f3_f2', f3(f2(df['v']))) \ | ||
.withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \ | ||
.withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \ | ||
.withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \ | ||
.withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \ | ||
.withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \ | ||
.withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) | ||
|
||
expected = df.withColumn('f1', df['v'] + 1) \ | ||
.withColumn('f2', df['v'] + 10) \ | ||
.withColumn('f3', df['v'] + 100) \ | ||
.withColumn('f1_f2', df['v'] + 11) \ | ||
.withColumn('f1_f3', df['v'] + 101) \ | ||
.withColumn('f2_f1', df['v'] + 11) \ | ||
.withColumn('f2_f3', df['v'] + 110) \ | ||
.withColumn('f3_f1', df['v'] + 101) \ | ||
.withColumn('f3_f2', df['v'] + 110) \ | ||
.withColumn('f1_f2_f3', df['v'] + 111) \ | ||
.withColumn('f1_f3_f2', df['v'] + 111) \ | ||
.withColumn('f2_f1_f3', df['v'] + 111) \ | ||
.withColumn('f2_f3_f1', df['v'] + 111) \ | ||
.withColumn('f3_f1_f2', df['v'] + 111) \ | ||
.withColumn('f3_f2_f1', df['v'] + 111) | ||
|
||
self.assertEquals(expected.collect(), df1.collect()) | ||
|
||
|
||
@unittest.skipIf( | ||
not _have_pandas or not _have_pyarrow, | ||
|
@@ -5487,6 +5636,21 @@ def dummy_pandas_udf(df): | |
F.col('temp0.key') == F.col('temp1.key')) | ||
self.assertEquals(res.count(), 5) | ||
|
||
def test_mixed_scalar_udfs_followed_by_grouby_apply(self): | ||
import pandas as pd | ||
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType | ||
|
||
df = self.spark.range(0, 10).toDF('v1') | ||
df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ | ||
.withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1'])) | ||
|
||
result = df.groupby() \ | ||
.apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]), | ||
'sum int', | ||
PandasUDFType.GROUPED_MAP)) | ||
|
||
self.assertEquals(result.collect()[0]['sum'], 165) | ||
|
||
|
||
@unittest.skipIf( | ||
not _have_pandas or not _have_pyarrow, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ import scala.collection.mutable | |
import scala.collection.mutable.ArrayBuffer | ||
|
||
import org.apache.spark.api.python.PythonEvalType | ||
import org.apache.spark.sql.AnalysisException | ||
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression | ||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} | ||
|
@@ -94,36 +95,52 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { | |
*/ | ||
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { | ||
|
||
private def hasPythonUDF(e: Expression): Boolean = { | ||
private type EvalType = Int | ||
private type EvalTypeChecker = EvalType => Boolean | ||
|
||
private def hasScalarPythonUDF(e: Expression): Boolean = { | ||
e.find(PythonUDF.isScalarPythonUDF).isDefined | ||
} | ||
|
||
private def canEvaluateInPython(e: PythonUDF): Boolean = { | ||
e.children match { | ||
// single PythonUDF child could be chained and evaluated in Python | ||
case Seq(u: PythonUDF) => canEvaluateInPython(u) | ||
case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u) | ||
// Python UDF can't be evaluated directly in JVM | ||
case children => !children.exists(hasPythonUDF) | ||
case children => !children.exists(hasScalarPythonUDF) | ||
} | ||
} | ||
|
||
private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { | ||
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) | ||
case e => e.children.flatMap(collectEvaluatableUDF) | ||
private def collectEvaluableUDFsFromExpressions(expressions: Seq[Expression]): Seq[PythonUDF] = { | ||
// Eval type checker is set once when we find the first evaluable UDF and its value | ||
// shouldn't change later. | ||
// Used to check if subsequent UDFs are of the same type as the first UDF. (since we can only | ||
// extract UDFs of the same eval type) | ||
var evalTypeChecker: Option[EvalTypeChecker] = None | ||
|
||
def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match { | ||
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) | ||
&& evalTypeChecker.isEmpty => | ||
evalTypeChecker = Some((otherEvalType: EvalType) => otherEvalType == udf.evalType) | ||
Seq(udf) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @HyukjinKwon In your code this line is |
||
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) | ||
&& evalTypeChecker.get(udf.evalType) => | ||
Seq(udf) | ||
case e => e.children.flatMap(collectEvaluableUDFs) | ||
} | ||
|
||
expressions.flatMap(collectEvaluableUDFs) | ||
} | ||
|
||
def apply(plan: SparkPlan): SparkPlan = plan transformUp { | ||
// AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker | ||
// Therefore we don't need to extract the UDFs | ||
case plan: FlatMapGroupsInPandasExec => plan | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is no longer needed because this rule will only extract Python UDF and Scalar Pandas UDF and ignore other types of UDFs |
||
case plan: SparkPlan => extract(plan) | ||
} | ||
|
||
/** | ||
* Extract all the PythonUDFs from the current operator and evaluate them before the operator. | ||
*/ | ||
private def extract(plan: SparkPlan): SparkPlan = { | ||
val udfs = plan.expressions.flatMap(collectEvaluatableUDF) | ||
val udfs = collectEvaluableUDFsFromExpressions(plan.expressions) | ||
// ignore the PythonUDF that come from second/third aggregate, which is not used | ||
.filter(udf => udf.references.subsetOf(plan.inputSet)) | ||
if (udfs.isEmpty) { | ||
|
@@ -167,7 +184,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { | |
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => | ||
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) | ||
case _ => | ||
throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs") | ||
throw new AnalysisException( | ||
"Expected either Scalar Pandas UDFs or Batched UDFs but got both") | ||
} | ||
|
||
attributeMap ++= validUdfs.zip(resultAttrs) | ||
|
@@ -205,7 +223,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { | |
case filter: FilterExec => | ||
val (candidates, nonDeterministic) = | ||
splitConjunctivePredicates(filter.condition).partition(_.deterministic) | ||
val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) | ||
val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_)) | ||
if (pushDown.nonEmpty) { | ||
val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) | ||
FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks testing udf + udf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the way the test is written is that I am trying to test many combinations so some combinations might not be mixed UDF. Do you prefer that I remove these cases?