diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java index bb77b5bf6de2a..40c2cc806e87a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -22,12 +22,10 @@ public final class RecordBinaryComparator extends RecordComparator { - // TODO(jiangxb) Add test suite for this. @Override public int compare( Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { int i = 0; - int res = 0; // If the arrays have different length, the longer one is larger. if (leftLen != rightLen) { @@ -40,27 +38,33 @@ public int compare( // check if stars align and we can get both offsets to be aligned if ((leftOff % 8) == (rightOff % 8)) { while ((leftOff + i) % 8 != 0 && i < leftLen) { - res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return res; + final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; + final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 1; } } // for architectures that support unaligned accesses, chew it up 8 bytes at a time if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { while (i <= leftLen - 8) { - res = (int) ((Platform.getLong(leftObj, leftOff + i) - - Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE); - if (res != 0) return res; + final long v1 = Platform.getLong(leftObj, leftOff + i); + final long v2 = Platform.getLong(rightObj, rightOff + i); + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 8; } } // this will finish off the unaligned comparisons, or do the entire aligned comparison // whichever is needed. while (i < leftLen) { - res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return res; + final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; + final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 1; } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java index a19ddbdbadba2..97f3dc588ecc5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java @@ -253,4 +253,70 @@ public void testBinaryComparatorForNullColumns() throws Exception { assert(compare(0, 0) == 0); assert(compare(0, 1) > 0); } + + @Test + public void testBinaryComparatorWhenSubtractionIsDivisibleByMaxIntValue() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setLong(0, 11); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setLong(0, 11L + Integer.MAX_VALUE); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorWhenSubtractionCanOverflowLongValue() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setLong(0, Long.MIN_VALUE); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setLong(0, 1); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorWhenOnlyTheLastColumnDiffers() throws Exception { + int numFields = 4; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setInt(0, 11); + row1.setDouble(1, 3.14); + row1.setInt(2, -1); + row1.setLong(3, 0); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setInt(0, 11); + row2.setDouble(1, 3.14); + row2.setInt(2, -1); + row2.setLong(3, 1); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } }