Skip to content

Commit

Permalink
[SPARK-19450] Replace askWithRetry with askSync.
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

`askSync` is already added in `RpcEndpointRef` (see SPARK-19347 and apache#16690 (comment)) and `askWithRetry` is marked as deprecated.
As mentioned SPARK-18113(apache#16503 (comment)):

>askWithRetry is basically an unneeded API, and a leftover from the akka days that doesn't make sense anymore. It's prone to cause deadlocks (exactly because it's blocking), it imposes restrictions on the caller (e.g. idempotency) and other things that people generally don't pay that much attention to when using it.

Since `askWithRetry` is just used inside spark and not in user logic. It might make sense to replace all of them with `askSync`.

## How was this patch tested?
This PR doesn't change code logic, existing unit test can cover.

Author: jinxing <[email protected]>

Closes apache#16790 from jinxing64/SPARK-19450.
  • Loading branch information
jinxing authored and srowen committed Feb 19, 2017
1 parent df3cbe3 commit ba8912e
Show file tree
Hide file tree
Showing 24 changed files with 58 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
*/
protected def askTracker[T: ClassTag](message: Any): T = {
try {
trackerEndpoint.askWithRetry[T](message)
trackerEndpoint.askSync[T](message)
} catch {
case e: Exception =>
logError("Error communicating with MapOutputTracker", e)
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ class SparkContext(config: SparkConf) extends Logging {
Some(Utils.getThreadDump())
} else {
val endpointRef = env.blockManager.master.getExecutorEndpointRef(executorId).get
Some(endpointRef.askWithRetry[Array[ThreadStackTrace]](TriggerThreadDump))
Some(endpointRef.askSync[Array[ThreadStackTrace]](TriggerThreadDump))
}
} catch {
case e: Exception =>
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/deploy/Client.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ private class ClientEndpoint(
Thread.sleep(5000)
logInfo("... polling master for driver state")
val statusResponse =
activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId))
activeMasterEndpoint.askSync[DriverStatusResponse](RequestDriverStatus(driverId))
if (statusResponse.found) {
logInfo(s"State of $driverId is ${statusResponse.state.get}")
// Worker node, if present
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ private[deploy] object Master extends Logging {
val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr)
val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME,
new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf))
val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest)
val portsResponse = masterEndpoint.askSync[BoundPortsResponse](BoundPortsRequest)
(rpcEnv, portsResponse.webUIPort, portsResponse.restPort)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
/** Executor details for a particular application */
def render(request: HttpServletRequest): Seq[Node] = {
val appId = request.getParameter("appId")
val state = master.askWithRetry[MasterStateResponse](RequestMasterState)
val state = master.askSync[MasterStateResponse](RequestMasterState)
val app = state.activeApps.find(_.id == appId)
.getOrElse(state.completedApps.find(_.id == appId).orNull)
if (app == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
private val master = parent.masterEndpointRef

def getMasterState: MasterStateResponse = {
master.askWithRetry[MasterStateResponse](RequestMasterState)
master.askSync[MasterStateResponse](RequestMasterState)
}

override def renderJson(request: HttpServletRequest): JValue = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef,
extends KillRequestServlet {

protected def handleKill(submissionId: String): KillSubmissionResponse = {
val response = masterEndpoint.askWithRetry[DeployMessages.KillDriverResponse](
val response = masterEndpoint.askSync[DeployMessages.KillDriverResponse](
DeployMessages.RequestKillDriver(submissionId))
val k = new KillSubmissionResponse
k.serverSparkVersion = sparkVersion
Expand All @@ -89,7 +89,7 @@ private[rest] class StandaloneStatusRequestServlet(masterEndpoint: RpcEndpointRe
extends StatusRequestServlet {

protected def handleStatus(submissionId: String): SubmissionStatusResponse = {
val response = masterEndpoint.askWithRetry[DeployMessages.DriverStatusResponse](
val response = masterEndpoint.askSync[DeployMessages.DriverStatusResponse](
DeployMessages.RequestDriverStatus(submissionId))
val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) }
val d = new SubmissionStatusResponse
Expand Down Expand Up @@ -174,7 +174,7 @@ private[rest] class StandaloneSubmitRequestServlet(
requestMessage match {
case submitRequest: CreateSubmissionRequest =>
val driverDescription = buildDriverDescription(submitRequest)
val response = masterEndpoint.askWithRetry[DeployMessages.SubmitDriverResponse](
val response = masterEndpoint.askSync[DeployMessages.SubmitDriverResponse](
DeployMessages.RequestSubmitDriver(driverDescription))
val submitResponse = new CreateSubmissionResponse
submitResponse.serverSparkVersion = sparkVersion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") {
private val workerEndpoint = parent.worker.self

override def renderJson(request: HttpServletRequest): JValue = {
val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState)
val workerState = workerEndpoint.askSync[WorkerStateResponse](RequestWorkerState)
JsonProtocol.writeWorkerState(workerState)
}

def render(request: HttpServletRequest): Seq[Node] = {
val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState)
val workerState = workerEndpoint.askSync[WorkerStateResponse](RequestWorkerState)

val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs")
val runningExecutors = workerState.executors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
new SecurityManager(executorConf),
clientMode = true)
val driver = fetcher.setupEndpointRefByURI(driverUrl)
val cfg = driver.askWithRetry[SparkAppConfig](RetrieveSparkAppConfig)
val cfg = driver.askSync[SparkAppConfig](RetrieveSparkAppConfig)
val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId))
fetcher.shutdown()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ private[spark] class Executor(

val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId)
try {
val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](
val response = heartbeatReceiverRef.askSync[HeartbeatResponse](
message, RpcTimeout(conf, "spark.executor.heartbeatInterval", "10s"))
if (response.reregisterBlockManager) {
logInfo("Told to re-register on heartbeat")
Expand Down
60 changes: 0 additions & 60 deletions core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,64 +92,4 @@ private[spark] abstract class RpcEndpointRef(conf: SparkConf)
timeout.awaitResult(future)
}

/**
* Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a
* default timeout, throw a SparkException if this fails even after the default number of retries.
* The default `timeout` will be used in every trial of calling `sendWithReply`. Because this
* method retries, the message handling in the receiver side should be idempotent.
*
* Note: this is a blocking action which may cost a lot of time, so don't call it in a message
* loop of [[RpcEndpoint]].
*
* @param message the message to send
* @tparam T type of the reply message
* @return the reply message from the corresponding [[RpcEndpoint]]
*/
@deprecated("use 'askSync' instead.", "2.2.0")
def askWithRetry[T: ClassTag](message: Any): T = askWithRetry(message, defaultAskTimeout)

/**
* Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a
* specified timeout, throw a SparkException if this fails even after the specified number of
* retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method
* retries, the message handling in the receiver side should be idempotent.
*
* Note: this is a blocking action which may cost a lot of time, so don't call it in a message
* loop of [[RpcEndpoint]].
*
* @param message the message to send
* @param timeout the timeout duration
* @tparam T type of the reply message
* @return the reply message from the corresponding [[RpcEndpoint]]
*/
@deprecated("use 'askSync' instead.", "2.2.0")
def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
// TODO: Consider removing multiple attempts
var attempts = 0
var lastException: Exception = null
while (attempts < maxRetries) {
attempts += 1
try {
val future = ask[T](message, timeout)
val result = timeout.awaitResult(future)
if (result == null) {
throw new SparkException("RpcEndpoint returned null")
}
return result
} catch {
case ie: InterruptedException => throw ie
case e: Exception =>
lastException = e
logWarning(s"Error sending message [message = $message] in $attempts attempts", e)
}

if (attempts < maxRetries) {
Thread.sleep(retryWaitMs)
}
}

throw new SparkException(
s"Error sending message [message = $message]", lastException)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class DAGScheduler(
accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])],
blockManagerId: BlockManagerId): Boolean = {
listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates))
blockManagerMaster.driverEndpoint.askWithRetry[Boolean](
blockManagerMaster.driverEndpoint.askSync[Boolean](
BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat"))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
try {
if (driverEndpoint != null) {
logInfo("Shutting down all executors")
driverEndpoint.askWithRetry[Boolean](StopExecutors)
driverEndpoint.askSync[Boolean](StopExecutors)
}
} catch {
case e: Exception =>
Expand All @@ -384,7 +384,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
stopExecutors()
try {
if (driverEndpoint != null) {
driverEndpoint.askWithRetry[Boolean](StopDriver)
driverEndpoint.askSync[Boolean](StopDriver)
}
} catch {
case e: Exception =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class BlockManagerMaster(
maxMemSize: Long,
slaveEndpoint: RpcEndpointRef): BlockManagerId = {
logInfo(s"Registering BlockManager $blockManagerId")
val updatedId = driverEndpoint.askWithRetry[BlockManagerId](
val updatedId = driverEndpoint.askSync[BlockManagerId](
RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint))
logInfo(s"Registered BlockManager $updatedId")
updatedId
Expand All @@ -72,20 +72,20 @@ class BlockManagerMaster(
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long): Boolean = {
val res = driverEndpoint.askWithRetry[Boolean](
val res = driverEndpoint.askSync[Boolean](
UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize))
logDebug(s"Updated info of block $blockId")
res
}

/** Get locations of the blockId from the driver */
def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetLocations(blockId))
driverEndpoint.askSync[Seq[BlockManagerId]](GetLocations(blockId))
}

/** Get locations of multiple blockIds from the driver */
def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = {
driverEndpoint.askWithRetry[IndexedSeq[Seq[BlockManagerId]]](
driverEndpoint.askSync[IndexedSeq[Seq[BlockManagerId]]](
GetLocationsMultipleBlockIds(blockIds))
}

Expand All @@ -99,24 +99,24 @@ class BlockManagerMaster(

/** Get ids of other nodes in the cluster from the driver */
def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = {
driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId))
driverEndpoint.askSync[Seq[BlockManagerId]](GetPeers(blockManagerId))
}

def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = {
driverEndpoint.askWithRetry[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId))
driverEndpoint.askSync[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId))
}

/**
* Remove a block from the slaves that have it. This can only be used to remove
* blocks that the driver knows about.
*/
def removeBlock(blockId: BlockId) {
driverEndpoint.askWithRetry[Boolean](RemoveBlock(blockId))
driverEndpoint.askSync[Boolean](RemoveBlock(blockId))
}

/** Remove all blocks belonging to the given RDD. */
def removeRdd(rddId: Int, blocking: Boolean) {
val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId))
val future = driverEndpoint.askSync[Future[Seq[Int]]](RemoveRdd(rddId))
future.onFailure {
case e: Exception =>
logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}", e)
Expand All @@ -128,7 +128,7 @@ class BlockManagerMaster(

/** Remove all blocks belonging to the given shuffle. */
def removeShuffle(shuffleId: Int, blocking: Boolean) {
val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
val future = driverEndpoint.askSync[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
future.onFailure {
case e: Exception =>
logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}", e)
Expand All @@ -140,7 +140,7 @@ class BlockManagerMaster(

/** Remove all blocks belonging to the given broadcast. */
def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) {
val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](
val future = driverEndpoint.askSync[Future[Seq[Int]]](
RemoveBroadcast(broadcastId, removeFromMaster))
future.onFailure {
case e: Exception =>
Expand All @@ -159,11 +159,11 @@ class BlockManagerMaster(
* amount of remaining memory.
*/
def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
driverEndpoint.askWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
driverEndpoint.askSync[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
}

def getStorageStatus: Array[StorageStatus] = {
driverEndpoint.askWithRetry[Array[StorageStatus]](GetStorageStatus)
driverEndpoint.askSync[Array[StorageStatus]](GetStorageStatus)
}

/**
Expand All @@ -184,7 +184,7 @@ class BlockManagerMaster(
* master endpoint for a response to a prior message.
*/
val response = driverEndpoint.
askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
askSync[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
val (blockManagerIds, futures) = response.unzip
implicit val sameThread = ThreadUtils.sameThread
val cbf =
Expand Down Expand Up @@ -214,7 +214,7 @@ class BlockManagerMaster(
filter: BlockId => Boolean,
askSlaves: Boolean): Seq[BlockId] = {
val msg = GetMatchingBlockIds(filter, askSlaves)
val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg)
val future = driverEndpoint.askSync[Future[Seq[BlockId]]](msg)
timeout.awaitResult(future)
}

Expand All @@ -223,7 +223,7 @@ class BlockManagerMaster(
* since they are not reported the master.
*/
def hasCachedBlocks(executorId: String): Boolean = {
driverEndpoint.askWithRetry[Boolean](HasCachedBlocks(executorId))
driverEndpoint.askSync[Boolean](HasCachedBlocks(executorId))
}

/** Stop the driver endpoint, called only on the Spark driver node */
Expand All @@ -237,7 +237,7 @@ class BlockManagerMaster(

/** Send a one-way message to the master endpoint, to which we expect it to reply with true. */
private def tell(message: Any) {
if (!driverEndpoint.askWithRetry[Boolean](message)) {
if (!driverEndpoint.askSync[Boolean](message)) {
throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ class StandaloneDynamicAllocationSuite

/** Get the Master state */
private def getMasterState: MasterStateResponse = {
master.self.askWithRetry[MasterStateResponse](RequestMasterState)
master.self.askSync[MasterStateResponse](RequestMasterState)
}

/** Get the applications that are active from Master */
Expand Down Expand Up @@ -620,7 +620,7 @@ class StandaloneDynamicAllocationSuite
when(endpointRef.address).thenReturn(mockAddress)
val message = RegisterExecutor(id, endpointRef, "localhost", 10, Map.empty)
val backend = sc.schedulerBackend.asInstanceOf[CoarseGrainedSchedulerBackend]
backend.driverEndpoint.askWithRetry[Boolean](message)
backend.driverEndpoint.askSync[Boolean](message)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class AppClientSuite

/** Get the Master state */
private def getMasterState: MasterStateResponse = {
master.self.askWithRetry[MasterStateResponse](RequestMasterState)
master.self.askSync[MasterStateResponse](RequestMasterState)
}

/** Get the applications that are active from Master */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ class MasterSuite extends SparkFunSuite
val master = makeMaster()
master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master)
eventually(timeout(10.seconds)) {
val masterState = master.self.askWithRetry[MasterStateResponse](RequestMasterState)
val masterState = master.self.askSync[MasterStateResponse](RequestMasterState)
assert(masterState.status === RecoveryState.ALIVE, "Master is not alive")
}

Expand Down
Loading

0 comments on commit ba8912e

Please sign in to comment.