-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tighten up field/method visibility in Executor and made some code more clear to read. #4850
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ import java.io.File | |
import java.lang.management.ManagementFactory | ||
import java.net.URL | ||
import java.nio.ByteBuffer | ||
import java.util.concurrent._ | ||
import java.util.concurrent.ConcurrentHashMap | ||
|
||
import scala.collection.JavaConversions._ | ||
import scala.collection.mutable.{ArrayBuffer, HashMap} | ||
|
@@ -31,24 +31,26 @@ import akka.actor.Props | |
|
||
import org.apache.spark._ | ||
import org.apache.spark.deploy.SparkHadoopUtil | ||
import org.apache.spark.scheduler._ | ||
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} | ||
import org.apache.spark.shuffle.FetchFailedException | ||
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} | ||
import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, | ||
SparkUncaughtExceptionHandler, AkkaUtils, Utils} | ||
import org.apache.spark.util._ | ||
|
||
/** | ||
* Spark executor used with Mesos, YARN, and the standalone scheduler. | ||
* In coarse-grained mode, an existing actor system is provided. | ||
* Spark executor, backed by a threadpool to run tasks. | ||
* | ||
* This can be used with Mesos, YARN, and the standalone scheduler. | ||
* An internal RPC interface (at the moment Akka) is used for communication with the driver, | ||
* except in the case of Mesos fine-grained mode. | ||
*/ | ||
private[spark] class Executor( | ||
executorId: String, | ||
executorHostname: String, | ||
env: SparkEnv, | ||
userClassPath: Seq[URL] = Nil, | ||
isLocal: Boolean = false) | ||
extends Logging | ||
{ | ||
extends Logging { | ||
|
||
logInfo(s"Starting executor ID $executorId on host $executorHostname") | ||
|
||
// Application dependencies (added through SparkContext) that we've fetched so far on this node. | ||
|
@@ -78,9 +80,8 @@ private[spark] class Executor( | |
} | ||
|
||
// Start worker thread pool | ||
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") | ||
|
||
val executorSource = new ExecutorSource(this, executorId) | ||
private val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") | ||
private val executorSource = new ExecutorSource(threadPool, executorId) | ||
|
||
if (!isLocal) { | ||
env.metricsSystem.registerSource(executorSource) | ||
|
@@ -122,21 +123,21 @@ private[spark] class Executor( | |
taskId: Long, | ||
attemptNumber: Int, | ||
taskName: String, | ||
serializedTask: ByteBuffer) { | ||
serializedTask: ByteBuffer): Unit = { | ||
val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName, | ||
serializedTask) | ||
runningTasks.put(taskId, tr) | ||
threadPool.execute(tr) | ||
} | ||
|
||
def killTask(taskId: Long, interruptThread: Boolean) { | ||
def killTask(taskId: Long, interruptThread: Boolean): Unit = { | ||
val tr = runningTasks.get(taskId) | ||
if (tr != null) { | ||
tr.kill(interruptThread) | ||
} | ||
} | ||
|
||
def stop() { | ||
def stop(): Unit = { | ||
env.metricsSystem.report() | ||
env.actorSystem.stop(executorActor) | ||
isStopped = true | ||
|
@@ -146,7 +147,10 @@ private[spark] class Executor( | |
} | ||
} | ||
|
||
private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum | ||
/** Returns the total amount of time this JVM process has spent in garbage collection. */ | ||
private def computeTotalGcTime(): Long = { | ||
ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum | ||
} | ||
|
||
class TaskRunner( | ||
execBackend: ExecutorBackend, | ||
|
@@ -156,27 +160,34 @@ private[spark] class Executor( | |
serializedTask: ByteBuffer) | ||
extends Runnable { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This |
||
|
||
/** Whether this task has been killed. */ | ||
@volatile private var killed = false | ||
@volatile var task: Task[Any] = _ | ||
@volatile var attemptedTask: Option[Task[Any]] = None | ||
|
||
/** How much the JVM process has spent in GC when the task starts to run. */ | ||
@volatile var startGCTime: Long = _ | ||
|
||
def kill(interruptThread: Boolean) { | ||
/** | ||
* The task to run. This will be set in run() by deserializing the task binary coming | ||
* from the driver. Once it is set, it will never be changed. | ||
*/ | ||
@volatile var task: Task[Any] = _ | ||
|
||
def kill(interruptThread: Boolean): Unit = { | ||
logInfo(s"Executor is trying to kill $taskName (TID $taskId)") | ||
killed = true | ||
if (task != null) { | ||
task.kill(interruptThread) | ||
} | ||
} | ||
|
||
override def run() { | ||
override def run(): Unit = { | ||
val deserializeStartTime = System.currentTimeMillis() | ||
Thread.currentThread.setContextClassLoader(replClassLoader) | ||
val ser = env.closureSerializer.newInstance() | ||
logInfo(s"Running $taskName (TID $taskId)") | ||
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) | ||
var taskStart: Long = 0 | ||
startGCTime = gcTime | ||
startGCTime = computeTotalGcTime() | ||
|
||
try { | ||
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) | ||
|
@@ -193,7 +204,6 @@ private[spark] class Executor( | |
throw new TaskKilledException | ||
} | ||
|
||
attemptedTask = Some(task) | ||
logDebug("Task " + taskId + "'s epoch is " + task.epoch) | ||
env.mapOutputTracker.updateEpoch(task.epoch) | ||
|
||
|
@@ -215,18 +225,17 @@ private[spark] class Executor( | |
for (m <- task.metrics) { | ||
m.setExecutorDeserializeTime(taskStart - deserializeStartTime) | ||
m.setExecutorRunTime(taskFinish - taskStart) | ||
m.setJvmGCTime(gcTime - startGCTime) | ||
m.setJvmGCTime(computeTotalGcTime() - startGCTime) | ||
m.setResultSerializationTime(afterSerialization - beforeSerialization) | ||
} | ||
|
||
val accumUpdates = Accumulators.values | ||
|
||
val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) | ||
val serializedDirectResult = ser.serialize(directResult) | ||
val resultSize = serializedDirectResult.limit | ||
|
||
// directSend = sending directly back to the driver | ||
val serializedResult = { | ||
val serializedResult: ByteBuffer = { | ||
if (maxResultSize > 0 && resultSize > maxResultSize) { | ||
logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + | ||
s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + | ||
|
@@ -248,42 +257,40 @@ private[spark] class Executor( | |
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) | ||
|
||
} catch { | ||
case ffe: FetchFailedException => { | ||
case ffe: FetchFailedException => | ||
val reason = ffe.toTaskEndReason | ||
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) | ||
} | ||
|
||
case _: TaskKilledException | _: InterruptedException if task.killed => { | ||
case _: TaskKilledException | _: InterruptedException if task.killed => | ||
logInfo(s"Executor killed $taskName (TID $taskId)") | ||
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) | ||
} | ||
|
||
case cDE: CommitDeniedException => { | ||
case cDE: CommitDeniedException => | ||
val reason = cDE.toTaskEndReason | ||
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) | ||
} | ||
|
||
case t: Throwable => { | ||
case t: Throwable => | ||
// Attempt to exit cleanly by informing the driver of our failure. | ||
// If anything goes wrong (or this was a fatal exception), we will delegate to | ||
// the default uncaught exception handler, which will terminate the Executor. | ||
logError(s"Exception in $taskName (TID $taskId)", t) | ||
|
||
val serviceTime = System.currentTimeMillis() - taskStart | ||
val metrics = attemptedTask.flatMap(t => t.metrics) | ||
for (m <- metrics) { | ||
m.setExecutorRunTime(serviceTime) | ||
m.setJvmGCTime(gcTime - startGCTime) | ||
val metrics: Option[TaskMetrics] = Option(task).flatMap { task => | ||
task.metrics.map { m => | ||
m.setExecutorRunTime(System.currentTimeMillis() - taskStart) | ||
m.setJvmGCTime(computeTotalGcTime() - startGCTime) | ||
m | ||
} | ||
} | ||
val reason = new ExceptionFailure(t, metrics) | ||
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) | ||
val taskEndReason = new ExceptionFailure(t, metrics) | ||
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason)) | ||
|
||
// Don't forcibly exit unless the exception was inherently fatal, to avoid | ||
// stopping other tasks unnecessarily. | ||
if (Utils.isFatalError(t)) { | ||
SparkUncaughtExceptionHandler.uncaughtException(t) | ||
} | ||
} | ||
|
||
} finally { | ||
// Release memory used by this thread for shuffles | ||
env.shuffleMemoryManager.releaseMemoryForThisThread() | ||
|
@@ -358,7 +365,7 @@ private[spark] class Executor( | |
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { | ||
logInfo("Fetching " + name + " with timestamp " + timestamp) | ||
// Fetch file with useCache mode, close cache for local mode. | ||
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, | ||
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, | ||
env.securityManager, hadoopConf, timestamp, useCache = !isLocal) | ||
currentFiles(name) = timestamp | ||
} | ||
|
@@ -370,12 +377,12 @@ private[spark] class Executor( | |
if (currentTimeStamp < timestamp) { | ||
logInfo("Fetching " + name + " with timestamp " + timestamp) | ||
// Fetch file with useCache mode, close cache for local mode. | ||
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, | ||
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, | ||
env.securityManager, hadoopConf, timestamp, useCache = !isLocal) | ||
currentJars(name) = timestamp | ||
// Add it to our class loader | ||
val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL | ||
if (!urlClassLoader.getURLs.contains(url)) { | ||
val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL | ||
if (!urlClassLoader.getURLs().contains(url)) { | ||
logInfo("Adding " + url + " to class loader") | ||
urlClassLoader.addURL(url) | ||
} | ||
|
@@ -384,61 +391,70 @@ private[spark] class Executor( | |
} | ||
} | ||
|
||
def startDriverHeartbeater() { | ||
val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) | ||
val timeout = AkkaUtils.lookupTimeout(conf) | ||
val retryAttempts = AkkaUtils.numRetries(conf) | ||
val retryIntervalMs = AkkaUtils.retryWaitMs(conf) | ||
val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) | ||
private val timeout = AkkaUtils.lookupTimeout(conf) | ||
private val retryAttempts = AkkaUtils.numRetries(conf) | ||
private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) | ||
private val heartbeatReceiverRef = | ||
AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) | ||
|
||
/** Reports heartbeat and metrics for active tasks to the driver. */ | ||
private def reportHeartBeat(): Unit = { | ||
// list of (task id, metrics) to send back to the driver | ||
val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() | ||
val curGCTime = computeTotalGcTime() | ||
|
||
for (taskRunner <- runningTasks.values()) { | ||
if (taskRunner.task != null) { | ||
taskRunner.task.metrics.foreach { metrics => | ||
metrics.updateShuffleReadMetrics() | ||
metrics.updateInputMetrics() | ||
metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) | ||
|
||
if (isLocal) { | ||
// JobProgressListener will hold an reference of it during | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be worth fixing the grammar in this comment. |
||
// onExecutorMetricsUpdate(), then JobProgressListener can not see | ||
// the changes of metrics any more, so make a deep copy of it | ||
val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics)) | ||
tasksMetrics += ((taskRunner.taskId, copiedMetrics)) | ||
} else { | ||
// It will be copied by serialization | ||
tasksMetrics += ((taskRunner.taskId, metrics)) | ||
} | ||
} | ||
} | ||
} | ||
|
||
val t = new Thread() { | ||
val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) | ||
try { | ||
val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, | ||
retryAttempts, retryIntervalMs, timeout) | ||
if (response.reregisterBlockManager) { | ||
logWarning("Told to re-register on heartbeat") | ||
env.blockManager.reregister() | ||
} | ||
} catch { | ||
case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e) | ||
} | ||
} | ||
|
||
/** | ||
* Starts a thread to report heartbeat and partial metrics for active tasks to driver. | ||
* This thread stops running when the executor is stopped. | ||
*/ | ||
private def startDriverHeartbeater(): Unit = { | ||
val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) | ||
val thread = new Thread() { | ||
override def run() { | ||
// Sleep a random interval so the heartbeats don't end up in sync | ||
Thread.sleep(interval + (math.random * interval).asInstanceOf[Int]) | ||
|
||
while (!isStopped) { | ||
val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() | ||
val curGCTime = gcTime | ||
|
||
for (taskRunner <- runningTasks.values()) { | ||
if (taskRunner.attemptedTask.nonEmpty) { | ||
Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => | ||
metrics.updateShuffleReadMetrics() | ||
metrics.updateInputMetrics() | ||
metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) | ||
|
||
if (isLocal) { | ||
// JobProgressListener will hold an reference of it during | ||
// onExecutorMetricsUpdate(), then JobProgressListener can not see | ||
// the changes of metrics any more, so make a deep copy of it | ||
val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics)) | ||
tasksMetrics += ((taskRunner.taskId, copiedMetrics)) | ||
} else { | ||
// It will be copied by serialization | ||
tasksMetrics += ((taskRunner.taskId, metrics)) | ||
} | ||
} | ||
} | ||
} | ||
|
||
val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) | ||
try { | ||
val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, | ||
retryAttempts, retryIntervalMs, timeout) | ||
if (response.reregisterBlockManager) { | ||
logWarning("Told to re-register on heartbeat") | ||
env.blockManager.reregister() | ||
} | ||
} catch { | ||
case NonFatal(t) => logWarning("Issue communicating with driver in heartbeater", t) | ||
} | ||
|
||
reportHeartBeat() | ||
Thread.sleep(interval) | ||
} | ||
} | ||
} | ||
t.setDaemon(true) | ||
t.setName("Driver Heartbeater") | ||
t.start() | ||
thread.setDaemon(true) | ||
thread.setName("driver-heartbeater") | ||
thread.start() | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this was inadvertently public before, and thus was public in Spark 1.3, I think that this change will cause a MiMa failure once we bump the version to 1.4.0-SNAPSHOT. Therefore, this PR sort of implicitly conflicts with #5056, so we'll have to make sure to re-test whichever PR we merge second.