diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index ac9321665c251..ddf6b6e5404c1 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -17,6 +17,7 @@ package org.apache.spark.network +import java.io.IOException import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ @@ -41,16 +42,26 @@ import org.apache.spark.util.{SystemClock, Utils} private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManager: SecurityManager) extends Logging { + /** + * Used by sendMessageReliably to track messages being sent. + * @param message the message that was sent + * @param connectionManagerId the connection manager that sent this message + * @param completionHandler callback that's invoked when the send has completed or failed + */ class MessageStatus( val message: Message, val connectionManagerId: ConnectionManagerId, completionHandler: MessageStatus => Unit) { + /** This is non-None if message has been ack'd */ var ackMessage: Option[Message] = None - var attempted = false - var acked = false - def markDone() { completionHandler(this) } + def markDone(ackMessage: Option[Message]) { + this.synchronized { + this.ackMessage = ackMessage + completionHandler(this) + } + } } private val selector = SelectorProvider.provider.openSelector() @@ -434,11 +445,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) .foreach(status => { logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.markDone() - } + status.markDone(None) }) messageStatuses.retain((i, status) => { @@ -467,11 +474,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { logInfo("Notifying " + s) - s.synchronized { - s.attempted = true - s.acked = false - s.markDone() - } + s.markDone(None) } messageStatuses.retain((i, status) => { @@ -539,13 +542,13 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, val securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId.toString) val message = securityMsgResp.toBufferMessage - if (message == null) throw new Exception("Error creating security message") + if (message == null) throw new IOException("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) } catch { case e: Exception => { logError("Error handling sasl client authentication", e) waitingConn.close() - throw new Exception("Error evaluating sasl response: " + e) + throw new IOException("Error evaluating sasl response: " + e) } } } @@ -653,12 +656,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } } } - sentMessageStatus.synchronized { - sentMessageStatus.ackMessage = Some(message) - sentMessageStatus.attempted = true - sentMessageStatus.acked = true - sentMessageStatus.markDone() - } + sentMessageStatus.markDone(Some(message)) } else { var ackMessage : Option[Message] = None try { @@ -681,7 +679,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } } catch { case e: Exception => { - logError(s"Exception was thrown during processing message", e) + logError(s"Exception was thrown while processing message", e) val m = Message.createBufferMessage(bufferMessage.id) m.hasError = true ackMessage = Some(m) @@ -802,11 +800,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, case Some(msgStatus) => { messageStatuses -= message.id logInfo("Notifying " + msgStatus.connectionManagerId) - msgStatus.synchronized { - msgStatus.attempted = true - msgStatus.acked = false - msgStatus.markDone() - } + msgStatus.markDone(None) } case None => { logError("no messageStatus for failed message id: " + message.id) @@ -825,11 +819,28 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, selector.wakeup() } + /** + * Send a message and block until an acknowldgment is received or an error occurs. + * @param connectionManagerId the message's destination + * @param message the message being sent + * @return a Future that either returns the acknowledgment message or captures an exception. + */ def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) - : Future[Option[Message]] = { - val promise = Promise[Option[Message]] - val status = new MessageStatus( - message, connectionManagerId, s => promise.success(s.ackMessage)) + : Future[Message] = { + val promise = Promise[Message]() + val status = new MessageStatus(message, connectionManagerId, s => { + s.ackMessage match { + case None => // Indicates a failure where we either never sent or never got ACK'd + promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) + case Some(ackMessage) => + if (ackMessage.hasError) { + promise.failure( + new IOException("sendMessageReliably failed with ACK that signalled a remote error")) + } else { + promise.success(ackMessage) + } + } + }) messageStatuses.synchronized { messageStatuses += ((message.id, status)) } @@ -838,7 +849,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, - message: Message): Option[Message] = { + message: Message): Message = { Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf) } @@ -864,6 +875,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, private[spark] object ConnectionManager { + import ExecutionContext.Implicits.global def main(args: Array[String]) { val conf = new SparkConf @@ -919,8 +931,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis @@ -954,8 +968,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis @@ -984,8 +1000,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis Thread.sleep(1000) diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index b8ea7c2cff9a2..f15eaef80d3e1 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.network import java.nio.ByteBuffer +import scala.util.Try import org.apache.spark.{SecurityManager, SparkConf} private[spark] object SenderTest { @@ -51,7 +52,7 @@ private[spark] object SenderTest { val dataMessage = Message.createBufferMessage(buffer.duplicate) val startTime = System.currentTimeMillis /* println("Started timer at " + startTime) */ - val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) + val responseStr = Try(manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage)) .map { response => val buffer = response.asInstanceOf[BufferMessage].buffers(0) new String(buffer.array, "utf-8") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 59c97e1dece4a..f05fadf5c26da 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -22,6 +22,7 @@ import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue +import scala.util.{Failure, Success} import io.netty.buffer.ByteBuf @@ -118,31 +119,24 @@ object BlockFetcherIterator { bytesInFlight += req.size val sizeMap = req.blocks.toMap // so we can look up the size of each blockID val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - future.onSuccess { - case Some(message) => { + future.onComplete { + case Success(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] - if (bufferMessage.hasError) { - logError("Could not get block(s) from " + cmId) - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } - } else { - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - throw new SparkException( - "Unexpected message " + blockMessage.getType + " received from " + cmId) - } - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - results.put(new FetchResult(blockId, sizeMap(blockId), - () => dataDeserialize(blockId, blockMessage.getData, serializer))) - _remoteBytesRead += networkSize - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) + for (blockMessage <- blockMessageArray) { + if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { + throw new SparkException( + "Unexpected message " + blockMessage.getType + " received from " + cmId) } + val blockId = blockMessage.getId + val networkSize = blockMessage.getData.limit() + results.put(new FetchResult(blockId, sizeMap(blockId), + () => dataDeserialize(blockId, blockMessage.getData, serializer))) + _remoteBytesRead += networkSize + logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } } - case None => { + case Failure(exception) => { logError("Could not get block(s) from " + cmId) for ((blockId, size) <- req.blocks) { results.put(new FetchResult(blockId, -1, null)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index 9c459cd3a3437..518cc5e419102 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -23,6 +23,8 @@ import org.apache.spark.Logging import org.apache.spark.network._ import org.apache.spark.util.Utils +import scala.util.{Failure, Success, Try} + /** * A network interface for BlockManager. Each slave should have one * BlockManagerWorker. @@ -115,9 +117,9 @@ private[spark] object BlockManagerWorker extends Logging { val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromPutBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) - val resultMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) - resultMessage.isDefined + val resultMessage = Try(connectionManager.sendMessageReliablySync( + toConnManagerId, blockMessageArray.toBufferMessage)) + resultMessage.isSuccess } def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { @@ -125,10 +127,10 @@ private[spark] object BlockManagerWorker extends Logging { val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromGetBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) - val responseMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) + val responseMessage = Try(connectionManager.sendMessageReliablySync( + toConnManagerId, blockMessageArray.toBufferMessage)) responseMessage match { - case Some(message) => { + case Success(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] logDebug("Response message received " + bufferMessage) BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => { @@ -136,7 +138,7 @@ private[spark] object BlockManagerWorker extends Logging { return blockMessage.getData }) } - case None => logDebug("No response message received") + case Failure(exception) => logDebug("No response message received") } null } diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index 162d0a016c136..03f36fe50d9c4 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.FunSuite import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.Try /** * Test the ConnectionManager with various security settings. @@ -209,7 +210,6 @@ class ConnectionManagerSuite extends FunSuite { }).foreach(f => { try { val g = Await.result(f, 1 second) - if (!g.isDefined) assert(false) else assert(true) } catch { case e: Exception => { assert(false) @@ -240,9 +240,8 @@ class ConnectionManagerSuite extends FunSuite { val future = manager.sendMessageReliably(managerServer.id, bufferMessage) - val message = Await.result(future, 1 second) - assert(message.isDefined) - assert(message.get.hasError) + val message = Try(Await.result(future, 1 second)) + assert(message.isFailure) manager.stop() managerServer.stop() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index db867f21e0cb9..0bcf42c981cea 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -148,8 +148,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { val f = future { val message = Message.createBufferMessage(0) message.hasError = true - val someMessage = Some(message) - someMessage + message } when(connManager.sendMessageReliably(any(), any())).thenReturn(f) @@ -204,10 +203,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { buffer.flip() arrayBuffer += buffer - val someMessage = Some(Message.createBufferMessage(arrayBuffer)) - val f = future { - someMessage + Message.createBufferMessage(arrayBuffer) } when(connManager.sendMessageReliably(any(), any())).thenReturn(f)