Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some improvements in SortShuffleRead #4

Merged
merged 5 commits into from
Nov 13, 2014
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,14 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
}

/** Release numBytes bytes for the current thread. */
def release(numBytes: Long): Unit = release(numBytes, Thread.currentThread().getId)

/** Release numBytes bytes for the specific thread. */
def release(numBytes: Long, tid: Long): Unit = synchronized {
val curMem = threadMemory.getOrElse(tid, 0L)
def release(numBytes: Long): Unit = synchronized {
val threadId = Thread.currentThread().getId
val curMem = threadMemory.getOrElse(threadId, 0L)
if (curMem < numBytes) {
throw new SparkException(
s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}")
}
threadMemory(tid) -= numBytes
threadMemory(threadId) -= numBytes
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.shuffle.sort

import java.io.{BufferedOutputStream, FileOutputStream}
import java.io.FileOutputStream
import java.util.Comparator

import scala.collection.mutable.{ArrayBuffer, HashMap, Queue}
import org.apache.spark.executor.ShuffleWriteMetrics

import scala.collection.mutable
import scala.util.{Failure, Success, Try}

import org.apache.spark._
Expand Down Expand Up @@ -59,9 +61,6 @@ private[spark] class SortShuffleReader[K, C](
/** Shuffle block fetcher iterator */
private var shuffleRawBlockFetcherItr: ShuffleRawBlockFetcherIterator = _

/** Number of bytes left to fetch */
private var unfetchedBytes: Long = _

private val dep = handle.dependency
private val conf = SparkEnv.get.conf
private val blockManager = SparkEnv.get.blockManager
Expand All @@ -70,11 +69,23 @@ private[spark] class SortShuffleReader[K, C](

private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024

/** ArrayBuffer to store in-memory shuffle blocks */
private val inMemoryBlocks = new Queue[MemoryShuffleBlock]()
/** Number of bytes spilled in memory and on disk */
private var _memoryBytesSpilled: Long = 0L
private var _diskBytesSpilled: Long = 0L

/** Queue to store in-memory shuffle blocks */
private val inMemoryBlocks = new mutable.Queue[MemoryShuffleBlock]()

/** number of bytes left to fetch */
private var unfetchedBytes: Long = 0L

/** Manage the BlockManagerId and related shuffle blocks */
private var statuses: Array[(BlockManagerId, Long)] = _
/**
* Maintain the relation between shuffle block and its size. The reason we should maintain this
* is that the request shuffle block size is not equal to the result size because of
* compression of size. So here we should maintain this make sure the correctness of our
* algorithm.
*/
private val shuffleBlockMap = new mutable.HashMap[ShuffleBlockId, (BlockManagerId, Long)]()

/** keyComparator for mergeSort, id keyOrdering is not available,
* using hashcode of key to compare */
Expand All @@ -89,42 +100,51 @@ private[spark] class SortShuffleReader[K, C](
/** A merge thread to merge on-disk blocks */
private val tieredMerger = new TieredDiskMerger(conf, dep, keyComparator, context)

def memoryBytesSpilled: Long = _memoryBytesSpilled

def diskBytesSpilled: Long = _diskBytesSpilled + tieredMerger.diskBytesSpilled

override def read(): Iterator[Product2[K, C]] = {
tieredMerger.start()

computeShuffleBlocks()

for ((blockId, blockOption) <- fetchRawBlocks()) {
val blockData = blockOption match {
case Success(block) => block
case Failure(e) =>
blockId match {
case ShuffleBlockId (shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
throw new FetchFailedException (address, shufId.toInt, mapId.toInt, startPartition,
case b @ ShuffleBlockId(shuffleId, mapId, _) =>
val address = shuffleBlockMap(b)._1
throw new FetchFailedException (address, shuffleId.toInt, mapId.toInt, startPartition,
Utils.exceptionString (e))
case _ =>
throw new SparkException (
s"Failed to get block $blockId, which is not a shuffle block", e)
}
}

shuffleRawBlockFetcherItr.currentResult = null

// Try to fit block in memory. If this fails, merge in-memory blocks to disk.
val blockSize = blockData.size
val granted = shuffleMemoryManager.tryToAcquire(blockSize)
val block = MemoryShuffleBlock(blockId, blockData)
if (granted < blockSize) {
logInfo(s"Granted $granted memory is not enough to store shuffle block ($blockSize), " +
s"spilling in-memory blocks to release the memory")
if (granted >= blockSize) {
inMemoryBlocks += MemoryShuffleBlock(blockId, blockData)
} else {
logInfo(s"Granted $granted memory is not enough to store shuffle block id $blockId, " +
s"block size $blockSize, spilling in-memory blocks to release the memory")

shuffleMemoryManager.release(granted)
spillInMemoryBlocks(block)
} else {
inMemoryBlocks += block
}

unfetchedBytes -= blockData.size()
shuffleRawBlockFetcherItr.currentResult = null
unfetchedBytes -= shuffleBlockMap(blockId.asInstanceOf[ShuffleBlockId])._2
}
assert(unfetchedBytes == 0)

// Make sure all the blocks have been fetched.
assert(unfetchedBytes == 0L)

tieredMerger.doneRegisteringOnDiskBlocks()

Expand All @@ -133,9 +153,13 @@ private[spark] class SortShuffleReader[K, C](
val mergedItr =
MergeUtil.mergeSort(finalItrGroup, keyComparator, dep.keyOrdering, dep.aggregator)

// Update the spilled info.
context.taskMetrics().memoryBytesSpilled += memoryBytesSpilled
context.taskMetrics().diskBytesSpilled += diskBytesSpilled

// Release the in-memory block when iteration is completed.
val completionItr = CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](
mergedItr, () => {
mergedItr, {
inMemoryBlocks.foreach { block =>
block.blockData.release()
shuffleMemoryManager.release(block.blockData.size)
Expand All @@ -146,48 +170,93 @@ private[spark] class SortShuffleReader[K, C](
new InterruptibleIterator(context, completionItr.map(p => (p._1, p._2)))
}

def spillInMemoryBlocks(tippingBlock: MemoryShuffleBlock): Unit = {
private def spillInMemoryBlocks(tippingBlock: MemoryShuffleBlock): Unit = {
// Write merged blocks to disk
val (tmpBlockId, file) = blockManager.diskBlockManager.createTempShuffleBlock()
val fos = new FileOutputStream(file)
val bos = new BufferedOutputStream(fos, fileBufferSize)

// If the remaining unfetched data would fit inside our current allocation, we don't want to
// waste time spilling blocks beyond the space needed for it.
// We use the request size to calculate the remaining spilled size to make sure the
// correctness, since the request size is slightly different from result block size because
// of size compression.
var bytesToSpill = unfetchedBytes
val blocksToSpill = new ArrayBuffer[MemoryShuffleBlock]()
val blocksToSpill = new mutable.ArrayBuffer[MemoryShuffleBlock]()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Judging from other files, Spark convention appears to be using "ArrayBuffer", not "mutable.ArrayBuffer"

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, IDEA complains about the previous code, so I changed. I will return to the previous one.

blocksToSpill += tippingBlock
bytesToSpill -= tippingBlock.blockData.size
bytesToSpill -= shuffleBlockMap(tippingBlock.blockId.asInstanceOf[ShuffleBlockId])._2
while (bytesToSpill > 0 && !inMemoryBlocks.isEmpty) {
val block = inMemoryBlocks.dequeue()
blocksToSpill += block
bytesToSpill -= block.blockData.size
bytesToSpill -= shuffleBlockMap(block.blockId.asInstanceOf[ShuffleBlockId])._2
}

if (blocksToSpill.size > 1) {
_memoryBytesSpilled += blocksToSpill.map(_.blockData.size()).sum

if (inMemoryBlocks.size > 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be "if (blocksToSpill.size > 1)", right?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, my fault, I will modify it.

val itrGroup = inMemoryBlocksToIterators(blocksToSpill)
val partialMergedItr =
MergeUtil.mergeSort(itrGroup, keyComparator, dep.keyOrdering, dep.aggregator)
blockManager.dataSerializeStream(tmpBlockId, bos, partialMergedItr, ser)
val curWriteMetrics = new ShuffleWriteMetrics()
var writer =
blockManager.getDiskWriter(tmpBlockId, file, ser, fileBufferSize, curWriteMetrics)
var success = false

try {
partialMergedItr.foreach(p => writer.write(p))
success = true
} finally {
if (!success) {
if (writer != null) {
writer.revertPartialWritesAndClose()
writer = null
}
if (file.exists()) {
file.delete()
}
} else {
writer.commitAndClose()
writer = null
}
}
_diskBytesSpilled += curWriteMetrics.shuffleBytesWritten

} else {
val fos = new FileOutputStream(file)
val buffer = blocksToSpill.map(_.blockData.nioByteBuffer()).head
val channel = fos.getChannel
while (buffer.hasRemaining) {
channel.write(buffer)
var channel = fos.getChannel
var success = false

try {
while (buffer.hasRemaining) {
channel.write(buffer)
}
success = true
} finally {
if (channel != null) {
channel.close()
channel = null
}
if (!success) {
if (file.exists()) {
file.delete()
}
} else {
_diskBytesSpilled = file.length()
}
}
channel.close()
}

tieredMerger.registerOnDiskBlock(tmpBlockId, file)

logInfo(s"Merged ${blocksToSpill.size} in-memory blocks into file ${file.getName}")
logInfo(s"Merged ${inMemoryBlocks.size} in-memory blocks into file ${file.getName}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is blocksToSpill not right here? inMemoryBlocks will hold the blocks that we haven't spilled, right?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I forgot to change this, will modify it.


for (block <- blocksToSpill) {
for (block <- inMemoryBlocks) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be blocksToSpill as well, right?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, sorry to forget that.

block.blockData.release()
if (block != tippingBlock) {
shuffleMemoryManager.release(block.blockData.size)
}
}

inMemoryBlocks.clear()
}

private def inMemoryBlocksToIterators(blocks: Seq[MemoryShuffleBlock])
Expand All @@ -198,33 +267,42 @@ private[spark] class SortShuffleReader[K, C](
}
}

private def fetchRawBlocks(): Iterator[(BlockId, Try[ManagedBuffer])] = {
statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(handle.shuffleId, startPartition)
private def computeShuffleBlocks(): Unit = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method should have a header comment explaining what it does.

val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(handle.shuffleId, startPartition)

val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]()
val splitsByAddress = new mutable.HashMap[BlockManagerId, mutable.ArrayBuffer[(Int, Long)]]()
for (((address, size), index) <- statuses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
splitsByAddress.getOrElseUpdate(address, mutable.ArrayBuffer()) += ((index, size))
}

val blocksByAddress = splitsByAddress.toSeq.map { case (address, splits) =>
val blocks = splits.map { s =>
(ShuffleBlockId(handle.shuffleId, s._1, startPartition), s._2)
splitsByAddress.foreach { case (id, blocks) =>
blocks.foreach { case (idx, len) =>
shuffleBlockMap.put(ShuffleBlockId(handle.shuffleId, idx, startPartition), (id, len))
unfetchedBytes += len
}
(address, blocks.toSeq)
}
unfetchedBytes = blocksByAddress.flatMap(a => a._2.map(b => b._2)).sum
}

private def fetchRawBlocks(): Iterator[(BlockId, Try[ManagedBuffer])] = {
val blocksByAddress = new mutable.HashMap[BlockManagerId,
mutable.ArrayBuffer[(ShuffleBlockId, Long)]]()

shuffleBlockMap.foreach { case (block, (id, len)) =>
blocksByAddress.getOrElseUpdate(id,
mutable.ArrayBuffer[(ShuffleBlockId, Long)]()) += ((block, len))
}

shuffleRawBlockFetcherItr = new ShuffleRawBlockFetcherIterator(
context,
SparkEnv.get.blockManager.shuffleClient,
blockManager,
blocksByAddress,
blocksByAddress.toSeq,
conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024)

val completionItr = CompletionIterator[
(BlockId, Try[ManagedBuffer]),
Iterator[(BlockId, Try[ManagedBuffer])]](shuffleRawBlockFetcherItr,
() => context.taskMetrics.updateShuffleReadMetrics())
context.taskMetrics.updateShuffleReadMetrics())

new InterruptibleIterator[(BlockId, Try[ManagedBuffer])](context, completionItr)
}
Expand Down
Loading