diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 36fdffede89e7..fc028c3becad7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -47,6 +47,8 @@ import org.apache.spark.util.io.ByteArrayChunkOutputStream * This prevents the driver from being the bottleneck in sending out multiple copies of the * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]]. * + * When initialized, TorrentBroadcast objects read SparkEnv.get.conf. + * * @param obj object to broadcast * @param id A unique identifier for the broadcast variable. */ @@ -116,18 +118,19 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) // First try getLocalBytes because there is a chance that previous attempts to fetch the // broadcast blocks have already fetched some of the blocks. In that case, some blocks // would be available locally (on this executor). - val block: ByteBuffer = bm.getLocalBytes(pieceId).getOrElse { - bm.getRemoteBytes(pieceId).map { block => - // If we found the block from remote executors/driver's BlockManager, put the block - // in this executor's BlockManager. - SparkEnv.get.blockManager.putBytes( - pieceId, - block, - StorageLevel.MEMORY_AND_DISK_SER, - tellMaster = true) - block - }.getOrElse(throw new SparkException(s"Failed to get $pieceId of $broadcastId")) + def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId) + def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block => + // If we found the block from remote executors/driver's BlockManager, put the block + // in this executor's BlockManager. + SparkEnv.get.blockManager.putBytes( + pieceId, + block, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true) + block } + val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse( + throw new SparkException(s"Failed to get $pieceId of $broadcastId")) blocks(pid) = block } blocks