diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index ec306c1e20456..a34b67db388f6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -75,7 +75,7 @@ private[spark] class TaskSchedulerImpl( // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. - val taskSetsByStage = new HashMap[Int, HashMap[Int, TaskSetManager]] + val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]] val taskIdToStageIdAndAttempt = new HashMap[Long, (Int, Int)] val taskIdToExecutorId = new HashMap[Long, String] @@ -163,7 +163,8 @@ private[spark] class TaskSchedulerImpl( this.synchronized { val manager = createTaskSetManager(taskSet, maxTaskFailures) val stage = taskSet.stageId - val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) + val stageTaskSets = + taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) stageTaskSets(taskSet.stageAttemptId) = manager val conflictingTaskSet = stageTaskSets.exists { case (_, ts) => ts.taskSet != taskSet && !ts.isZombie @@ -201,7 +202,7 @@ private[spark] class TaskSchedulerImpl( override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) - taskSetsByStage.get(stageId).foreach { attempts => + taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => attempts.foreach { case (_, tsm) => // There are two possible cases here: // 1. The task set manager has been created and some tasks have been scheduled. @@ -225,10 +226,10 @@ private[spark] class TaskSchedulerImpl( * cleaned up. */ def taskSetFinished(manager: TaskSetManager): Unit = synchronized { - taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage => + taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage => taskSetsForStage -= manager.taskSet.stageAttemptId if (taskSetsForStage.isEmpty) { - taskSetsByStage -= manager.taskSet.stageId + taskSetsByStageIdAndAttempt -= manager.taskSet.stageId } } manager.parent.removeSchedulable(manager) @@ -380,7 +381,7 @@ private[spark] class TaskSchedulerImpl( taskMetrics.flatMap { case (id, metrics) => for { (stageId, stageAttemptId) <- taskIdToStageIdAndAttempt.get(id) - attempts <- taskSetsByStage.get(stageId) + attempts <- taskSetsByStageIdAndAttempt.get(stageId) taskSetMgr <- attempts.get(stageAttemptId) } yield { (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics) @@ -416,10 +417,10 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (taskSetsByStage.nonEmpty) { + if (taskSetsByStageIdAndAttempt.nonEmpty) { // Have each task set throw a SparkException with the error for { - attempts <- taskSetsByStage.values + attempts <- taskSetsByStageIdAndAttempt.values manager <- attempts.values } { try { @@ -552,7 +553,7 @@ private[spark] class TaskSchedulerImpl( stageId: Int, stageAttemptId: Int): Option[TaskSetManager] = { for { - attempts <- taskSetsByStage.get(stageId) + attempts <- taskSetsByStageIdAndAttempt.get(stageId) manager <- attempts.get(stageAttemptId) } yield { manager