Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-45227][CORE] Fix a subtle thread-safety issue with CoarseGrainedExecutorBackend #43021

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ package org.apache.spark.executor
import java.net.URL
import java.nio.ByteBuffer
import java.util.Locale
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.mutable
import scala.util.{Failure, Success}
import scala.util.control.NonFatal

Expand Down Expand Up @@ -71,9 +71,12 @@ private[spark] class CoarseGrainedExecutorBackend(
/**
* Map each taskId to the information about the resource allocated to it, Please refer to
* [[ResourceInformation]] for specifics.
* CHM is used to ensure thread-safety (https://issues.apache.org/jira/browse/SPARK-45227)
* Exposed for testing only.
*/
private[executor] val taskResources = new mutable.HashMap[Long, Map[String, ResourceInformation]]
private[executor] val taskResources = new ConcurrentHashMap[
Long, Map[String, ResourceInformation]
]

private var decommissioned = false

Expand Down Expand Up @@ -186,7 +189,7 @@ private[spark] class CoarseGrainedExecutorBackend(
} else {
val taskDesc = TaskDescription.decode(data.value)
logInfo("Got assigned task " + taskDesc.taskId)
taskResources(taskDesc.taskId) = taskDesc.resources
taskResources.put(taskDesc.taskId, taskDesc.resources)
executor.launchTask(this, taskDesc)
}

Expand Down Expand Up @@ -266,7 +269,7 @@ private[spark] class CoarseGrainedExecutorBackend(
}

override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer): Unit = {
val resources = taskResources.getOrElse(taskId, Map.empty[String, ResourceInformation])
val resources = taskResources.getOrDefault(taskId, Map.empty[String, ResourceInformation])
val cpus = executor.runningTasks.get(taskId).taskDescription.cpus
val msg = StatusUpdate(executorId, taskId, state, data, cpus, resources)
if (TaskState.isFinished(state)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
resourceProfile = ResourceProfile.getOrCreateDefaultProfile(conf))
assert(backend.taskResources.isEmpty)

val taskId = 1000000
val taskId = 1000000L
// We don't really verify the data, just pass it around.
val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
val taskDescription = new TaskDescription(taskId, 2, "1", "TASK 1000000", 19,
Expand Down Expand Up @@ -339,14 +339,14 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite
backend.self.send(LaunchTask(new SerializableBuffer(serializedTaskDescription)))
eventually(timeout(10.seconds)) {
assert(backend.taskResources.size == 1)
val resources = backend.taskResources(taskId)
val resources = backend.taskResources.get(taskId)
assert(resources(GPU).addresses sameElements Array("0", "1"))
}

// Update the status of a running task shall not affect `taskResources` map.
backend.statusUpdate(taskId, TaskState.RUNNING, data)
assert(backend.taskResources.size == 1)
val resources = backend.taskResources(taskId)
val resources = backend.taskResources.get(taskId)
assert(resources(GPU).addresses sameElements Array("0", "1"))

// Update the status of a finished task shall remove the entry from `taskResources` map.
Expand Down