Skip to content

Commit

Permalink
[CELEBORN-1838] Interrupt spark task should not report fetch failure
Browse files Browse the repository at this point in the history
  • Loading branch information
FMX committed Jan 17, 2025
1 parent ad93381 commit b71d05e
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import scala.Tuple2;
import scala.Tuple3;
import scala.reflect.ClassTag$;

import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -265,9 +265,9 @@ public ReduceFileGroups updateFileGroup(
int shuffleId, int partitionId, boolean isSegmentGranularityVisible)
throws CelebornIOException {
ReduceFileGroups reduceFileGroups =
reduceFileGroupsMap.computeIfAbsent(
shuffleId, (id) -> Tuple2.apply(new ReduceFileGroups(), null))
._1;
reduceFileGroupsMap
.computeIfAbsent(shuffleId, (id) -> Tuple3.apply(new ReduceFileGroups(), null, null))
._1();
if (reduceFileGroups.partitionIds != null
&& reduceFileGroups.partitionIds.contains(partitionId)) {
logger.debug(
Expand All @@ -281,12 +281,12 @@ public ReduceFileGroups updateFileGroup(
Utils.makeReducerKey(shuffleId, partitionId));
} else {
// refresh file groups
Tuple2<ReduceFileGroups, String> fileGroups =
Tuple3<ReduceFileGroups, String, Exception> fileGroups =
loadFileGroupInternal(shuffleId, isSegmentGranularityVisible);
ReduceFileGroups newGroups = fileGroups._1;
ReduceFileGroups newGroups = fileGroups._1();
if (newGroups == null) {
throw new CelebornIOException(
loadFileGroupException(shuffleId, partitionId, fileGroups._2));
loadFileGroupException(shuffleId, partitionId, fileGroups._2()));
} else if (!newGroups.partitionIds.contains(partitionId)) {
throw new CelebornIOException(
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
package org.apache.spark.shuffle.celeborn

import java.io.IOException
import java.nio.file.Files
import java.util
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit}
import java.util.concurrent.atomic.AtomicReference

import scala.collection.JavaConverters._

import com.google.common.annotations.VisibleForTesting
import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext}
import org.apache.spark.celeborn.ExceptionMakerHelper
import org.apache.spark.internal.Logging
Expand All @@ -33,14 +35,14 @@ import org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.{DummyShuffleClient, ShuffleClient}
import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups
import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException}
import org.apache.celeborn.common.network.client.TransportClient
import org.apache.celeborn.common.network.protocol.TransportMessage
import org.apache.celeborn.common.protocol.{MessageType, PartitionLocation, PbOpenStreamList, PbOpenStreamListResponse, PbStreamHandler}
import org.apache.celeborn.common.protocol._
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils, Utils}

Expand All @@ -56,18 +58,34 @@ class CelebornShuffleReader[K, C](
shuffleIdTracker: ExecutorShuffleIdTracker)
extends ShuffleReader[K, C] with Logging {

private val dep = handle.dependency
private val shuffleClient = ShuffleClient.get(
handle.appUniqueId,
handle.lifecycleManagerHost,
handle.lifecycleManagerPort,
conf,
handle.userIdentifier,
handle.extension)
val mockReader = conf.testMockShuffleReader

private val dep =
if (mockReader) {
null
} else {
handle.dependency
}

@VisibleForTesting
val shuffleClient =
if (mockReader) {
new DummyShuffleClient(conf, Files.createTempFile("test", "mockfile").toFile)
} else {
ShuffleClient.get(
handle.appUniqueId,
handle.lifecycleManagerHost,
handle.lifecycleManagerPort,
conf,
handle.userIdentifier,
handle.extension)
}

private val exceptionRef = new AtomicReference[IOException]
private val throwsFetchFailure = handle.throwsFetchFailure
private val encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(context)
private val throwsFetchFailure =
if (mockReader) conf.clientStageRerunEnabled else handle.throwsFetchFailure
private val encodedAttemptId =
if (mockReader) 0 else SparkCommonUtils.getEncodedAttemptNumber(context)

override def read(): Iterator[Product2[K, C]] = {

Expand Down Expand Up @@ -111,7 +129,9 @@ class CelebornShuffleReader[K, C](
fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
} catch {
case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
handleFetchExceptions(handle.shuffleId, shuffleId, 0, ce)
// if a task is interrupted, should not report fetch failure
// if a task update file group timeout, should not report fetch failure
checkAndReportFetchFailureForUpdateFileGroupFailure(shuffleId, ce)
case e: Throwable => throw e
}

Expand Down Expand Up @@ -370,7 +390,22 @@ class CelebornShuffleReader[K, C](
}
}

private def handleFetchExceptions(
@VisibleForTesting
def checkAndReportFetchFailureForUpdateFileGroupFailure(
celebornShuffleId: Int,
ce: Throwable): Unit = {
if (ce.getCause != null &&
ce.getCause.isInstanceOf[InterruptedException] || ce.getCause.isInstanceOf[
TimeoutException]) {
logWarning(s"fetch shuffle ${celebornShuffleId} timeout or interrupt", ce)
throw ce
} else {
handleFetchExceptions(handle.shuffleId, celebornShuffleId, 0, ce)
}
}

@VisibleForTesting
def handleFetchExceptions(
appShuffleId: Int,
shuffleId: Int,
partitionId: Int,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.celeborn

import java.util.concurrent.TimeoutException

import org.apache.spark.TaskContext
import org.apache.spark.shuffle.ShuffleReadMetricsReporter
import org.mockito.Mockito
import org.scalatest.funsuite.AnyFunSuite

import org.apache.celeborn.client.DummyShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.CelebornIOException

class CelebornShuffleReaderSuite extends AnyFunSuite {

/**
* Due to spark limitations, spark local mode can not test speculation tasks ,
* test the method `checkAndReportFetchFailureForUpdateFileGroupFailure`
*/
test("CELEBORN-1838 test check report fetch failure exceptions ") {
val handler = Mockito.mock(classOf[CelebornShuffleHandle[Int, Int, Int]])
val context = Mockito.mock(classOf[TaskContext])
val metricReporter = Mockito.mock(classOf[ShuffleReadMetricsReporter])
val conf = new CelebornConf()
conf.set("celeborn.test.client.mockShuffleReader", "true")
conf.set("celeborn.client.spark.stageRerun.enabled", "true")

val shuffleReader =
new CelebornShuffleReader[Int, Int](handler, 0, 0, 0, 0, context, conf, metricReporter, null)

val exception1: Throwable = new CelebornIOException("test1", new InterruptedException("test1"))
val exception2: Throwable = new CelebornIOException("test2", new TimeoutException("test2"))
val exception3: Throwable = new CelebornIOException("test3")
val exception4: Throwable = new CelebornIOException("test4")

try {
shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception1)
} catch {
case _: Throwable =>
}
try {
shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception2)
} catch {
case _: Throwable =>
}
try {
shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception3)
} catch {
case _: Throwable =>
}
assert(
shuffleReader.shuffleClient.asInstanceOf[DummyShuffleClient].fetchFailureCount.get() === 1)
try {
shuffleReader.checkAndReportFetchFailureForUpdateFileGroupFailure(0, exception4)
} catch {
case _: Throwable =>
}
assert(
shuffleReader.shuffleClient.asInstanceOf[DummyShuffleClient].fetchFailureCount.get() === 2)

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -55,6 +56,8 @@ public class DummyShuffleClient extends ShuffleClient {
private final Map<Integer, ConcurrentHashMap<Integer, PartitionLocation>> reducePartitionMap =
new HashMap<>();

public AtomicInteger fetchFailureCount = new AtomicInteger();

public DummyShuffleClient(CelebornConf conf, File file) throws Exception {
this.os = new BufferedOutputStream(new FileOutputStream(file));
this.conf = conf;
Expand Down Expand Up @@ -181,6 +184,7 @@ public int getShuffleId(

@Override
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) {
fetchFailureCount.incrementAndGet();
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.concurrent.TimeUnit;

import scala.Tuple2;
import scala.Tuple3;
import scala.reflect.ClassTag$;

import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -170,7 +171,7 @@ public void update(ReduceFileGroups fileGroups) {
}

// key: shuffleId
protected final Map<Integer, Tuple2<ReduceFileGroups, String>> reduceFileGroupsMap =
protected final Map<Integer, Tuple3<ReduceFileGroups, String, Exception>> reduceFileGroupsMap =
JavaUtils.newConcurrentHashMap();

public ShuffleClientImpl(String appUniqueId, CelebornConf conf, UserIdentifier userIdentifier) {
Expand Down Expand Up @@ -1742,11 +1743,12 @@ public boolean cleanupShuffle(int shuffleId) {
return true;
}

protected Tuple2<ReduceFileGroups, String> loadFileGroupInternal(
protected Tuple3<ReduceFileGroups, String, Exception> loadFileGroupInternal(
int shuffleId, boolean isSegmentGranularityVisible) {
{
long getReducerFileGroupStartTime = System.nanoTime();
String exceptionMsg = null;
Exception exception = null;
try {
if (lifecycleManagerRef == null) {
exceptionMsg = "Driver endpoint is null!";
Expand All @@ -1768,9 +1770,10 @@ protected Tuple2<ReduceFileGroups, String> loadFileGroupInternal(
shuffleId,
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - getReducerFileGroupStartTime),
response.fileGroup().size());
return Tuple2.apply(
return Tuple3.apply(
new ReduceFileGroups(
response.fileGroup(), response.attempts(), response.partitionIds()),
null,
null);
case SHUFFLE_NOT_REGISTERED:
logger.warn(
Expand All @@ -1779,9 +1782,10 @@ protected Tuple2<ReduceFileGroups, String> loadFileGroupInternal(
response.status(),
shuffleId);
// return empty result
return Tuple2.apply(
return Tuple3.apply(
new ReduceFileGroups(
response.fileGroup(), response.attempts(), response.partitionIds()),
null,
null);
case STAGE_END_TIME_OUT:
case SHUFFLE_DATA_LOST:
Expand All @@ -1800,8 +1804,9 @@ protected Tuple2<ReduceFileGroups, String> loadFileGroupInternal(
}
logger.error("Exception raised while call GetReducerFileGroup for {}.", shuffleId, e);
exceptionMsg = e.getMessage();
exception = e;
}
return Tuple2.apply(null, exceptionMsg);
return Tuple3.apply(null, exceptionMsg, exception);
}
}

Expand All @@ -1814,21 +1819,22 @@ public ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
public ReduceFileGroups updateFileGroup(
int shuffleId, int partitionId, boolean isSegmentGranularityVisible)
throws CelebornIOException {
Tuple2<ReduceFileGroups, String> fileGroupTuple =
Tuple3<ReduceFileGroups, String, Exception> fileGroupTuple =
reduceFileGroupsMap.compute(
shuffleId,
(id, existsTuple) -> {
if (existsTuple == null || existsTuple._1 == null) {
if (existsTuple == null || existsTuple._1() == null) {
return loadFileGroupInternal(shuffleId, isSegmentGranularityVisible);
} else {
return existsTuple;
}
});
if (fileGroupTuple._1 == null) {
if (fileGroupTuple._1() == null) {
throw new CelebornIOException(
loadFileGroupException(shuffleId, partitionId, (fileGroupTuple._2)));
loadFileGroupException(shuffleId, partitionId, (fileGroupTuple._2())),
fileGroupTuple._3());
} else {
return fileGroupTuple._1;
return fileGroupTuple._1();
}
}

Expand Down Expand Up @@ -1899,7 +1905,7 @@ public CelebornInputStream readPartition(
}

@VisibleForTesting
public Map<Integer, Tuple2<ReduceFileGroups, String>> getReduceFileGroupsMap() {
public Map<Integer, Tuple3<ReduceFileGroups, String, Exception>> getReduceFileGroupsMap() {
return reduceFileGroupsMap;
}

Expand Down
Loading

0 comments on commit b71d05e

Please sign in to comment.