Skip to content

Commit

Permalink
Fall back to Akka if endpoint was not REST
Browse files Browse the repository at this point in the history
In this commit we also introduce a new type of exception to serve
this purpose and privatize as many classes as possible.
  • Loading branch information
Andrew Or committed Feb 4, 2015
1 parent 252d53c commit 9165ae8
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 56 deletions.
40 changes: 26 additions & 14 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -124,22 +124,33 @@ object SparkSubmit {
* running the child main class based on the cluster manager and the deploy mode.
* Second, we use this launch environment to invoke the main method of the child
* main class.
*
* As of Spark 1.3, a REST-based application submission gateway is introduced.
* If this is enabled, then we will run standalone cluster mode by passing the submit
* parameters directly to a REST client, which will submit the application using the
* REST protocol instead.
*/
private[spark] def submit(args: SparkSubmitArguments): Unit = {
val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args)
if (args.isStandaloneCluster && args.isRestEnabled) {
printStream.println("Running Spark using the REST application submission protocol.")
val client = new StandaloneRestClient
val response = client.createSubmission(args)
response match {
case s: CreateSubmissionResponse => handleRestResponse(s)
case r => handleUnexpectedRestResponse(r)
/*
* In standalone cluster mode, there are two submission gateways:
* (1) The traditional Akka gateway using o.a.s.deploy.Client as a wrapper
* (2) The new REST-based gateway introduced in Spark 1.3
* The latter is the default behavior as of Spark 1.3, but Spark submit will fail over
* to use the legacy gateway if the master endpoint turns out to be not a REST server.
*/
if (args.isStandaloneCluster) {
try {
printStream.println("Running Spark using the REST application submission protocol.")
val client = new StandaloneRestClient
val response = client.createSubmission(args)
response match {
case s: CreateSubmissionResponse => handleRestResponse(s)
case r => handleUnexpectedRestResponse(r)
}
} catch {
// Fail over to use the legacy submission gateway
case e: SubmitRestConnectionException =>
printStream.println(s"Master endpoint ${args.master} was not a " +
s"REST server. Falling back to legacy submission gateway instead.")
runMain(childArgs, childClasspath, sysProps, childMainClass)
}
// In all other modes, just run the main class as prepared
} else {
runMain(childArgs, childClasspath, sysProps, childMainClass)
}
Expand All @@ -152,6 +163,7 @@ object SparkSubmit {
* (2) a list of classpath entries for the child,
* (3) a list of system properties and env vars, and
* (4) the main class for the child
* In standalone cluster mode, this mutates the original arguments passed in.
* Exposed for testing.
*/
private[spark] def prepareSubmitEnvironment(args: SparkSubmitArguments)
Expand Down Expand Up @@ -347,7 +359,7 @@ object SparkSubmit {

// In standalone-cluster mode, use Client as a wrapper around the user class
// Note that we won't actually launch this class if we're using the REST protocol
if (args.isStandaloneCluster && !args.isRestEnabled) {
if (args.isStandaloneCluster) {
childMainClass = "org.apache.spark.deploy.Client"
if (args.supervise) {
childArgs += "--supervise"
Expand Down Expand Up @@ -419,7 +431,7 @@ object SparkSubmit {
// NOTE: If we are using the REST gateway, we will use the original arguments directly.
// Since we mutate the values of some configs in this method, we must update the
// corresponding fields in the original SparkSubmitArguments to reflect these changes.
if (args.isStandaloneCluster && args.isRestEnabled) {
if (args.isStandaloneCluster) {
args.sparkProperties.clear()
args.sparkProperties ++= sysProps
sysProps.get("spark.jars").foreach { args.jars = _ }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
var driverToKill: String = null
var driverToRequestStatusFor: String = null

private val restEnabledKey = "spark.submit.rest.enabled"

/** Default properties present in the currently defined defaults file. */
lazy val defaultSparkProperties: HashMap[String, String] = {
val defaultProperties = new HashMap[String, String]()
Expand Down Expand Up @@ -220,10 +218,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
if (!isStandaloneCluster) {
SparkSubmit.printErrorAndExit("Killing drivers is only supported in standalone cluster mode")
}
if (!isRestEnabled) {
SparkSubmit.printErrorAndExit("Killing drivers is currently only supported " +
s"through the REST interface. Please set $restEnabledKey to true.")
}
if (driverToKill == null) {
SparkSubmit.printErrorAndExit("Please specify a driver to kill")
}
Expand All @@ -234,10 +228,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
SparkSubmit.printErrorAndExit(
"Requesting driver statuses is only supported in standalone cluster mode")
}
if (!isRestEnabled) {
SparkSubmit.printErrorAndExit("Requesting driver statuses is currently only " +
s"supported through the REST interface. Please set $restEnabledKey to true.")
}
if (driverToRequestStatusFor == null) {
SparkSubmit.printErrorAndExit("Please specify a driver to request status for")
}
Expand All @@ -247,11 +237,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
master.startsWith("spark://") && deployMode == "cluster"
}

/** Return whether the REST application submission protocol is enabled. */
def isRestEnabled: Boolean = {
sparkProperties.get(restEnabledKey).getOrElse("false").toBoolean
}

override def toString = {
s"""Parsed arguments:
| master $master
Expand Down Expand Up @@ -472,8 +457,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| Spark standalone with cluster deploy mode only:
| --driver-cores NUM Cores for driver (Default: 1).
| --supervise If given, restarts the driver on failure.
| --kill DRIVER_ID If given, kills the driver specified.
| --status DRIVER_ID If given, requests the status of the driver specified.
| --kill SUBMISSION_ID If given, kills the driver specified.
| --status SUBMISSION_ID If given, requests the status of the driver specified.
|
| Spark standalone and Mesos only:
| --total-executor-cores NUM Total cores for all executors.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

package org.apache.spark.deploy.rest

import java.io.{FileNotFoundException, DataOutputStream}
import java.net.{HttpURLConnection, URL}
import java.io.{DataOutputStream, FileNotFoundException}
import java.net.{HttpURLConnection, SocketException, URL}

import scala.io.Source

import com.google.common.base.Charsets

import org.apache.spark.{Logging, SparkException, SPARK_VERSION => sparkVersion}
import org.apache.spark.{Logging, SPARK_VERSION => sparkVersion}
import org.apache.spark.deploy.SparkSubmitArguments

/**
Expand All @@ -43,8 +43,8 @@ import org.apache.spark.deploy.SparkSubmitArguments
* Additionally, the base URL includes the version of the protocol. For instance:
* http://1.2.3.4:6066/v1/submissions/create. Since the protocol is expected to be stable
* across Spark versions, existing fields cannot be added or removed. In the rare event that
* backward compatibility is broken, Spark must introduce a new protocol version (e.g. v2).
* The client and the server must communicate on the same version of the protocol.
* forward or backward compatibility is broken, Spark must introduce a new protocol version
* (e.g. v2). The client and the server must communicate on the same version of the protocol.
*/
private[spark] class StandaloneRestClient extends Logging {
import StandaloneRestClient._
Expand Down Expand Up @@ -123,7 +123,7 @@ private[spark] class StandaloneRestClient extends Logging {
private def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = {
try {
val responseJson = Source.fromInputStream(connection.getInputStream).mkString
logDebug(s"Response from the REST server:\n$responseJson")
logDebug(s"Response from the server:\n$responseJson")
val response = SubmitRestProtocolMessage.fromJson(responseJson)
// The response should have already been validated on the server.
// In case this is not true, validate it ourselves to avoid potential NPEs.
Expand All @@ -139,39 +139,40 @@ private[spark] class StandaloneRestClient extends Logging {
case error: ErrorResponse =>
logError(s"Server responded with error:\n${error.message}")
error
case response: SubmitRestProtocolResponse =>
response
case response: SubmitRestProtocolResponse => response
case unexpected =>
throw new SubmitRestProtocolException(
s"Unexpected message received from server:\n$unexpected")
s"Message received from server was not a response:\n${unexpected.toJson}")
}
} catch {
case e: FileNotFoundException =>
throw new SparkException(s"Unable to connect to server ${connection.getURL}", e)
case e @ (_: FileNotFoundException | _: SocketException) =>
throw new SubmitRestConnectionException(
s"Unable to connect to server ${connection.getURL}", e)
}
}

/** Return the REST URL for creating a new submission. */
private def getSubmitUrl(master: String): URL = {
val baseUrl = getBaseUrl(master)
new URL(s"$baseUrl/submissions/create")
new URL(s"$baseUrl/create")
}

/** Return the REST URL for killing an existing submission. */
private def getKillUrl(master: String, submissionId: String): URL = {
val baseUrl = getBaseUrl(master)
new URL(s"$baseUrl/submissions/kill/$submissionId")
new URL(s"$baseUrl/kill/$submissionId")
}

/** Return the REST URL for requesting the status of an existing submission. */
private def getStatusUrl(master: String, submissionId: String): URL = {
val baseUrl = getBaseUrl(master)
new URL(s"$baseUrl/submissions/status/$submissionId")
new URL(s"$baseUrl/status/$submissionId")
}

/** Return the base URL for communicating with the server, including the protocol version. */
private def getBaseUrl(master: String): String = {
"http://" + master.stripPrefix("spark://").stripSuffix("/") + "/" + PROTOCOL_VERSION
val masterUrl = master.stripPrefix("spark://").stripSuffix("/")
s"http://$masterUrl/$PROTOCOL_VERSION/submissions"
}

/** Throw an exception if this is not standalone mode. */
Expand Down Expand Up @@ -223,10 +224,11 @@ private[spark] class StandaloneRestClient extends Logging {
if (submitSuccess) {
val submissionId = submitResponse.submissionId
if (submissionId != null) {
logInfo(s"Driver successfully submitted as $submissionId. Polling driver state...")
logInfo(s"Submission successfully created as $submissionId. Polling submission state...")
pollSubmissionStatus(master, submissionId)
} else {
logError("Application successfully submitted, but driver ID was not provided!")
// should never happen
logError("Application successfully submitted, but submission ID was not provided!")
}
} else {
val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("")
Expand Down Expand Up @@ -267,7 +269,7 @@ private[spark] class StandaloneRestClient extends Logging {
}
Thread.sleep(REPORT_DRIVER_STATUS_INTERVAL)
}
logError(s"Error: Master did not recognize submission $submissionId.")
logError(s"Error: Master did not recognize driver $submissionId.")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,17 @@ package org.apache.spark.deploy.rest
/**
* An exception thrown in the REST application submission protocol.
*/
class SubmitRestProtocolException(message: String, cause: Exception = null)
private[spark] class SubmitRestProtocolException(message: String, cause: Throwable = null)
extends Exception(message, cause)

/**
* An exception thrown if a field is missing from a [[SubmitRestProtocolMessage]].
*/
class SubmitRestMissingFieldException(message: String) extends SubmitRestProtocolException(message)
private[spark] class SubmitRestMissingFieldException(message: String)
extends SubmitRestProtocolException(message)

/**
* An exception thrown if the REST client cannot reach the REST server.
*/
private[spark] class SubmitRestConnectionException(message: String, cause: Throwable)
extends SubmitRestProtocolException(message, cause)
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.util.Utils
@JsonInclude(Include.NON_NULL)
@JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY)
@JsonPropertyOrder(alphabetic = true)
abstract class SubmitRestProtocolMessage {
private[spark] abstract class SubmitRestProtocolMessage {
@JsonIgnore
val messageType = Utils.getFormattedClassName(this)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ import org.apache.spark.deploy.worker.Worker
/**
* End-to-end tests for the REST application submission protocol in standalone mode.
*/
class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
private val systemsToStop = new ArrayBuffer[ActorSystem]
private val masterRestUrl = startLocalCluster()
private val client = new StandaloneRestClient
private val mainJar = StandaloneRestProtocolSuite.createJar()
private val mainJar = StandaloneRestSubmitSuite.createJar()
private val mainClass = StandaloneRestApp.getClass.getName.stripSuffix("$")

override def afterAll() {
Expand Down Expand Up @@ -125,7 +125,6 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B
"--master", masterRestUrl,
"--name", mainClass,
"--class", mainClass,
"--conf", "spark.submit.rest.enabled=true",
mainJar) ++ appArgs
val args = new SparkSubmitArguments(commandLineArgs)
SparkSubmit.prepareSubmitEnvironment(args)
Expand Down Expand Up @@ -195,7 +194,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B
}
}

private object StandaloneRestProtocolSuite {
private object StandaloneRestSubmitSuite {
private val pathPrefix = "org/apache/spark/deploy/rest"

/**
Expand Down

0 comments on commit 9165ae8

Please sign in to comment.