diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index a8c827030a1ef..6a187b40628a2 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -32,8 +32,19 @@ import org.apache.spark.annotation.DeveloperApi */ @DeveloperApi trait BroadcastFactory { + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit + + /** + * Creates a new broadcast variable. + * + * @param value value to broadcast + * @param isLocal whether we are in local mode (single JVM process) + * @param id unique id representing this broadcast variable + */ def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit + def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index d8be649f96e5f..d49d2041a2a79 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -18,7 +18,9 @@ package org.apache.spark.broadcast import java.io._ +import java.nio.ByteBuffer +import scala.collection.JavaConversions.asJavaEnumeration import scala.reflect.ClassTag import scala.util.Random @@ -27,41 +29,87 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} /** - * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like - * protocol to do a distributed transfer of the broadcasted data to the executors. - * The mechanism is as follows. The driver divides the serializes the broadcasted data, - * divides it into smaller chunks, and stores them in the BlockManager of the driver. - * These chunks are reported to the BlockManagerMaster so that all the executors can - * learn the location of those chunks. The first time the broadcast variable (sent as - * part of task) is deserialized at a executor, all the chunks are fetched using - * the BlockManager. When all the chunks are fetched (initially from the driver's - * BlockManager), they are combined and deserialized to recreate the broadcasted data. - * However, the chunks are also stored in the BlockManager and reported to the - * BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns - * multiple locations for each chunk. Hence, subsequent fetches of each chunk will be - * made to other executors who already have those chunks, resulting in a distributed - * fetching. This prevents the driver from being the bottleneck in sending out multiple - * copies of the broadcast data (one per executor) as done by the - * [[org.apache.spark.broadcast.HttpBroadcast]]. + * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. + * + * The mechanism is as follows: + * + * The driver divides the serialized object into small chunks and + * stores those chunks in the BlockManager of the driver. + * + * On each executor, the executor first attempts to fetch the object from its BlockManager. If + * it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or + * other executors if available. Once it gets the chunks, it puts the chunks in its own + * BlockManager, ready for other executors to fetch from. + * + * This prevents the driver from being the bottleneck in sending out multiple copies of the + * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]]. + * + * @param obj object to broadcast + * @param isLocal whether Spark is running in local mode (single JVM process). + * @param id A unique identifier for the broadcast variable. */ private[spark] class TorrentBroadcast[T: ClassTag]( - @transient var value_ : T, isLocal: Boolean, id: Long) + obj : T, + @transient private val isLocal: Boolean, + id: Long) extends Broadcast[T](id) with Logging with Serializable { - override protected def getValue() = value_ + override protected def getValue() = _value + + /** + * Value of the broadcast object. On driver, this is set directly by the constructor. + * On executors, this is reconstructed by [[readObject]], which builds this value by reading + * blocks from the driver and/or other executors. + */ + @transient private var _value: T = obj + + /** Total number of blocks this broadcast variable contains. */ + private val numBlocks: Int = writeBlocks() private val broadcastId = BroadcastBlockId(id) - SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + /** + * Divide the object into multiple blocks and put those blocks in the block manager. + * + * @return number of blocks this broadcast variable is divided into + */ + private def writeBlocks(): Int = { + val blocks = TorrentBroadcast.blockifyObject(_value) + blocks.zipWithIndex.foreach { case (block, i) => + // TODO: Use putBytes directly. + SparkEnv.get.blockManager.putSingle( + BroadcastBlockId(id, "piece" + i), + blocks(i), + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true) + } + blocks.length + } - @transient private var arrayOfBlocks: Array[TorrentBlock] = null - @transient private var totalBlocks = -1 - @transient private var totalBytes = -1 - @transient private var hasBlocks = 0 + /** Fetch torrent blocks from the driver and/or other executors. */ + private def readBlocks(): Array[Array[Byte]] = { + // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported + // to the driver, so other executors can pull these thunks from this executor as well. + var numBlocksAvailable = 0 + val blocks = new Array[Array[Byte]](numBlocks) - if (!isLocal) { - sendBroadcast() + for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { + val pieceId = BroadcastBlockId(id, "piece" + pid) + SparkEnv.get.blockManager.getSingle(pieceId) match { + case Some(x) => + blocks(pid) = x.asInstanceOf[Array[Byte]] + numBlocksAvailable += 1 + SparkEnv.get.blockManager.putBytes( + pieceId, + ByteBuffer.wrap(blocks(pid)), + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true) + + case None => + throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) + } + } + blocks } /** @@ -79,26 +127,6 @@ private[spark] class TorrentBroadcast[T: ClassTag]( TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) } - private def sendBroadcast() { - val tInfo = TorrentBroadcast.blockifyObject(value_) - totalBlocks = tInfo.totalBlocks - totalBytes = tInfo.totalBytes - hasBlocks = tInfo.totalBlocks - - // Store meta-info - val metaId = BroadcastBlockId(id, "meta") - val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) - SparkEnv.get.blockManager.putSingle( - metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - - // Store individual pieces - for (i <- 0 until totalBlocks) { - val pieceId = BroadcastBlockId(id, "piece" + i) - SparkEnv.get.blockManager.putSingle( - pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) - } - } - /** Used by the JVM when serializing this object. */ private def writeObject(out: ObjectOutputStream) { assertValid() @@ -109,99 +137,30 @@ private[spark] class TorrentBroadcast[T: ClassTag]( private def readObject(in: ObjectInputStream) { in.defaultReadObject() TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(broadcastId) match { + SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match { case Some(x) => - value_ = x.asInstanceOf[T] + _value = x.asInstanceOf[T] case None => - val start = System.nanoTime logInfo("Started reading broadcast variable " + id) - - // Initialize @transient variables that will receive garbage values from the master. - resetWorkerVariables() - - if (receiveBroadcast()) { - value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - - /* Store the merged copy in cache so that the next worker doesn't need to rebuild it. - * This creates a trade-off between memory usage and latency. Storing copy doubles - * the memory footprint; not storing doubles deserialization cost. Also, - * this does not need to be reported to BlockManagerMaster since other executors - * does not need to access this block (they only need to fetch the chunks, - * which are reported). - */ - SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - - // Remove arrayOfBlocks from memory once value_ is on local cache - resetWorkerVariables() - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 + val start = System.nanoTime() + val blocks = readBlocks() + val time = (System.nanoTime() - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - - private def resetWorkerVariables() { - arrayOfBlocks = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = 0 - } - - private def receiveBroadcast(): Boolean = { - // Receive meta-info about the size of broadcast data, - // the number of chunks it is divided into, etc. - val metaId = BroadcastBlockId(id, "meta") - var attemptId = 10 - while (attemptId > 0 && totalBlocks == -1) { - SparkEnv.get.blockManager.getSingle(metaId) match { - case Some(x) => - val tInfo = x.asInstanceOf[TorrentInfo] - totalBlocks = tInfo.totalBlocks - totalBytes = tInfo.totalBytes - arrayOfBlocks = new Array[TorrentBlock](totalBlocks) - hasBlocks = 0 - - case None => - Thread.sleep(500) - } - attemptId -= 1 - } - - if (totalBlocks == -1) { - return false - } - /* - * Fetch actual chunks of data. Note that all these chunks are stored in - * the BlockManager and reported to the master, so that other executors - * can find out and pull the chunks from this executor. - */ - val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) - for (pid <- recvOrder) { - val pieceId = BroadcastBlockId(id, "piece" + pid) - SparkEnv.get.blockManager.getSingle(pieceId) match { - case Some(x) => - arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] - hasBlocks += 1 + _value = TorrentBroadcast.unBlockifyObject[T](blocks) + // Store the merged copy in BlockManager so other tasks on this executor doesn't + // need to re-fetch it. SparkEnv.get.blockManager.putSingle( - pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true) - - case None => - throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) + broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } } - - hasBlocks == totalBlocks } - } -private[broadcast] object TorrentBroadcast extends Logging { + +private object TorrentBroadcast extends Logging { + /** Size of each block. Default value is 4MB. */ private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 private var initialized = false private var conf: SparkConf = null @@ -223,7 +182,9 @@ private[broadcast] object TorrentBroadcast extends Logging { initialized = false } - def blockifyObject[T: ClassTag](obj: T): TorrentInfo = { + def blockifyObject[T: ClassTag](obj: T): Array[Array[Byte]] = { + // TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks + // so we don't need to do the extra memory copy. val bos = new ByteArrayOutputStream() val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos val ser = SparkEnv.get.serializer.newInstance() @@ -231,44 +192,27 @@ private[broadcast] object TorrentBroadcast extends Logging { serOut.writeObject[T](obj).close() val byteArray = bos.toByteArray val bais = new ByteArrayInputStream(byteArray) + val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt + val blocks = new Array[Array[Byte]](numBlocks) - var blockNum = byteArray.length / BLOCK_SIZE - if (byteArray.length % BLOCK_SIZE != 0) { - blockNum += 1 - } - - val blocks = new Array[TorrentBlock](blockNum) var blockId = 0 - for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) val tempByteArray = new Array[Byte](thisBlockSize) bais.read(tempByteArray, 0, thisBlockSize) - blocks(blockId) = new TorrentBlock(blockId, tempByteArray) + blocks(blockId) = tempByteArray blockId += 1 } bais.close() - - val info = TorrentInfo(blocks, blockNum, byteArray.length) - info.hasBlocks = blockNum - info + blocks } - def unBlockifyObject[T: ClassTag]( - arrayOfBlocks: Array[TorrentBlock], - totalBytes: Int, - totalBlocks: Int): T = { - val retByteArray = new Array[Byte](totalBytes) - for (i <- 0 until totalBlocks) { - System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, - i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length) - } + def unBlockifyObject[T: ClassTag](blocks: Array[Array[Byte]]): T = { + val is = new SequenceInputStream( + asJavaEnumeration(blocks.iterator.map(block => new ByteArrayInputStream(block)))) + val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is - val in: InputStream = { - val arrIn = new ByteArrayInputStream(retByteArray) - if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn - } val ser = SparkEnv.get.serializer.newInstance() val serIn = ser.deserializeStream(in) val obj = serIn.readObject[T]() @@ -284,17 +228,3 @@ private[broadcast] object TorrentBroadcast extends Logging { SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) } } - -private[broadcast] case class TorrentBlock( - blockID: Int, - byteArray: Array[Byte]) - extends Serializable - -private[broadcast] case class TorrentInfo( - @transient arrayOfBlocks: Array[TorrentBlock], - totalBlocks: Int, - totalBytes: Int) - extends Serializable { - - @transient var hasBlocks = 0 -} diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 17c64455b2429..01d019821a46a 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -17,10 +17,12 @@ package org.apache.spark.broadcast -import org.apache.spark.storage.{BroadcastBlockId, _} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} import org.scalatest.FunSuite +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} +import org.apache.spark.storage._ + + class BroadcastSuite extends FunSuite with LocalSparkContext { private val httpConf = broadcastConf("HttpBroadcastFactory") @@ -124,12 +126,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) { val numSlaves = if (distributed) 2 else 0 - def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id)) - // Verify that the broadcast file is created, and blocks are persisted only on the driver - def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - assert(blockIds.size === 1) - val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) { + val blockId = BroadcastBlockId(broadcastId) + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) assert(statuses.size === 1) statuses.head match { case (bm, status) => assert(bm.executorId === "", "Block should only be on the driver") @@ -139,14 +139,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } if (distributed) { // this file is only generated in distributed mode - assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!") + assert(HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file not found!") } } // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - assert(blockIds.size === 1) - val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) { + val blockId = BroadcastBlockId(broadcastId) + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) assert(statuses.size === numSlaves + 1) statuses.foreach { case (_, status) => assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) @@ -157,21 +157,21 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver // is true. In the latter case, also verify that the broadcast file is deleted on the driver. - def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - assert(blockIds.size === 1) - val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) { + val blockId = BroadcastBlockId(broadcastId) + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) val expectedNumBlocks = if (removeFromDriver) 0 else 1 val possiblyNot = if (removeFromDriver) "" else " not" assert(statuses.size === expectedNumBlocks, "Block should%s be unpersisted on the driver".format(possiblyNot)) if (distributed && removeFromDriver) { // this file is only generated in distributed mode - assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists, + assert(!HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file should%s be deleted".format(possiblyNot)) } } - testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, httpConf, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -185,67 +185,57 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) { val numSlaves = if (distributed) 2 else 0 - def getBlockIds(id: Long) = { - val broadcastBlockId = BroadcastBlockId(id) - val metaBlockId = BroadcastBlockId(id, "meta") - // Assume broadcast value is small enough to fit into 1 piece - val pieceBlockId = BroadcastBlockId(id, "piece0") - if (distributed) { - // the metadata and piece blocks are generated only in distributed mode - Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId) - } else { - Seq[BroadcastBlockId](broadcastBlockId) + // Verify that blocks are persisted only on the driver + def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) { + var blockId = BroadcastBlockId(broadcastId) + var statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === 0) + + blockId = BroadcastBlockId(broadcastId, "piece0") + statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === 1) + statuses.head match { case (bm, status) => + assert(bm.executorId === "", "Block should only be on the driver") + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK_SER) + assert(status.memSize > 0, "Block should be in memory store on the driver") + assert(status.diskSize === 0, "Block should not be in disk store on the driver") } } - // Verify that blocks are persisted only on the driver - def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - blockIds.foreach { blockId => - val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) { + var blockId = BroadcastBlockId(broadcastId) + var statuses = bmm.getBlockStatus(blockId, askSlaves = true) + if (distributed) { + assert(statuses.size === numSlaves) + } else { assert(statuses.size === 1) - statuses.head match { case (bm, status) => - assert(bm.executorId === "", "Block should only be on the driver") - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) - assert(status.memSize > 0, "Block should be in memory store on the driver") - assert(status.diskSize === 0, "Block should not be in disk store on the driver") - } } - } - // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - blockIds.foreach { blockId => - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - if (blockId.field == "meta") { - // Meta data is only on the driver - assert(statuses.size === 1) - statuses.head match { case (bm, _) => assert(bm.executorId === "") } - } else { - // Other blocks are on both the executors and the driver - assert(statuses.size === numSlaves + 1, - blockId + " has " + statuses.size + " statuses: " + statuses.mkString(",")) - statuses.foreach { case (_, status) => - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) - assert(status.memSize > 0, "Block should be in memory store") - assert(status.diskSize === 0, "Block should not be in disk store") - } - } + blockId = BroadcastBlockId(broadcastId, "piece0") + statuses = bmm.getBlockStatus(blockId, askSlaves = true) + if (distributed) { + assert(statuses.size === numSlaves + 1) + } else { + assert(statuses.size === 1) } } // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver // is true. - def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - val expectedNumBlocks = if (removeFromDriver) 0 else 1 - val possiblyNot = if (removeFromDriver) "" else " not" - blockIds.foreach { blockId => - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === expectedNumBlocks, - "Block should%s be unpersisted on the driver".format(possiblyNot)) - } + def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) { + var blockId = BroadcastBlockId(broadcastId) + var expectedNumBlocks = if (removeFromDriver) 0 else if (distributed) 0 else 1 + var statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === expectedNumBlocks) + + blockId = BroadcastBlockId(broadcastId, "piece0") + expectedNumBlocks = if (removeFromDriver) 0 else 1 + statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === expectedNumBlocks) } - testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -262,10 +252,9 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { distributed: Boolean, numSlaves: Int, // used only when distributed = true broadcastConf: SparkConf, - getBlockIds: Long => Seq[BroadcastBlockId], - afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, - afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, - afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterCreation: (Long, BlockManagerMaster) => Unit, + afterUsingBroadcast: (Long, BlockManagerMaster) => Unit, + afterUnpersist: (Long, BlockManagerMaster) => Unit, removeFromDriver: Boolean) { sc = if (distributed) { @@ -278,15 +267,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Create broadcast variable val broadcast = sc.broadcast(list) - val blocks = getBlockIds(broadcast.id) - afterCreation(blocks, blockManagerMaster) + afterCreation(broadcast.id, blockManagerMaster) // Use broadcast variable on all executors val partitions = 10 assert(partitions > numSlaves) val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) - afterUsingBroadcast(blocks, blockManagerMaster) + afterUsingBroadcast(broadcast.id, blockManagerMaster) // Unpersist broadcast if (removeFromDriver) { @@ -294,7 +282,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } else { broadcast.unpersist(blocking = true) } - afterUnpersist(blocks, blockManagerMaster) + afterUnpersist(broadcast.id, blockManagerMaster) // If the broadcast is removed from driver, all subsequent uses of the broadcast variable // should throw SparkExceptions. Otherwise, the result should be the same as before.