Skip to content

Commit

Permalink
Addressed code review feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Aug 19, 2014
1 parent 4e5faa2 commit 6c08b07
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 40 deletions.
72 changes: 42 additions & 30 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ class DAGScheduler(
}

private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) {
val stageInfo = stageIdToStage(task.stageId).info
val stageInfo = stageIdToStage(task.stageId).latestInfo
listenerBus.post(SparkListenerTaskStart(task.stageId, stageInfo.attemptId, taskInfo))
submitWaitingStages()
}
Expand All @@ -696,8 +696,8 @@ class DAGScheduler(
// is in the process of getting stopped.
val stageFailedMessage = "Stage cancelled because SparkContext was shut down"
runningStages.foreach { stage =>
stage.info.stageFailed(stageFailedMessage)
listenerBus.post(SparkListenerStageCompleted(stage.info))
stage.latestInfo.stageFailed(stageFailedMessage)
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
}
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
}
Expand Down Expand Up @@ -782,7 +782,16 @@ class DAGScheduler(
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()
var tasks = ArrayBuffer[Task[_]]()

// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = {
if (stage.isShuffleMap) {
(0 until stage.numPartitions).filter(id => stage.outputLocs(id) == Nil)
} else {
val job = stage.resultOfJob.get
(0 until job.numPartitions).filter(id => !job.finished(id))
}
}

val properties = if (jobIdToActiveJob.contains(jobId)) {
jobIdToActiveJob(stage.jobId).properties
Expand All @@ -796,7 +805,8 @@ class DAGScheduler(
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))
stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size))
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))

// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
// Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
Expand Down Expand Up @@ -827,25 +837,22 @@ class DAGScheduler(
return
}

if (stage.isShuffleMap) {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
val locs = getPreferredLocs(stage.rdd, p)
val part = stage.rdd.partitions(p)
tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs)
val tasks: Seq[Task[_]] = if (stage.isShuffleMap) {
partitionsToCompute.map { id =>
val locs = getPreferredLocs(stage.rdd, id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, taskBinary, part, locs)
}
} else {
// This is a final stage; figure out its job's missing partitions
val job = stage.resultOfJob.get
for (id <- 0 until job.numPartitions if !job.finished(id)) {
partitionsToCompute.map { id =>
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
tasks += new ResultTask(stage.id, taskBinary, part, locs, id)
new ResultTask(stage.id, taskBinary, part, locs, id)
}
}

stage.info = StageInfo.fromStage(stage, Some(tasks.size))

if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this
// exception here because it would be fairly hard to catch the non-serializable exception
Expand All @@ -872,11 +879,11 @@ class DAGScheduler(
logDebug("New pending tasks: " + stage.pendingTasks)
taskScheduler.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
stage.info.submissionTime = Some(clock.getTime())
stage.latestInfo.submissionTime = Some(clock.getTime())
} else {
// Because we posted SparkListenerStageSubmitted earlier, we should post
// SparkListenerStageCompleted here in case there are no tasks to run.
listenerBus.post(SparkListenerStageCompleted(stage.info))
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
runningStages -= stage
Expand All @@ -890,7 +897,7 @@ class DAGScheduler(
private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
val stageId = task.stageId
val stageInfo = stageIdToStage(task.stageId).info
val stageInfo = stageIdToStage(task.stageId).latestInfo
val taskType = Utils.getFormattedClassName(task)

// The success case is dealt with separately below, since we need to compute accumulator
Expand All @@ -906,14 +913,19 @@ class DAGScheduler(
}
val stage = stageIdToStage(task.stageId)

def markStageAsFinished(stage: Stage) = {
val serviceTime = stage.info.submissionTime match {
def markStageAsFinished(stage: Stage, isSuccessful: Boolean) = {
val serviceTime = stage.latestInfo.submissionTime match {
case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0)
case _ => "Unknown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
stage.info.completionTime = Some(clock.getTime())
listenerBus.post(SparkListenerStageCompleted(stage.info))
if (isSuccessful) {
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
} else {

logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime))
}
stage.latestInfo.completionTime = Some(clock.getTime())
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
runningStages -= stage
}
event.reason match {
Expand All @@ -928,7 +940,7 @@ class DAGScheduler(
val name = acc.name.get
val stringPartialValue = Accumulators.stringifyPartialValue(partialValue)
val stringValue = Accumulators.stringifyValue(acc.value)
stage.info.accumulables(id) = AccumulableInfo(id, name, stringValue)
stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue)
event.taskInfo.accumulables +=
AccumulableInfo(id, name, Some(stringPartialValue), stringValue)
}
Expand All @@ -951,7 +963,7 @@ class DAGScheduler(
job.numFinished += 1
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
markStageAsFinished(stage)
markStageAsFinished(stage, isSuccessful = true)
cleanupStateForJobAndIndependentStages(job)
listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded))
}
Expand Down Expand Up @@ -980,7 +992,7 @@ class DAGScheduler(
stage.addOutputLoc(smt.partitionId, status)
}
if (runningStages.contains(stage) && stage.pendingTasks.isEmpty) {
markStageAsFinished(stage)
markStageAsFinished(stage, isSuccessful = true)
logInfo("looking for newly runnable stages")
logInfo("running: " + runningStages)
logInfo("waiting: " + waitingStages)
Expand Down Expand Up @@ -1033,7 +1045,7 @@ class DAGScheduler(
case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
// Mark the stage that the reducer was in as unrunnable
val failedStage = stageIdToStage(task.stageId)
listenerBus.post(SparkListenerStageCompleted(failedStage.info))
markStageAsFinished(failedStage, isSuccessful = false)
runningStages -= failedStage
// TODO: Cancel running tasks in the stage
logInfo("Marking " + failedStage + " (" + failedStage.name +
Expand Down Expand Up @@ -1147,7 +1159,7 @@ class DAGScheduler(
}
val dependentJobs: Seq[ActiveJob] =
activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq
failedStage.info.completionTime = Some(clock.getTime())
failedStage.latestInfo.completionTime = Some(clock.getTime())
for (job <- dependentJobs) {
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason")
}
Expand Down Expand Up @@ -1187,8 +1199,8 @@ class DAGScheduler(
if (runningStages.contains(stage)) {
try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
taskScheduler.cancelTasks(stageId, shouldInterruptThread)
stage.info.stageFailed(failureReason)
listenerBus.post(SparkListenerStageCompleted(stage.info))
stage.latestInfo.stageFailed(failureReason)
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
} catch {
case e: UnsupportedOperationException =>
logInfo(s"Could not cancel tasks for stage $stageId", e)
Expand Down
7 changes: 5 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/Stage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ import org.apache.spark.util.CallSite
* stage, the callSite gives the user code that created the RDD being shuffled. For a result
* stage, the callSite gives the user code that executes the associated action (e.g. count()).
*
* A single stage can consist of multiple attempts. In that case, the latestInfo field will
* be updated for each attempt.
*
*/
private[spark] class Stage(
val id: Int,
Expand Down Expand Up @@ -71,8 +74,8 @@ private[spark] class Stage(
val name = callSite.shortForm
val details = callSite.longForm

/** Pointer to the [StageInfo] object, set by DAGScheduler. */
var info: StageInfo = StageInfo.fromStage(this)
/** Pointer to the latest [StageInfo] object, set by DAGScheduler. */
var latestInfo: StageInfo = StageInfo.fromStage(this)

def isAvailable: Boolean = {
if (!isShuffleMap) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
// Map from stageId to StageInfo
val activeStages = new HashMap[Int, StageInfo]

// Map from (stageId, attemptId) to StageInfo
// Map from (stageId, attemptId) to StageUIData
val stageIdToData = new HashMap[(Int, Int), StageUIData]

val completedStages = ListBuffer[StageInfo]()
val failedStages = ListBuffer[StageInfo]()

val poolToActiveStages = HashMap[String, HashMap[(Int, Int), StageInfo]]()
// Map from pool name to a hash map (map from stage id to StageInfo).
val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]()

val executorIdToBlockManagerId = HashMap[String, BlockManagerId]()

Expand All @@ -72,7 +73,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
}

poolToActiveStages.get(stageData.schedulingPool).foreach { hashMap =>
hashMap.remove((stage.stageId, stage.attemptId))
hashMap.remove(stage.stageId)
}
activeStages.remove(stage.stageId)
if (stage.failureReason.isEmpty) {
Expand Down Expand Up @@ -109,8 +110,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION))
}

val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[(Int, Int), StageInfo]())
stages((stage.stageId, stage.attemptId)) = stage
val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo])
stages(stage.stageId) = stage
}

override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) {
}

private def poolTable(
makeRow: (Schedulable, HashMap[String, HashMap[(Int, Int), StageInfo]]) => Seq[Node],
makeRow: (Schedulable, HashMap[String, HashMap[Int, StageInfo]]) => Seq[Node],
rows: Seq[Schedulable]): Seq[Node] = {
<table class="table table-bordered table-striped table-condensed sortable table-fixed">
<thead>
Expand All @@ -53,7 +53,7 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) {

private def poolRow(
p: Schedulable,
poolToActiveStages: HashMap[String, HashMap[(Int, Int), StageInfo]]): Seq[Node] = {
poolToActiveStages: HashMap[String, HashMap[Int, StageInfo]]): Seq[Node] = {
val activeStages = poolToActiveStages.get(p.name) match {
case Some(stages) => stages.size
case None => 0
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
<h4>Summary Metrics</h4> No tasks have started yet
<h4>Tasks</h4> No tasks have started yet
</div>
return UIUtils.headerSparkPage("Details for Stage %s".format(stageId), content, parent)
return UIUtils.headerSparkPage(
s"Details for Stage $stageId (Attempt $stageAttemptId)", content, parent)
}

val stageData = stageDataOption.get
Expand Down

0 comments on commit 6c08b07

Please sign in to comment.