diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 4a5de8cd5da92..299148a912a8b 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -20,9 +20,9 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer import java.util.Locale +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicBoolean -import scala.collection.mutable import scala.util.{Failure, Success} import scala.util.control.NonFatal @@ -71,9 +71,12 @@ private[spark] class CoarseGrainedExecutorBackend( /** * Map each taskId to the information about the resource allocated to it, Please refer to * [[ResourceInformation]] for specifics. + * CHM is used to ensure thread-safety (https://issues.apache.org/jira/browse/SPARK-45227) * Exposed for testing only. */ - private[executor] val taskResources = new mutable.HashMap[Long, Map[String, ResourceInformation]] + private[executor] val taskResources = new ConcurrentHashMap[ + Long, Map[String, ResourceInformation] + ] private var decommissioned = false @@ -186,7 +189,7 @@ private[spark] class CoarseGrainedExecutorBackend( } else { val taskDesc = TaskDescription.decode(data.value) logInfo("Got assigned task " + taskDesc.taskId) - taskResources(taskDesc.taskId) = taskDesc.resources + taskResources.put(taskDesc.taskId, taskDesc.resources) executor.launchTask(this, taskDesc) } @@ -266,7 +269,7 @@ private[spark] class CoarseGrainedExecutorBackend( } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer): Unit = { - val resources = taskResources.getOrElse(taskId, Map.empty[String, ResourceInformation]) + val resources = taskResources.getOrDefault(taskId, Map.empty[String, ResourceInformation]) val cpus = executor.runningTasks.get(taskId).taskDescription.cpus val msg = StatusUpdate(executorId, taskId, state, data, cpus, resources) if (TaskState.isFinished(state)) { diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala index 0dcc7c7f9b4cf..909d605442575 100644 --- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala @@ -302,7 +302,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf)) assert(backend.taskResources.isEmpty) - val taskId = 1000000 + val taskId = 1000000L // We don't really verify the data, just pass it around. val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) val taskDescription = new TaskDescription(taskId, 2, "1", "TASK 1000000", 19, @@ -339,14 +339,14 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite backend.self.send(LaunchTask(new SerializableBuffer(serializedTaskDescription))) eventually(timeout(10.seconds)) { assert(backend.taskResources.size == 1) - val resources = backend.taskResources(taskId) + val resources = backend.taskResources.get(taskId) assert(resources(GPU).addresses sameElements Array("0", "1")) } // Update the status of a running task shall not affect `taskResources` map. backend.statusUpdate(taskId, TaskState.RUNNING, data) assert(backend.taskResources.size == 1) - val resources = backend.taskResources(taskId) + val resources = backend.taskResources.get(taskId) assert(resources(GPU).addresses sameElements Array("0", "1")) // Update the status of a finished task shall remove the entry from `taskResources` map.