Skip to content

Commit

Permalink
Import my original tests and get them to pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 6, 2015
1 parent d5d3106 commit 2bd8c9a
Show file tree
Hide file tree
Showing 4 changed files with 337 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ public UnsafeSorterSpillWriter(
}

public void write(
Object baseObject,
long baseOffset,
int recordLength,
long keyPrefix) throws IOException {
Object baseObject,
long baseOffset,
int recordLength,
long keyPrefix) throws IOException {
dos.writeInt(recordLength);
dos.writeLong(keyPrefix);
PlatformDependent.copyMemory(
Expand All @@ -72,7 +72,6 @@ public void write(
PlatformDependent.BYTE_ARRAY_OFFSET,
recordLength);
writer.write(arr, 0, recordLength);
// TODO: add a test that detects whether we leave this call out:
writer.recordWritten();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,14 @@ private[spark] class DiskBlockObjectWriter(
recordWritten()
}

override def write(b: Int): Unit = throw new UnsupportedOperationException()
override def write(b: Int): Unit = {
// TOOD: re-enable the `throw new UnsupportedOperationException()` here
if (!initialized) {
open()
}

bs.write(b)
}

override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
if (!initialized) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.collection.unsafe.sort;

import java.io.File;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.UUID;

import scala.Tuple2;
import scala.Tuple2$;
import scala.runtime.AbstractFunction1;

import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import static org.junit.Assert.*;
import static org.mockito.AdditionalAnswers.returnsFirstArg;
import static org.mockito.AdditionalAnswers.returnsSecondArg;
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Mockito.*;

import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.storage.*;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
import org.apache.spark.unsafe.memory.MemoryAllocator;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;

public class UnsafeExternalSorterSuite {

final TaskMemoryManager memoryManager =
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
// Compute key prefixes based on the records' partition ids
final HashPartitioner hashPartitioner = new HashPartitioner(4);
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
final PrefixComparator prefixComparator = new PrefixComparator() {
@Override
public int compare(long prefix1, long prefix2) {
return (int) prefix1 - (int) prefix2;
}
};
// Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
// use a dummy comparator
final RecordComparator recordComparator = new RecordComparator() {
@Override
public int compare(
Object leftBaseObject,
long leftBaseOffset,
Object rightBaseObject,
long rightBaseOffset) {
return 0;
}
};

@Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
@Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
@Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;

File tempDir;

private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
@Override
public OutputStream apply(OutputStream stream) {
return stream;
}
}

@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
tempDir = new File(Utils.createTempDir$default$1());
taskContext = mock(TaskContext.class);
when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer<Tuple2<TempLocalBlockId, File>>() {
@Override
public Tuple2<TempLocalBlockId, File> answer(InvocationOnMock invocationOnMock) throws Throwable {
TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
File file = File.createTempFile("spillFile", ".spill", tempDir);
return Tuple2$.MODULE$.apply(blockId, file);
}
});
when(blockManager.getDiskWriter(
any(BlockId.class),
any(File.class),
any(SerializerInstance.class),
anyInt(),
any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
@Override
public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
Object[] args = invocationOnMock.getArguments();

return new DiskBlockObjectWriter(
(BlockId) args[0],
(File) args[1],
(SerializerInstance) args[2],
(Integer) args[3],
new CompressStream(),
false,
(ShuffleWriteMetrics) args[4]
);
}
});
when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
.then(returnsSecondArg());
}

private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception {
final int[] arr = new int[] { value };
sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value);
}

/**
* Tests the type of sorting that's used in the non-combiner path of sort-based shuffle.
*/
@Test
public void testSortingOnlyByPartitionId() throws Exception {

final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
memoryManager,
shuffleMemoryManager,
blockManager,
taskContext,
recordComparator,
prefixComparator,
1024,
new SparkConf());

insertNumber(sorter, 5);
insertNumber(sorter, 1);
insertNumber(sorter, 3);
sorter.spill();
insertNumber(sorter, 4);
insertNumber(sorter, 2);

UnsafeSorterIterator iter = sorter.getSortedIterator();

iter.loadNext();
assertEquals(1, iter.getKeyPrefix());
iter.loadNext();
assertEquals(2, iter.getKeyPrefix());
iter.loadNext();
assertEquals(3, iter.getKeyPrefix());
iter.loadNext();
assertEquals(4, iter.getKeyPrefix());
iter.loadNext();
assertEquals(5, iter.getKeyPrefix());
assertFalse(iter.hasNext());
// TODO: check that the values are also read back properly.

// TODO: test for cleanup:
// assert(tempDir.isEmpty)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.collection.unsafe.sort;

import java.util.Arrays;

import org.junit.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;

import org.apache.spark.HashPartitioner;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
import org.apache.spark.unsafe.memory.MemoryAllocator;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.TaskMemoryManager;

public class UnsafeInMemorySorterSuite {

private static String getStringFromDataPage(Object baseObject, long baseOffset) {
final int strLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset);
final byte[] strBytes = new byte[strLength];
PlatformDependent.copyMemory(
baseObject,
baseOffset + 8,
strBytes,
PlatformDependent.BYTE_ARRAY_OFFSET, strLength);
return new String(strBytes);
}

@Test
public void testSortingEmptyInput() {
final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)),
mock(RecordComparator.class),
mock(PrefixComparator.class),
100);
final UnsafeSorterIterator iter = sorter.getSortedIterator();
assert(!iter.hasNext());
}

/**
* Tests the type of sorting that's used in the non-combiner path of sort-based shuffle.
*/
@Test
public void testSortingOnlyByPartitionId() throws Exception {
final String[] dataToSort = new String[] {
"Boba",
"Pearls",
"Tapioca",
"Taho",
"Condensed Milk",
"Jasmine",
"Milk Tea",
"Lychee",
"Mango"
};
final TaskMemoryManager memoryManager =
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
final MemoryBlock dataPage = memoryManager.allocatePage(2048);
final Object baseObject = dataPage.getBaseObject();
// Write the records into the data page:
long position = dataPage.getBaseOffset();
for (String str : dataToSort) {
final byte[] strBytes = str.getBytes("utf-8");
PlatformDependent.UNSAFE.putLong(baseObject, position, strBytes.length);
position += 8;
PlatformDependent.copyMemory(
strBytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
baseObject,
position,
strBytes.length);
position += strBytes.length;
}
// Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
// use a dummy comparator
final RecordComparator recordComparator = new RecordComparator() {
@Override
public int compare(
Object leftBaseObject,
long leftBaseOffset,
Object rightBaseObject,
long rightBaseOffset) {
return 0;
}
};
// Compute key prefixes based on the records' partition ids
final HashPartitioner hashPartitioner = new HashPartitioner(4);
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
final PrefixComparator prefixComparator = new PrefixComparator() {
@Override
public int compare(long prefix1, long prefix2) {
return (int) prefix1 - (int) prefix2;
}
};
UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator,
prefixComparator, dataToSort.length);
// Given a page of records, insert those records into the sorter one-by-one:
position = dataPage.getBaseOffset();
for (int i = 0; i < dataToSort.length; i++) {
// position now points to the start of a record (which holds its length).
final long recordLength = PlatformDependent.UNSAFE.getLong(baseObject, position);
final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);
final String str = getStringFromDataPage(baseObject, position);
final int partitionId = hashPartitioner.getPartition(str);
sorter.insertRecord(address, partitionId);
position += 8 + recordLength;
}
final UnsafeSorterIterator iter = sorter.getSortedIterator();
int iterLength = 0;
long prevPrefix = -1;
Arrays.sort(dataToSort);
while (iter.hasNext()) {
iter.loadNext();
final String str = getStringFromDataPage(iter.getBaseObject(), iter.getBaseOffset());
final long keyPrefix = iter.getKeyPrefix();
assertTrue(Arrays.asList(dataToSort).contains(str));
assertThat(keyPrefix, greaterThanOrEqualTo(prevPrefix));
prevPrefix = keyPrefix;
iterLength++;
}
assertEquals(dataToSort.length, iterLength);
}
}

0 comments on commit 2bd8c9a

Please sign in to comment.