diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 2e1e52906ceeb..936e7dd2d1910 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -29,8 +29,7 @@ import org.apache.spark.util.MemoryParam * Command-line parser for the driver client. */ private[spark] class ClientArguments(args: Array[String]) { - val defaultCores = 1 - val defaultMemory = 512 + import ClientArguments._ var cmd: String = "" // 'launch' or 'kill' var logLevel = Level.WARN @@ -39,9 +38,9 @@ private[spark] class ClientArguments(args: Array[String]) { var master: String = "" var jarUrl: String = "" var mainClass: String = "" - var supervise: Boolean = false - var memory: Int = defaultMemory - var cores: Int = defaultCores + var supervise: Boolean = DEFAULT_SUPERVISE + var memory: Int = DEFAULT_MEMORY + var cores: Int = DEFAULT_CORES private var _driverOptions = ListBuffer[String]() def driverOptions = _driverOptions.toSeq @@ -106,9 +105,10 @@ private[spark] class ClientArguments(args: Array[String]) { |Usage: DriverClient kill | |Options: - | -c CORES, --cores CORES Number of cores to request (default: $defaultCores) - | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $defaultMemory) + | -c CORES, --cores CORES Number of cores to request (default: $DEFAULT_CORES) + | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $DEFAULT_MEMORY) | -s, --supervise Whether to restart the driver on failure + | (default: $DEFAULT_SUPERVISE) | -v, --verbose Print more debugging output """.stripMargin System.err.println(usage) @@ -117,6 +117,10 @@ private[spark] class ClientArguments(args: Array[String]) { } object ClientArguments { + private[spark] val DEFAULT_CORES = 1 + private[spark] val DEFAULT_MEMORY = 512 // MB + private[spark] val DEFAULT_SUPERVISE = false + def isValidJarUrl(s: String): Boolean = { try { val uri = new URI(s) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1c89b2452e830..7c89f0bcaebae 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import org.apache.spark.executor.ExecutorURLClassLoader import org.apache.spark.util.Utils +import org.apache.spark.deploy.rest.StandaloneRestClient /** * Main gateway of launching a Spark application. @@ -72,6 +73,16 @@ object SparkSubmit { if (appArgs.verbose) { printStream.println(appArgs) } + + // In standalone cluster mode, use the brand new REST client to submit the application + val doingRest = appArgs.master.startsWith("spark://") && appArgs.deployMode == "cluster" + if (doingRest) { + println("Submitting driver through the REST interface.") + new StandaloneRestClient().submitDriver(appArgs) + println("Done submitting driver.") + return + } + val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 47059b08a397f..310b34a926338 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -104,6 +104,15 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St .orElse(sparkProperties.get("spark.master")) .orElse(env.get("MASTER")) .orNull + driverExtraClassPath = Option(driverExtraClassPath) + .orElse(sparkProperties.get("spark.driver.extraClassPath")) + .orNull + driverExtraJavaOptions = Option(driverExtraJavaOptions) + .orElse(sparkProperties.get("spark.driver.extraJavaOptions")) + .orNull + driverExtraLibraryPath = Option(driverExtraLibraryPath) + .orElse(sparkProperties.get("spark.driver.extraLibraryPath")) + .orNull driverMemory = Option(driverMemory) .orElse(sparkProperties.get("spark.driver.memory")) .orElse(env.get("SPARK_DRIVER_MEMORY")) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 4b631ec639071..24c08373ade4c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -43,6 +43,7 @@ import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI +import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI @@ -121,6 +122,8 @@ private[spark] class Master( throw new SparkException("spark.deploy.defaultCores must be positive") } + val restServer = new StandaloneRestServer(this, host, 6677) + override def preStart() { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala index 0a1b5efc2488d..69b9a4f4bdabb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala @@ -27,8 +27,8 @@ private[spark] object KillDriverResponseField extends StandaloneRestProtocolFiel case object MESSAGE extends KillDriverResponseField case object MASTER extends KillDriverResponseField case object DRIVER_ID extends KillDriverResponseField - case object DRIVER_STATE extends SubmitDriverResponseField - override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, DRIVER_STATE) + case object SUCCESS extends SubmitDriverResponseField + override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, SUCCESS) override val optionalFields = Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index c805d75968242..6059344d93b6f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -27,6 +27,7 @@ import com.google.common.base.Charsets import org.apache.spark.{SPARK_VERSION => sparkVersion} import org.apache.spark.deploy.SparkSubmitArguments +import org.apache.spark.util.Utils /** * A client that submits Spark applications using a stable REST protocol in standalone @@ -63,6 +64,12 @@ private[spark] class StandaloneRestClient { */ private def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequestMessage = { import SubmitDriverRequestField._ + val driverMemory = Option(args.driverMemory) + .map { m => Utils.memoryStringToMb(m).toString } + .orNull + val executorMemory = Option(args.executorMemory) + .map { m => Utils.memoryStringToMb(m).toString } + .orNull val message = new SubmitDriverRequestMessage() .setField(SPARK_VERSION, sparkVersion) .setField(MASTER, args.master) @@ -72,19 +79,21 @@ private[spark] class StandaloneRestClient { .setFieldIfNotNull(JARS, args.jars) .setFieldIfNotNull(FILES, args.files) .setFieldIfNotNull(PY_FILES, args.pyFiles) - .setFieldIfNotNull(DRIVER_MEMORY, args.driverMemory) + .setFieldIfNotNull(DRIVER_MEMORY, driverMemory) .setFieldIfNotNull(DRIVER_CORES, args.driverCores) .setFieldIfNotNull(DRIVER_EXTRA_JAVA_OPTIONS, args.driverExtraJavaOptions) .setFieldIfNotNull(DRIVER_EXTRA_CLASS_PATH, args.driverExtraClassPath) .setFieldIfNotNull(DRIVER_EXTRA_LIBRARY_PATH, args.driverExtraLibraryPath) .setFieldIfNotNull(SUPERVISE_DRIVER, args.supervise.toString) - .setFieldIfNotNull(EXECUTOR_MEMORY, args.executorMemory) + .setFieldIfNotNull(EXECUTOR_MEMORY, executorMemory) .setFieldIfNotNull(TOTAL_EXECUTOR_CORES, args.totalExecutorCores) - // Set each Spark property as its own field - // TODO: Include environment variables? + args.childArgs.zipWithIndex.foreach { case (arg, i) => + message.setFieldIfNotNull(APP_ARG(i), arg) + } args.sparkProperties.foreach { case (k, v) => message.setFieldIfNotNull(SPARK_PROPERTY(k), v) } + // TODO: set environment variables? message.validate() } @@ -175,8 +184,8 @@ private[spark] class StandaloneRestClient { object StandaloneRestClient { def main(args: Array[String]): Unit = { assert(args.length > 0) - val client = new StandaloneRestClient - client.killDriver("spark://" + args(0), "abc_driver") + //val client = new StandaloneRestClient + //client.submitDriver("spark://" + args(0)) println("Done.") } } \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala index c8ea8fd395c6e..7945271a870f3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala @@ -63,19 +63,28 @@ private[spark] abstract class StandaloneRestProtocolMessage( import StandaloneRestProtocolField._ - private val fields = new mutable.HashMap[StandaloneRestProtocolField, String] private val className = Utils.getFormattedClassName(this) + protected val fields = new mutable.HashMap[StandaloneRestProtocolField, String] // Set the action field fields(actionField) = action.toString + /** Return all fields currently set in this message. */ + def getFields: Map[StandaloneRestProtocolField, String] = fields + + /** Return the value of the given field. If the field is not present, return null. */ + def getField(key: StandaloneRestProtocolField): String = getFieldOption(key).orNull + /** Return the value of the given field. If the field is not present, throw an exception. */ - def getField(key: StandaloneRestProtocolField): String = { - fields.get(key).getOrElse { + def getFieldNotNull(key: StandaloneRestProtocolField): String = { + getFieldOption(key).getOrElse { throw new IllegalArgumentException(s"Field $key is not set in message $className") } } + /** Return the value of the given field as an option. */ + def getFieldOption(key: StandaloneRestProtocolField): Option[String] = fields.get(key) + /** Assign the given value to the field, overriding any existing value. */ def setField(key: StandaloneRestProtocolField, value: String): this.type = { if (key == actionField) { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index f4294b64a6530..344a3ef89a4d7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.rest import java.io.DataOutputStream +import java.net.InetSocketAddress import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.io.Source @@ -26,25 +27,37 @@ import com.google.common.base.Charsets import org.eclipse.jetty.server.{Request, Server} import org.eclipse.jetty.server.handler.AbstractHandler -import org.apache.spark.{Logging, SPARK_VERSION => sparkVersion} -import org.apache.spark.deploy.rest.StandaloneRestProtocolAction._ -import org.apache.spark.util.Utils +import org.apache.spark.{SPARK_VERSION => sparkVersion, Logging} +import org.apache.spark.deploy.master.Master +import org.apache.spark.util.{AkkaUtils, Utils} /** * A server that responds to requests submitted by the StandaloneRestClient. */ -private[spark] class StandaloneRestServer(requestedPort: Int) { - val server = new Server(requestedPort) - server.setHandler(new StandaloneRestHandler) +private[spark] class StandaloneRestServer(master: Master, host: String, requestedPort: Int) { + val server = new Server(new InetSocketAddress(host, requestedPort)) + server.setHandler(new StandaloneRestServerHandler(master)) server.start() - server.join() } /** * A Jetty handler that responds to requests submitted via the standalone REST protocol. */ -private[spark] class StandaloneRestHandler extends AbstractHandler with Logging { +private[spark] abstract class StandaloneRestHandler(master: Master) + extends AbstractHandler with Logging { + private implicit val askTimeout = AkkaUtils.askTimeout(master.conf) + + /** Handle a request to submit a driver. */ + protected def handleSubmit(request: SubmitDriverRequestMessage): SubmitDriverResponseMessage + /** Handle a request to kill a driver. */ + protected def handleKill(request: KillDriverRequestMessage): KillDriverResponseMessage + /** Handle a request for a driver's status. */ + protected def handleStatus(request: DriverStatusRequestMessage): DriverStatusResponseMessage + + /** + * Handle a request submitted by the StandaloneRestClient. + */ override def handle( target: String, baseRequest: Request, @@ -67,6 +80,10 @@ private[spark] class StandaloneRestHandler extends AbstractHandler with Logging } } + /** + * Construct the appropriate response message based on the type of the request message. + * If an IllegalArgumentException is thrown in the process, construct an error message. + */ private def constructResponseMessage( request: StandaloneRestProtocolMessage): StandaloneRestProtocolMessage = { // If the request is sent via the StandaloneRestClient, it should have already been @@ -74,67 +91,21 @@ private[spark] class StandaloneRestHandler extends AbstractHandler with Logging // against potential NPEs. If validation fails, return an ERROR message to the sender. try { request.validate() + request match { + case submit: SubmitDriverRequestMessage => handleSubmit(submit) + case kill: KillDriverRequestMessage => handleKill(kill) + case status: DriverStatusRequestMessage => handleStatus(status) + case unexpected => handleError( + s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") + } } catch { - case e: IllegalArgumentException => - return handleError(e.getMessage) - } - request match { - case submit: SubmitDriverRequestMessage => handleSubmitRequest(submit) - case kill: KillDriverRequestMessage => handleKillRequest(kill) - case status: DriverStatusRequestMessage => handleStatusRequest(status) - case unexpected => handleError( - s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") + // Propagate exception to user in an ErrorMessage. If the construction of the + // ErrorMessage itself throws an exception, log the exception and ignore the request. + case e: IllegalArgumentException => handleError(e.getMessage) } } - private def handleSubmitRequest( - request: SubmitDriverRequestMessage): SubmitDriverResponseMessage = { - import SubmitDriverResponseField._ - // TODO: Actually submit the driver - val message = "Driver is submitted successfully..." - val master = request.getField(SubmitDriverRequestField.MASTER) - val driverId = "new_driver_id" - val driverState = "SUBMITTED" - new SubmitDriverResponseMessage() - .setField(SPARK_VERSION, sparkVersion) - .setField(MESSAGE, message) - .setField(MASTER, master) - .setField(DRIVER_ID, driverId) - .setField(DRIVER_STATE, driverState) - .validate() - } - - private def handleKillRequest(request: KillDriverRequestMessage): KillDriverResponseMessage = { - import KillDriverResponseField._ - // TODO: Actually kill the driver - val message = "Driver is killed successfully..." - val master = request.getField(KillDriverRequestField.MASTER) - val driverId = request.getField(KillDriverRequestField.DRIVER_ID) - val driverState = "KILLED" - new KillDriverResponseMessage() - .setField(SPARK_VERSION, sparkVersion) - .setField(MESSAGE, message) - .setField(MASTER, master) - .setField(DRIVER_ID, driverId) - .setField(DRIVER_STATE, driverState) - .validate() - } - - private def handleStatusRequest( - request: DriverStatusRequestMessage): DriverStatusResponseMessage = { - import DriverStatusResponseField._ - // TODO: Actually look up the status of the driver - val master = request.getField(DriverStatusRequestField.MASTER) - val driverId = request.getField(DriverStatusRequestField.DRIVER_ID) - val driverState = "HEALTHY" - new DriverStatusResponseMessage() - .setField(SPARK_VERSION, sparkVersion) - .setField(MASTER, master) - .setField(DRIVER_ID, driverId) - .setField(DRIVER_STATE, driverState) - .validate() - } - + /** Construct an error message to signal the fact that an exception has been thrown. */ private def handleError(message: String): ErrorMessage = { import ErrorField._ new ErrorMessage() @@ -144,10 +115,10 @@ private[spark] class StandaloneRestHandler extends AbstractHandler with Logging } } -object StandaloneRestServer { - def main(args: Array[String]): Unit = { - println("Hey boy I'm starting a server.") - new StandaloneRestServer(6677) - readLine() - } -} \ No newline at end of file +//object StandaloneRestServer { +// def main(args: Array[String]): Unit = { +// println("Hey boy I'm starting a server.") +// new StandaloneRestServer(6677) +// readLine() +// } +//} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServerHandler.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServerHandler.scala new file mode 100644 index 0000000000000..e11698e51bf19 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServerHandler.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.File + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SPARK_VERSION => sparkVersion} +import org.apache.spark.SparkConf +import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.deploy.{Command, DriverDescription} +import org.apache.spark.deploy.ClientArguments._ +import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.master.Master + +/** + * + */ +private[spark] class StandaloneRestServerHandler(master: Master) + extends StandaloneRestHandler(master) { + + private implicit val askTimeout = AkkaUtils.askTimeout(master.conf) + + override protected def handleSubmit( + request: SubmitDriverRequestMessage): SubmitDriverResponseMessage = { + import SubmitDriverResponseField._ + val driverDescription = buildDriverDescription(request) + val response = AkkaUtils.askWithReply[SubmitDriverResponse]( + RequestSubmitDriver(driverDescription), master.self, askTimeout) + new SubmitDriverResponseMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MESSAGE, response.message) + .setField(MASTER, master.masterUrl) + .setField(SUCCESS, response.success.toString) + .setFieldIfNotNull(DRIVER_ID, response.driverId.orNull) + .validate() + } + + override protected def handleKill( + request: KillDriverRequestMessage): KillDriverResponseMessage = { + import KillDriverResponseField._ + val driverId = request.getFieldNotNull(KillDriverRequestField.DRIVER_ID) + val response = AkkaUtils.askWithReply[KillDriverResponse]( + RequestKillDriver(driverId), master.self, askTimeout) + new KillDriverResponseMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MESSAGE, response.message) + .setField(MASTER, master.masterUrl) + .setField(DRIVER_ID, driverId) + .setField(SUCCESS, response.success.toString) + .validate() + } + + override protected def handleStatus( + request: DriverStatusRequestMessage): DriverStatusResponseMessage = { + import DriverStatusResponseField._ + // TODO: Actually look up the status of the driver + val master = request.getField(DriverStatusRequestField.MASTER) + val driverId = request.getField(DriverStatusRequestField.DRIVER_ID) + val driverState = "HEALTHY" + new DriverStatusResponseMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MASTER, master) + .setField(DRIVER_ID, driverId) + .setField(DRIVER_STATE, driverState) + .validate() + } + + private def buildDriverDescription(request: SubmitDriverRequestMessage): DriverDescription = { + import SubmitDriverRequestField._ + + // Required fields + //val _master = request.getFieldNotNull(MASTER) + val appName = request.getFieldNotNull(APP_NAME) + val appResource = request.getFieldNotNull(APP_RESOURCE) + + // Since standalone cluster mode does not yet support python, + // we treat the main class as required + val mainClass = request.getFieldNotNull(MAIN_CLASS) + + // Optional fields + val jars = request.getFieldOption(JARS) + val files = request.getFieldOption(FILES) + val driverMemory = request.getFieldOption(DRIVER_MEMORY) + val driverCores = request.getFieldOption(DRIVER_CORES) + val driverExtraJavaOptions = request.getFieldOption(DRIVER_EXTRA_JAVA_OPTIONS) + val driverExtraClassPath = request.getFieldOption(DRIVER_EXTRA_CLASS_PATH) + val driverExtraLibraryPath = request.getFieldOption(DRIVER_EXTRA_LIBRARY_PATH) + val superviseDriver = request.getFieldOption(SUPERVISE_DRIVER) + val executorMemory = request.getFieldOption(EXECUTOR_MEMORY) + val totalExecutorCores = request.getFieldOption(TOTAL_EXECUTOR_CORES) + + // Parse special fields that take in parameters + val conf = new SparkConf(false) + val env = new mutable.HashMap[String, String] + val appArgs = new ArrayBuffer[(Int, String)] + request.getFields.foreach { case (k, v) => + println(s"> Found this field: $k = $v") + k match { + case APP_ARG(index) => appArgs += ((index, v)) + case SPARK_PROPERTY(propKey) => conf.set(propKey, v) + case ENVIRONMENT_VARIABLE(envKey) => env(envKey) = v + case _ => + } + } + + // Use the actual master URL instead of the one that refers to this REST server + // Otherwise, once the driver is launched it will contact with the wrong server + conf.set("spark.master", master.masterUrl) + conf.set("spark.app.name", appName) + conf.set("spark.jars", jars.map(_ + ",").getOrElse("") + appResource) // include app resource + files.foreach { f => conf.set("spark.files", f) } + driverExtraJavaOptions.foreach { j => conf.set("spark.driver.extraJavaOptions", j) } + driverExtraClassPath.foreach { cp => conf.set("spark.driver.extraClassPath", cp) } + driverExtraLibraryPath.foreach { lp => conf.set("spark.driver.extraLibraryPath", lp) } + executorMemory.foreach { m => conf.set("spark.executor.memory", m) } + totalExecutorCores.foreach { c => conf.set("spark.cores.max", c) } + + // Construct driver description and submit it + val actualDriverMemory = driverMemory.map(_.toInt).getOrElse(DEFAULT_MEMORY) + val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES) + val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) + val actualAppArgs = appArgs.sortBy(_._1).map(_._2) // sort by index, map to value + val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = new Command( + "org.apache.spark.deploy.worker.DriverWrapper", + Seq("{{WORKER_URL}}", mainClass) ++ actualAppArgs, + env, extraClassPath, extraLibraryPath, javaOpts) + new DriverDescription( + appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala index fd95ecb1aefbe..72f92f2c0d49a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala @@ -17,6 +17,8 @@ package org.apache.spark.deploy.rest +import scala.util.matching.Regex + import org.apache.spark.util.Utils /** @@ -39,9 +41,12 @@ private[spark] object SubmitDriverRequestField extends StandaloneRestProtocolFie case object DRIVER_EXTRA_JAVA_OPTIONS extends SubmitDriverRequestField case object DRIVER_EXTRA_CLASS_PATH extends SubmitDriverRequestField case object DRIVER_EXTRA_LIBRARY_PATH extends SubmitDriverRequestField - case object SUPERVISE_DRIVER extends SubmitDriverRequestField + case object SUPERVISE_DRIVER extends SubmitDriverRequestField // standalone cluster mode only case object EXECUTOR_MEMORY extends SubmitDriverRequestField case object TOTAL_EXECUTOR_CORES extends SubmitDriverRequestField + case class APP_ARG(index: Int) extends SubmitDriverRequestField { + override def toString: String = Utils.getFormattedClassName(this) + "_" + index + } case class SPARK_PROPERTY(prop: String) extends SubmitDriverRequestField { override def toString: String = Utils.getFormattedClassName(this) + "_" + prop } @@ -52,6 +57,22 @@ private[spark] object SubmitDriverRequestField extends StandaloneRestProtocolFie override val optionalFields = Seq(MESSAGE, MAIN_CLASS, JARS, FILES, PY_FILES, DRIVER_MEMORY, DRIVER_CORES, DRIVER_EXTRA_JAVA_OPTIONS, DRIVER_EXTRA_CLASS_PATH, DRIVER_EXTRA_LIBRARY_PATH, SUPERVISE_DRIVER, EXECUTOR_MEMORY, TOTAL_EXECUTOR_CORES) + + // Because certain fields taken in arguments, we cannot simply rely on the + // list of all fields to reconstruct a field from its String representation. + // Instead, we must treat these fields as special cases and match on their prefixes. + override def withName(field: String): StandaloneRestProtocolField = { + def buildRegex(obj: AnyRef): Regex = s"${Utils.getFormattedClassName(obj)}_(.*)".r + val appArg = buildRegex(APP_ARG) + val sparkProperty = buildRegex(SPARK_PROPERTY) + val environmentVariable = buildRegex(ENVIRONMENT_VARIABLE) + field match { + case appArg(f) => APP_ARG(f.toInt) + case sparkProperty(f) => SPARK_PROPERTY(f) + case environmentVariable(f) => ENVIRONMENT_VARIABLE(f) + case _ => super.withName(field) + } + } } /** @@ -60,7 +81,19 @@ private[spark] object SubmitDriverRequestField extends StandaloneRestProtocolFie private[spark] class SubmitDriverRequestMessage extends StandaloneRestProtocolMessage( StandaloneRestProtocolAction.SUBMIT_DRIVER_REQUEST, SubmitDriverRequestField.ACTION, - SubmitDriverRequestField.requiredFields) + SubmitDriverRequestField.requiredFields) { + + // Ensure continuous range of app arg indices starting from 0 + override def validate(): this.type = { + import SubmitDriverRequestField._ + val indices = fields.collect { case (a: APP_ARG, _) => a }.toSeq.sortBy(_.index).map(_.index) + val expectedIndices = (0 until indices.size).toSeq + if (indices != expectedIndices) { + throw new IllegalArgumentException(s"Malformed app arg indices: ${indices.mkString(",")}") + } + super.validate() + } +} private[spark] object SubmitDriverRequestMessage extends StandaloneRestProtocolMessageCompanion { protected override def newMessage(): StandaloneRestProtocolMessage = diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala index 034c7f80de234..e656c35ad9657 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala @@ -26,10 +26,10 @@ private[spark] object SubmitDriverResponseField extends StandaloneRestProtocolFi case object SPARK_VERSION extends SubmitDriverResponseField case object MESSAGE extends SubmitDriverResponseField case object MASTER extends SubmitDriverResponseField + case object SUCCESS extends SubmitDriverResponseField case object DRIVER_ID extends SubmitDriverResponseField - case object DRIVER_STATE extends SubmitDriverResponseField - override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, DRIVER_STATE) - override val optionalFields = Seq.empty + override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, SUCCESS) + override val optionalFields = Seq(DRIVER_ID) } /**