Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tighten up field/method visibility in Executor and made some code more clear to read. #4850

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions core/src/main/scala/org/apache/spark/TaskEndReason.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,7 @@ case object TaskKilled extends TaskFailedReason {
* Task requested the driver to commit, but was denied.
*/
@DeveloperApi
case class TaskCommitDenied(
jobID: Int,
partitionID: Int,
attemptID: Int)
extends TaskFailedReason {
case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason {
override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" +
s" for job: $jobID, partition: $partitionID, attempt: $attemptID"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@ import org.apache.spark.{TaskCommitDenied, TaskEndReason}
/**
* Exception thrown when a task attempts to commit output to HDFS but is denied by the driver.
*/
class CommitDeniedException(
private[spark] class CommitDeniedException(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this was inadvertently public before, and thus was public in Spark 1.3, I think that this change will cause a MiMa failure once we bump the version to 1.4.0-SNAPSHOT. Therefore, this PR sort of implicitly conflicts with #5056, so we'll have to make sure to re-test whichever PR we merge second.

msg: String,
jobID: Int,
splitID: Int,
attemptID: Int)
extends Exception(msg) {

def toTaskEndReason: TaskEndReason = new TaskCommitDenied(jobID, splitID, attemptID)

def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptID)
}

196 changes: 106 additions & 90 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.io.File
import java.lang.management.ManagementFactory
import java.net.URL
import java.nio.ByteBuffer
import java.util.concurrent._
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
Expand All @@ -31,24 +31,26 @@ import akka.actor.Props

import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader,
SparkUncaughtExceptionHandler, AkkaUtils, Utils}
import org.apache.spark.util._

/**
* Spark executor used with Mesos, YARN, and the standalone scheduler.
* In coarse-grained mode, an existing actor system is provided.
* Spark executor, backed by a threadpool to run tasks.
*
* This can be used with Mesos, YARN, and the standalone scheduler.
* An internal RPC interface (at the moment Akka) is used for communication with the driver,
* except in the case of Mesos fine-grained mode.
*/
private[spark] class Executor(
executorId: String,
executorHostname: String,
env: SparkEnv,
userClassPath: Seq[URL] = Nil,
isLocal: Boolean = false)
extends Logging
{
extends Logging {

logInfo(s"Starting executor ID $executorId on host $executorHostname")

// Application dependencies (added through SparkContext) that we've fetched so far on this node.
Expand Down Expand Up @@ -78,9 +80,8 @@ private[spark] class Executor(
}

// Start worker thread pool
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")

val executorSource = new ExecutorSource(this, executorId)
private val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
private val executorSource = new ExecutorSource(threadPool, executorId)

if (!isLocal) {
env.metricsSystem.registerSource(executorSource)
Expand Down Expand Up @@ -122,21 +123,21 @@ private[spark] class Executor(
taskId: Long,
attemptNumber: Int,
taskName: String,
serializedTask: ByteBuffer) {
serializedTask: ByteBuffer): Unit = {
val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
serializedTask)
runningTasks.put(taskId, tr)
threadPool.execute(tr)
}

def killTask(taskId: Long, interruptThread: Boolean) {
def killTask(taskId: Long, interruptThread: Boolean): Unit = {
val tr = runningTasks.get(taskId)
if (tr != null) {
tr.kill(interruptThread)
}
}

def stop() {
def stop(): Unit = {
env.metricsSystem.report()
env.actorSystem.stop(executorActor)
isStopped = true
Expand All @@ -146,7 +147,10 @@ private[spark] class Executor(
}
}

private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
/** Returns the total amount of time this JVM process has spent in garbage collection. */
private def computeTotalGcTime(): Long = {
ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
}

class TaskRunner(
execBackend: ExecutorBackend,
Expand All @@ -156,27 +160,34 @@ private[spark] class Executor(
serializedTask: ByteBuffer)
extends Runnable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This attemptedTask vs task stuff in the old code is/was really confusing, so thanks for cleaning it up.


/** Whether this task has been killed. */
@volatile private var killed = false
@volatile var task: Task[Any] = _
@volatile var attemptedTask: Option[Task[Any]] = None

/** How much the JVM process has spent in GC when the task starts to run. */
@volatile var startGCTime: Long = _

def kill(interruptThread: Boolean) {
/**
* The task to run. This will be set in run() by deserializing the task binary coming
* from the driver. Once it is set, it will never be changed.
*/
@volatile var task: Task[Any] = _

def kill(interruptThread: Boolean): Unit = {
logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
killed = true
if (task != null) {
task.kill(interruptThread)
}
}

override def run() {
override def run(): Unit = {
val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
startGCTime = gcTime
startGCTime = computeTotalGcTime()

try {
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
Expand All @@ -193,7 +204,6 @@ private[spark] class Executor(
throw new TaskKilledException
}

attemptedTask = Some(task)
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)

Expand All @@ -215,18 +225,17 @@ private[spark] class Executor(
for (m <- task.metrics) {
m.setExecutorDeserializeTime(taskStart - deserializeStartTime)
m.setExecutorRunTime(taskFinish - taskStart)
m.setJvmGCTime(gcTime - startGCTime)
m.setJvmGCTime(computeTotalGcTime() - startGCTime)
m.setResultSerializationTime(afterSerialization - beforeSerialization)
}

val accumUpdates = Accumulators.values

val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit

// directSend = sending directly back to the driver
val serializedResult = {
val serializedResult: ByteBuffer = {
if (maxResultSize > 0 && resultSize > maxResultSize) {
logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
Expand All @@ -248,42 +257,40 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

} catch {
case ffe: FetchFailedException => {
case ffe: FetchFailedException =>
val reason = ffe.toTaskEndReason
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}

case _: TaskKilledException | _: InterruptedException if task.killed => {
case _: TaskKilledException | _: InterruptedException if task.killed =>
logInfo(s"Executor killed $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
}

case cDE: CommitDeniedException => {
case cDE: CommitDeniedException =>
val reason = cDE.toTaskEndReason
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}

case t: Throwable => {
case t: Throwable =>
// Attempt to exit cleanly by informing the driver of our failure.
// If anything goes wrong (or this was a fatal exception), we will delegate to
// the default uncaught exception handler, which will terminate the Executor.
logError(s"Exception in $taskName (TID $taskId)", t)

val serviceTime = System.currentTimeMillis() - taskStart
val metrics = attemptedTask.flatMap(t => t.metrics)
for (m <- metrics) {
m.setExecutorRunTime(serviceTime)
m.setJvmGCTime(gcTime - startGCTime)
val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>
task.metrics.map { m =>
m.setExecutorRunTime(System.currentTimeMillis() - taskStart)
m.setJvmGCTime(computeTotalGcTime() - startGCTime)
m
}
}
val reason = new ExceptionFailure(t, metrics)
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
val taskEndReason = new ExceptionFailure(t, metrics)
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason))

// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
SparkUncaughtExceptionHandler.uncaughtException(t)
}
}

} finally {
// Release memory used by this thread for shuffles
env.shuffleMemoryManager.releaseMemoryForThisThread()
Expand Down Expand Up @@ -358,7 +365,7 @@ private[spark] class Executor(
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentFiles(name) = timestamp
}
Expand All @@ -370,12 +377,12 @@ private[spark] class Executor(
if (currentTimeStamp < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentJars(name) = timestamp
// Add it to our class loader
val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
if (!urlClassLoader.getURLs.contains(url)) {
val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL
if (!urlClassLoader.getURLs().contains(url)) {
logInfo("Adding " + url + " to class loader")
urlClassLoader.addURL(url)
}
Expand All @@ -384,61 +391,70 @@ private[spark] class Executor(
}
}

def startDriverHeartbeater() {
val interval = conf.getInt("spark.executor.heartbeatInterval", 10000)
val timeout = AkkaUtils.lookupTimeout(conf)
val retryAttempts = AkkaUtils.numRetries(conf)
val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem)
private val timeout = AkkaUtils.lookupTimeout(conf)
private val retryAttempts = AkkaUtils.numRetries(conf)
private val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
private val heartbeatReceiverRef =
AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem)

/** Reports heartbeat and metrics for active tasks to the driver. */
private def reportHeartBeat(): Unit = {
// list of (task id, metrics) to send back to the driver
val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
val curGCTime = computeTotalGcTime()

for (taskRunner <- runningTasks.values()) {
if (taskRunner.task != null) {
taskRunner.task.metrics.foreach { metrics =>
metrics.updateShuffleReadMetrics()
metrics.updateInputMetrics()
metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)

if (isLocal) {
// JobProgressListener will hold an reference of it during
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be worth fixing the grammar in this comment.

// onExecutorMetricsUpdate(), then JobProgressListener can not see
// the changes of metrics any more, so make a deep copy of it
val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics))
tasksMetrics += ((taskRunner.taskId, copiedMetrics))
} else {
// It will be copied by serialization
tasksMetrics += ((taskRunner.taskId, metrics))
}
}
}
}

val t = new Thread() {
val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
try {
val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef,
retryAttempts, retryIntervalMs, timeout)
if (response.reregisterBlockManager) {
logWarning("Told to re-register on heartbeat")
env.blockManager.reregister()
}
} catch {
case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e)
}
}

/**
* Starts a thread to report heartbeat and partial metrics for active tasks to driver.
* This thread stops running when the executor is stopped.
*/
private def startDriverHeartbeater(): Unit = {
val interval = conf.getInt("spark.executor.heartbeatInterval", 10000)
val thread = new Thread() {
override def run() {
// Sleep a random interval so the heartbeats don't end up in sync
Thread.sleep(interval + (math.random * interval).asInstanceOf[Int])

while (!isStopped) {
val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
val curGCTime = gcTime

for (taskRunner <- runningTasks.values()) {
if (taskRunner.attemptedTask.nonEmpty) {
Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
metrics.updateShuffleReadMetrics()
metrics.updateInputMetrics()
metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)

if (isLocal) {
// JobProgressListener will hold an reference of it during
// onExecutorMetricsUpdate(), then JobProgressListener can not see
// the changes of metrics any more, so make a deep copy of it
val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics))
tasksMetrics += ((taskRunner.taskId, copiedMetrics))
} else {
// It will be copied by serialization
tasksMetrics += ((taskRunner.taskId, metrics))
}
}
}
}

val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
try {
val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef,
retryAttempts, retryIntervalMs, timeout)
if (response.reregisterBlockManager) {
logWarning("Told to re-register on heartbeat")
env.blockManager.reregister()
}
} catch {
case NonFatal(t) => logWarning("Issue communicating with driver in heartbeater", t)
}

reportHeartBeat()
Thread.sleep(interval)
}
}
}
t.setDaemon(true)
t.setName("Driver Heartbeater")
t.start()
thread.setDaemon(true)
thread.setName("driver-heartbeater")
thread.start()
}
}
Loading