Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into dt-robustness
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Aug 7, 2014
2 parents 7a61f7b + 4201d27 commit 4dc449a
Show file tree
Hide file tree
Showing 13 changed files with 407 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val security = if (isSecurityNeg) 1 else 0
if (size == 0 && !gotChunkForSendingOnce) {
val newChunk = new MessageChunk(
new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null)
new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null)
gotChunkForSendingOnce = true
return Some(newChunk)
}
Expand All @@ -66,7 +66,8 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
}
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
typ, id, size, newBuffer.remaining, ackId,
hasError, security, senderAddress), newBuffer)
gotChunkForSendingOnce = true
return Some(newChunk)
}
Expand All @@ -88,7 +89,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer)
return Some(newChunk)
}
None
Expand Down
143 changes: 83 additions & 60 deletions core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.network

import java.io.IOException
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
Expand Down Expand Up @@ -45,16 +46,26 @@ private[spark] class ConnectionManager(
name: String = "Connection manager")
extends Logging {

/**
* Used by sendMessageReliably to track messages being sent.
* @param message the message that was sent
* @param connectionManagerId the connection manager that sent this message
* @param completionHandler callback that's invoked when the send has completed or failed
*/
class MessageStatus(
val message: Message,
val connectionManagerId: ConnectionManagerId,
completionHandler: MessageStatus => Unit) {

/** This is non-None if message has been ack'd */
var ackMessage: Option[Message] = None
var attempted = false
var acked = false

def markDone() { completionHandler(this) }
def markDone(ackMessage: Option[Message]) {
this.synchronized {
this.ackMessage = ackMessage
completionHandler(this)
}
}
}

private val selector = SelectorProvider.provider.openSelector()
Expand Down Expand Up @@ -442,11 +453,7 @@ private[spark] class ConnectionManager(
messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId)
.foreach(status => {
logInfo("Notifying " + status)
status.synchronized {
status.attempted = true
status.acked = false
status.markDone()
}
status.markDone(None)
})

messageStatuses.retain((i, status) => {
Expand Down Expand Up @@ -475,11 +482,7 @@ private[spark] class ConnectionManager(
for (s <- messageStatuses.values
if s.connectionManagerId == sendingConnectionManagerId) {
logInfo("Notifying " + s)
s.synchronized {
s.attempted = true
s.acked = false
s.markDone()
}
s.markDone(None)
}

messageStatuses.retain((i, status) => {
Expand Down Expand Up @@ -547,13 +550,13 @@ private[spark] class ConnectionManager(
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
securityMsg.getConnectionId.toString)
val message = securityMsgResp.toBufferMessage
if (message == null) throw new Exception("Error creating security message")
if (message == null) throw new IOException("Error creating security message")
sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
} catch {
case e: Exception => {
logError("Error handling sasl client authentication", e)
waitingConn.close()
throw new Exception("Error evaluating sasl response: " + e)
throw new IOException("Error evaluating sasl response: ", e)
}
}
}
Expand Down Expand Up @@ -661,34 +664,39 @@ private[spark] class ConnectionManager(
}
}
}
sentMessageStatus.synchronized {
sentMessageStatus.ackMessage = Some(message)
sentMessageStatus.attempted = true
sentMessageStatus.acked = true
sentMessageStatus.markDone()
}
sentMessageStatus.markDone(Some(message))
} else {
val ackMessage = if (onReceiveCallback != null) {
logDebug("Calling back")
onReceiveCallback(bufferMessage, connectionManagerId)
} else {
logDebug("Not calling back as callback is null")
None
}
var ackMessage : Option[Message] = None
try {
ackMessage = if (onReceiveCallback != null) {
logDebug("Calling back")
onReceiveCallback(bufferMessage, connectionManagerId)
} else {
logDebug("Not calling back as callback is null")
None
}

if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
+ ackMessage.get.getClass)
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
logDebug("Response to " + bufferMessage + " does not have ack id set")
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
+ ackMessage.get.getClass)
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
logDebug("Response to " + bufferMessage + " does not have ack id set")
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
}
}
} catch {
case e: Exception => {
logError(s"Exception was thrown while processing message", e)
val m = Message.createBufferMessage(bufferMessage.id)
m.hasError = true
ackMessage = Some(m)
}
} finally {
sendMessage(connectionManagerId, ackMessage.getOrElse {
Message.createBufferMessage(bufferMessage.id)
})
}

sendMessage(connectionManagerId, ackMessage.getOrElse {
Message.createBufferMessage(bufferMessage.id)
})
}
}
case _ => throw new Exception("Unknown type message received")
Expand Down Expand Up @@ -800,11 +808,7 @@ private[spark] class ConnectionManager(
case Some(msgStatus) => {
messageStatuses -= message.id
logInfo("Notifying " + msgStatus.connectionManagerId)
msgStatus.synchronized {
msgStatus.attempted = true
msgStatus.acked = false
msgStatus.markDone()
}
msgStatus.markDone(None)
}
case None => {
logError("no messageStatus for failed message id: " + message.id)
Expand All @@ -823,23 +827,35 @@ private[spark] class ConnectionManager(
selector.wakeup()
}

/**
* Send a message and block until an acknowldgment is received or an error occurs.
* @param connectionManagerId the message's destination
* @param message the message being sent
* @return a Future that either returns the acknowledgment message or captures an exception.
*/
def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message)
: Future[Option[Message]] = {
val promise = Promise[Option[Message]]
val status = new MessageStatus(
message, connectionManagerId, s => promise.success(s.ackMessage))
: Future[Message] = {
val promise = Promise[Message]()
val status = new MessageStatus(message, connectionManagerId, s => {
s.ackMessage match {
case None => // Indicates a failure where we either never sent or never got ACK'd
promise.failure(new IOException("sendMessageReliably failed without being ACK'd"))
case Some(ackMessage) =>
if (ackMessage.hasError) {
promise.failure(
new IOException("sendMessageReliably failed with ACK that signalled a remote error"))
} else {
promise.success(ackMessage)
}
}
})
messageStatuses.synchronized {
messageStatuses += ((message.id, status))
}
sendMessage(connectionManagerId, message)
promise.future
}

def sendMessageReliablySync(connectionManagerId: ConnectionManagerId,
message: Message): Option[Message] = {
Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf)
}

def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
onReceiveCallback = callback
}
Expand All @@ -862,6 +878,7 @@ private[spark] class ConnectionManager(


private[spark] object ConnectionManager {
import ExecutionContext.Implicits.global

def main(args: Array[String]) {
val conf = new SparkConf
Expand Down Expand Up @@ -896,7 +913,7 @@ private[spark] object ConnectionManager {

(0 until count).map(i => {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliablySync(manager.id, bufferMessage)
Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf)
})
println("--------------------------")
println()
Expand All @@ -917,8 +934,10 @@ private[spark] object ConnectionManager {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
f.onFailure {
case e => println("Failed due to " + e)
}
Await.ready(f, 1 second)
})
val finishTime = System.currentTimeMillis

Expand Down Expand Up @@ -952,8 +971,10 @@ private[spark] object ConnectionManager {
val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
f.onFailure {
case e => println("Failed due to " + e)
}
Await.ready(f, 1 second)
})
val finishTime = System.currentTimeMillis

Expand Down Expand Up @@ -982,8 +1003,10 @@ private[spark] object ConnectionManager {
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
manager.sendMessageReliably(manager.id, bufferMessage)
}).foreach(f => {
val g = Await.result(f, 1 second)
if (!g.isDefined) println("Failed")
f.onFailure {
case e => println("Failed due to " + e)
}
Await.ready(f, 1 second)
})
val finishTime = System.currentTimeMillis
Thread.sleep(1000)
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/network/Message.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
var startTime = -1L
var finishTime = -1L
var isSecurityNeg = false
var hasError = false

def size: Int

Expand Down Expand Up @@ -87,6 +88,7 @@ private[spark] object Message {
case BUFFER_MESSAGE => new BufferMessage(header.id,
ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
}
newMessage.hasError = header.hasError
newMessage.senderAddress = header.address
newMessage
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader(
val totalSize: Int,
val chunkSize: Int,
val other: Int,
val hasError: Boolean,
val securityNeg: Int,
val address: InetSocketAddress) {
lazy val buffer = {
Expand All @@ -41,6 +42,7 @@ private[spark] class MessageChunkHeader(
putInt(totalSize).
putInt(chunkSize).
putInt(other).
put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]).
putInt(securityNeg).
putInt(ip.size).
put(ip).
Expand All @@ -56,7 +58,7 @@ private[spark] class MessageChunkHeader(


private[spark] object MessageChunkHeader {
val HEADER_SIZE = 44
val HEADER_SIZE = 45

def create(buffer: ByteBuffer): MessageChunkHeader = {
if (buffer.remaining != HEADER_SIZE) {
Expand All @@ -67,13 +69,14 @@ private[spark] object MessageChunkHeader {
val totalSize = buffer.getInt()
val chunkSize = buffer.getInt()
val other = buffer.getInt()
val hasError = buffer.get() != 0
val securityNeg = buffer.getInt()
val ipSize = buffer.getInt()
val ipBytes = new Array[Byte](ipSize)
buffer.get(ipBytes)
val ip = InetAddress.getByAddress(ipBytes)
val port = buffer.getInt()
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg,
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg,
new InetSocketAddress(ip, port))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ package org.apache.spark.network
import java.nio.ByteBuffer
import org.apache.spark.{SecurityManager, SparkConf}

import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.util.Try

private[spark] object SenderTest {
def main(args: Array[String]) {

Expand Down Expand Up @@ -51,7 +55,8 @@ private[spark] object SenderTest {
val dataMessage = Message.createBufferMessage(buffer.duplicate)
val startTime = System.currentTimeMillis
/* println("Started timer at " + startTime) */
val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage)
val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage)
val responseStr: String = Try(Await.result(promise, Duration.Inf))
.map { response =>
val buffer = response.asInstanceOf[BufferMessage].buffers(0)
new String(buffer.array, "utf-8")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashSet
import scala.collection.mutable.Queue
import scala.util.{Failure, Success}

import io.netty.buffer.ByteBuf

Expand Down Expand Up @@ -118,8 +119,8 @@ object BlockFetcherIterator {
bytesInFlight += req.size
val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
future.onSuccess {
case Some(message) => {
future.onComplete {
case Success(message) => {
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
for (blockMessage <- blockMessageArray) {
Expand All @@ -135,8 +136,8 @@ object BlockFetcherIterator {
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
}
case None => {
logError("Could not get block(s) from " + cmId)
case Failure(exception) => {
logError("Could not get block(s) from " + cmId, exception)
for ((blockId, size) <- req.blocks) {
results.put(new FetchResult(blockId, -1, null))
}
Expand Down
Loading

0 comments on commit 4dc449a

Please sign in to comment.