diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index dc012cc381346..fc4812753d005 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -42,9 +42,13 @@ class TaskContext( // List of callback functions to execute when the task completes. @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit] + // Set to true when the task is completed, before the onCompleteCallbacks are executed. + @volatile var completed: Boolean = false + /** * Add a callback function to be executed on task completion. An example use * is for HadoopRDD to register a callback to close the input stream. + * Will be called in any situation - success, failure, or cancellation. * @param f Callback function. */ def addOnCompleteCallback(f: () => Unit) { @@ -52,6 +56,7 @@ class TaskContext( } def executeOnCompleteCallbacks() { + completed = true // Process complete callbacks in the reverse order of registration onCompleteCallbacks.reverse.foreach{_()} } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 61407007087c6..928460581ed2d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -56,122 +56,46 @@ private[spark] class PythonRDD[T: ClassTag]( val env = SparkEnv.get val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) - // Ensure worker socket is closed on task completion. Closing sockets is idempotent. - context.addOnCompleteCallback(() => + // Start a thread to feed the process input from our parent's iterator + val writerThread = new WriterThread(env, worker, split, context) + + context.addOnCompleteCallback { () => + writerThread.shutdownOnTaskCompletion() + + // Cleanup the worker socket. This will also cause the Python worker to exit. try { worker.close() } catch { case e: Exception => logWarning("Failed to close worker socket", e) } - ) - - @volatile var readerException: Exception = null - // Start a thread to feed the process input from our parent's iterator - new Thread("stdin writer for " + pythonExec) { - override def run() { + // The python worker must be destroyed in the event of cancellation to ensure it unblocks. + if (context.interrupted) { try { - SparkEnv.set(env) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) - // Partition index - dataOut.writeInt(split.index) - // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) - // Broadcast variables - dataOut.writeInt(broadcastVars.length) - for (broadcast <- broadcastVars) { - dataOut.writeLong(broadcast.id) - dataOut.writeInt(broadcast.value.length) - dataOut.write(broadcast.value) - } - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.length) - for (include <- pythonIncludes) { - PythonRDD.writeUTF(include, dataOut) - } - dataOut.flush() - // Serialized command: - dataOut.writeInt(command.length) - dataOut.write(command) - // Data values - PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) - dataOut.flush() - worker.shutdownOutput() + logWarning("Incomplete task interrupted: Attempting to kill Python Worker") + env.destroyPythonWorker(pythonExec, envVars.toMap) } catch { - - case e: java.io.FileNotFoundException => - readerException = e - Try(worker.shutdownOutput()) // kill Python worker process - - case e: IOException => - // This can happen for legitimate reasons if the Python code stops returning data - // before we are done passing elements through, e.g., for take(). Just log a message to - // say it happened (as it could also be hiding a real IOException from a data source). - logInfo("stdin writer to Python finished early (may not be an error)", e) - - case e: Exception => - // We must avoid throwing exceptions here, because the thread uncaught exception handler - // will kill the whole executor (see Executor). - readerException = e - Try(worker.shutdownOutput()) // kill Python worker process + case e: Exception => logError("Exception when trying to kill worker", e) } } - }.start() - - // Necessary to distinguish between a task that has failed and a task that is finished - @volatile var complete: Boolean = false - - // It is necessary to have a monitor thread for python workers if the user cancels with - // interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the - // threads can block indefinitely. - new Thread(s"Worker Monitor for $pythonExec") { - override def run() { - // Kill the worker if it is interrupted or completed - // When a python task completes, the context is always set to interupted - while (!context.interrupted) { - Thread.sleep(2000) - } - if (!complete) { - try { - logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.toMap) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) - } - } - } - }.start() - - /* - * Partial fix for SPARK-1019: Attempts to stop reading the input stream since - * other completion callbacks might invalidate the input. Because interruption - * is not synchronous this still leaves a potential race where the interruption is - * processed only after the stream becomes invalid. - */ - context.addOnCompleteCallback{ () => - complete = true // Indicate that the task has completed successfully - context.interrupted = true } + writerThread.start() + // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) val stdoutIterator = new Iterator[Array[Byte]] { def next(): Array[Byte] = { val obj = _nextObj if (hasNext) { - // FIXME: can deadlock if worker is waiting for us to - // respond to current message (currently irrelevant because - // output is shutdown before we read any input) _nextObj = read() } obj } private def read(): Array[Byte] = { - if (readerException != null) { - throw readerException + if (writerThread.exception.isDefined) { + throw writerThread.exception.get } try { stream.readInt() match { @@ -190,13 +114,14 @@ private[spark] class PythonRDD[T: ClassTag]( val total = finishTime - startTime logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) - read + read() case SpecialLengths.PYTHON_EXCEPTION_THROWN => // Signals that an exception has been thrown in python val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) - throw new PythonException(new String(obj, "utf-8"), readerException) + throw new PythonException(new String(obj, "utf-8"), + writerThread.exception.getOrElse(null)) case SpecialLengths.END_OF_DATA_SECTION => // We've finished the data section of the output, but we can still // read some accumulator updates: @@ -210,10 +135,11 @@ private[spark] class PythonRDD[T: ClassTag]( Array.empty[Byte] } } catch { - case e: Exception if readerException != null => + + case e: Exception if writerThread.exception.isDefined => logError("Python worker exited unexpectedly (crashed)", e) - logError("Python crash may have been caused by prior exception:", readerException) - throw readerException + logError("This may have been caused by a prior exception:", writerThread.exception.get) + throw writerThread.exception.get case eof: EOFException => throw new SparkException("Python worker exited unexpectedly (crashed)", eof) @@ -224,10 +150,70 @@ private[spark] class PythonRDD[T: ClassTag]( def hasNext = _nextObj.length != 0 } - stdoutIterator + new InterruptibleIterator(context, stdoutIterator) } val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) + + /** + * The thread responsible for writing the data from the PythonRDD's parent iterator to the + * Python process. + */ + class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext) + extends Thread(s"stdout writer for $pythonExec") { + + @volatile private var _exception: Exception = null + + /** Contains the exception thrown while writing the parent iterator to the Python process. */ + def exception: Option[Exception] = Option(_exception) + + /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ + def shutdownOnTaskCompletion() { + assert(context.completed) + this.interrupt() + } + + override def run() { + try { + SparkEnv.set(env) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + // Partition index + dataOut.writeInt(split.index) + // sparkFilesDir + PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) + // Broadcast variables + dataOut.writeInt(broadcastVars.length) + for (broadcast <- broadcastVars) { + dataOut.writeLong(broadcast.id) + dataOut.writeInt(broadcast.value.length) + dataOut.write(broadcast.value) + } + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.length) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } + dataOut.flush() + // Serialized command: + dataOut.writeInt(command.length) + dataOut.write(command) + // Data values + PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) + dataOut.flush() + worker.shutdownOutput() + } catch { + case e: Exception if context.completed => + logDebug("Exception thrown after task completion (likely due to cleanup)", e) + + case e: Exception => + // We must avoid throwing exceptions here, because the thread uncaught exception handler + // will kill the whole executor (see org.apache.spark.executor.Executor). + _exception = e + Try(worker.shutdownOutput()) // kill Python worker process + } + } + } } /** Thrown for exceptions in user Python code. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 02b62de7e36b6..2259df0b56bad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -17,11 +17,13 @@ package org.apache.spark.scheduler +import scala.language.existentials + import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashMap -import scala.language.existentials +import scala.util.Try import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics @@ -196,7 +198,11 @@ private[spark] class ShuffleMapTask( } finally { // Release the writers back to the shuffle block manager. if (shuffle != null && shuffle.writers != null) { - shuffle.releaseWriters(success) + try { + shuffle.releaseWriters(success) + } catch { + case e: Exception => logError("Failed to release shuffle writers", e) + } } // Execute the callbacks on task completion. context.executeOnCompleteCallbacks() diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index eb18ec08c9139..ef730dab664b7 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -74,6 +74,15 @@ def handle_sigchld(*args): raise signal.signal(SIGCHLD, handle_sigchld) + # Blocks until the socket is closed by draining the input stream + # until it raises an exception. + def waitSocketClose(sock): + try: + while True: + sock.recv(4096) + except: + pass + # Handle clients while not should_exit(): # Wait until a client arrives or we have to exit @@ -105,7 +114,8 @@ def handle_sigchld(*args): exit_code = exc.code finally: outfile.flush() - sock.close() + # The Scala side will close the socket upon task completion. + waitSocketClose(sock) os._exit(compute_real_exit_code(exit_code)) else: sock.close()