Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 28, 2014
1 parent 74df565 commit 26ea396
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 104 deletions.
25 changes: 13 additions & 12 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pyspark import RDD
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.util import rddToFileName, RDDFunction, RDDFunction2
from pyspark.streaming.util import rddToFileName, RDDFunction
from pyspark.rdd import portable_hash
from pyspark.resultiterable import ResultIterable

Expand Down Expand Up @@ -141,7 +141,7 @@ def foreachRDD(self, func):
This is an output operator, so this DStream will be registered as an output
stream and there materialized.
"""
jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer)
jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self._jrdd_deserializer)
self.ctx._jvm.PythonForeachDStream(self._jdstream.dstream(), jfunc)

def pprint(self):
Expand Down Expand Up @@ -294,7 +294,7 @@ def transformWithTime(self, func):
return TransformedDStream(self, func, False)

def transformWith(self, func, other, keepSerializer=False):
jfunc = RDDFunction2(self.ctx, func, self._jrdd_deserializer)
jfunc = RDDFunction(self.ctx, lambda a, b, t: func(a, b), self._jrdd_deserializer)
dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(),
other._jdstream.dstream(), jfunc)
jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer
Expand All @@ -304,16 +304,16 @@ def repartitions(self, numPartitions):
return self.transform(lambda rdd: rdd.repartition(numPartitions))

def union(self, other):
return self.transformWith(lambda a, b, t: a.union(b), other, True)
return self.transformWith(lambda a, b: a.union(b), other, True)

def cogroup(self, other):
return self.transformWith(lambda a, b, t: a.cogroup(b), other)
return self.transformWith(lambda a, b: a.cogroup(b), other)

def leftOuterJoin(self, other):
return self.transformWith(lambda a, b, t: a.leftOuterJion(b), other)
return self.transformWith(lambda a, b: a.leftOuterJion(b), other)

def rightOuterJoin(self, other):
return self.transformWith(lambda a, b, t: a.rightOuterJoin(b), other)
return self.transformWith(lambda a, b: a.rightOuterJoin(b), other)

def _jtime(self, milliseconds):
return self.ctx._jvm.Time(milliseconds)
Expand Down Expand Up @@ -364,8 +364,8 @@ def invReduceFunc(a, b, t):
joined = a.leftOuterJoin(b, numPartitions)
return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1)

jreduceFunc = RDDFunction2(self.ctx, reduceFunc, reduced._jrdd_deserializer)
jinvReduceFunc = RDDFunction2(self.ctx, invReduceFunc, reduced._jrdd_deserializer)
jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer)
jinvReduceFunc = RDDFunction(self.ctx, invReduceFunc, reduced._jrdd_deserializer)
dstream = self.ctx._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(),
jreduceFunc, jinvReduceFunc,
self._ssc._jduration(windowDuration),
Expand All @@ -384,8 +384,8 @@ def reduceFunc(a, b, t):
(k, list(vb), list(va)[0] if len(va) else None))
return g.mapPartitions(lambda x: updateFunc(x) or [])

jreduceFunc = RDDFunction2(self.ctx, reduceFunc,
self.ctx.serializer, self._jrdd_deserializer)
jreduceFunc = RDDFunction(self.ctx, reduceFunc,
self.ctx.serializer, self._jrdd_deserializer)
dstream = self.ctx._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer)

Expand Down Expand Up @@ -417,7 +417,8 @@ def _jdstream(self):
if self._jdstream_val is not None:
return self._jdstream_val

jfunc = RDDFunction(self.ctx, self.func, self.prev._jrdd_deserializer)
func = self.func
jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self.prev._jrdd_deserializer)
jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(),
jfunc, self.reuse).asJavaDStream()
self._jdstream_val = jdstream
Expand Down
9 changes: 5 additions & 4 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,17 +346,18 @@ def test_queueStream(self):
result = dstream.collect()
self.ssc.start()
time.sleep(1)
self.assertEqual(input, result)
self.assertEqual(input, result[:3])

def test_union(self):
input = [range(i) for i in range(3)]
dstream = self.ssc.queueStream(input)
dstream2 = self.ssc.union(dstream, dstream)
result = dstream.collect()
dstream2 = self.ssc.queueStream(input)
dstream3 = self.ssc.union(dstream, dstream2)
result = dstream3.collect()
self.ssc.start()
time.sleep(1)
expected = [i * 2 for i in input]
self.assertEqual(input, result)
self.assertEqual(expected, result[:3])


if __name__ == "__main__":
Expand Down
47 changes: 8 additions & 39 deletions python/pyspark/streaming/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,65 +20,34 @@

class RDDFunction(object):
"""
This class is for py4j callback. This class is related with
org.apache.spark.streaming.api.python.PythonRDDFunction.
This class is for py4j callback.
"""
def __init__(self, ctx, func, jrdd_deserializer):
def __init__(self, ctx, func, deserializer, deserializer2=None):
self.ctx = ctx
self.func = func
self.deserializer = jrdd_deserializer

def call(self, jrdd, milliseconds):
try:
emptyRDD = getattr(self.ctx, "_emptyRDD", None)
if emptyRDD is None:
self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache()
rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else emptyRDD
r = self.func(rdd, milliseconds)
if r:
return r._jrdd
except:
import traceback
traceback.print_exc()

def __repr__(self):
return "RDDFunction(%s, %s)" % (str(self.deserializer), str(self.func))

class Java:
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']


class RDDFunction2(object):
"""
This class is for py4j callback. This class is related with
org.apache.spark.streaming.api.python.PythonRDDFunction2.
"""
def __init__(self, ctx, func, jrdd_deserializer, jrdd_deserializer2=None):
self.ctx = ctx
self.func = func
self.jrdd_deserializer = jrdd_deserializer
self.jrdd_deserializer2 = jrdd_deserializer2 or jrdd_deserializer
self.deserializer = deserializer
self.deserializer2 = deserializer2 or deserializer

def call(self, jrdd, jrdd2, milliseconds):
try:
emptyRDD = getattr(self.ctx, "_emptyRDD", None)
if emptyRDD is None:
self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache()

rdd = RDD(jrdd, self.ctx, self.jrdd_deserializer) if jrdd else emptyRDD
other = RDD(jrdd2, self.ctx, self.jrdd_deserializer2) if jrdd2 else emptyRDD
rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else emptyRDD
other = RDD(jrdd2, self.ctx, self.deserializer2) if jrdd2 else emptyRDD
r = self.func(rdd, other, milliseconds)
if r:
return r._jrdd
except:
except Exception:
import traceback
traceback.print_exc()

def __repr__(self):
return "RDDFunction2(%s)" % (str(self.func))

class Java:
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction2']
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']


def rddToFileName(prefix, suffix, time):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,50 +19,22 @@ package org.apache.spark.streaming.api.python

import java.util.{ArrayList => JArrayList}

import org.apache.spark.Partitioner
import org.apache.spark.rdd.{CoGroupedRDD, UnionRDD, PartitionerAwareUnionRDD, RDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java._
import org.apache.spark.api.python._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Interval, Duration, Time}
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.api.java._

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag


/**
* Interface for Python callback function with two arguments
*/
trait PythonRDDFunction {
def call(rdd: JavaRDD[_], time: Long): JavaRDD[Array[Byte]]
}

class RDDFunction(pfunc: PythonRDDFunction) {
def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
val jrdd = if (rdd.isDefined) {
JavaRDD.fromRDD(rdd.get)
} else {
null
}
val r = pfunc.call(jrdd, time.milliseconds)
if (r != null) {
Some(r.rdd)
} else {
None
}
}
}

/**
* Interface for Python callback function with three arguments
*/
trait PythonRDDFunction2 {
trait PythonRDDFunction {
def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]]
}

class RDDFunction2(pfunc: PythonRDDFunction2) {
class RDDFunction(pfunc: PythonRDDFunction) {
def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
val jrdd = if (rdd.isDefined) {
JavaRDD.fromRDD(rdd.get)
Expand Down Expand Up @@ -114,7 +86,7 @@ private[spark] class PythonTransformedDStream (parent: DStream[_], pfunc: Python
if (reuse && lastResult != null) {
Some(lastResult.copyTo(rdd1.get))
} else {
val r = func(rdd1, validTime)
val r = func(rdd1, None, validTime)
if (reuse && r.isDefined && lastResult == null) {
r.get match {
case rdd: PythonRDD =>
Expand All @@ -137,10 +109,10 @@ private[spark] class PythonTransformedDStream (parent: DStream[_], pfunc: Python
*/
private[spark]
class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_],
pfunc: PythonRDDFunction2)
pfunc: PythonRDDFunction)
extends DStream[Array[Byte]] (parent.ssc) {

val func = new RDDFunction2(pfunc)
val func = new RDDFunction(pfunc)

override def slideDuration: Duration = parent.slideDuration

Expand All @@ -157,10 +129,10 @@ class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_],
* similar to StateDStream
*/
private[spark]
class PythonStateDStream(parent: DStream[Array[Byte]], preduceFunc: PythonRDDFunction2)
class PythonStateDStream(parent: DStream[Array[Byte]], preduceFunc: PythonRDDFunction)
extends PythonDStream(parent) {

val reduceFunc = new RDDFunction2(preduceFunc)
val reduceFunc = new RDDFunction(preduceFunc)

super.persist(StorageLevel.MEMORY_ONLY)
override val mustCheckpoint = true
Expand All @@ -177,12 +149,12 @@ class PythonStateDStream(parent: DStream[Array[Byte]], preduceFunc: PythonRDDFun
}

/**
* Copied from ReducedWindowedDStream
* similar to ReducedWindowedDStream
*/
private[spark]
class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
preduceFunc: PythonRDDFunction2,
pinvReduceFunc: PythonRDDFunction2,
preduceFunc: PythonRDDFunction,
pinvReduceFunc: PythonRDDFunction,
_windowDuration: Duration,
_slideDuration: Duration
) extends PythonStateDStream(parent, preduceFunc) {
Expand All @@ -197,7 +169,7 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
"must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")"
)

val invReduceFunc = new RDDFunction2(pinvReduceFunc)
val invReduceFunc = new RDDFunction(pinvReduceFunc)

def windowDuration: Duration = _windowDuration
override def slideDuration: Duration = _slideDuration
Expand All @@ -209,12 +181,6 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
currentTime)
val previousWindow = currentWindow - slideDuration

logDebug("Window time = " + windowDuration)
logDebug("Slide time = " + slideDuration)
logDebug("ZeroTime = " + zeroTime)
logDebug("Current window = " + currentWindow)
logDebug("Previous window = " + previousWindow)

// _____________________________
// | previous window _________|___________________
// |___________________| current window | --------------> Time
Expand Down Expand Up @@ -271,7 +237,7 @@ class PythonForeachDStream(
prev,
(rdd: RDD[Array[Byte]], time: Time) => {
if (rdd != null) {
foreachFunction.call(rdd, time.milliseconds)
foreachFunction.call(rdd, null, time.milliseconds)
}
}
) {
Expand All @@ -283,7 +249,6 @@ class PythonForeachDStream(
/**
* similar to QueueInputStream
*/

class PythonDataInputStream(
ssc_ : JavaStreamingContext,
inputRDDs: JArrayList[JavaRDD[Array[Byte]]],
Expand All @@ -294,7 +259,7 @@ class PythonDataInputStream(
val emptyRDD = if (defaultRDD != null) {
Some(defaultRDD.rdd)
} else {
None // ssc.sparkContext.emptyRDD[Array[Byte]]
Some(ssc.sparkContext.emptyRDD[Array[Byte]])
}

def start() {}
Expand Down

0 comments on commit 26ea396

Please sign in to comment.