Skip to content

Commit

Permalink
Add test to ensure that a job that denies all commits cannot complete…
Browse files Browse the repository at this point in the history
… successfully.

Currently, the job just hangs.

This test was added after I noticed that our previous tests would still pass
after commenting out the "throw CommitDeniedException" line in SparkHadoopWriter.
  • Loading branch information
JoshRosen committed Feb 4, 2015
1 parent 97da5fe commit ede7590
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 11 deletions.
11 changes: 10 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER)

// Create the Spark execution environment (cache, map output tracker, etc)
private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus)

// This function allows components created by SparkEnv to be mocked in unit tests:
private[spark] def createSparkEnv(
conf: SparkConf,
isLocal: Boolean,
listenerBus: LiveListenerBus): SparkEnv = {
SparkEnv.createDriverEnv(conf, isLocal, listenerBus)
}

private[spark] val env = createSparkEnv(conf, isLocal, listenerBus)
SparkEnv.set(env)

// Used to store a URL for each static file/jar together with the file's local timestamp
Expand Down
14 changes: 10 additions & 4 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ object SparkEnv extends Logging {
private[spark] def createDriverEnv(
conf: SparkConf,
isLocal: Boolean,
listenerBus: LiveListenerBus): SparkEnv = {
listenerBus: LiveListenerBus,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!")
assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!")
val hostname = conf.get("spark.driver.host")
Expand All @@ -166,7 +167,8 @@ object SparkEnv extends Logging {
port,
isDriver = true,
isLocal = isLocal,
listenerBus = listenerBus
listenerBus = listenerBus,
mockOutputCommitCoordinator = mockOutputCommitCoordinator
)
}

Expand Down Expand Up @@ -205,7 +207,8 @@ object SparkEnv extends Logging {
isDriver: Boolean,
isLocal: Boolean,
listenerBus: LiveListenerBus = null,
numUsableCores: Int = 0): SparkEnv = {
numUsableCores: Int = 0,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {

// Listener bus is only used on the driver
if (isDriver) {
Expand Down Expand Up @@ -353,10 +356,13 @@ object SparkEnv extends Logging {
"levels using the RDD.persist() method instead.")
}

val outputCommitCoordinator = new OutputCommitCoordinator(conf)
val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse {
new OutputCommitCoordinator(conf)
}
val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator",
new OutputCommitCoordinatorActor(outputCommitCoordinator))
outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor)

new SparkEnv(
executorId,
actorSystem,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging {
authorizedCommittersByStage.clear
}

private def handleAskPermissionToCommit(
// Marked private[scheduler] instead of private so this can be mocked in tests
private[scheduler] def handleAskPermissionToCommit(
stage: StageId,
task: TaskId,
attempt: TaskAttemptId): Boolean = synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.scheduler

import java.io.File
import java.util.concurrent.TimeoutException

import org.mockito.Matchers
import org.mockito.Mockito._
Expand All @@ -28,9 +29,13 @@ import org.scalatest.{BeforeAndAfter, FunSuite}
import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter}

import org.apache.spark._
import org.apache.spark.rdd.FakeOutputCommitter
import org.apache.spark.rdd.{RDD, FakeOutputCommitter}
import org.apache.spark.util.Utils

import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps

/**
* Unit tests for the output commit coordination functionality.
*
Expand Down Expand Up @@ -61,13 +66,23 @@ import org.apache.spark.util.Utils
*/
class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter {

var dagScheduler: DAGScheduler = null
var outputCommitCoordinator: OutputCommitCoordinator = null
var tempDir: File = null
var sc: SparkContext = null

before {
sc = new SparkContext("local[4]", classOf[OutputCommitCoordinatorSuite].getSimpleName)
tempDir = Utils.createTempDir()
sc = new SparkContext("local[4]", classOf[OutputCommitCoordinatorSuite].getSimpleName) {
override private[spark] def createSparkEnv(
conf: SparkConf,
isLocal: Boolean,
listenerBus: LiveListenerBus): SparkEnv = {
outputCommitCoordinator = spy(new OutputCommitCoordinator(conf))
// Use Mockito.spy() to maintain the default infrastructure everywhere else.
// This mocking allows us to control the coordinator responses in test cases.
SparkEnv.createDriverEnv(conf, isLocal, listenerBus, Some(outputCommitCoordinator))
}
}
// Use Mockito.spy() to maintain the default infrastructure everywhere else
val mockTaskScheduler = spy(sc.taskScheduler.asInstanceOf[TaskSchedulerImpl])

Expand Down Expand Up @@ -109,6 +124,7 @@ class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter {
after {
sc.stop()
tempDir.delete()
outputCommitCoordinator = null
}

test("Only one of two duplicate commit tasks should commit") {
Expand All @@ -124,6 +140,20 @@ class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter {
0 until rdd.partitions.size, allowLocal = false)
assert(tempDir.list().size === 1)
}

test("Job should not complete if all commits are denied") {
doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit(
Matchers.any(), Matchers.any(), Matchers.any())
val rdd: RDD[Int] = sc.parallelize(Seq(1), 1)
def resultHandler(x: Int, y: Unit): Unit = {}
val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd,
OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully,
0 until rdd.partitions.size, resultHandler, 0)
intercept[TimeoutException] {
Await.result(futureAction, 5 seconds)
}
assert(tempDir.list().size === 0)
}
}

/**
Expand All @@ -145,11 +175,13 @@ private case class OutputCommitFunctions(tempDirPath: String) {
}
}

def commitSuccessfully(ctx: TaskContext, iter: Iterator[Int]): Unit = {
def commitSuccessfully(iter: Iterator[Int]): Unit = {
val ctx = TaskContext.get()
runCommitWithProvidedCommitter(ctx, iter, successfulOutputCommitter)
}

def failFirstCommitAttempt(ctx: TaskContext, iter: Iterator[Int]): Unit = {
def failFirstCommitAttempt(iter: Iterator[Int]): Unit = {
val ctx = TaskContext.get()
runCommitWithProvidedCommitter(ctx, iter,
if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter)
}
Expand Down

0 comments on commit ede7590

Please sign in to comment.