diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 4479e6875a731..f5fe6947c179e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -43,6 +43,7 @@ import org.apache.spark.rpc.RpcEnv import org.apache.spark.serializer.{Serializer, SerializerInstance} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ +import org.apache.spark.util.collection.ReferenceCounter private[spark] sealed trait BlockValues private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues @@ -161,6 +162,8 @@ private[spark] class BlockManager( * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) + private val referenceCounts = new ReferenceCounter[BlockId] + /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -414,7 +417,11 @@ private[spark] class BlockManager( */ def getLocal(blockId: BlockId): Option[BlockResult] = { logDebug(s"Getting local block $blockId") - doGetLocal(blockId, asBlockResult = true).asInstanceOf[Option[BlockResult]] + val res = doGetLocal(blockId, asBlockResult = true).asInstanceOf[Option[BlockResult]] + if (res.isDefined) { + referenceCounts.retain(blockId) + } + res } /** @@ -424,7 +431,7 @@ private[spark] class BlockManager( logDebug(s"Getting local block $blockId as bytes") // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work - if (blockId.isShuffle) { + val res = if (blockId.isShuffle) { val shuffleBlockResolver = shuffleManager.shuffleBlockResolver // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. @@ -433,6 +440,10 @@ private[spark] class BlockManager( } else { doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] } + if (res.isDefined) { + referenceCounts.retain(blockId) + } + res } private def doGetLocal(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { @@ -564,7 +575,11 @@ private[spark] class BlockManager( */ def getRemote(blockId: BlockId): Option[BlockResult] = { logDebug(s"Getting remote block $blockId") - doGetRemote(blockId, asBlockResult = true).asInstanceOf[Option[BlockResult]] + val res = doGetRemote(blockId, asBlockResult = true).asInstanceOf[Option[BlockResult]] + if (res.isDefined) { + referenceCounts.retain(blockId) + } + res } /** @@ -572,7 +587,11 @@ private[spark] class BlockManager( */ def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = { logDebug(s"Getting remote block $blockId as bytes") - doGetRemote(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] + val res = doGetRemote(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] + if (res.isDefined) { + referenceCounts.retain(blockId) + } + res } /** @@ -642,6 +661,17 @@ private[spark] class BlockManager( None } + /** + * Release one reference to the given block. + */ + def release(blockId: BlockId): Unit = { + referenceCounts.release(blockId) + } + + private[storage] def getReferenceCount(blockId: BlockId): Int = { + referenceCounts.getReferenceCount(blockId) + } + def putIterator( blockId: BlockId, values: Iterator[Any], diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index bdab8c2332fae..e97f5419b2f30 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -213,6 +213,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo } override def remove(blockId: BlockId): Boolean = memoryManager.synchronized { + val referenceCount = blockManager.getReferenceCount(blockId) + if (referenceCount != 0) { + throw new IllegalStateException( + s"Cannot free block $blockId since it is still referenced $referenceCount times") + } val entry = entries.synchronized { entries.remove(blockId) } if (entry != null) { memoryManager.releaseStorageMemory(entry.size) @@ -425,6 +430,10 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo var freedMemory = 0L val rddToAdd = blockId.flatMap(getRddId) val selectedBlocks = new ArrayBuffer[BlockId] + def blockIsEvictable(blockId: BlockId): Boolean = { + blockManager.getReferenceCount(blockId) == 0 && + (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) + } // This is synchronized to ensure that the set of entries is not changed // (because of getValue or getBytes) while traversing the iterator, as that // can lead to exceptions. @@ -433,7 +442,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo while (freedMemory < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey - if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) { + if (blockIsEvictable(blockId)) { selectedBlocks += blockId freedMemory += pair.getValue.size } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ReferenceCounter.scala b/core/src/main/scala/org/apache/spark/util/collection/ReferenceCounter.scala index 091a902988d12..6823c1f9339ec 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ReferenceCounter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ReferenceCounter.scala @@ -55,15 +55,15 @@ private[spark] class ReferenceCounter[T] { /** * Increments the given object's reference count for the current task. */ - def retain(obj: T): Unit = { - retainForTask(TaskContext.get().taskAttemptId(), obj) - } + def retain(obj: T): Unit = retainForTask(currentTaskAttemptId, obj) /** * Decrements the given object's reference count for the current task. */ - def release(obj: T): Unit = { - releaseForTask(TaskContext.get().taskAttemptId(), obj) + def release(obj: T): Unit = releaseForTask(currentTaskAttemptId, obj) + + private def currentTaskAttemptId: TaskAttemptId = { + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) } /** diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 21db3b1c9ffbd..3ac33dab45a82 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -174,8 +174,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // Checking whether blocks are in memory assert(store.getSingle("a1").isDefined, "a1 was not in store") + store.release("a1") assert(store.getSingle("a2").isDefined, "a2 was not in store") + store.release("a2") assert(store.getSingle("a3").isDefined, "a3 was not in store") + store.release("a3") // Checking whether master knows about the blocks or not assert(master.getLocations("a1").size > 0, "master was not told about a1") @@ -223,8 +226,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(memStatus._1 == 20000L, "total memory " + memStatus._1 + " should equal 20000") assert(memStatus._2 <= 12000L, "remaining memory " + memStatus._2 + " should <= 12000") assert(store.getSingle("a1-to-remove").isDefined, "a1 was not in store") + store.release("a1-to-remove") assert(store.getSingle("a2-to-remove").isDefined, "a2 was not in store") + store.release("a2-to-remove") assert(store.getSingle("a3-to-remove").isDefined, "a3 was not in store") + store.release("a3-to-remove") // Checking whether master knows about the blocks or not assert(master.getLocations("a1-to-remove").size > 0, "master was not told about a1") @@ -313,9 +319,13 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // verify whether the blocks exist in both the stores Seq(driverStore, executorStore).foreach { case s => s.getLocal(broadcast0BlockId) should not be (None) + s.release(broadcast0BlockId) s.getLocal(broadcast1BlockId) should not be (None) + s.release(broadcast1BlockId) s.getLocal(broadcast2BlockId) should not be (None) + s.release(broadcast2BlockId) s.getLocal(broadcast2BlockId2) should not be (None) + s.release(broadcast2BlockId2) } // remove broadcast 0 block only from executors @@ -324,17 +334,23 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // only broadcast 0 block should be removed from the executor store executorStore.getLocal(broadcast0BlockId) should be (None) executorStore.getLocal(broadcast1BlockId) should not be (None) + executorStore.release(broadcast1BlockId) executorStore.getLocal(broadcast2BlockId) should not be (None) + executorStore.release(broadcast2BlockId) // nothing should be removed from the driver store driverStore.getLocal(broadcast0BlockId) should not be (None) + driverStore.release(broadcast0BlockId) driverStore.getLocal(broadcast1BlockId) should not be (None) + driverStore.release(broadcast1BlockId) driverStore.getLocal(broadcast2BlockId) should not be (None) + driverStore.release(broadcast2BlockId) // remove broadcast 0 block from the driver as well master.removeBroadcast(0, removeFromMaster = true, blocking = true) driverStore.getLocal(broadcast0BlockId) should be (None) driverStore.getLocal(broadcast1BlockId) should not be (None) + driverStore.release(broadcast1BlockId) // remove broadcast 1 block from both the stores asynchronously // and verify all broadcast 1 blocks have been removed @@ -505,38 +521,30 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } test("in-memory LRU storage") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.getSingle("a1") === None, "a1 was in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3") === None, "a3 was in store") + testInMemoryLRUStorage(StorageLevel.MEMORY_ONLY) } test("in-memory LRU storage with serialization") { + testInMemoryLRUStorage(StorageLevel.MEMORY_ONLY_SER) + } + + private def testInMemoryLRUStorage(storageLevel: StorageLevel): Unit = { store = makeBlockManager(12000) val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a1", a1, storageLevel) + store.putSingle("a2", a2, storageLevel) + store.putSingle("a3", a3, storageLevel) assert(store.getSingle("a2").isDefined, "a2 was not in store") + store.release("a2") assert(store.getSingle("a3").isDefined, "a3 was not in store") + store.release("a3") assert(store.getSingle("a1") === None, "a1 was in store") assert(store.getSingle("a2").isDefined, "a2 was not in store") + store.release("a2") // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a1", a1, storageLevel) assert(store.getSingle("a1").isDefined, "a1 was not in store") assert(store.getSingle("a2").isDefined, "a2 was not in store") assert(store.getSingle("a3") === None, "a3 was in store") @@ -618,62 +626,38 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } test("disk and memory storage") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, _.getSingle) } test("disk and memory storage with getLocalBytes") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) - assert(store.getLocalBytes("a2").isDefined, "a2 was not in store") - assert(store.getLocalBytes("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getLocalBytes("a1").isDefined, "a1 was not in store") - assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, _.getLocalBytes) } test("disk and memory storage with serialization") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, _.getSingle) } test("disk and memory storage with serialization and getLocalBytes") { + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, _.getLocalBytes) + } + + def testDiskAndMemoryStorage( + storageLevel: StorageLevel, + accessMethod: BlockManager => BlockId => Option[_]): Unit = { store = makeBlockManager(12000) val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getLocalBytes("a2").isDefined, "a2 was not in store") - assert(store.getLocalBytes("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getLocalBytes("a1").isDefined, "a1 was not in store") + store.putSingle("a1", a1, storageLevel) + store.putSingle("a2", a2, storageLevel) + store.putSingle("a3", a3, storageLevel) + assert(accessMethod(store)("a2").isDefined, "a2 was not in store") + store.release("a2") + assert(accessMethod(store)("a3").isDefined, "a3 was not in store") + store.release("a3") + assert(store.memoryStore.getValues("a1").isEmpty, "a1 was in memory store") + assert(accessMethod(store)("a1").isDefined, "a1 was not in store") + store.release("a1") assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") } @@ -689,14 +673,20 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putSingle("a3", a3, StorageLevel.DISK_ONLY) // At this point LRU should not kick in because a3 is only on disk assert(store.getSingle("a1").isDefined, "a1 was not in store") + store.release("a1") assert(store.getSingle("a2").isDefined, "a2 was not in store") + store.release("a2") assert(store.getSingle("a3").isDefined, "a3 was not in store") + store.release("a3") // Now let's add in a4, which uses both disk and memory; a1 should drop out store.putSingle("a4", a4, StorageLevel.MEMORY_AND_DISK_SER) assert(store.getSingle("a1") == None, "a1 was in store") assert(store.getSingle("a2").isDefined, "a2 was not in store") + store.release("a2") assert(store.getSingle("a3").isDefined, "a3 was not in store") + store.release("a3") assert(store.getSingle("a4").isDefined, "a4 was not in store") + store.release("a4") } test("in-memory LRU with streams") { @@ -709,17 +699,27 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putIterator("list3", list3.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) assert(store.get("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) + store.release("list2") + store.release("list2") assert(store.get("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) + store.release("list3") + store.release("list3") assert(store.get("list1") === None, "list1 was in store") assert(store.get("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) + store.release("list2") + store.release("list2") // At this point list2 was gotten last, so LRU will getSingle rid of list3 store.putIterator("list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) assert(store.get("list1").isDefined, "list1 was not in store") assert(store.get("list1").get.data.size === 2) + store.release("list1") + store.release("list1") assert(store.get("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) + store.release("list2") + store.release("list2") assert(store.get("list3") === None, "list1 was in store") } @@ -739,25 +739,43 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // At this point LRU should not kick in because list3 is only on disk assert(store.get("list1").isDefined, "list1 was not in store") assert(store.get("list1").get.data.size === 2) + store.release("list1") + store.release("list1") assert(store.get("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) + store.release("list2") + store.release("list2") assert(store.get("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) + store.release("list3") + store.release("list3") assert(store.get("list1").isDefined, "list1 was not in store") assert(store.get("list1").get.data.size === 2) + store.release("list1") + store.release("list1") assert(store.get("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) + store.release("list2") + store.release("list2") assert(store.get("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) + store.release("list3") + store.release("list3") // Now let's add in list4, which uses both disk and memory; list1 should drop out store.putIterator("list4", list4.iterator, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true) assert(store.get("list1") === None, "list1 was in store") assert(store.get("list2").isDefined, "list2 was not in store") assert(store.get("list2").get.data.size === 2) + store.release("list2") + store.release("list2") assert(store.get("list3").isDefined, "list3 was not in store") assert(store.get("list3").get.data.size === 2) + store.release("list3") + store.release("list3") assert(store.get("list4").isDefined, "list4 was not in store") assert(store.get("list4").get.data.size === 2) + store.release("list4") + store.release("list4") } test("negative byte values in ByteBufferInputStream") { @@ -1059,6 +1077,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putSingle(rdd(1, 0), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // Access rdd_1_0 to ensure it's not least recently used. assert(store.getSingle(rdd(1, 0)).isDefined, "rdd_1_0 was not in store") + store.release(rdd(1, 0)) // According to the same-RDD rule, rdd_1_0 should be replaced here. store.putSingle(rdd(0, 1), new Array[Byte](4000), StorageLevel.MEMORY_ONLY) // rdd_1_0 should have been replaced, even it's not least recently used.