Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-17756][PYTHON][STREAMING] Workaround to avoid return type mism…
…atch in PythonTransformFunction ## What changes were proposed in this pull request? This PR proposes to wrap the transformed rdd within `TransformFunction`. `PythonTransformFunction` looks requiring to return `JavaRDD` in `_jrdd`. https://github.com/apache/spark/blob/39e2bad6a866d27c3ca594d15e574a1da3ee84cc/python/pyspark/streaming/util.py#L67 https://github.com/apache/spark/blob/6ee28423ad1b2e6089b82af64a31d77d3552bb38/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala#L43 However, this could be `JavaPairRDD` by some APIs, for example, `zip` in PySpark's RDD API. `_jrdd` could be checked as below: ```python >>> rdd.zip(rdd)._jrdd.getClass().toString() u'class org.apache.spark.api.java.JavaPairRDD' ``` So, here, I wrapped it with `map` so that it ensures returning `JavaRDD`. ```python >>> rdd.zip(rdd).map(lambda x: x)._jrdd.getClass().toString() u'class org.apache.spark.api.java.JavaRDD' ``` I tried to elaborate some failure cases as below: ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]) \ .transform(lambda rdd: rdd.cartesian(rdd)) \ .pprint() ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.cartesian(rdd)) ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd)) ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd).union(rdd.zip(rdd))) ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd).coalesce(1)) ssc.start() ``` ## How was this patch tested? Unit tests were added in `python/pyspark/streaming/tests.py` and manually tested. Author: hyukjinkwon <[email protected]> Closes #19498 from HyukjinKwon/SPARK-17756.
- Loading branch information