Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Apr 26, 2017
1 parent 0616716 commit 4fcee4f
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ private[spark] object TaskMetrics extends Logging {
def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = {
val tm = new TaskMetrics
for (acc <- accums) {
val name = AccumulatorContext.get(acc.id).flatMap(_.name)
val name = acc.name
if (name.isDefined && tm.nameToAccums.contains(name.get)) {
val tmAcc = tm.nameToAccums(name.get).asInstanceOf[AccumulatorV2[Any, Any]]
tmAcc.metadata = acc.metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.{AccumulatorContext, LongAccumulator, ThreadUtils, Utils}
import org.apache.spark.util.{LongAccumulator, ThreadUtils, Utils}

/**
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
Expand Down Expand Up @@ -100,8 +100,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
// We need to do this here on the driver because if we did this on the executors then
// we would have to serialize the result again after updating the size.
result.accumUpdates = result.accumUpdates.map { a =>
val accName = AccumulatorContext.get(a.id).flatMap(_.name)
if (accName == Some(InternalAccumulator.RESULT_SIZE)) {
if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
val acc = a.asInstanceOf[LongAccumulator]
assert(acc.sum == 0L, "task result size should not have been set on the executors")
acc.setValue(size.toLong)
Expand Down
8 changes: 6 additions & 2 deletions core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
* Returns the name of this accumulator, can only be called after registration.
*/
final def name: Option[String] = {
assertMetadataNotNull()
metadata.name
if (atDriverSide) {
AccumulatorContext.get(id).flatMap(_.metadata.name)
} else {
assertMetadataNotNull()
metadata.name
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark._
import org.apache.spark.storage.TaskResultBlockId
import org.apache.spark.TestUtils.JavaSourceFromString
import org.apache.spark.util.{AccumulatorContext, MutableURLClassLoader, RpcUtils, Utils}
import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils}


/**
Expand Down Expand Up @@ -242,12 +242,8 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
assert(resultGetter.taskResults.size === 1)
val resBefore = resultGetter.taskResults.head
val resAfter = captor.getValue
val resSizeBefore = resBefore.accumUpdates.find { acc =>
AccumulatorContext.get(acc.id).flatMap(_.name) == Some(RESULT_SIZE)
}.map(_.value)
val resSizeAfter = resAfter.accumUpdates.find { acc =>
AccumulatorContext.get(acc.id).flatMap(_.name) == Some(RESULT_SIZE)
}.map(_.value)
val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value)
val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value)
assert(resSizeBefore.exists(_ == 0L))
assert(resSizeAfter.exists(_.toString.toLong > 0L))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
val execId = "exe-1"

def makeTaskMetrics(base: Int): TaskMetrics = {
val taskMetrics = TaskMetrics.empty
val taskMetrics = TaskMetrics.registered
val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics()
val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics
val inputMetrics = taskMetrics.inputMetrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
hasHadoopInput: Boolean,
hasOutput: Boolean,
hasRecords: Boolean = true) = {
val t = TaskMetrics.empty
val t = TaskMetrics.registered
// Set CPU times same as wall times for testing purpose
t.setExecutorDeserializeTime(a)
t.setExecutorDeserializeCpuTime(a)
Expand Down

0 comments on commit 4fcee4f

Please sign in to comment.