Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-4084] Reuse sort key in Sorter #2937

Closed
wants to merge 13 commits into from
66 changes: 44 additions & 22 deletions core/src/main/java/org/apache/spark/util/collection/Sorter.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@
* See the method comment on sort() for more details.
*
* This has been kept in Java with the original style in order to match very closely with the
* Anroid source code, and thus be easy to verify correctness.
* Android source code, and thus be easy to verify correctness.
*
* The purpose of the port is to generalize the interface to the sort to accept input data formats
* besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap
* uses this to sort an Array with alternating elements of the form [key, value, key, value].
* This generalization comes with minimal overhead -- see SortDataFormat for more information.
*
* We allow key reuse to prevent creating many key objects -- see SortDataFormat.
*
* @see org.apache.spark.util.collection.SortDataFormat
*/
class Sorter<K, Buffer> {

Expand Down Expand Up @@ -162,10 +166,13 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator<? super
if (start == lo)
start++;

K key0 = s.newKey();
K key1 = s.newKey();

Buffer pivotStore = s.allocate(1);
for ( ; start < hi; start++) {
s.copyElement(a, start, pivotStore, 0);
K pivot = s.getKey(pivotStore, 0);
K pivot = s.getKey(pivotStore, 0, key0);

// Set left (and right) to the index where a[start] (pivot) belongs
int left = lo;
Expand All @@ -178,7 +185,7 @@ private void binarySort(Buffer a, int lo, int hi, int start, Comparator<? super
*/
while (left < right) {
int mid = (left + right) >>> 1;
if (c.compare(pivot, s.getKey(a, mid)) < 0)
if (c.compare(pivot, s.getKey(a, mid, key1)) < 0)
right = mid;
else
left = mid + 1;
Expand Down Expand Up @@ -235,13 +242,16 @@ private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator<? supe
if (runHi == hi)
return 1;

K key1 = s.newKey();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key0, key1 pls

K key2 = s.newKey();

// Find end of run, and reverse range if descending
if (c.compare(s.getKey(a, runHi++), s.getKey(a, lo)) < 0) { // Descending
while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) < 0)
if (c.compare(s.getKey(a, runHi++, key1), s.getKey(a, lo, key2)) < 0) { // Descending
while (runHi < hi && c.compare(s.getKey(a, runHi, key1), s.getKey(a, runHi - 1, key2)) < 0)
runHi++;
reverseRange(a, lo, runHi);
} else { // Ascending
while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) >= 0)
while (runHi < hi && c.compare(s.getKey(a, runHi, key1), s.getKey(a, runHi - 1, key2)) >= 0)
runHi++;
}

Expand Down Expand Up @@ -468,11 +478,13 @@ private void mergeAt(int i) {
}
stackSize--;

K key0 = s.newKey();

/*
* Find where the first element of run2 goes in run1. Prior elements
* in run1 can be ignored (because they're already in place).
*/
int k = gallopRight(s.getKey(a, base2), a, base1, len1, 0, c);
int k = gallopRight(s.getKey(a, base2, key0), a, base1, len1, 0, c);
assert k >= 0;
base1 += k;
len1 -= k;
Expand All @@ -483,7 +495,7 @@ private void mergeAt(int i) {
* Find where the last element of run1 goes in run2. Subsequent elements
* in run2 can be ignored (because they're already in place).
*/
len2 = gallopLeft(s.getKey(a, base1 + len1 - 1), a, base2, len2, len2 - 1, c);
len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c);
assert len2 >= 0;
if (len2 == 0)
return;
Expand Down Expand Up @@ -517,10 +529,12 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator<
assert len > 0 && hint >= 0 && hint < len;
int lastOfs = 0;
int ofs = 1;
if (c.compare(key, s.getKey(a, base + hint)) > 0) {
K key0 = s.newKey();

if (c.compare(key, s.getKey(a, base + hint, key0)) > 0) {
// Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs]
int maxOfs = len - hint;
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) > 0) {
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) {
lastOfs = ofs;
ofs = (ofs << 1) + 1;
if (ofs <= 0) // int overflow
Expand All @@ -535,7 +549,7 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator<
} else { // key <= a[base + hint]
// Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs]
final int maxOfs = hint + 1;
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) <= 0) {
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) {
lastOfs = ofs;
ofs = (ofs << 1) + 1;
if (ofs <= 0) // int overflow
Expand All @@ -560,7 +574,7 @@ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator<
while (lastOfs < ofs) {
int m = lastOfs + ((ofs - lastOfs) >>> 1);

if (c.compare(key, s.getKey(a, base + m)) > 0)
if (c.compare(key, s.getKey(a, base + m, key0)) > 0)
lastOfs = m + 1; // a[base + m] < key
else
ofs = m; // key <= a[base + m]
Expand All @@ -587,10 +601,12 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator

int ofs = 1;
int lastOfs = 0;
if (c.compare(key, s.getKey(a, base + hint)) < 0) {
K key1 = s.newKey();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an input parameter called key. So this one became key1.


if (c.compare(key, s.getKey(a, base + hint, key1)) < 0) {
// Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs]
int maxOfs = hint + 1;
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) < 0) {
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) {
lastOfs = ofs;
ofs = (ofs << 1) + 1;
if (ofs <= 0) // int overflow
Expand All @@ -606,7 +622,7 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator
} else { // a[b + hint] <= key
// Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs]
int maxOfs = len - hint;
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) >= 0) {
while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) {
lastOfs = ofs;
ofs = (ofs << 1) + 1;
if (ofs <= 0) // int overflow
Expand All @@ -630,7 +646,7 @@ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator
while (lastOfs < ofs) {
int m = lastOfs + ((ofs - lastOfs) >>> 1);

if (c.compare(key, s.getKey(a, base + m)) < 0)
if (c.compare(key, s.getKey(a, base + m, key1)) < 0)
ofs = m; // key < a[b + m]
else
lastOfs = m + 1; // a[b + m] <= key
Expand Down Expand Up @@ -679,6 +695,9 @@ private void mergeLo(int base1, int len1, int base2, int len2) {
return;
}

K key0 = s.newKey();
K key1 = s.newKey();

Comparator<? super K> c = this.c; // Use local variable for performance
int minGallop = this.minGallop; // " " " " "
outer:
Expand All @@ -692,7 +711,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) {
*/
do {
assert len1 > 1 && len2 > 0;
if (c.compare(s.getKey(a, cursor2), s.getKey(tmp, cursor1)) < 0) {
if (c.compare(s.getKey(a, cursor2, key0), s.getKey(tmp, cursor1, key1)) < 0) {
s.copyElement(a, cursor2++, a, dest++);
count2++;
count1 = 0;
Expand All @@ -714,7 +733,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) {
*/
do {
assert len1 > 1 && len2 > 0;
count1 = gallopRight(s.getKey(a, cursor2), tmp, cursor1, len1, 0, c);
count1 = gallopRight(s.getKey(a, cursor2, key0), tmp, cursor1, len1, 0, c);
if (count1 != 0) {
s.copyRange(tmp, cursor1, a, dest, count1);
dest += count1;
Expand All @@ -727,7 +746,7 @@ private void mergeLo(int base1, int len1, int base2, int len2) {
if (--len2 == 0)
break outer;

count2 = gallopLeft(s.getKey(tmp, cursor1), a, cursor2, len2, 0, c);
count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c);
if (count2 != 0) {
s.copyRange(a, cursor2, a, dest, count2);
dest += count2;
Expand Down Expand Up @@ -784,6 +803,9 @@ private void mergeHi(int base1, int len1, int base2, int len2) {
int cursor2 = len2 - 1; // Indexes into tmp array
int dest = base2 + len2 - 1; // Indexes into a

K key0 = s.newKey();
K key1 = s.newKey();

// Move last element of first run and deal with degenerate cases
s.copyElement(a, cursor1--, a, dest--);
if (--len1 == 0) {
Expand Down Expand Up @@ -811,7 +833,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) {
*/
do {
assert len1 > 0 && len2 > 1;
if (c.compare(s.getKey(tmp, cursor2), s.getKey(a, cursor1)) < 0) {
if (c.compare(s.getKey(tmp, cursor2, key0), s.getKey(a, cursor1, key1)) < 0) {
s.copyElement(a, cursor1--, a, dest--);
count1++;
count2 = 0;
Expand All @@ -833,7 +855,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) {
*/
do {
assert len1 > 0 && len2 > 1;
count1 = len1 - gallopRight(s.getKey(tmp, cursor2), a, base1, len1, len1 - 1, c);
count1 = len1 - gallopRight(s.getKey(tmp, cursor2, key0), a, base1, len1, len1 - 1, c);
if (count1 != 0) {
dest -= count1;
cursor1 -= count1;
Expand All @@ -846,7 +868,7 @@ private void mergeHi(int base1, int len1, int base2, int len2) {
if (--len2 == 1)
break outer;

count2 = len2 - gallopLeft(s.getKey(a, cursor1), tmp, 0, len2, len2 - 1, c);
count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c);
if (count2 != 0) {
dest -= count2;
cursor2 -= count2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,20 @@ import scala.reflect.ClassTag
*/
// TODO: Making Buffer a real trait would be a better abstraction, but adds some complexity.
private[spark] trait SortDataFormat[K, Buffer] extends Any {

/** Creates a new mutable key for reuse. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note that this must be implemented only if you also override getKey(data, pos, reuse).

protected def newKey(): K = null.asInstanceOf[K]

/** Return the sort key for the element at the given index. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment that this is ONLY invoked by the default getKey(data: Buffer, pos: Int, reuse: K) method. That is, you should not call this from outside.

protected def getKey(data: Buffer, pos: Int): K

/**
* Returns the sort key for the element at the given index and reuse the input key if possible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note that the default implementation simply ignores the reuse parameter and invokes the other method. Also give the precondition that the "reused" key will have initially been constructed via newKey().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

*/
protected def getKey(data: Buffer, pos: Int, reuse: K): K = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's somewhat surprising that a universal trait can have a default implementation, but maybe we can convert this to an abstract class to ensure it's still compiled into simple Java bytecode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also funny that Java ignores the protectedness, maybe we can upgrade all actually-public methods to public (that's everything but getKey(data: Buffer, pos: Int), which is only used internally)

getKey(data, pos)
}

/** Swap two elements. */
protected def swap(data: Buffer, pos0: Int, pos1: Int): Unit

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ class SorterSuite extends FunSuite {
val rand = new XORShiftRandom(123)
val data0 = Array.tabulate[Int](10000) { i => rand.nextInt() }
val data1 = data0.clone()
val data2 = data0.clone()

Arrays.sort(data0)
new Sorter(new IntArraySortDataFormat).sort(data1, 0, data1.length, Ordering.Int)
new Sorter(new KeyReuseIntArraySortDataFormat)
.sort(data2, 0, data2.length, Ordering[IntWrapper])

data0.zip(data1).foreach { case (x, y) => assert(x === y) }
assert(data0.view === data1.view)
assert(data0.view === data2.view)
}

test("KVArraySorter") {
Expand Down Expand Up @@ -137,12 +141,7 @@ class SorterSuite extends FunSuite {
}
}


/** Format to sort a simple Array[Int]. Could be easily generified and specialized. */
class IntArraySortDataFormat extends SortDataFormat[Int, Array[Int]] {
override protected def getKey(data: Array[Int], pos: Int): Int = {
data(pos)
}
abstract class AbstractIntArraySortDataFormat[K] extends SortDataFormat[K, Array[Int]] {

override protected def swap(data: Array[Int], pos0: Int, pos1: Int): Unit = {
val tmp = data(pos0)
Expand All @@ -165,3 +164,39 @@ class IntArraySortDataFormat extends SortDataFormat[Int, Array[Int]] {
new Array[Int](length)
}
}

/** Format to sort a simple Array[Int]. Could be easily generified and specialized. */
class IntArraySortDataFormat extends AbstractIntArraySortDataFormat[Int] {

override protected def getKey(data: Array[Int], pos: Int): Int = {
data(pos)
}
}

/** Wrapper of Int for key reuse. */
class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] {
override def compare(that: IntWrapper): Int = {
key.compareTo(that.key)
}
}

/** SortDataFormat for Array[Int] with reused keys. */
class KeyReuseIntArraySortDataFormat extends AbstractIntArraySortDataFormat[IntWrapper] {

override protected def newKey(): IntWrapper = {
new IntWrapper()
}

override protected def getKey(data: Array[Int], pos: Int, reuse: IntWrapper): IntWrapper = {
if (reuse == null) {
new IntWrapper(data(pos))
} else {
reuse.key = data(pos)
reuse
}
}

override protected def getKey(data: Array[Int], pos: Int): IntWrapper = {
getKey(data, pos, null)
}
}