From 080541a6d77cb85f788c297670cca24fbbc9f9b5 Mon Sep 17 00:00:00 2001 From: giwa Date: Thu, 14 Aug 2014 02:19:46 -0700 Subject: [PATCH] broke something --- python/pyspark/rdd.py | 3 ++- python/pyspark/streaming/context.py | 10 ++++++---- python/pyspark/streaming/dstream.py | 20 +++++++++++++++++++ python/pyspark/streaming_tests.py | 2 ++ python/pyspark/worker.py | 11 ++++++++++ .../streaming/api/python/PythonDStream.scala | 1 - 6 files changed, 41 insertions(+), 6 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index f64f48e3a4c9c..942382b40d28f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -283,7 +283,8 @@ def mapPartitions(self, f, preservesPartitioning=False): >>> rdd.mapPartitions(f).collect() [3, 7] """ - def func(s, iterator): return f(iterator) + def func(s, iterator): + return f(iterator) return self.mapPartitionsWithIndex(func) def mapPartitionsWithIndex(self, f, preservesPartitioning=False): diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 32b52f74e16f0..809158aedbc96 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -169,8 +169,7 @@ def _testInputStream(self, test_inputs, numSlices=None): jinput_stream = self._jvm.PythonTestInputStream(self._jssc, jtempFiles, numSlices).asJavaDStream() - return DStream(jinput_stream, self, PickleSerializer()) - + return DStream(jinput_stream, self, BatchedSerializer(PickleSerializer())) def _testInputStream2(self, test_inputs, numSlices=None): """ @@ -178,12 +177,15 @@ def _testInputStream2(self, test_inputs, numSlices=None): which contain the RDD. """ test_rdds = list() + test_rdd_deserializers = list() for test_input in test_inputs: test_rdd = self._sc.parallelize(test_input, numSlices) - print test_rdd.glom().collect() test_rdds.append(test_rdd._jrdd) + test_rdd_deserializers.append(test_rdd._jrdd_deserializer) jtest_rdds = ListConverter().convert(test_rdds, SparkContext._gateway._gateway_client) jinput_stream = self._jvm.PythonTestInputStream2(self._jssc, jtest_rdds).asJavaDStream() - return DStream(jinput_stream, self, BatchedSerializer(PickleSerializer())) + dstream = DStream(jinput_stream, self, test_rdd_deserializers[0]) + dstream._test_switch_dserializer(test_rdd_deserializers) + return dstream diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 101bfdbca0102..0a93a46d2b2a2 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -17,6 +17,7 @@ from collections import defaultdict from itertools import chain, ifilter, imap +import time import operator from pyspark.serializers import NoOpSerializer,\ @@ -289,6 +290,25 @@ def get_output(rdd, time): self.foreachRDD(get_output) + def _test_switch_dserializer(self, serializer_que): + """ + Deserializer is dynamically changed based on numSlice and the number of + input. This function choose deserializer. Currently this is just FIFO. + """ + + jrdd_deserializer = self._jrdd_deserializer + + def switch(rdd, jtime): + try: + print serializer_que + jrdd_deserializer = serializer_que.pop(0) + print jrdd_deserializer + except Exception as e: + print e + + self.foreachRDD(switch) + + # TODO: implement groupByKey # TODO: impelment union diff --git a/python/pyspark/streaming_tests.py b/python/pyspark/streaming_tests.py index e346bc227fe46..e23b86e8f040e 100644 --- a/python/pyspark/streaming_tests.py +++ b/python/pyspark/streaming_tests.py @@ -118,6 +118,8 @@ def test_count(self): test_input = [[], [1], range(1, 3), range(1, 4), range(1, 5)] def test_func(dstream): + print "count" + dstream.count().pyprint() return dstream.count() expected_output = map(lambda x: [len(x)], test_input) output = self._run_stream(test_input, test_func, expected_output) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7ca3252270d5a..8ee2f0b3a260f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,6 +23,7 @@ import time import socket import traceback +import itertools # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. from pyspark.accumulators import _accumulatorRegistry @@ -74,6 +75,16 @@ def main(infile, outfile): (func, deserializer, serializer) = command init_time = time.time() iterator = deserializer.load_stream(infile) + print "deserializer in worker: %s" % str(deserializer) + iterator, walk = itertools.tee(iterator) + if isinstance(walk, int): + print "this is int" + print walk + else: + try: + print list(walk) + except: + print list(walk) serializer.dump_stream(func(split_index, iterator), outfile) except Exception as e: # Write the error to stderr in addition to trying to pass it back to diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 20e0b0d177d0f..e8788d4579dea 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -165,7 +165,6 @@ class PythonTestInputStream(ssc_ : JavaStreamingContext, inputFiles: JArrayList[ tempFile.getAbsolutePath } } - println("PythonTestInputStreaming numPartitons" + numPartitions ) val rdd = PythonRDD.readRDDFromFile(JavaSparkContext.fromSparkContext(ssc_.sparkContext), selectedInputFile, numPartitions).rdd logInfo("Created RDD " + rdd.id + " with " + selectedInputFile) Some(rdd)