Skip to content

Commit

Permalink
Clean up REST response output in Spark submit
Browse files Browse the repository at this point in the history
Now we don't log a response twice or log an error message twice.
Also, before we would actually throw a ClassCastException if the
server returns an error due to type erasure. This commit eases
the relevant complexity involved.
  • Loading branch information
Andrew Or committed Feb 1, 2015
1 parent b2fef8b commit ade28fd
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 29 deletions.
30 changes: 26 additions & 4 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import java.net.URL

import scala.collection.mutable.{ArrayBuffer, HashMap, Map}

import org.apache.spark.deploy.rest._
import org.apache.spark.executor.ExecutorURLClassLoader
import org.apache.spark.util.Utils
import org.apache.spark.deploy.rest.StandaloneRestClient

/**
* Whether to submit, kill, or request the status of an application.
Expand Down Expand Up @@ -95,7 +95,10 @@ object SparkSubmit {
private def kill(args: SparkSubmitArguments): Unit = {
val client = new StandaloneRestClient
val response = client.killDriver(args.master, args.driverToKill)
printStream.println(response.toJson)
response match {
case k: KillDriverResponse => handleRestResponse(k)
case r => handleUnexpectedRestResponse(r)
}
}

/**
Expand All @@ -105,7 +108,10 @@ object SparkSubmit {
private def requestStatus(args: SparkSubmitArguments): Unit = {
val client = new StandaloneRestClient
val response = client.requestDriverStatus(args.master, args.driverToRequestStatusFor)
printStream.println(response.toJson)
response match {
case s: DriverStatusResponse => handleRestResponse(s)
case r => handleUnexpectedRestResponse(r)
}
}

/**
Expand All @@ -126,7 +132,12 @@ object SparkSubmit {
val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args)
if (args.isStandaloneCluster && args.isRestEnabled) {
printStream.println("Running Spark using the REST application submission protocol.")
new StandaloneRestClient().submitDriver(args)
val client = new StandaloneRestClient
val response = client.submitDriver(args)
response match {
case s: SubmitDriverResponse => handleRestResponse(s)
case r => handleUnexpectedRestResponse(r)
}
} else {
runMain(childArgs, childClasspath, sysProps, childMainClass)
}
Expand Down Expand Up @@ -461,6 +472,17 @@ object SparkSubmit {
}
}

/** Log the response sent by the server in the REST application submission protocol. */
private def handleRestResponse(response: SubmitRestProtocolResponse): Unit = {
printStream.println(s"Server responded with ${response.messageType}:\n${response.toJson}")
}

/** Log an appropriate error if the response sent by the server is not of the expected type. */
private def handleUnexpectedRestResponse(unexpected: SubmitRestProtocolResponse): Unit = {
printStream.println(
s"Error: Server responded with message of unexpected type ${unexpected.messageType}.")
}

/**
* Return whether the given primary resource represents a user jar.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ private[spark] class StandaloneRestClient extends SubmitRestClient {
override def submitDriver(args: SparkSubmitArguments): SubmitRestProtocolResponse = {
validateSubmitArgs(args)
val response = super.submitDriver(args)
val submitResponse = getResponse[SubmitDriverResponse](response).getOrElse { return response }
val submitResponse = response match {
case s: SubmitDriverResponse => s
case _ => return response
}
val submitSuccess = submitResponse.getSuccess.toBoolean
if (submitSuccess) {
val driverId = submitResponse.getDriverId
Expand Down Expand Up @@ -71,7 +74,10 @@ private[spark] class StandaloneRestClient extends SubmitRestClient {
private def pollSubmittedDriverStatus(master: String, driverId: String): Unit = {
(1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ =>
val response = requestDriverStatus(master, driverId)
val statusResponse = getResponse[DriverStatusResponse](response).getOrElse { return }
val statusResponse = response match {
case s: DriverStatusResponse => s
case _ => return
}
val statusSuccess = statusResponse.getSuccess.toBoolean
if (statusSuccess) {
val driverState = Option(statusResponse.getDriverState)
Expand Down Expand Up @@ -160,23 +166,6 @@ private[spark] class StandaloneRestClient extends SubmitRestClient {
"This REST client is only supported in standalone cluster mode.")
}
}

/**
* Return the response as the expected type, or fail with an informative error message.
* Exposed for testing.
*/
private[spark] def getResponse[T <: SubmitRestProtocolResponse](
response: SubmitRestProtocolResponse): Option[T] = {
try {
// Do not match on type T because types are erased at runtime
// Instead, manually try to cast it to type T ourselves
Some(response.asInstanceOf[T])
} catch {
case e: ClassCastException =>
logError(s"Server returned response of unexpected type:\n${response.toJson}")
None
}
}
}

private object StandaloneRestClient {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[spark] abstract class SubmitRestClient extends Logging {
val url = getHttpUrl(args.master)
val request = constructSubmitRequest(args)
val response = sendHttp(url, request)
validateResponse(response)
handleResponse(response)
}

/** Request that the REST server kill the specified driver. */
Expand All @@ -48,7 +48,7 @@ private[spark] abstract class SubmitRestClient extends Logging {
val url = getHttpUrl(master)
val request = constructKillRequest(master, driverId)
val response = sendHttp(url, request)
validateResponse(response)
handleResponse(response)
}

/** Request the status of the specified driver from the REST server. */
Expand All @@ -57,7 +57,7 @@ private[spark] abstract class SubmitRestClient extends Logging {
val url = getHttpUrl(master)
val request = constructStatusRequest(master, driverId)
val response = sendHttp(url, request)
validateResponse(response)
handleResponse(response)
}

/** Return the HTTP URL of the REST server that corresponds to the given master URL. */
Expand Down Expand Up @@ -95,10 +95,14 @@ private[spark] abstract class SubmitRestClient extends Logging {
}
}

/** Validate the response... */
private def validateResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = {
/** Validate the response and log any error messages provided by the server. */
private def handleResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = {
try {
response.validate()
response match {
case e: ErrorResponse => logError(s"Server responded with error:\n${e.getMessage}")
case _ =>
}
} catch {
case e: SubmitRestProtocolException =>
throw new SubmitRestProtocolException("Malformed response received from server", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.util.Utils
@JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY)
@JsonPropertyOrder(alphabetic = true)
abstract class SubmitRestProtocolMessage {
private val messageType = Utils.getFormattedClassName(this)
val messageType = Utils.getFormattedClassName(this)
protected val action: String = messageType
protected val sparkVersion: SubmitRestProtocolField[String]
protected val message = new SubmitRestProtocolField[String]("message")
Expand Down

0 comments on commit ade28fd

Please sign in to comment.