From 4fcee4ff3452b9857ef8fbe8b5a5f7035a9bc384 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 25 Apr 2017 22:28:48 +0800 Subject: [PATCH] address comments --- .../scala/org/apache/spark/executor/TaskMetrics.scala | 2 +- .../org/apache/spark/scheduler/TaskResultGetter.scala | 5 ++--- .../scala/org/apache/spark/util/AccumulatorV2.scala | 8 ++++++-- .../apache/spark/scheduler/TaskResultGetterSuite.scala | 10 +++------- .../spark/ui/jobs/JobProgressListenerSuite.scala | 2 +- .../org/apache/spark/util/JsonProtocolSuite.scala | 2 +- 6 files changed, 14 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 54c75f3b7e9b2..45c0a98eb0a1f 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -306,7 +306,7 @@ private[spark] object TaskMetrics extends Logging { def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = { val tm = new TaskMetrics for (acc <- accums) { - val name = AccumulatorContext.get(acc.id).flatMap(_.name) + val name = acc.name if (name.isDefined && tm.nameToAccums.contains(name.get)) { val tmAcc = tm.nameToAccums(name.get).asInstanceOf[AccumulatorV2[Any, Any]] tmAcc.metadata = acc.metadata diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 0e00fc257bbc0..a284f7956cd31 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -27,7 +27,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.Logging import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.{AccumulatorContext, LongAccumulator, ThreadUtils, Utils} +import org.apache.spark.util.{LongAccumulator, ThreadUtils, Utils} /** * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. @@ -100,8 +100,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul // We need to do this here on the driver because if we did this on the executors then // we would have to serialize the result again after updating the size. result.accumUpdates = result.accumUpdates.map { a => - val accName = AccumulatorContext.get(a.id).flatMap(_.name) - if (accName == Some(InternalAccumulator.RESULT_SIZE)) { + if (a.name == Some(InternalAccumulator.RESULT_SIZE)) { val acc = a.asInstanceOf[LongAccumulator] assert(acc.sum == 0L, "task result size should not have been set on the executors") acc.setValue(size.toLong) diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 31d064a08f04a..a65ec75cc5db6 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -84,8 +84,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { * Returns the name of this accumulator, can only be called after registration. */ final def name: Option[String] = { - assertMetadataNotNull() - metadata.name + if (atDriverSide) { + AccumulatorContext.get(id).flatMap(_.metadata.name) + } else { + assertMetadataNotNull() + metadata.name + } } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index f86e84457ba8b..3e55d399e9df9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -36,7 +36,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.storage.TaskResultBlockId import org.apache.spark.TestUtils.JavaSourceFromString -import org.apache.spark.util.{AccumulatorContext, MutableURLClassLoader, RpcUtils, Utils} +import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils} /** @@ -242,12 +242,8 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local assert(resultGetter.taskResults.size === 1) val resBefore = resultGetter.taskResults.head val resAfter = captor.getValue - val resSizeBefore = resBefore.accumUpdates.find { acc => - AccumulatorContext.get(acc.id).flatMap(_.name) == Some(RESULT_SIZE) - }.map(_.value) - val resSizeAfter = resAfter.accumUpdates.find { acc => - AccumulatorContext.get(acc.id).flatMap(_.name) == Some(RESULT_SIZE) - }.map(_.value) + val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value) + val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value) assert(resSizeBefore.exists(_ == 0L)) assert(resSizeAfter.exists(_.toString.toLong > 0L)) } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 93964a2d56743..48be3be81755a 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -293,7 +293,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val execId = "exe-1" def makeTaskMetrics(base: Int): TaskMetrics = { - val taskMetrics = TaskMetrics.empty + val taskMetrics = TaskMetrics.registered val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics() val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics val inputMetrics = taskMetrics.inputMetrics diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a64dbeae47294..a77c8e3cab4e8 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -830,7 +830,7 @@ private[spark] object JsonProtocolSuite extends Assertions { hasHadoopInput: Boolean, hasOutput: Boolean, hasRecords: Boolean = true) = { - val t = TaskMetrics.empty + val t = TaskMetrics.registered // Set CPU times same as wall times for testing purpose t.setExecutorDeserializeTime(a) t.setExecutorDeserializeCpuTime(a)