Skip to content

Commit

Permalink
Some cleanup in OutputCommitCoordinatorSuite
Browse files Browse the repository at this point in the history
I found having classes that extended Function to be more confusing than
treating a method as a function.  I also pulled the nested classes out
of the test suite in order to remove the special serialization logic
that was needed to avoid serializing the entire suite.
  • Loading branch information
JoshRosen committed Feb 3, 2015
1 parent a7c0e29 commit f582574
Showing 1 changed file with 54 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

package org.apache.spark.scheduler

import java.io.{File, ObjectInputStream, ObjectOutputStream, IOException}
import java.io.File

import org.mockito.Matchers
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.concurrent.Timeouts
import org.scalatest.{BeforeAndAfter, FunSuite}

import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter}
Expand All @@ -32,8 +31,6 @@ import org.apache.spark._
import org.apache.spark.rdd.FakeOutputCommitter
import org.apache.spark.util.Utils

import scala.collection.mutable.ArrayBuffer

/**
* Unit tests for the output commit coordination functionality.
*
Expand Down Expand Up @@ -62,29 +59,21 @@ import scala.collection.mutable.ArrayBuffer
* increments would be captured even though the commit in both tasks was executed
* erroneously.
*/
class OutputCommitCoordinatorSuite
extends FunSuite
with BeforeAndAfter
with Timeouts {

val conf = new SparkConf()
.set("spark.localExecution.enabled", "true")
class OutputCommitCoordinatorSuite extends FunSuite with BeforeAndAfter {

var dagScheduler: DAGScheduler = null
var tempDir: File = null
var tempDirPath: String = null
var sc: SparkContext = null

before {
sc = new SparkContext("local[4]", "Output Commit Coordinator Suite")
sc = new SparkContext("local[4]", classOf[OutputCommitCoordinatorSuite].getSimpleName)
tempDir = Utils.createTempDir()
tempDirPath = tempDir.getAbsolutePath()
// Use Mockito.spy() to maintain the default infrastructure everywhere else
val mockTaskScheduler = spy(sc.taskScheduler.asInstanceOf[TaskSchedulerImpl])

doAnswer(new Answer[Unit]() {
override def answer(invoke: InvocationOnMock): Unit = {
// Submit the tasks, then, force the task scheduler to dequeue the
// Submit the tasks, then force the task scheduler to dequeue the
// speculated task
invoke.callRealMethod()
mockTaskScheduler.backend.reviveOffers()
Expand All @@ -94,17 +83,17 @@ class OutputCommitCoordinatorSuite
doAnswer(new Answer[TaskSetManager]() {
override def answer(invoke: InvocationOnMock): TaskSetManager = {
val taskSet = invoke.getArguments()(0).asInstanceOf[TaskSet]
return new TaskSetManager(mockTaskScheduler, taskSet, 4) {
new TaskSetManager(mockTaskScheduler, taskSet, 4) {
var hasDequeuedSpeculatedTask = false
override def dequeueSpeculativeTask(
execId: String,
host: String,
locality: TaskLocality.Value): Option[(Int, TaskLocality.Value)] = {
if (!hasDequeuedSpeculatedTask) {
hasDequeuedSpeculatedTask = true
return Some(0, TaskLocality.PROCESS_LOCAL)
Some(0, TaskLocality.PROCESS_LOCAL)
} else {
return None
None
}
}
}
Expand All @@ -122,82 +111,64 @@ class OutputCommitCoordinatorSuite
tempDir.delete()
}

/**
* Function that constructs a SparkHadoopWriter with a mock committer and runs its commit
*/
private class OutputCommittingFunction(private var tempDirPath: String)
extends ((TaskContext, Iterator[Int]) => Int) with Serializable {
test("Only one of two duplicate commit tasks should commit") {
val rdd = sc.parallelize(Seq(1), 1)
sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully _,
0 until rdd.partitions.size, allowLocal = false)
assert(tempDir.list().size === 1)
}

def apply(ctxt: TaskContext, it: Iterator[Int]): Int = {
val outputCommitter = new FakeOutputCommitter {
override def commitTask(context: TaskAttemptContext) : Unit = {
Utils.createDirectory(tempDirPath)
}
}
runCommitWithProvidedCommitter(ctxt, it, outputCommitter)
}
test("If commit fails, if task is retried it should not be locked, and will succeed.") {
val rdd = sc.parallelize(Seq(1), 1)
sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).failFirstCommitAttempt _,
0 until rdd.partitions.size, allowLocal = false)
assert(tempDir.list().size === 1)
}
}

protected def runCommitWithProvidedCommitter(
ctxt: TaskContext,
it: Iterator[Int],
outputCommitter: OutputCommitter): Int = {
def jobConf = new JobConf {
override def getOutputCommitter(): OutputCommitter = outputCommitter
}
val sparkHadoopWriter = new SparkHadoopWriter(jobConf) {
override def newTaskAttemptContext(
conf: JobConf,
attemptId: TaskAttemptID): TaskAttemptContext = {
mock(classOf[TaskAttemptContext])
}
}
sparkHadoopWriter.setup(ctxt.stageId, ctxt.partitionId, ctxt.attemptNumber)
sparkHadoopWriter.commit
0
}
/**
* Class with methods that can be passed to runJob to test commits with a mock committer.
*/
private case class OutputCommitFunctions(tempDirPath: String) {

// Need this otherwise the entire test suite attempts to be serialized
@throws(classOf[IOException])
private def writeObject(out: ObjectOutputStream): Unit = {
out.writeUTF(tempDirPath)
// Mock output committer that simulates a successful commit (after commit is authorized)
private def successfulOutputCommitter = new FakeOutputCommitter {
override def commitTask(context: TaskAttemptContext): Unit = {
Utils.createDirectory(tempDirPath)
}
}

@throws(classOf[IOException])
private def readObject(in: ObjectInputStream): Unit = {
tempDirPath = in.readUTF()
// Mock output committer that simulates a failed commit (after commit is authorized)
private def failingOutputCommitter = new FakeOutputCommitter {
override def commitTask(taskAttemptContext: TaskAttemptContext) {
throw new RuntimeException
}
}

/**
* Function that will explicitly fail to commit on the first attempt
*/
private class FailFirstTimeCommittingFunction(private var tempDirPath: String)
extends OutputCommittingFunction(tempDirPath) {
override def apply(ctxt: TaskContext, it: Iterator[Int]): Int = {
if (ctxt.attemptNumber == 0) {
val outputCommitter = new FakeOutputCommitter {
override def commitTask(taskAttemptContext: TaskAttemptContext) {
throw new RuntimeException
}
}
runCommitWithProvidedCommitter(ctxt, it, outputCommitter)
} else {
super.apply(ctxt, it)
}
}
def commitSuccessfully(ctx: TaskContext, iter: Iterator[Int]): Unit = {
runCommitWithProvidedCommitter(ctx, iter, successfulOutputCommitter)
}

test("Only one of two duplicate commit tasks should commit") {
val rdd = sc.parallelize(Seq(1), 1)
sc.runJob(rdd, new OutputCommittingFunction(tempDirPath),
0 until rdd.partitions.size, allowLocal = true)
assert(tempDir.list().size === 1)
def failFirstCommitAttempt(ctx: TaskContext, iter: Iterator[Int]): Unit = {
runCommitWithProvidedCommitter(ctx, iter,
if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter)
}

test("If commit fails, if task is retried it should not be locked, and will succeed.") {
val rdd = sc.parallelize(Seq(1), 1)
sc.runJob(rdd, new FailFirstTimeCommittingFunction(tempDirPath),
0 until rdd.partitions.size, allowLocal = true)
assert(tempDir.list().size === 1)
private def runCommitWithProvidedCommitter(
ctx: TaskContext,
iter: Iterator[Int],
outputCommitter: OutputCommitter): Unit = {
def jobConf = new JobConf {
override def getOutputCommitter(): OutputCommitter = outputCommitter
}
val sparkHadoopWriter = new SparkHadoopWriter(jobConf) {
override def newTaskAttemptContext(
conf: JobConf,
attemptId: TaskAttemptID): TaskAttemptContext = {
mock(classOf[TaskAttemptContext])
}
}
sparkHadoopWriter.setup(ctx.stageId, ctx.partitionId, ctx.attemptNumber)
sparkHadoopWriter.commit()
}
}

0 comments on commit f582574

Please sign in to comment.