Skip to content

Commit

Permalink
Implement status requests + fix validation behavior
Browse files Browse the repository at this point in the history
This commit makes the StandaloneRestServer actually handle status
requests. The existing polling behavior from o.a.s.deploy.Client
is also implemented in the StandaloneRestClient and amended.

Additionally, the validation behavior was confusing before this
commit. Previously the error message would seem to indicate that
the user constructed a malformed message even if the message was
constructed on the server side. This commit ensures that the error
message is different for these two cases.
  • Loading branch information
Andrew Or committed Jan 22, 2015
1 parent 120ab9d commit b44e103
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ object SparkSubmit {
private val STANDALONE = 2
private val MESOS = 4
private val LOCAL = 8
private val REST = 16
private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | REST
private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL

// Deploy modes
private val CLIENT = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ private[spark] class DriverStatusRequestMessage extends SubmitRestProtocolMessag
private[spark] object DriverStatusRequestMessage
extends SubmitRestProtocolMessageCompanion[DriverStatusRequestMessage] {
protected override def newMessage() = new DriverStatusRequestMessage
protected override def fieldFromString(field: String) = DriverStatusRequestField.fromString(field)
protected override def fieldFromString(f: String) = DriverStatusRequestField.fromString(f)
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ private[spark] object DriverStatusResponseField
case object MESSAGE extends DriverStatusResponseField
case object MASTER extends DriverStatusResponseField
case object DRIVER_ID extends DriverStatusResponseField
case object SUCCESS extends DriverStatusResponseField
// Standalone specific fields
case object DRIVER_STATE extends DriverStatusResponseField
case object WORKER_ID extends DriverStatusResponseField
case object WORKER_HOST_PORT extends DriverStatusResponseField
override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE,
MASTER, DRIVER_ID, DRIVER_STATE, WORKER_ID, WORKER_HOST_PORT)
override val optionalFields = Seq.empty
override val requiredFields = Seq(ACTION, SPARK_VERSION, MASTER, DRIVER_ID, SUCCESS)
override val optionalFields = Seq(MESSAGE, DRIVER_STATE, WORKER_ID, WORKER_HOST_PORT)
}

/**
Expand All @@ -48,5 +49,5 @@ private[spark] class DriverStatusResponseMessage extends SubmitRestProtocolMessa
private[spark] object DriverStatusResponseMessage
extends SubmitRestProtocolMessageCompanion[DriverStatusResponseMessage] {
protected override def newMessage() = new DriverStatusResponseMessage
protected override def fieldFromString(field: String) = DriverStatusResponseField.fromString(field)
protected override def fieldFromString(f: String) = DriverStatusResponseField.fromString(f)
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ private[spark] class ErrorMessage extends SubmitRestProtocolMessage(

private[spark] object ErrorMessage extends SubmitRestProtocolMessageCompanion[ErrorMessage] {
protected override def newMessage() = new ErrorMessage
protected override def fieldFromString(field: String) = ErrorField.fromString(field)
protected override def fieldFromString(f: String) = ErrorField.fromString(f)
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ private[spark] class KillDriverRequestMessage extends SubmitRestProtocolMessage(
private[spark] object KillDriverRequestMessage
extends SubmitRestProtocolMessageCompanion[KillDriverRequestMessage] {
protected override def newMessage() = new KillDriverRequestMessage
protected override def fieldFromString(field: String) = KillDriverRequestField.fromString(field)
protected override def fieldFromString(f: String) = KillDriverRequestField.fromString(f)
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ private[spark] class KillDriverResponseMessage extends SubmitRestProtocolMessage
private[spark] object KillDriverResponseMessage
extends SubmitRestProtocolMessageCompanion[KillDriverResponseMessage] {
protected override def newMessage() = new KillDriverResponseMessage
protected override def fieldFromString(field: String) = KillDriverResponseField.fromString(field)
protected override def fieldFromString(f: String) = KillDriverResponseField.fromString(f)
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,58 @@ import org.apache.spark.util.Utils
* This client is intended to communicate with the StandaloneRestServer. Cluster mode only.
*/
private[spark] class StandaloneRestClient extends SubmitRestClient {
import StandaloneRestClient._

/**
* Request that the REST server submit a driver specified by the provided arguments.
*
* If the driver was successfully submitted, this polls the status of the driver that was
* just submitted and reports it to the user. Otherwise, if the submission was unsuccessful,
* this reports failure and logs an error message provided by the REST server.
*/
override def submitDriver(args: SparkSubmitArguments): SubmitDriverResponseMessage = {
import SubmitDriverResponseField._
val submitResponse = super.submitDriver(args).asInstanceOf[SubmitDriverResponseMessage]
val submitSuccess = submitResponse.getFieldNotNull(SUCCESS).toBoolean
if (submitSuccess) {
val driverId = submitResponse.getFieldNotNull(DRIVER_ID)
logInfo(s"Driver successfully submitted as $driverId. Polling driver state...")
pollSubmittedDriverStatus(args.master, driverId)
} else {
val submitMessage = submitResponse.getFieldNotNull(MESSAGE)
logError(s"Application submission failed: $submitMessage")
}
submitResponse
}

/**
* Poll the status of the driver that was just submitted and report it.
* This retries up to a fixed number of times until giving up.
*/
private def pollSubmittedDriverStatus(master: String, driverId: String): Unit = {
import DriverStatusResponseField._
(1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ =>
val statusResponse = requestDriverStatus(master, driverId)
.asInstanceOf[DriverStatusResponseMessage]
val statusSuccess = statusResponse.getFieldNotNull(SUCCESS).toBoolean
if (statusSuccess) {
val driverState = statusResponse.getFieldNotNull(DRIVER_STATE)
val workerId = statusResponse.getFieldOption(WORKER_ID)
val workerHostPort = statusResponse.getFieldOption(WORKER_HOST_PORT)
val exception = statusResponse.getFieldOption(MESSAGE)
logInfo(s"State of driver $driverId is now $driverState.")
// Log worker node, if present
(workerId, workerHostPort) match {
case (Some(id), Some(hp)) => logInfo(s"Driver is running on worker $id at $hp.")
case _ =>
}
// Log exception stack trace, if present
exception.foreach { e => logError(e) }
return
}
}
logError(s"Error: Master did not recognize driver $driverId.")
}

/** Construct a submit driver request message. */
override protected def constructSubmitRequest(
Expand All @@ -54,7 +106,7 @@ private[spark] class StandaloneRestClient extends SubmitRestClient {
args.childArgs.foreach(message.appendAppArg)
args.sparkProperties.foreach { case (k, v) => message.setSparkProperty(k, v) }
// TODO: send special environment variables?
message.validate()
message
}

/** Construct a kill driver request message. */
Expand All @@ -66,7 +118,6 @@ private[spark] class StandaloneRestClient extends SubmitRestClient {
.setField(SPARK_VERSION, sparkVersion)
.setField(MASTER, master)
.setField(DRIVER_ID, driverId)
.validate()
}

/** Construct a driver status request message. */
Expand All @@ -78,7 +129,6 @@ private[spark] class StandaloneRestClient extends SubmitRestClient {
.setField(SPARK_VERSION, sparkVersion)
.setField(MASTER, master)
.setField(DRIVER_ID, driverId)
.validate()
}

/** Throw an exception if this is not standalone mode. */
Expand All @@ -101,3 +151,8 @@ private[spark] class StandaloneRestClient extends SubmitRestClient {
new URL("http://" + master.stripPrefix("spark://"))
}
}

private object StandaloneRestClient {
val REPORT_DRIVER_STATUS_INTERVAL = 1000
val REPORT_DRIVER_STATUS_MAX_TRIES = 10
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ private[spark] class StandaloneRestServerHandler(
.setField(MASTER, masterUrl)
.setField(SUCCESS, response.success.toString)
.setFieldIfNotNull(DRIVER_ID, response.driverId.orNull)
.validate()
}

/** Handle a request to kill a driver. */
Expand All @@ -83,23 +82,29 @@ private[spark] class StandaloneRestServerHandler(
.setField(MASTER, masterUrl)
.setField(DRIVER_ID, driverId)
.setField(SUCCESS, response.success.toString)
.validate()
}

/** Handle a request for a driver's status. */
override protected def handleStatus(
request: DriverStatusRequestMessage): DriverStatusResponseMessage = {
import DriverStatusResponseField._
// TODO: Actually look up the status of the driver
val master = request.getField(DriverStatusRequestField.MASTER)
val driverId = request.getField(DriverStatusRequestField.DRIVER_ID)
val driverState = "HEALTHY"
val response = AkkaUtils.askWithReply[DriverStatusResponse](
RequestDriverStatus(driverId), masterActor, askTimeout)
// Format exception nicely, if it exists
val message = response.exception.map { e =>
val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n")
s"Exception from the cluster:\n$e\n$stackTraceString"
}
new DriverStatusResponseMessage()
.setField(SPARK_VERSION, sparkVersion)
.setField(MASTER, master)
.setField(MASTER, masterUrl)
.setField(DRIVER_ID, driverId)
.setField(DRIVER_STATE, driverState)
.validate()
.setField(SUCCESS, response.found.toString)
.setFieldIfNotNull(DRIVER_STATE, response.state.map(_.toString).orNull)
.setFieldIfNotNull(WORKER_ID, response.workerId.orNull)
.setFieldIfNotNull(WORKER_HOST_PORT, response.workerHostPort.orNull)
.setFieldIfNotNull(MESSAGE, message.orNull)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ private[spark] object SubmitDriverRequestMessage
import SubmitDriverRequestField._

protected override def newMessage() = new SubmitDriverRequestMessage
protected override def fieldFromString(field: String) = SubmitDriverRequestField.fromString(field)
protected override def fieldFromString(f: String) = SubmitDriverRequestField.fromString(f)

/**
* Process the given field and value appropriately based on the type of the field.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ private[spark] class SubmitDriverResponseMessage extends SubmitRestProtocolMessa
private[spark] object SubmitDriverResponseMessage
extends SubmitRestProtocolMessageCompanion[SubmitDriverResponseMessage] {
protected override def newMessage() = new SubmitDriverResponseMessage
protected override def fieldFromString(field: String) = SubmitDriverResponseField.fromString(field)
protected override def fieldFromString(f: String) = SubmitDriverResponseField.fromString(f)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

package org.apache.spark.deploy.rest

import java.io.DataOutputStream
import java.io.{DataOutputStream, FileNotFoundException}
import java.net.{HttpURLConnection, URL}

import scala.io.Source

import com.google.common.base.Charsets

import org.apache.spark.Logging
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.deploy.SparkSubmitArguments

/**
Expand All @@ -33,17 +33,17 @@ import org.apache.spark.deploy.SparkSubmitArguments
*/
private[spark] abstract class SubmitRestClient extends Logging {

/** Request that the REST server submits a driver specified by the provided arguments. */
def submitDriver(args: SparkSubmitArguments): Unit = {
/** Request that the REST server submit a driver specified by the provided arguments. */
def submitDriver(args: SparkSubmitArguments): SubmitRestProtocolMessage = {
validateSubmitArguments(args)
val url = getHttpUrl(args.master)
val request = constructSubmitRequest(args)
logInfo(s"Submitting a request to launch a driver in ${args.master}.")
sendHttp(url, request)
}

/** Request that the REST server kills the specified driver. */
def killDriver(master: String, driverId: String): Unit = {
/** Request that the REST server kill the specified driver. */
def killDriver(master: String, driverId: String): SubmitRestProtocolMessage = {
validateMaster(master)
val url = getHttpUrl(master)
val request = constructKillRequest(master, driverId)
Expand All @@ -52,7 +52,7 @@ private[spark] abstract class SubmitRestClient extends Logging {
}

/** Request the status of the specified driver from the REST server. */
def requestDriverStatus(master: String, driverId: String): Unit = {
def requestDriverStatus(master: String, driverId: String): SubmitRestProtocolMessage = {
validateMaster(master)
val url = getHttpUrl(master)
val request = constructStatusRequest(master, driverId)
Expand Down Expand Up @@ -82,18 +82,24 @@ private[spark] abstract class SubmitRestClient extends Logging {
* Return the response received from the REST server.
*/
private def sendHttp(url: URL, request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = {
val conn = url.openConnection().asInstanceOf[HttpURLConnection]
conn.setRequestMethod("POST")
conn.setRequestProperty("Content-Type", "application/json")
conn.setRequestProperty("charset", "utf-8")
conn.setDoOutput(true)
val requestJson = request.toJson
logDebug(s"Sending the following request to the REST server:\n$requestJson")
val out = new DataOutputStream(conn.getOutputStream)
out.write(requestJson.getBytes(Charsets.UTF_8))
out.close()
val responseJson = Source.fromInputStream(conn.getInputStream).mkString
logDebug(s"Response from the REST server:\n$responseJson")
SubmitRestProtocolMessage.fromJson(responseJson)
try {
val conn = url.openConnection().asInstanceOf[HttpURLConnection]
conn.setRequestMethod("POST")
conn.setRequestProperty("Content-Type", "application/json")
conn.setRequestProperty("charset", "utf-8")
conn.setDoOutput(true)
request.validate()
val requestJson = request.toJson
logDebug(s"Sending the following request to the REST server:\n$requestJson")
val out = new DataOutputStream(conn.getOutputStream)
out.write(requestJson.getBytes(Charsets.UTF_8))
out.close()
val responseJson = Source.fromInputStream(conn.getInputStream).mkString
logDebug(s"Response from the REST server:\n$responseJson")
SubmitRestProtocolMessage.fromJson(responseJson)
} catch {
case e: FileNotFoundException =>
throw new SparkException(s"Unable to connect to REST server $url", e)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,29 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi
*/
private def constructResponseMessage(
request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = {
// If the request is sent via the SubmitRestClient, it should have already been validated
// remotely. In case this is not true, validate the request here again to guard against
// potential NPEs. If validation fails, send an error message back to the sender.
try {
request.validate()
request match {
case submit: SubmitDriverRequestMessage => handleSubmit(submit)
case kill: KillDriverRequestMessage => handleKill(kill)
case status: DriverStatusRequestMessage => handleStatus(status)
case unexpected => handleError(
s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.")
// Validate the request message to ensure that it is correctly constructed. If the request
// is sent via the SubmitRestClient, it should have already been validated remotely. In case
// this is not true, do it again here to guard against potential NPEs. If validation fails,
// send an error message back to the sender.
val response =
try {
request.validate()
request match {
case submit: SubmitDriverRequestMessage => handleSubmit(submit)
case kill: KillDriverRequestMessage => handleKill(kill)
case status: DriverStatusRequestMessage => handleStatus(status)
case unexpected => handleError(
s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.")
}
} catch {
case e: IllegalArgumentException => handleError(e.getMessage)
}
// Validate the response message to ensure that it is correctly constructed. If it is not,
// propagate the exception back to the client and signal that it is a server error.
try {
response.validate()
} catch {
// Propagate exception to user in an ErrorMessage.
// Note that the construction of the error message itself may throw an exception.
// In this case, let the higher level caller take care of this request.
case e: IllegalArgumentException => handleError(e.getMessage)
case e: IllegalArgumentException => handleError(s"Internal server error: ${e.getMessage}")
}
}

Expand All @@ -126,6 +132,5 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi
new ErrorMessage()
.setField(SPARK_VERSION, sparkVersion)
.setField(MESSAGE, message)
.validate()
}
}

0 comments on commit b44e103

Please sign in to comment.