Skip to content

Commit

Permalink
fetch failure integration testing (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
turboFei authored Dec 27, 2024
1 parent 4d96e29 commit c35172c
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,19 @@

package org.apache.celeborn.tests.spark

import java.io.{File, IOException}
import java.io.IOException
import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.{BarrierTaskContext, ShuffleDependency, SparkConf, SparkContextHelper, SparkException, TaskContext}
import org.apache.spark.celeborn.ExceptionMakerHelper
import org.apache.spark.rdd.RDD
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager}
import org.apache.spark.shuffle.celeborn.{SparkShuffleManager, SparkUtils, TestCelebornShuffleManager}
import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.protocol.ShuffleMode
import org.apache.celeborn.service.deploy.worker.Worker

class CelebornFetchFailureSuite extends AnyFunSuite
with SparkTestBase
Expand All @@ -46,57 +43,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite
System.gc()
}

var workerDirs: Seq[String] = Seq.empty

override def createWorker(map: Map[String, String]): Worker = {
val storageDir = createTmpDir()
this.synchronized {
workerDirs = workerDirs :+ storageDir
}
super.createWorker(map, storageDir)
}

class ShuffleReaderGetHook(conf: CelebornConf) extends ShuffleManagerHook {
var executed: AtomicBoolean = new AtomicBoolean(false)
val lock = new Object

override def exec(
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): Unit = {
if (executed.get() == true) return

lock.synchronized {
handle match {
case h: CelebornShuffleHandle[_, _, _] => {
val appUniqueId = h.appUniqueId
val shuffleClient = ShuffleClient.get(
h.appUniqueId,
h.lifecycleManagerHost,
h.lifecycleManagerPort,
conf,
h.userIdentifier,
h.extension)
val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
val allFiles = workerDirs.map(dir => {
new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
})
val datafile = allFiles.filter(_.exists())
.flatMap(_.listFiles().iterator).headOption
datafile match {
case Some(file) => file.delete()
case None => throw new RuntimeException("unexpected, there must be some data file" +
s" under ${workerDirs.mkString(",")}")
}
}
case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here")
}
executed.set(true)
}
}
}

test("celeborn spark integration test - Fetch Failure") {
if (Spark3OrNewer) {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
Expand All @@ -111,7 +57,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()

val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHook(celebornConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)

val value = Range(1, 10000).mkString(",")
Expand Down Expand Up @@ -184,7 +130,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()

val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHook(celebornConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)

import sparkSession.implicits._
Expand Down Expand Up @@ -215,7 +161,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()

val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHook(celebornConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)

val sc = sparkSession.sparkContext
Expand Down Expand Up @@ -255,7 +201,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()

val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHook(celebornConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)

val sc = sparkSession.sparkContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,26 @@

package org.apache.celeborn.tests.spark

import java.io.File
import java.util.concurrent.atomic.AtomicBoolean

import scala.util.Random

import org.apache.spark.SPARK_VERSION
import org.apache.spark.SparkConf
import org.apache.spark.{SPARK_VERSION, SparkConf, TaskContext}
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkUtils}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.scalatest.funsuite.AnyFunSuite

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.CelebornConf._
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.protocol.ShuffleMode
import org.apache.celeborn.service.deploy.MiniClusterFeature
import org.apache.celeborn.service.deploy.worker.Worker

trait SparkTestBase extends AnyFunSuite
with Logging with MiniClusterFeature with BeforeAndAfterAll with BeforeAndAfterEach {
Expand All @@ -52,6 +59,16 @@ trait SparkTestBase extends AnyFunSuite
shutdownMiniCluster()
}

var workerDirs: Seq[String] = Seq.empty

override def createWorker(map: Map[String, String]): Worker = {
val storageDir = createTmpDir()
this.synchronized {
workerDirs = workerDirs :+ storageDir
}
super.createWorker(map, storageDir)
}

def updateSparkConf(sparkConf: SparkConf, mode: ShuffleMode): SparkConf = {
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.set(
Expand Down Expand Up @@ -98,4 +115,45 @@ trait SparkTestBase extends AnyFunSuite
val outMap = result.collect().map(row => row.getString(0) -> row.getLong(1)).toMap
outMap
}

class ShuffleReaderFetchFailureGetHook(conf: CelebornConf) extends ShuffleManagerHook {
var executed: AtomicBoolean = new AtomicBoolean(false)
val lock = new Object

override def exec(
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): Unit = {
if (executed.get() == true) return

lock.synchronized {
handle match {
case h: CelebornShuffleHandle[_, _, _] => {
val appUniqueId = h.appUniqueId
val shuffleClient = ShuffleClient.get(
h.appUniqueId,
h.lifecycleManagerHost,
h.lifecycleManagerPort,
conf,
h.userIdentifier,
h.extension)
val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
val allFiles = workerDirs.map(dir => {
new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
})
val datafile = allFiles.filter(_.exists())
.flatMap(_.listFiles().iterator).sortBy(_.getName).headOption
datafile match {
case Some(file) => file.delete()
case None => throw new RuntimeException("unexpected, there must be some data file" +
s" under ${workerDirs.mkString(",")}")
}
}
case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here")
}
executed.set(true)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.shuffle.celeborn

import scala.collection.JavaConverters._

import org.apache.spark.SparkConf
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -54,13 +56,19 @@ class SparkUtilsSuite extends AnyFunSuite
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
.getOrCreate()

val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)

try {
val sc = sparkSession.sparkContext
val jobThread = new Thread {
override def run(): Unit = {
try {
sc.parallelize(1 to 100, 2)
.repartition(1)
val value = Range(1, 10000).mkString(",")
sc.parallelize(1 to 10000, 2)
.map { i => (i, value) }
.groupByKey(10)
.mapPartitions { iter =>
Thread.sleep(3000)
iter
Expand All @@ -73,13 +81,15 @@ class SparkUtilsSuite extends AnyFunSuite
jobThread.start()

val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
eventually(timeout(3.seconds), interval(100.milliseconds)) {
val taskId = 0
val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, taskId)
eventually(timeout(30.seconds), interval(0.milliseconds)) {
assert(hook.executed.get() == true)
val reportedTaskId =
SparkUtils.reportedStageShuffleFetchFailureTaskIds.values().asScala.flatMap(
_.asScala).head
val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, reportedTaskId)
assert(taskSetManager != null)
assert(SparkUtils.getTaskAttempts(taskSetManager, taskId)._2.size() == 1)
assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId))
assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 1)
assert(SparkUtils.getTaskAttempts(taskSetManager, reportedTaskId)._2.size() == 1)
assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(reportedTaskId))
}

sparkSession.sparkContext.cancelAllJobs()
Expand Down

0 comments on commit c35172c

Please sign in to comment.