Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/pr/1490' into connection-manager…
Browse files Browse the repository at this point in the history
…-fixes

Conflicts:
	core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
	core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
  • Loading branch information
JoshRosen committed Aug 1, 2014
2 parents 78f2af5 + ee91bb7 commit 7399c6b
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val security = if (isSecurityNeg) 1 else 0
if (size == 0 && !gotChunkForSendingOnce) {
val newChunk = new MessageChunk(
new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null)
new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null)
gotChunkForSendingOnce = true
return Some(newChunk)
}
Expand All @@ -66,7 +66,8 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
}
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
typ, id, size, newBuffer.remaining, ackId,
hasError, security, senderAddress), newBuffer)
gotChunkForSendingOnce = true
return Some(newChunk)
}
Expand All @@ -88,7 +89,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer)
return Some(newChunk)
}
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,27 +660,37 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
sentMessageStatus.markDone()
}
} else {
val ackMessage = if (onReceiveCallback != null) {
logDebug("Calling back")
onReceiveCallback(bufferMessage, connectionManagerId)
} else {
logDebug("Not calling back as callback is null")
None
}
var ackMessage : Option[Message] = None
try {
ackMessage = if (onReceiveCallback != null) {
logDebug("Calling back")
onReceiveCallback(bufferMessage, connectionManagerId)
} else {
logDebug("Not calling back as callback is null")
None
}

if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
+ ackMessage.get.getClass)
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
logDebug("Response to " + bufferMessage + " does not have ack id set")
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
+ ackMessage.get.getClass)
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
logDebug("Response to " + bufferMessage + " does not have ack id set")
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
}
}
} catch {
case e: Exception => {
logError(s"Exception was thrown during processing message", e)
val m = Message.createBufferMessage(bufferMessage.id)
m.hasError = true
ackMessage = Some(m)
}
} finally {
sendMessage(connectionManagerId, ackMessage.getOrElse {
Message.createBufferMessage(bufferMessage.id)
})
}

sendMessage(connectionManagerId, ackMessage.getOrElse {
Message.createBufferMessage(bufferMessage.id)
})
}
}
case _ => throw new Exception("Unknown type message received")
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/network/Message.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
var startTime = -1L
var finishTime = -1L
var isSecurityNeg = false
var hasError = false

def size: Int

Expand Down Expand Up @@ -87,6 +88,7 @@ private[spark] object Message {
case BUFFER_MESSAGE => new BufferMessage(header.id,
ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
}
newMessage.hasError = header.hasError
newMessage.senderAddress = header.address
newMessage
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader(
val totalSize: Int,
val chunkSize: Int,
val other: Int,
val hasError: Boolean,
val securityNeg: Int,
val address: InetSocketAddress) {
lazy val buffer = {
Expand All @@ -41,6 +42,7 @@ private[spark] class MessageChunkHeader(
putInt(totalSize).
putInt(chunkSize).
putInt(other).
put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]).
putInt(securityNeg).
putInt(ip.size).
put(ip).
Expand All @@ -56,7 +58,7 @@ private[spark] class MessageChunkHeader(


private[spark] object MessageChunkHeader {
val HEADER_SIZE = 44
val HEADER_SIZE = 45

def create(buffer: ByteBuffer): MessageChunkHeader = {
if (buffer.remaining != HEADER_SIZE) {
Expand All @@ -67,13 +69,14 @@ private[spark] object MessageChunkHeader {
val totalSize = buffer.getInt()
val chunkSize = buffer.getInt()
val other = buffer.getInt()
val hasError = buffer.get() != 0
val securityNeg = buffer.getInt()
val ipSize = buffer.getInt()
val ipBytes = new Array[Byte](ipSize)
buffer.get(ipBytes)
val ip = InetAddress.getByAddress(ipBytes)
val port = buffer.getInt()
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg,
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg,
new InetSocketAddress(ip, port))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,25 @@ object BlockFetcherIterator {
future.onSuccess {
case Some(message) => {
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
for (blockMessage <- blockMessageArray) {
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
throw new SparkException(
"Unexpected message " + blockMessage.getType + " received from " + cmId)
if (bufferMessage.hasError) {
logError("Could not get block(s) from " + cmId)
for ((blockId, size) <- req.blocks) {
results.put(new FetchResult(blockId, -1, null))
}
} else {
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
for (blockMessage <- blockMessageArray) {
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
throw new SparkException(
"Unexpected message " + blockMessage.getType + " received from " + cmId)
}
val blockId = blockMessage.getId
val networkSize = blockMessage.getData.limit()
results.put(new FetchResult(blockId, sizeMap(blockId),
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
_remoteBytesRead += networkSize
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
val blockId = blockMessage.getId
val networkSize = blockMessage.getData.limit()
results.put(new FetchResult(blockId, sizeMap(blockId),
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
_remoteBytesRead += networkSize
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
}
case None => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,19 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
Some(new BlockMessageArray(responseMessages).toBufferMessage)
} catch {
case e: Exception => logError("Exception handling buffer message", e)
None
case e: Exception => {
logError("Exception handling buffer message", e)
val errorMessage = Message.createBufferMessage(msg.id)
errorMessage.hasError = true
Some(errorMessage)
}
}
}
case otherMessage: Any => {
logError("Unknown type message received: " + otherMessage)
None
val errorMessage = Message.createBufferMessage(msg.id)
errorMessage.hasError = true
Some(errorMessage)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,31 @@ class ConnectionManagerSuite extends FunSuite {
managerServer.stop()
}

test("Ack error message") {
val conf = new SparkConf
conf.set("spark.authenticate", "false")
val securityManager = new SecurityManager(conf)
val manager = new ConnectionManager(0, conf, securityManager)
val managerServer = new ConnectionManager(0, conf, securityManager)
managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
throw new Exception
})

val size = 10 * 1024 * 1024
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
val bufferMessage = Message.createBufferMessage(buffer)

val future = manager.sendMessageReliably(managerServer.id, bufferMessage)

val message = Await.result(future, 1 second)
assert(message.isDefined)
assert(message.get.hasError)

manager.stop()
managerServer.stop()

}

}

Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@

package org.apache.spark.storage

import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.future
import scala.concurrent.ExecutionContext.Implicits.global

import org.scalatest.{FunSuite, Matchers}
import org.scalatest.PrivateMethodTester._

import org.mockito.Mockito._
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.stubbing.Answer
import org.mockito.invocation.InvocationOnMock

import org.apache.spark._
import org.apache.spark.storage.BlockFetcherIterator._
import org.apache.spark.network.{ConnectionManager, ConnectionManagerId,
Message}
import org.apache.spark.network.{ConnectionManager, Message}

class BlockFetcherIteratorSuite extends FunSuite with Matchers {

Expand Down Expand Up @@ -137,4 +140,95 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined")
}

test("block fetch from remote fails using BasicBlockFetcherIterator") {
val blockManager = mock(classOf[BlockManager])
val connManager = mock(classOf[ConnectionManager])
when(blockManager.connectionManager).thenReturn(connManager)

val f = future {
val message = Message.createBufferMessage(0)
message.hasError = true
val someMessage = Some(message)
someMessage
}
when(connManager.sendMessageReliably(any(),
any())).thenReturn(f)
when(blockManager.futureExecContext).thenReturn(global)

when(blockManager.blockManagerId).thenReturn(
BlockManagerId("test-client", "test-client", 1, 0))
when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024)

val blId1 = ShuffleBlockId(0,0,0)
val blId2 = ShuffleBlockId(0,1,0)
val bmId = BlockManagerId("test-server", "test-server",1 , 0)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(bmId, Seq((blId1, 1L), (blId2, 1L)))
)

val iterator = new BasicBlockFetcherIterator(blockManager,
blocksByAddress, null)

iterator.initialize()
iterator.foreach{
case (_, r) => {
(!r.isDefined) should be(true)
}
}
}

test("block fetch from remote succeed using BasicBlockFetcherIterator") {
val blockManager = mock(classOf[BlockManager])
val connManager = mock(classOf[ConnectionManager])
when(blockManager.connectionManager).thenReturn(connManager)

val blId1 = ShuffleBlockId(0,0,0)
val blId2 = ShuffleBlockId(0,1,0)
val buf1 = ByteBuffer.allocate(4)
val buf2 = ByteBuffer.allocate(4)
buf1.putInt(1)
buf1.flip()
buf2.putInt(1)
buf2.flip()
val blockMessage1 = BlockMessage.fromGotBlock(GotBlock(blId1, buf1))
val blockMessage2 = BlockMessage.fromGotBlock(GotBlock(blId2, buf2))
val blockMessageArray = new BlockMessageArray(
Seq(blockMessage1, blockMessage2))

val bufferMessage = blockMessageArray.toBufferMessage
val buffer = ByteBuffer.allocate(bufferMessage.size)
val arrayBuffer = new ArrayBuffer[ByteBuffer]
bufferMessage.buffers.foreach{ b =>
buffer.put(b)
}
buffer.flip()
arrayBuffer += buffer

val someMessage = Some(Message.createBufferMessage(arrayBuffer))

val f = future {
someMessage
}
when(connManager.sendMessageReliably(any(),
any())).thenReturn(f)
when(blockManager.futureExecContext).thenReturn(global)

when(blockManager.blockManagerId).thenReturn(
BlockManagerId("test-client", "test-client", 1, 0))
when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024)

val bmId = BlockManagerId("test-server", "test-server",1 , 0)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(bmId, Seq((blId1, 1L), (blId2, 1L)))
)

val iterator = new BasicBlockFetcherIterator(blockManager,
blocksByAddress, null)
iterator.initialize()
iterator.foreach{
case (_, r) => {
(r.isDefined) should be(true)
}
}
}
}
Loading

0 comments on commit 7399c6b

Please sign in to comment.