Skip to content

Commit

Permalink
Use nullable fields for integer and boolean values
Browse files Browse the repository at this point in the history
This allows us to use the raw values for these types of fields
in the JSON instead of a string version of them.
  • Loading branch information
Andrew Or committed Feb 6, 2015
1 parent dfe4bd7 commit 6f0c597
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ private[spark] class StandaloneRestClient extends Logging {
private def reportSubmissionStatus(
master: String,
submitResponse: CreateSubmissionResponse): Unit = {
val submitSuccess = submitResponse.success.toBoolean
if (submitSuccess) {
if (submitResponse.success) {
val submissionId = submitResponse.submissionId
if (submissionId != null) {
logInfo(s"Submission successfully created as $submissionId. Polling submission state...")
Expand All @@ -245,8 +244,7 @@ private[spark] class StandaloneRestClient extends Logging {
case s: SubmissionStatusResponse => s
case _ => return // unexpected type, let upstream caller handle it
}
val statusSuccess = statusResponse.success.toBoolean
if (statusSuccess) {
if (statusResponse.success) {
val driverState = Option(statusResponse.driverState)
val workerId = Option(statusResponse.workerId)
val workerHostPort = Option(statusResponse.workerHostPort)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
k.serverSparkVersion = sparkVersion
k.message = response.message
k.submissionId = submissionId
k.success = response.success.toString
k.success = response.success
k
}
}
Expand Down Expand Up @@ -275,7 +275,7 @@ private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
val d = new SubmissionStatusResponse
d.serverSparkVersion = sparkVersion
d.submissionId = submissionId
d.success = response.found.toString
d.success = response.found
d.driverState = response.state.map(_.toString).orNull
d.workerId = response.workerId.orNull
d.workerHostPort = response.workerHostPort.orNull
Expand Down Expand Up @@ -339,7 +339,7 @@ private class SubmitRequestServlet(
val submitResponse = new CreateSubmissionResponse
submitResponse.serverSparkVersion = sparkVersion
submitResponse.message = response.message
submitResponse.success = response.success.toString
submitResponse.success = response.success
submitResponse.submissionId = response.driverId.orNull
val unknownFields = findUnknownFields(requestMessageJson, requestMessage)
if (unknownFields.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,45 +81,12 @@ private[spark] abstract class SubmitRestProtocolMessage {
}

/** Assert that the specified field is set in this message. */
protected def assertFieldIsSet(value: String, name: String): Unit = {
protected def assertFieldIsSet[T](value: T, name: String): Unit = {
if (value == null) {
throw new SubmitRestMissingFieldException(s"'$name' is missing in message $messageType.")
}
}

/** Assert that the value of the specified field is a boolean. */
protected def assertFieldIsBoolean(value: String, name: String): Unit = {
if (value != null) {
Try(value.toBoolean).getOrElse {
throw new SubmitRestProtocolException(
s"'$name' expected boolean value: actual was '$value'.")
}
}
}

/** Assert that the value of the specified field is a numeric. */
protected def assertFieldIsNumeric(value: String, name: String): Unit = {
if (value != null) {
Try(value.toInt).getOrElse {
throw new SubmitRestProtocolException(
s"'$name' expected numeric value: actual was '$value'.")
}
}
}

/**
* Assert that the value of the specified field is a memory string.
* Examples of valid memory strings include 3g, 512m, 128k, 4096.
*/
protected def assertFieldIsMemory(value: String, name: String): Unit = {
if (value != null) {
Try(Utils.memoryStringToMb(value)).getOrElse {
throw new SubmitRestProtocolException(
s"'$name' expected memory value: actual was '$value'.")
}
}
}

/**
* Assert a condition when validating this message.
* If the assertion fails, throw a [[SubmitRestProtocolException]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.deploy.rest

import scala.util.Try

import org.apache.spark.util.Utils

/**
* An abstract request sent from the client in the REST application submission protocol.
*/
Expand Down Expand Up @@ -54,11 +58,21 @@ private[spark] class CreateSubmissionRequest extends SubmitRestProtocolRequest {
assertFieldIsSet(sparkProperties.getOrElse(key, null), key)

private def assertPropertyIsBoolean(key: String): Unit =
assertFieldIsBoolean(sparkProperties.getOrElse(key, null), key)
assertProperty[Boolean](key, "boolean", _.toBoolean)

private def assertPropertyIsNumeric(key: String): Unit =
assertFieldIsNumeric(sparkProperties.getOrElse(key, null), key)
assertProperty[Int](key, "numeric", _.toInt)

private def assertPropertyIsMemory(key: String): Unit =
assertFieldIsMemory(sparkProperties.getOrElse(key, null), key)
assertProperty[Int](key, "memory", Utils.memoryStringToMb)

/** Assert that a Spark property can be converted to a certain type. */
private def assertProperty[T](key: String, valueType: String, convert: (String => T)): Unit = {
sparkProperties.get(key).foreach { value =>
Try(convert(value)).getOrElse {
throw new SubmitRestProtocolException(
s"Property '$key' expected $valueType value: actual was '$value'.")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@

package org.apache.spark.deploy.rest

import java.lang.Boolean

/**
* An abstract response sent from the server in the REST application submission protocol.
*/
private[spark] abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage {
var serverSparkVersion: String = null
var success: String = null
var success: Boolean = null
var unknownFields: Array[String] = null
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(serverSparkVersion, "serverSparkVersion")
assertFieldIsBoolean(success, "success")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterAll with Bef
val response = client.killSubmission(masterRestUrl, "submission-that-does-not-exist")
val killResponse = getKillResponse(response)
val killSuccess = killResponse.success
assert(killSuccess === "false")
assert(!killSuccess)
}

test("kill running submission") {
Expand All @@ -78,8 +78,8 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterAll with Bef
val statusResponse = getStatusResponse(response2)
val statusSuccess = statusResponse.success
val driverState = statusResponse.driverState
assert(killSuccess === "true")
assert(statusSuccess === "true")
assert(killSuccess)
assert(statusSuccess)
assert(driverState === DriverState.KILLED.toString)
// we should not see the expected results because we killed the submission
intercept[TestFailedException] { validateResult(resultsFile, numbers, size) }
Expand All @@ -89,7 +89,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterAll with Bef
val response = client.requestSubmissionStatus(masterRestUrl, "submission-that-does-not-exist")
val statusResponse = getStatusResponse(response)
val statusSuccess = statusResponse.success
assert(statusSuccess === "false")
assert(!statusSuccess)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.deploy.rest

import java.lang.Boolean
import java.lang.Integer

import org.json4s.jackson.JsonMethods._
import org.scalatest.FunSuite

Expand All @@ -34,9 +37,9 @@ class SubmitRestProtocolSuite extends FunSuite {
intercept[SubmitRestProtocolException] { request.validate() } // missing name and age
request.name = "something"
intercept[SubmitRestProtocolException] { request.validate() } // missing only age
request.age = "2"
request.age = 2
intercept[SubmitRestProtocolException] { request.validate() } // age too low
request.age = "10"
request.age = 10
request.validate() // everything is set properly
request.clientSparkVersion = null
intercept[SubmitRestProtocolException] { request.validate() } // missing only Spark version
Expand All @@ -47,52 +50,34 @@ class SubmitRestProtocolSuite extends FunSuite {
intercept[SubmitRestProtocolException] { request.validate() } // still missing name
}

test("validate with illegal argument") {
val request = new DummyRequest
request.clientSparkVersion = "1.2.3"
request.name = "abc"
request.age = "not-a-number"
intercept[SubmitRestProtocolException] { request.validate() }
request.age = "true"
intercept[SubmitRestProtocolException] { request.validate() }
request.age = "150"
request.validate()
request.active = "not-a-boolean"
intercept[SubmitRestProtocolException] { request.validate() }
request.active = "150"
intercept[SubmitRestProtocolException] { request.validate() }
request.active = "true"
request.validate()
}

test("request to and from JSON") {
val request = new DummyRequest
intercept[SubmitRestProtocolException] { request.toJson } // implicit validation
request.clientSparkVersion = "1.2.3"
request.active = "true"
request.age = "25"
request.active = true
request.age = 25
request.name = "jung"
val json = request.toJson
assertJsonEquals(json, dummyRequestJson)
val newRequest = SubmitRestProtocolMessage.fromJson(json, classOf[DummyRequest])
assert(newRequest.clientSparkVersion === "1.2.3")
assert(newRequest.clientSparkVersion === "1.2.3")
assert(newRequest.active === "true")
assert(newRequest.age === "25")
assert(newRequest.active)
assert(newRequest.age === 25)
assert(newRequest.name === "jung")
assert(newRequest.message === null)
}

test("response to and from JSON") {
val response = new DummyResponse
response.serverSparkVersion = "3.3.4"
response.success = "true"
response.success = true
val json = response.toJson
assertJsonEquals(json, dummyResponseJson)
val newResponse = SubmitRestProtocolMessage.fromJson(json, classOf[DummyResponse])
assert(newResponse.serverSparkVersion === "3.3.4")
assert(newResponse.serverSparkVersion === "3.3.4")
assert(newResponse.success === "true")
assert(newResponse.success)
assert(newResponse.message === null)
}

Expand Down Expand Up @@ -160,64 +145,52 @@ class SubmitRestProtocolSuite extends FunSuite {
intercept[SubmitRestProtocolException] { message.validate() }
message.serverSparkVersion = "1.2.3"
message.submissionId = "driver_123"
message.success = "true"
message.success = true
message.validate()
// bad fields
message.success = "maybe not"
intercept[SubmitRestProtocolException] { message.validate() }
message.success = "true"
// test JSON
val json = message.toJson
assertJsonEquals(json, submitDriverResponseJson)
val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[CreateSubmissionResponse])
assert(newMessage.serverSparkVersion === "1.2.3")
assert(newMessage.submissionId === "driver_123")
assert(newMessage.success === "true")
assert(newMessage.success)
}

test("KillSubmissionResponse") {
val message = new KillSubmissionResponse
intercept[SubmitRestProtocolException] { message.validate() }
message.serverSparkVersion = "1.2.3"
message.submissionId = "driver_123"
message.success = "true"
message.success = true
message.validate()
// bad fields
message.success = "maybe not"
intercept[SubmitRestProtocolException] { message.validate() }
message.success = "true"
// test JSON
val json = message.toJson
assertJsonEquals(json, killDriverResponseJson)
val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillSubmissionResponse])
assert(newMessage.serverSparkVersion === "1.2.3")
assert(newMessage.submissionId === "driver_123")
assert(newMessage.success === "true")
assert(newMessage.success)
}

test("SubmissionStatusResponse") {
val message = new SubmissionStatusResponse
intercept[SubmitRestProtocolException] { message.validate() }
message.serverSparkVersion = "1.2.3"
message.submissionId = "driver_123"
message.success = "true"
message.success = true
message.validate()
// optional fields
message.driverState = "RUNNING"
message.workerId = "worker_123"
message.workerHostPort = "1.2.3.4:7780"
// bad fields
message.success = "maybe"
intercept[SubmitRestProtocolException] { message.validate() }
message.success = "true"
// test JSON
val json = message.toJson
assertJsonEquals(json, driverStatusResponseJson)
val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[SubmissionStatusResponse])
assert(newMessage.serverSparkVersion === "1.2.3")
assert(newMessage.submissionId === "driver_123")
assert(newMessage.driverState === "RUNNING")
assert(newMessage.success === "true")
assert(newMessage.success)
assert(newMessage.workerId === "worker_123")
assert(newMessage.workerHostPort === "1.2.3.4:7780")
}
Expand All @@ -240,8 +213,8 @@ class SubmitRestProtocolSuite extends FunSuite {
"""
|{
| "action" : "DummyRequest",
| "active" : "true",
| "age" : "25",
| "active" : true,
| "age" : 25,
| "clientSparkVersion" : "1.2.3",
| "name" : "jung"
|}
Expand All @@ -252,7 +225,7 @@ class SubmitRestProtocolSuite extends FunSuite {
|{
| "action" : "DummyResponse",
| "serverSparkVersion" : "3.3.4",
| "success": "true"
| "success": true
|}
""".stripMargin

Expand Down Expand Up @@ -289,7 +262,7 @@ class SubmitRestProtocolSuite extends FunSuite {
| "action" : "CreateSubmissionResponse",
| "serverSparkVersion" : "1.2.3",
| "submissionId" : "driver_123",
| "success" : "true"
| "success" : true
|}
""".stripMargin

Expand All @@ -299,7 +272,7 @@ class SubmitRestProtocolSuite extends FunSuite {
| "action" : "KillSubmissionResponse",
| "serverSparkVersion" : "1.2.3",
| "submissionId" : "driver_123",
| "success" : "true"
| "success" : true
|}
""".stripMargin

Expand All @@ -310,7 +283,7 @@ class SubmitRestProtocolSuite extends FunSuite {
| "driverState" : "RUNNING",
| "serverSparkVersion" : "1.2.3",
| "submissionId" : "driver_123",
| "success" : "true",
| "success" : true,
| "workerHostPort" : "1.2.3.4:7780",
| "workerId" : "worker_123"
|}
Expand Down Expand Up @@ -339,15 +312,13 @@ class SubmitRestProtocolSuite extends FunSuite {

private class DummyResponse extends SubmitRestProtocolResponse
private class DummyRequest extends SubmitRestProtocolRequest {
var active: String = null
var age: String = null
var active: Boolean = null
var age: Integer = null
var name: String = null
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(name, "name")
assertFieldIsSet(age, "age")
assertFieldIsBoolean(active, "active")
assertFieldIsNumeric(age, "age")
assert(age.toInt > 5, "Not old enough!")
assert(age > 5, "Not old enough!")
}
}

0 comments on commit 6f0c597

Please sign in to comment.