From 6f0c5970b842bef25685e4d7b1b4cfacd43db0f9 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 6 Feb 2015 14:18:00 -0800 Subject: [PATCH] Use nullable fields for integer and boolean values This allows us to use the raw values for these types of fields in the JSON instead of a string version of them. --- .../deploy/rest/StandaloneRestClient.scala | 6 +- .../deploy/rest/StandaloneRestServer.scala | 6 +- .../rest/SubmitRestProtocolMessage.scala | 35 +------- .../rest/SubmitRestProtocolRequest.scala | 20 ++++- .../rest/SubmitRestProtocolResponse.scala | 5 +- .../rest/StandaloneRestSubmitSuite.scala | 8 +- .../deploy/rest/SubmitRestProtocolSuite.scala | 81 ++++++------------- 7 files changed, 56 insertions(+), 105 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index adeaf6d795db1..115aa5278bb62 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -218,8 +218,7 @@ private[spark] class StandaloneRestClient extends Logging { private def reportSubmissionStatus( master: String, submitResponse: CreateSubmissionResponse): Unit = { - val submitSuccess = submitResponse.success.toBoolean - if (submitSuccess) { + if (submitResponse.success) { val submissionId = submitResponse.submissionId if (submissionId != null) { logInfo(s"Submission successfully created as $submissionId. Polling submission state...") @@ -245,8 +244,7 @@ private[spark] class StandaloneRestClient extends Logging { case s: SubmissionStatusResponse => s case _ => return // unexpected type, let upstream caller handle it } - val statusSuccess = statusResponse.success.toBoolean - if (statusSuccess) { + if (statusResponse.success) { val driverState = Option(statusResponse.driverState) val workerId = Option(statusResponse.workerId) val workerHostPort = Option(statusResponse.workerHostPort) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index d580c63701740..2033d67e1f394 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -238,7 +238,7 @@ private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) k.serverSparkVersion = sparkVersion k.message = response.message k.submissionId = submissionId - k.success = response.success.toString + k.success = response.success k } } @@ -275,7 +275,7 @@ private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) val d = new SubmissionStatusResponse d.serverSparkVersion = sparkVersion d.submissionId = submissionId - d.success = response.found.toString + d.success = response.found d.driverState = response.state.map(_.toString).orNull d.workerId = response.workerId.orNull d.workerHostPort = response.workerHostPort.orNull @@ -339,7 +339,7 @@ private class SubmitRequestServlet( val submitResponse = new CreateSubmissionResponse submitResponse.serverSparkVersion = sparkVersion submitResponse.message = response.message - submitResponse.success = response.success.toString + submitResponse.success = response.success submitResponse.submissionId = response.driverId.orNull val unknownFields = findUnknownFields(requestMessageJson, requestMessage) if (unknownFields.nonEmpty) { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index 7e3c7bec5494f..b877898231e3e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -81,45 +81,12 @@ private[spark] abstract class SubmitRestProtocolMessage { } /** Assert that the specified field is set in this message. */ - protected def assertFieldIsSet(value: String, name: String): Unit = { + protected def assertFieldIsSet[T](value: T, name: String): Unit = { if (value == null) { throw new SubmitRestMissingFieldException(s"'$name' is missing in message $messageType.") } } - /** Assert that the value of the specified field is a boolean. */ - protected def assertFieldIsBoolean(value: String, name: String): Unit = { - if (value != null) { - Try(value.toBoolean).getOrElse { - throw new SubmitRestProtocolException( - s"'$name' expected boolean value: actual was '$value'.") - } - } - } - - /** Assert that the value of the specified field is a numeric. */ - protected def assertFieldIsNumeric(value: String, name: String): Unit = { - if (value != null) { - Try(value.toInt).getOrElse { - throw new SubmitRestProtocolException( - s"'$name' expected numeric value: actual was '$value'.") - } - } - } - - /** - * Assert that the value of the specified field is a memory string. - * Examples of valid memory strings include 3g, 512m, 128k, 4096. - */ - protected def assertFieldIsMemory(value: String, name: String): Unit = { - if (value != null) { - Try(Utils.memoryStringToMb(value)).getOrElse { - throw new SubmitRestProtocolException( - s"'$name' expected memory value: actual was '$value'.") - } - } - } - /** * Assert a condition when validating this message. * If the assertion fails, throw a [[SubmitRestProtocolException]]. diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala index 5d59d4654b8e1..9e1fd8c40cabd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -17,6 +17,10 @@ package org.apache.spark.deploy.rest +import scala.util.Try + +import org.apache.spark.util.Utils + /** * An abstract request sent from the client in the REST application submission protocol. */ @@ -54,11 +58,21 @@ private[spark] class CreateSubmissionRequest extends SubmitRestProtocolRequest { assertFieldIsSet(sparkProperties.getOrElse(key, null), key) private def assertPropertyIsBoolean(key: String): Unit = - assertFieldIsBoolean(sparkProperties.getOrElse(key, null), key) + assertProperty[Boolean](key, "boolean", _.toBoolean) private def assertPropertyIsNumeric(key: String): Unit = - assertFieldIsNumeric(sparkProperties.getOrElse(key, null), key) + assertProperty[Int](key, "numeric", _.toInt) private def assertPropertyIsMemory(key: String): Unit = - assertFieldIsMemory(sparkProperties.getOrElse(key, null), key) + assertProperty[Int](key, "memory", Utils.memoryStringToMb) + + /** Assert that a Spark property can be converted to a certain type. */ + private def assertProperty[T](key: String, valueType: String, convert: (String => T)): Unit = { + sparkProperties.get(key).foreach { value => + Try(convert(value)).getOrElse { + throw new SubmitRestProtocolException( + s"Property '$key' expected $valueType value: actual was '$value'.") + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala index d137cdc804baa..16dfe041d4bea 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -17,17 +17,18 @@ package org.apache.spark.deploy.rest +import java.lang.Boolean + /** * An abstract response sent from the server in the REST application submission protocol. */ private[spark] abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { var serverSparkVersion: String = null - var success: String = null + var success: Boolean = null var unknownFields: Array[String] = null protected override def doValidate(): Unit = { super.doValidate() assertFieldIsSet(serverSparkVersion, "serverSparkVersion") - assertFieldIsBoolean(success, "success") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index f213372f3ab48..29aed89b67aa7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -62,7 +62,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterAll with Bef val response = client.killSubmission(masterRestUrl, "submission-that-does-not-exist") val killResponse = getKillResponse(response) val killSuccess = killResponse.success - assert(killSuccess === "false") + assert(!killSuccess) } test("kill running submission") { @@ -78,8 +78,8 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterAll with Bef val statusResponse = getStatusResponse(response2) val statusSuccess = statusResponse.success val driverState = statusResponse.driverState - assert(killSuccess === "true") - assert(statusSuccess === "true") + assert(killSuccess) + assert(statusSuccess) assert(driverState === DriverState.KILLED.toString) // we should not see the expected results because we killed the submission intercept[TestFailedException] { validateResult(resultsFile, numbers, size) } @@ -89,7 +89,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterAll with Bef val response = client.requestSubmissionStatus(masterRestUrl, "submission-that-does-not-exist") val statusResponse = getStatusResponse(response) val statusSuccess = statusResponse.success - assert(statusSuccess === "false") + assert(!statusSuccess) } /** diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 6cfe535a8a462..1d64ec201e647 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.deploy.rest +import java.lang.Boolean +import java.lang.Integer + import org.json4s.jackson.JsonMethods._ import org.scalatest.FunSuite @@ -34,9 +37,9 @@ class SubmitRestProtocolSuite extends FunSuite { intercept[SubmitRestProtocolException] { request.validate() } // missing name and age request.name = "something" intercept[SubmitRestProtocolException] { request.validate() } // missing only age - request.age = "2" + request.age = 2 intercept[SubmitRestProtocolException] { request.validate() } // age too low - request.age = "10" + request.age = 10 request.validate() // everything is set properly request.clientSparkVersion = null intercept[SubmitRestProtocolException] { request.validate() } // missing only Spark version @@ -47,38 +50,20 @@ class SubmitRestProtocolSuite extends FunSuite { intercept[SubmitRestProtocolException] { request.validate() } // still missing name } - test("validate with illegal argument") { - val request = new DummyRequest - request.clientSparkVersion = "1.2.3" - request.name = "abc" - request.age = "not-a-number" - intercept[SubmitRestProtocolException] { request.validate() } - request.age = "true" - intercept[SubmitRestProtocolException] { request.validate() } - request.age = "150" - request.validate() - request.active = "not-a-boolean" - intercept[SubmitRestProtocolException] { request.validate() } - request.active = "150" - intercept[SubmitRestProtocolException] { request.validate() } - request.active = "true" - request.validate() - } - test("request to and from JSON") { val request = new DummyRequest intercept[SubmitRestProtocolException] { request.toJson } // implicit validation request.clientSparkVersion = "1.2.3" - request.active = "true" - request.age = "25" + request.active = true + request.age = 25 request.name = "jung" val json = request.toJson assertJsonEquals(json, dummyRequestJson) val newRequest = SubmitRestProtocolMessage.fromJson(json, classOf[DummyRequest]) assert(newRequest.clientSparkVersion === "1.2.3") assert(newRequest.clientSparkVersion === "1.2.3") - assert(newRequest.active === "true") - assert(newRequest.age === "25") + assert(newRequest.active) + assert(newRequest.age === 25) assert(newRequest.name === "jung") assert(newRequest.message === null) } @@ -86,13 +71,13 @@ class SubmitRestProtocolSuite extends FunSuite { test("response to and from JSON") { val response = new DummyResponse response.serverSparkVersion = "3.3.4" - response.success = "true" + response.success = true val json = response.toJson assertJsonEquals(json, dummyResponseJson) val newResponse = SubmitRestProtocolMessage.fromJson(json, classOf[DummyResponse]) assert(newResponse.serverSparkVersion === "3.3.4") assert(newResponse.serverSparkVersion === "3.3.4") - assert(newResponse.success === "true") + assert(newResponse.success) assert(newResponse.message === null) } @@ -160,19 +145,15 @@ class SubmitRestProtocolSuite extends FunSuite { intercept[SubmitRestProtocolException] { message.validate() } message.serverSparkVersion = "1.2.3" message.submissionId = "driver_123" - message.success = "true" + message.success = true message.validate() - // bad fields - message.success = "maybe not" - intercept[SubmitRestProtocolException] { message.validate() } - message.success = "true" // test JSON val json = message.toJson assertJsonEquals(json, submitDriverResponseJson) val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[CreateSubmissionResponse]) assert(newMessage.serverSparkVersion === "1.2.3") assert(newMessage.submissionId === "driver_123") - assert(newMessage.success === "true") + assert(newMessage.success) } test("KillSubmissionResponse") { @@ -180,19 +161,15 @@ class SubmitRestProtocolSuite extends FunSuite { intercept[SubmitRestProtocolException] { message.validate() } message.serverSparkVersion = "1.2.3" message.submissionId = "driver_123" - message.success = "true" + message.success = true message.validate() - // bad fields - message.success = "maybe not" - intercept[SubmitRestProtocolException] { message.validate() } - message.success = "true" // test JSON val json = message.toJson assertJsonEquals(json, killDriverResponseJson) val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillSubmissionResponse]) assert(newMessage.serverSparkVersion === "1.2.3") assert(newMessage.submissionId === "driver_123") - assert(newMessage.success === "true") + assert(newMessage.success) } test("SubmissionStatusResponse") { @@ -200,16 +177,12 @@ class SubmitRestProtocolSuite extends FunSuite { intercept[SubmitRestProtocolException] { message.validate() } message.serverSparkVersion = "1.2.3" message.submissionId = "driver_123" - message.success = "true" + message.success = true message.validate() // optional fields message.driverState = "RUNNING" message.workerId = "worker_123" message.workerHostPort = "1.2.3.4:7780" - // bad fields - message.success = "maybe" - intercept[SubmitRestProtocolException] { message.validate() } - message.success = "true" // test JSON val json = message.toJson assertJsonEquals(json, driverStatusResponseJson) @@ -217,7 +190,7 @@ class SubmitRestProtocolSuite extends FunSuite { assert(newMessage.serverSparkVersion === "1.2.3") assert(newMessage.submissionId === "driver_123") assert(newMessage.driverState === "RUNNING") - assert(newMessage.success === "true") + assert(newMessage.success) assert(newMessage.workerId === "worker_123") assert(newMessage.workerHostPort === "1.2.3.4:7780") } @@ -240,8 +213,8 @@ class SubmitRestProtocolSuite extends FunSuite { """ |{ | "action" : "DummyRequest", - | "active" : "true", - | "age" : "25", + | "active" : true, + | "age" : 25, | "clientSparkVersion" : "1.2.3", | "name" : "jung" |} @@ -252,7 +225,7 @@ class SubmitRestProtocolSuite extends FunSuite { |{ | "action" : "DummyResponse", | "serverSparkVersion" : "3.3.4", - | "success": "true" + | "success": true |} """.stripMargin @@ -289,7 +262,7 @@ class SubmitRestProtocolSuite extends FunSuite { | "action" : "CreateSubmissionResponse", | "serverSparkVersion" : "1.2.3", | "submissionId" : "driver_123", - | "success" : "true" + | "success" : true |} """.stripMargin @@ -299,7 +272,7 @@ class SubmitRestProtocolSuite extends FunSuite { | "action" : "KillSubmissionResponse", | "serverSparkVersion" : "1.2.3", | "submissionId" : "driver_123", - | "success" : "true" + | "success" : true |} """.stripMargin @@ -310,7 +283,7 @@ class SubmitRestProtocolSuite extends FunSuite { | "driverState" : "RUNNING", | "serverSparkVersion" : "1.2.3", | "submissionId" : "driver_123", - | "success" : "true", + | "success" : true, | "workerHostPort" : "1.2.3.4:7780", | "workerId" : "worker_123" |} @@ -339,15 +312,13 @@ class SubmitRestProtocolSuite extends FunSuite { private class DummyResponse extends SubmitRestProtocolResponse private class DummyRequest extends SubmitRestProtocolRequest { - var active: String = null - var age: String = null + var active: Boolean = null + var age: Integer = null var name: String = null protected override def doValidate(): Unit = { super.doValidate() assertFieldIsSet(name, "name") assertFieldIsSet(age, "age") - assertFieldIsBoolean(active, "active") - assertFieldIsNumeric(age, "age") - assert(age.toInt > 5, "Not old enough!") + assert(age > 5, "Not old enough!") } }