diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
index ba1f89d099838..4fda87ab57c49 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
@@ -17,13 +17,14 @@
package org.apache.spark.shuffle.unsafe;
-import org.junit.Assert;
import org.junit.Test;
+import static org.junit.Assert.*;
import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
import org.apache.spark.unsafe.memory.MemoryAllocator;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*;
public class PackedRecordPointerSuite {
@@ -36,8 +37,8 @@ public void heap() {
final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42);
PackedRecordPointer packedPointer = new PackedRecordPointer();
packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
- Assert.assertEquals(360, packedPointer.getPartitionId());
- Assert.assertEquals(addressInPage1, packedPointer.getRecordPointer());
+ assertEquals(360, packedPointer.getPartitionId());
+ assertEquals(addressInPage1, packedPointer.getRecordPointer());
memoryManager.cleanUpAllAllocatedMemory();
}
@@ -50,8 +51,43 @@ public void offHeap() {
final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, 42);
PackedRecordPointer packedPointer = new PackedRecordPointer();
packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
- Assert.assertEquals(360, packedPointer.getPartitionId());
- Assert.assertEquals(addressInPage1, packedPointer.getRecordPointer());
+ assertEquals(360, packedPointer.getPartitionId());
+ assertEquals(addressInPage1, packedPointer.getRecordPointer());
memoryManager.cleanUpAllAllocatedMemory();
}
+
+ @Test
+ public void maximumPartitionIdCanBeEncoded() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID));
+ assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId());
+ }
+
+ @Test
+ public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ try {
+ // Pointers greater than the maximum partition ID will overflow or trigger an assertion error
+ packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1));
+ assertFalse(MAXIMUM_PARTITION_ID + 1 == packedPointer.getPartitionId());
+ } catch (AssertionError e ) {
+ // pass
+ }
+ }
+
+ @Test
+ public void maximumOffsetInPageCanBeEncoded() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1);
+ packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+ assertEquals(address, packedPointer.getRecordPointer());
+ }
+
+ @Test
+ public void offsetsPastMaxOffsetInPageWillOverflow() {
+ PackedRecordPointer packedPointer = new PackedRecordPointer();
+ long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES);
+ packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+ assertEquals(0, packedPointer.getRecordPointer());
+ }
}
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
index 8451e8d9a9785..61511de6a5219 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -35,10 +35,10 @@
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
-import static org.junit.Assert.*;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.lessThan;
+import static org.junit.Assert.*;
import static org.mockito.AdditionalAnswers.returnsFirstArg;
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Mockito.*;
diff --git a/unsafe/pom.xml b/unsafe/pom.xml
index 5b0733206b2bc..9e151fc7a9141 100644
--- a/unsafe/pom.xml
+++ b/unsafe/pom.xml
@@ -42,6 +42,10 @@
com.google.code.findbugs
jsr305
+
+ com.google.guava
+ guava
+
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index 2aacf637eb6a4..cfd54035bee99 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -19,6 +19,7 @@
import java.util.*;
+import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -169,8 +170,13 @@ public void free(MemoryBlock memory) {
* This address will remain valid as long as the corresponding page has not been freed.
*/
public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
- assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
- return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
+ return encodePageNumberAndOffset(page.pageNumber, offsetInPage);
+ }
+
+ @VisibleForTesting
+ public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
+ assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
+ return (((long) pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
}
/**