From b500538bbb0c7c76eb13d007246f7ed3c02f6d47 Mon Sep 17 00:00:00 2001 From: Sumedh Wale Date: Mon, 28 Nov 2016 16:21:36 +0530 Subject: [PATCH] [SNAP-1136] Kryo closure serialtization support and optimizations (#27) - added back configurable closure serializer in Spark which was removed in SPARK-12414; some minor changes taken from closed Spark PR https://github.com/apache/spark/pull/6361 - added optimized Kryo serialization for multiple classes; currently registration and string sharing fix for kryo (https://github.com/EsotericSoftware/kryo/issues/128) is only in the SnappyData layer PooledKryoSerializer implementation; classes providing maximum benefit have added KryoSerializable notably Accumulators and *Metrics - use closureSerializer for Netty messaging too instead of fixed JavaSerializer - updated kryo to 4.0.0 to get the fix for kryo#342 - actually fixing scalastyle errors introduced by d80ef1b4 - set ordering field with kryo serialization in GenerateOrdering - removed warning if non-closure passed for cleaning Conflicts: core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala core/src/main/scala/org/apache/spark/scheduler/Task.scala core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala core/src/main/scala/org/apache/spark/storage/BlockId.scala core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala --- build.gradle | 3 +- common/unsafe/build.gradle | 5 +- core/build.gradle | 9 +- .../scala/org/apache/spark/SparkEnv.scala | 58 ++++++--- .../apache/spark/executor/InputMetrics.scala | 20 +++- .../apache/spark/executor/OutputMetrics.scala | 20 +++- .../spark/executor/ShuffleReadMetrics.scala | 28 ++++- .../spark/executor/ShuffleWriteMetrics.scala | 22 +++- .../apache/spark/executor/TaskMetrics.scala | 39 ++++++- .../netty/NettyBlockTransferService.scala | 7 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 5 +- .../spark/rdd/ZippedPartitionsRDD.scala | 30 ++++- .../org/apache/spark/rpc/RpcEndpointRef.scala | 10 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 4 + .../apache/spark/rpc/netty/NettyRpcEnv.scala | 58 +++++++-- .../apache/spark/scheduler/ResultTask.scala | 27 ++++- .../spark/scheduler/ShuffleMapTask.scala | 19 ++- .../org/apache/spark/scheduler/Task.scala | 53 +++++++-- .../spark/scheduler/TaskDescription.scala | 65 +++++++---- .../spark/scheduler/TaskSetManager.scala | 10 +- .../cluster/CoarseGrainedClusterMessage.scala | 28 ++++- .../org/apache/spark/storage/BlockId.scala | 19 ++- .../spark/storage/BlockManagerMessages.scala | 55 ++++++++- .../org/apache/spark/util/AccumulatorV2.scala | 110 +++++++++++++++++- .../apache/spark/util/ClosureCleaner.scala | 2 +- .../spark/util/SerializableBuffer.scala | 23 +++- .../apache/spark/util/collection/BitSet.scala | 32 ++++- .../codegen/GenerateOrdering.scala | 5 +- .../sql/execution/metric/SQLMetrics.scala | 21 +++- 29 files changed, 666 insertions(+), 121 deletions(-) diff --git a/build.gradle b/build.gradle index 2f3f420315a64..01d46b079a504 100644 --- a/build.gradle +++ b/build.gradle @@ -61,7 +61,8 @@ allprojects { javaxServletVersion = '3.1.0' guavaVersion = '14.0.1' hiveVersion = '1.2.1.spark2' - chillVersion = '0.8.0' + chillVersion = '0.8.1' + kryoVersion = '4.0.0' nettyVersion = '3.8.0.Final' nettyAllVersion = '4.0.29.Final' derbyVersion = '10.12.1.1' diff --git a/common/unsafe/build.gradle b/common/unsafe/build.gradle index ee2347c9eb872..b14fed1ab31d7 100644 --- a/common/unsafe/build.gradle +++ b/common/unsafe/build.gradle @@ -20,7 +20,10 @@ description = 'Spark Project Unsafe' dependencies { compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) - compile group: 'com.twitter', name: 'chill_' + scalaBinaryVersion, version: chillVersion + compile group: 'com.esotericsoftware', name: 'kryo-shaded', version: kryoVersion + compile(group: 'com.twitter', name: 'chill_' + scalaBinaryVersion, version: chillVersion) { + exclude(group: 'com.esotericsoftware', module: 'kryo-shaded') + } compile group: 'com.google.code.findbugs', name: 'jsr305', version: jsr305Version compile group: 'com.google.guava', name: 'guava', version: guavaVersion diff --git a/core/build.gradle b/core/build.gradle index 1caee72201e40..ebeff567df64d 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -40,8 +40,13 @@ dependencies { exclude(group: 'org.apache.avro', module: 'avro-ipc') } compile group: 'com.google.guava', name: 'guava', version: guavaVersion - compile group: 'com.twitter', name: 'chill_' + scalaBinaryVersion, version: chillVersion - compile group: 'com.twitter', name: 'chill-java', version: chillVersion + compile group: 'com.esotericsoftware', name: 'kryo-shaded', version: kryoVersion + compile(group: 'com.twitter', name: 'chill_' + scalaBinaryVersion, version: chillVersion) { + exclude(group: 'com.esotericsoftware', module: 'kryo-shaded') + } + compile(group: 'com.twitter', name: 'chill-java', version: chillVersion) { + exclude(group: 'com.esotericsoftware', module: 'kryo-shaded') + } compile group: 'org.apache.xbean', name: 'xbean-asm5-shaded', version: '4.4' // explicitly include netty from akka-remote to not let zookeeper override it compile group: 'io.netty', name: 'netty', version: nettyVersion diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 2ff379cced1a8..8a9d8a9292b34 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -170,6 +170,43 @@ object SparkEnv extends Logging { env } + // Create an instance of the class with the given name, possibly initializing it with our conf + def instantiateClass[T](className: String, conf: SparkConf, + isDriver: Boolean): T = { + val cls = Utils.classForName(className) + // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just + // SparkConf, then one taking no arguments + try { + cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE) + .newInstance(conf, new java.lang.Boolean(isDriver)) + .asInstanceOf[T] + } catch { + case _: NoSuchMethodException => + try { + cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] + } catch { + case _: NoSuchMethodException => + cls.getConstructor().newInstance().asInstanceOf[T] + } + } + } + + def getClosureSerializer(conf: SparkConf, doLog: Boolean = false): Serializer = { + val defaultClosureSerializerClass = classOf[JavaSerializer].getName + val closureSerializerClass = conf.get("spark.closure.serializer", + defaultClosureSerializerClass) + val closureSerializer = instantiateClass[Serializer]( + closureSerializerClass, conf, isDriver = false) + if (doLog) { + if (closureSerializerClass != defaultClosureSerializerClass) { + logInfo(s"Using non-default closure serializer: $closureSerializerClass") + } else { + logDebug(s"Using closure serializer: $closureSerializerClass") + } + } + closureSerializer + } + /** * Create a SparkEnv for the driver. */ @@ -272,26 +309,9 @@ object SparkEnv extends Logging { conf.set("spark.driver.port", rpcEnv.address.port.toString) } - // Create an instance of the class with the given name, possibly initializing it with our conf def instantiateClass[T](className: String): T = { - val cls = Utils.classForName(className) - // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just - // SparkConf, then one taking no arguments - try { - cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE) - .newInstance(conf, new java.lang.Boolean(isDriver)) - .asInstanceOf[T] - } catch { - case _: NoSuchMethodException => - try { - cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] - } catch { - case _: NoSuchMethodException => - cls.getConstructor().newInstance().asInstanceOf[T] - } - } + SparkEnv.instantiateClass(className, conf, isDriver) } - // Create an instance of the class named by the given SparkConf property, or defaultClassName // if the property is not set, possibly initializing it with our conf def instantiateClassFromConf[T](propertyName: String, defaultClassName: String): T = { @@ -304,7 +324,7 @@ object SparkEnv extends Logging { val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey) - val closureSerializer = new JavaSerializer(conf) + val closureSerializer = getClosureSerializer(conf, doLog = true) def registerOrLookupEndpoint( name: String, endpointCreator: => RpcEndpoint): diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala index 3d15f3a0396e1..1647b06ce0481 100644 --- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala @@ -17,6 +17,10 @@ package org.apache.spark.executor +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.LongAccumulator @@ -39,7 +43,7 @@ object DataReadMethod extends Enumeration with Serializable { * A collection of accumulators that represents metrics about reading data from external systems. */ @DeveloperApi -class InputMetrics private[spark] () extends Serializable { +class InputMetrics private[spark] () extends Serializable with KryoSerializable { private[executor] val _bytesRead = new LongAccumulator private[executor] val _recordsRead = new LongAccumulator @@ -56,4 +60,18 @@ class InputMetrics private[spark] () extends Serializable { private[spark] def incBytesRead(v: Long): Unit = _bytesRead.add(v) private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v) private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v) + + override def write(kryo: Kryo, output: Output): Unit = { + _bytesRead.write(kryo, output) + _recordsRead.write(kryo, output) + } + + override final def read(kryo: Kryo, input: Input): Unit = { + read(kryo, input, context = null) + } + + def read(kryo: Kryo, input: Input, context: TaskContext): Unit = { + _bytesRead.read(kryo, input, context) + _recordsRead.read(kryo, input, context) + } } diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala index dada9697c1cf9..418a831c7555f 100644 --- a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala @@ -17,6 +17,10 @@ package org.apache.spark.executor +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.LongAccumulator @@ -38,7 +42,7 @@ object DataWriteMethod extends Enumeration with Serializable { * A collection of accumulators that represents metrics about writing data to external systems. */ @DeveloperApi -class OutputMetrics private[spark] () extends Serializable { +class OutputMetrics private[spark] () extends Serializable with KryoSerializable { private[executor] val _bytesWritten = new LongAccumulator private[executor] val _recordsWritten = new LongAccumulator @@ -54,4 +58,18 @@ class OutputMetrics private[spark] () extends Serializable { private[spark] def setBytesWritten(v: Long): Unit = _bytesWritten.setValue(v) private[spark] def setRecordsWritten(v: Long): Unit = _recordsWritten.setValue(v) + + override def write(kryo: Kryo, output: Output): Unit = { + _bytesWritten.write(kryo, output) + _recordsWritten.write(kryo, output) + } + + override final def read(kryo: Kryo, input: Input): Unit = { + read(kryo, input, context = null) + } + + def read(kryo: Kryo, input: Input, context: TaskContext): Unit = { + _bytesWritten.read(kryo, input, context) + _recordsWritten.read(kryo, input, context) + } } diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index 5a435f2f9a1ba..6b49c8cbeb629 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -17,6 +17,10 @@ package org.apache.spark.executor +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.{DoubleAccumulator, LongAccumulator} @@ -27,7 +31,7 @@ import org.apache.spark.util.{DoubleAccumulator, LongAccumulator} * Operations are not thread-safe. */ @DeveloperApi -class ShuffleReadMetrics private[spark] () extends Serializable { +class ShuffleReadMetrics private[spark] () extends Serializable with KryoSerializable { private[executor] val _remoteBlocksFetched = new LongAccumulator private[executor] val _localBlocksFetched = new LongAccumulator private[executor] val _remoteBytesRead = new LongAccumulator @@ -121,6 +125,28 @@ class ShuffleReadMetrics private[spark] () extends Serializable { _recordsRead.add(metric.recordsRead) } } + + override def write(kryo: Kryo, output: Output): Unit = { + _remoteBlocksFetched.write(kryo, output) + _localBlocksFetched.write(kryo, output) + _remoteBytesRead.write(kryo, output) + _localBytesRead.write(kryo, output) + _fetchWaitTime.write(kryo, output) + _recordsRead.write(kryo, output) + } + + override final def read(kryo: Kryo, input: Input): Unit = { + read(kryo, input, context = null) + } + + def read(kryo: Kryo, input: Input, context: TaskContext): Unit = { + _remoteBlocksFetched.read(kryo, input, context) + _localBlocksFetched.read(kryo, input, context) + _remoteBytesRead.read(kryo, input, context) + _localBytesRead.read(kryo, input, context) + _fetchWaitTime.read(kryo, input, context) + _recordsRead.read(kryo, input, context) + } } /** diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala index ada2e1bc08593..f6aaf90d93b9c 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala @@ -17,6 +17,10 @@ package org.apache.spark.executor +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.LongAccumulator @@ -27,7 +31,7 @@ import org.apache.spark.util.LongAccumulator * Operations are not thread-safe. */ @DeveloperApi -class ShuffleWriteMetrics private[spark] () extends Serializable { +class ShuffleWriteMetrics private[spark] () extends Serializable with KryoSerializable { private[executor] val _bytesWritten = new LongAccumulator private[executor] val _recordsWritten = new LongAccumulator private[executor] val _writeTime = new LongAccumulator @@ -57,6 +61,22 @@ class ShuffleWriteMetrics private[spark] () extends Serializable { _recordsWritten.setValue(recordsWritten - v) } + override def write(kryo: Kryo, output: Output): Unit = { + _bytesWritten.write(kryo, output) + _recordsWritten.write(kryo, output) + _writeTime.write(kryo, output) + } + + override def read(kryo: Kryo, input: Input): Unit = { + read(kryo, input, context = null) + } + + def read(kryo: Kryo, input: Input, context: TaskContext): Unit = { + _bytesWritten.read(kryo, input, context) + _recordsWritten.read(kryo, input, context) + _writeTime.read(kryo, input, context) + } + // Legacy methods for backward compatibility. // TODO: remove these once we make this class private. @deprecated("use bytesWritten instead", "2.0.0") diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index dc6dc6878f567..3facda3df95bc 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -20,6 +20,9 @@ package org.apache.spark.executor import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, LinkedHashMap} +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging @@ -42,7 +45,7 @@ import org.apache.spark.util._ * be sent to the driver. */ @DeveloperApi -class TaskMetrics private[spark] () extends Serializable { +class TaskMetrics private[spark] () extends Serializable with KryoSerializable { // Each metric is internally represented as an accumulator private val _executorDeserializeTime = new DoubleAccumulator private val _executorDeserializeCpuTime = new LongAccumulator @@ -263,6 +266,40 @@ class TaskMetrics private[spark] () extends Serializable { // value will be updated at driver side. internalAccums.filter(a => !a.isZero || a == _resultSize) } + + override def write(kryo: Kryo, output: Output): Unit = { + _executorDeserializeTime.write(kryo, output) + _executorRunTime.write(kryo, output) + _resultSize.write(kryo, output) + _jvmGCTime.write(kryo, output) + _resultSerializationTime.write(kryo, output) + _memoryBytesSpilled.write(kryo, output) + _diskBytesSpilled.write(kryo, output) + _peakExecutionMemory.write(kryo, output) + _updatedBlockStatuses.write(kryo, output) + inputMetrics.write(kryo, output) + outputMetrics.write(kryo, output) + shuffleReadMetrics.write(kryo, output) + shuffleWriteMetrics.write(kryo, output) + } + + override def read(kryo: Kryo, input: Input): Unit = { + // read the TaskContext thread-local once + val taskContext = TaskContext.get() + _executorDeserializeTime.read(kryo, input, taskContext) + _executorRunTime.read(kryo, input, taskContext) + _resultSize.read(kryo, input, taskContext) + _jvmGCTime.read(kryo, input, taskContext) + _resultSerializationTime.read(kryo, input, taskContext) + _memoryBytesSpilled.read(kryo, input, taskContext) + _diskBytesSpilled.read(kryo, input, taskContext) + _peakExecutionMemory.read(kryo, input, taskContext) + _updatedBlockStatuses.read(kryo, input, taskContext) + inputMetrics.read(kryo, input, taskContext) + outputMetrics.read(kryo, input, taskContext) + shuffleReadMetrics.read(kryo, input, taskContext) + shuffleWriteMetrics.read(kryo, input, taskContext) + } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b7d8c35032763..e718336af82ca 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -24,9 +24,7 @@ import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag -import com.codahale.metrics.{Metric, MetricSet} - -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} @@ -35,7 +33,6 @@ import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempFileManager} import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils -import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -52,7 +49,7 @@ private[spark] class NettyBlockTransferService( extends BlockTransferService { // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. - private val serializer = new JavaSerializer(conf) + private val serializer = SparkEnv.getClosureSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0574abdca32ac..74871f836191e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -147,7 +147,8 @@ abstract class RDD[T: ClassTag]( def sparkContext: SparkContext = sc /** A unique ID for this RDD (within its SparkContext). */ - val id: Int = sc.newRddId() + protected var _id: Int = sc.newRddId() + def id: Int = _id /** A friendly name for this RDD */ @transient var name: String = _ @@ -1651,7 +1652,7 @@ abstract class RDD[T: ClassTag]( // Other internal methods and fields // ======================================================================= - private var storageLevel: StorageLevel = StorageLevel.NONE + protected var storageLevel: StorageLevel = StorageLevel.NONE /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ @transient private[spark] val creationSite = sc.getCallSite() diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 3cb1231bd3477..7d4e5595fe860 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -21,16 +21,19 @@ import java.io.{IOException, ObjectOutputStream} import scala.reflect.ClassTag +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext} import org.apache.spark.util.Utils private[spark] class ZippedPartitionsPartition( - idx: Int, + private var idx: Int, @transient private val rdds: Seq[RDD[_]], @transient val preferredLocations: Seq[String]) - extends Partition { + extends Partition with KryoSerializable { - override val index: Int = idx + override def index: Int = idx var partitionValues = rdds.map(rdd => rdd.partitions(idx)) def partitions: Seq[Partition] = partitionValues @@ -40,6 +43,27 @@ private[spark] class ZippedPartitionsPartition( partitionValues = rdds.map(rdd => rdd.partitions(idx)) oos.defaultWriteObject() } + + override def write(kryo: Kryo, output: Output): Unit = { + // Update the reference to parent split at the time of task serialization + partitionValues = rdds.map(rdd => rdd.partitions(idx)) + output.writeVarInt(idx, true) + output.writeVarInt(partitionValues.length, true) + for (p <- partitionValues) { + kryo.writeClassAndObject(output, p) + } + } + + override def read(kryo: Kryo, input: Input): Unit = { + idx = input.readVarInt(true) + var numPartitions = input.readVarInt(true) + val partitionBuilder = Seq.newBuilder[Partition] + while (numPartitions > 0) { + partitionBuilder += kryo.readClassAndObject(input).asInstanceOf[Partition] + numPartitions -= 1 + } + partitionValues = partitionBuilder.result() + } } private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 4d39f144dd198..4788230bbd000 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -27,12 +27,12 @@ import org.apache.spark.util.RpcUtils /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. */ -private[spark] abstract class RpcEndpointRef(conf: SparkConf) - extends Serializable with Logging { +private[spark] abstract class RpcEndpointRef(conf: SparkConf, + _env: RpcEnv) extends Serializable with Logging { - private[this] val maxRetries = RpcUtils.numRetries(conf) - private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) - private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf) + @transient protected var maxRetries = _env.maxRetries + @transient protected var retryWaitMs = _env.retryWaitMs + @transient protected var defaultAskTimeout = _env.defaultAskTimeout /** * return the address for the [[RpcEndpointRef]] diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index de2cc56bc6b16..59fce6a4e731b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -72,6 +72,10 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { private[spark] val defaultLookupTimeout = RpcUtils.lookupRpcTimeout(conf) + private[spark] val maxRetries = RpcUtils.numRetries(conf) + private[spark] val retryWaitMs = RpcUtils.retryWaitMs(conf) + private[spark] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf) + /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement * [[RpcEndpoint.self]]. Return `null` if the corresponding [[RpcEndpointRef]] does not exist. diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index a2936d6ad539c..79639f2f38745 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -29,7 +29,10 @@ import scala.reflect.ClassTag import scala.util.{DynamicVariable, Failure, Success, Try} import scala.util.control.NonFatal -import org.apache.spark.{SecurityManager, SparkConf} +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ @@ -37,12 +40,12 @@ import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.server._ import org.apache.spark.rpc._ -import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance, SerializationStream} -import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, ThreadUtils, Utils} +import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance, Serializer, SerializerInstance, SerializationStream} +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, RpcUtils, ThreadUtils, Utils} private[netty] class NettyRpcEnv( val conf: SparkConf, - javaSerializerInstance: JavaSerializerInstance, + serializer: Serializer, host: String, securityManager: SecurityManager, numUsableCores: Int) extends RpcEnv(conf) with Logging { @@ -52,6 +55,10 @@ private[netty] class NettyRpcEnv( "rpc", conf.getInt("spark.rpc.io.threads", 0)) + private val serializerInstance = new ThreadLocal[SerializerInstance] { + override def initialValue(): SerializerInstance = serializer.newInstance() + } + private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores) private val streamManager = new NettyStreamManager(this) @@ -255,7 +262,7 @@ private[netty] class NettyRpcEnv( } private[netty] def serialize(content: Any): ByteBuffer = { - javaSerializerInstance.serialize(content) + serializerInstance.get().serialize(content) } /** @@ -268,7 +275,7 @@ private[netty] class NettyRpcEnv( private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = { NettyRpcEnv.currentClient.withValue(client) { deserialize { () => - javaSerializerInstance.deserialize[T](bytes) + serializerInstance.get().deserialize[T](bytes) } } } @@ -453,10 +460,7 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { def create(config: RpcEnvConfig): RpcEnv = { val sparkConf = config.conf - // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support - // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance - val javaSerializerInstance = - new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance] + val serializer = SparkEnv.getClosureSerializer(sparkConf) val nettyEnv = new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress, config.securityManager, config.numUsableCores) @@ -500,7 +504,8 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { private[netty] class NettyRpcEndpointRef( @transient private val conf: SparkConf, private val endpointAddress: RpcEndpointAddress, - @transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) { + @transient @volatile private var nettyEnv: NettyRpcEnv) + extends RpcEndpointRef(conf, nettyEnv) with Serializable with KryoSerializable with Logging { @transient @volatile var client: TransportClient = _ @@ -511,6 +516,10 @@ private[netty] class NettyRpcEndpointRef( in.defaultReadObject() nettyEnv = NettyRpcEnv.currentEnv.value client = NettyRpcEnv.currentClient.value + + maxRetries = nettyEnv.maxRetries + retryWaitMs = nettyEnv.retryWaitMs + defaultAskTimeout = nettyEnv.defaultAskTimeout } private def writeObject(out: ObjectOutputStream): Unit = { @@ -519,6 +528,33 @@ private[netty] class NettyRpcEndpointRef( override def name: String = endpointAddress.name + override def write(kryo: Kryo, output: Output): Unit = { + val addr = address + output.writeString(_name) + if (addr != null && addr.host != null) { + output.writeString(addr.host) + output.writeInt(addr.port) + } else { + output.writeString(null) + } + } + + override def read(kryo: Kryo, input: Input): Unit = { + _name = input.readString() + _address = null + val host = input.readString() + if (host != null) { + val port = input.readInt() + _address = RpcEndpointAddress(host, port, _name) + } + nettyEnv = NettyRpcEnv.currentEnv.value + client = NettyRpcEnv.currentClient.value + + maxRetries = nettyEnv.maxRetries + retryWaitMs = nettyEnv.retryWaitMs + defaultAskTimeout = nettyEnv.defaultAskTimeout + } + override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 2acac37dcff94..86f188dbb6022 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -22,6 +22,9 @@ import java.lang.management.ManagementFactory import java.nio.ByteBuffer import java.util.Properties +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -38,7 +41,7 @@ import org.apache.spark.rdd.RDD * (RDD[T], (TaskContext, Iterator[T]) => U). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling - * @param outputId index of the task in this job (a job can launch tasks on only a subset of the + * @param _outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). * @param localProperties copy of thread-local properties set by the user on the driver side. * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side @@ -52,10 +55,10 @@ import org.apache.spark.rdd.RDD private[spark] class ResultTask[T, U]( stageId: Int, stageAttemptId: Int, - taskBinary: Broadcast[Array[Byte]], - partition: Partition, + private var taskBinary: Broadcast[Array[Byte]], + private var partition: Partition, locs: Seq[TaskLocation], - val outputId: Int, + private var _outputId: Int, localProperties: Properties, serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, @@ -65,6 +68,8 @@ private[spark] class ResultTask[T, U]( jobId, appId, appAttemptId) with Serializable { + final def outputId: Int = _outputId + @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } @@ -91,4 +96,18 @@ private[spark] class ResultTask[T, U]( override def preferredLocations: Seq[TaskLocation] = preferredLocs override def toString: String = "ResultTask(" + stageId + ", " + partitionId + ")" + + override def write(kryo: Kryo, output: Output): Unit = { + super.write(kryo, output) + kryo.writeClassAndObject(output, taskBinary) + kryo.writeClassAndObject(output, partition) + output.writeInt(_outputId) + } + + override def read(kryo: Kryo, input: Input): Unit = { + super.read(kryo, input) + taskBinary = kryo.readClassAndObject(input).asInstanceOf[Broadcast[Array[Byte]]] + partition = kryo.readClassAndObject(input).asInstanceOf[Partition] + _outputId = input.readInt() + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index ef2b3e8764b71..c5e757024ed45 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -23,6 +23,9 @@ import java.util.Properties import scala.language.existentials +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging @@ -53,8 +56,8 @@ import org.apache.spark.shuffle.ShuffleWriter private[spark] class ShuffleMapTask( stageId: Int, stageAttemptId: Int, - taskBinary: Broadcast[Array[Byte]], - partition: Partition, + private var taskBinary: Broadcast[Array[Byte]], + private var partition: Partition, @transient private var locs: Seq[TaskLocation], localProperties: Properties, serializedTaskMetrics: Array[Byte], @@ -112,4 +115,16 @@ private[spark] class ShuffleMapTask( override def preferredLocations: Seq[TaskLocation] = preferredLocs override def toString: String = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) + + override def write(kryo: Kryo, output: Output): Unit = { + super.write(kryo, output) + kryo.writeClassAndObject(output, taskBinary) + kryo.writeClassAndObject(output, partition) + } + + override def read(kryo: Kryo, input: Input): Unit = { + super.read(kryo, input) + taskBinary = kryo.readClassAndObject(input).asInstanceOf[Broadcast[Array[Byte]]] + partition = kryo.readClassAndObject(input).asInstanceOf[Partition] + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 2a9b0628c1131..867bca3b84de2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -20,6 +20,12 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer import java.util.Properties +import scala.collection.mutable +import scala.collection.mutable.HashMap + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config.APP_CALLER_CONTEXT @@ -38,9 +44,10 @@ import org.apache.spark.util._ * and sends the task output back to the driver application. A ShuffleMapTask executes the task * and divides the task output to multiple buckets (based on the task's partitioner). * - * @param stageId id of the stage this task belongs to - * @param stageAttemptId attempt id of the stage this task belongs to - * @param partitionId index of the number in the RDD + * @param _stageId id of the stage this task belongs to + * @param _stageAttemptId attempt id of the stage this task belongs to + * @param _partitionId index of the number in the RDD + * @param _metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side * and sent to executor side. @@ -51,16 +58,27 @@ import org.apache.spark.util._ * @param appAttemptId attempt id of the app this task belongs to */ private[spark] abstract class Task[T]( - val stageId: Int, - val stageAttemptId: Int, - val partitionId: Int, + private var _stageId: Int, + private var _stageAttemptId: Int, + private var _partitionId: Int, + // The default value is only used in tests. + private var _metrics: TaskMetrics = TaskMetrics.registered, @transient var localProperties: Properties = new Properties, // The default value is only used in tests. serializedTaskMetrics: Array[Byte] = SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(), val jobId: Option[Int] = None, val appId: Option[String] = None, - val appAttemptId: Option[String] = None) extends Serializable { + val appAttemptId: Option[String] = None) extends Serializable + with KryoSerializable { + + final def stageId: Int = _stageId + + final def stageAttemptId: Int = _stageAttemptId + + final def partitionId: Int = _partitionId + + final def metrics: TaskMetrics = _metrics @transient lazy val metrics: TaskMetrics = SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(serializedTaskMetrics)) @@ -147,7 +165,7 @@ private[spark] abstract class Task[T]( } } - private var taskMemoryManager: TaskMemoryManager = _ + @transient private var taskMemoryManager: TaskMemoryManager = _ def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = { this.taskMemoryManager = taskMemoryManager @@ -216,4 +234,23 @@ private[spark] abstract class Task[T]( taskThread.interrupt() } } + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeInt(_stageId) + output.writeVarInt(_stageAttemptId, true) + output.writeVarInt(_partitionId, true) + output.writeLong(epoch) + output.writeLong(_executorDeserializeTime) + _metrics.write(kryo, output) + } + + override def read(kryo: Kryo, input: Input): Unit = { + _stageId = input.readInt() + _stageAttemptId = input.readVarInt(true) + _partitionId = input.readVarInt(true) + epoch = input.readLong() + _executorDeserializeTime = input.readLong() + _metrics = new TaskMetrics + _metrics.read(kryo, input) + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index c98b87148e404..583898e12145f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -25,7 +25,9 @@ import java.util.Properties import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, Map} -import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, SerializableBuffer, Utils} +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} /** * Description of a task that gets passed onto executors to be executed, usually created by @@ -45,31 +47,46 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti * (which can introduce significant overhead when the maps are small). */ private[spark] class TaskDescription( - val taskId: Long, - val attemptNumber: Int, - val executorId: String, - val name: String, - val index: Int, // Index within this task's TaskSet - val addedFiles: Map[String, Long], - val addedJars: Map[String, Long], - val properties: Properties, - val serializedTask: ByteBuffer) { - - override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) -} - -private[spark] object TaskDescription { - private def serializeStringLongMap(map: Map[String, Long], dataOut: DataOutputStream): Unit = { - dataOut.writeInt(map.size) - for ((key, value) <- map) { - dataOut.writeUTF(key) - dataOut.writeLong(value) - } + private var _taskId: Long, + private var _attemptNumber: Int, + private var _executorId: String, + private var _name: String, + private var _index: Int, // Index within this task's TaskSet + @transient private var _serializedTask: ByteBuffer) + extends Serializable with KryoSerializable { + + def taskId: Long = _taskId + def attemptNumber: Int = _attemptNumber + def executorId: String = _executorId + def name: String = _name + def index: Int = _index + + // Because ByteBuffers are not serializable, wrap the task in a SerializableBuffer + private val buffer = + if (_serializedTask ne null) new SerializableBuffer(_serializedTask) else null + + def serializedTask: ByteBuffer = + if (_serializedTask ne null) _serializedTask else buffer.value + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeLong(_taskId) + output.writeVarInt(_attemptNumber, true) + output.writeString(_executorId) + output.writeString(_name) + output.writeInt(_index) + output.writeInt(_serializedTask.remaining()) + Utils.writeByteBuffer(_serializedTask, output) } - def encode(taskDescription: TaskDescription): ByteBuffer = { - val bytesOut = new ByteBufferOutputStream(4096) - val dataOut = new DataOutputStream(bytesOut) + override def read(kryo: Kryo, input: Input): Unit = { + _taskId = input.readLong() + _attemptNumber = input.readVarInt(true) + _executorId = input.readString() + _name = input.readString() + _index = input.readInt() + val len = input.readInt() + _serializedTask = ByteBuffer.wrap(input.readBytes(len)) + } dataOut.writeLong(taskDescription.taskId) dataOut.writeInt(taskDescription.attemptNumber) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3da2116f10d97..3519669e96fe6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -427,7 +427,7 @@ private[spark] class TaskSetManager( } if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) { - for (index <- dequeueTaskFromList(execId, getPendingTasksForHost(host)) + for (index <- dequeueTaskFromList(execId, host, getPendingTasksForHost(host)) // don't return executor-local tasks that are still alive if canRunOnExecutor(execId, index)) { return Some((index, TaskLocality.NODE_LOCAL, false)) @@ -444,7 +444,7 @@ private[spark] class TaskSetManager( if (TaskLocality.isAllowed(maxLocality, TaskLocality.RACK_LOCAL)) { for { rack <- sched.getRackForHost(host) - index <- dequeueTaskFromList(execId, getPendingTasksForRack(rack)) + index <- dequeueTaskFromList(execId, host, getPendingTasksForRack(rack)) // don't return executor-local tasks that are still alive if canRunOnExecutor(execId, index) } { @@ -453,7 +453,7 @@ private[spark] class TaskSetManager( } if (TaskLocality.isAllowed(maxLocality, TaskLocality.ANY)) { - for (index <- dequeueTaskFromList(execId, allPendingTasks) + for (index <- dequeueTaskFromList(execId, host, allPendingTasks) // don't return executor-local tasks that are still alive if canRunOnExecutor(execId, index)) { return Some((index, TaskLocality.ANY, false)) @@ -547,8 +547,8 @@ private[spark] class TaskSetManager( sched.dagScheduler.taskStarted(task, info) new TaskDescription( - taskId, - attemptNum, + _taskId = taskId, + _attemptNumber = attemptNum, execId, taskName, index, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index e8b7fc0ef100a..fc5115c57fb03 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -19,10 +19,13 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.TaskState.TaskState import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.ExecutorLossReason -import org.apache.spark.util.SerializableBuffer +import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -66,8 +69,27 @@ private[spark] object CoarseGrainedClusterMessages { logUrls: Map[String, String]) extends CoarseGrainedClusterMessage - case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, - data: SerializableBuffer) extends CoarseGrainedClusterMessage + case class StatusUpdate(var executorId: String, var taskId: Long, + var state: TaskState, var data: SerializableBuffer) + extends CoarseGrainedClusterMessage with KryoSerializable { + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeString(executorId) + output.writeLong(taskId) + output.writeVarInt(state.id, true) + val buffer = data.buffer + output.writeInt(buffer.remaining()) + Utils.writeByteBuffer(buffer, output) + } + + override def read(kryo: Kryo, input: Input): Unit = { + executorId = input.readString() + taskId = input.readLong() + state = org.apache.spark.TaskState(input.readVarInt(true)) + val len = input.readInt() + data = new SerializableBuffer(ByteBuffer.wrap(input.readBytes(len))) + } + } object StatusUpdate { /** Alternate factory method that takes a ByteBuffer directly for the data field */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 7ac2c71c18eb3..d0bfaafe8088f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -20,6 +20,10 @@ package org.apache.spark.storage import java.util.UUID import org.apache.spark.SparkException + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.annotation.DeveloperApi /** @@ -45,8 +49,19 @@ sealed abstract class BlockId { } @DeveloperApi -case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { - override def name: String = "rdd_" + rddId + "_" + splitIndex +case class RDDBlockId(var rddId: Int, var splitIndex: Int) + extends BlockId with KryoSerializable { + @transient override lazy val name: String = "rdd_" + rddId + "_" + splitIndex + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeInt(rddId) + output.writeVarInt(splitIndex, true) + } + + override def read(kryo: Kryo, input: Input): Unit = { + rddId = input.readInt() + splitIndex = input.readVarInt(true) + } } // Format of the shuffle block ids (including data and index) should be kept in sync with diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 1bbe7a5b39509..b6f45c4f894c6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -19,6 +19,9 @@ package org.apache.spark.storage import java.io.{Externalizable, ObjectInput, ObjectOutput} +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils @@ -30,21 +33,63 @@ private[spark] object BlockManagerMessages { // Remove a block from the slaves that have it. This can only be used to remove // blocks that the master knows about. - case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave + case class RemoveBlock(private var blockId: BlockId) extends ToBlockManagerSlave + with KryoSerializable { + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeString(blockId.name) + } + + override def read(kryo: Kryo, input: Input): Unit = { + blockId = BlockId(input.readString()) + } + } // Replicate blocks that were lost due to executor failure case class ReplicateBlock(blockId: BlockId, replicas: Seq[BlockManagerId], maxReplicas: Int) extends ToBlockManagerSlave // Remove all blocks belonging to a specific RDD. - case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave + case class RemoveRdd(private var rddId: Int) extends ToBlockManagerSlave + with KryoSerializable { + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeInt(rddId) + } + + override def read(kryo: Kryo, input: Input): Unit = { + rddId = input.readInt() + } + } // Remove all blocks belonging to a specific shuffle. - case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave + case class RemoveShuffle(private var shuffleId: Int) extends ToBlockManagerSlave + with KryoSerializable { + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeInt(shuffleId) + } + + override def read(kryo: Kryo, input: Input): Unit = { + shuffleId = input.readInt() + } + } // Remove all blocks belonging to a specific broadcast. - case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) - extends ToBlockManagerSlave + case class RemoveBroadcast(private var broadcastId: Long, + private var removeFromDriver: Boolean = true) + extends ToBlockManagerSlave with KryoSerializable { + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeLong(broadcastId) + output.writeBoolean(removeFromDriver) + } + + override def read(kryo: Kryo, input: Input): Unit = { + broadcastId = input.readLong() + removeFromDriver = input.readBoolean() + } + } /** * Driver to Executor message to trigger a thread dump. diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index f4a736d6d439a..42010a261c553 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -23,6 +23,11 @@ import java.util.{ArrayList, Collections} import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong +import scala.collection.JavaConverters._ + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext} import org.apache.spark.scheduler.AccumulableInfo @@ -41,7 +46,7 @@ private[spark] case class AccumulatorMetadata( */ abstract class AccumulatorV2[IN, OUT] extends Serializable { private[spark] var metadata: AccumulatorMetadata = _ - private[this] var atDriverSide = true + private[spark] var atDriverSide = true private[spark] def register( sc: SparkContext, @@ -207,6 +212,63 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } } +abstract class AccumulatorV2Kryo[IN, OUT] + extends AccumulatorV2[IN, OUT] with KryoSerializable { + + /** + * Child classes cannot override this and must instead implement + * writeKryo/readKryo for consistent writeReplace() behavior. + */ + override final def write(kryo: Kryo, output: Output): Unit = { + var instance = this + if (atDriverSide) { + instance = copyAndReset().asInstanceOf[AccumulatorV2Kryo[IN, OUT]] + assert(instance.isZero, "copyAndReset must return a zero value copy") + instance.metadata = this.metadata + } + val metadata = instance.metadata + output.writeLong(metadata.id) + metadata.name match { + case None => output.writeString(null) + case Some(name) => output.writeString(name) + } + output.writeBoolean(metadata.countFailedValues) + output.writeBoolean(instance.atDriverSide) + + instance.writeKryo(kryo, output) + } + + /** + * Child classes must implement readKryo() and cannot override this. + */ + override final def read(kryo: Kryo, input: Input): Unit = { + read(kryo, input, context = null) + } + + final def read(kryo: Kryo, input: Input, context: TaskContext): Unit = { + val id = input.readLong() + val name = input.readString() + metadata = AccumulatorMetadata(id, Option(name), input.readBoolean()) + atDriverSide = input.readBoolean() + if (atDriverSide) { + atDriverSide = false + // Automatically register the accumulator when it is deserialized with the task closure. + // This is for external accumulators and internal ones that do not represent task level + // metrics, e.g. internal SQL metrics, which are per-operator. + val taskContext = if (context != null) context else TaskContext.get() + if (taskContext != null) { + taskContext.registerAccumulator(this) + } + } else { + atDriverSide = true + } + + readKryo(kryo, input) + } + + def writeKryo(kryo: Kryo, output: Output): Unit + def readKryo(kryo: Kryo, input: Input): Unit +} /** * An internal class used to track accumulators by Spark itself. @@ -285,7 +347,8 @@ private[spark] object AccumulatorContext { * * @since 2.0.0 */ -class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { +class LongAccumulator extends AccumulatorV2Kryo[jl.Long, jl.Long] + with KryoSerializable { private var _sum = 0L private var _count = 0L @@ -355,6 +418,16 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { private[spark] def setValue(newValue: Long): Unit = _sum = newValue override def value: jl.Long = _sum + + override def writeKryo(kryo: Kryo, output: Output): Unit = { + output.writeLong(_sum) + output.writeLong(_count) + } + + override def readKryo(kryo: Kryo, input: Input): Unit = { + _sum = input.readLong() + _count = input.readLong() + } } @@ -364,7 +437,8 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { * * @since 2.0.0 */ -class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { +class DoubleAccumulator extends AccumulatorV2Kryo[jl.Double, jl.Double] + with KryoSerializable { private var _sum = 0.0 private var _count = 0L @@ -430,6 +504,16 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { private[spark] def setValue(newValue: Double): Unit = _sum = newValue override def value: jl.Double = _sum + + override def writeKryo(kryo: Kryo, output: Output): Unit = { + output.writeDouble(_sum) + output.writeVarLong(_count, true) + } + + override def readKryo(kryo: Kryo, input: Input): Unit = { + _sum = input.readDouble() + _count = input.readVarLong(true) + } } @@ -438,7 +522,8 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { * * @since 2.0.0 */ -class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { +class CollectionAccumulator[T] extends AccumulatorV2Kryo[T, java.util.List[T]] + with KryoSerializable { private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]()) override def isZero: Boolean = _list.isEmpty @@ -471,6 +556,23 @@ class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { _list.clear() _list.addAll(newValue) } + + override def writeKryo(kryo: Kryo, output: Output): Unit = { + output.writeVarInt(_list.size(), true) + val iter = _list.iterator() + while (iter.hasNext) { + kryo.writeClassAndObject(output, iter.next()) + } + } + + override def readKryo(kryo: Kryo, input: Input): Unit = { + var len = input.readVarInt(true) + if (!_list.isEmpty) _list.clear() + while (len > 0) { + _list.add(kryo.readClassAndObject(input).asInstanceOf[T]) + len -= 1 + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 40616421b5bca..01c0588707594 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -207,7 +207,7 @@ private[spark] object ClosureCleaner extends Logging { accessedFields: Map[Class[_], Set[String]]): Unit = { if (!isClosure(func.getClass)) { - logWarning("Expected a closure; got " + func.getClass.getName) + // logWarning("Expected a closure; got " + func.getClass.getName) return } diff --git a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala index a06b6f84ef11b..5b27fe5cdc6eb 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala @@ -21,12 +21,17 @@ import java.io.{EOFException, IOException, ObjectInputStream, ObjectOutputStream import java.nio.ByteBuffer import java.nio.channels.Channels +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + /** * A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make * it easier to pass ByteBuffers in case class messages. */ private[spark] -class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable { +class SerializableBuffer(@transient var buffer: ByteBuffer) + extends Serializable with KryoSerializable { + def value: ByteBuffer = buffer private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { @@ -51,4 +56,20 @@ class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable } buffer.rewind() // Allow us to write it again later } + + override def write(kryo: Kryo, output: Output) { + if (buffer.position() != 0) { + throw new IOException(s"Unexpected buffer position ${buffer.position()}") + } + output.writeInt(buffer.limit()) + output.writeBytes(buffer.array(), buffer.arrayOffset(), buffer.limit()) + } + + override def read(kryo: Kryo, input: Input) { + val length = input.readInt() + val b = new Array[Byte](length) + input.readBytes(b) + buffer = ByteBuffer.wrap(b) + buffer.rewind() // Allow us to read it later + } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index e63e0e3e1f68f..953699fe37b7a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -19,14 +19,17 @@ package org.apache.spark.util.collection import java.util.Arrays +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + /** * A simple, fixed-size bit set implementation. This implementation is fast because it avoids * safety/bound checking. */ -class BitSet(numBits: Int) extends Serializable { +class BitSet(numBits: Int) extends Serializable with KryoSerializable { - private val words = new Array[Long](bit2words(numBits)) - private val numWords = words.length + private var words = new Array[Long](bit2words(numBits)) + private var numWords = words.length /** * Compute the capacity (number of bits) that can be represented @@ -238,4 +241,27 @@ class BitSet(numBits: Int) extends Serializable { /** Return the number of longs it would take to hold numBits. */ private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 + + override def write(kryo: Kryo, output: Output): Unit = { + val words = this.words + val numWords = this.numWords + output.writeVarInt(numWords, true) + var i = 0 + while (i < numWords) { + output.writeLong(words(i)) + i += 1 + } + } + + override def read(kryo: Kryo, input: Input): Unit = { + val numWords = input.readVarInt(true) + val words = new Array[Long](numWords) + var i = 0 + while (i < numWords) { + words(i) = input.readLong() + i += 1 + } + this.words = words + this.numWords = numWords + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4a459571ed634..c49a8026af15c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -197,7 +197,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR /** * A lazily generated row ordering comparator. */ -class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) +class LazilyGeneratedOrdering(private var ordering: Seq[SortOrder]) extends Ordering[InternalRow] with KryoSerializable { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = @@ -220,7 +220,8 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) } override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { - generatedOrdering = GenerateOrdering.generate(kryo.readObject(in, classOf[Array[SortOrder]])) + ordering = kryo.readObject(in, classOf[Array[SortOrder]]) + generatedOrdering = GenerateOrdering.generate(ordering) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 67bb66da5f38b..6713464f336b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -20,10 +20,13 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat import java.util.Locale +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates -import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, AccumulatorV2Kryo, Utils} /** @@ -32,12 +35,12 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. */ final class SQLMetric(val metricType: String, initValue: Long = 0L) - extends AccumulatorV2[Long, Long] { + extends AccumulatorV2Kryo[Long, Long] with KryoSerializable { // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will // update it at the end of task and the value will be at least 0. Then we can filter out the -1 // values before calculate max, min, etc. - private[this] var _value = initValue + private var _value = initValue private var _zeroValue = initValue override def copy(): SQLMetric = { @@ -78,6 +81,18 @@ final class SQLMetric(val metricType: String, initValue: Long = 0L) new AccumulableInfo( id, name, update, value, true, true, Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) } + + override def writeKryo(kryo: Kryo, output: Output): Unit = { + output.writeString(metricType) + output.writeLong(_value) + output.writeLong(_zeroValue) + } + + override def readKryo(kryo: Kryo, input: Input): Unit = { + metricType = input.readString() + _value = input.readLong() + _zeroValue = input.readLong() + } } object SQLMetrics {