From b71d05e3bde1082fedd06bb1f05038f46f1b78ff Mon Sep 17 00:00:00 2001 From: mingji Date: Fri, 17 Jan 2025 00:08:55 +0800 Subject: [PATCH] [CELEBORN-1838] Interrupt spark task should not report fetch failure --- .../flink/client/FlinkShuffleClientImpl.java | 14 +- .../celeborn/CelebornShuffleReader.scala | 65 ++++++-- .../celeborn/CelebornShuffleReaderSuite.scala | 79 ++++++++++ .../celeborn/client/DummyShuffleClient.java | 4 + .../celeborn/client/ShuffleClientImpl.java | 28 ++-- .../celeborn/client/ShuffleClientSuiteJ.java | 146 +++++++++++++++++- .../apache/celeborn/common/CelebornConf.scala | 11 ++ ...te.scala => CelebornStageRerunSuite.scala} | 4 +- 8 files changed, 315 insertions(+), 36 deletions(-) create mode 100644 client-spark/spark-3-4/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala rename client/src/{test => main}/java/org/apache/celeborn/client/DummyShuffleClient.java (97%) rename tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/{CelebornShuffleLostSuite.scala => CelebornStageRerunSuite.scala} (96%) diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java index bdf64691ba8..7c909c7476b 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/client/FlinkShuffleClientImpl.java @@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import scala.Tuple2; +import scala.Tuple3; import scala.reflect.ClassTag$; import com.google.common.annotations.VisibleForTesting; @@ -265,9 +265,9 @@ public ReduceFileGroups updateFileGroup( int shuffleId, int partitionId, boolean isSegmentGranularityVisible) throws CelebornIOException { ReduceFileGroups reduceFileGroups = - reduceFileGroupsMap.computeIfAbsent( - shuffleId, (id) -> Tuple2.apply(new ReduceFileGroups(), null)) - ._1; + reduceFileGroupsMap + .computeIfAbsent(shuffleId, (id) -> Tuple3.apply(new ReduceFileGroups(), null, null)) + ._1(); if (reduceFileGroups.partitionIds != null && reduceFileGroups.partitionIds.contains(partitionId)) { logger.debug( @@ -281,12 +281,12 @@ public ReduceFileGroups updateFileGroup( Utils.makeReducerKey(shuffleId, partitionId)); } else { // refresh file groups - Tuple2 fileGroups = + Tuple3 fileGroups = loadFileGroupInternal(shuffleId, isSegmentGranularityVisible); - ReduceFileGroups newGroups = fileGroups._1; + ReduceFileGroups newGroups = fileGroups._1(); if (newGroups == null) { throw new CelebornIOException( - loadFileGroupException(shuffleId, partitionId, fileGroups._2)); + loadFileGroupException(shuffleId, partitionId, fileGroups._2())); } else if (!newGroups.partitionIds.contains(partitionId)) { throw new CelebornIOException( String.format( diff --git a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index ad802f8cef9..05ae4c808ae 100644 --- a/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3-4/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -18,12 +18,14 @@ package org.apache.spark.shuffle.celeborn import java.io.IOException +import java.nio.file.Files import java.util -import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit} import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ +import com.google.common.annotations.VisibleForTesting import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext} import org.apache.spark.celeborn.ExceptionMakerHelper import org.apache.spark.internal.Logging @@ -33,14 +35,14 @@ import org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.client.{DummyShuffleClient, ShuffleClient} import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback} import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException} import org.apache.celeborn.common.network.client.TransportClient import org.apache.celeborn.common.network.protocol.TransportMessage -import org.apache.celeborn.common.protocol.{MessageType, PartitionLocation, PbOpenStreamList, PbOpenStreamListResponse, PbStreamHandler} +import org.apache.celeborn.common.protocol._ import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils, Utils} @@ -56,18 +58,34 @@ class CelebornShuffleReader[K, C]( shuffleIdTracker: ExecutorShuffleIdTracker) extends ShuffleReader[K, C] with Logging { - private val dep = handle.dependency - private val shuffleClient = ShuffleClient.get( - handle.appUniqueId, - handle.lifecycleManagerHost, - handle.lifecycleManagerPort, - conf, - handle.userIdentifier, - handle.extension) + val mockReader = conf.testMockShuffleReader + + private val dep = + if (mockReader) { + null + } else { + handle.dependency + } + + @VisibleForTesting + val shuffleClient = + if (mockReader) { + new DummyShuffleClient(conf, Files.createTempFile("test", "mockfile").toFile) + } else { + ShuffleClient.get( + handle.appUniqueId, + handle.lifecycleManagerHost, + handle.lifecycleManagerPort, + conf, + handle.userIdentifier, + handle.extension) + } private val exceptionRef = new AtomicReference[IOException] - private val throwsFetchFailure = handle.throwsFetchFailure - private val encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(context) + private val throwsFetchFailure = + if (mockReader) conf.clientStageRerunEnabled else handle.throwsFetchFailure + private val encodedAttemptId = + if (mockReader) 0 else SparkCommonUtils.getEncodedAttemptNumber(context) override def read(): Iterator[Product2[K, C]] = { @@ -111,7 +129,9 @@ class CelebornShuffleReader[K, C]( fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition) } catch { case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) => - handleFetchExceptions(handle.shuffleId, shuffleId, 0, ce) + // if a task is interrupted, should not report fetch failure + // if a task update file group timeout, should not report fetch failure + checkAndReportFetchFailureForUpdateFileGroupFailure(shuffleId, ce) case e: Throwable => throw e } @@ -370,7 +390,22 @@ class CelebornShuffleReader[K, C]( } } - private def handleFetchExceptions( + @VisibleForTesting + def checkAndReportFetchFailureForUpdateFileGroupFailure( + celebornShuffleId: Int, + ce: Throwable): Unit = { + if (ce.getCause != null && + ce.getCause.isInstanceOf[InterruptedException] || ce.getCause.isInstanceOf[ + TimeoutException]) { + logWarning(s"fetch shuffle ${celebornShuffleId} timeout or interrupt", ce) + throw ce + } else { + handleFetchExceptions(handle.shuffleId, celebornShuffleId, 0, ce) + } + } + + @VisibleForTesting + def handleFetchExceptions( appShuffleId: Int, shuffleId: Int, partitionId: Int, diff --git a/client-spark/spark-3-4/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala b/client-spark/spark-3-4/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala new file mode 100644 index 00000000000..dada6e8df05 --- /dev/null +++ b/client-spark/spark-3-4/src/test/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReaderSuite.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.celeborn + +import java.util.concurrent.TimeoutException + +import org.apache.spark.TaskContext +import org.apache.spark.shuffle.ShuffleReadMetricsReporter +import org.mockito.Mockito +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.celeborn.client.DummyShuffleClient +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.exception.CelebornIOException + +class CelebornShuffleReaderSuite extends AnyFunSuite { + + /** + * Due to spark limitations, spark local mode can not test speculation tasks , + * test the method `checkAndReportFetchFailureForUpdateFileGroupFailure` + */ + test("CELEBORN-1838 test check report fetch failure exceptions ") { + val handler = Mockito.mock(classOf[CelebornShuffleHandle[Int, Int, Int]]) + val context = Mockito.mock(classOf[TaskContext]) + val metricReporter = Mockito.mock(classOf[ShuffleReadMetricsReporter]) + val conf = new CelebornConf() + conf.set("celeborn.test.client.mockShuffleReader", "true") + conf.set("celeborn.client.spark.stageRerun.enabled", "true") + + val shuffleReader = + new CelebornShuffleReader[Int, Int](handler, 0, 0, 0, 0, context, conf, metricReporter, null) + + val exception1: Throwable = new CelebornIOException("test1", new InterruptedException("test1")) + val exception2: Throwable = new CelebornIOException("test2", new TimeoutException("test2")) + val exception3: Throwable = new CelebornIOException("test3") + val exception4: Throwable = new CelebornIOException("test4") + + try { + shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception1) + } catch { + case _: Throwable => + } + try { + shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception2) + } catch { + case _: Throwable => + } + try { + shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception3) + } catch { + case _: Throwable => + } + assert( + shuffleReader.shuffleClient.asInstanceOf[DummyShuffleClient].fetchFailureCount.get() === 1) + try { + shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception4) + } catch { + case _: Throwable => + } + assert( + shuffleReader.shuffleClient.asInstanceOf[DummyShuffleClient].fetchFailureCount.get() === 2) + + } +} diff --git a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java similarity index 97% rename from client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java rename to client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java index 77a9c784c4a..2b933181faa 100644 --- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -55,6 +56,8 @@ public class DummyShuffleClient extends ShuffleClient { private final Map> reducePartitionMap = new HashMap<>(); + public AtomicInteger fetchFailureCount = new AtomicInteger(); + public DummyShuffleClient(CelebornConf conf, File file) throws Exception { this.os = new BufferedOutputStream(new FileOutputStream(file)); this.conf = conf; @@ -181,6 +184,7 @@ public int getShuffleId( @Override public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) { + fetchFailureCount.incrementAndGet(); return true; } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 48412f81f12..8dbccb00b49 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -26,6 +26,7 @@ import java.util.concurrent.TimeUnit; import scala.Tuple2; +import scala.Tuple3; import scala.reflect.ClassTag$; import com.google.common.annotations.VisibleForTesting; @@ -170,7 +171,7 @@ public void update(ReduceFileGroups fileGroups) { } // key: shuffleId - protected final Map> reduceFileGroupsMap = + protected final Map> reduceFileGroupsMap = JavaUtils.newConcurrentHashMap(); public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier userIdentifier) { @@ -1742,11 +1743,12 @@ public boolean cleanupShuffle(int shuffleId) { return true; } - protected Tuple2 loadFileGroupInternal( + protected Tuple3 loadFileGroupInternal( int shuffleId, boolean isSegmentGranularityVisible) { { long getReducerFileGroupStartTime = System.nanoTime(); String exceptionMsg = null; + Exception exception = null; try { if (lifecycleManagerRef == null) { exceptionMsg = "Driver endpoint is null!"; @@ -1768,9 +1770,10 @@ protected Tuple2 loadFileGroupInternal( shuffleId, TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - getReducerFileGroupStartTime), response.fileGroup().size()); - return Tuple2.apply( + return Tuple3.apply( new ReduceFileGroups( response.fileGroup(), response.attempts(), response.partitionIds()), + null, null); case SHUFFLE_NOT_REGISTERED: logger.warn( @@ -1779,9 +1782,10 @@ protected Tuple2 loadFileGroupInternal( response.status(), shuffleId); // return empty result - return Tuple2.apply( + return Tuple3.apply( new ReduceFileGroups( response.fileGroup(), response.attempts(), response.partitionIds()), + null, null); case STAGE_END_TIME_OUT: case SHUFFLE_DATA_LOST: @@ -1800,8 +1804,9 @@ protected Tuple2 loadFileGroupInternal( } logger.error("Exception raised while call GetReducerFileGroup for {}.", shuffleId, e); exceptionMsg = e.getMessage(); + exception = e; } - return Tuple2.apply(null, exceptionMsg); + return Tuple3.apply(null, exceptionMsg, exception); } } @@ -1814,21 +1819,22 @@ public ReduceFileGroups updateFileGroup(int shuffleId, int partitionId) public ReduceFileGroups updateFileGroup( int shuffleId, int partitionId, boolean isSegmentGranularityVisible) throws CelebornIOException { - Tuple2 fileGroupTuple = + Tuple3 fileGroupTuple = reduceFileGroupsMap.compute( shuffleId, (id, existsTuple) -> { - if (existsTuple == null || existsTuple._1 == null) { + if (existsTuple == null || existsTuple._1() == null) { return loadFileGroupInternal(shuffleId, isSegmentGranularityVisible); } else { return existsTuple; } }); - if (fileGroupTuple._1 == null) { + if (fileGroupTuple._1() == null) { throw new CelebornIOException( - loadFileGroupException(shuffleId, partitionId, (fileGroupTuple._2))); + loadFileGroupException(shuffleId, partitionId, (fileGroupTuple._2())), + fileGroupTuple._3()); } else { - return fileGroupTuple._1; + return fileGroupTuple._1(); } } @@ -1899,7 +1905,7 @@ public CelebornInputStream readPartition( } @VisibleForTesting - public Map> getReduceFileGroupsMap() { + public Map> getReduceFileGroupsMap() { return reduceFileGroupsMap; } diff --git a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java index 5256ae0fb0c..4f9c1e2ab60 100644 --- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java @@ -26,13 +26,20 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.GenericFutureListener; import org.apache.commons.lang3.RandomStringUtils; +import org.junit.Assert; import org.junit.Test; import org.apache.celeborn.client.compress.Compressor; @@ -43,7 +50,8 @@ import org.apache.celeborn.common.network.client.TransportClientFactory; import org.apache.celeborn.common.protocol.CompressionCodec; import org.apache.celeborn.common.protocol.PartitionLocation; -import org.apache.celeborn.common.protocol.message.ControlMessages.*; +import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse$; +import org.apache.celeborn.common.protocol.message.ControlMessages.RegisterShuffleResponse$; import org.apache.celeborn.common.protocol.message.StatusCode; import org.apache.celeborn.common.rpc.RpcEndpointRef; @@ -397,4 +405,140 @@ public Void get(long timeout, TimeUnit unit) { shuffleClient.dataClientFactory = clientFactory; return conf; } + + @Test + public void testUpdateReducerFileGroupInterrupted() throws InterruptedException { + CelebornConf conf = new CelebornConf(); + conf.set("celeborn.client.spark.stageRerun.enabled", "true"); + Map> locations = new HashMap<>(); + when(endpointRef.askSync(any(), any(), any())) + .thenAnswer( + t -> { + Thread.sleep(60 * 1000); + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, locations, new int[0], Collections.emptySet()); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + AtomicReference exceptionRef = new AtomicReference<>(); + Thread thread = + new Thread( + new Runnable() { + @Override + public void run() { + try { + shuffleClient.updateFileGroup(0, 0); + } catch (CelebornIOException e) { + exceptionRef.set(e); + } + } + }); + + thread.start(); + Thread.sleep(1000); + thread.interrupt(); + Thread.sleep(1000); + + Exception exception = exceptionRef.get(); + Assert.assertTrue(exception.getCause() instanceof InterruptedException); + } + + @Test + public void testUpdateReducerFileGroupTimeout() throws InterruptedException { + CelebornConf conf = new CelebornConf(); + conf.set("celeborn.client.rpc.getReducerFileGroup.askTimeout", "1ms"); + LifecycleManager lifecycleManager = new LifecycleManager("APP", conf); + Map> locations = new HashMap<>(); + when(endpointRef.askSync(any(), any(), any())) + .thenAnswer( + t -> { + Thread.sleep(60 * 1000); + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SUCCESS, locations, new int[0], Collections.emptySet()); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(lifecycleManager.self()); + + AtomicReference exceptionRef = new AtomicReference<>(); + Thread thread = + new Thread( + new Runnable() { + @Override + public void run() { + try { + shuffleClient.updateFileGroup(0, 0); + } catch (CelebornIOException e) { + exceptionRef.set(e); + } + } + }); + + thread.start(); + Thread.sleep(5000); + + Exception exception = exceptionRef.get(); + Assert.assertTrue(exception.getCause() instanceof TimeoutException); + } + + @Test + public void testUpdateReducerFileGroupNonFetchFailureExceptions() { + CelebornConf conf = new CelebornConf(); + conf.set("celeborn.client.spark.stageRerun.enabled", "true"); + Map> locations = new HashMap<>(); + when(endpointRef.askSync(any(), any(), any())) + .thenAnswer( + t -> { + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SHUFFLE_NOT_REGISTERED, locations, new int[0], Collections.emptySet()); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + try { + shuffleClient.updateFileGroup(0, 0); + } catch (CelebornIOException e) { + Assert.assertTrue(e.getCause() == null); + } + + when(endpointRef.askSync(any(), any(), any())) + .thenAnswer( + t -> { + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.STAGE_END_TIME_OUT, locations, new int[0], Collections.emptySet()); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + try { + shuffleClient.updateFileGroup(0, 0); + } catch (CelebornIOException e) { + Assert.assertTrue(e.getCause() == null); + } + + when(endpointRef.askSync(any(), any(), any())) + .thenAnswer( + t -> { + return GetReducerFileGroupResponse$.MODULE$.apply( + StatusCode.SHUFFLE_DATA_LOST, locations, new int[0], Collections.emptySet()); + }); + + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + try { + shuffleClient.updateFileGroup(0, 0); + } catch (CelebornIOException e) { + Assert.assertTrue(e.getCause() == null); + } + } } diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 80d31d74750..e3e5a9c5a31 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1341,6 +1341,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def testMockCommitFilesFailure: Boolean = get(TEST_MOCK_COMMIT_FILES_FAILURE) def testMockShuffleLost: Boolean = get(TEST_CLIENT_MOCK_SHUFFLE_LOST) def testMockShuffleLostShuffle: Int = get(TEST_CLIENT_MOCK_SHUFFLE_LOST_SHUFFLE) + def testMockShuffleReader: Boolean = get(TEST_CLIENT_MOCK_SHUFFLE_READER) def testPushPrimaryDataTimeout: Boolean = get(TEST_CLIENT_PUSH_PRIMARY_DATA_TIMEOUT) def testPushReplicaDataTimeout: Boolean = get(TEST_WORKER_PUSH_REPLICA_DATA_TIMEOUT) def testRetryRevive: Boolean = get(TEST_CLIENT_RETRY_REVIVE) @@ -4352,6 +4353,16 @@ object CelebornConf extends Logging { .intConf .createWithDefault(0) + val TEST_CLIENT_MOCK_SHUFFLE_READER: ConfigEntry[Boolean] = + buildConf("celeborn.test.client.mockShuffleReader") + .internal + .categories("test", "client") + .doc("Mock shuffle reader for shuffle") + .version("0.5.3") + .internal + .booleanConf + .createWithDefault(false) + val CLIENT_PUSH_REPLICATE_ENABLED: ConfigEntry[Boolean] = buildConf("celeborn.client.push.replicate.enabled") .withAlternative("celeborn.push.replicate.enabled") diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleLostSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornStageRerunSuite.scala similarity index 96% rename from tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleLostSuite.scala rename to tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornStageRerunSuite.scala index a281196c77a..04345593c47 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornShuffleLostSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornStageRerunSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.funsuite.AnyFunSuite import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.common.protocol.ShuffleMode -class CelebornShuffleLostSuite extends AnyFunSuite +class CelebornStageRerunSuite extends AnyFunSuite with SparkTestBase with BeforeAndAfterEach { @@ -37,7 +37,7 @@ class CelebornShuffleLostSuite extends AnyFunSuite System.gc() } - test("celeborn shuffle data lost - hash") { + test("stage rerun for data lost - hash") { val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]") val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() val combineResult = combine(sparkSession)