Skip to content

Commit

Permalink
WIP towards understanding destruction.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Mar 15, 2016
1 parent 79b1a6a commit 7dbcd5a
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
blocks.zipWithIndex.foreach { case (block, i) =>
val pieceId = BroadcastBlockId(id, "piece" + i)
val bytes = new ChunkedByteBuffer(block)
val bytes = new ChunkedByteBuffer(block.duplicate())
if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ private[spark] class Executor(
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId,
new ChunkedByteBuffer(serializedDirectResult),
new ChunkedByteBuffer(serializedDirectResult.duplicate()),
StorageLevel.MEMORY_AND_DISK_SER)
logInfo(
s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,15 +418,15 @@ private[spark] class BlockManager(
val iter: Iterator[Any] = if (level.deserialized) {
memoryStore.getValues(blockId).get
} else {
dataDeserialize(blockId, memoryStore.getBytes(blockId).get)
dataDeserializeStream(blockId, memoryStore.getBytes(blockId).get.toInputStream)
}
val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId))
Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
} else if (level.useDisk && diskStore.contains(blockId)) {
val iterToReturn: Iterator[Any] = {
val diskBytes = diskStore.getBytes(blockId)
if (level.deserialized) {
val diskValues = dataDeserialize(blockId, diskBytes)
val diskValues = dataDeserializeStream(blockId, diskBytes.toDestructiveInputStream)
maybeCacheDiskValuesInMemory(info, blockId, level, diskValues)
} else {
dataDeserialize(blockId, maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes))
Expand Down Expand Up @@ -505,7 +505,8 @@ private[spark] class BlockManager(
*/
def getRemoteValues(blockId: BlockId): Option[BlockResult] = {
getRemoteBytes(blockId).map { data =>
new BlockResult(dataDeserialize(blockId, data), DataReadMethod.Network, data.limit)
val values = dataDeserializeStream(blockId, data.toInputStream)
new BlockResult(values, DataReadMethod.Network, data.limit)
}
}

Expand Down Expand Up @@ -750,7 +751,7 @@ private[spark] class BlockManager(
// Put it in memory first, even if it also has useDisk set to true;
// We will drop it to disk later if the memory store can't hold it.
val putSucceeded = if (level.deserialized) {
val values = dataDeserialize(blockId, bytes)
val values = dataDeserializeStream(blockId, bytes.toInputStream)
memoryStore.putIterator(blockId, values, level) match {
case Right(_) => true
case Left(iter) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ private[spark] class MemoryStore(
val entry = if (level.deserialized) {
new DeserializedMemoryEntry(arrayValues, SizeEstimator.estimate(arrayValues))
} else {
// TODO(josh): incrementally serialize
val bytes = new ChunkedByteBuffer(blockManager.dataSerialize(blockId, arrayValues.iterator))
new SerializedMemoryEntry(bytes, bytes.limit)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,34 @@ import io.netty.buffer.{ByteBuf, Unpooled}
import org.apache.spark.network.util.ByteArrayWritableChannel
import org.apache.spark.storage.BlockManager

private[spark] class ChunkedByteBuffer(_chunks: Array[ByteBuffer]) {
private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
require(chunks != null, "chunks must not be null")
require(chunks.nonEmpty, "Cannot create a ChunkedByteBuffer with no chunks")
require(chunks.forall(_.limit() > 0), "chunks must be non-empty")
require(chunks.forall(_.position() == 0), "chunks' positions must be 0")

require(_chunks.nonEmpty, "Cannot create a ChunkedByteBuffer with no chunks")
require(_chunks.forall(_.limit() > 0), "chunks must be non-empty")
val limit: Long = chunks.map(_.limit().asInstanceOf[Long]).sum

def this(byteBuffer: ByteBuffer) = {
this(Array(byteBuffer))
}

private[this] val chunks: Array[ByteBuffer] = {
_chunks.map(_.duplicate().rewind().asInstanceOf[ByteBuffer]) // doesn't actually copy bytes
}

val limit: Long = chunks.map(_.limit().asInstanceOf[Long]).sum

def writeFully(channel: WritableByteChannel): Unit = {
assertNotDisposed()
for (bytes <- getChunks()) {
while (bytes.remaining > 0) {
channel.write(bytes)
}
}
}

def toNetty: ByteBuf = Unpooled.wrappedBuffer(getChunks(): _*)
def toNetty: ByteBuf = {
assertNotDisposed()
Unpooled.wrappedBuffer(getChunks(): _*)
}

def toArray: Array[Byte] = {
assertNotDisposed()
if (limit >= Integer.MAX_VALUE) {
throw new UnsupportedOperationException(
s"cannot call toArray because buffer size ($limit bytes) exceeds maximum array size")
Expand All @@ -63,39 +65,55 @@ private[spark] class ChunkedByteBuffer(_chunks: Array[ByteBuffer]) {
byteChannel.getData
}

def toInputStream(dispose: Boolean): InputStream = new ChunkedByteBufferInputStream(this, dispose)
def toInputStream: InputStream = {
assertNotDisposed()
new ChunkedByteBufferInputStream(getChunks().iterator)
}

def getChunks(): Array[ByteBuffer] = chunks.map(_.duplicate())
def toDestructiveInputStream: InputStream = {
val is = new ChunkedByteBufferInputStream(chunks.iterator)
chunks = null
is
}

def getChunks(): Array[ByteBuffer] = {
assertNotDisposed()
chunks.map(_.duplicate())
}

def copy(): ChunkedByteBuffer = {
assertNotDisposed()
val copiedChunks = getChunks().map { chunk =>
// TODO: accept an allocator in this copy method, etc.
// TODO: accept an allocator in this copy method to integrate with mem. accounting systems
val newChunk = ByteBuffer.allocate(chunk.limit())
newChunk.put(chunk)
newChunk.flip()
newChunk
}
new ChunkedByteBuffer(copiedChunks)
}

def dispose(): Unit = {
assertNotDisposed()
chunks.foreach(BlockManager.dispose)
chunks = null
}
}

private def assertNotDisposed(): Unit = {
if (chunks == null) {
throw new IllegalStateException("Cannot call methods on a disposed ChunkedByteBuffer")
}
}
}

// TODO(josh): implement dispose

private class ChunkedByteBufferInputStream(
chunkedBuffer: ChunkedByteBuffer,
dispose: Boolean = false) extends InputStream {
private class ChunkedByteBufferInputStream(chunks: Iterator[ByteBuffer]) extends InputStream {

private[this] val chunksIterator: Iterator[ByteBuffer] = chunkedBuffer.getChunks().iterator
private[this] var currentChunk: ByteBuffer = chunksIterator.next()
assert(currentChunk.position() == 0)
private[this] var currentChunk: ByteBuffer = chunks.next()

override def available(): Int = {
while (!currentChunk.hasRemaining && chunksIterator.hasNext) {
currentChunk = chunksIterator.next()
assert(currentChunk.position() == 0)
while (!currentChunk.hasRemaining && chunks.hasNext) {
BlockManager.dispose(currentChunk)
currentChunk = chunks.next()
}
currentChunk.remaining()
}
Expand All @@ -111,13 +129,15 @@ private class ChunkedByteBufferInputStream(
// }

override def read(): Int = {
if (!currentChunk.hasRemaining && chunksIterator.hasNext) {
currentChunk = chunksIterator.next()
assert(currentChunk.position() == 0)
if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) {
BlockManager.dispose(currentChunk)
currentChunk = chunks.next()
}
if (currentChunk.hasRemaining) {
if (currentChunk != null && currentChunk.hasRemaining) {
UnsignedBytes.toInt(currentChunk.get())
} else {
BlockManager.dispose(currentChunk)
currentChunk = null
-1
}
}
Expand All @@ -128,6 +148,13 @@ private class ChunkedByteBufferInputStream(
// override def read(b: Array[Byte], off: Int, len: Int): Int = super.read(b, off, len)

override def close(): Unit = {
if (currentChunk != null) {
BlockManager.dispose(currentChunk)
while (chunks.hasNext) {
currentChunk = chunks.next()
BlockManager.dispose(currentChunk)
}
}
currentChunk = null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,30 +39,6 @@ class ChunkedByteBufferSuite extends SparkFunSuite {
}
}

test("constructor duplicates chunks") {
val byteBuffer = ByteBuffer.allocate(8)
byteBuffer.limit(4)
val chunkedByteBuffer = new ChunkedByteBuffer(Array(byteBuffer))
assert(chunkedByteBuffer.limit === 4)
assert(chunkedByteBuffer.getChunks().head.limit() === 4)
// Changing the original ByteBuffer's position and limit does not affect the ChunkedByteBuffer:
byteBuffer.limit(8)
byteBuffer.position(4)
assert(chunkedByteBuffer.limit === 4)
assert(chunkedByteBuffer.getChunks().head.limit() === 4)
assert(chunkedByteBuffer.getChunks().head.position() === 0)
}

test("constructor rewinds chunks") {
val byteBuffer = ByteBuffer.allocate(8)
byteBuffer.get()
byteBuffer.get()
assert(byteBuffer.position() === 2)
val chunkedByteBuffer = new ChunkedByteBuffer(Array(byteBuffer))
assert(chunkedByteBuffer.limit === 8)
assert(chunkedByteBuffer.getChunks().head.position() === 0)
}

test("getChunks() duplicates chunks") {
val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(8)))
chunkedByteBuffer.getChunks().head.position(4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
logInfo(s"Read partition data of $this from write ahead log, record handle " +
partition.walRecordHandle)
if (storeInBlockManager) {
blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead), storageLevel)
blockManager.putBytes(blockId, new ChunkedByteBuffer(dataRead.duplicate()), storageLevel)
logDebug(s"Stored partition data of $this into block manager with level $storageLevel")
dataRead.rewind()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private[streaming] class BlockManagerBasedBlockHandler(
putResult
case ByteBufferBlock(byteBuffer) =>
blockManager.putBytes(
blockId, new ChunkedByteBuffer(byteBuffer), storageLevel, tellMaster = true)
blockId, new ChunkedByteBuffer(byteBuffer.duplicate()), storageLevel, tellMaster = true)
case o =>
throw new SparkException(
s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}")
Expand Down Expand Up @@ -187,7 +187,10 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
// Store the block in block manager
val storeInBlockManagerFuture = Future {
val putSucceeded = blockManager.putBytes(
blockId, new ChunkedByteBuffer(serializedBlock), effectiveStorageLevel, tellMaster = true)
blockId,
new ChunkedByteBuffer(serializedBlock.duplicate()),
effectiveStorageLevel,
tellMaster = true)
if (!putSucceeded) {
throw new SparkException(
s"Could not store $blockId to block manager with storage level $storageLevel")
Expand Down

0 comments on commit 7dbcd5a

Please sign in to comment.