From df90e8b32ce017294cc0a47bcb78e118943662f9 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 29 Jan 2015 11:01:41 -0800 Subject: [PATCH] Use Jackson for JSON de/serialization This involves a major refactor of all message representations. The main motivation for this change is to simplify the logic to enforce type safety, such that we no longer depend on the behavior of all the scala class magic we used to rely on. This commit also introduces a differentiation between request and response messages to provide further type safety. This would have introduced much additional complexity without the refactor. --- .../deploy/rest/DriverStatusRequest.scala | 31 + .../rest/DriverStatusRequestMessage.scala | 47 -- .../deploy/rest/DriverStatusResponse.scala | 45 ++ .../rest/DriverStatusResponseMessage.scala | 52 -- .../spark/deploy/rest/ErrorMessage.scala | 44 -- .../spark/deploy/rest/ErrorResponse.scala | 26 + .../spark/deploy/rest/KillDriverRequest.scala | 31 + .../rest/KillDriverRequestMessage.scala | 47 -- .../deploy/rest/KillDriverResponse.scala | 36 ++ .../rest/KillDriverResponseMessage.scala | 48 -- .../deploy/rest/StandaloneRestClient.scala | 79 ++- .../deploy/rest/StandaloneRestServer.scala | 95 ++- .../deploy/rest/SubmitDriverRequest.scala | 131 ++++ .../rest/SubmitDriverRequestMessage.scala | 155 ----- .../deploy/rest/SubmitDriverResponse.scala | 35 ++ .../rest/SubmitDriverResponseMessage.scala | 48 -- .../spark/deploy/rest/SubmitRestClient.scala | 22 +- .../deploy/rest/SubmitRestProtocolField.scala | 89 +-- .../rest/SubmitRestProtocolMessage.scala | 294 ++++----- .../spark/deploy/rest/SubmitRestServer.scala | 26 +- .../org/apache/spark/util/JsonProtocol.scala | 9 + .../rest/StandaloneRestProtocolSuite.scala | 14 +- .../deploy/rest/SubmitRestProtocolSuite.scala | 577 +++++++++--------- 23 files changed, 910 insertions(+), 1071 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala new file mode 100644 index 0000000000000..f5d4d95cebf14 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +class DriverStatusRequest extends SubmitRestProtocolRequest { + protected override val action = SubmitRestProtocolAction.DRIVER_STATUS_REQUEST + private val driverId = new SubmitRestProtocolField[String] + + def getDriverId: String = driverId.toString + def setDriverId(s: String): this.type = setField(driverId, s) + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(driverId, "driver_id") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala deleted file mode 100644 index f0d0c5f874d57..0000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A field used in a DriverStatusRequestMessage. - */ -private[spark] abstract class DriverStatusRequestField extends SubmitRestProtocolField -private[spark] object DriverStatusRequestField - extends SubmitRestProtocolFieldCompanion[DriverStatusRequestField] { - case object ACTION extends DriverStatusRequestField with ActionField - case object CLIENT_SPARK_VERSION extends DriverStatusRequestField - case object MESSAGE extends DriverStatusRequestField - case object DRIVER_ID extends DriverStatusRequestField - override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, DRIVER_ID) - override val optionalFields = Seq(MESSAGE) -} - -/** - * A request sent to the cluster manager to query the status of a driver - * in the stable application submission REST protocol. - */ -private[spark] class DriverStatusRequestMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.DRIVER_STATUS_REQUEST, - DriverStatusRequestField.ACTION, - DriverStatusRequestField.requiredFields) - -private[spark] object DriverStatusRequestMessage - extends SubmitRestProtocolMessageCompanion[DriverStatusRequestMessage] { - protected override def newMessage() = new DriverStatusRequestMessage - protected override def fieldFromString(f: String) = DriverStatusRequestField.fromString(f) -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala new file mode 100644 index 0000000000000..1e8090c336812 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +class DriverStatusResponse extends SubmitRestProtocolResponse { + protected override val action = SubmitRestProtocolAction.DRIVER_STATUS_RESPONSE + private val driverId = new SubmitRestProtocolField[String] + private val success = new SubmitRestProtocolField[Boolean] + private val driverState = new SubmitRestProtocolField[String] + private val workerId = new SubmitRestProtocolField[String] + private val workerHostPort = new SubmitRestProtocolField[String] + + def getDriverId: String = driverId.toString + def getSuccess: String = success.toString + def getDriverState: String = driverState.toString + def getWorkerId: String = workerId.toString + def getWorkerHostPort: String = workerHostPort.toString + + def setDriverId(s: String): this.type = setField(driverId, s) + def setSuccess(s: String): this.type = setBooleanField(success, s) + def setDriverState(s: String): this.type = setField(driverState, s) + def setWorkerId(s: String): this.type = setField(workerId, s) + def setWorkerHostPort(s: String): this.type = setField(workerHostPort, s) + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(driverId, "driver_id") + assertFieldIsSet(success, "success") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala deleted file mode 100644 index d651452485055..0000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A field used in a DriverStatusResponseMessage. - */ -private[spark] abstract class DriverStatusResponseField extends SubmitRestProtocolField -private[spark] object DriverStatusResponseField - extends SubmitRestProtocolFieldCompanion[DriverStatusResponseField] { - case object ACTION extends DriverStatusResponseField with ActionField - case object SERVER_SPARK_VERSION extends DriverStatusResponseField - case object MESSAGE extends DriverStatusResponseField - case object DRIVER_ID extends DriverStatusResponseField - case object SUCCESS extends DriverStatusResponseField with BooleanField - // Standalone specific fields - case object DRIVER_STATE extends DriverStatusResponseField - case object WORKER_ID extends DriverStatusResponseField - case object WORKER_HOST_PORT extends DriverStatusResponseField - override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, DRIVER_ID, SUCCESS) - override val optionalFields = Seq(MESSAGE, DRIVER_STATE, WORKER_ID, WORKER_HOST_PORT) -} - -/** - * A message sent from the cluster manager in response to a DriverStatusRequestMessage - * in the stable application submission REST protocol. - */ -private[spark] class DriverStatusResponseMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.DRIVER_STATUS_RESPONSE, - DriverStatusResponseField.ACTION, - DriverStatusResponseField.requiredFields) - -private[spark] object DriverStatusResponseMessage - extends SubmitRestProtocolMessageCompanion[DriverStatusResponseMessage] { - protected override def newMessage() = new DriverStatusResponseMessage - protected override def fieldFromString(f: String) = DriverStatusResponseField.fromString(f) -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala deleted file mode 100644 index f1fbdd227507c..0000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A field used in an ErrorMessage. - */ -private[spark] abstract class ErrorField extends SubmitRestProtocolField -private[spark] object ErrorField extends SubmitRestProtocolFieldCompanion[ErrorField] { - case object ACTION extends ErrorField with ActionField - case object SERVER_SPARK_VERSION extends ErrorField - case object MESSAGE extends ErrorField - override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MESSAGE) - override val optionalFields = Seq.empty -} - -/** - * An error message sent from the cluster manager - * in the stable application submission REST protocol. - */ -private[spark] class ErrorMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.ERROR, - ErrorField.ACTION, - ErrorField.requiredFields) - -private[spark] object ErrorMessage extends SubmitRestProtocolMessageCompanion[ErrorMessage] { - protected override def newMessage() = new ErrorMessage - protected override def fieldFromString(f: String) = ErrorField.fromString(f) -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala new file mode 100644 index 0000000000000..8c30d31850880 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +class ErrorResponse extends SubmitRestProtocolResponse { + protected override val action = SubmitRestProtocolAction.ERROR + override def validate(): Unit = { + super.validate() + assertFieldIsSet(message, "message") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala new file mode 100644 index 0000000000000..c44c94d95a1fc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +class KillDriverRequest extends SubmitRestProtocolRequest { + protected override val action = SubmitRestProtocolAction.KILL_DRIVER_REQUEST + private val driverId = new SubmitRestProtocolField[String] + + def getDriverId: String = driverId.toString + def setDriverId(s: String): this.type = setField(driverId, s) + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(driverId, "driver_id") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala deleted file mode 100644 index 232bb364e8899..0000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A field used in a KillDriverRequestMessage. - */ -private[spark] abstract class KillDriverRequestField extends SubmitRestProtocolField -private[spark] object KillDriverRequestField - extends SubmitRestProtocolFieldCompanion[KillDriverRequestField] { - case object ACTION extends KillDriverRequestField with ActionField - case object CLIENT_SPARK_VERSION extends KillDriverRequestField - case object MESSAGE extends KillDriverRequestField - case object DRIVER_ID extends KillDriverRequestField - override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, DRIVER_ID) - override val optionalFields = Seq(MESSAGE) -} - -/** - * A request sent to the cluster manager to kill a driver - * in the stable application submission REST protocol. - */ -private[spark] class KillDriverRequestMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.KILL_DRIVER_REQUEST, - KillDriverRequestField.ACTION, - KillDriverRequestField.requiredFields) - -private[spark] object KillDriverRequestMessage - extends SubmitRestProtocolMessageCompanion[KillDriverRequestMessage] { - protected override def newMessage() = new KillDriverRequestMessage - protected override def fieldFromString(f: String) = KillDriverRequestField.fromString(f) -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala new file mode 100644 index 0000000000000..e75a52bc9bf0b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +class KillDriverResponse extends SubmitRestProtocolResponse { + protected override val action = SubmitRestProtocolAction.KILL_DRIVER_RESPONSE + private val driverId = new SubmitRestProtocolField[String] + private val success = new SubmitRestProtocolField[Boolean] + + def getDriverId: String = driverId.toString + def getSuccess: String = success.toString + + def setDriverId(s: String): this.type = setField(driverId, s) + def setSuccess(s: String): this.type = setBooleanField(success, s) + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(driverId, "driver_id") + assertFieldIsSet(success, "success") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala deleted file mode 100644 index 0717131ab2ec0..0000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A field used in a KillDriverResponseMessage. - */ -private[spark] abstract class KillDriverResponseField extends SubmitRestProtocolField -private[spark] object KillDriverResponseField - extends SubmitRestProtocolFieldCompanion[KillDriverResponseField] { - case object ACTION extends KillDriverResponseField with ActionField - case object SERVER_SPARK_VERSION extends KillDriverResponseField - case object MESSAGE extends KillDriverResponseField - case object DRIVER_ID extends KillDriverResponseField - case object SUCCESS extends KillDriverResponseField with BooleanField - override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, DRIVER_ID, SUCCESS) - override val optionalFields = Seq(MESSAGE) -} - -/** - * A message sent from the cluster manager in response to a KillDriverRequestMessage - * in the stable application submission REST protocol. - */ -private[spark] class KillDriverResponseMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.KILL_DRIVER_RESPONSE, - KillDriverResponseField.ACTION, - KillDriverResponseField.requiredFields) - -private[spark] object KillDriverResponseMessage - extends SubmitRestProtocolMessageCompanion[KillDriverResponseMessage] { - protected override def newMessage() = new KillDriverResponseMessage - protected override def fieldFromString(f: String) = KillDriverResponseField.fromString(f) -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 278c9af749b14..b564006fd7457 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -37,16 +37,15 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { * just submitted and reports it to the user. Otherwise, if the submission was unsuccessful, * this reports failure and logs an error message provided by the REST server. */ - override def submitDriver(args: SparkSubmitArguments): SubmitDriverResponseMessage = { - import SubmitDriverResponseField._ - val submitResponse = super.submitDriver(args).asInstanceOf[SubmitDriverResponseMessage] - val submitSuccess = submitResponse.getFieldNotNull(SUCCESS).toBoolean + override def submitDriver(args: SparkSubmitArguments): SubmitDriverResponse = { + val submitResponse = super.submitDriver(args).asInstanceOf[SubmitDriverResponse] + val submitSuccess = submitResponse.getSuccess.toBoolean if (submitSuccess) { - val driverId = submitResponse.getFieldNotNull(DRIVER_ID) + val driverId = submitResponse.getDriverId logInfo(s"Driver successfully submitted as $driverId. Polling driver state...") pollSubmittedDriverStatus(args.master, driverId) } else { - val submitMessage = submitResponse.getFieldNotNull(MESSAGE) + val submitMessage = submitResponse.getMessage logError(s"Application submission failed: $submitMessage") } submitResponse @@ -57,16 +56,15 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { * This retries up to a fixed number of times until giving up. */ private def pollSubmittedDriverStatus(master: String, driverId: String): Unit = { - import DriverStatusResponseField._ (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => val statusResponse = requestDriverStatus(master, driverId) - .asInstanceOf[DriverStatusResponseMessage] - val statusSuccess = statusResponse.getFieldNotNull(SUCCESS).toBoolean + .asInstanceOf[DriverStatusResponse] + val statusSuccess = statusResponse.getSuccess.toBoolean if (statusSuccess) { - val driverState = statusResponse.getFieldNotNull(DRIVER_STATE) - val workerId = statusResponse.getFieldOption(WORKER_ID) - val workerHostPort = statusResponse.getFieldOption(WORKER_HOST_PORT) - val exception = statusResponse.getFieldOption(MESSAGE) + val driverState = statusResponse.getDriverState + val workerId = Option(statusResponse.getWorkerId) + val workerHostPort = Option(statusResponse.getWorkerHostPort) + val exception = Option(statusResponse.getMessage) logInfo(s"State of driver $driverId is now $driverState.") // Log worker node, if present (workerId, workerHostPort) match { @@ -83,26 +81,23 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { /** Construct a submit driver request message. */ override protected def constructSubmitRequest( - args: SparkSubmitArguments): SubmitDriverRequestMessage = { - import SubmitDriverRequestField._ - val dm = Option(args.driverMemory).map { m => Utils.memoryStringToMb(m).toString }.orNull - val em = Option(args.executorMemory).map { m => Utils.memoryStringToMb(m).toString }.orNull - val message = new SubmitDriverRequestMessage() - .setField(CLIENT_SPARK_VERSION, sparkVersion) - .setField(APP_NAME, args.name) - .setField(APP_RESOURCE, args.primaryResource) - .setFieldIfNotNull(MAIN_CLASS, args.mainClass) - .setFieldIfNotNull(JARS, args.jars) - .setFieldIfNotNull(FILES, args.files) - .setFieldIfNotNull(DRIVER_MEMORY, dm) - .setFieldIfNotNull(DRIVER_CORES, args.driverCores) - .setFieldIfNotNull(DRIVER_EXTRA_JAVA_OPTIONS, args.driverExtraJavaOptions) - .setFieldIfNotNull(DRIVER_EXTRA_CLASS_PATH, args.driverExtraClassPath) - .setFieldIfNotNull(DRIVER_EXTRA_LIBRARY_PATH, args.driverExtraLibraryPath) - .setFieldIfNotNull(SUPERVISE_DRIVER, args.supervise.toString) - .setFieldIfNotNull(EXECUTOR_MEMORY, em) - .setFieldIfNotNull(TOTAL_EXECUTOR_CORES, args.totalExecutorCores) - args.childArgs.foreach(message.appendAppArg) + args: SparkSubmitArguments): SubmitDriverRequest = { + val message = new SubmitDriverRequest() + .setSparkVersion(sparkVersion) + .setAppName(args.name) + .setAppResource(args.primaryResource) + .setMainClass(args.mainClass) + .setJars(args.jars) + .setFiles(args.files) + .setDriverMemory(args.driverMemory) + .setDriverCores(args.driverCores) + .setDriverExtraJavaOptions(args.driverExtraJavaOptions) + .setDriverExtraClassPath(args.driverExtraClassPath) + .setDriverExtraLibraryPath(args.driverExtraLibraryPath) + .setSuperviseDriver(args.supervise.toString) + .setExecutorMemory(args.executorMemory) + .setTotalExecutorCores(args.totalExecutorCores) + args.childArgs.foreach(message.addAppArg) args.sparkProperties.foreach { case (k, v) => message.setSparkProperty(k, v) } // TODO: send special environment variables? message @@ -111,21 +106,19 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { /** Construct a kill driver request message. */ override protected def constructKillRequest( master: String, - driverId: String): KillDriverRequestMessage = { - import KillDriverRequestField._ - new KillDriverRequestMessage() - .setField(CLIENT_SPARK_VERSION, sparkVersion) - .setField(DRIVER_ID, driverId) + driverId: String): KillDriverRequest = { + new KillDriverRequest() + .setSparkVersion(sparkVersion) + .setDriverId(driverId) } /** Construct a driver status request message. */ override protected def constructStatusRequest( master: String, - driverId: String): DriverStatusRequestMessage = { - import DriverStatusRequestField._ - new DriverStatusRequestMessage() - .setField(CLIENT_SPARK_VERSION, sparkVersion) - .setField(DRIVER_ID, driverId) + driverId: String): DriverStatusRequest = { + new DriverStatusRequest() + .setSparkVersion(sparkVersion) + .setDriverId(driverId) } /** Throw an exception if this is not standalone mode. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index eb6065ff16c4b..3fcfe189c6a10 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkConf import org.apache.spark.util.{AkkaUtils, Utils} import org.apache.spark.deploy.{Command, DriverDescription} import org.apache.spark.deploy.ClientArguments._ -import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.DeployMessages import org.apache.spark.deploy.master.Master /** @@ -56,52 +56,49 @@ private[spark] class StandaloneRestServerHandler( /** Handle a request to submit a driver. */ override protected def handleSubmit( - request: SubmitDriverRequestMessage): SubmitDriverResponseMessage = { - import SubmitDriverResponseField._ + request: SubmitDriverRequest): SubmitDriverResponse = { val driverDescription = buildDriverDescription(request) - val response = AkkaUtils.askWithReply[SubmitDriverResponse]( - RequestSubmitDriver(driverDescription), masterActor, askTimeout) - new SubmitDriverResponseMessage() - .setField(SERVER_SPARK_VERSION, sparkVersion) - .setField(MESSAGE, response.message) - .setField(SUCCESS, response.success.toString) - .setFieldIfNotNull(DRIVER_ID, response.driverId.orNull) + val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + new SubmitDriverResponse() + .setSparkVersion(sparkVersion) + .setMessage(response.message) + .setSuccess(response.success.toString) + .setDriverId(response.driverId.orNull) } /** Handle a request to kill a driver. */ override protected def handleKill( - request: KillDriverRequestMessage): KillDriverResponseMessage = { - import KillDriverResponseField._ - val driverId = request.getFieldNotNull(KillDriverRequestField.DRIVER_ID) - val response = AkkaUtils.askWithReply[KillDriverResponse]( - RequestKillDriver(driverId), masterActor, askTimeout) - new KillDriverResponseMessage() - .setField(SERVER_SPARK_VERSION, sparkVersion) - .setField(MESSAGE, response.message) - .setField(DRIVER_ID, driverId) - .setField(SUCCESS, response.success.toString) + request: KillDriverRequest): KillDriverResponse = { + val driverId = request.getDriverId + val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( + DeployMessages.RequestKillDriver(driverId), masterActor, askTimeout) + new KillDriverResponse() + .setSparkVersion(sparkVersion) + .setMessage(response.message) + .setDriverId(driverId) + .setSuccess(response.success.toString) } /** Handle a request for a driver's status. */ override protected def handleStatus( - request: DriverStatusRequestMessage): DriverStatusResponseMessage = { - import DriverStatusResponseField._ - val driverId = request.getField(DriverStatusRequestField.DRIVER_ID) - val response = AkkaUtils.askWithReply[DriverStatusResponse]( - RequestDriverStatus(driverId), masterActor, askTimeout) + request: DriverStatusRequest): DriverStatusResponse = { + val driverId = request.getDriverId + val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( + DeployMessages.RequestDriverStatus(driverId), masterActor, askTimeout) // Format exception nicely, if it exists val message = response.exception.map { e => val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") s"Exception from the cluster:\n$e\n$stackTraceString" } - new DriverStatusResponseMessage() - .setField(SERVER_SPARK_VERSION, sparkVersion) - .setField(DRIVER_ID, driverId) - .setField(SUCCESS, response.found.toString) - .setFieldIfNotNull(DRIVER_STATE, response.state.map(_.toString).orNull) - .setFieldIfNotNull(WORKER_ID, response.workerId.orNull) - .setFieldIfNotNull(WORKER_HOST_PORT, response.workerHostPort.orNull) - .setFieldIfNotNull(MESSAGE, message.orNull) + new DriverStatusResponse() + .setSparkVersion(sparkVersion) + .setDriverId(driverId) + .setSuccess(response.found.toString) + .setDriverState(response.state.map(_.toString).orNull) + .setWorkerId(response.workerId.orNull) + .setWorkerHostPort(response.workerHostPort.orNull) + .setMessage(message.orNull) } /** @@ -109,25 +106,23 @@ private[spark] class StandaloneRestServerHandler( * This does not currently consider fields used by python applications since * python is not supported in standalone cluster mode yet. */ - private def buildDriverDescription(request: SubmitDriverRequestMessage): DriverDescription = { - import SubmitDriverRequestField._ - + private def buildDriverDescription(request: SubmitDriverRequest): DriverDescription = { // Required fields, including the main class because python is not yet supported - val appName = request.getFieldNotNull(APP_NAME) - val appResource = request.getFieldNotNull(APP_RESOURCE) - val mainClass = request.getFieldNotNull(MAIN_CLASS) + val appName = request.getAppName + val appResource = request.getAppResource + val mainClass = request.getMainClass // Optional fields - val jars = request.getFieldOption(JARS) - val files = request.getFieldOption(FILES) - val driverMemory = request.getFieldOption(DRIVER_MEMORY) - val driverCores = request.getFieldOption(DRIVER_CORES) - val driverExtraJavaOptions = request.getFieldOption(DRIVER_EXTRA_JAVA_OPTIONS) - val driverExtraClassPath = request.getFieldOption(DRIVER_EXTRA_CLASS_PATH) - val driverExtraLibraryPath = request.getFieldOption(DRIVER_EXTRA_LIBRARY_PATH) - val superviseDriver = request.getFieldOption(SUPERVISE_DRIVER) - val executorMemory = request.getFieldOption(EXECUTOR_MEMORY) - val totalExecutorCores = request.getFieldOption(TOTAL_EXECUTOR_CORES) + val jars = Option(request.getJars) + val files = Option(request.getFiles) + val driverMemory = Option(request.getDriverMemory) + val driverCores = Option(request.getDriverCores) + val driverExtraJavaOptions = Option(request.getDriverExtraJavaOptions) + val driverExtraClassPath = Option(request.getDriverExtraClassPath) + val driverExtraLibraryPath = Option(request.getDriverExtraLibraryPath) + val superviseDriver = Option(request.getSuperviseDriver) + val executorMemory = Option(request.getExecutorMemory) + val totalExecutorCores = Option(request.getTotalExecutorCores) val appArgs = request.getAppArgs val sparkProperties = request.getSparkProperties val environmentVariables = request.getEnvironmentVariables @@ -155,7 +150,7 @@ private[spark] class StandaloneRestServerHandler( "org.apache.spark.deploy.worker.DriverWrapper", Seq("{{WORKER_URL}}", mainClass) ++ appArgs, // args to the DriverWrapper environmentVariables, extraClassPath, extraLibraryPath, javaOpts) - val actualDriverMemory = driverMemory.map(_.toInt).getOrElse(DEFAULT_MEMORY) + val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY) val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES) val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) new DriverDescription( diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala new file mode 100644 index 0000000000000..9bde3345d03fa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.fasterxml.jackson.annotation.{JsonIgnore, JsonProperty} +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.util.JsonProtocol + +class SubmitDriverRequest extends SubmitRestProtocolRequest { + protected override val action = SubmitRestProtocolAction.SUBMIT_DRIVER_REQUEST + private val appName = new SubmitRestProtocolField[String] + private val appResource = new SubmitRestProtocolField[String] + private val mainClass = new SubmitRestProtocolField[String] + private val jars = new SubmitRestProtocolField[String] + private val files = new SubmitRestProtocolField[String] + private val pyFiles = new SubmitRestProtocolField[String] + private val driverMemory = new SubmitRestProtocolField[String] + private val driverCores = new SubmitRestProtocolField[Int] + private val driverExtraJavaOptions = new SubmitRestProtocolField[String] + private val driverExtraClassPath = new SubmitRestProtocolField[String] + private val driverExtraLibraryPath = new SubmitRestProtocolField[String] + private val superviseDriver = new SubmitRestProtocolField[Boolean] + private val executorMemory = new SubmitRestProtocolField[String] + private val totalExecutorCores = new SubmitRestProtocolField[Int] + + // Special fields + private val appArgs = new ArrayBuffer[String] + private val sparkProperties = new mutable.HashMap[String, String] + private val envVars = new mutable.HashMap[String, String] + + def getAppName: String = appName.toString + def getAppResource: String = appResource.toString + def getMainClass: String = mainClass.toString + def getJars: String = jars.toString + def getFiles: String = files.toString + def getPyFiles: String = pyFiles.toString + def getDriverMemory: String = driverMemory.toString + def getDriverCores: String = driverCores.toString + def getDriverExtraJavaOptions: String = driverExtraJavaOptions.toString + def getDriverExtraClassPath: String = driverExtraClassPath.toString + def getDriverExtraLibraryPath: String = driverExtraLibraryPath.toString + def getSuperviseDriver: String = superviseDriver.toString + def getExecutorMemory: String = executorMemory.toString + def getTotalExecutorCores: String = totalExecutorCores.toString + + // Special getters required for JSON de/serialization + @JsonProperty("appArgs") + private def getAppArgsJson: String = arrayToJson(getAppArgs) + @JsonProperty("sparkProperties") + private def getSparkPropertiesJson: String = mapToJson(getSparkProperties) + @JsonProperty("environmentVariables") + private def getEnvironmentVariablesJson: String = mapToJson(getEnvironmentVariables) + + def setAppName(s: String): this.type = setField(appName, s) + def setAppResource(s: String): this.type = setField(appResource, s) + def setMainClass(s: String): this.type = setField(mainClass, s) + def setJars(s: String): this.type = setField(jars, s) + def setFiles(s: String): this.type = setField(files, s) + def setPyFiles(s: String): this.type = setField(pyFiles, s) + def setDriverMemory(s: String): this.type = setField(driverMemory, s) + def setDriverCores(s: String): this.type = setNumericField(driverCores, s) + def setDriverExtraJavaOptions(s: String): this.type = setField(driverExtraJavaOptions, s) + def setDriverExtraClassPath(s: String): this.type = setField(driverExtraClassPath, s) + def setDriverExtraLibraryPath(s: String): this.type = setField(driverExtraLibraryPath, s) + def setSuperviseDriver(s: String): this.type = setBooleanField(superviseDriver, s) + def setExecutorMemory(s: String): this.type = setField(executorMemory, s) + def setTotalExecutorCores(s: String): this.type = setNumericField(totalExecutorCores, s) + + // Special setters required for JSON de/serialization + @JsonProperty("appArgs") + private def setAppArgsJson(s: String): Unit = { + appArgs.clear() + appArgs ++= JsonProtocol.arrayFromJson(parse(s)) + } + @JsonProperty("sparkProperties") + private def setSparkPropertiesJson(s: String): Unit = { + sparkProperties.clear() + sparkProperties ++= JsonProtocol.mapFromJson(parse(s)) + } + @JsonProperty("environmentVariables") + private def setEnvironmentVariablesJson(s: String): Unit = { + envVars.clear() + envVars ++= JsonProtocol.mapFromJson(parse(s)) + } + + @JsonIgnore + def getAppArgs: Array[String] = appArgs.toArray + @JsonIgnore + def getSparkProperties: Map[String, String] = sparkProperties.toMap + @JsonIgnore + def getEnvironmentVariables: Map[String, String] = envVars.toMap + @JsonIgnore + def addAppArg(s: String): this.type = { appArgs += s; this } + @JsonIgnore + def setSparkProperty(k: String, v: String): this.type = { sparkProperties(k) = v; this } + @JsonIgnore + def setEnvironmentVariable(k: String, v: String): this.type = { envVars(k) = v; this } + + private def arrayToJson(arr: Array[String]): String = { + if (arr.nonEmpty) { compact(render(JsonProtocol.arrayToJson(arr))) } else { null } + } + + private def mapToJson(map: Map[String, String]): String = { + if (map.nonEmpty) { compact(render(JsonProtocol.mapToJson(map))) } else { null } + } + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(appName, "app_name") + assertFieldIsSet(appResource, "app_resource") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala deleted file mode 100644 index 90d7e408fefc1..0000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.json4s.JsonAST._ - -import org.apache.spark.util.JsonProtocol - -/** - * A field used in a SubmitDriverRequestMessage. - */ -private[spark] abstract class SubmitDriverRequestField extends SubmitRestProtocolField -private[spark] object SubmitDriverRequestField - extends SubmitRestProtocolFieldCompanion[SubmitDriverRequestField] { - case object ACTION extends SubmitDriverRequestField with ActionField - case object CLIENT_SPARK_VERSION extends SubmitDriverRequestField - case object MESSAGE extends SubmitDriverRequestField - case object APP_NAME extends SubmitDriverRequestField - case object APP_RESOURCE extends SubmitDriverRequestField - case object MAIN_CLASS extends SubmitDriverRequestField - case object JARS extends SubmitDriverRequestField - case object FILES extends SubmitDriverRequestField - case object PY_FILES extends SubmitDriverRequestField - case object DRIVER_MEMORY extends SubmitDriverRequestField with MemoryField - case object DRIVER_CORES extends SubmitDriverRequestField with NumericField - case object DRIVER_EXTRA_JAVA_OPTIONS extends SubmitDriverRequestField - case object DRIVER_EXTRA_CLASS_PATH extends SubmitDriverRequestField - case object DRIVER_EXTRA_LIBRARY_PATH extends SubmitDriverRequestField - case object SUPERVISE_DRIVER extends SubmitDriverRequestField with BooleanField - case object EXECUTOR_MEMORY extends SubmitDriverRequestField with MemoryField - case object TOTAL_EXECUTOR_CORES extends SubmitDriverRequestField with NumericField - - // Special fields that should not be set directly - case object APP_ARGS extends SubmitDriverRequestField { - override def validateValue(v: String): Unit = { - validateFailed(v, "Use message.appendAppArg(arg) instead") - } - } - case object SPARK_PROPERTIES extends SubmitDriverRequestField { - override def validateValue(v: String): Unit = { - validateFailed(v, "Use message.setSparkProperty(k, v) instead") - } - } - case object ENVIRONMENT_VARIABLES extends SubmitDriverRequestField { - override def validateValue(v: String): Unit = { - validateFailed(v, "Use message.setEnvironmentVariable(k, v) instead") - } - } - - override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, APP_NAME, APP_RESOURCE) - override val optionalFields = Seq(MESSAGE, MAIN_CLASS, JARS, FILES, PY_FILES, DRIVER_MEMORY, - DRIVER_CORES, DRIVER_EXTRA_JAVA_OPTIONS, DRIVER_EXTRA_CLASS_PATH, DRIVER_EXTRA_LIBRARY_PATH, - SUPERVISE_DRIVER, EXECUTOR_MEMORY, TOTAL_EXECUTOR_CORES, APP_ARGS, SPARK_PROPERTIES, - ENVIRONMENT_VARIABLES) -} - -/** - * A request sent to the cluster manager to submit a driver - * in the stable application submission REST protocol. - */ -private[spark] class SubmitDriverRequestMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.SUBMIT_DRIVER_REQUEST, - SubmitDriverRequestField.ACTION, - SubmitDriverRequestField.requiredFields) { - - import SubmitDriverRequestField._ - - private val appArgs = new ArrayBuffer[String] - private val sparkProperties = new mutable.HashMap[String, String] - private val environmentVariables = new mutable.HashMap[String, String] - - // Setters for special fields - def appendAppArg(arg: String): Unit = { appArgs += arg } - def setSparkProperty(k: String, v: String): Unit = { sparkProperties(k) = v } - def setEnvironmentVariable(k: String, v: String): Unit = { environmentVariables(k) = v } - - // Getters for special fields - def getAppArgs: Seq[String] = appArgs.clone() - def getSparkProperties: Map[String, String] = sparkProperties.toMap - def getEnvironmentVariables: Map[String, String] = environmentVariables.toMap - - // Include app args, spark properties, and environment variables in the JSON object - // The order imposed here is as follows: * < APP_ARGS < SPARK_PROPERTIES < ENVIRONMENT_VARIABLES - override def toJsonObject: JObject = { - val otherFields = super.toJsonObject.obj - val appArgsJson = JArray(appArgs.map(JString).toList) - val sparkPropertiesJson = JsonProtocol.mapToJson(sparkProperties) - val environmentVariablesJson = JsonProtocol.mapToJson(environmentVariables) - val jsonFields = new ArrayBuffer[JField] - jsonFields ++= otherFields - if (appArgs.nonEmpty) { - jsonFields += JField(APP_ARGS.toString, appArgsJson) - } - if (sparkProperties.nonEmpty) { - jsonFields += JField(SPARK_PROPERTIES.toString, sparkPropertiesJson) - } - if (environmentVariables.nonEmpty) { - jsonFields += JField(ENVIRONMENT_VARIABLES.toString, environmentVariablesJson) - } - JObject(jsonFields.toList) - } -} - -private[spark] object SubmitDriverRequestMessage - extends SubmitRestProtocolMessageCompanion[SubmitDriverRequestMessage] { - - import SubmitDriverRequestField._ - - protected override def newMessage() = new SubmitDriverRequestMessage - protected override def fieldFromString(f: String) = SubmitDriverRequestField.fromString(f) - - /** - * Process the given field and value appropriately based on the type of the field. - * This handles certain nested values in addition to flat values. - */ - override def handleField( - message: SubmitDriverRequestMessage, - field: SubmitRestProtocolField, - value: JValue): Unit = { - (field, value) match { - case (APP_ARGS, JArray(args)) => - args.map(_.asInstanceOf[JString].s).foreach { arg => - message.appendAppArg(arg) - } - case (SPARK_PROPERTIES, props: JObject) => - JsonProtocol.mapFromJson(props).foreach { case (k, v) => - message.setSparkProperty(k, v) - } - case (ENVIRONMENT_VARIABLES, envVars: JObject) => - JsonProtocol.mapFromJson(envVars).foreach { case (envKey, envValue) => - message.setEnvironmentVariable(envKey, envValue) - } - // All other fields are assumed to have flat values - case _ => super.handleField(message, field, value) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala new file mode 100644 index 0000000000000..8a1676767cec9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +class SubmitDriverResponse extends SubmitRestProtocolResponse { + protected override val action = SubmitRestProtocolAction.SUBMIT_DRIVER_RESPONSE + private val success = new SubmitRestProtocolField[Boolean] + private val driverId = new SubmitRestProtocolField[String] + + def getSuccess: String = success.toString + def getDriverId: String = driverId.toString + + def setSuccess(s: String): this.type = setBooleanField(success, s) + def setDriverId(s: String): this.type = setField(driverId, s) + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(success, "success") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala deleted file mode 100644 index d5a2e1660eb0b..0000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A field used in a SubmitDriverResponseMessage. - */ -private[spark] abstract class SubmitDriverResponseField extends SubmitRestProtocolField -private[spark] object SubmitDriverResponseField - extends SubmitRestProtocolFieldCompanion[SubmitDriverResponseField] { - case object ACTION extends SubmitDriverResponseField with ActionField - case object SERVER_SPARK_VERSION extends SubmitDriverResponseField - case object MESSAGE extends SubmitDriverResponseField - case object SUCCESS extends SubmitDriverResponseField with BooleanField - case object DRIVER_ID extends SubmitDriverResponseField - override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MESSAGE, SUCCESS) - override val optionalFields = Seq(DRIVER_ID) -} - -/** - * A message sent from the cluster manager in response to a SubmitDriverRequestMessage - * in the stable application submission REST protocol. - */ -private[spark] class SubmitDriverResponseMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.SUBMIT_DRIVER_RESPONSE, - SubmitDriverResponseField.ACTION, - SubmitDriverResponseField.requiredFields) - -private[spark] object SubmitDriverResponseMessage - extends SubmitRestProtocolMessageCompanion[SubmitDriverResponseMessage] { - protected override def newMessage() = new SubmitDriverResponseMessage - protected override def fieldFromString(f: String) = SubmitDriverResponseField.fromString(f) -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala index 513c17deee89c..eb258290bdc7b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala @@ -34,39 +34,39 @@ import org.apache.spark.deploy.SparkSubmitArguments private[spark] abstract class SubmitRestClient extends Logging { /** Request that the REST server submit a driver specified by the provided arguments. */ - def submitDriver(args: SparkSubmitArguments): SubmitRestProtocolMessage = { + def submitDriver(args: SparkSubmitArguments): SubmitDriverResponse = { validateSubmitArguments(args) val url = getHttpUrl(args.master) val request = constructSubmitRequest(args) logInfo(s"Submitting a request to launch a driver in ${args.master}.") - sendHttp(url, request) + sendHttp(url, request).asInstanceOf[SubmitDriverResponse] } /** Request that the REST server kill the specified driver. */ - def killDriver(master: String, driverId: String): SubmitRestProtocolMessage = { + def killDriver(master: String, driverId: String): KillDriverResponse = { validateMaster(master) val url = getHttpUrl(master) val request = constructKillRequest(master, driverId) logInfo(s"Submitting a request to kill driver $driverId in $master.") - sendHttp(url, request) + sendHttp(url, request).asInstanceOf[KillDriverResponse] } /** Request the status of the specified driver from the REST server. */ - def requestDriverStatus(master: String, driverId: String): SubmitRestProtocolMessage = { + def requestDriverStatus(master: String, driverId: String): DriverStatusResponse = { validateMaster(master) val url = getHttpUrl(master) val request = constructStatusRequest(master, driverId) logInfo(s"Submitting a request for the status of driver $driverId in $master.") - sendHttp(url, request) + sendHttp(url, request).asInstanceOf[DriverStatusResponse] } /** Return the HTTP URL of the REST server that corresponds to the given master URL. */ protected def getHttpUrl(master: String): URL // Construct the appropriate type of message based on the request type - protected def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequestMessage - protected def constructKillRequest(master: String, driverId: String): KillDriverRequestMessage - protected def constructStatusRequest(master: String, driverId: String): DriverStatusRequestMessage + protected def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequest + protected def constructKillRequest(master: String, driverId: String): KillDriverRequest + protected def constructStatusRequest(master: String, driverId: String): DriverStatusRequest // If the provided arguments are not as expected, throw an exception protected def validateMaster(master: String): Unit @@ -81,7 +81,7 @@ private[spark] abstract class SubmitRestClient extends Logging { * This assumes both the request and the response use the JSON format. * Return the response received from the REST server. */ - private def sendHttp(url: URL, request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = { + private def sendHttp(url: URL, request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = { try { val conn = url.openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod("POST") @@ -96,7 +96,7 @@ private[spark] abstract class SubmitRestClient extends Logging { out.close() val responseJson = Source.fromInputStream(conn.getInputStream).mkString logDebug(s"Response from the REST server:\n$responseJson") - SubmitRestProtocolMessage.fromJson(responseJson) + SubmitRestProtocolResponse.fromJson(responseJson) } catch { case e: FileNotFoundException => throw new SparkException(s"Unable to connect to REST server $url", e) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala index 639e00d912e7f..4c0c45b450faf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala @@ -17,68 +17,11 @@ package org.apache.spark.deploy.rest -import scala.collection.Map -import scala.util.Try - -import org.apache.spark.util.Utils - -/** - * A field used in a SubmitRestProtocolMessage. - * There are a few special fields: - * - ACTION entirely specifies the type of the message and is required in all messages - * - MESSAGE contains arbitrary messages and is common, but not required, in all messages - * - CLIENT_SPARK_VERSION is required in all messages sent from the client - * - SERVER_SPARK_VERSION is required in all messages sent from the server - */ -private[spark] abstract class SubmitRestProtocolField { - protected val fieldName = Utils.getFormattedClassName(this) - def validateValue(value: String): Unit = { } - def validateFailed(v: String, msg: String): Unit = { - throw new IllegalArgumentException(s"Detected setting of $fieldName to $v: $msg") - } -} -private[spark] object SubmitRestProtocolField { - def isActionField(field: String): Boolean = field == "ACTION" -} - -/** A field that should accept only boolean values. */ -private[spark] trait BooleanField extends SubmitRestProtocolField { - override def validateValue(v: String): Unit = { - Try(v.toBoolean).getOrElse { validateFailed(v, s"Error parsing $v as a boolean!") } - } -} - -/** A field that should accept only numeric values. */ -private[spark] trait NumericField extends SubmitRestProtocolField { - override def validateValue(v: String): Unit = { - Try(v.toInt).getOrElse { validateFailed(v, s"Error parsing $v as an integer!") } - } -} - -/** A field that should accept only memory values. */ -private[spark] trait MemoryField extends SubmitRestProtocolField { - override def validateValue(v: String): Unit = { - Try(Utils.memoryStringToMb(v)).getOrElse { - validateFailed(v, s"Error parsing $v as a memory string!") - } - } -} - -/** - * The main action field in every message. - * This should be set only on message instantiation. - */ -private[spark] trait ActionField extends SubmitRestProtocolField { - override def validateValue(v: String): Unit = { - validateFailed(v, "The ACTION field must not be set directly after instantiation.") - } -} - /** * All possible values of the ACTION field in a SubmitRestProtocolMessage. */ -private[spark] abstract class SubmitRestProtocolAction -private[spark] object SubmitRestProtocolAction { +abstract class SubmitRestProtocolAction +object SubmitRestProtocolAction { case object SUBMIT_DRIVER_REQUEST extends SubmitRestProtocolAction case object SUBMIT_DRIVER_RESPONSE extends SubmitRestProtocolAction case object KILL_DRIVER_REQUEST extends SubmitRestProtocolAction @@ -98,24 +41,12 @@ private[spark] object SubmitRestProtocolAction { } } -/** - * Common methods used by companion objects of SubmitRestProtocolField's subclasses. - * This keeps track of all fields that belong to this object in order to reconstruct - * the fields from their names. - */ -private[spark] trait SubmitRestProtocolFieldCompanion[FieldType <: SubmitRestProtocolField] { - val requiredFields: Seq[FieldType] - val optionalFields: Seq[FieldType] - - // Listing of all fields indexed by the field's string representation - private lazy val allFieldsMap: Map[String, FieldType] = { - (requiredFields ++ optionalFields).map { f => (f.toString, f) }.toMap - } - - /** Return the appropriate SubmitRestProtocolField from its string representation. */ - def fromString(field: String): FieldType = { - allFieldsMap.get(field).getOrElse { - throw new IllegalArgumentException(s"Unknown field $field") - } - } +class SubmitRestProtocolField[T] { + protected var value: Option[T] = None + def isSet: Boolean = value.isDefined + def getValue: T = value.getOrElse { throw new IllegalAccessException("Value not set!") } + def getValueOption: Option[T] = value + def setValue(v: T): Unit = { value = Some(v) } + def clearValue(): Unit = { value = None } + override def toString: String = value.map(_.toString).orNull } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index 7899668ac5260..0b2085b5e3bf1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -17,185 +17,185 @@ package org.apache.spark.deploy.rest -import scala.collection.Map -import scala.collection.JavaConversions._ - -import org.json4s.jackson.JsonMethods._ +import com.fasterxml.jackson.annotation._ +import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility +import com.fasterxml.jackson.annotation.JsonInclude.Include +import com.fasterxml.jackson.databind.ObjectMapper import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ -import org.apache.spark.Logging import org.apache.spark.util.Utils +import org.apache.spark.deploy.rest.SubmitRestProtocolAction._ + +@JsonInclude(Include.NON_NULL) +@JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY) +@JsonPropertyOrder(alphabetic = true) +abstract class SubmitRestProtocolMessage { + import SubmitRestProtocolMessage._ + + private val messageType = Utils.getFormattedClassName(this) + protected val action: SubmitRestProtocolAction + protected val sparkVersion = new SubmitRestProtocolField[String] + protected val message = new SubmitRestProtocolField[String] + + // Required for JSON de/serialization and not explicitly used + private def getAction: String = action.toString + private def setAction(s: String): this.type = this + + // Spark version implementation depends on whether this is a request or a response + @JsonIgnore + def getSparkVersion: String + @JsonIgnore + def setSparkVersion(s: String): this.type + + def getMessage: String = message.toString + def setMessage(s: String): this.type = setField(message, s) + + def toJson: String = { + validate() + val mapper = new ObjectMapper + val json = mapper.writeValueAsString(this) + postProcessJson(json) + } -/** - * A general message exchanged in the stable application submission REST protocol. - * - * The message is represented by a set of fields in the form of key value pairs. - * Each message must contain an ACTION field, which fully specifies the type of the message. - * For compatibility with older versions of Spark, existing fields must not be removed or - * modified, though new fields can be added as necessary. - */ -private[spark] abstract class SubmitRestProtocolMessage( - action: SubmitRestProtocolAction, - actionField: ActionField, - requiredFields: Seq[SubmitRestProtocolField]) { - - // Maintain the insert order for converting to JSON later - private val fields = new java.util.LinkedHashMap[SubmitRestProtocolField, String] - val className = Utils.getFormattedClassName(this) - - // Set the action field - fields.put(actionField, action.toString) - - /** Return all fields currently set in this message. */ - def getFields: Map[SubmitRestProtocolField, String] = fields.toMap - - /** Return the value of the given field. If the field is not present, return null. */ - def getField(key: SubmitRestProtocolField): String = getFieldOption(key).orNull + def validate(): Unit = { + assert(action != null, s"The action field is missing in $messageType!") + } - /** Return the value of the given field. If the field is not present, throw an exception. */ - def getFieldNotNull(key: SubmitRestProtocolField): String = { - getFieldOption(key).getOrElse { - throw new IllegalArgumentException(s"Field $key is not set in message $className") - } + protected def assertFieldIsSet(field: SubmitRestProtocolField[_], name: String): Unit = { + assert(field.isSet, s"The $name field is missing in $messageType!") } - /** Return the value of the given field as an option. */ - def getFieldOption(key: SubmitRestProtocolField): Option[String] = Option(fields.get(key)) + protected def setField(field: SubmitRestProtocolField[String], value: String): this.type = { + if (value == null) { field.clearValue() } else { field.setValue(value) } + this + } - /** Assign the given value to the field, overriding any existing value. */ - def setField(key: SubmitRestProtocolField, value: String): this.type = { - key.validateValue(value) - fields.put(key, value) + protected def setBooleanField( + field: SubmitRestProtocolField[Boolean], + value: String): this.type = { + if (value == null) { field.clearValue() } else { field.setValue(value.toBoolean) } this } - /** Assign the given value to the field only if the value is not null. */ - def setFieldIfNotNull(key: SubmitRestProtocolField, value: String): this.type = { - if (value != null) { - setField(key, value) - } + protected def setNumericField( + field: SubmitRestProtocolField[Int], + value: String): this.type = { + if (value == null) { field.clearValue() } else { field.setValue(value.toInt) } this } - /** - * Validate that all required fields are set and the value of the ACTION field is as expected. - * If any of these conditions are not met, throw an exception. - */ - def validate(): this.type = { - if (!fields.contains(actionField)) { - throw new IllegalArgumentException(s"The action field is missing from message $className.") - } - if (fields(actionField) != action.toString) { - throw new IllegalArgumentException( - s"Expected action $action in message $className, but actual was ${fields(actionField)}.") - } - val missingFields = requiredFields.filterNot(fields.contains) - if (missingFields.nonEmpty) { - val missingFieldsString = missingFields.mkString(", ") - throw new IllegalArgumentException( - s"The following fields are missing from message $className: $missingFieldsString.") - } + protected def setMemoryField( + field: SubmitRestProtocolField[String], + value: String): this.type = { + Utils.memoryStringToMb(value) + setField(field, value) this } - /** Return the JSON representation of this message. */ - def toJson: String = pretty(render(toJsonObject)) + private def postProcessJson(json: String): String = { + val fields = parse(json).asInstanceOf[JObject].obj + val newFields = fields.map { case (k, v) => (camelCaseToUnderscores(k), v) } + pretty(render(JObject(newFields))) + } +} + +abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage { + def getClientSparkVersion: String = sparkVersion.toString + def setClientSparkVersion(s: String): this.type = setField(sparkVersion, s) + override def getSparkVersion: String = getClientSparkVersion + override def setSparkVersion(s: String) = setClientSparkVersion(s) + override def validate(): Unit = { + super.validate() + assertFieldIsSet(sparkVersion, "client_spark_version") + } +} - /** - * Return a JObject that represents the JSON form of this message. - * This ignores fields with null values. - */ - protected def toJsonObject: JObject = { - val jsonFields = fields.toSeq - .filter { case (_, v) => v != null } - .map { case (k, v) => JField(k.toString, JString(v)) } - .toList - JObject(jsonFields) +abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { + def getServerSparkVersion: String = sparkVersion.toString + def setServerSparkVersion(s: String): this.type = setField(sparkVersion, s) + override def getSparkVersion: String = getServerSparkVersion + override def setSparkVersion(s: String) = setServerSparkVersion(s) + override def validate(): Unit = { + super.validate() + assertFieldIsSet(sparkVersion, "server_spark_version") } } -private[spark] object SubmitRestProtocolMessage { - import SubmitRestProtocolField._ - import SubmitRestProtocolAction._ +object SubmitRestProtocolMessage { + private val mapper = new ObjectMapper - /** - * Construct a SubmitRestProtocolMessage from its JSON representation. - * This uses the ACTION field to determine the type of the message to reconstruct. - * If such a field does not exist, throw an exception. - */ def fromJson(json: String): SubmitRestProtocolMessage = { - val jsonObject = parse(json).asInstanceOf[JObject] - val action = getAction(jsonObject).getOrElse { - throw new IllegalArgumentException(s"ACTION not found in message:\n$json") + val fields = parse(json).asInstanceOf[JObject].obj + val action = fields + .find { case (f, _) => f == "action" } + .map { case (_, v) => v.asInstanceOf[JString].s } + .getOrElse { + throw new IllegalArgumentException(s"Could not find action field in message:\n$json") + } + val clazz = SubmitRestProtocolAction.fromString(action) match { + case SUBMIT_DRIVER_REQUEST => classOf[SubmitDriverRequest] + case SUBMIT_DRIVER_RESPONSE => classOf[SubmitDriverResponse] + case KILL_DRIVER_REQUEST => classOf[KillDriverRequest] + case KILL_DRIVER_RESPONSE => classOf[KillDriverResponse] + case DRIVER_STATUS_REQUEST => classOf[DriverStatusRequest] + case DRIVER_STATUS_RESPONSE => classOf[DriverStatusResponse] + case ERROR => classOf[ErrorResponse] } - SubmitRestProtocolAction.fromString(action) match { - case SUBMIT_DRIVER_REQUEST => SubmitDriverRequestMessage.fromJsonObject(jsonObject) - case SUBMIT_DRIVER_RESPONSE => SubmitDriverResponseMessage.fromJsonObject(jsonObject) - case KILL_DRIVER_REQUEST => KillDriverRequestMessage.fromJsonObject(jsonObject) - case KILL_DRIVER_RESPONSE => KillDriverResponseMessage.fromJsonObject(jsonObject) - case DRIVER_STATUS_REQUEST => DriverStatusRequestMessage.fromJsonObject(jsonObject) - case DRIVER_STATUS_RESPONSE => DriverStatusResponseMessage.fromJsonObject(jsonObject) - case ERROR => ErrorMessage.fromJsonObject(jsonObject) + fromJson(json, clazz) + } + + def fromJson[T <: SubmitRestProtocolMessage](json: String, clazz: Class[T]): T = { + val fields = parse(json).asInstanceOf[JObject].obj + val processedFields = fields.map { case (k, v) => (underscoresToCamelCase(k), v) } + val processedJson = compact(render(JObject(processedFields))) + mapper.readValue(processedJson, clazz) + } + + private def camelCaseToUnderscores(s: String): String = { + val newString = new StringBuilder + s.foreach { c => + if (c.isUpper) { + newString.append("_" + c.toLower) + } else { + newString.append(c) + } } + newString.toString() } - /** - * Extract the value of the ACTION field in the JSON object. - */ - private def getAction(jsonObject: JObject): Option[String] = { - jsonObject.obj - .collect { case JField(k, JString(v)) if isActionField(k) => v } - .headOption + private def underscoresToCamelCase(s: String): String = { + val newString = new StringBuilder + var capitalizeNext = false + s.foreach { c => + if (c == '_') { + capitalizeNext = true + } else { + val nextChar = if (capitalizeNext) c.toUpper else c + newString.append(nextChar) + capitalizeNext = false + } + } + newString.toString() } } -/** - * Common methods used by companion objects of SubmitRestProtocolMessage's subclasses. - */ -private[spark] trait SubmitRestProtocolMessageCompanion[MessageType <: SubmitRestProtocolMessage] - extends Logging { - - import SubmitRestProtocolField._ - - /** Construct a new message of the relevant type. */ - protected def newMessage(): MessageType - - /** Return a field of the relevant type from the field's string representation. */ - protected def fieldFromString(field: String): SubmitRestProtocolField - - /** - * Populate the given field and value in the provided message. - * The default behavior only handles fields that have flat values and ignores other fields. - * If the subclass uses fields with nested values, it should override this method appropriately. - */ - protected def handleField( - message: MessageType, - field: SubmitRestProtocolField, - value: JValue): Unit = { - value match { - case JString(s) => message.setField(field, s) - case _ => logWarning( - s"Unexpected value for field $field in message ${message.className}:\n$value") +object SubmitRestProtocolRequest { + def fromJson(s: String): SubmitRestProtocolRequest = { + SubmitRestProtocolMessage.fromJson(s) match { + case req: SubmitRestProtocolRequest => req + case res: SubmitRestProtocolResponse => + throw new IllegalArgumentException(s"Message was not a request:\n$s") } } +} - /** Construct a SubmitRestProtocolMessage from the given JSON object. */ - def fromJsonObject(jsonObject: JObject): MessageType = { - val message = newMessage() - val fields = jsonObject.obj - .map { case JField(k, v) => (k, v) } - // The ACTION field is already handled on instantiation - .filter { case (k, _) => !isActionField(k) } - .flatMap { case (k, v) => - try { - Some((fieldFromString(k), v)) - } catch { - case e: IllegalArgumentException => - logWarning(s"Unexpected field $k in message ${Utils.getFormattedClassName(this)}") - None - } - } - fields.foreach { case (k, v) => handleField(message, k, v) } - message +object SubmitRestProtocolResponse { + def fromJson(s: String): SubmitRestProtocolResponse = { + SubmitRestProtocolMessage.fromJson(s) match { + case req: SubmitRestProtocolRequest => + throw new IllegalArgumentException(s"Message was not a response:\n$s") + case res: SubmitRestProtocolResponse => res + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala index 9bc7220eb19d3..89a2b83d2cdee 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -70,9 +70,9 @@ private[spark] abstract class SubmitRestServer(host: String, requestedPort: Int, * This represents the main handler used in the SubmitRestServer. */ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler with Logging { - protected def handleSubmit(request: SubmitDriverRequestMessage): SubmitDriverResponseMessage - protected def handleKill(request: KillDriverRequestMessage): KillDriverResponseMessage - protected def handleStatus(request: DriverStatusRequestMessage): DriverStatusResponseMessage + protected def handleSubmit(request: SubmitDriverRequest): SubmitDriverResponse + protected def handleKill(request: KillDriverRequest): KillDriverResponse + protected def handleStatus(request: DriverStatusRequest): DriverStatusResponse /** * Handle a request submitted by the SubmitRestClient. @@ -85,7 +85,7 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi response: HttpServletResponse): Unit = { try { val requestMessageJson = Source.fromInputStream(request.getInputStream).mkString - val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + val requestMessage = SubmitRestProtocolRequest.fromJson(requestMessageJson) val responseMessage = constructResponseMessage(requestMessage) response.setContentType("application/json") response.setCharacterEncoding("utf-8") @@ -105,7 +105,7 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi * If an IllegalArgumentException is thrown in the process, construct an error message instead. */ private def constructResponseMessage( - request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = { + request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = { // Validate the request message to ensure that it is correctly constructed. If the request // is sent via the SubmitRestClient, it should have already been validated remotely. In case // this is not true, do it again here to guard against potential NPEs. If validation fails, @@ -114,9 +114,9 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi try { request.validate() request match { - case submit: SubmitDriverRequestMessage => handleSubmit(submit) - case kill: KillDriverRequestMessage => handleKill(kill) - case status: DriverStatusRequestMessage => handleStatus(status) + case submit: SubmitDriverRequest => handleSubmit(submit) + case kill: KillDriverRequest => handleKill(kill) + case status: DriverStatusRequest => handleStatus(status) case unexpected => handleError( s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") } @@ -130,13 +130,13 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi } catch { case e: IllegalArgumentException => handleError(s"Internal server error: ${e.getMessage}") } + response } /** Construct an error message to signal the fact that an exception has been thrown. */ - private def handleError(message: String): ErrorMessage = { - import ErrorField._ - new ErrorMessage() - .setField(SERVER_SPARK_VERSION, sparkVersion) - .setField(MESSAGE, message) + private def handleError(message: String): ErrorResponse = { + new ErrorResponse() + .setSparkVersion(sparkVersion) + .setMessage(message) } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index f896b5072e4fa..0a1fba3bad753 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -389,6 +389,10 @@ private[spark] object JsonProtocol { * Util JSON serialization methods | * ------------------------------- */ + def arrayToJson(a: Array[String]): JValue = { + JArray(a.toList.map(JString)) + } + def mapToJson(m: Map[String, String]): JValue = { val jsonFields = m.map { case (k, v) => JField(k, JString(v)) } JObject(jsonFields.toList) @@ -795,6 +799,11 @@ private[spark] object JsonProtocol { * Util JSON deserialization methods | * --------------------------------- */ + def arrayFromJson(json: JValue): Array[String] = { + val values = json.asInstanceOf[JArray].arr + values.toArray.map(_.asInstanceOf[JString].s) + } + def mapFromJson(json: JValue): Map[String, String] = { val jsonFields = json.asInstanceOf[JObject].obj jsonFields.map { case JField(k, JString(v)) => (k, v) }.toMap diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala index d7dc8234d57a3..11e49077d893a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala @@ -60,7 +60,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B test("kill empty driver") { val killResponse = client.killDriver(masterRestUrl, "driver-that-does-not-exist") - val killSuccess = killResponse.getFieldNotNull(KillDriverResponseField.SUCCESS) + val killSuccess = killResponse.getSuccess assert(killSuccess === "false") } @@ -70,11 +70,11 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B val size = 500 val driverId = submitApplication(resultsFile, numbers, size) val killResponse = client.killDriver(masterRestUrl, driverId) - val killSuccess = killResponse.getFieldNotNull(KillDriverResponseField.SUCCESS) + val killSuccess = killResponse.getSuccess waitUntilFinished(driverId) val statusResponse = client.requestDriverStatus(masterRestUrl, driverId) - val statusSuccess = statusResponse.getFieldNotNull(DriverStatusResponseField.SUCCESS) - val driverState = statusResponse.getFieldNotNull(DriverStatusResponseField.DRIVER_STATE) + val statusSuccess = statusResponse.getSuccess + val driverState = statusResponse.getDriverState assert(killSuccess === "true") assert(statusSuccess === "true") assert(driverState === DriverState.KILLED.toString) @@ -83,7 +83,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B test("request status for empty driver") { val statusResponse = client.requestDriverStatus(masterRestUrl, "driver-that-does-not-exist") - val statusSuccess = statusResponse.getFieldNotNull(DriverStatusResponseField.SUCCESS) + val statusSuccess = statusResponse.getSuccess assert(statusSuccess === "false") } @@ -125,7 +125,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B val args = new SparkSubmitArguments(commandLineArgs) SparkSubmit.prepareSubmitEnvironment(args) val submitResponse = client.submitDriver(args) - submitResponse.getFieldNotNull(SubmitDriverResponseField.DRIVER_ID) + submitResponse.getDriverId } /** Wait until the given driver has finished running up to the specified timeout. */ @@ -134,7 +134,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B val expireTime = System.currentTimeMillis + maxSeconds * 1000 while (!finished) { val statusResponse = client.requestDriverStatus(masterRestUrl, driverId) - val driverState = statusResponse.getFieldNotNull(DriverStatusResponseField.DRIVER_STATE) + val driverState = statusResponse.getDriverState finished = driverState != DriverState.SUBMITTED.toString && driverState != DriverState.RUNNING.toString diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 18091e98c0b28..a7468a02dfe83 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -17,36 +17,37 @@ package org.apache.spark.deploy.rest -import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ import org.scalatest.FunSuite -/** - * Dummy fields and messages for testing. - */ -private abstract class DummyField extends SubmitRestProtocolField -private object DummyField extends SubmitRestProtocolFieldCompanion[DummyField] { - case object ACTION extends DummyField with ActionField - case object DUMMY_FIELD extends DummyField - case object BOOLEAN_FIELD extends DummyField with BooleanField - case object MEMORY_FIELD extends DummyField with MemoryField - case object NUMERIC_FIELD extends DummyField with NumericField - case object REQUIRED_FIELD extends DummyField - override val requiredFields = Seq(ACTION, REQUIRED_FIELD) - override val optionalFields = Seq(DUMMY_FIELD, BOOLEAN_FIELD, MEMORY_FIELD, NUMERIC_FIELD) -} -private object DUMMY_ACTION extends SubmitRestProtocolAction { - override def toString: String = "DUMMY_ACTION" -} -private class DummyMessage extends SubmitRestProtocolMessage( - DUMMY_ACTION, - DummyField.ACTION, - DummyField.requiredFields) -private object DummyMessage extends SubmitRestProtocolMessageCompanion[DummyMessage] { - protected override def newMessage() = new DummyMessage - protected override def fieldFromString(f: String) = DummyField.fromString(f) +case object DUMMY_REQUEST extends SubmitRestProtocolAction +case object DUMMY_RESPONSE extends SubmitRestProtocolAction + +class DummyRequest extends SubmitRestProtocolRequest { + protected override val action = DUMMY_REQUEST + private val active = new SubmitRestProtocolField[Boolean] + private val age = new SubmitRestProtocolField[Int] + private val name = new SubmitRestProtocolField[String] + + def getActive: String = active.toString + def getAge: String = age.toString + def getName: String = name.toString + + def setActive(s: String): this.type = setBooleanField(active, s) + def setAge(s: String): this.type = setNumericField(age, s) + def setName(s: String): this.type = setField(name, s) + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(name, "name") + assertFieldIsSet(age, "age") + assert(age.getValue > 5, "Not old enough!") + } } +class DummyResponse extends SubmitRestProtocolResponse { + protected override val action = DUMMY_RESPONSE +} /** * Tests for the stable application submission REST protocol. @@ -65,110 +66,123 @@ class SubmitRestProtocolSuite extends FunSuite { } test("get and set fields") { - import DummyField._ - val message = new DummyMessage - // action field is already set on instantiation - assert(message.getFields.size === 1) - assert(message.getField(ACTION) === DUMMY_ACTION.toString) - // required field not set yet - intercept[IllegalArgumentException] { message.validate() } - intercept[IllegalArgumentException] { message.getFieldNotNull(DUMMY_FIELD) } - intercept[IllegalArgumentException] { message.getFieldNotNull(REQUIRED_FIELD) } - message.setField(DUMMY_FIELD, "dummy value") - message.setField(BOOLEAN_FIELD, "true") - message.setField(MEMORY_FIELD, "401k") - message.setField(NUMERIC_FIELD, "401") - message.setFieldIfNotNull(REQUIRED_FIELD, null) // no-op because value is null - assert(message.getFields.size === 5) - // required field still not set - intercept[IllegalArgumentException] { message.validate() } - intercept[IllegalArgumentException] { message.getFieldNotNull(REQUIRED_FIELD) } - message.setFieldIfNotNull(REQUIRED_FIELD, "dummy value") - // all required fields are now set - assert(message.getFields.size === 6) - assert(message.getField(DUMMY_FIELD) === "dummy value") - assert(message.getField(BOOLEAN_FIELD) === "true") - assert(message.getField(MEMORY_FIELD) === "401k") - assert(message.getField(NUMERIC_FIELD) === "401") - assert(message.getField(REQUIRED_FIELD) === "dummy value") - message.validate() - // bad field values - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - intercept[IllegalArgumentException] { message.setField(BOOLEAN_FIELD, "not T nor F") } - intercept[IllegalArgumentException] { message.setField(MEMORY_FIELD, "not memory") } - intercept[IllegalArgumentException] { message.setField(NUMERIC_FIELD, "not a number") } + val request = new DummyRequest + assert(request.getSparkVersion === null) + assert(request.getMessage === null) + assert(request.getActive === null) + assert(request.getAge === null) + assert(request.getName === null) + request.setSparkVersion("1.2.3") + request.setActive("true") + request.setAge("10") + request.setName("dolphin") + assert(request.getSparkVersion === "1.2.3") + assert(request.getMessage === null) + assert(request.getActive === "true") + assert(request.getAge === "10") + assert(request.getName === "dolphin") + // overwrite + request.setName("shark") + request.setActive("false") + assert(request.getName === "shark") + assert(request.getActive === "false") + } + + test("get and set fields with null values") { + val request = new DummyRequest + request.setSparkVersion(null) + request.setActive(null) + request.setAge(null) + request.setName(null) + request.setMessage(null) + assert(request.getSparkVersion === null) + assert(request.getMessage === null) + assert(request.getActive === null) + assert(request.getAge === null) + assert(request.getName === null) + } + + test("set fields with illegal argument") { + val request = new DummyRequest + intercept[IllegalArgumentException] { request.setActive("not-a-boolean") } + intercept[IllegalArgumentException] { request.setActive("150") } + intercept[IllegalArgumentException] { request.setAge("not-a-number") } + intercept[IllegalArgumentException] { request.setAge("true") } } - test("to and from JSON") { - import DummyField._ - val message = new DummyMessage() - .setField(DUMMY_FIELD, "dummy value") - .setField(BOOLEAN_FIELD, "true") - .setField(MEMORY_FIELD, "401k") - .setField(NUMERIC_FIELD, "401") - .setField(REQUIRED_FIELD, "dummy value") - .validate() - val expectedJson = - """ - |{ - | "ACTION" : "DUMMY_ACTION", - | "DUMMY_FIELD" : "dummy value", - | "BOOLEAN_FIELD" : "true", - | "MEMORY_FIELD" : "401k", - | "NUMERIC_FIELD" : "401", - | "REQUIRED_FIELD" : "dummy value" - |} - """.stripMargin - val actualJson = message.toJson - assertJsonEquals(actualJson, expectedJson) - // Do not use SubmitRestProtocolMessage.fromJson here - // because DUMMY_ACTION is not a known action - val jsonObject = parse(expectedJson).asInstanceOf[JObject] - val newMessage = DummyMessage.fromJsonObject(jsonObject) - assert(newMessage.getFieldNotNull(ACTION) === "DUMMY_ACTION") - assert(newMessage.getFieldNotNull(DUMMY_FIELD) === "dummy value") - assert(newMessage.getFieldNotNull(BOOLEAN_FIELD) === "true") - assert(newMessage.getFieldNotNull(MEMORY_FIELD) === "401k") - assert(newMessage.getFieldNotNull(NUMERIC_FIELD) === "401") - assert(newMessage.getFieldNotNull(REQUIRED_FIELD) === "dummy value") - assert(newMessage.getFields.size === 6) + test("validate") { + val request = new DummyRequest + intercept[AssertionError] { request.validate() } // missing everything + request.setSparkVersion("1.4.8") + intercept[AssertionError] { request.validate() } // missing name and age + request.setName("something") + intercept[AssertionError] { request.validate() } // missing only age + request.setAge("2") + intercept[AssertionError] { request.validate() } // age too low + request.setAge("10") + request.validate() // everything is set + request.setSparkVersion(null) + intercept[AssertionError] { request.validate() } // missing only Spark version + request.setSparkVersion("1.2.3") + request.setName(null) + intercept[AssertionError] { request.validate() } // missing only name + request.setMessage("not-setting-name") + intercept[AssertionError] { request.validate() } // still missing name } - test("SubmitDriverRequestMessage") { - import SubmitDriverRequestField._ - val message = new SubmitDriverRequestMessage - intercept[IllegalArgumentException] { message.validate() } - message.setField(CLIENT_SPARK_VERSION, "1.2.3") - message.setField(MESSAGE, "Submitting them drivers.") - message.setField(APP_NAME, "SparkPie") - message.setField(APP_RESOURCE, "honey-walnut-cherry.jar") - // all required fields are now set + test("request to and from JSON") { + val request = new DummyRequest() + .setSparkVersion("1.2.3") + .setActive("true") + .setAge("25") + .setName("jung") + val json = request.toJson + assertJsonEquals(json, dummyRequestJson) + val newRequest = SubmitRestProtocolMessage.fromJson(json, classOf[DummyRequest]) + assert(newRequest.getSparkVersion === "1.2.3") + assert(newRequest.getClientSparkVersion === "1.2.3") + assert(newRequest.getActive === "true") + assert(newRequest.getAge === "25") + assert(newRequest.getName === "jung") + assert(newRequest.getMessage === null) + } + + test("response to and from JSON") { + val response = new DummyResponse().setSparkVersion("3.3.4") + val json = response.toJson + assertJsonEquals(json, dummyResponseJson) + val newResponse = SubmitRestProtocolMessage.fromJson(json, classOf[DummyResponse]) + assert(newResponse.getSparkVersion === "3.3.4") + assert(newResponse.getServerSparkVersion === "3.3.4") + assert(newResponse.getMessage === null) + } + + test("SubmitDriverRequest") { + val message = new SubmitDriverRequest + intercept[AssertionError] { message.validate() } + intercept[IllegalArgumentException] { message.setDriverCores("one hundred feet") } + intercept[IllegalArgumentException] { message.setSuperviseDriver("nope, never") } + intercept[IllegalArgumentException] { message.setTotalExecutorCores("two men") } + message.setSparkVersion("1.2.3") + message.setAppName("SparkPie") + message.setAppResource("honey-walnut-cherry.jar") message.validate() - message.setField(MAIN_CLASS, "org.apache.spark.examples.SparkPie") - message.setField(JARS, "mayonnaise.jar,ketchup.jar") - message.setField(FILES, "fireball.png") - message.setField(PY_FILES, "do-not-eat-my.py") - message.setField(DRIVER_MEMORY, "512m") - message.setField(DRIVER_CORES, "180") - message.setField(DRIVER_EXTRA_JAVA_OPTIONS, " -Dslices=5 -Dcolor=mostly_red") - message.setField(DRIVER_EXTRA_CLASS_PATH, "food-coloring.jar") - message.setField(DRIVER_EXTRA_LIBRARY_PATH, "pickle.jar") - message.setField(SUPERVISE_DRIVER, "false") - message.setField(EXECUTOR_MEMORY, "256m") - message.setField(TOTAL_EXECUTOR_CORES, "10000") - // bad field values - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - intercept[IllegalArgumentException] { message.setField(DRIVER_MEMORY, "more than expected") } - intercept[IllegalArgumentException] { message.setField(DRIVER_CORES, "one hundred feet") } - intercept[IllegalArgumentException] { message.setField(SUPERVISE_DRIVER, "nope, never") } - intercept[IllegalArgumentException] { message.setField(EXECUTOR_MEMORY, "less than expected") } - intercept[IllegalArgumentException] { message.setField(TOTAL_EXECUTOR_CORES, "two men") } - intercept[IllegalArgumentException] { message.setField(APP_ARGS, "anything") } - intercept[IllegalArgumentException] { message.setField(SPARK_PROPERTIES, "anything") } - intercept[IllegalArgumentException] { message.setField(ENVIRONMENT_VARIABLES, "anything") } + // optional fields + message.setMainClass("org.apache.spark.examples.SparkPie") + message.setJars("mayonnaise.jar,ketchup.jar") + message.setFiles("fireball.png") + message.setPyFiles("do-not-eat-my.py") + message.setDriverMemory("512m") + message.setDriverCores("180") + message.setDriverExtraJavaOptions(" -Dslices=5 -Dcolor=mostly_red") + message.setDriverExtraClassPath("food-coloring.jar") + message.setDriverExtraLibraryPath("pickle.jar") + message.setSuperviseDriver("false") + message.setExecutorMemory("256m") + message.setTotalExecutorCores("10000") // special fields - message.appendAppArg("two slices") - message.appendAppArg("a hint of cinnamon") + message.addAppArg("two slices") + message.addAppArg("a hint of cinnamon") message.setSparkProperty("spark.live.long", "true") message.setSparkProperty("spark.shuffle.enabled", "false") message.setEnvironmentVariable("PATH", "/dev/null") @@ -181,231 +195,234 @@ class SubmitRestProtocolSuite extends FunSuite { assert(message.getEnvironmentVariables("PATH") === "/dev/null") assert(message.getEnvironmentVariables("PYTHONPATH") === "/dev/null") // test JSON - val expectedJson = submitDriverRequestJson - assertJsonEquals(message.toJson, expectedJson) - val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) - .asInstanceOf[SubmitDriverRequestMessage] - assert(newMessage.getFields === message.getFields) + val json = message.toJson + assertJsonEquals(json, submitDriverRequestJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[SubmitDriverRequest]) + assert(newMessage.getSparkVersion === "1.2.3") + assert(newMessage.getClientSparkVersion === "1.2.3") + assert(newMessage.getAppName === "SparkPie") + assert(newMessage.getAppResource === "honey-walnut-cherry.jar") + assert(newMessage.getMainClass === "org.apache.spark.examples.SparkPie") + assert(newMessage.getJars === "mayonnaise.jar,ketchup.jar") + assert(newMessage.getFiles === "fireball.png") + assert(newMessage.getPyFiles === "do-not-eat-my.py") + assert(newMessage.getDriverMemory === "512m") + assert(newMessage.getDriverCores === "180") + assert(newMessage.getDriverExtraJavaOptions === " -Dslices=5 -Dcolor=mostly_red") + assert(newMessage.getDriverExtraClassPath === "food-coloring.jar") + assert(newMessage.getDriverExtraLibraryPath === "pickle.jar") + assert(newMessage.getSuperviseDriver === "false") + assert(newMessage.getExecutorMemory === "256m") + assert(newMessage.getTotalExecutorCores === "10000") assert(newMessage.getAppArgs === message.getAppArgs) assert(newMessage.getSparkProperties === message.getSparkProperties) assert(newMessage.getEnvironmentVariables === message.getEnvironmentVariables) } - test("SubmitDriverResponseMessage") { - import SubmitDriverResponseField._ - val message = new SubmitDriverResponseMessage - intercept[IllegalArgumentException] { message.validate() } - message.setField(SERVER_SPARK_VERSION, "1.2.3") - message.setField(MESSAGE, "Dem driver is now submitted.") - message.setField(DRIVER_ID, "driver_123") - message.setField(SUCCESS, "true") - // all required fields are now set + test("SubmitDriverResponse") { + val message = new SubmitDriverResponse + intercept[AssertionError] { message.validate() } + intercept[IllegalArgumentException] { message.setSuccess("maybe not") } + message.setSparkVersion("1.2.3") + message.setDriverId("driver_123") + message.setSuccess("true") message.validate() - // bad field values - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - intercept[IllegalArgumentException] { message.setField(SUCCESS, "maybe not") } // test JSON - val expectedJson = submitDriverResponseJson - val actualJson = message.toJson - assertJsonEquals(actualJson, expectedJson) - val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) - assert(newMessage.isInstanceOf[SubmitDriverResponseMessage]) - assert(newMessage.getFields === message.getFields) + val json = message.toJson + assertJsonEquals(json, submitDriverResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[SubmitDriverResponse]) + assert(newMessage.getSparkVersion === "1.2.3") + assert(newMessage.getServerSparkVersion === "1.2.3") + assert(newMessage.getDriverId === "driver_123") + assert(newMessage.getSuccess === "true") } - test("KillDriverRequestMessage") { - import KillDriverRequestField._ - val message = new KillDriverRequestMessage - intercept[IllegalArgumentException] { message.validate() } - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - message.setField(CLIENT_SPARK_VERSION, "1.2.3") - message.setField(DRIVER_ID, "driver_123") - // all required fields are now set + test("KillDriverRequest") { + val message = new KillDriverRequest + intercept[AssertionError] { message.validate() } + message.setSparkVersion("1.2.3") + message.setDriverId("driver_123") message.validate() // test JSON - val expectedJson = killDriverRequestJson - val actualJson = message.toJson - assertJsonEquals(actualJson, expectedJson) - val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) - assert(newMessage.isInstanceOf[KillDriverRequestMessage]) - assert(newMessage.getFields === message.getFields) + val json = message.toJson + assertJsonEquals(json, killDriverRequestJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillDriverRequest]) + assert(newMessage.getSparkVersion === "1.2.3") + assert(newMessage.getClientSparkVersion === "1.2.3") + assert(newMessage.getDriverId === "driver_123") } - test("KillDriverResponseMessage") { - import KillDriverResponseField._ - val message = new KillDriverResponseMessage - intercept[IllegalArgumentException] { message.validate() } - message.setField(SERVER_SPARK_VERSION, "1.2.3") - message.setField(DRIVER_ID, "driver_123") - message.setField(SUCCESS, "true") - // all required fields are now set + test("KillDriverResponse") { + val message = new KillDriverResponse + intercept[AssertionError] { message.validate() } + intercept[IllegalArgumentException] { message.setSuccess("maybe not") } + message.setSparkVersion("1.2.3") + message.setDriverId("driver_123") + message.setSuccess("true") message.validate() - message.setField(MESSAGE, "Killing dem reckless drivers.") - // bad field values - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - intercept[IllegalArgumentException] { message.setField(SUCCESS, "maybe?") } // test JSON - val expectedJson = killDriverResponseJson - val actualJson = message.toJson - assertJsonEquals(actualJson, expectedJson) - val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) - assert(newMessage.isInstanceOf[KillDriverResponseMessage]) - assert(newMessage.getFields === message.getFields) + val json = message.toJson + assertJsonEquals(json, killDriverResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillDriverResponse]) + assert(newMessage.getSparkVersion === "1.2.3") + assert(newMessage.getServerSparkVersion === "1.2.3") + assert(newMessage.getDriverId === "driver_123") + assert(newMessage.getSuccess === "true") } - test("DriverStatusRequestMessage") { - import DriverStatusRequestField._ - val message = new DriverStatusRequestMessage - intercept[IllegalArgumentException] { message.validate() } - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - message.setField(CLIENT_SPARK_VERSION, "1.2.3") - message.setField(DRIVER_ID, "driver_123") - // all required fields are now set + test("DriverStatusRequest") { + val message = new DriverStatusRequest + intercept[AssertionError] { message.validate() } + message.setSparkVersion("1.2.3") + message.setDriverId("driver_123") message.validate() // test JSON - val expectedJson = driverStatusRequestJson - val actualJson = message.toJson - assertJsonEquals(actualJson, expectedJson) - val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) - assert(newMessage.isInstanceOf[DriverStatusRequestMessage]) - assert(newMessage.getFields === message.getFields) + val json = message.toJson + assertJsonEquals(json, driverStatusRequestJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[DriverStatusRequest]) + assert(newMessage.getSparkVersion === "1.2.3") + assert(newMessage.getClientSparkVersion === "1.2.3") + assert(newMessage.getDriverId === "driver_123") } - test("DriverStatusResponseMessage") { - import DriverStatusResponseField._ - val message = new DriverStatusResponseMessage - intercept[IllegalArgumentException] { message.validate() } - message.setField(SERVER_SPARK_VERSION, "1.2.3") - message.setField(DRIVER_ID, "driver_123") - message.setField(SUCCESS, "true") - // all required fields are now set + test("DriverStatusResponse") { + val message = new DriverStatusResponse + intercept[AssertionError] { message.validate() } + intercept[IllegalArgumentException] { message.setSuccess("maybe") } + message.setSparkVersion("1.2.3") + message.setDriverId("driver_123") + message.setSuccess("true") message.validate() - message.setField(MESSAGE, "Your driver is having some trouble...") - message.setField(DRIVER_STATE, "RUNNING") - message.setField(WORKER_ID, "worker_123") - message.setField(WORKER_HOST_PORT, "1.2.3.4:7780") - // bad field values - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - intercept[IllegalArgumentException] { message.setField(SUCCESS, "maybe") } + // optional fields + message.setDriverState("RUNNING") + message.setWorkerId("worker_123") + message.setWorkerHostPort("1.2.3.4:7780") // test JSON - val expectedJson = driverStatusResponseJson - val actualJson = message.toJson - assertJsonEquals(actualJson, expectedJson) - val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) - assert(newMessage.isInstanceOf[DriverStatusResponseMessage]) - assert(newMessage.getFields === message.getFields) + val json = message.toJson + assertJsonEquals(json, driverStatusResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[DriverStatusResponse]) + assert(newMessage.getSparkVersion === "1.2.3") + assert(newMessage.getServerSparkVersion === "1.2.3") + assert(newMessage.getDriverId === "driver_123") + assert(newMessage.getSuccess === "true") } - test("ErrorMessage") { - import ErrorField._ - val message = new ErrorMessage - intercept[IllegalArgumentException] { message.validate() } - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - message.setField(SERVER_SPARK_VERSION, "1.2.3") - message.setField(MESSAGE, "Your wife threw an exception!") - // all required fields are now set + test("ErrorResponse") { + val message = new ErrorResponse + intercept[AssertionError] { message.validate() } + message.setSparkVersion("1.2.3") + message.setMessage("Field not found in submit request: X") message.validate() // test JSON - val expectedJson = errorJson - val actualJson = message.toJson - assertJsonEquals(actualJson, expectedJson) - val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) - assert(newMessage.isInstanceOf[ErrorMessage]) - assert(newMessage.getFields === message.getFields) + val json = message.toJson + assertJsonEquals(json, errorJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[ErrorResponse]) + assert(newMessage.getSparkVersion === "1.2.3") + assert(newMessage.getServerSparkVersion === "1.2.3") + assert(newMessage.getMessage === "Field not found in submit request: X") } + private val dummyRequestJson = + """ + |{ + | "action" : "DUMMY_REQUEST", + | "active" : "true", + | "age" : "25", + | "client_spark_version" : "1.2.3", + | "name" : "jung" + |} + """.stripMargin + + private val dummyResponseJson = + """ + |{ + | "action" : "DUMMY_RESPONSE", + | "server_spark_version" : "3.3.4" + |} + """.stripMargin + private val submitDriverRequestJson = """ |{ - | "ACTION" : "SUBMIT_DRIVER_REQUEST", - | "CLIENT_SPARK_VERSION" : "1.2.3", - | "MESSAGE" : "Submitting them drivers.", - | "APP_NAME" : "SparkPie", - | "APP_RESOURCE" : "honey-walnut-cherry.jar", - | "MAIN_CLASS" : "org.apache.spark.examples.SparkPie", - | "JARS" : "mayonnaise.jar,ketchup.jar", - | "FILES" : "fireball.png", - | "PY_FILES" : "do-not-eat-my.py", - | "DRIVER_MEMORY" : "512m", - | "DRIVER_CORES" : "180", - | "DRIVER_EXTRA_JAVA_OPTIONS" : " -Dslices=5 -Dcolor=mostly_red", - | "DRIVER_EXTRA_CLASS_PATH" : "food-coloring.jar", - | "DRIVER_EXTRA_LIBRARY_PATH" : "pickle.jar", - | "SUPERVISE_DRIVER" : "false", - | "EXECUTOR_MEMORY" : "256m", - | "TOTAL_EXECUTOR_CORES" : "10000", - | "APP_ARGS" : [ "two slices", "a hint of cinnamon" ], - | "SPARK_PROPERTIES" : { - | "spark.live.long" : "true", - | "spark.shuffle.enabled" : "false" - | }, - | "ENVIRONMENT_VARIABLES" : { - | "PATH" : "/dev/null", - | "PYTHONPATH" : "/dev/null" - | } + | "action" : "SUBMIT_DRIVER_REQUEST", + | "app_args" : "[\"two slices\",\"a hint of cinnamon\"]", + | "app_name" : "SparkPie", + | "app_resource" : "honey-walnut-cherry.jar", + | "client_spark_version" : "1.2.3", + | "driver_cores" : "180", + | "driver_extra_class_path" : "food-coloring.jar", + | "driver_extra_java_options" : " -Dslices=5 -Dcolor=mostly_red", + | "driver_extra_library_path" : "pickle.jar", + | "driver_memory" : "512m", + | "environment_variables" : "{\"PATH\":\"/dev/null\",\"PYTHONPATH\":\"/dev/null\"}", + | "executor_memory" : "256m", + | "files" : "fireball.png", + | "jars" : "mayonnaise.jar,ketchup.jar", + | "main_class" : "org.apache.spark.examples.SparkPie", + | "py_files" : "do-not-eat-my.py", + | "spark_properties" : "{\"spark.live.long\":\"true\",\"spark.shuffle.enabled\":\"false\"}", + | "supervise_driver" : "false", + | "total_executor_cores" : "10000" |} """.stripMargin private val submitDriverResponseJson = """ |{ - | "ACTION" : "SUBMIT_DRIVER_RESPONSE", - | "SERVER_SPARK_VERSION" : "1.2.3", - | "MESSAGE" : "Dem driver is now submitted.", - | "DRIVER_ID" : "driver_123", - | "SUCCESS" : "true" + | "action" : "SUBMIT_DRIVER_RESPONSE", + | "driver_id" : "driver_123", + | "server_spark_version" : "1.2.3", + | "success" : "true" |} """.stripMargin private val killDriverRequestJson = """ |{ - | "ACTION" : "KILL_DRIVER_REQUEST", - | "CLIENT_SPARK_VERSION" : "1.2.3", - | "DRIVER_ID" : "driver_123" + | "action" : "KILL_DRIVER_REQUEST", + | "client_spark_version" : "1.2.3", + | "driver_id" : "driver_123" |} """.stripMargin private val killDriverResponseJson = """ |{ - | "ACTION" : "KILL_DRIVER_RESPONSE", - | "SERVER_SPARK_VERSION" : "1.2.3", - | "DRIVER_ID" : "driver_123", - | "SUCCESS" : "true", - | "MESSAGE" : "Killing dem reckless drivers." + | "action" : "KILL_DRIVER_RESPONSE", + | "driver_id" : "driver_123", + | "server_spark_version" : "1.2.3", + | "success" : "true" |} """.stripMargin private val driverStatusRequestJson = """ |{ - | "ACTION" : "DRIVER_STATUS_REQUEST", - | "CLIENT_SPARK_VERSION" : "1.2.3", - | "DRIVER_ID" : "driver_123" + | "action" : "DRIVER_STATUS_REQUEST", + | "client_spark_version" : "1.2.3", + | "driver_id" : "driver_123" |} """.stripMargin private val driverStatusResponseJson = """ |{ - | "ACTION" : "DRIVER_STATUS_RESPONSE", - | "SERVER_SPARK_VERSION" : "1.2.3", - | "DRIVER_ID" : "driver_123", - | "SUCCESS" : "true", - | "MESSAGE" : "Your driver is having some trouble...", - | "DRIVER_STATE" : "RUNNING", - | "WORKER_ID" : "worker_123", - | "WORKER_HOST_PORT" : "1.2.3.4:7780" + | "action" : "DRIVER_STATUS_RESPONSE", + | "driver_id" : "driver_123", + | "driver_state" : "RUNNING", + | "server_spark_version" : "1.2.3", + | "success" : "true", + | "worker_host_port" : "1.2.3.4:7780", + | "worker_id" : "worker_123" |} """.stripMargin private val errorJson = """ |{ - | "ACTION" : "ERROR", - | "SERVER_SPARK_VERSION" : "1.2.3", - | "MESSAGE" : "Your wife threw an exception!" + | "action" : "ERROR", + | "message" : "Field not found in submit request: X", + | "server_spark_version" : "1.2.3" |} """.stripMargin }