diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e6adc40f83fa3..82d4b007589a6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -418,7 +418,7 @@ private[spark] class BlockManager( val iter: Iterator[Any] = if (level.deserialized) { memoryStore.getValues(blockId).get } else { - dataDeserializeStream(blockId, memoryStore.getBytes(blockId).get.toInputStream) + dataDeserialize(blockId, memoryStore.getBytes(blockId).get) } val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) @@ -426,7 +426,7 @@ private[spark] class BlockManager( val iterToReturn: Iterator[Any] = { val diskBytes = diskStore.getBytes(blockId) if (level.deserialized) { - val diskValues = dataDeserializeStream(blockId, diskBytes.toDestructiveInputStream) + val diskValues = dataDeserialize(blockId, diskBytes) maybeCacheDiskValuesInMemory(info, blockId, level, diskValues) } else { dataDeserialize(blockId, maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes)) @@ -505,8 +505,7 @@ private[spark] class BlockManager( */ def getRemoteValues(blockId: BlockId): Option[BlockResult] = { getRemoteBytes(blockId).map { data => - val values = dataDeserializeStream(blockId, data.toInputStream) - new BlockResult(values, DataReadMethod.Network, data.limit) + new BlockResult(dataDeserialize(blockId, data), DataReadMethod.Network, data.limit) } } @@ -751,7 +750,7 @@ private[spark] class BlockManager( // Put it in memory first, even if it also has useDisk set to true; // We will drop it to disk later if the memory store can't hold it. val putSucceeded = if (level.deserialized) { - val values = dataDeserializeStream(blockId, bytes.toInputStream) + val values = dataDeserialize(blockId, bytes) memoryStore.putIterator(blockId, values, level) match { case Right(_) => true case Left(iter) => diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index c9f8e6578c936..5809548c4a10d 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -40,7 +40,6 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } def writeFully(channel: WritableByteChannel): Unit = { - assertNotDisposed() for (bytes <- getChunks()) { while (bytes.remaining > 0) { channel.write(bytes) @@ -49,12 +48,10 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } def toNetty: ByteBuf = { - assertNotDisposed() Unpooled.wrappedBuffer(getChunks(): _*) } def toArray: Array[Byte] = { - assertNotDisposed() if (limit >= Integer.MAX_VALUE) { throw new UnsupportedOperationException( s"cannot call toArray because buffer size ($limit bytes) exceeds maximum array size") @@ -65,24 +62,15 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { byteChannel.getData } - def toInputStream: InputStream = { - assertNotDisposed() - new ChunkedByteBufferInputStream(getChunks().iterator) - } - - def toDestructiveInputStream: InputStream = { - val is = new ChunkedByteBufferInputStream(chunks.iterator) - chunks = null - is + def toInputStream(dispose: Boolean = false): InputStream = { + new ChunkedByteBufferInputStream(this, dispose) } def getChunks(): Array[ByteBuffer] = { - assertNotDisposed() chunks.map(_.duplicate()) } def copy(): ChunkedByteBuffer = { - assertNotDisposed() val copiedChunks = getChunks().map { chunk => // TODO: accept an allocator in this copy method to integrate with mem. accounting systems val newChunk = ByteBuffer.allocate(chunk.limit()) @@ -93,41 +81,29 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { new ChunkedByteBuffer(copiedChunks) } + /** + * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that + * might cause errors if one attempts to read from the unmapped buffer, but it's better than + * waiting for the GC to find it because that could lead to huge numbers of open files. There's + * unfortunately no standard API to do this. + */ def dispose(): Unit = { - assertNotDisposed() chunks.foreach(BlockManager.dispose) - chunks = null - } - - private def assertNotDisposed(): Unit = { - if (chunks == null) { - throw new IllegalStateException("Cannot call methods on a disposed ChunkedByteBuffer") - } } } -private class ChunkedByteBufferInputStream(chunks: Iterator[ByteBuffer]) extends InputStream { +/** + * Reads data from a ChunkedByteBuffer, and optionally cleans it up using BlockManager.dispose() + * at the end of the stream (e.g. to close a memory-mapped file). + */ +private class ChunkedByteBufferInputStream( + var chunkedByteBuffer: ChunkedByteBuffer, + dispose: Boolean) + extends InputStream { + private[this] var chunks = chunkedByteBuffer.getChunks().iterator private[this] var currentChunk: ByteBuffer = chunks.next() - override def available(): Int = { - while (!currentChunk.hasRemaining && chunks.hasNext) { - BlockManager.dispose(currentChunk) - currentChunk = chunks.next() - } - currentChunk.remaining() - } - -// override def skip(n: Long): Long = { -// // TODO(josh): check contract -// var i = n -// while (i > 0) { -// read() -// i -= 1 -// } -// n -// } - override def read(): Int = { if (currentChunk != null && !currentChunk.hasRemaining && chunks.hasNext) { BlockManager.dispose(currentChunk) @@ -136,25 +112,24 @@ private class ChunkedByteBufferInputStream(chunks: Iterator[ByteBuffer]) extends if (currentChunk != null && currentChunk.hasRemaining) { UnsignedBytes.toInt(currentChunk.get()) } else { - BlockManager.dispose(currentChunk) - currentChunk = null + close() -1 } } // TODO(josh): implement // override def read(b: Array[Byte]): Int = super.read(b) -// // override def read(b: Array[Byte], off: Int, len: Int): Int = super.read(b, off, len) +// override def skip(n: Long): Long = super.skip(n) override def close(): Unit = { if (currentChunk != null) { - BlockManager.dispose(currentChunk) - while (chunks.hasNext) { - currentChunk = chunks.next() - BlockManager.dispose(currentChunk) + if (dispose) { + chunkedByteBuffer.dispose() } } + chunkedByteBuffer = null + chunks = null currentChunk = null } } diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index 44b0f16693cc8..952b3f0924eed 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -73,17 +73,18 @@ class ChunkedByteBufferSuite extends SparkFunSuite { } } - // TODO(josh) test dispose behavior test("toInputStream()") { val bytes1 = ByteBuffer.wrap(Array.tabulate(256)(_.toByte)) val bytes2 = ByteBuffer.wrap(Array.tabulate(128)(_.toByte)) val chunkedByteBuffer = new ChunkedByteBuffer(Array(bytes1, bytes2)) assert(chunkedByteBuffer.limit === bytes1.limit() + bytes2.limit()) - val inputStream = chunkedByteBuffer.toInputStream(false) + val inputStream = chunkedByteBuffer.toInputStream(dispose = false) val bytesFromStream = new Array[Byte](chunkedByteBuffer.limit.toInt) ByteStreams.readFully(inputStream, bytesFromStream) assert(bytesFromStream === bytes1.array() ++ bytes2.array()) assert(chunkedByteBuffer.getChunks().head.position() === 0) } + + // TODO(josh): figure out how to test the dispose=true case. }