From 6f0da2fa486c2a580045a2e9e3133b6617875363 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 1 Oct 2014 00:08:54 -0700 Subject: [PATCH] recover from checkpoint --- .../apache/spark/api/python/PythonRDD.scala | 8 +- .../spark/rdd/ParallelCollectionRDD.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 8 ++ python/pyspark/context.py | 8 +- python/pyspark/streaming/context.py | 76 ++++++++++++++----- python/pyspark/streaming/tests.py | 33 ++++++++ python/pyspark/streaming/util.py | 24 ++++-- .../streaming/api/python/PythonDStream.scala | 8 +- .../streaming/dstream/QueueInputDStream.scala | 7 ++ 9 files changed, 136 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8051b221ac3d1..b093917430a59 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -42,7 +42,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils private[spark] class PythonRDD( - parent: RDD[_], + @transient parent: RDD[_], command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], @@ -61,9 +61,9 @@ private[spark] class PythonRDD( val bufferSize = conf.getInt("spark.buffer.size", 65536) val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions = parent.partitions + override def getPartitions = firstParent.partitions - override val partitioner = if (preservePartitoning) parent.partitioner else None + override val partitioner = if (preservePartitoning) firstParent.partitioner else None override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis @@ -241,7 +241,7 @@ private[spark] class PythonRDD( dataOut.writeInt(command.length) dataOut.write(command) // Data values - PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) + PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.flush() } catch { diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 66c71bf7e8bb5..1069e23241302 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -84,7 +84,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag]( private[spark] class ParallelCollectionRDD[T: ClassTag]( @transient sc: SparkContext, - @transient data: Seq[T], + data: Seq[T], numSlices: Int, locationPrefs: Map[Int, Seq[String]]) extends RDD[T](sc, Nil) { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0e90caa5c9ca7..352ce5e00d5ec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -82,6 +82,14 @@ abstract class RDD[T: ClassTag]( def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) + // setContext after loading from checkpointing + private[spark] def setContext(s: SparkContext) = { + if (sc != null && sc != s) { + throw new SparkException("Context is already set in " + this + ", cannot set it again") + } + sc = s + } + private[spark] def conf = sc.conf // ======================================================================= // Methods that should be implemented by subclasses of RDD diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 8e7b00469e246..ba930d949101d 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -68,7 +68,7 @@ class SparkContext(object): def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, - gateway=None): + gateway=None, jsc=None): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -103,14 +103,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf) + conf, jsc) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf): + conf, jsc): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size @@ -151,7 +151,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self.environment[varName] = v # Create the Java SparkContext through Py4J - self._jsc = self._initialize_context(self._conf._jconf) + self._jsc = jsc or self._initialize_context(self._conf._jconf) # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 9808361eb664f..759feda169cff 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -14,11 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import sys from py4j.java_collections import ListConverter from py4j.java_gateway import java_import -from pyspark import RDD +from pyspark import RDD, SparkConf from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer from pyspark.context import SparkContext from pyspark.storagelevel import StorageLevel @@ -75,41 +77,81 @@ class StreamingContext(object): respectively. `context.awaitTransformation()` allows the current thread to wait for the termination of the context by `stop()` or by an exception. """ + _transformerSerializer = None - def __init__(self, sparkContext, duration): + def __init__(self, sparkContext, duration=None, jssc=None): """ Create a new StreamingContext. @param sparkContext: L{SparkContext} object. @param duration: number of seconds. """ + self._sc = sparkContext self._jvm = self._sc._jvm - self._start_callback_server() - self._jssc = self._initialize_context(self._sc, duration) + self._jssc = jssc or self._initialize_context(self._sc, duration) + + def _initialize_context(self, sc, duration): + self._ensure_initialized() + return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) + + def _jduration(self, seconds): + """ + Create Duration object given number of seconds + """ + return self._jvm.Duration(int(seconds * 1000)) - def _start_callback_server(self): - gw = self._sc._gateway + @classmethod + def _ensure_initialized(cls): + SparkContext._ensure_initialized() + gw = SparkContext._gateway + # start callback server # getattr will fallback to JVM if "_callback_server" not in gw.__dict__: _daemonize_callback_server() gw._start_callback_server(gw._python_proxy_port) - gw._python_proxy_port = gw._callback_server.port # update port with real port - def _initialize_context(self, sc, duration): - java_import(self._jvm, "org.apache.spark.streaming.*") - java_import(self._jvm, "org.apache.spark.streaming.api.java.*") - java_import(self._jvm, "org.apache.spark.streaming.api.python.*") + java_import(gw.jvm, "org.apache.spark.streaming.*") + java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") + java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") # register serializer for RDDFunction - ser = RDDFunctionSerializer(self._sc, CloudPickleSerializer()) - self._jvm.PythonDStream.registerSerializer(ser) - return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) + # it happens before creating SparkContext when loading from checkpointing + cls._transformerSerializer = RDDFunctionSerializer(SparkContext._active_spark_context, + CloudPickleSerializer(), gw) + gw.jvm.PythonDStream.registerSerializer(cls._transformerSerializer) - def _jduration(self, seconds): + @classmethod + def getOrCreate(cls, path, setupFunc): """ - Create Duration object given number of seconds + Get the StreamingContext from checkpoint file at `path`, or setup + it by `setupFunc`. + + :param path: directory of checkpoint + :param setupFunc: a function used to create StreamingContext and + setup DStreams. + :return: a StreamingContext """ - return self._jvm.Duration(int(seconds * 1000)) + if not os.path.exists(path) or not os.path.isdir(path) or not os.listdir(path): + ssc = setupFunc() + ssc.checkpoint(path) + return ssc + + cls._ensure_initialized() + gw = SparkContext._gateway + + try: + jssc = gw.jvm.JavaStreamingContext(path) + except Exception: + print >>sys.stderr, "failed to load StreamingContext from checkpoint" + raise + + jsc = jssc.sparkContext() + conf = SparkConf(_jconf=jsc.getConf()) + sc = SparkContext(conf=conf, gateway=gw, jsc=jsc) + # update ctx in serializer + SparkContext._active_spark_context = sc + cls._transformerSerializer.ctx = sc + return StreamingContext(sc, None, jssc) @property def sparkContext(self): diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index bd6d92255dbc6..00fea041d0be3 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -493,5 +493,38 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) +class TestCheckpoint(PySparkStreamingTestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_get_or_create(self): + result = [0] + + def setup(): + conf = SparkConf().set("spark.default.parallelism", 1) + sc = SparkContext(conf=conf) + ssc = StreamingContext(sc, .2) + rdd = sc.parallelize(range(10), 1) + dstream = ssc.queueStream([rdd], default=rdd) + result[0] = self._collect(dstream.countByWindow(1, .2)) + return ssc + tmpd = tempfile.mkdtemp("test_streaming_cps") + ssc = StreamingContext.getOrCreate(tmpd, setup) + ssc.start() + ssc.awaitTermination(4) + ssc.stop() + expected = [[i * 10 + 10] for i in range(5)] + [[50]] * 5 + self.assertEqual(expected, result[0][:10]) + + ssc = StreamingContext.getOrCreate(tmpd, setup) + ssc.start() + ssc.awaitTermination(2) + ssc.stop() + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index c15f9d98c1866..4cfaa3fc50e18 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -18,24 +18,31 @@ from datetime import datetime import traceback -from pyspark.rdd import RDD +from pyspark import SparkContext, RDD class RDDFunction(object): """ This class is for py4j callback. """ + _emptyRDD = None + def __init__(self, ctx, func, *deserializers): self.ctx = ctx self.func = func self.deserializers = deserializers - emptyRDD = getattr(self.ctx, "_emptyRDD", None) - if emptyRDD is None: - self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache() - self.emptyRDD = emptyRDD + + @property + def emptyRDD(self): + if self._emptyRDD is None and self.ctx: + self._emptyRDD = self.ctx.parallelize([]).cache() + return self._emptyRDD def call(self, milliseconds, jrdds): try: + if self.ctx is None: + self.ctx = SparkContext._active_spark_context + # extend deserializers with the first one sers = self.deserializers if len(sers) < len(jrdds): @@ -51,20 +58,21 @@ def call(self, milliseconds, jrdds): traceback.print_exc() def __repr__(self): - return "RDDFunction(%s)" % (str(self.func)) + return "RDDFunction(%s)" % self.func class Java: implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction'] class RDDFunctionSerializer(object): - def __init__(self, ctx, serializer): + def __init__(self, ctx, serializer, gateway=None): self.ctx = ctx self.serializer = serializer + self.gateway = gateway or self.ctx._gateway def dumps(self, id): try: - func = self.ctx._gateway.gateway_property.pool[id] + func = self.gateway.gateway_property.pool[id] return bytearray(self.serializer.dumps((func.func, func.deserializers))) except Exception: traceback.print_exc() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index f2ed0c507c2b7..48d1f2ae17e8c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -77,7 +77,7 @@ private[python] class RDDFunction(@transient var pfunc: PythonRDDFunction) } /** - * Inferface for Python Serializer to serialize PythonRDDFunction + * Interface for Python Serializer to serialize PythonRDDFunction */ private[python] trait PythonRDDFunctionSerializer { def dumps(id: String): Array[Byte] // @@ -91,9 +91,9 @@ private[python] class RDDFunctionSerializer(pser: PythonRDDFunctionSerializer) { def serialize(func: PythonRDDFunction): Array[Byte] = { // get the id of PythonRDDFunction in py4j val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) - val f = h.getClass().getDeclaredField("id"); - f.setAccessible(true); - val id = f.get(h).asInstanceOf[String]; + val f = h.getClass().getDeclaredField("id") + f.setAccessible(true) + val id = f.get(h).asInstanceOf[String] pser.dumps(id) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index ed7da6dc1315e..0557ac87b5a1e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.dstream +import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.rdd.UnionRDD import scala.collection.mutable.Queue @@ -32,6 +33,12 @@ class QueueInputDStream[T: ClassTag]( defaultRDD: RDD[T] ) extends InputDStream[T](ssc) { + private[streaming] override def setContext(s: StreamingContext) { + super.setContext(s) + queue.map(_.setContext(s.sparkContext)) + defaultRDD.setContext(s.sparkContext) + } + override def start() { } override def stop() { }