Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-2583] ConnectionManager error reporting #1758

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e2b8c4a
Modify to propagete error using ConnectionManager
sarutak Jul 19, 2014
6635467
Merge branch 'master' of git://git.apache.org/spark into SPARK-2583
sarutak Jul 19, 2014
717c9c3
Merge branch 'master' of git://git.apache.org/spark into SPARK-2583
sarutak Jul 20, 2014
4117b8f
Modified ConnectionManager to be alble to handle error during process…
sarutak Jul 20, 2014
12d3de8
Added BlockFetcherIteratorSuite.scala
sarutak Jul 20, 2014
ffaa83d
Merge branch 'master' of git://git.apache.org/spark into SPARK-2583
sarutak Jul 22, 2014
0654128
Merge branch 'master' of git://git.apache.org/spark into SPARK-2583
sarutak Jul 23, 2014
281589c
Add a test case to BlockFetcherIteratorSuite.scala for fetching block…
sarutak Jul 23, 2014
e579302
Merge branch 'master' of git://git.apache.org/spark into SPARK-2583
sarutak Jul 23, 2014
22d7ebd
Add test cases to BlockManagerSuite for SPARK-2583
sarutak Jul 24, 2014
2a18d6b
Merge branch 'master' of git://git.apache.org/spark into SPARK-2583
sarutak Jul 24, 2014
326a17f
Add test cases to ConnectionManagerSuite.scala for SPARK-2583
sarutak Jul 24, 2014
e7d9aa6
rebase to master
sarutak Jul 28, 2014
9dfd0d8
Merge branch 'master' of git://git.apache.org/spark into SPARK-2583
sarutak Jul 29, 2014
ee91bb7
Modified BufferMessage.scala to keep the spark code style
sarutak Jul 29, 2014
7399c6b
Merge remote-tracking branch 'origin/pr/1490' into connection-manager…
JoshRosen Aug 1, 2014
f1cd1bb
Clean up @sarutak's PR #1490 for [SPARK-2583]: ConnectionManager erro…
JoshRosen Aug 4, 2014
c01c450
Return Try[Message] from sendMessageReliablySync.
JoshRosen Aug 4, 2014
a2f745c
Remove sendMessageReliablySync; callers can wait themselves.
JoshRosen Aug 5, 2014
659521f
Include previous exception when throwing new one
JoshRosen Aug 6, 2014
b8bb4d4
Fix manager.id vs managerServer.id typo that broke security tests.
JoshRosen Aug 6, 2014
83673de
Error ACKs should trigger IOExceptions, so catch only those exception…
JoshRosen Aug 6, 2014
68620cb
Fix test in BlockFetcherIteratorSuite:
JoshRosen Aug 6, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val security = if (isSecurityNeg) 1 else 0
if (size == 0 && !gotChunkForSendingOnce) {
val newChunk = new MessageChunk(
new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null)
new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null)
gotChunkForSendingOnce = true
return Some(newChunk)
}
Expand All @@ -66,7 +66,8 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
}
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
typ, id, size, newBuffer.remaining, ackId,
hasError, security, senderAddress), newBuffer)
gotChunkForSendingOnce = true
return Some(newChunk)
}
Expand All @@ -88,7 +89,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer)
return Some(newChunk)
}
None
Expand Down
143 changes: 83 additions & 60 deletions core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.network

import java.io.IOException
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
Expand All @@ -41,16 +42,26 @@ import org.apache.spark.util.{SystemClock, Utils}
private[spark] class ConnectionManager(port: Int, conf: SparkConf,
securityManager: SecurityManager) extends Logging {

/**
* Used by sendMessageReliably to track messages being sent.
* @param message the message that was sent
* @param connectionManagerId the connection manager that sent this message
* @param completionHandler callback that's invoked when the send has completed or failed
*/
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll notice that I removed a bunch of fields here. attempted was never read anywhere, and acked implied ackMessage != None.

class MessageStatus(
val message: Message,
val connectionManagerId: ConnectionManagerId,
completionHandler: MessageStatus => Unit) {

/** This is non-None if message has been ack'd */
var ackMessage: Option[Message] = None
var attempted = false
var acked = false

def markDone() { completionHandler(this) }
def markDone(ackMessage: Option[Message]) {
this.synchronized {
this.ackMessage = ackMessage
completionHandler(this)
}
}
}

private val selector = SelectorProvider.provider.openSelector()
Expand Down Expand Up @@ -434,11 +445,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId)
.foreach(status => {
logInfo("Notifying " + status)
status.synchronized {
status.attempted = true
status.acked = false
status.markDone()
}
status.markDone(None)
})

messageStatuses.retain((i, status) => {
Expand Down Expand Up @@ -467,11 +474,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
for (s <- messageStatuses.values
if s.connectionManagerId == sendingConnectionManagerId) {
logInfo("Notifying " + s)
s.synchronized {
s.attempted = true
s.acked = false
s.markDone()
}
s.markDone(None)
}

messageStatuses.retain((i, status) => {
Expand Down Expand Up @@ -539,13 +542,13 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
securityMsg.getConnectionId.toString)
val message = securityMsgResp.toBufferMessage
if (message == null) throw new Exception("Error creating security message")
if (message == null) throw new IOException("Error creating security message")
sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
} catch {
case e: Exception => {
logError("Error handling sasl client authentication", e)
waitingConn.close()
throw new Exception("Error evaluating sasl response: " + e)
throw new IOException("Error evaluating sasl response: ", e)
}
}
}
Expand Down Expand Up @@ -653,34 +656,39 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
}
}
}
sentMessageStatus.synchronized {
sentMessageStatus.ackMessage = Some(message)
sentMessageStatus.attempted = true
sentMessageStatus.acked = true
sentMessageStatus.markDone()
}
sentMessageStatus.markDone(Some(message))
} else {
val ackMessage = if (onReceiveCallback != null) {
logDebug("Calling back")
onReceiveCallback(bufferMessage, connectionManagerId)
} else {
logDebug("Not calling back as callback is null")
None
}
var ackMessage : Option[Message] = None
try {
ackMessage = if (onReceiveCallback != null) {
logDebug("Calling back")
onReceiveCallback(bufferMessage, connectionManagerId)
} else {
logDebug("Not calling back as callback is null")
None
}

if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
+ ackMessage.get.getClass)
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
logDebug("Response to " + bufferMessage + " does not have ack id set")
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
+ ackMessage.get.getClass)
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
logDebug("Response to " + bufferMessage + " does not have ack id set")
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
}
}
} catch {
case e: Exception => {
logError(s"Exception was thrown while processing message", e)
val m = Message.createBufferMessage(bufferMessage.id)
m.hasError = true
ackMessage = Some(m)
}
} finally {
sendMessage(connectionManagerId, ackMessage.getOrElse {
Message.createBufferMessage(bufferMessage.id)
})
}

sendMessage(connectionManagerId, ackMessage.getOrElse {
Message.createBufferMessage(bufferMessage.id)
})
}
}
case _ => throw new Exception("Unknown type message received")
Expand Down Expand Up @@ -792,11 +800,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
case Some(msgStatus) => {
messageStatuses -= message.id
logInfo("Notifying " + msgStatus.connectionManagerId)
msgStatus.synchronized {
msgStatus.attempted = true
msgStatus.acked = false
msgStatus.markDone()
}
msgStatus.markDone(None)
}
case None => {
logError("no messageStatus for failed message id: " + message.id)
Expand All @@ -815,23 +819,35 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
selector.wakeup()
}

/**
* Send a message and block until an acknowldgment is received or an error occurs.
* @param connectionManagerId the message's destination
* @param message the message being sent
* @return a Future that either returns the acknowledgment message or captures an exception.
*/
def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message)
: Future[Option[Message]] = {
val promise = Promise[Option[Message]]
val status = new MessageStatus(
message, connectionManagerId, s => promise.success(s.ackMessage))
: Future[Message] = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we now signal failures via Futures, I changed this into Future[Message] instead of Future[Some[Message]].

val promise = Promise[Message]()
val status = new MessageStatus(message, connectionManagerId, s => {
s.ackMessage match {
case None => // Indicates a failure where we either never sent or never got ACK'd
promise.failure(new IOException("sendMessageReliably failed without being ACK'd"))
case Some(ackMessage) =>
if (ackMessage.hasError) {
promise.failure(
new IOException("sendMessageReliably failed with ACK that signalled a remote error"))
} else {
promise.success(ackMessage)
}
}
})
messageStatuses.synchronized {
messageStatuses += ((message.id, status))
}
sendMessage(connectionManagerId, message)
promise.future
}

def sendMessageReliablySync(connectionManagerId: ConnectionManagerId,
message: Message): Option[Message] = {
Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf)
}

def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
onReceiveCallback = callback
}
Expand All @@ -854,6 +870,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,


private[spark] object ConnectionManager {
import ExecutionContext.Implicits.global

def main(args: Array[String]) {
val conf = new SparkConf
Expand Down Expand Up @@ -888,7 +905,7 @@ private[spark] object ConnectionManager {

(0 until count).map(i => {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliablySync(manager.id, bufferMessage)
Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf)
})
println("--------------------------")
println()
Expand All @@ -909,8 +926,10 @@ private[spark] object ConnectionManager {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
f.onFailure {
case e => println("Failed due to " + e)
}
Await.ready(f, 1 second)
})
val finishTime = System.currentTimeMillis

Expand Down Expand Up @@ -944,8 +963,10 @@ private[spark] object ConnectionManager {
val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
f.onFailure {
case e => println("Failed due to " + e)
}
Await.ready(f, 1 second)
})
val finishTime = System.currentTimeMillis

Expand Down Expand Up @@ -974,8 +995,10 @@ private[spark] object ConnectionManager {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
f.onFailure {
case e => println("Failed due to " + e)
}
Await.ready(f, 1 second)
})
val finishTime = System.currentTimeMillis
Thread.sleep(1000)
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/network/Message.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
var startTime = -1L
var finishTime = -1L
var isSecurityNeg = false
var hasError = false

def size: Int

Expand Down Expand Up @@ -87,6 +88,7 @@ private[spark] object Message {
case BUFFER_MESSAGE => new BufferMessage(header.id,
ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
}
newMessage.hasError = header.hasError
newMessage.senderAddress = header.address
newMessage
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader(
val totalSize: Int,
val chunkSize: Int,
val other: Int,
val hasError: Boolean,
val securityNeg: Int,
val address: InetSocketAddress) {
lazy val buffer = {
Expand All @@ -41,6 +42,7 @@ private[spark] class MessageChunkHeader(
putInt(totalSize).
putInt(chunkSize).
putInt(other).
put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]).
putInt(securityNeg).
putInt(ip.size).
put(ip).
Expand All @@ -56,7 +58,7 @@ private[spark] class MessageChunkHeader(


private[spark] object MessageChunkHeader {
val HEADER_SIZE = 44
val HEADER_SIZE = 45

def create(buffer: ByteBuffer): MessageChunkHeader = {
if (buffer.remaining != HEADER_SIZE) {
Expand All @@ -67,13 +69,14 @@ private[spark] object MessageChunkHeader {
val totalSize = buffer.getInt()
val chunkSize = buffer.getInt()
val other = buffer.getInt()
val hasError = buffer.get() != 0
val securityNeg = buffer.getInt()
val ipSize = buffer.getInt()
val ipBytes = new Array[Byte](ipSize)
buffer.get(ipBytes)
val ip = InetAddress.getByAddress(ipBytes)
val port = buffer.getInt()
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg,
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg,
new InetSocketAddress(ip, port))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ package org.apache.spark.network
import java.nio.ByteBuffer
import org.apache.spark.{SecurityManager, SparkConf}

import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.util.Try

private[spark] object SenderTest {
def main(args: Array[String]) {

Expand Down Expand Up @@ -51,7 +55,8 @@ private[spark] object SenderTest {
val dataMessage = Message.createBufferMessage(buffer.duplicate)
val startTime = System.currentTimeMillis
/* println("Started timer at " + startTime) */
val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage)
val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage)
val responseStr: String = Try(Await.result(promise, Duration.Inf))
.map { response =>
val buffer = response.asInstanceOf[BufferMessage].buffers(0)
new String(buffer.array, "utf-8")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashSet
import scala.collection.mutable.Queue
import scala.util.{Failure, Success}

import io.netty.buffer.ByteBuf

Expand Down Expand Up @@ -118,8 +119,8 @@ object BlockFetcherIterator {
bytesInFlight += req.size
val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
future.onSuccess {
case Some(message) => {
future.onComplete {
case Success(message) => {
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
for (blockMessage <- blockMessageArray) {
Expand All @@ -135,8 +136,8 @@ object BlockFetcherIterator {
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
}
case None => {
logError("Could not get block(s) from " + cmId)
case Failure(exception) => {
logError("Could not get block(s) from " + cmId, exception)
for ((blockId, size) <- req.blocks) {
results.put(new FetchResult(blockId, -1, null))
}
Expand Down
Loading