-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Some improvements in SortShuffleRead #4
Changes from 4 commits
2585658
728f4f0
d5a5cb7
1189f0e
4a5711b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,10 +17,12 @@ | |
|
||
package org.apache.spark.shuffle.sort | ||
|
||
import java.io.{BufferedOutputStream, FileOutputStream} | ||
import java.io.FileOutputStream | ||
import java.util.Comparator | ||
|
||
import scala.collection.mutable.{ArrayBuffer, HashMap, Queue} | ||
import org.apache.spark.executor.ShuffleWriteMetrics | ||
|
||
import scala.collection.mutable | ||
import scala.util.{Failure, Success, Try} | ||
|
||
import org.apache.spark._ | ||
|
@@ -59,9 +61,6 @@ private[spark] class SortShuffleReader[K, C]( | |
/** Shuffle block fetcher iterator */ | ||
private var shuffleRawBlockFetcherItr: ShuffleRawBlockFetcherIterator = _ | ||
|
||
/** Number of bytes left to fetch */ | ||
private var unfetchedBytes: Long = _ | ||
|
||
private val dep = handle.dependency | ||
private val conf = SparkEnv.get.conf | ||
private val blockManager = SparkEnv.get.blockManager | ||
|
@@ -70,11 +69,23 @@ private[spark] class SortShuffleReader[K, C]( | |
|
||
private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 | ||
|
||
/** ArrayBuffer to store in-memory shuffle blocks */ | ||
private val inMemoryBlocks = new Queue[MemoryShuffleBlock]() | ||
/** Number of bytes spilled in memory and on disk */ | ||
private var _memoryBytesSpilled: Long = 0L | ||
private var _diskBytesSpilled: Long = 0L | ||
|
||
/** Queue to store in-memory shuffle blocks */ | ||
private val inMemoryBlocks = new mutable.Queue[MemoryShuffleBlock]() | ||
|
||
/** number of bytes left to fetch */ | ||
private var unfetchedBytes: Long = 0L | ||
|
||
/** Manage the BlockManagerId and related shuffle blocks */ | ||
private var statuses: Array[(BlockManagerId, Long)] = _ | ||
/** | ||
* Maintain the relation between shuffle block and its size. The reason we should maintain this | ||
* is that the request shuffle block size is not equal to the result size because of | ||
* compression of size. So here we should maintain this make sure the correctness of our | ||
* algorithm. | ||
*/ | ||
private val shuffleBlockMap = new mutable.HashMap[ShuffleBlockId, (BlockManagerId, Long)]() | ||
|
||
/** keyComparator for mergeSort, id keyOrdering is not available, | ||
* using hashcode of key to compare */ | ||
|
@@ -89,42 +100,51 @@ private[spark] class SortShuffleReader[K, C]( | |
/** A merge thread to merge on-disk blocks */ | ||
private val tieredMerger = new TieredDiskMerger(conf, dep, keyComparator, context) | ||
|
||
def memoryBytesSpilled: Long = _memoryBytesSpilled | ||
|
||
def diskBytesSpilled: Long = _diskBytesSpilled + tieredMerger.diskBytesSpilled | ||
|
||
override def read(): Iterator[Product2[K, C]] = { | ||
tieredMerger.start() | ||
|
||
computeShuffleBlocks() | ||
|
||
for ((blockId, blockOption) <- fetchRawBlocks()) { | ||
val blockData = blockOption match { | ||
case Success(block) => block | ||
case Failure(e) => | ||
blockId match { | ||
case ShuffleBlockId (shufId, mapId, _) => | ||
val address = statuses(mapId.toInt)._1 | ||
throw new FetchFailedException (address, shufId.toInt, mapId.toInt, startPartition, | ||
case b @ ShuffleBlockId(shuffleId, mapId, _) => | ||
val address = shuffleBlockMap(b)._1 | ||
throw new FetchFailedException (address, shuffleId.toInt, mapId.toInt, startPartition, | ||
Utils.exceptionString (e)) | ||
case _ => | ||
throw new SparkException ( | ||
s"Failed to get block $blockId, which is not a shuffle block", e) | ||
} | ||
} | ||
|
||
shuffleRawBlockFetcherItr.currentResult = null | ||
|
||
// Try to fit block in memory. If this fails, merge in-memory blocks to disk. | ||
val blockSize = blockData.size | ||
val granted = shuffleMemoryManager.tryToAcquire(blockSize) | ||
val block = MemoryShuffleBlock(blockId, blockData) | ||
if (granted < blockSize) { | ||
logInfo(s"Granted $granted memory is not enough to store shuffle block ($blockSize), " + | ||
s"spilling in-memory blocks to release the memory") | ||
if (granted >= blockSize) { | ||
inMemoryBlocks += MemoryShuffleBlock(blockId, blockData) | ||
} else { | ||
logInfo(s"Granted $granted memory is not enough to store shuffle block id $blockId, " + | ||
s"block size $blockSize, spilling in-memory blocks to release the memory") | ||
|
||
shuffleMemoryManager.release(granted) | ||
spillInMemoryBlocks(block) | ||
} else { | ||
inMemoryBlocks += block | ||
} | ||
|
||
unfetchedBytes -= blockData.size() | ||
shuffleRawBlockFetcherItr.currentResult = null | ||
unfetchedBytes -= shuffleBlockMap(blockId.asInstanceOf[ShuffleBlockId])._2 | ||
} | ||
assert(unfetchedBytes == 0) | ||
|
||
// Make sure all the blocks have been fetched. | ||
assert(unfetchedBytes == 0L) | ||
|
||
tieredMerger.doneRegisteringOnDiskBlocks() | ||
|
||
|
@@ -133,9 +153,13 @@ private[spark] class SortShuffleReader[K, C]( | |
val mergedItr = | ||
MergeUtil.mergeSort(finalItrGroup, keyComparator, dep.keyOrdering, dep.aggregator) | ||
|
||
// Update the spilled info. | ||
context.taskMetrics().memoryBytesSpilled += memoryBytesSpilled | ||
context.taskMetrics().diskBytesSpilled += diskBytesSpilled | ||
|
||
// Release the in-memory block when iteration is completed. | ||
val completionItr = CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]]( | ||
mergedItr, () => { | ||
mergedItr, { | ||
inMemoryBlocks.foreach { block => | ||
block.blockData.release() | ||
shuffleMemoryManager.release(block.blockData.size) | ||
|
@@ -146,48 +170,93 @@ private[spark] class SortShuffleReader[K, C]( | |
new InterruptibleIterator(context, completionItr.map(p => (p._1, p._2))) | ||
} | ||
|
||
def spillInMemoryBlocks(tippingBlock: MemoryShuffleBlock): Unit = { | ||
private def spillInMemoryBlocks(tippingBlock: MemoryShuffleBlock): Unit = { | ||
// Write merged blocks to disk | ||
val (tmpBlockId, file) = blockManager.diskBlockManager.createTempShuffleBlock() | ||
val fos = new FileOutputStream(file) | ||
val bos = new BufferedOutputStream(fos, fileBufferSize) | ||
|
||
// If the remaining unfetched data would fit inside our current allocation, we don't want to | ||
// waste time spilling blocks beyond the space needed for it. | ||
// We use the request size to calculate the remaining spilled size to make sure the | ||
// correctness, since the request size is slightly different from result block size because | ||
// of size compression. | ||
var bytesToSpill = unfetchedBytes | ||
val blocksToSpill = new ArrayBuffer[MemoryShuffleBlock]() | ||
val blocksToSpill = new mutable.ArrayBuffer[MemoryShuffleBlock]() | ||
blocksToSpill += tippingBlock | ||
bytesToSpill -= tippingBlock.blockData.size | ||
bytesToSpill -= shuffleBlockMap(tippingBlock.blockId.asInstanceOf[ShuffleBlockId])._2 | ||
while (bytesToSpill > 0 && !inMemoryBlocks.isEmpty) { | ||
val block = inMemoryBlocks.dequeue() | ||
blocksToSpill += block | ||
bytesToSpill -= block.blockData.size | ||
bytesToSpill -= shuffleBlockMap(block.blockId.asInstanceOf[ShuffleBlockId])._2 | ||
} | ||
|
||
if (blocksToSpill.size > 1) { | ||
_memoryBytesSpilled += blocksToSpill.map(_.blockData.size()).sum | ||
|
||
if (inMemoryBlocks.size > 1) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be "if (blocksToSpill.size > 1)", right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, my fault, I will modify it. |
||
val itrGroup = inMemoryBlocksToIterators(blocksToSpill) | ||
val partialMergedItr = | ||
MergeUtil.mergeSort(itrGroup, keyComparator, dep.keyOrdering, dep.aggregator) | ||
blockManager.dataSerializeStream(tmpBlockId, bos, partialMergedItr, ser) | ||
val curWriteMetrics = new ShuffleWriteMetrics() | ||
var writer = | ||
blockManager.getDiskWriter(tmpBlockId, file, ser, fileBufferSize, curWriteMetrics) | ||
var success = false | ||
|
||
try { | ||
partialMergedItr.foreach(p => writer.write(p)) | ||
success = true | ||
} finally { | ||
if (!success) { | ||
if (writer != null) { | ||
writer.revertPartialWritesAndClose() | ||
writer = null | ||
} | ||
if (file.exists()) { | ||
file.delete() | ||
} | ||
} else { | ||
writer.commitAndClose() | ||
writer = null | ||
} | ||
} | ||
_diskBytesSpilled += curWriteMetrics.shuffleBytesWritten | ||
|
||
} else { | ||
val fos = new FileOutputStream(file) | ||
val buffer = blocksToSpill.map(_.blockData.nioByteBuffer()).head | ||
val channel = fos.getChannel | ||
while (buffer.hasRemaining) { | ||
channel.write(buffer) | ||
var channel = fos.getChannel | ||
var success = false | ||
|
||
try { | ||
while (buffer.hasRemaining) { | ||
channel.write(buffer) | ||
} | ||
success = true | ||
} finally { | ||
if (channel != null) { | ||
channel.close() | ||
channel = null | ||
} | ||
if (!success) { | ||
if (file.exists()) { | ||
file.delete() | ||
} | ||
} else { | ||
_diskBytesSpilled = file.length() | ||
} | ||
} | ||
channel.close() | ||
} | ||
|
||
tieredMerger.registerOnDiskBlock(tmpBlockId, file) | ||
|
||
logInfo(s"Merged ${blocksToSpill.size} in-memory blocks into file ${file.getName}") | ||
logInfo(s"Merged ${inMemoryBlocks.size} in-memory blocks into file ${file.getName}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is blocksToSpill not right here? inMemoryBlocks will hold the blocks that we haven't spilled, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops, I forgot to change this, will modify it. |
||
|
||
for (block <- blocksToSpill) { | ||
for (block <- inMemoryBlocks) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be blocksToSpill as well, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's right, sorry to forget that. |
||
block.blockData.release() | ||
if (block != tippingBlock) { | ||
shuffleMemoryManager.release(block.blockData.size) | ||
} | ||
} | ||
|
||
inMemoryBlocks.clear() | ||
} | ||
|
||
private def inMemoryBlocksToIterators(blocks: Seq[MemoryShuffleBlock]) | ||
|
@@ -198,33 +267,42 @@ private[spark] class SortShuffleReader[K, C]( | |
} | ||
} | ||
|
||
private def fetchRawBlocks(): Iterator[(BlockId, Try[ManagedBuffer])] = { | ||
statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(handle.shuffleId, startPartition) | ||
private def computeShuffleBlocks(): Unit = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method should have a header comment explaining what it does. |
||
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(handle.shuffleId, startPartition) | ||
|
||
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]() | ||
val splitsByAddress = new mutable.HashMap[BlockManagerId, mutable.ArrayBuffer[(Int, Long)]]() | ||
for (((address, size), index) <- statuses.zipWithIndex) { | ||
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) | ||
splitsByAddress.getOrElseUpdate(address, mutable.ArrayBuffer()) += ((index, size)) | ||
} | ||
|
||
val blocksByAddress = splitsByAddress.toSeq.map { case (address, splits) => | ||
val blocks = splits.map { s => | ||
(ShuffleBlockId(handle.shuffleId, s._1, startPartition), s._2) | ||
splitsByAddress.foreach { case (id, blocks) => | ||
blocks.foreach { case (idx, len) => | ||
shuffleBlockMap.put(ShuffleBlockId(handle.shuffleId, idx, startPartition), (id, len)) | ||
unfetchedBytes += len | ||
} | ||
(address, blocks.toSeq) | ||
} | ||
unfetchedBytes = blocksByAddress.flatMap(a => a._2.map(b => b._2)).sum | ||
} | ||
|
||
private def fetchRawBlocks(): Iterator[(BlockId, Try[ManagedBuffer])] = { | ||
val blocksByAddress = new mutable.HashMap[BlockManagerId, | ||
mutable.ArrayBuffer[(ShuffleBlockId, Long)]]() | ||
|
||
shuffleBlockMap.foreach { case (block, (id, len)) => | ||
blocksByAddress.getOrElseUpdate(id, | ||
mutable.ArrayBuffer[(ShuffleBlockId, Long)]()) += ((block, len)) | ||
} | ||
|
||
shuffleRawBlockFetcherItr = new ShuffleRawBlockFetcherIterator( | ||
context, | ||
SparkEnv.get.blockManager.shuffleClient, | ||
blockManager, | ||
blocksByAddress, | ||
blocksByAddress.toSeq, | ||
conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) | ||
|
||
val completionItr = CompletionIterator[ | ||
(BlockId, Try[ManagedBuffer]), | ||
Iterator[(BlockId, Try[ManagedBuffer])]](shuffleRawBlockFetcherItr, | ||
() => context.taskMetrics.updateShuffleReadMetrics()) | ||
context.taskMetrics.updateShuffleReadMetrics()) | ||
|
||
new InterruptibleIterator[(BlockId, Try[ManagedBuffer])](context, completionItr) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Judging from other files, Spark convention appears to be using "ArrayBuffer", not "mutable.ArrayBuffer"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, IDEA complains about the previous code, so I changed. I will return to the previous one.