diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 3cce7717d8fb9..6173fd3a69fc7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -62,11 +62,11 @@ private[spark] class TorrentBroadcast[T: ClassTag]( */ @transient private var _value: T = obj + private val broadcastId = BroadcastBlockId(id) + /** Total number of blocks this broadcast variable contains. */ private val numBlocks: Int = writeBlocks() - private val broadcastId = BroadcastBlockId(id) - override protected def getValue() = _value /** @@ -75,15 +75,23 @@ private[spark] class TorrentBroadcast[T: ClassTag]( * @return number of blocks this broadcast variable is divided into */ private def writeBlocks(): Int = { - val blocks = TorrentBroadcast.blockifyObject(_value) - blocks.zipWithIndex.foreach { case (block, i) => - SparkEnv.get.blockManager.putBytes( - BroadcastBlockId(id, "piece" + i), - block, - StorageLevel.MEMORY_AND_DISK_SER, - tellMaster = true) + // For local mode, just put the object in the BlockManager so we can find it later. + SparkEnv.get.blockManager.putSingle( + broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + + if (!isLocal) { + val blocks = TorrentBroadcast.blockifyObject(_value) + blocks.zipWithIndex.foreach { case (block, i) => + SparkEnv.get.blockManager.putBytes( + BroadcastBlockId(id, "piece" + i), + block, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true) + } + blocks.length + } else { + 0 } - blocks.length } /** Fetch torrent blocks from the driver and/or other executors. */ @@ -91,24 +99,33 @@ private[spark] class TorrentBroadcast[T: ClassTag]( // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported // to the driver, so other executors can pull these chunks from this executor as well. val blocks = new Array[ByteBuffer](numBlocks) + val bm = SparkEnv.get.blockManager for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { val pieceId = BroadcastBlockId(id, "piece" + pid) - // Note that we use getBytes rather than getRemoteBytes here because there is a chance - // that previous attempts to fetch the broadcast blocks have already fetched some of the - // blocks. In that case, some blocks would be available locally (on this executor). - SparkEnv.get.blockManager.getBytes(pieceId) match { - case Some(block) => - blocks(pid) = block - SparkEnv.get.blockManager.putBytes( - pieceId, - block, - StorageLevel.MEMORY_AND_DISK_SER, - tellMaster = true) - case None => - throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) + // First try getLocalBytes because there is a chance that previous attempts to fetch the + // broadcast blocks have already fetched some of the blocks. In that case, some blocks + // would be available locally (on this executor). + var blockOpt = bm.getLocalBytes(pieceId) + if (!blockOpt.isDefined) { + blockOpt = bm.getRemoteBytes(pieceId) + blockOpt match { + case Some(block) => + // If we found the block from remote executors/driver's BlockManager, put the block + // in this executor's BlockManager. + SparkEnv.get.blockManager.putBytes( + pieceId, + block, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true) + + case None => + throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) + } } + // If we get here, the option is defined. + blocks(pid) = blockOpt.get } blocks } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1d37a29ee0b21..e4c3d58905e7f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -517,16 +517,6 @@ private[spark] class BlockManager( None } - def getBytes(blockId: BlockId): Option[ByteBuffer] = { - val local = getLocalBytes(blockId) - if (local.isDefined) { - local - } else { - val remote = getRemoteBytes(blockId) - remote - } - } - /** * Get a block from the block manager (either local or remote). */ diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 01d019821a46a..978a6ded80829 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -189,17 +189,11 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) { var blockId = BroadcastBlockId(broadcastId) var statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === 0) + assert(statuses.size === 1) blockId = BroadcastBlockId(broadcastId, "piece0") statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === 1) - statuses.head match { case (bm, status) => - assert(bm.executorId === "", "Block should only be on the driver") - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK_SER) - assert(status.memSize > 0, "Block should be in memory store on the driver") - assert(status.diskSize === 0, "Block should not be in disk store on the driver") - } + assert(statuses.size === (if (distributed) 1 else 0)) } // Verify that blocks are persisted in both the executors and the driver @@ -207,7 +201,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { var blockId = BroadcastBlockId(broadcastId) var statuses = bmm.getBlockStatus(blockId, askSlaves = true) if (distributed) { - assert(statuses.size === numSlaves) + assert(statuses.size === numSlaves + 1) } else { assert(statuses.size === 1) } @@ -217,7 +211,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { if (distributed) { assert(statuses.size === numSlaves + 1) } else { - assert(statuses.size === 1) + assert(statuses.size === 0) } } @@ -225,12 +219,12 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // is true. def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) { var blockId = BroadcastBlockId(broadcastId) - var expectedNumBlocks = if (removeFromDriver) 0 else if (distributed) 0 else 1 + var expectedNumBlocks = if (removeFromDriver) 0 else 1 var statuses = bmm.getBlockStatus(blockId, askSlaves = true) assert(statuses.size === expectedNumBlocks) blockId = BroadcastBlockId(broadcastId, "piece0") - expectedNumBlocks = if (removeFromDriver) 0 else 1 + expectedNumBlocks = if (removeFromDriver || !distributed) 0 else 1 statuses = bmm.getBlockStatus(blockId, askSlaves = true) assert(statuses.size === expectedNumBlocks) }