Skip to content

Commit

Permalink
[SPARK-17756][PYTHON][STREAMING] Workaround to avoid return type mism…
Browse files Browse the repository at this point in the history
…atch in PythonTransformFunction

## What changes were proposed in this pull request?

This PR proposes to wrap the transformed rdd within `TransformFunction`. `PythonTransformFunction` looks requiring to return `JavaRDD` in `_jrdd`.

https://github.com/apache/spark/blob/39e2bad6a866d27c3ca594d15e574a1da3ee84cc/python/pyspark/streaming/util.py#L67

https://github.com/apache/spark/blob/6ee28423ad1b2e6089b82af64a31d77d3552bb38/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala#L43

However, this could be `JavaPairRDD` by some APIs, for example, `zip` in PySpark's RDD API.
`_jrdd` could be checked as below:

```python
>>> rdd.zip(rdd)._jrdd.getClass().toString()
u'class org.apache.spark.api.java.JavaPairRDD'
```

So, here, I wrapped it with `map` so that it ensures returning `JavaRDD`.

```python
>>> rdd.zip(rdd).map(lambda x: x)._jrdd.getClass().toString()
u'class org.apache.spark.api.java.JavaRDD'
```

I tried to elaborate some failure cases as below:

```python
from pyspark.streaming import StreamingContext
ssc = StreamingContext(spark.sparkContext, 10)
ssc.queueStream([sc.range(10)]) \
    .transform(lambda rdd: rdd.cartesian(rdd)) \
    .pprint()
ssc.start()
```

```python
from pyspark.streaming import StreamingContext
ssc = StreamingContext(spark.sparkContext, 10)
ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.cartesian(rdd))
ssc.start()
```

```python
from pyspark.streaming import StreamingContext
ssc = StreamingContext(spark.sparkContext, 10)
ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd))
ssc.start()
```

```python
from pyspark.streaming import StreamingContext
ssc = StreamingContext(spark.sparkContext, 10)
ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd).union(rdd.zip(rdd)))
ssc.start()
```

```python
from pyspark.streaming import StreamingContext
ssc = StreamingContext(spark.sparkContext, 10)
ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd).coalesce(1))
ssc.start()
```

## How was this patch tested?

Unit tests were added in `python/pyspark/streaming/tests.py` and manually tested.

Author: hyukjinkwon <[email protected]>

Closes #19498 from HyukjinKwon/SPARK-17756.
  • Loading branch information
HyukjinKwon committed Jun 8, 2018
1 parent 1a644af commit b070ded
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def transform(self, dstreams, transformFunc):
jdstreams = [d._jdstream for d in dstreams]
# change the final serializer to sc.serializer
func = TransformFunction(self._sc,
lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
lambda t, *rdds: transformFunc(rdds),
*[d._jrdd_deserializer for d in dstreams])
jfunc = self._jvm.TransformFunction(func)
jdstream = self._jssc.transform(jdstreams, jfunc)
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,12 @@ def func(rdds):

self.assertEqual([2, 3, 1], self._take(dstream, 3))

def test_transform_pairrdd(self):
# This regression test case is for SPARK-17756.
dstream = self.ssc.queueStream(
[[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd))
self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3))

def test_get_active(self):
self.assertEqual(StreamingContext.getActive(), None)

Expand Down
11 changes: 10 additions & 1 deletion python/pyspark/streaming/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import traceback
import sys

from py4j.java_gateway import is_instance_of

from pyspark import SparkContext, RDD


Expand Down Expand Up @@ -65,7 +67,14 @@ def call(self, milliseconds, jrdds):
t = datetime.fromtimestamp(milliseconds / 1000.0)
r = self.func(t, *rdds)
if r:
return r._jrdd
# Here, we work around to ensure `_jrdd` is `JavaRDD` by wrapping it by `map`.
# org.apache.spark.streaming.api.python.PythonTransformFunction requires to return
# `JavaRDD`; however, this could be `JavaPairRDD` by some APIs, for example, `zip`.
# See SPARK-17756.
if is_instance_of(self.ctx._gateway, r._jrdd, "org.apache.spark.api.java.JavaRDD"):
return r._jrdd
else:
return r.map(lambda x: x)._jrdd
except:
self.failure = traceback.format_exc()

Expand Down

0 comments on commit b070ded

Please sign in to comment.