Skip to content

Commit

Permalink
support transform(), refactor and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 25, 2014
1 parent df098fc commit 7f53086
Show file tree
Hide file tree
Showing 11 changed files with 384 additions and 905 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ private[spark] class PythonRDD(
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {

def copyTo(rdd: RDD[_]): PythonRDD = {
new PythonRDD(rdd, command, envVars, pythonIncludes, preservePartitoning, pythonExec, broadcastVars, accumulator)
}

val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)

Expand Down
114 changes: 33 additions & 81 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer
from pyspark.context import SparkContext
from pyspark.streaming.dstream import DStream
from pyspark.streaming.duration import Duration
from pyspark.streaming.duration import Duration, Seconds

from py4j.java_collections import ListConverter

Expand All @@ -35,68 +35,31 @@ class StreamingContext(object):
broadcast variables on that cluster.
"""

def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None,
gateway=None, sparkContext=None, duration=None):
def __init__(self, sparkContext, duration):
"""
Create a new StreamingContext. At least the master and app name and duration
should be set, either through the named parameters here or through C{conf}.
@param master: Cluster URL to connect to
(e.g. mesos://host:port, spark://host:port, local[4]).
@param appName: A name for your job, to display on the cluster web UI.
@param sparkHome: Location where Spark is installed on cluster nodes.
@param pyFiles: Collection of .zip or .py files to send to the cluster
and add to PYTHONPATH. These can be paths on the local file
system or HDFS, HTTP, HTTPS, or FTP URLs.
@param environment: A dictionary of environment variables to set on
worker nodes.
@param batchSize: The number of Python objects represented as a single
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
@param serializer: The serializer for RDDs.
@param conf: A L{SparkConf} object setting Spark properties.
@param gateway: Use an existing gateway and JVM, otherwise a new JVM
will be instatiated.
@param sparkContext: L{SparkContext} object.
@param duration: A L{Duration} object for SparkStreaming.
@param duration: A L{Duration} object or seconds for SparkStreaming.
"""
if isinstance(duration, (int, long, float)):
duration = Seconds(duration)

if not isinstance(duration, Duration):
raise TypeError("Input should be pyspark.streaming.duration.Duration object")

if sparkContext is None:
# Create the Python Sparkcontext
self._sc = SparkContext(master=master, appName=appName, sparkHome=sparkHome,
pyFiles=pyFiles, environment=environment, batchSize=batchSize,
serializer=serializer, conf=conf, gateway=gateway)
else:
self._sc = sparkContext

# Start py4j callback server.
# Callback sever is need only by SparkStreming; therefore the callback sever
# is started in StreamingContext.
SparkContext._gateway.restart_callback_server()
self._set_clean_up_handler()
self._sc = sparkContext
self._jvm = self._sc._jvm
self._jssc = self._initialize_context(self._sc._jsc, duration._jduration)
self._start_callback_server()
self._jssc = self._initialize_context(self._sc, duration)

# Initialize StremaingContext in function to allow subclass specific initialization
def _initialize_context(self, jspark_context, jduration):
return self._jvm.JavaStreamingContext(jspark_context, jduration)
def _start_callback_server(self):
gw = self._sc._gateway
# getattr will fallback to JVM
if "_callback_server" not in gw.__dict__:
gw._start_callback_server(gw._python_proxy_port)

def _set_clean_up_handler(self):
""" set clean up hander using atexit """

def clean_up_handler():
SparkContext._gateway.shutdown()

atexit.register(clean_up_handler)
# atext is not called when the program is killed by a signal not handled by
# Python.
for sig in (SIGINT, SIGTERM):
signal(sig, clean_up_handler)
def _initialize_context(self, sc, duration):
return self._jvm.JavaStreamingContext(sc._jsc, duration._jduration)

@property
def sparkContext(self):
Expand All @@ -121,17 +84,26 @@ def awaitTermination(self, timeout=None):
else:
self._jssc.awaitTermination(timeout)

def stop(self, stopSparkContext=True, stopGraceFully=False):
"""
Stop the execution of the streams immediately (does not wait for all received data
to be processed).
"""
self._jssc.stop(stopSparkContext, stopGraceFully)
if stopSparkContext:
self._sc.stop()

def remember(self, duration):
"""
Set each DStreams in this context to remember RDDs it generated in the last given duration.
DStreams remember RDDs only for a limited duration of time and releases them for garbage
collection. This method allows the developer to specify how to long to remember the RDDs (
if the developer wishes to query old data outside the DStream computation).
@param duration pyspark.streaming.duration.Duration object.
@param duration pyspark.streaming.duration.Duration object or seconds.
Minimum duration that each DStream should remember its RDDs
"""
if not isinstance(duration, Duration):
raise TypeError("Input should be pyspark.streaming.duration.Duration object")
if isinstance(duration, (int, long, float)):
duration = Seconds(duration)

self._jssc.remember(duration._jduration)

Expand All @@ -153,34 +125,14 @@ def textFileStream(self, directory):
"""
return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())

def stop(self, stopSparkContext=True, stopGraceFully=False):
"""
Stop the execution of the streams immediately (does not wait for all received data
to be processed).
"""
self._jssc.stop(stopSparkContext, stopGraceFully)
if stopSparkContext:
self._sc.stop()

# Shutdown only callback server and all py3j client is shutdowned
# clean up handler
SparkContext._gateway._shutdown_callback_server()

def _testInputStream(self, test_inputs, numSlices=None):
def _makeStream(self, inputs, numSlices=None):
"""
This function is only for unittest.
It requires a list as input, and returns the i_th element at the i_th batch
under manual clock.
"""
test_rdds = list()
test_rdd_deserializers = list()
for test_input in test_inputs:
test_rdd = self._sc.parallelize(test_input, numSlices)
test_rdds.append(test_rdd._jrdd)
test_rdd_deserializers.append(test_rdd._jrdd_deserializer)
# All deserializers have to be the same.
# TODO: add deserializer validation
jtest_rdds = ListConverter().convert(test_rdds, SparkContext._gateway._gateway_client)
jinput_stream = self._jvm.PythonTestInputStream(self._jssc, jtest_rdds).asJavaDStream()

return DStream(jinput_stream, self, test_rdd_deserializers[0])
rdds = [self._sc.parallelize(input, numSlices) for input in inputs]
jrdds = ListConverter().convert([r._jrdd for r in rdds],
SparkContext._gateway._gateway_client)
jdstream = self._jvm.PythonDataInputStream(self._jssc, jrdds).asJavaDStream()
return DStream(jdstream, self, rdds[0]._jrdd_deserializer)
Loading

0 comments on commit 7f53086

Please sign in to comment.