diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6a16a31654630..59a63696b26b1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -240,7 +240,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) // Create the Spark execution environment (cache, map output tracker, etc) - private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus) + + // This function allows components created by SparkEnv to be mocked in unit tests: + private[spark] def createSparkEnv( + conf: SparkConf, + isLocal: Boolean, + listenerBus: LiveListenerBus): SparkEnv = { + SparkEnv.createDriverEnv(conf, isLocal, listenerBus) + } + + private[spark] val env = createSparkEnv(conf, isLocal, listenerBus) SparkEnv.set(env) // Used to store a URL for each static file/jar together with the file's local timestamp diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index f0edf5d9cb280..d95a176d18928 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -154,7 +154,8 @@ object SparkEnv extends Logging { private[spark] def createDriverEnv( conf: SparkConf, isLocal: Boolean, - listenerBus: LiveListenerBus): SparkEnv = { + listenerBus: LiveListenerBus, + mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!") assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!") val hostname = conf.get("spark.driver.host") @@ -166,7 +167,8 @@ object SparkEnv extends Logging { port, isDriver = true, isLocal = isLocal, - listenerBus = listenerBus + listenerBus = listenerBus, + mockOutputCommitCoordinator = mockOutputCommitCoordinator ) } @@ -205,7 +207,8 @@ object SparkEnv extends Logging { isDriver: Boolean, isLocal: Boolean, listenerBus: LiveListenerBus = null, - numUsableCores: Int = 0): SparkEnv = { + numUsableCores: Int = 0, + mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { // Listener bus is only used on the driver if (isDriver) { @@ -353,10 +356,13 @@ object SparkEnv extends Logging { "levels using the RDD.persist() method instead.") } - val outputCommitCoordinator = new OutputCommitCoordinator(conf) + val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { + new OutputCommitCoordinator(conf) + } val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator", new OutputCommitCoordinatorActor(outputCommitCoordinator)) outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor) + new SparkEnv( executorId, actorSystem, diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index c454fc6f220a7..6c44e6a24ae07 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -119,7 +119,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { authorizedCommittersByStage.clear } - private def handleAskPermissionToCommit( + // Marked private[scheduler] instead of private so this can be mocked in tests + private[scheduler] def handleAskPermissionToCommit( stage: StageId, task: TaskId, attempt: TaskAttemptId): Boolean = synchronized { diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index db3d89945d0a6..9ffc7e01037a8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.io.File +import java.util.concurrent.TimeoutException import org.mockito.Matchers import org.mockito.Mockito._ @@ -28,9 +29,13 @@ import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter} import org.apache.spark._ -import org.apache.spark.rdd.FakeOutputCommitter +import org.apache.spark.rdd.{RDD, FakeOutputCommitter} import org.apache.spark.util.Utils +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.postfixOps + /** * Unit tests for the output commit coordination functionality. * @@ -61,13 +66,23 @@ import org.apache.spark.util.Utils */ class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter { - var dagScheduler: DAGScheduler = null + var outputCommitCoordinator: OutputCommitCoordinator = null var tempDir: File = null var sc: SparkContext = null before { - sc = new SparkContext("local[4]", classOf[OutputCommitCoordinatorSuite].getSimpleName) tempDir = Utils.createTempDir() + sc = new SparkContext("local[4]", classOf[OutputCommitCoordinatorSuite].getSimpleName) { + override private[spark] def createSparkEnv( + conf: SparkConf, + isLocal: Boolean, + listenerBus: LiveListenerBus): SparkEnv = { + outputCommitCoordinator = spy(new OutputCommitCoordinator(conf)) + // Use Mockito.spy() to maintain the default infrastructure everywhere else. + // This mocking allows us to control the coordinator responses in test cases. + SparkEnv.createDriverEnv(conf, isLocal, listenerBus, Some(outputCommitCoordinator)) + } + } // Use Mockito.spy() to maintain the default infrastructure everywhere else val mockTaskScheduler = spy(sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]) @@ -109,6 +124,7 @@ class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter { after { sc.stop() tempDir.delete() + outputCommitCoordinator = null } test("Only one of two duplicate commit tasks should commit") { @@ -124,6 +140,20 @@ class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter { 0 until rdd.partitions.size, allowLocal = false) assert(tempDir.list().size === 1) } + + test("Job should not complete if all commits are denied") { + doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit( + Matchers.any(), Matchers.any(), Matchers.any()) + val rdd: RDD[Int] = sc.parallelize(Seq(1), 1) + def resultHandler(x: Int, y: Unit): Unit = {} + val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd, + OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully, + 0 until rdd.partitions.size, resultHandler, 0) + intercept[TimeoutException] { + Await.result(futureAction, 5 seconds) + } + assert(tempDir.list().size === 0) + } } /** @@ -145,11 +175,13 @@ private case class OutputCommitFunctions(tempDirPath: String) { } } - def commitSuccessfully(ctx: TaskContext, iter: Iterator[Int]): Unit = { + def commitSuccessfully(iter: Iterator[Int]): Unit = { + val ctx = TaskContext.get() runCommitWithProvidedCommitter(ctx, iter, successfulOutputCommitter) } - def failFirstCommitAttempt(ctx: TaskContext, iter: Iterator[Int]): Unit = { + def failFirstCommitAttempt(iter: Iterator[Int]): Unit = { + val ctx = TaskContext.get() runCommitWithProvidedCommitter(ctx, iter, if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter) }