Skip to content

Commit

Permalink
Unit tests for OutputCommitCoordinator
Browse files Browse the repository at this point in the history
  • Loading branch information
mccheah committed Jan 22, 2015
1 parent 6e6f748 commit bc80770
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

package org.apache.spark.scheduler

import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration

import akka.actor.{PoisonPill, ActorRef, Actor}

import org.apache.spark.Logging
import org.apache.spark.util.{AkkaUtils, ActorLogReceive}

import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration

private[spark] sealed trait OutputCommitCoordinationMessage

private[spark] case class StageStarted(stage: Int) extends OutputCommitCoordinationMessage
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.scheduler

import scala.util.control.NonFatal

class DAGSchedulerSingleThreadedProcessLoop(dagScheduler: DAGScheduler)
extends DAGSchedulerEventProcessLoop(dagScheduler) {

override def post(event: DAGSchedulerEvent): Unit = {
try {
// Forward event to `onReceive` directly to avoid processing event asynchronously.
onReceive(event)
} catch {
case NonFatal(e) => onError(e)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.scheduler

import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map}
import scala.language.reflectiveCalls
import scala.util.control.NonFatal

import org.scalatest.{BeforeAndAfter, FunSuiteLike}
import org.scalatest.concurrent.Timeouts
Expand All @@ -32,19 +31,6 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
import org.apache.spark.util.CallSite
import org.apache.spark.executor.TaskMetrics

class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
extends DAGSchedulerEventProcessLoop(dagScheduler) {

override def post(event: DAGSchedulerEvent): Unit = {
try {
// Forward event to `onReceive` directly to avoid processing event asynchronously.
onReceive(event)
} catch {
case NonFatal(e) => onError(e)
}
}
}

/**
* An RDD for passing to DAGScheduler. These RDDs will use the dependencies and
* preferredLocations (if any) that are passed to them. They are deliberately not executable
Expand Down Expand Up @@ -171,7 +157,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
runLocallyWithinThread(job)
}
}
dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler)
dagEventProcessLoopTester = new DAGSchedulerSingleThreadedProcessLoop(scheduler)
}

override def afterAll() {
Expand Down Expand Up @@ -399,7 +385,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
runLocallyWithinThread(job)
}
}
dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(noKillScheduler)
dagEventProcessLoopTester = new DAGSchedulerSingleThreadedProcessLoop(noKillScheduler)
val jobId = submit(new MyRDD(sc, 1, Nil), Array(0))
cancel(jobId)
// Because the job wasn't actually cancelled, we shouldn't have received a failure message.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.scheduler

import java.io.{ObjectInputStream, ObjectOutputStream, IOException}

import scala.collection.mutable

import org.scalatest.concurrent.Timeouts
import org.scalatest.{BeforeAndAfter, FunSuiteLike}

import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter}
import org.mockito.Mockito._

import org.apache.spark._
import org.apache.spark.executor.{TaskMetrics}
import org.apache.spark.rdd.FakeOutputCommitter

/**
* Unit tests for the output commit coordination functionality. Overrides the SchedulerImpl
* to just run the tasks directly and send completion or error messages back to the
* DAG scheduler.
*/
class OutputCommitCoordinatorSuite
extends FunSuiteLike
with BeforeAndAfter
with LocalSparkContext
with Timeouts {

val conf = new SparkConf().set("spark.localExecution.enabled", "true")

var taskScheduler: TaskSchedulerImpl = null
var dagScheduler: DAGScheduler = null
var dagSchedulerEventProcessLoop: DAGSchedulerEventProcessLoop = null
var accum: Accumulator[Int] = null
var accumId: Long = 0

before {
sc = new SparkContext("local", "Output Commit Coordinator Suite")
accum = sc.accumulator[Int](0)
Accumulators.register(accum, true)
accumId = accum.id

taskScheduler = new TaskSchedulerImpl(sc, 4, true) {
override def submitTasks(taskSet: TaskSet) {
// Instead of submitting a task to some executor, just run the task directly.
// Make two attempts. The first may or may not succeed. If the first
// succeeds then the second is redundant and should be handled
// accordingly by OutputCommitCoordinator. Otherwise the second
// should not be blocked from succeeding.
execTasks(taskSet, 0)
execTasks(taskSet, 1)
}

private def execTasks(taskSet: TaskSet, attemptNumber: Int) {
var taskIndex = 0
taskSet.tasks.foreach(t => {
val tid = newTaskId
val taskInfo = new TaskInfo(tid, taskIndex, 0, System.currentTimeMillis, "0",
"localhost", TaskLocality.NODE_LOCAL, false)
taskIndex += 1
// Track the successful commits in an accumulator. However, we can't just invoke
// accum += 1 since this unit test circumvents the usual accumulator updating
// infrastructure. So just send the accumulator update manually.
val accumUpdates = new mutable.HashMap[Long, Any]
try {
accumUpdates(accumId) = t.run(attemptNumber, attemptNumber)
dagSchedulerEventProcessLoop.post(
new CompletionEvent(t, Success, 0, accumUpdates, taskInfo, new TaskMetrics))
} catch {
case e: Throwable =>
dagSchedulerEventProcessLoop.post(new CompletionEvent(t, new ExceptionFailure(e,
Option.empty[TaskMetrics]), 1, accumUpdates, taskInfo, new TaskMetrics))
}
})
}
}

dagScheduler = new DAGScheduler(sc, taskScheduler)
taskScheduler.setDAGScheduler(dagScheduler)
sc.dagScheduler = dagScheduler
dagSchedulerEventProcessLoop = new DAGSchedulerSingleThreadedProcessLoop(dagScheduler)
}

/**
* Function that constructs a SparkHadoopWriter with a mock committer and runs its commit
*/
private class OutputCommittingFunctionWithAccumulator(var accum: Accumulator[Int])
extends ((TaskContext, Iterator[Int]) => Int) with Serializable {

def apply(ctxt: TaskContext, it: Iterator[Int]): Int = {
val outputCommitter = new FakeOutputCommitter {
override def commitTask(taskAttemptContext: TaskAttemptContext) {
super.commitTask(taskAttemptContext)
}
}
runCommitWithProvidedCommitter(ctxt, it, outputCommitter)
}

protected def runCommitWithProvidedCommitter(
ctxt: TaskContext,
it: Iterator[Int],
outputCommitter: OutputCommitter): Int = {
def jobConf = new JobConf {
override def getOutputCommitter(): OutputCommitter = outputCommitter
}
val sparkHadoopWriter = new SparkHadoopWriter(jobConf) {
override def newTaskAttemptContext(
conf: JobConf,
attemptId: TaskAttemptID): TaskAttemptContext = {
mock(classOf[TaskAttemptContext])
}
}
sparkHadoopWriter.setup(ctxt.stageId, ctxt.partitionId, ctxt.attemptNumber)
sparkHadoopWriter.commit
if (FakeOutputCommitter.ran) {
FakeOutputCommitter.ran = false
1
} else {
0
}
}

@throws(classOf[IOException])
private def writeObject(out: ObjectOutputStream) {
out.writeObject(accum)
}

@throws(classOf[IOException])
private def readObject(in: ObjectInputStream) {
accum = in.readObject.asInstanceOf[Accumulator[Int]]
}
}

/**
* Function that will explicitly fail to commit on the first attempt
*/
private class FailFirstTimeCommittingFunctionWithAccumulator(accum: Accumulator[Int])
extends OutputCommittingFunctionWithAccumulator(accum) {
override def apply(ctxt: TaskContext, it: Iterator[Int]): Int = {
if (ctxt.attemptNumber == 0) {
val outputCommitter = new FakeOutputCommitter {
override def commitTask(taskAttemptContext: TaskAttemptContext) {
throw new RuntimeException
}
}
runCommitWithProvidedCommitter(ctxt, it, outputCommitter)
} else {
super.apply(ctxt, it)
}
}
}

test("Only one of two duplicate commit tasks should commit") {
val rdd = sc.parallelize(1 to 10, 10)
sc.runJob(rdd, new OutputCommittingFunctionWithAccumulator(accum))
assert(accum.value === 10)
}

test("If commit fails, if task is retried it should not be locked, and will succeed.") {
val rdd = sc.parallelize(Seq(1), 1)
sc.runJob(rdd, new FailFirstTimeCommittingFunctionWithAccumulator(accum))
assert(accum.value == 1)
}
}

0 comments on commit bc80770

Please sign in to comment.