Skip to content

Commit

Permalink
[SPARK-17514] df.take(1) and df.limit(1).collect() should perform the…
Browse files Browse the repository at this point in the history
… same in Python

## What changes were proposed in this pull request?

In PySpark, `df.take(1)` runs a single-stage job which computes only one partition of the DataFrame, while `df.limit(1).collect()` computes all partitions and runs a two-stage job. This difference in performance is confusing.

The reason why `limit(1).collect()` is so much slower is that `collect()` internally maps to `df.rdd.<some-pyspark-conversions>.toLocalIterator`, which causes Spark SQL to build a query where a global limit appears in the middle of the plan; this, in turn, ends up being executed inefficiently because limits in the middle of plans are now implemented by repartitioning to a single task rather than by running a `take()` job on the driver (this was done in #7334, a patch which was a prerequisite to allowing partition-local limits to be pushed beneath unions, etc.).

In order to fix this performance problem I think that we should generalize the fix from SPARK-10731 / #8876 so that `DataFrame.collect()` also delegates to the Scala implementation and shares the same performance properties. This patch modifies `DataFrame.collect()` to first collect all results to the driver and then pass them to Python, allowing this query to be planned using Spark's `CollectLimit` optimizations.

## How was this patch tested?

Added a regression test in `sql/tests.py` which asserts that the expected number of jobs, stages, and tasks are run for both queries.

Author: Josh Rosen <[email protected]>

Closes #15068 from JoshRosen/pyspark-collect-limit.
  • Loading branch information
JoshRosen authored and davies committed Sep 14, 2016
1 parent 52738d4 commit 6d06ff6
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 18 deletions.
5 changes: 1 addition & 4 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,7 @@ def take(self, num):
>>> df.take(2)
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe(
self._jdf, num)
return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
return self.limit(num).collect()

@since(1.3)
def foreach(self, f):
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,6 +1862,24 @@ def test_collect_functions(self):
sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r),
["1", "2", "2", "2"])

def test_limit_and_take(self):
df = self.spark.range(1, 1000, numPartitions=10)

def assert_runs_only_one_job_stage_and_task(job_group_name, f):
tracker = self.sc.statusTracker()
self.sc.setJobGroup(job_group_name, description="")
f()
jobs = tracker.getJobIdsForGroup(job_group_name)
self.assertEqual(1, len(jobs))
stages = tracker.getJobInfo(jobs[0]).stageIds
self.assertEqual(1, len(stages))
self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks)

# Regression test for SPARK-10731: take should delegate to Scala implementation
assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1))
# Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n)
assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect())


if __name__ == "__main__":
from pyspark.sql.tests import *
Expand Down
8 changes: 6 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
import org.apache.spark.api.python.PythonRDD
import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst._
Expand Down Expand Up @@ -2567,8 +2567,12 @@ class Dataset[T] private[sql](
}

private[sql] def collectToPython(): Int = {
EvaluatePython.registerPicklers()
withNewExecutionId {
PythonRDD.collectAndServe(javaToPython.rdd)
val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
val iter = new SerDeUtil.AutoBatchedPickler(
queryExecution.executedPlan.executeCollect().iterator.map(toJava))
PythonRDD.serveIterator(iter, "serve-DataFrame")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,15 @@ import scala.collection.JavaConverters._

import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler}

import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

object EvaluatePython {
def takeAndServe(df: DataFrame, n: Int): Int = {
registerPicklers()
df.withNewExecutionId {
val iter = new SerDeUtil.AutoBatchedPickler(
df.queryExecution.executedPlan.executeTake(n).iterator.map { row =>
EvaluatePython.toJava(row, df.schema)
})
PythonRDD.serveIterator(iter, s"serve-DataFrame")
}
}

def needConversionInPython(dt: DataType): Boolean = dt match {
case DateType | TimestampType => true
Expand Down

0 comments on commit 6d06ff6

Please sign in to comment.