Skip to content

Commit

Permalink
Detect "clean socket shutdowns" and stop waiting on the socket
Browse files Browse the repository at this point in the history
  • Loading branch information
aarondav committed May 6, 2014
1 parent c0c49da commit b391ff8
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 16 deletions.
51 changes: 38 additions & 13 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,10 @@ private[spark] class PythonRDD[T: ClassTag](
} catch {
case e: Exception => logWarning("Failed to close worker socket", e)
}

// The python worker must be destroyed in the event of cancellation to ensure it unblocks.
if (context.interrupted) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.toMap)
} catch {
case e: Exception => logError("Exception when trying to kill worker", e)
}
}
}

writerThread.start()
new MonitorThread(env, worker, context).start()

// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
Expand Down Expand Up @@ -136,6 +127,10 @@ private[spark] class PythonRDD[T: ClassTag](
}
} catch {

case e: Exception if context.interrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException

case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
logError("This may have been caused by a prior exception:", writerThread.exception.get)
Expand Down Expand Up @@ -164,6 +159,8 @@ private[spark] class PythonRDD[T: ClassTag](

@volatile private var _exception: Exception = null

setDaemon(true)

/** Contains the exception thrown while writing the parent iterator to the Python process. */
def exception: Option[Exception] = Option(_exception)

Expand Down Expand Up @@ -201,16 +198,44 @@ private[spark] class PythonRDD[T: ClassTag](
// Data values
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
dataOut.flush()
worker.shutdownOutput()
} catch {
case e: Exception if context.completed =>
case e: Exception if context.completed || context.interrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)

case e: Exception =>
// We must avoid throwing exceptions here, because the thread uncaught exception handler
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
Try(worker.shutdownOutput()) // kill Python worker process
} finally {
Try(worker.shutdownOutput()) // kill Python worker process
}
}
}

/**
* It is necessary to have a monitor thread for python workers if the user cancels with
* interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
* threads can block indefinitely.
*/
class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
extends Thread(s"Worker Monitor for $pythonExec") {

setDaemon(true)

override def run() {
// Kill the worker if it is interrupted, checking until task completion.
// TODO: This has a race condition if interruption occurs, as completed may still become true.
while (!context.interrupted && !context.completed) {
Thread.sleep(2000)
}
if (!context.completed) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.toMap)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False):
>>> lock = threading.Lock()
>>> def map_func(x):
... sleep(100)
... return x * x
... raise Exception("Task should have been cancelled")
>>> def start_job(x):
... global result
... try:
Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ def handle_sigchld(*args):
signal.signal(SIGCHLD, handle_sigchld)

# Blocks until the socket is closed by draining the input stream
# until it raises an exception.
# until it raises an exception or returns EOF.
def waitSocketClose(sock):
try:
while True:
sock.recv(4096)
# Empty string is returned upon EOF (and only then).
if sock.recv(4096) == '':
return
except:
pass

Expand Down

0 comments on commit b391ff8

Please sign in to comment.