Skip to content

Commit

Permalink
Handle null keys in hash-based comparator, and add tests for collisions
Browse files Browse the repository at this point in the history
  • Loading branch information
mateiz committed Jul 30, 2014
1 parent ef4e397 commit ba7db7f
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.storage.BlockId
* @param aggregator optional Aggregator with combine functions to use for merging data
* @param partitioner optional partitioner; if given, sort by partition ID and then key
* @param ordering optional ordering to sort keys within each partition
* @param serializer serializer to use
* @param serializer serializer to use when spilling to disk
*/
private[spark] class ExternalSorter[K, V, C](
aggregator: Option[Aggregator[K, V, C]] = None,
Expand Down Expand Up @@ -95,7 +95,11 @@ private[spark] class ExternalSorter[K, V, C](
// non-equal keys also have this, so we need to do a later pass to find truly equal keys).
// Note that we ignore this if no aggregator is given.
private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
override def compare(a: K, b: K): Int = a.hashCode() - b.hashCode()
override def compare(a: K, b: K): Int = {
val h1 = if (a == null) 0 else a.hashCode()
val h2 = if (b == null) 0 else b.hashCode()
h1 - h2
}
})

private val sortWithinPartitions = ordering.isDefined || aggregator.isDefined
Expand Down Expand Up @@ -215,7 +219,6 @@ private[spark] class ExternalSorter[K, V, C](
val batchSizes = new ArrayBuffer[Long]

// How many elements we have in each partition
// TODO: this could become a sparser data structure
val elementsPerPartition = new Array[Long](numPartitions)

// Flush the disk writer's contents to disk, and update relevant variables
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

package org.apache.spark.util.collection

import scala.collection.mutable.ArrayBuffer

import org.scalatest.FunSuite

import org.apache.spark._
import org.apache.spark.SparkContext._
import scala.Some

class ExternalSorterSuite extends FunSuite with LocalSparkContext {
test("spilling in local cluster") {
Expand Down Expand Up @@ -332,4 +333,133 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
}).toSeq
assert(results === expected)
}

test("spilling with hash collisions") {
val conf = new SparkConf(true)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)

def createCombiner(i: String) = ArrayBuffer[String](i)
def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i
def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) =
buffer1 ++= buffer2

val agg = new Aggregator[String, String, ArrayBuffer[String]](
createCombiner _, mergeValue _, mergeCombiners _)

val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
Some(agg), None, None, None)

val collisionPairs = Seq(
("Aa", "BB"), // 2112
("to", "v1"), // 3707
("variants", "gelato"), // -1249574770
("Teheran", "Siblings"), // 231609873
("misused", "horsemints"), // 1069518484
("isohel", "epistolaries"), // -1179291542
("righto", "buzzards"), // -931102253
("hierarch", "crinolines"), // -1732884796
("inwork", "hypercatalexes"), // -1183663690
("wainages", "presentencing"), // 240183619
("trichothecenes", "locular"), // 339006536
("pomatoes", "eructation") // 568647356
)

collisionPairs.foreach { case (w1, w2) =>
// String.hashCode is documented to use a specific algorithm, but check just in case
assert(w1.hashCode === w2.hashCode)
}

val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++
collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap)

sorter.write(toInsert)

// A map of collision pairs in both directions
val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap

// Avoid map.size or map.iterator.length because this destructively sorts the underlying map
var count = 0

val it = sorter.iterator
while (it.hasNext) {
val kv = it.next()
val expectedValue = ArrayBuffer[String](collisionPairsMap.getOrElse(kv._1, kv._1))
assert(kv._2.equals(expectedValue))
count += 1
}
assert(count === 100000 + collisionPairs.size * 2)
}

test("spilling with many hash collisions") {
val conf = new SparkConf(true)
conf.set("spark.shuffle.memoryFraction", "0.0001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)

val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None)

// Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
// problems if the map fails to group together the objects with the same code (SPARK-2043).
val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1)
sorter.write(toInsert.iterator)

val it = sorter.iterator
var count = 0
while (it.hasNext) {
val kv = it.next()
assert(kv._2 === 10)
count += 1
}
assert(count === 10000)
}

test("spilling with hash collisions using the Int.MaxValue key") {
val conf = new SparkConf(true)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)

def createCombiner(i: Int) = ArrayBuffer[Int](i)
def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i
def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2

val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners)
val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None)

sorter.write((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue)))

val it = sorter.iterator
while (it.hasNext) {
// Should not throw NoSuchElementException
it.next()
}
}

test("spilling with null keys and values") {
val conf = new SparkConf(true)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)

def createCombiner(i: String) = ArrayBuffer[String](i)
def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i
def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]) = buf1 ++= buf2

val agg = new Aggregator[String, String, ArrayBuffer[String]](
createCombiner, mergeValue, mergeCombiners)

val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
Some(agg), None, None, None)

sorter.write((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator(
(null.asInstanceOf[String], "1"),
("1", null.asInstanceOf[String]),
(null.asInstanceOf[String], null.asInstanceOf[String])
))

val it = sorter.iterator
while (it.hasNext) {
// Should not throw NullPointerException
it.next()
}
}
}

0 comments on commit ba7db7f

Please sign in to comment.