Skip to content

Commit

Permalink
More work
Browse files Browse the repository at this point in the history
  • Loading branch information
mateiz committed Jul 30, 2014
1 parent 3a56341 commit bbf359d
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@

package org.apache.spark.shuffle.sort

import java.io.{DataInputStream, FileInputStream}

import org.apache.spark.shuffle._
import org.apache.spark.{TaskContext, ShuffleDependency}
import org.apache.spark.shuffle.hash.HashShuffleReader
import org.apache.spark.storage.{DiskBlockManager, FileSegment, ShuffleBlockId}

private[spark] class SortShuffleManager extends ShuffleManager {
/**
Expand Down Expand Up @@ -57,4 +60,21 @@ private[spark] class SortShuffleManager extends ShuffleManager {

/** Shut down this ShuffleManager. */
override def stop(): Unit = {}

/** Get the location of a block in a map output file. Uses the index file we create for it. */
def getBlockLocation(blockId: ShuffleBlockId, diskManager: DiskBlockManager): FileSegment = {
// The block is actually going to be a range of a single map output file for this map,
// so
val realId = ShuffleBlockId(blockId.shuffleId, blockId.mapId, 0)
val indexFile = diskManager.getFile(realId.name + ".index")
val in = new DataInputStream(new FileInputStream(indexFile))
try {
in.skip(blockId.reduceId * 8)
val offset = in.readLong()
val nextOffset = in.readLong()
new FileSegment(diskManager.getFile(realId), offset, nextOffset - offset)
} finally {
in.close()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
package org.apache.spark.shuffle.sort

import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.{SparkEnv, Logging, TaskContext}
import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.storage.ShuffleBlockId
import java.util.concurrent.atomic.AtomicInteger
import org.apache.spark.executor.ShuffleWriteMetrics
import java.io.{BufferedOutputStream, FileOutputStream, DataOutputStream}

private[spark] class SortShuffleWriter[K, V, C](
handle: BaseShuffleHandle[K, V, C],
Expand All @@ -30,17 +34,24 @@ private[spark] class SortShuffleWriter[K, V, C](
extends ShuffleWriter[K, V] with Logging {

private val dep = handle.dependency
private val numOutputPartitions = dep.partitioner.numPartitions
private val numPartitions = dep.partitioner.numPartitions
private val metrics = context.taskMetrics

private val blockManager = SparkEnv.get.blockManager
private val shuffleBlockManager = blockManager.shuffleBlockManager
private val diskBlockManager = blockManager.diskBlockManager
private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))

private val conf = SparkEnv.get.conf
private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024

private var sorter: ExternalSorter[K, V, _] = null

private var stopping = false
private var mapStatus: MapStatus = null

/** Write a bunch of records to this task's output */
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
var sorter: ExternalSorter[K, V, _] = null

val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = {
if (dep.mapSideCombine) {
if (!dep.aggregator.isDefined) {
Expand All @@ -58,13 +69,81 @@ private[spark] class SortShuffleWriter[K, V, C](
}
}

// Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later
// serve different ranges of this file using an index file that we create at the end.
val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0)
val shuffleFile = blockManager.diskBlockManager.getFile(blockId)

// Track location of each range in the output file
val offsets = new Array[Long](numPartitions + 1)
val lengths = new Array[Long](numPartitions)

// Statistics
var totalBytes = 0L
var totalTime = 0L

for ((id, elements) <- partitions) {
if (elements.hasNext) {
val writer = blockManager.getDiskWriter(blockId, shuffleFile, ser, fileBufferSize)
for (elem <- elements) {
writer.write(elem)
}
writer.commit()
writer.close()
val segment = writer.fileSegment()
offsets(id + 1) = segment.offset + segment.length
lengths(id) = segment.length
totalTime += writer.timeWriting()
totalBytes += segment.length
} else {
// Don't create a new writer to avoid writing any headers and things like that
offsets(id + 1) = offsets(id)
}
}

val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
shuffleMetrics.shuffleWriteTime = totalTime
context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics)

// Write an index file with the offsets of each block, plus a final offset at the end for the
// end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure
// out where each block begins and ends.

val diskBlockManager = blockManager.diskBlockManager
val indexFile = diskBlockManager.getFile(blockId.name + ".index")
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
try {
var i = 0
while (i < numPartitions + 1) {
out.writeLong(offsets(i))
i += 1
}
} finally {
out.close()
}

???
mapStatus = new MapStatus(blockManager.blockManagerId,
lengths.map(MapOutputTracker.compressSize))

// TODO: keep track of our file in a way that can be cleaned up later
}

/** Close this writer, passing along whether the map completed */
override def stop(success: Boolean): Option[MapStatus] = ???
override def stop(success: Boolean): Option[MapStatus] = {
try {
if (stopping) {
return None
}
stopping = true
if (success) {
return Option(mapStatus)
} else {
// TODO: clean up our file
return None
}
} finally {
// TODO: sorter.stop()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ import java.io.File
import java.text.SimpleDateFormat
import java.util.{Date, Random, UUID}

import org.apache.spark.Logging
import org.apache.spark.{SparkEnv, Logging}
import org.apache.spark.executor.ExecutorExitCode
import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
import org.apache.spark.util.Utils
import org.apache.spark.shuffle.sort.SortShuffleManager

/**
* Creates and maintains the logical mapping between logical blocks and physical on-disk
Expand Down Expand Up @@ -54,12 +55,15 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
addShutdownHook()

/**
* Returns the physical file segment in which the given BlockId is located.
* If the BlockId has been mapped to a specific FileSegment, that will be returned.
* Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
* Returns the physical file segment in which the given BlockId is located. If the BlockId has
* been mapped to a specific FileSegment by the shuffle layer, that will be returned.
* Otherwise, we assume the Block is mapped to the whole file identified by the BlockId.
*/
def getBlockLocation(blockId: BlockId): FileSegment = {
if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) {
if (blockId.isShuffle && SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]) {
val sortShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[SortShuffleManager]
sortShuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId], this)
} else if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) {
shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
} else {
val file = getFile(blockId.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.util.collection

import java.io._
import java.util.Comparator

import scala.collection.mutable.ArrayBuffer

Expand Down Expand Up @@ -88,6 +89,13 @@ private[spark] class ExternalSorter[K, V, C](
(Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
}

// For now, just compare them by partition; later we can compare by key as well
private val comparator = new Comparator[((Int, K), C)] {
override def compare(a: ((Int, K), C), b: ((Int, K), C)): Int = {
a._1._1 - b._1._1
}
}

// Information about a spilled file. Includes sizes in bytes of "batches" written by the
// serializer as we periodically reset its stream, as well as number of elements in each
// partition, used to efficiently keep track of partitions when merging.
Expand Down Expand Up @@ -192,7 +200,7 @@ private[spark] class ExternalSorter[K, V, C](
}

try {
val it = collection.iterator // TODO: destructiveSortedIterator(comparator)
val it = collection.destructiveSortedIterator(comparator)
while (it.hasNext) {
val elem = it.next()
val partitionId = elem._1._1
Expand Down Expand Up @@ -232,11 +240,22 @@ private[spark] class ExternalSorter[K, V, C](
* inside each partition. This can be used to either write out a new file or return data to
* the user.
*/
def merge(spills: Seq[SpilledFile]): Iterator[(Int, Iterator[Product2[K, C]])] = {
def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
: Iterator[(Int, Iterator[Product2[K, C]])] = {
// TODO: merge intermediate results if they are sorted by the comparator
val readers = spills.map(new SpillReader(_))
val inMemBuffered = inMemory.buffered
(0 until numPartitions).iterator.map { p =>
(p, readers.iterator.flatMap(_.readNextPartition()))
val inMemIterator = new Iterator[(K, C)] {
override def hasNext: Boolean = {
inMemBuffered.hasNext && inMemBuffered.head._1._1 == p
}
override def next(): (K, C) = {
val elem = inMemBuffered.next()
(elem._1._2, elem._2)
}
}
(p, readers.iterator.flatMap(_.readNextPartition()) ++ inMemIterator)
}
}

Expand Down Expand Up @@ -301,6 +320,11 @@ private[spark] class ExternalSorter[K, V, C](
}
val k = deserStream.readObject().asInstanceOf[K]
val c = deserStream.readObject().asInstanceOf[C]
if (partitionId == numPartitions - 1 &&
indexInPartition == spill.elementsPerPartition(partitionId) - 1) {
finished = true
deserStream.close()
}
(k, c)
} catch {
case e: EOFException =>
Expand All @@ -319,6 +343,9 @@ private[spark] class ExternalSorter[K, V, C](
override def hasNext: Boolean = {
if (nextItem == null) {
nextItem = readNextItem()
if (nextItem == null) {
return false
}
}
// Check that we're still in the right partition; will be numPartitions at EOF
partitionId == myPartition
Expand All @@ -328,7 +355,9 @@ private[spark] class ExternalSorter[K, V, C](
if (!hasNext) {
throw new NoSuchElementException
}
nextItem
val item = nextItem
nextItem = null
item
}
}
}
Expand All @@ -337,11 +366,16 @@ private[spark] class ExternalSorter[K, V, C](
* Return an iterator over all the data written to this object, grouped by partition. For each
* partition we then have an iterator over its contents, and these are expected to be accessed
* in order (you can't "skip ahead" to one partition without reading the previous one).
* Guaranteed to return a key-value pair for each partition, in order of partition ID.
*
* For now, we just merge all the spilled files in once pass, but this can be modified to
* support hierarchical merging.
*/
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = merge(spills)
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
val usingMap = aggregator.isDefined
val collection: SizeTrackingCollection[((Int, K), C)] = if (usingMap) map else buffer
merge(spills, collection.destructiveSortedIterator(comparator))
}

/**
* Return an iterator over all the data written to this object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,3 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
}

}

/**
* A dummy class that always returns the same hash code, to easily test hash collisions
*/
case class FixedHashObject(v: Int, h: Int) extends Serializable {
override def hashCode(): Int = h
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.collection

import org.scalatest.FunSuite

import org.apache.spark.{SparkContext, SparkConf, LocalSparkContext}
import org.apache.spark.SparkContext._
import scala.collection.mutable.ArrayBuffer

class ExternalSorterSuite extends FunSuite with LocalSparkContext {

test("spilling in local cluster") {
val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
conf.set("spark.shuffle.memoryFraction", "0.001")
conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)

// reduceByKey - should spill ~8 times
val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
val resultA = rddA.reduceByKey(math.max).collect()
assert(resultA.length == 50000)
resultA.foreach { case(k, v) =>
k match {
case 0 => assert(v == 1)
case 25000 => assert(v == 50001)
case 49999 => assert(v == 99999)
case _ =>
}
}

// groupByKey - should spill ~17 times
val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i))
val resultB = rddB.groupByKey().collect()
assert(resultB.length == 25000)
resultB.foreach { case(i, seq) =>
i match {
case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3))
case 12500 => assert(seq.toSet == Set[Int](50000, 50001, 50002, 50003))
case 24999 => assert(seq.toSet == Set[Int](99996, 99997, 99998, 99999))
case _ =>
}
}

// cogroup - should spill ~7 times
val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i))
val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i))
val resultC = rddC1.cogroup(rddC2).collect()
assert(resultC.length == 10000)
resultC.foreach { case(i, (seq1, seq2)) =>
i match {
case 0 =>
assert(seq1.toSet == Set[Int](0))
assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
case 5000 =>
assert(seq1.toSet == Set[Int](5000))
assert(seq2.toSet == Set[Int]())
case 9999 =>
assert(seq1.toSet == Set[Int](9999))
assert(seq2.toSet == Set[Int]())
case _ =>
}
}
}
}
Loading

0 comments on commit bbf359d

Please sign in to comment.