From c685c43aa161bf352b20c72b30ce107af547fcd9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 12 Jan 2024 13:49:24 -0800 Subject: [PATCH] [SPARK-46700][CORE] Count the last spilling for the shuffle disk spilling bytes metric ### What changes were proposed in this pull request? This PR fixes a long-standing bug in ShuffleExternalSorter about the "spilled disk bytes" metrics. When we close the sorter, we will spill the remaining data in the buffer, with a flag `isLastFile = true`. This flag means the spilling will not increase the "spilled disk bytes" metrics. This makes sense if the sorter has never spilled before, then the final spill file will be used as the final shuffle output file, and we should keep the "spilled disk bytes" metrics as 0. However, if spilling did happen before, then we simply miscount the final spill file for the "spilled disk bytes" metrics today. This PR fixes this issue, by setting that flag when closing the sorter only if this is the first spilling. ### Why are the changes needed? make metrics accurate ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? updated tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #44709 from cloud-fan/shuffle. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun (cherry picked from commit 4ea374257c1fdb276abcd6b953ba042593e4d5a3) Signed-off-by: Dongjoon Hyun --- .../shuffle/sort/ShuffleExternalSorter.java | 34 +++++++++++-------- .../shuffle/sort/UnsafeShuffleWriter.java | 6 ---- .../sort/UnsafeShuffleWriterSuite.java | 20 +++++++---- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index a82f691d085d4..b097089282ce3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -150,11 +150,21 @@ public long[] getChecksums() { * Sorts the in-memory records and writes the sorted records to an on-disk file. * This method does not free the sort data structures. * - * @param isLastFile if true, this indicates that we're writing the final output file and that the - * bytes written should be counted towards shuffle spill metrics rather than - * shuffle write metrics. + * @param isFinalFile if true, this indicates that we're writing the final output file and that + * the bytes written should be counted towards shuffle write metrics rather + * than shuffle spill metrics. */ - private void writeSortedFile(boolean isLastFile) { + private void writeSortedFile(boolean isFinalFile) { + // Only emit the log if this is an actual spilling. + if (!isFinalFile) { + logger.info( + "Task {} on Thread {} spilling sort data of {} to disk ({} {} so far)", + taskContext.taskAttemptId(), + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spills.size(), + spills.size() != 1 ? " times" : " time"); + } // This call performs the actual sort. final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = @@ -167,13 +177,14 @@ private void writeSortedFile(boolean isLastFile) { final ShuffleWriteMetricsReporter writeMetricsToUse; - if (isLastFile) { + if (isFinalFile) { // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. writeMetricsToUse = writeMetrics; } else { // We're spilling, so bytes written should be counted towards spill rather than write. // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count // them towards shuffle bytes written. + // The actual shuffle bytes written will be counted when we merge the spill files. writeMetricsToUse = new ShuffleWriteMetrics(); } @@ -246,7 +257,7 @@ private void writeSortedFile(boolean isLastFile) { spills.add(spillInfo); } - if (!isLastFile) { // i.e. this is a spill file + if (!isFinalFile) { // i.e. this is a spill file // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter // relies on its `recordWritten()` method being called in order to trigger periodic updates to @@ -281,12 +292,6 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { return 0L; } - logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", - Thread.currentThread().getId(), - Utils.bytesToString(getMemoryUsage()), - spills.size(), - spills.size() > 1 ? " times" : " time"); - writeSortedFile(false); final long spillSize = freeMemory(); inMemSorter.reset(); @@ -440,8 +445,9 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p */ public SpillInfo[] closeAndGetSpills() throws IOException { if (inMemSorter != null) { - // Do not count the final file towards the spill count. - writeSortedFile(true); + // Here we are spilling the remaining data in the buffer. If there is no spill before, this + // final spill file will be the final shuffle output file. + writeSortedFile(/* isFinalFile = */spills.isEmpty()); freeMemory(); inMemSorter.free(); inMemSorter = null; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 9c54184105951..d5b4eb138b1a6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -327,12 +327,6 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep logger.debug("Using slow merge"); mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); } - // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has - // in-memory records, we write out the in-memory records to a file but do not count that - // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs - // to be counted as shuffle write, but this will lead to double-counting of the final - // SpillInfo's bytes. - writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); partitionLengths = mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths(); } catch (Exception e) { try { diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index d3aa93549a83a..1fa17b908699f 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -69,6 +69,7 @@ public class UnsafeShuffleWriterSuite implements ShuffleChecksumTestHelper { File tempDir; long[] partitionSizesInMergedFile; final LinkedList spillFilesCreated = new LinkedList<>(); + long totalSpilledDiskBytes = 0; SparkConf conf; final Serializer serializer = new KryoSerializer(new SparkConf().set("spark.kryo.unsafe", "false")); @@ -96,6 +97,7 @@ public void setUp() throws Exception { mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); + totalSpilledDiskBytes = 0; conf = new SparkConf() .set(package$.MODULE$.BUFFER_PAGESIZE().key(), "1m") .set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), false) @@ -160,7 +162,11 @@ public void setUp() throws Exception { when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); - File file = File.createTempFile("spillFile", ".spill", tempDir); + File file = spy(File.createTempFile("spillFile", ".spill", tempDir)); + when(file.delete()).thenAnswer(inv -> { + totalSpilledDiskBytes += file.length(); + return inv.callRealMethod(); + }); spillFilesCreated.add(file); return Tuple2$.MODULE$.apply(blockId, file); }); @@ -284,6 +290,9 @@ public void writeWithoutSpilling() throws Exception { final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); assertTrue(mergedOutputFile.exists()); + // Even if there is no spill, the sorter still writes its data to a spill file at the end, + // which will become the final shuffle file. + assertEquals(1, spillFilesCreated.size()); long sumOfPartitionSizes = 0; for (long size: partitionSizesInMergedFile) { @@ -425,9 +434,8 @@ private void testMergingSpills( assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); - assertTrue(taskMetrics.diskBytesSpilled() > 0L); - assertTrue(taskMetrics.diskBytesSpilled() < mergedOutputFile.length()); assertTrue(taskMetrics.memoryBytesSpilled() > 0L); + assertEquals(totalSpilledDiskBytes, taskMetrics.diskBytesSpilled()); assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } @@ -517,9 +525,8 @@ public void writeEnoughDataToTriggerSpill() throws Exception { assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); - assertTrue(taskMetrics.diskBytesSpilled() > 0L); - assertTrue(taskMetrics.diskBytesSpilled() < mergedOutputFile.length()); assertTrue(taskMetrics.memoryBytesSpilled()> 0L); + assertEquals(totalSpilledDiskBytes, taskMetrics.diskBytesSpilled()); assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } @@ -550,9 +557,8 @@ private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exc assertSpillFilesWereCleanedUp(); ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); - assertTrue(taskMetrics.diskBytesSpilled() > 0L); - assertTrue(taskMetrics.diskBytesSpilled() < mergedOutputFile.length()); assertTrue(taskMetrics.memoryBytesSpilled()> 0L); + assertEquals(totalSpilledDiskBytes, taskMetrics.diskBytesSpilled()); assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); }