Skip to content

Commit

Permalink
revert ut
Browse files Browse the repository at this point in the history
  • Loading branch information
turboFei committed Nov 19, 2024
1 parent e86c9f0 commit b087eb1
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public static void cancelShuffle(int shuffleId, String reason) {
.defaultAlwaysNull()
.build();

public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
if (SparkContext$.MODULE$.getActive().nonEmpty()) {
TaskSchedulerImpl taskScheduler =
(TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler();
Expand All @@ -238,7 +238,7 @@ public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
.asJavaCollection().stream()
.anyMatch(
ti -> {
if ((ti.running() || ti.successful())
if ((!ti.finished() || ti.successful())
&& ti.attemptNumber() != taskInfo.attemptNumber()) {
LOG.info("Another attempt of task {} is running: {}.", taskInfo, ti);
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ public static void cancelShuffle(int shuffleId, String reason) {
.defaultAlwaysNull()
.build();

public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
public static synchronized boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
if (SparkContext$.MODULE$.getActive().nonEmpty()) {
TaskSchedulerImpl taskScheduler =
(TaskSchedulerImpl) SparkContext$.MODULE$.getActive().get().taskScheduler();
Expand All @@ -354,7 +354,7 @@ public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
.asJavaCollection().stream()
.anyMatch(
ti -> {
if ((ti.running() || ti.successful())
if ((!ti.finished() || ti.successful())
&& ti.attemptNumber() != taskInfo.attemptNumber()) {
LOG.info("Another attempt of task {} is running: {}.", taskInfo, ti);
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicBoolean
import org.apache.spark.{BarrierTaskContext, ShuffleDependency, SparkConf, SparkContextHelper, SparkException, TaskContext}
import org.apache.spark.celeborn.ExceptionMakerHelper
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.SparkSchedulerHelper
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager}
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -57,8 +56,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
super.createWorker(map, storageDir)
}

class ShuffleReaderGetHook(conf: CelebornConf, speculation: Boolean = false)
extends ShuffleManagerHook {
class ShuffleReaderGetHook(conf: CelebornConf) extends ShuffleManagerHook {
var executed: AtomicBoolean = new AtomicBoolean(false)
val lock = new Object

Expand All @@ -67,10 +65,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite
startPartition: Int,
endPartition: Int,
context: TaskContext): Unit = {
val taskIndex = SparkSchedulerHelper.getTaskIndex(context.taskAttemptId())
if (speculation && taskIndex == 0) {
Thread.sleep(3000) // sleep for speculation
}
if (executed.get() == true) return

lock.synchronized {
Expand All @@ -88,19 +82,17 @@ class CelebornFetchFailureSuite extends AnyFunSuite
val allFiles = workerDirs.map(dir => {
new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
})
val datafiles = allFiles.filter(_.exists())
if (datafiles.nonEmpty) {
if (taskIndex == 0) { // only cleanup the data file in the task with index 0
datafiles.foreach(_.delete())
executed.set(true)
}
} else {
throw new RuntimeException("unexpected, there must be some data file" +
s" under ${workerDirs.mkString(",")}")
val datafile = allFiles.filter(_.exists())
.flatMap(_.listFiles().iterator).headOption
datafile match {
case Some(file) => file.delete()
case None => throw new RuntimeException("unexpected, there must be some data file" +
s" under ${workerDirs.mkString(",")}")
}
}
case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here")
}
executed.set(true)
}
}
}
Expand Down Expand Up @@ -447,61 +439,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite
}
}

test(s"celeborn spark integration test - do not rerun stage if task another attempt is running") {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
val sparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
.config("spark.sql.shuffle.partitions", 2)
.config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
.config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
.config(
"spark.shuffle.manager",
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
.config("spark.speculation", "true")
.config("spark.speculation.multiplier", "2")
.config("spark.speculation.quantile", "0")
.getOrCreate()

val shuffleMgr = SparkContextHelper.env
.shuffleManager
.asInstanceOf[TestCelebornShuffleManager]
var preventUnnecessaryStageRerun = false
val lifecycleManager = shuffleMgr.getLifecycleManager
lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck(new java.util.function.Function[
java.lang.Long,
Boolean] {
override def apply(taskId: java.lang.Long): Boolean = {
val anotherRunningOrSuccessful = SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId)
if (anotherRunningOrSuccessful) {
preventUnnecessaryStageRerun = true
}
!anotherRunningOrSuccessful
}
})

val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHook(celebornConf, speculation = true)
TestCelebornShuffleManager.registerReaderGetHook(hook)

val value = Range(1, 10000).mkString(",")
val tuples = sparkSession.sparkContext.parallelize(1 to 10000, 2)
.map { i => (i, value) }.groupByKey(2).collect()

// verify result
assert(hook.executed.get() == true)
assert(preventUnnecessaryStageRerun)
assert(tuples.length == 10000)
for (elem <- tuples) {
assert(elem._2.mkString(",").equals(value))
}

shuffleMgr.unregisterShuffle(0)
assert(lifecycleManager.getUnregisterShuffleTime().containsKey(0))
assert(lifecycleManager.getUnregisterShuffleTime().containsKey(1))

sparkSession.stop()
}

private def findAppShuffleId(rdd: RDD[_]): Int = {
val deps = rdd.dependencies
if (deps.size != 1 && !deps.head.isInstanceOf[ShuffleDependency[_, _, _]]) {
Expand Down

This file was deleted.

0 comments on commit b087eb1

Please sign in to comment.