-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-24624][SQL][PYTHON] Support mixture of Python UDF and Scalar Pandas UDF #21650
Changes from 12 commits
3c2fe9a
b3435b6
490dc09
3015257
cbf310e
78f2ebf
4c9c007
83635da
2bc906d
6b22fea
b25936d
8e995e8
f3a45a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4763,17 +4763,6 @@ def test_vectorized_udf_invalid_length(self): | |
'Result vector from pandas_udf was not the required length'): | ||
df.select(raise_exception(col('id'))).collect() | ||
|
||
def test_vectorized_udf_mix_udf(self): | ||
from pyspark.sql.functions import pandas_udf, udf, col | ||
df = self.spark.range(10) | ||
row_by_row_udf = udf(lambda x: x, LongType()) | ||
pd_udf = pandas_udf(lambda x: x, LongType()) | ||
with QuietTest(self.sc): | ||
with self.assertRaisesRegexp( | ||
Exception, | ||
'Can not mix vectorized and non-vectorized UDFs'): | ||
df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect() | ||
|
||
def test_vectorized_udf_chained(self): | ||
from pyspark.sql.functions import pandas_udf, col | ||
df = self.spark.range(10) | ||
|
@@ -5060,6 +5049,166 @@ def test_type_annotation(self): | |
df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id')) | ||
self.assertEqual(df.first()[0], 0) | ||
|
||
def test_mixed_udf(self): | ||
import pandas as pd | ||
from pyspark.sql.functions import col, udf, pandas_udf | ||
|
||
df = self.spark.range(0, 1).toDF('v') | ||
|
||
# Test mixture of multiple UDFs and Pandas UDFs. | ||
|
||
@udf('int') | ||
def f1(x): | ||
assert type(x) == int | ||
return x + 1 | ||
|
||
@pandas_udf('int') | ||
def f2(x): | ||
assert type(x) == pd.Series | ||
return x + 10 | ||
|
||
@udf('int') | ||
def f3(x): | ||
assert type(x) == int | ||
return x + 100 | ||
|
||
@pandas_udf('int') | ||
def f4(x): | ||
assert type(x) == pd.Series | ||
return x + 1000 | ||
|
||
# Test single expression with chained UDFs | ||
df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v']))) | ||
df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) | ||
df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v']))))) | ||
df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v'])))) | ||
df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v'])))) | ||
|
||
expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11) | ||
expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111) | ||
expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111) | ||
expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011) | ||
expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101) | ||
|
||
self.assertEquals(expected_chained_1.collect(), df_chained_1.collect()) | ||
self.assertEquals(expected_chained_2.collect(), df_chained_2.collect()) | ||
self.assertEquals(expected_chained_3.collect(), df_chained_3.collect()) | ||
self.assertEquals(expected_chained_4.collect(), df_chained_4.collect()) | ||
self.assertEquals(expected_chained_5.collect(), df_chained_5.collect()) | ||
|
||
# Test multiple mixed UDF expressions in a single projection | ||
df_multi_1 = df \ | ||
.withColumn('f1', f1(col('v'))) \ | ||
.withColumn('f2', f2(col('v'))) \ | ||
.withColumn('f3', f3(col('v'))) \ | ||
.withColumn('f4', f4(col('v'))) \ | ||
.withColumn('f2_f1', f2(col('f1'))) \ | ||
.withColumn('f3_f1', f3(col('f1'))) \ | ||
.withColumn('f4_f1', f4(col('f1'))) \ | ||
.withColumn('f3_f2', f3(col('f2'))) \ | ||
.withColumn('f4_f2', f4(col('f2'))) \ | ||
.withColumn('f4_f3', f4(col('f3'))) \ | ||
.withColumn('f3_f2_f1', f3(col('f2_f1'))) \ | ||
.withColumn('f4_f2_f1', f4(col('f2_f1'))) \ | ||
.withColumn('f4_f3_f1', f4(col('f3_f1'))) \ | ||
.withColumn('f4_f3_f2', f4(col('f3_f2'))) \ | ||
.withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1'))) | ||
|
||
# Test mixed udfs in a single expression | ||
df_multi_2 = df \ | ||
.withColumn('f1', f1(col('v'))) \ | ||
.withColumn('f2', f2(col('v'))) \ | ||
.withColumn('f3', f3(col('v'))) \ | ||
.withColumn('f4', f4(col('v'))) \ | ||
.withColumn('f2_f1', f2(f1(col('v')))) \ | ||
.withColumn('f3_f1', f3(f1(col('v')))) \ | ||
.withColumn('f4_f1', f4(f1(col('v')))) \ | ||
.withColumn('f3_f2', f3(f2(col('v')))) \ | ||
.withColumn('f4_f2', f4(f2(col('v')))) \ | ||
.withColumn('f4_f3', f4(f3(col('v')))) \ | ||
.withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \ | ||
.withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \ | ||
.withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \ | ||
.withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \ | ||
.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v')))))) | ||
|
||
expected = df \ | ||
.withColumn('f1', df['v'] + 1) \ | ||
.withColumn('f2', df['v'] + 10) \ | ||
.withColumn('f3', df['v'] + 100) \ | ||
.withColumn('f4', df['v'] + 1000) \ | ||
.withColumn('f2_f1', df['v'] + 11) \ | ||
.withColumn('f3_f1', df['v'] + 101) \ | ||
.withColumn('f4_f1', df['v'] + 1001) \ | ||
.withColumn('f3_f2', df['v'] + 110) \ | ||
.withColumn('f4_f2', df['v'] + 1010) \ | ||
.withColumn('f4_f3', df['v'] + 1100) \ | ||
.withColumn('f3_f2_f1', df['v'] + 111) \ | ||
.withColumn('f4_f2_f1', df['v'] + 1011) \ | ||
.withColumn('f4_f3_f1', df['v'] + 1101) \ | ||
.withColumn('f4_f3_f2', df['v'] + 1110) \ | ||
.withColumn('f4_f3_f2_f1', df['v'] + 1111) | ||
|
||
self.assertEquals(expected.collect(), df_multi_1.collect()) | ||
self.assertEquals(expected.collect(), df_multi_2.collect()) | ||
|
||
def test_mixed_udf_and_sql(self): | ||
import pandas as pd | ||
from pyspark.sql import Column | ||
from pyspark.sql.functions import udf, pandas_udf | ||
|
||
df = self.spark.range(0, 1).toDF('v') | ||
|
||
# Test mixture of UDFs, Pandas UDFs and SQL expression. | ||
|
||
@udf('int') | ||
def f1(x): | ||
assert type(x) == int | ||
return x + 1 | ||
|
||
def f2(x): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like this is neither There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the purpose is to test mixing udf, pandas_udf and sql expression. I will add comments to make it clearer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added comments in test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see why it looks confusing. Can we add an assert here too (check if it's a column)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
assert type(x) == Column | ||
return x + 10 | ||
|
||
@pandas_udf('int') | ||
def f3(x): | ||
assert type(x) == pd.Series | ||
return x + 100 | ||
|
||
df1 = df.withColumn('f1', f1(df['v'])) \ | ||
.withColumn('f2', f2(df['v'])) \ | ||
.withColumn('f3', f3(df['v'])) \ | ||
.withColumn('f1_f2', f1(f2(df['v']))) \ | ||
.withColumn('f1_f3', f1(f3(df['v']))) \ | ||
.withColumn('f2_f1', f2(f1(df['v']))) \ | ||
.withColumn('f2_f3', f2(f3(df['v']))) \ | ||
.withColumn('f3_f1', f3(f1(df['v']))) \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks combination between f1 and f3 duplicating few tests in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the way the test is written is that I am trying to test many combinations so there are some dup cases. Do you prefer that I remove these? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea.. I know it's still minor since the elapsed time will be virtually the same but recently the build / test time was an issue, and I wonder if there's better way then avoding duplicated tests for now.. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was discussed here #21845 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I don't think it's necessary (we are only likely to remove a few cases and like you said, the test time is virtually the same) and helps the readability of the tests (so it doesn't look like some test cases are missed). But if that's the preferred practice I can remove duplicate cases in the next commit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am okay to leave it too here since it's clear they are virtually the same but let's remove duplicated tests or orthogonal tests next time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotcha. I will keep that in mind next time. |
||
.withColumn('f3_f2', f3(f2(df['v']))) \ | ||
.withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \ | ||
.withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \ | ||
.withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \ | ||
.withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \ | ||
.withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \ | ||
.withColumn('f3_f2_f1', f3(f2(f1(df['v'])))) | ||
|
||
expected = df.withColumn('f1', df['v'] + 1) \ | ||
.withColumn('f2', df['v'] + 10) \ | ||
.withColumn('f3', df['v'] + 100) \ | ||
.withColumn('f1_f2', df['v'] + 11) \ | ||
.withColumn('f1_f3', df['v'] + 101) \ | ||
.withColumn('f2_f1', df['v'] + 11) \ | ||
.withColumn('f2_f3', df['v'] + 110) \ | ||
.withColumn('f3_f1', df['v'] + 101) \ | ||
.withColumn('f3_f2', df['v'] + 110) \ | ||
.withColumn('f1_f2_f3', df['v'] + 111) \ | ||
.withColumn('f1_f3_f2', df['v'] + 111) \ | ||
.withColumn('f2_f1_f3', df['v'] + 111) \ | ||
.withColumn('f2_f3_f1', df['v'] + 111) \ | ||
.withColumn('f3_f1_f2', df['v'] + 111) \ | ||
.withColumn('f3_f2_f1', df['v'] + 111) | ||
|
||
self.assertEquals(expected.collect(), df1.collect()) | ||
|
||
|
||
@unittest.skipIf( | ||
not _have_pandas or not _have_pyarrow, | ||
|
@@ -5487,6 +5636,21 @@ def dummy_pandas_udf(df): | |
F.col('temp0.key') == F.col('temp1.key')) | ||
self.assertEquals(res.count(), 5) | ||
|
||
def test_mixed_scalar_udfs_followed_by_grouby_apply(self): | ||
import pandas as pd | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not a big deal at all really .. but I would swap the import order (thridparty, pyspark) |
||
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType | ||
|
||
df = self.spark.range(0, 10).toDF('v1') | ||
df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \ | ||
.withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1'])) | ||
|
||
result = df.groupby() \ | ||
.apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]), | ||
'sum int', | ||
PandasUDFType.GROUPED_MAP)) | ||
|
||
self.assertEquals(result.collect()[0]['sum'], 165) | ||
|
||
|
||
@unittest.skipIf( | ||
not _have_pandas or not _have_pyarrow, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ import scala.collection.mutable | |
import scala.collection.mutable.ArrayBuffer | ||
|
||
import org.apache.spark.api.python.PythonEvalType | ||
import org.apache.spark.sql.AnalysisException | ||
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression | ||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} | ||
|
@@ -94,36 +95,60 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { | |
*/ | ||
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { | ||
|
||
private def hasPythonUDF(e: Expression): Boolean = { | ||
private case class EvalTypeHolder(private var evalType: Int = -1) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about this: private type EvalType = Int
private type EvalTypeChecker = EvalType => Boolean
private def collectEvaluableUDFsFromExpressions(expressions: Seq[Expression]): Seq[PythonUDF] = {
// Eval type checker is set in the middle of checking because once it's found,
// the same eval type should be checked .. blah blah
var evalChecker: Option[EvalTypeChecker] = None
def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match {
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
&& evalChecker.isEmpty =>
evalChecker = Some((otherEvalType: EvalType) => otherEvalType == udf.evalType)
collectEvaluableUDFs(expr)
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
&& evalChecker.get(udf.evalType) =>
Seq(udf)
case e => e.children.flatMap(collectEvaluableUDFs)
}
expressions.flatMap(collectEvaluableUDFs)
}
def apply(plan: SparkPlan): SparkPlan = plan transformUp {
case plan: SparkPlan => extract(plan)
}
/**
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
*/
private def extract(plan: SparkPlan): SparkPlan = {
val udfs = collectEvaluableUDFsFromExpressions(plan.expressions) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see... You uses a var and nested function definition and var to remove the need of a holder object. IMHO I usually find nested function definition and function that refers to variable outside its definition scope hard to read, but it could be my personal preference. Another thing I like about the current impl the is That being said, I am ok with your suggestions too if you insist or @BryanCutler also prefers it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yup. I do avoid nested functions but I found here is where is's needed. If it's clear when it's set and unset within a function, I think the shorter one is fine. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I will update the code then. |
||
def isSet: Boolean = evalType >= 0 | ||
|
||
def set(evalType: Int): Unit = { | ||
if (isSet && evalType != this.evalType) { | ||
throw new IllegalStateException("Cannot reset eval type to a different value") | ||
} else { | ||
this.evalType = evalType | ||
} | ||
} | ||
|
||
def get(): Int = { | ||
if (!isSet) { | ||
throw new IllegalStateException("Eval type is not set") | ||
} else { | ||
evalType | ||
} | ||
} | ||
} | ||
|
||
private def hasScalarPythonUDF(e: Expression): Boolean = { | ||
e.find(PythonUDF.isScalarPythonUDF).isDefined | ||
} | ||
|
||
private def canEvaluateInPython(e: PythonUDF): Boolean = { | ||
e.children match { | ||
// single PythonUDF child could be chained and evaluated in Python | ||
case Seq(u: PythonUDF) => canEvaluateInPython(u) | ||
case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u) | ||
// Python UDF can't be evaluated directly in JVM | ||
case children => !children.exists(hasPythonUDF) | ||
case children => !children.exists(hasScalarPythonUDF) | ||
} | ||
} | ||
|
||
private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { | ||
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf) | ||
case e => e.children.flatMap(collectEvaluatableUDF) | ||
private def collectEvaluableUDFs( | ||
expr: Expression, | ||
firstEvalType: EvalTypeHolder): Seq[PythonUDF] = expr match { | ||
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) | ||
&& (!firstEvalType.isSet || firstEvalType.get == udf.evalType) | ||
&& canEvaluateInPython(udf) => | ||
firstEvalType.set(udf.evalType) | ||
Seq(udf) | ||
case e => e.children.flatMap(collectEvaluableUDFs(_, firstEvalType)) | ||
} | ||
|
||
def apply(plan: SparkPlan): SparkPlan = plan transformUp { | ||
// AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker | ||
// Therefore we don't need to extract the UDFs | ||
case plan: FlatMapGroupsInPandasExec => plan | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is no longer needed because this rule will only extract Python UDF and Scalar Pandas UDF and ignore other types of UDFs |
||
case plan: SparkPlan => extract(plan) | ||
} | ||
|
||
/** | ||
* Extract all the PythonUDFs from the current operator and evaluate them before the operator. | ||
*/ | ||
private def extract(plan: SparkPlan): SparkPlan = { | ||
val udfs = plan.expressions.flatMap(collectEvaluatableUDF) | ||
val firstEvalType = new EvalTypeHolder | ||
val udfs = plan.expressions.flatMap(collectEvaluableUDFs(_, firstEvalType)) | ||
// ignore the PythonUDF that come from second/third aggregate, which is not used | ||
.filter(udf => udf.references.subsetOf(plan.inputSet)) | ||
if (udfs.isEmpty) { | ||
|
@@ -167,7 +192,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { | |
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => | ||
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) | ||
case _ => | ||
throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs") | ||
throw new AnalysisException( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why change the exception type? Can you make a test that causes this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is because we shouldn't reach here. (Otherwise it's bug). Don't know what's the best exception type here though. |
||
"Expected either Scalar Pandas UDFs or Batched UDFs but got both") | ||
} | ||
|
||
attributeMap ++= validUdfs.zip(resultAttrs) | ||
|
@@ -205,7 +231,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { | |
case filter: FilterExec => | ||
val (candidates, nonDeterministic) = | ||
splitConjunctivePredicates(filter.condition).partition(_.deterministic) | ||
val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) | ||
val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_)) | ||
if (pushDown.nonEmpty) { | ||
val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) | ||
FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks testing udf + udf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the way the test is written is that I am trying to test many combinations so some combinations might not be mixed UDF. Do you prefer that I remove these cases?