From b44e103b78b36fadd887dc4b894027a03069b1f7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 21 Jan 2015 17:19:33 -0800 Subject: [PATCH] Implement status requests + fix validation behavior This commit makes the StandaloneRestServer actually handle status requests. The existing polling behavior from o.a.s.deploy.Client is also implemented in the StandaloneRestClient and amended. Additionally, the validation behavior was confusing before this commit. Previously the error message would seem to indicate that the user constructed a malformed message even if the message was constructed on the server side. This commit ensures that the error message is different for these two cases. --- .../org/apache/spark/deploy/SparkSubmit.scala | 3 +- .../rest/DriverStatusRequestMessage.scala | 2 +- .../rest/DriverStatusResponseMessage.scala | 9 +-- .../spark/deploy/rest/ErrorMessage.scala | 2 +- .../rest/KillDriverRequestMessage.scala | 2 +- .../rest/KillDriverResponseMessage.scala | 2 +- .../deploy/rest/StandaloneRestClient.scala | 61 ++++++++++++++++++- .../deploy/rest/StandaloneRestServer.scala | 21 ++++--- .../rest/SubmitDriverRequestMessage.scala | 2 +- .../rest/SubmitDriverResponseMessage.scala | 2 +- .../spark/deploy/rest/SubmitRestClient.scala | 46 ++++++++------ .../spark/deploy/rest/SubmitRestServer.scala | 37 ++++++----- 12 files changed, 130 insertions(+), 59 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1e3c7c2f1bb18..30b982822dbaf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -49,8 +49,7 @@ object SparkSubmit { private val STANDALONE = 2 private val MESOS = 4 private val LOCAL = 8 - private val REST = 16 - private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | REST + private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL // Deploy modes private val CLIENT = 1 diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala index 57f79554151e2..d435687606e3f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala @@ -44,5 +44,5 @@ private[spark] class DriverStatusRequestMessage extends SubmitRestProtocolMessag private[spark] object DriverStatusRequestMessage extends SubmitRestProtocolMessageCompanion[DriverStatusRequestMessage] { protected override def newMessage() = new DriverStatusRequestMessage - protected override def fieldFromString(field: String) = DriverStatusRequestField.fromString(field) + protected override def fieldFromString(f: String) = DriverStatusRequestField.fromString(f) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala index 42c64dc601758..a0264643890a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala @@ -28,12 +28,13 @@ private[spark] object DriverStatusResponseField case object MESSAGE extends DriverStatusResponseField case object MASTER extends DriverStatusResponseField case object DRIVER_ID extends DriverStatusResponseField + case object SUCCESS extends DriverStatusResponseField + // Standalone specific fields case object DRIVER_STATE extends DriverStatusResponseField case object WORKER_ID extends DriverStatusResponseField case object WORKER_HOST_PORT extends DriverStatusResponseField - override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, - MASTER, DRIVER_ID, DRIVER_STATE, WORKER_ID, WORKER_HOST_PORT) - override val optionalFields = Seq.empty + override val requiredFields = Seq(ACTION, SPARK_VERSION, MASTER, DRIVER_ID, SUCCESS) + override val optionalFields = Seq(MESSAGE, DRIVER_STATE, WORKER_ID, WORKER_HOST_PORT) } /** @@ -48,5 +49,5 @@ private[spark] class DriverStatusResponseMessage extends SubmitRestProtocolMessa private[spark] object DriverStatusResponseMessage extends SubmitRestProtocolMessageCompanion[DriverStatusResponseMessage] { protected override def newMessage() = new DriverStatusResponseMessage - protected override def fieldFromString(field: String) = DriverStatusResponseField.fromString(field) + protected override def fieldFromString(f: String) = DriverStatusResponseField.fromString(f) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala index 04a298d98a349..aefd7b60d32af 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala @@ -39,5 +39,5 @@ private[spark] class ErrorMessage extends SubmitRestProtocolMessage( private[spark] object ErrorMessage extends SubmitRestProtocolMessageCompanion[ErrorMessage] { protected override def newMessage() = new ErrorMessage - protected override def fieldFromString(field: String) = ErrorField.fromString(field) + protected override def fieldFromString(f: String) = ErrorField.fromString(f) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala index 3245058ce4ba7..3353bfba5a690 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala @@ -44,5 +44,5 @@ private[spark] class KillDriverRequestMessage extends SubmitRestProtocolMessage( private[spark] object KillDriverRequestMessage extends SubmitRestProtocolMessageCompanion[KillDriverRequestMessage] { protected override def newMessage() = new KillDriverRequestMessage - protected override def fieldFromString(field: String) = KillDriverRequestField.fromString(field) + protected override def fieldFromString(f: String) = KillDriverRequestField.fromString(f) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala index 92db6cfa2d640..974fcb9936fcb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala @@ -45,5 +45,5 @@ private[spark] class KillDriverResponseMessage extends SubmitRestProtocolMessage private[spark] object KillDriverResponseMessage extends SubmitRestProtocolMessageCompanion[KillDriverResponseMessage] { protected override def newMessage() = new KillDriverResponseMessage - protected override def fieldFromString(field: String) = KillDriverResponseField.fromString(field) + protected override def fieldFromString(f: String) = KillDriverResponseField.fromString(f) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 03eaa93f0d33e..43164ae3a4c88 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -28,6 +28,58 @@ import org.apache.spark.util.Utils * This client is intended to communicate with the StandaloneRestServer. Cluster mode only. */ private[spark] class StandaloneRestClient extends SubmitRestClient { + import StandaloneRestClient._ + + /** + * Request that the REST server submit a driver specified by the provided arguments. + * + * If the driver was successfully submitted, this polls the status of the driver that was + * just submitted and reports it to the user. Otherwise, if the submission was unsuccessful, + * this reports failure and logs an error message provided by the REST server. + */ + override def submitDriver(args: SparkSubmitArguments): SubmitDriverResponseMessage = { + import SubmitDriverResponseField._ + val submitResponse = super.submitDriver(args).asInstanceOf[SubmitDriverResponseMessage] + val submitSuccess = submitResponse.getFieldNotNull(SUCCESS).toBoolean + if (submitSuccess) { + val driverId = submitResponse.getFieldNotNull(DRIVER_ID) + logInfo(s"Driver successfully submitted as $driverId. Polling driver state...") + pollSubmittedDriverStatus(args.master, driverId) + } else { + val submitMessage = submitResponse.getFieldNotNull(MESSAGE) + logError(s"Application submission failed: $submitMessage") + } + submitResponse + } + + /** + * Poll the status of the driver that was just submitted and report it. + * This retries up to a fixed number of times until giving up. + */ + private def pollSubmittedDriverStatus(master: String, driverId: String): Unit = { + import DriverStatusResponseField._ + (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => + val statusResponse = requestDriverStatus(master, driverId) + .asInstanceOf[DriverStatusResponseMessage] + val statusSuccess = statusResponse.getFieldNotNull(SUCCESS).toBoolean + if (statusSuccess) { + val driverState = statusResponse.getFieldNotNull(DRIVER_STATE) + val workerId = statusResponse.getFieldOption(WORKER_ID) + val workerHostPort = statusResponse.getFieldOption(WORKER_HOST_PORT) + val exception = statusResponse.getFieldOption(MESSAGE) + logInfo(s"State of driver $driverId is now $driverState.") + // Log worker node, if present + (workerId, workerHostPort) match { + case (Some(id), Some(hp)) => logInfo(s"Driver is running on worker $id at $hp.") + case _ => + } + // Log exception stack trace, if present + exception.foreach { e => logError(e) } + return + } + } + logError(s"Error: Master did not recognize driver $driverId.") + } /** Construct a submit driver request message. */ override protected def constructSubmitRequest( @@ -54,7 +106,7 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { args.childArgs.foreach(message.appendAppArg) args.sparkProperties.foreach { case (k, v) => message.setSparkProperty(k, v) } // TODO: send special environment variables? - message.validate() + message } /** Construct a kill driver request message. */ @@ -66,7 +118,6 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { .setField(SPARK_VERSION, sparkVersion) .setField(MASTER, master) .setField(DRIVER_ID, driverId) - .validate() } /** Construct a driver status request message. */ @@ -78,7 +129,6 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { .setField(SPARK_VERSION, sparkVersion) .setField(MASTER, master) .setField(DRIVER_ID, driverId) - .validate() } /** Throw an exception if this is not standalone mode. */ @@ -101,3 +151,8 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { new URL("http://" + master.stripPrefix("spark://")) } } + +private object StandaloneRestClient { + val REPORT_DRIVER_STATUS_INTERVAL = 1000 + val REPORT_DRIVER_STATUS_MAX_TRIES = 10 +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 7916029517ccc..563ee1c251442 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -67,7 +67,6 @@ private[spark] class StandaloneRestServerHandler( .setField(MASTER, masterUrl) .setField(SUCCESS, response.success.toString) .setFieldIfNotNull(DRIVER_ID, response.driverId.orNull) - .validate() } /** Handle a request to kill a driver. */ @@ -83,23 +82,29 @@ private[spark] class StandaloneRestServerHandler( .setField(MASTER, masterUrl) .setField(DRIVER_ID, driverId) .setField(SUCCESS, response.success.toString) - .validate() } /** Handle a request for a driver's status. */ override protected def handleStatus( request: DriverStatusRequestMessage): DriverStatusResponseMessage = { import DriverStatusResponseField._ - // TODO: Actually look up the status of the driver - val master = request.getField(DriverStatusRequestField.MASTER) val driverId = request.getField(DriverStatusRequestField.DRIVER_ID) - val driverState = "HEALTHY" + val response = AkkaUtils.askWithReply[DriverStatusResponse]( + RequestDriverStatus(driverId), masterActor, askTimeout) + // Format exception nicely, if it exists + val message = response.exception.map { e => + val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") + s"Exception from the cluster:\n$e\n$stackTraceString" + } new DriverStatusResponseMessage() .setField(SPARK_VERSION, sparkVersion) - .setField(MASTER, master) + .setField(MASTER, masterUrl) .setField(DRIVER_ID, driverId) - .setField(DRIVER_STATE, driverState) - .validate() + .setField(SUCCESS, response.found.toString) + .setFieldIfNotNull(DRIVER_STATE, response.state.map(_.toString).orNull) + .setFieldIfNotNull(WORKER_ID, response.workerId.orNull) + .setFieldIfNotNull(WORKER_HOST_PORT, response.workerHostPort.orNull) + .setFieldIfNotNull(MESSAGE, message.orNull) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala index 47f97b4fdc77f..1ce867febcf9a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala @@ -105,7 +105,7 @@ private[spark] object SubmitDriverRequestMessage import SubmitDriverRequestField._ protected override def newMessage() = new SubmitDriverRequestMessage - protected override def fieldFromString(field: String) = SubmitDriverRequestField.fromString(field) + protected override def fieldFromString(f: String) = SubmitDriverRequestField.fromString(f) /** * Process the given field and value appropriately based on the type of the field. diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala index 70670fd6c9c78..4551707660377 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala @@ -45,5 +45,5 @@ private[spark] class SubmitDriverResponseMessage extends SubmitRestProtocolMessa private[spark] object SubmitDriverResponseMessage extends SubmitRestProtocolMessageCompanion[SubmitDriverResponseMessage] { protected override def newMessage() = new SubmitDriverResponseMessage - protected override def fieldFromString(field: String) = SubmitDriverResponseField.fromString(field) + protected override def fieldFromString(f: String) = SubmitDriverResponseField.fromString(f) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala index b3e0d9e02fabd..513c17deee89c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala @@ -17,14 +17,14 @@ package org.apache.spark.deploy.rest -import java.io.DataOutputStream +import java.io.{DataOutputStream, FileNotFoundException} import java.net.{HttpURLConnection, URL} import scala.io.Source import com.google.common.base.Charsets -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkException} import org.apache.spark.deploy.SparkSubmitArguments /** @@ -33,8 +33,8 @@ import org.apache.spark.deploy.SparkSubmitArguments */ private[spark] abstract class SubmitRestClient extends Logging { - /** Request that the REST server submits a driver specified by the provided arguments. */ - def submitDriver(args: SparkSubmitArguments): Unit = { + /** Request that the REST server submit a driver specified by the provided arguments. */ + def submitDriver(args: SparkSubmitArguments): SubmitRestProtocolMessage = { validateSubmitArguments(args) val url = getHttpUrl(args.master) val request = constructSubmitRequest(args) @@ -42,8 +42,8 @@ private[spark] abstract class SubmitRestClient extends Logging { sendHttp(url, request) } - /** Request that the REST server kills the specified driver. */ - def killDriver(master: String, driverId: String): Unit = { + /** Request that the REST server kill the specified driver. */ + def killDriver(master: String, driverId: String): SubmitRestProtocolMessage = { validateMaster(master) val url = getHttpUrl(master) val request = constructKillRequest(master, driverId) @@ -52,7 +52,7 @@ private[spark] abstract class SubmitRestClient extends Logging { } /** Request the status of the specified driver from the REST server. */ - def requestDriverStatus(master: String, driverId: String): Unit = { + def requestDriverStatus(master: String, driverId: String): SubmitRestProtocolMessage = { validateMaster(master) val url = getHttpUrl(master) val request = constructStatusRequest(master, driverId) @@ -82,18 +82,24 @@ private[spark] abstract class SubmitRestClient extends Logging { * Return the response received from the REST server. */ private def sendHttp(url: URL, request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = { - val conn = url.openConnection().asInstanceOf[HttpURLConnection] - conn.setRequestMethod("POST") - conn.setRequestProperty("Content-Type", "application/json") - conn.setRequestProperty("charset", "utf-8") - conn.setDoOutput(true) - val requestJson = request.toJson - logDebug(s"Sending the following request to the REST server:\n$requestJson") - val out = new DataOutputStream(conn.getOutputStream) - out.write(requestJson.getBytes(Charsets.UTF_8)) - out.close() - val responseJson = Source.fromInputStream(conn.getInputStream).mkString - logDebug(s"Response from the REST server:\n$responseJson") - SubmitRestProtocolMessage.fromJson(responseJson) + try { + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + conn.setRequestProperty("Content-Type", "application/json") + conn.setRequestProperty("charset", "utf-8") + conn.setDoOutput(true) + request.validate() + val requestJson = request.toJson + logDebug(s"Sending the following request to the REST server:\n$requestJson") + val out = new DataOutputStream(conn.getOutputStream) + out.write(requestJson.getBytes(Charsets.UTF_8)) + out.close() + val responseJson = Source.fromInputStream(conn.getInputStream).mkString + logDebug(s"Response from the REST server:\n$responseJson") + SubmitRestProtocolMessage.fromJson(responseJson) + } catch { + case e: FileNotFoundException => + throw new SparkException(s"Unable to connect to REST server $url", e) + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala index 980d6089b6760..c659dfddbf6ab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -100,23 +100,29 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi */ private def constructResponseMessage( request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = { - // If the request is sent via the SubmitRestClient, it should have already been validated - // remotely. In case this is not true, validate the request here again to guard against - // potential NPEs. If validation fails, send an error message back to the sender. - try { - request.validate() - request match { - case submit: SubmitDriverRequestMessage => handleSubmit(submit) - case kill: KillDriverRequestMessage => handleKill(kill) - case status: DriverStatusRequestMessage => handleStatus(status) - case unexpected => handleError( - s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") + // Validate the request message to ensure that it is correctly constructed. If the request + // is sent via the SubmitRestClient, it should have already been validated remotely. In case + // this is not true, do it again here to guard against potential NPEs. If validation fails, + // send an error message back to the sender. + val response = + try { + request.validate() + request match { + case submit: SubmitDriverRequestMessage => handleSubmit(submit) + case kill: KillDriverRequestMessage => handleKill(kill) + case status: DriverStatusRequestMessage => handleStatus(status) + case unexpected => handleError( + s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") + } + } catch { + case e: IllegalArgumentException => handleError(e.getMessage) } + // Validate the response message to ensure that it is correctly constructed. If it is not, + // propagate the exception back to the client and signal that it is a server error. + try { + response.validate() } catch { - // Propagate exception to user in an ErrorMessage. - // Note that the construction of the error message itself may throw an exception. - // In this case, let the higher level caller take care of this request. - case e: IllegalArgumentException => handleError(e.getMessage) + case e: IllegalArgumentException => handleError(s"Internal server error: ${e.getMessage}") } } @@ -126,6 +132,5 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi new ErrorMessage() .setField(SPARK_VERSION, sparkVersion) .setField(MESSAGE, message) - .validate() } }