Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-1745] Move interrupted flag from TaskContext constructor (minor) #675

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,23 @@ import org.apache.spark.executor.TaskMetrics
*/
@DeveloperApi
class TaskContext(
val stageId: Int,
val partitionId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
@volatile var interrupted: Boolean = false,
private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty
) extends Serializable {
val stageId: Int,
val partitionId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends Serializable {

@deprecated("use partitionId", "0.8.1")
def splitId = partitionId

// List of callback functions to execute when the task completes.
@transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]

// Set to true when the task is completed, before the onCompleteCallbacks are executed.
// Whether the corresponding task has been killed.
@volatile var interrupted: Boolean = false

// Whether the task has completed, before the onCompleteCallbacks are executed.
@volatile var completed: Boolean = false

/**
Expand All @@ -58,6 +60,6 @@ class TaskContext(
def executeOnCompleteCallbacks() {
completed = true
// Process complete callbacks in the reverse order of registration
onCompleteCallbacks.reverse.foreach{_()}
onCompleteCallbacks.reverse.foreach { _() }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import scala.collection.mutable.HashMap
import scala.util.Try

import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
Expand Down Expand Up @@ -70,7 +69,7 @@ private[spark] object ShuffleMapTask {
}

// Since both the JarSet and FileSet have the same format this is used for both.
def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = {
def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = {
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val objIn = new ObjectInputStream(in)
val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ public void persist() {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
TaskContext context = new TaskContext(0, 0, 0, false, false, new TaskMetrics());
TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
}

Expand Down
10 changes: 3 additions & 7 deletions core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.mock.EasyMockSugar

import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage._

// TODO: Test the CacheManager's thread-safety aspects
Expand Down Expand Up @@ -59,8 +58,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}

whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
taskMetrics = TaskMetrics.empty)
val context = new TaskContext(0, 0, 0)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
Expand All @@ -72,8 +70,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}

whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
taskMetrics = TaskMetrics.empty)
val context = new TaskContext(0, 0, 0)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
Expand All @@ -86,8 +83,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}

whenExecuting(blockManager) {
val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false,
taskMetrics = TaskMetrics.empty)
val context = new TaskContext(0, 0, 0, runningLocally = true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
Expand Down
4 changes: 1 addition & 3 deletions core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,12 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
val tContext = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
taskMetrics = TaskMetrics.empty)
val tContext = new TaskContext(0, 0, 0)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
} else {
// printenv isn't available so just pass the test
assert(true)
}
}

Expand Down