diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index f7e356319ecac..dbb6fdf1694ad 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -72,7 +72,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, # Callback sever is need only by SparkStreming; therefore the callback sever # is started in StreamingContext. SparkContext._gateway.restart_callback_server() - self._clean_up_trigger() + self._set_clean_up_trigger() self._jvm = self._sc._jvm self._jssc = self._initialize_context(self._sc._jsc, duration._jduration) @@ -80,13 +80,11 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, def _initialize_context(self, jspark_context, jduration): return self._jvm.JavaStreamingContext(jspark_context, jduration) - def _clean_up_trigger(self): + def _set_clean_up_trigger(self): """Kill py4j callback server properly using signal lib""" def clean_up_handler(*args): # Make sure stop callback server. - # This need improvement how to terminate callback sever properly. - SparkContext._gateway._shutdown_callback_server() SparkContext._gateway.shutdown() sys.exit(0) @@ -132,18 +130,15 @@ def stop(self, stopSparkContext=True, stopGraceFully=False): Stop the execution of the streams immediately (does not wait for all received data to be processed). """ - try: self._jssc.stop(stopSparkContext, stopGraceFully) finally: - # Stop Callback server - SparkContext._gateway._shutdown_callback_server() SparkContext._gateway.shutdown() def _testInputStream(self, test_inputs, numSlices=None): """ This function is only for unittest. - It requires a sequence as input, and returns the i_th element at the i_th batch + It requires a list as input, and returns the i_th element at the i_th batch under manual clock. """ test_rdds = list() diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 0b01a9f02f51f..22a2751138c41 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -207,7 +207,7 @@ def _defaultReducePartitions(self): """ Returns the default number of partitions to use during reduce tasks (e.g., groupBy). If spark.default.parallelism is set, then we'll use the value from SparkContext - defaultParallelism, otherwise we'll use the number of partitions in this RDD. + defaultParallelism, otherwise we'll use the number of partitions in this RDD This mirrors the behavior of the Scala Partitioner#defaultPartitioner, intended to reduce the likelihood of OOMs. Once PySpark adopts Partitioner-based APIs, this behavior will @@ -222,7 +222,8 @@ def getNumPartitions(self): """ Return the number of partitions in RDD """ - # TODO: remove hardcoding. RDD has NumPartitions but DStream does not have. + # TODO: remove hardcoding. RDD has NumPartitions. How do we get the number of partition + # through DStream? return 2 def foreachRDD(self, func): @@ -243,6 +244,10 @@ def pyprint(self): operator, so this DStream will be registered as an output stream and there materialized. """ def takeAndPrint(rdd, time): + """ + Closure to take element from RDD and print first 10 elements. + This closure is called by py4j callback server. + """ taken = rdd.take(11) print "-------------------------------------------" print "Time: %s" % (str(time)) @@ -307,17 +312,11 @@ def checkpoint(self, interval): Mark this DStream for checkpointing. It will be saved to a file inside the checkpoint directory set with L{SparkContext.setCheckpointDir()} - I am not sure this part in DStream - and - all references to its parent RDDs will be removed. This function must - be called before any job has been executed on this RDD. It is strongly - recommended that this RDD is persisted in memory, otherwise saving it - on a file will require recomputation. - - interval must be pysprak.streaming.duration + @param interval: Time interval after which generated RDD will be checkpointed + interval has to be pyspark.streaming.duration.Duration """ self.is_checkpointed = True - self._jdstream.checkpoint(interval) + self._jdstream.checkpoint(interval._jduration) return self def groupByKey(self, numPartitions=None): @@ -369,6 +368,10 @@ def saveAsTextFiles(self, prefix, suffix=None): Save this DStream as a text file, using string representations of elements. """ def saveAsTextFile(rdd, time): + """ + Closure to save element in RDD in DStream as Pickled data in file. + This closure is called by py4j callback server. + """ path = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(path) @@ -410,9 +413,10 @@ def get_output(rdd, time): # TODO: implement countByWindow # TODO: implement reduceByWindow -# Following operation has dependency to transform +# transform Operation # TODO: implement transform # TODO: implement transformWith +# Following operation has dependency with transform # TODO: implement union # TODO: implement repertitions # TODO: implement cogroup