Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some refactoring #6

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 31 additions & 84 deletions core/src/main/scala/org/apache/spark/ResourceUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ import org.apache.spark.util.Utils.executeAndGetOutput
*/
private[spark] case class ResourceID(componentName: String, resourceName: String) {
def confPrefix: String = s"$componentName.resource.$resourceName." // with ending dot
def amountConf: String = s"$confPrefix${ResourceUtils.AMOUNT}"
def discoveryScriptConf: String = s"$confPrefix${ResourceUtils.DISCOVERY_SCRIPT}"
def vendorConf: String = s"$confPrefix${ResourceUtils.VENDOR}"
}

private[spark] case class ResourceRequest(
id: ResourceID,
count: Int,
amount: Int,
discoveryScript: Option[String],
vendor: Option[String])

Expand All @@ -65,16 +68,21 @@ private[spark] object ResourceUtils extends Logging {
val DISCOVERY_SCRIPT = "discoveryScript"
val VENDOR = "vendor"
// user facing configs use .amount to allow to extend in the future,
// internally we currnetly only support addresses, so its just an integer count
// internally we currently only support addresses, so its just an integer count
val AMOUNT = "amount"

// case class to make extracting the JSON resource information easy
case class JsonResourceInformation(name: String, addresses: Seq[String])
private case class JsonResourceInformation(name: String, addresses: Seq[String]) {
def toResourceInformation: ResourceInformation = {
new ResourceInformation(name, addresses.toArray)
}
}

def parseResourceRequest(sparkConf: SparkConf, resourceId: ResourceID): ResourceRequest = {
val settings = sparkConf.getAllWithPrefix(resourceId.confPrefix).toMap
val amount = settings.get(AMOUNT).getOrElse(
throw new SparkException(s"You must specify an amount for ${resourceId.resourceName}")).toInt
val amount = settings.getOrElse(AMOUNT,
throw new SparkException(s"You must specify an amount for ${resourceId.resourceName}")
).toInt
val discoveryScript = settings.get(DISCOVERY_SCRIPT)
val vendor = settings.get(VENDOR)
ResourceRequest(resourceId, amount, discoveryScript, vendor)
Expand All @@ -95,19 +103,12 @@ private[spark] object ResourceUtils extends Logging {
}

def parseTaskResourceRequirements(sparkConf: SparkConf): Seq[TaskResourceRequirement] = {
listResourceIds(sparkConf, SPARK_TASK_PREFIX).map { id =>
val settings = sparkConf.getAllWithPrefix(id.confPrefix).toMap
val amount = settings.get(AMOUNT).getOrElse(
throw new SparkException(s"You must specify an amount for ${id.resourceName}")).toInt
TaskResourceRequirement(id.resourceName, amount)
parseAllResourceRequests(sparkConf, SPARK_TASK_PREFIX).map { request =>
TaskResourceRequirement(request.id.resourceName, request.amount)
}
}

def hasTaskResourceRequirements(sparkConf: SparkConf): Boolean = {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted this method because parseTaskResourceRequirements doesn't cost much.

sparkConf.getAllWithPrefix(s"$SPARK_TASK_PREFIX.resource.").nonEmpty
}

def parseAllocatedFromJsonFile(resourcesFile: String): Seq[ResourceAllocation] = {
private def parseAllocatedFromJsonFile(resourcesFile: String): Seq[ResourceAllocation] = {
implicit val formats = DefaultFormats
val resourceInput = new BufferedInputStream(new FileInputStream(resourcesFile))
try {
Expand All @@ -120,22 +121,21 @@ private[spark] object ResourceUtils extends Logging {
}
}

def parseResourceInformationFromJson(resourcesJson: String): JsonResourceInformation = {
private def parseResourceInformationFromJson(resourcesJson: String): ResourceInformation = {
implicit val formats = DefaultFormats
try {
parse(resourcesJson).extract[JsonResourceInformation]
parse(resourcesJson).extract[JsonResourceInformation].toResourceInformation
} catch {
case e@(_: MappingException | _: MismatchedInputException | _: ClassCastException) =>
throw new SparkException(s"Exception parsing the resources in $resourcesJson", e)
}
}

def parseAllocatedAndDiscoverResources(
private def parseAllocatedOrDiscoverResources(
sparkConf: SparkConf,
componentName: String,
resourcesFileOpt: Option[String]): Seq[ResourceAllocation] = {
val allocated = resourcesFileOpt.map(parseAllocatedFromJsonFile(_))
.getOrElse(Seq.empty[ResourceAllocation])
val allocated = resourcesFileOpt.toSeq.flatMap(parseAllocatedFromJsonFile)
.filter(_.id.componentName == componentName)
val otherResourceIds = listResourceIds(sparkConf, componentName).diff(allocated.map(_.id))
allocated ++ otherResourceIds.map { id =>
Expand All @@ -144,28 +144,32 @@ private[spark] object ResourceUtils extends Logging {
}
}

def assertResourceAllocationMeetsRequest(
private def assertResourceAllocationMeetsRequest(
allocation: ResourceAllocation,
request: ResourceRequest): Unit = {
require(allocation.id == request.id && allocation.addresses.size >= request.count,
require(allocation.id == request.id && allocation.addresses.size >= request.amount,
s"Resource: ${allocation.id.resourceName}, with addresses: " +
s"${allocation.addresses.mkString(",")} " +
s"is less than what the user requested: ${request.count})")
s"is less than what the user requested: ${request.amount})")
}

def assertAllResourceAllocationsMeetRequests(
private def assertAllResourceAllocationsMeetRequests(
allocations: Seq[ResourceAllocation],
requests: Seq[ResourceRequest]): Unit = {
val allocated = allocations.map(x => x.id -> x).toMap
requests.foreach(r => assertResourceAllocationMeetsRequest(allocated(r.id), r))
}

/**
* Gets all resource information for the input component.
* @return a map from resource name to resource info
*/
def getAllResources(
sparkConf: SparkConf,
componentName: String,
resourcesFileOpt: Option[String]): Map[String, ResourceInformation] = {
val requests = parseAllResourceRequests(sparkConf, componentName)
val allocations = parseAllocatedAndDiscoverResources(sparkConf, componentName, resourcesFileOpt)
val allocations = parseAllocatedOrDiscoverResources(sparkConf, componentName, resourcesFileOpt)
assertAllResourceAllocationsMeetRequests(allocations, requests)
val resourceInfoMap = allocations.map(a => (a.id.resourceName, a.toResourceInfo)).toMap
logInfo("==============================================================")
Expand All @@ -175,7 +179,8 @@ private[spark] object ResourceUtils extends Logging {
resourceInfoMap
}

def discoverResource(resourceRequest: ResourceRequest): JsonResourceInformation = {
// visible for test
def discoverResource(resourceRequest: ResourceRequest): ResourceInformation = {
val resourceName = resourceRequest.id.resourceName
val script = resourceRequest.discoveryScript
val result = if (script.nonEmpty) {
Expand All @@ -199,64 +204,6 @@ private[spark] object ResourceUtils extends Logging {
result
}

def resourceAmountConfigName(id: ResourceID): String = s"${id.confPrefix}$AMOUNT"

def resourceDiscoveryScriptConfigName(id: ResourceID): String = {
s"${id.confPrefix}$DISCOVERY_SCRIPT"
}

def resourceVendorConfigName(id: ResourceID): String = s"${id.confPrefix}$VENDOR"

def setResourceAmountConf(conf: SparkConf, id: ResourceID, value: String) {
conf.set(resourceAmountConfigName(id), value)
}

def setResourceDiscoveryScriptConf(conf: SparkConf, id: ResourceID, value: String) {
conf.set(resourceDiscoveryScriptConfigName(id), value)
}

def setResourceVendorConf(conf: SparkConf, id: ResourceID, value: String) {
conf.set(resourceVendorConfigName(id), value)
}

def setDriverResourceAmountConf(conf: SparkConf, resourceName: String, value: String): Unit = {
val resourceId = ResourceID(SPARK_DRIVER_PREFIX, resourceName)
setResourceAmountConf(conf, resourceId, value)
}

def setDriverResourceDiscoveryConf(conf: SparkConf, resourceName: String, value: String): Unit = {
val resourceId = ResourceID(SPARK_DRIVER_PREFIX, resourceName)
setResourceDiscoveryScriptConf(conf, resourceId, value)
}

def setDriverResourceVendorConf(conf: SparkConf, resourceName: String, value: String): Unit = {
val resourceId = ResourceID(SPARK_DRIVER_PREFIX, resourceName)
setResourceVendorConf(conf, resourceId, value)
}

def setExecutorResourceAmountConf(conf: SparkConf, resourceName: String, value: String): Unit = {
val resourceId = ResourceID(SPARK_EXECUTOR_PREFIX, resourceName)
setResourceAmountConf(conf, resourceId, value)
}

def setExecutorResourceDiscoveryConf(
conf: SparkConf,
resourceName: String,
value: String): Unit = {
val resourceId = ResourceID(SPARK_EXECUTOR_PREFIX, resourceName)
setResourceDiscoveryScriptConf(conf, resourceId, value)
}

def setExecutorResourceVendorConf(conf: SparkConf, resourceName: String, value: String): Unit = {
val resourceId = ResourceID(SPARK_EXECUTOR_PREFIX, resourceName)
setResourceVendorConf(conf, resourceId, value)
}

def setTaskResourceAmountConf(conf: SparkConf, resourceName: String, value: String): Unit = {
val resourceId = ResourceID(SPARK_TASK_PREFIX, resourceName)
setResourceAmountConf(conf, resourceId, value)
}

// known types of resources
final val GPU: String = "gpu"
final val FPGA: String = "fpga"
Expand Down
10 changes: 5 additions & 5 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2691,25 +2691,25 @@ object SparkContext extends Logging {
val taskResourceRequirements = parseTaskResourceRequirements(sc.conf)
val executorResourcesAndCounts =
parseAllResourceRequests(sc.conf, SPARK_EXECUTOR_PREFIX)
.map(request => (request.id.resourceName, request.count)).toMap
.map(request => (request.id.resourceName, request.amount)).toMap
var numSlots = execCores / taskCores
var limitingResourceName = "CPU"

taskResourceRequirements.foreach { taskReq =>
// Make sure the executor resources were specified through config.
val execCount = executorResourcesAndCounts.getOrElse(taskReq.resourceName,
throw new SparkException("The executor resource config: " +
resourceAmountConfigName(ResourceID(SPARK_EXECUTOR_PREFIX, taskReq.resourceName)) +
ResourceID(SPARK_EXECUTOR_PREFIX, taskReq.resourceName).amountConf +
" needs to be specified since a task requirement config: " +
resourceAmountConfigName(ResourceID(SPARK_TASK_PREFIX, taskReq.resourceName)) +
ResourceID(SPARK_TASK_PREFIX, taskReq.resourceName).amountConf +
" was specified")
)
// Make sure the executor resources are large enough to launch at least one task.
if (execCount < taskReq.count) {
throw new SparkException("The executor resource config: " +
resourceAmountConfigName(ResourceID(SPARK_EXECUTOR_PREFIX, taskReq.resourceName)) +
ResourceID(SPARK_EXECUTOR_PREFIX, taskReq.resourceName).amountConf +
s" = $execCount has to be >= the task config: " +
resourceAmountConfigName(ResourceID(SPARK_TASK_PREFIX, taskReq.resourceName)) +
ResourceID(SPARK_TASK_PREFIX, taskReq.resourceName).amountConf +
s" = ${taskReq.count}")
}
// Compare and update the max slots each executor can provide.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ private[spark] class CoarseGrainedExecutorBackend(
// visible for testing
def parseOrFindResources(resourcesFileOpt: Option[String]): Map[String, ResourceInformation] = {
// only parse the resources if a task requires them
val resourceInfo = if (hasTaskResourceRequirements(env.conf)) {
val resourceInfo = if (parseTaskResourceRequirements(env.conf).nonEmpty) {
val resources = getAllResources(env.conf, SPARK_EXECUTOR_PREFIX, resourcesFileOpt)
if (resources.isEmpty) {
throw new SparkException("User specified resources per task via: " +
Expand Down
48 changes: 24 additions & 24 deletions core/src/test/scala/org/apache/spark/ResourceUtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.File
import java.nio.file.{Files => JavaFiles}

import org.apache.spark.ResourceUtils._
import org.apache.spark.TestResourceIDs._
import org.apache.spark.TestUtils._
import org.apache.spark.internal.config._
import org.apache.spark.util.Utils
Expand All @@ -34,8 +35,8 @@ class ResourceUtilsSuite extends SparkFunSuite
withTempDir { dir =>
val gpuFile = new File(dir, "gpuDiscoverScript")
val scriptPath = writeStringToFileAndSetPermissions(gpuFile, """'{"name": "gpu"}'""")
setExecutorResourceAmountConf(conf, GPU, "2")
setExecutorResourceDiscoveryConf(conf, GPU, scriptPath)
conf.set(EXECUTOR_GPU_ID.amountConf, "2")
conf.set(EXECUTOR_GPU_ID.discoveryScriptConf, scriptPath)

val error = intercept[IllegalArgumentException] {
getAllResources(conf, SPARK_EXECUTOR_PREFIX, None)
Expand All @@ -52,14 +53,14 @@ class ResourceUtilsSuite extends SparkFunSuite
val gpuFile = new File(dir, "gpuDiscoverScript")
val gpuDiscovery = writeStringToFileAndSetPermissions(gpuFile,
"""'{"name": "gpu", "addresses": ["0", "1"]}'""")
setExecutorResourceAmountConf(conf, GPU, "2")
setExecutorResourceDiscoveryConf(conf, GPU, gpuDiscovery)
conf.set(EXECUTOR_GPU_ID.amountConf, "2")
conf.set(EXECUTOR_GPU_ID.discoveryScriptConf, gpuDiscovery)

val fpgaFile = new File(dir, "fpgaDiscoverScript")
val fpgaDiscovery = writeStringToFileAndSetPermissions(fpgaFile,
"""'{"name": "fpga", "addresses": ["f1", "f2", "f3"]}'""")
setExecutorResourceAmountConf(conf, FPGA, "2")
setExecutorResourceDiscoveryConf(conf, FPGA, fpgaDiscovery)
conf.set(EXECUTOR_FPGA_ID.amountConf, "2")
conf.set(EXECUTOR_FPGA_ID.discoveryScriptConf, fpgaDiscovery)

val resources = getAllResources(conf, SPARK_EXECUTOR_PREFIX, None)
assert(resources.size === 2)
Expand All @@ -80,12 +81,12 @@ class ResourceUtilsSuite extends SparkFunSuite

test("list resource ids") {
val conf = new SparkConf
setDriverResourceAmountConf(conf, GPU, "2")
conf.set(DRIVER_GPU_ID.amountConf, "2")
var resources = listResourceIds(conf, SPARK_DRIVER_PREFIX)
assert(resources.size === 1, "should only have GPU for resource")
assert(resources(0).resourceName == GPU, "name should be gpu")

setDriverResourceAmountConf(conf, FPGA, "2")
conf.set(DRIVER_FPGA_ID.amountConf, "2")
val resourcesMap = listResourceIds(conf, SPARK_DRIVER_PREFIX)
.map{ rId => (rId.resourceName, 1)}.toMap
assert(resourcesMap.size === 2, "should only have GPU for resource")
Expand All @@ -95,27 +96,26 @@ class ResourceUtilsSuite extends SparkFunSuite

test("parse resource request") {
val conf = new SparkConf
setDriverResourceAmountConf(conf, GPU, "2")
val gpuResourceID = ResourceID(SPARK_DRIVER_PREFIX, GPU)
var request = parseResourceRequest(conf, gpuResourceID)
conf.set(DRIVER_GPU_ID.amountConf, "2")
var request = parseResourceRequest(conf, DRIVER_GPU_ID)
assert(request.id.resourceName === GPU, "should only have GPU for resource")
assert(request.count === 2, "GPU count should be 2")
assert(request.amount === 2, "GPU count should be 2")
assert(request.discoveryScript === None, "discovery script should be empty")
assert(request.vendor === None, "vendor should be empty")

val vendor = "nvidia.com"
val discoveryScript = "discoveryScriptGPU"
setDriverResourceDiscoveryConf(conf, GPU, discoveryScript)
setDriverResourceVendorConf(conf, GPU, vendor)
request = parseResourceRequest(conf, gpuResourceID)
conf.set(DRIVER_GPU_ID.discoveryScriptConf, discoveryScript)
conf.set(DRIVER_GPU_ID.vendorConf, vendor)
request = parseResourceRequest(conf, DRIVER_GPU_ID)
assert(request.id.resourceName === GPU, "should only have GPU for resource")
assert(request.count === 2, "GPU count should be 2")
assert(request.amount === 2, "GPU count should be 2")
assert(request.discoveryScript.get === discoveryScript, "discovery script should be empty")
assert(request.vendor.get === vendor, "vendor should be empty")

conf.remove(s"${gpuResourceID.confPrefix}$AMOUNT")
conf.remove(DRIVER_GPU_ID.amountConf)
val error = intercept[SparkException] {
request = parseResourceRequest(conf, gpuResourceID)
request = parseResourceRequest(conf, DRIVER_GPU_ID)
}.getMessage()

assert(error.contains("You must specify an amount for gpu"))
Expand All @@ -128,8 +128,8 @@ class ResourceUtilsSuite extends SparkFunSuite
val gpuFile = new File(dir, "gpuDiscoverScript")
val gpuDiscovery = writeStringToFileAndSetPermissions(gpuFile,
"""'{"name": "gpu", "addresses": ["0", "1"]}'""")
setDriverResourceAmountConf(conf, GPU, "2")
setDriverResourceDiscoveryConf(conf, GPU, gpuDiscovery)
conf.set(DRIVER_GPU_ID.amountConf, "2")
conf.set(DRIVER_GPU_ID.discoveryScriptConf, gpuDiscovery)

// make sure it reads from correct config, here it should use driver
val resources = getAllResources(conf, SPARK_DRIVER_PREFIX, None)
Expand All @@ -150,7 +150,7 @@ class ResourceUtilsSuite extends SparkFunSuite
"""'{"name": "fpga", "addresses": ["0", "1"]}'""")
val request =
ResourceRequest(
ResourceID(SPARK_DRIVER_PREFIX, GPU),
DRIVER_GPU_ID,
2,
Some(gpuDiscovery),
None)
Expand All @@ -174,7 +174,7 @@ class ResourceUtilsSuite extends SparkFunSuite

val request =
ResourceRequest(
ResourceID(SPARK_EXECUTOR_PREFIX, GPU),
EXECUTOR_GPU_ID,
2,
Some(gpuDiscovery),
None)
Expand All @@ -194,7 +194,7 @@ class ResourceUtilsSuite extends SparkFunSuite
try {
val request =
ResourceRequest(
ResourceID(SPARK_EXECUTOR_PREFIX, GPU),
EXECUTOR_GPU_ID,
2,
Some(file1.getPath()),
None)
Expand All @@ -211,7 +211,7 @@ class ResourceUtilsSuite extends SparkFunSuite
}

test("gpu's specified but not a discovery script") {
val request = ResourceRequest(ResourceID(SPARK_EXECUTOR_PREFIX, GPU), 2, None, None)
val request = ResourceRequest(EXECUTOR_GPU_ID, 2, None, None)

val error = intercept[SparkException] {
discoverResource(request)
Expand Down
Loading