-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-12757] Add block-level read/write locks to BlockManager #10705
Changes from 8 commits
5d130e4
423faab
1ee665f
76cfebd
7265784
2fb8c89
7cad770
8ae88b0
c1a8d85
575a47b
0ba8318
feb1172
90cf403
7f28910
12ed084
43e50ed
1b18226
8d45da6
8a52f58
36253df
77d8c5c
e37f003
2cf8157
150c6e1
1adbdb9
1828757
76fc9f5
2942b24
62f6671
47f3174
4591308
77939c2
a0c5bb3
d40e010
3f29595
ef7d885
e8d6ec8
9c8d530
f3fc298
dd6358c
ec8cc24
6134989
c9726c2
c629f26
fc19cfd
0aa2392
b273422
7639e03
27e98a3
b72cd7b
5e23177
f0b6d71
e549f2f
6d09400
717c476
e07b62d
0c08731
55b5b19
3a12480
25b09d7
bcb8318
4e11d00
ed44f45
7c74591
a401adc
504986f
66202f2
4f620a4
8547841
99c460c
c94984e
ac2b73f
39b1185
6502047
1d903ff
24dbc3d
745c1f9
5cfbbdb
9427576
3d377b5
07e0e37
697eba2
f5f089d
0b7281b
68b9e83
a5ef11b
b9d6e18
5df7284
eab288c
06ebef5
0628a33
b963178
9becde3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,7 @@ import org.apache.spark.rpc.RpcEnv | |
import org.apache.spark.serializer.{Serializer, SerializerInstance} | ||
import org.apache.spark.shuffle.ShuffleManager | ||
import org.apache.spark.util._ | ||
import org.apache.spark.util.collection.ReferenceCounter | ||
|
||
private[spark] sealed trait BlockValues | ||
private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues | ||
|
@@ -161,6 +162,8 @@ private[spark] class BlockManager( | |
* loaded yet. */ | ||
private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) | ||
|
||
private val referenceCounts = new ReferenceCounter[BlockId] | ||
|
||
/** | ||
* Initializes the BlockManager with the given appId. This is not performed in the constructor as | ||
* the appId may not be known at BlockManager instantiation time (in particular for the driver, | ||
|
@@ -414,7 +417,11 @@ private[spark] class BlockManager( | |
*/ | ||
def getLocal(blockId: BlockId): Option[BlockResult] = { | ||
logDebug(s"Getting local block $blockId") | ||
doGetLocal(blockId, asBlockResult = true).asInstanceOf[Option[BlockResult]] | ||
val res = doGetLocal(blockId, asBlockResult = true).asInstanceOf[Option[BlockResult]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A shortened version of this should be a class comment. Describe what pinning means and what are the API semantics. "we should add a pin-counting mechanism to track which blocks/pages are being read in order to prevent them from being evicted prematurely. I propose to do this in two phases: first, add a safe, conservative approach in which all BlockManager.get*() calls implicitly increment the pin count of blocks and where tasks' pins are automatically freed upon task completion (this PR)" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should also document the thread safety guarantees wrt to pinning. |
||
if (res.isDefined) { | ||
referenceCounts.retain(blockId) | ||
} | ||
res | ||
} | ||
|
||
/** | ||
|
@@ -424,7 +431,7 @@ private[spark] class BlockManager( | |
logDebug(s"Getting local block $blockId as bytes") | ||
// As an optimization for map output fetches, if the block is for a shuffle, return it | ||
// without acquiring a lock; the disk store never deletes (recent) items so this should work | ||
if (blockId.isShuffle) { | ||
val res = if (blockId.isShuffle) { | ||
val shuffleBlockResolver = shuffleManager.shuffleBlockResolver | ||
// TODO: This should gracefully handle case where local block is not available. Currently | ||
// downstream code will throw an exception. | ||
|
@@ -433,6 +440,10 @@ private[spark] class BlockManager( | |
} else { | ||
doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] | ||
} | ||
if (res.isDefined) { | ||
referenceCounts.retain(blockId) | ||
} | ||
res | ||
} | ||
|
||
private def doGetLocal(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { | ||
|
@@ -564,15 +575,23 @@ private[spark] class BlockManager( | |
*/ | ||
def getRemote(blockId: BlockId): Option[BlockResult] = { | ||
logDebug(s"Getting remote block $blockId") | ||
doGetRemote(blockId, asBlockResult = true).asInstanceOf[Option[BlockResult]] | ||
val res = doGetRemote(blockId, asBlockResult = true).asInstanceOf[Option[BlockResult]] | ||
if (res.isDefined) { | ||
referenceCounts.retain(blockId) | ||
} | ||
res | ||
} | ||
|
||
/** | ||
* Get block from remote block managers as serialized bytes. | ||
*/ | ||
def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = { | ||
logDebug(s"Getting remote block $blockId as bytes") | ||
doGetRemote(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] | ||
val res = doGetRemote(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] | ||
if (res.isDefined) { | ||
referenceCounts.retain(blockId) | ||
} | ||
res | ||
} | ||
|
||
/** | ||
|
@@ -642,6 +661,17 @@ private[spark] class BlockManager( | |
None | ||
} | ||
|
||
/** | ||
* Release one reference to the given block. | ||
*/ | ||
def release(blockId: BlockId): Unit = { | ||
referenceCounts.release(blockId) | ||
} | ||
|
||
private[storage] def getReferenceCount(blockId: BlockId): Int = { | ||
referenceCounts.getReferenceCount(blockId) | ||
} | ||
|
||
def putIterator( | ||
blockId: BlockId, | ||
values: Iterator[Any], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -213,6 +213,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo | |
} | ||
|
||
override def remove(blockId: BlockId): Boolean = memoryManager.synchronized { | ||
val referenceCount = blockManager.getReferenceCount(blockId) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are the semantics here? It seems reasonable for another thread to get this block. Who calls remove? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the second case, we'll never hit the error message because the MemoryStore won't try to evict blocks with non-zero pin/reference counts. We do have to worry about the first case: if we try to force-remove a block while a task is still reading it then the removal should fail with an error. |
||
if (referenceCount != 0) { | ||
throw new IllegalStateException( | ||
s"Cannot free block $blockId since it is still referenced $referenceCount times") | ||
} | ||
val entry = entries.synchronized { entries.remove(blockId) } | ||
if (entry != null) { | ||
memoryManager.releaseStorageMemory(entry.size) | ||
|
@@ -425,6 +430,10 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo | |
var freedMemory = 0L | ||
val rddToAdd = blockId.flatMap(getRddId) | ||
val selectedBlocks = new ArrayBuffer[BlockId] | ||
def blockIsEvictable(blockId: BlockId): Boolean = { | ||
blockManager.getReferenceCount(blockId) == 0 && | ||
(rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) | ||
} | ||
// This is synchronized to ensure that the set of entries is not changed | ||
// (because of getValue or getBytes) while traversing the iterator, as that | ||
// can lead to exceptions. | ||
|
@@ -433,7 +442,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, memoryManager: Memo | |
while (freedMemory < space && iterator.hasNext) { | ||
val pair = iterator.next() | ||
val blockId = pair.getKey | ||
if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) { | ||
if (blockIsEvictable(blockId)) { | ||
selectedBlocks += blockId | ||
freedMemory += pair.getValue.size | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.util.collection | ||
|
||
import scala.collection.JavaConverters._ | ||
|
||
import com.google.common.cache.{CacheBuilder, CacheLoader} | ||
import com.google.common.collect.ConcurrentHashMultiset | ||
|
||
import org.apache.spark.TaskContext | ||
|
||
/** | ||
* Thread-safe collection for maintaining both global and per-task reference counts for objects. | ||
*/ | ||
private[spark] class ReferenceCounter[T] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any reason you did it this way instead of a counter per object? Not sure how many blocks we have but this seems contention prone. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to maintain global counts per each object as well as counts for each task (in order to automatically decrement the global counts when tasks finish) (I'm working on adding the If I stored the global count per block inside of the |
||
|
||
private type TaskAttemptId = Long | ||
|
||
/** | ||
* Total references across all tasks. | ||
*/ | ||
private[this] val allReferences = ConcurrentHashMultiset.create[T]() | ||
|
||
/** | ||
* Total references per task. Used to auto-release references upon task completion. | ||
*/ | ||
private[this] val referencesByTask = { | ||
// We need to explicitly box as java.lang.Long to avoid a type mismatch error: | ||
val loader = new CacheLoader[java.lang.Long, ConcurrentHashMultiset[T]] { | ||
override def load(t: java.lang.Long) = ConcurrentHashMultiset.create[T]() | ||
} | ||
CacheBuilder.newBuilder().build(loader) | ||
} | ||
|
||
/** | ||
* Returns the total reference count, across all tasks, for the given object. | ||
*/ | ||
def getReferenceCount(obj: T): Int = allReferences.count(obj) | ||
|
||
/** | ||
* Increments the given object's reference count for the current task. | ||
*/ | ||
def retain(obj: T): Unit = retainForTask(currentTaskAttemptId, obj) | ||
|
||
/** | ||
* Decrements the given object's reference count for the current task. | ||
*/ | ||
def release(obj: T): Unit = releaseForTask(currentTaskAttemptId, obj) | ||
|
||
private def currentTaskAttemptId: TaskAttemptId = { | ||
Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) | ||
} | ||
|
||
/** | ||
* Increments the given object's reference count for the given task. | ||
*/ | ||
def retainForTask(taskAttemptId: TaskAttemptId, obj: T): Unit = { | ||
referencesByTask.get(taskAttemptId).add(obj) | ||
allReferences.add(obj) | ||
} | ||
|
||
/** | ||
* Decrements the given object's reference count for the given task. | ||
*/ | ||
def releaseForTask(taskAttemptId: TaskAttemptId, obj: T): Unit = { | ||
val countsForTask = referencesByTask.get(taskAttemptId) | ||
val newReferenceCountForTask: Int = countsForTask.remove(obj, 1) - 1 | ||
val newTotalReferenceCount: Int = allReferences.remove(obj, 1) - 1 | ||
if (newReferenceCountForTask < 0) { | ||
throw new IllegalStateException( | ||
s"Task $taskAttemptId released object $obj more times than it was retained") | ||
} | ||
if (newTotalReferenceCount < 0) { | ||
throw new IllegalStateException( | ||
s"Task $taskAttemptId released object $obj more times than it was retained") | ||
} | ||
} | ||
|
||
/** | ||
* Release all references held by the given task, clearing that task's reference bookkeeping | ||
* structures and updating the global reference counts. This method should be called at the | ||
* end of a task (either by a task completion handler or in `TaskRunner.run()`). | ||
*/ | ||
def releaseAllReferencesForTask(taskAttemptId: TaskAttemptId): Unit = { | ||
val referenceCounts = referencesByTask.get(taskAttemptId) | ||
referencesByTask.invalidate(taskAttemptId) | ||
referenceCounts.entrySet().iterator().asScala.foreach { entry => | ||
val obj = entry.getElement | ||
val taskRefCount = entry.getCount | ||
val newRefCount = allReferences.remove(obj, taskRefCount) - taskRefCount | ||
if (newRefCount < 0) { | ||
throw new IllegalStateException( | ||
s"Task $taskAttemptId released object $obj more times than it was retained") | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Return the number of map entries in this reference counter's internal data structures. | ||
* This is used in unit tests in order to detect memory leaks. | ||
*/ | ||
private[collection] def getNumberOfMapEntries: Long = { | ||
allReferences.size() + | ||
referencesByTask.size() + | ||
referencesByTask.asMap().asScala.map(_._2.size()).sum | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment in this class that explains the ref counting mechanism? It can be a shorter version of the commit message.
Specifically:
What are the invariants? (explain get()) Need to call release. What does it mean if it is 0?
I slightly prefer pin count over ref count (the block manager has a reference but it is unpinned)