diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 123fa67f837e3..60bcf86783e95 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -130,48 +130,7 @@ def stop(self, stopSparkContext=True, stopGraceFully=False): # Stop Callback server SparkContext._gateway.shutdown() - def checkpoint(self, directory): - """ - Not tested - """ - self._jssc.checkpoint(directory) - def _testInputStream(self, test_inputs, numSlices=None): - """ - Generate multiple files to make "stream" in Scala side for test. - Scala chooses one of the files and generates RDD using PythonRDD.readRDDFromFile. - - QueStream maybe good way to implement this function - """ - numSlices = numSlices or self._sc.defaultParallelism - # Calling the Java parallelize() method with an ArrayList is too slow, - # because it sends O(n) Py4J commands. As an alternative, serialized - # objects are written to a file and loaded through textFile(). - - tempFiles = list() - for test_input in test_inputs: - tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) - - # Make sure we distribute data evenly if it's smaller than self.batchSize - if "__len__" not in dir(test_input): - test_input = list(test_input) # Make it a list so we can compute its length - batchSize = min(len(test_input) // numSlices, self._sc._batchSize) - if batchSize > 1: - serializer = BatchedSerializer(self._sc._unbatched_serializer, - batchSize) - else: - serializer = self._sc._unbatched_serializer - serializer.dump_stream(test_input, tempFile) - tempFile.close() - tempFiles.append(tempFile.name) - - jtempFiles = ListConverter().convert(tempFiles, SparkContext._gateway._gateway_client) - jinput_stream = self._jvm.PythonTestInputStream(self._jssc, - jtempFiles, - numSlices).asJavaDStream() - return DStream(jinput_stream, self, BatchedSerializer(PickleSerializer())) - - def _testInputStream2(self, test_inputs, numSlices=None): """ This is inpired by QueStream implementation. Give list of RDD and generate DStream which contain the RDD. @@ -184,7 +143,7 @@ def _testInputStream2(self, test_inputs, numSlices=None): test_rdd_deserializers.append(test_rdd._jrdd_deserializer) jtest_rdds = ListConverter().convert(test_rdds, SparkContext._gateway._gateway_client) - jinput_stream = self._jvm.PythonTestInputStream2(self._jssc, jtest_rdds).asJavaDStream() + jinput_stream = self._jvm.PythonTestInputStream(self._jssc, jtest_rdds).asJavaDStream() dstream = DStream(jinput_stream, self, test_rdd_deserializers[0]) return dstream diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index f96efb1fd1db7..183826cf2ef96 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -17,12 +17,13 @@ from collections import defaultdict from itertools import chain, ifilter, imap -import time import operator from pyspark.serializers import NoOpSerializer,\ BatchedSerializer, CloudPickleSerializer, pack_long from pyspark.rdd import _JavaStackTrace +from pyspark.storagelevel import StorageLevel +from pyspark.resultiterable import ResultIterable from py4j.java_collections import ListConverter, MapConverter @@ -35,6 +36,8 @@ def __init__(self, jdstream, ssc, jrdd_deserializer): self._ssc = ssc self.ctx = ssc._sc self._jrdd_deserializer = jrdd_deserializer + self.is_cached = False + self.is_checkpointed = False def context(self): """ @@ -247,8 +250,6 @@ def takeAndPrint(rdd, time): taken = rdd.take(11) print "-------------------------------------------" print "Time: %s" % (str(time)) - print rdd.glom().collect() - print "-------------------------------------------" print "-------------------------------------------" for record in taken[:10]: print record @@ -303,32 +304,65 @@ def get_output(rdd, time): self.foreachRDD(get_output) - def _test_switch_dserializer(self, serializer_que): + def cache(self): + """ + Persist this DStream with the default storage level (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + self.persist(StorageLevel.MEMORY_ONLY_SER) + return self + + def persist(self, storageLevel): + """ + Set this DStream's storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the DStream does not have a storage level set yet. + """ + self.is_cached = True + javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) + self._jdstream.persist(javaStorageLevel) + return self + + def checkpoint(self, interval): """ - Deserializer is dynamically changed based on numSlice and the number of - input. This function choose deserializer. Currently this is just FIFO. + Mark this DStream for checkpointing. It will be saved to a file inside the + checkpoint directory set with L{SparkContext.setCheckpointDir()} + + I am not sure this part in DStream + and + all references to its parent RDDs will be removed. This function must + be called before any job has been executed on this RDD. It is strongly + recommended that this RDD is persisted in memory, otherwise saving it + on a file will require recomputation. + + interval must be pysprak.streaming.duration """ - - jrdd_deserializer = self._jrdd_deserializer + self.is_checkpointed = True + self._jdstream.checkpoint(interval) + return self + + def groupByKey(self, numPartitions=None): + def createCombiner(x): + return [x] - def switch(rdd, jtime): - try: - print serializer_que - jrdd_deserializer = serializer_que.pop(0) - print jrdd_deserializer - except Exception as e: - print e + def mergeValue(xs, x): + xs.append(x) + return xs - self.foreachRDD(switch) + def mergeCombiners(a, b): + a.extend(b) + return a + return self.combineByKey(createCombiner, mergeValue, mergeCombiners, + numPartitions).mapValues(lambda x: ResultIterable(x)) # TODO: implement groupByKey +# TODO: implement saveAsTextFile + +# Following operation has dependency to transform # TODO: impelment union -# TODO: implement cache -# TODO: implement persist # TODO: implement repertitions -# TODO: implement saveAsTextFile # TODO: implement cogroup # TODO: implement join # TODO: implement countByValue @@ -355,6 +389,7 @@ def pipeline_func(split, iterator): self._prev_jdstream = prev._prev_jdstream # maintain the pipeline self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer self.is_cached = False + self.is_checkpointed = False self._ssc = prev._ssc self.ctx = prev.ctx self.prev = prev @@ -391,4 +426,4 @@ def _jdstream(self): return self._jdstream_val def _is_pipelinable(self): - return not self.is_cached + return not (self.is_cached or self.is_checkpointed)