diff --git a/build.gradle b/build.gradle index 2ef73bd706c29..7b5bf21ac784d 100644 --- a/build.gradle +++ b/build.gradle @@ -61,7 +61,8 @@ allprojects { javaxServletVersion = '3.1.0' guavaVersion = '14.0.1' hiveVersion = '1.2.1.spark2' - chillVersion = '0.8.0' + chillVersion = '0.8.1' + kryoVersion = '4.0.0' nettyVersion = '3.8.0.Final' nettyAllVersion = '4.0.29.Final' derbyVersion = '10.12.1.1' diff --git a/common/unsafe/build.gradle b/common/unsafe/build.gradle index 69d29942f5f1c..cc8a89304e8ff 100644 --- a/common/unsafe/build.gradle +++ b/common/unsafe/build.gradle @@ -20,7 +20,10 @@ description = 'Spark Project Unsafe' dependencies { compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion) - compile group: 'com.twitter', name: 'chill_' + scalaBinaryVersion, version: chillVersion + compile group: 'com.esotericsoftware', name: 'kryo-shaded', version: kryoVersion + compile(group: 'com.twitter', name: 'chill_' + scalaBinaryVersion, version: chillVersion) { + exclude(group: 'com.esotericsoftware', module: 'kryo-shaded') + } compile group: 'com.google.code.findbugs', name: 'jsr305', version: jsr305Version compile group: 'com.google.guava', name: 'guava', version: guavaVersion diff --git a/core/build.gradle b/core/build.gradle index 9395a129dac33..32a55761bbec3 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -40,8 +40,13 @@ dependencies { exclude(group: 'org.apache.avro', module: 'avro-ipc') } compile group: 'com.google.guava', name: 'guava', version: guavaVersion - compile group: 'com.twitter', name: 'chill_' + scalaBinaryVersion, version: chillVersion - compile group: 'com.twitter', name: 'chill-java', version: chillVersion + compile group: 'com.esotericsoftware', name: 'kryo-shaded', version: kryoVersion + compile(group: 'com.twitter', name: 'chill_' + scalaBinaryVersion, version: chillVersion) { + exclude(group: 'com.esotericsoftware', module: 'kryo-shaded') + } + compile(group: 'com.twitter', name: 'chill-java', version: chillVersion) { + exclude(group: 'com.esotericsoftware', module: 'kryo-shaded') + } compile group: 'org.apache.xbean', name: 'xbean-asm5-shaded', version: '4.4' // explicitly include netty from akka-remote to not let zookeeper override it compile group: 'io.netty', name: 'netty', version: nettyVersion diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 5da0dfcf5ad9d..beb69cbba5efc 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -167,6 +167,43 @@ object SparkEnv extends Logging { env } + // Create an instance of the class with the given name, possibly initializing it with our conf + def instantiateClass[T](className: String, conf: SparkConf, + isDriver: Boolean): T = { + val cls = Utils.classForName(className) + // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just + // SparkConf, then one taking no arguments + try { + cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE) + .newInstance(conf, new java.lang.Boolean(isDriver)) + .asInstanceOf[T] + } catch { + case _: NoSuchMethodException => + try { + cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] + } catch { + case _: NoSuchMethodException => + cls.getConstructor().newInstance().asInstanceOf[T] + } + } + } + + def getClosureSerializer(conf: SparkConf, doLog: Boolean = false): Serializer = { + val defaultClosureSerializerClass = classOf[JavaSerializer].getName + val closureSerializerClass = conf.get("spark.closure.serializer", + defaultClosureSerializerClass) + val closureSerializer = instantiateClass[Serializer]( + closureSerializerClass, conf, isDriver = false) + if (doLog) { + if (closureSerializerClass != defaultClosureSerializerClass) { + logInfo(s"Using non-default closure serializer: $closureSerializerClass") + } else { + logDebug(s"Using closure serializer: $closureSerializerClass") + } + } + closureSerializer + } + /** * Create a SparkEnv for the driver. */ @@ -251,26 +288,9 @@ object SparkEnv extends Logging { conf.set("spark.executor.port", rpcEnv.address.port.toString) } - // Create an instance of the class with the given name, possibly initializing it with our conf def instantiateClass[T](className: String): T = { - val cls = Utils.classForName(className) - // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just - // SparkConf, then one taking no arguments - try { - cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE) - .newInstance(conf, new java.lang.Boolean(isDriver)) - .asInstanceOf[T] - } catch { - case _: NoSuchMethodException => - try { - cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] - } catch { - case _: NoSuchMethodException => - cls.getConstructor().newInstance().asInstanceOf[T] - } - } + SparkEnv.instantiateClass(className, conf, isDriver) } - // Create an instance of the class named by the given SparkConf property, or defaultClassName // if the property is not set, possibly initializing it with our conf def instantiateClassFromConf[T](propertyName: String, defaultClassName: String): T = { @@ -283,7 +303,7 @@ object SparkEnv extends Logging { val serializerManager = new SerializerManager(serializer, conf) - val closureSerializer = new JavaSerializer(conf) + val closureSerializer = getClosureSerializer(conf, doLog = true) def registerOrLookupEndpoint( name: String, endpointCreator: => RpcEndpoint): diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala index 3d15f3a0396e1..1647b06ce0481 100644 --- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala @@ -17,6 +17,10 @@ package org.apache.spark.executor +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.LongAccumulator @@ -39,7 +43,7 @@ object DataReadMethod extends Enumeration with Serializable { * A collection of accumulators that represents metrics about reading data from external systems. */ @DeveloperApi -class InputMetrics private[spark] () extends Serializable { +class InputMetrics private[spark] () extends Serializable with KryoSerializable { private[executor] val _bytesRead = new LongAccumulator private[executor] val _recordsRead = new LongAccumulator @@ -56,4 +60,18 @@ class InputMetrics private[spark] () extends Serializable { private[spark] def incBytesRead(v: Long): Unit = _bytesRead.add(v) private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v) private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v) + + override def write(kryo: Kryo, output: Output): Unit = { + _bytesRead.write(kryo, output) + _recordsRead.write(kryo, output) + } + + override final def read(kryo: Kryo, input: Input): Unit = { + read(kryo, input, context = null) + } + + def read(kryo: Kryo, input: Input, context: TaskContext): Unit = { + _bytesRead.read(kryo, input, context) + _recordsRead.read(kryo, input, context) + } } diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala index dada9697c1cf9..418a831c7555f 100644 --- a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala @@ -17,6 +17,10 @@ package org.apache.spark.executor +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.LongAccumulator @@ -38,7 +42,7 @@ object DataWriteMethod extends Enumeration with Serializable { * A collection of accumulators that represents metrics about writing data to external systems. */ @DeveloperApi -class OutputMetrics private[spark] () extends Serializable { +class OutputMetrics private[spark] () extends Serializable with KryoSerializable { private[executor] val _bytesWritten = new LongAccumulator private[executor] val _recordsWritten = new LongAccumulator @@ -54,4 +58,18 @@ class OutputMetrics private[spark] () extends Serializable { private[spark] def setBytesWritten(v: Long): Unit = _bytesWritten.setValue(v) private[spark] def setRecordsWritten(v: Long): Unit = _recordsWritten.setValue(v) + + override def write(kryo: Kryo, output: Output): Unit = { + _bytesWritten.write(kryo, output) + _recordsWritten.write(kryo, output) + } + + override final def read(kryo: Kryo, input: Input): Unit = { + read(kryo, input, context = null) + } + + def read(kryo: Kryo, input: Input, context: TaskContext): Unit = { + _bytesWritten.read(kryo, input, context) + _recordsWritten.read(kryo, input, context) + } } diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index f7a991770d402..bbc9f65312ca6 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -17,6 +17,10 @@ package org.apache.spark.executor +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.LongAccumulator @@ -27,7 +31,7 @@ import org.apache.spark.util.LongAccumulator * Operations are not thread-safe. */ @DeveloperApi -class ShuffleReadMetrics private[spark] () extends Serializable { +class ShuffleReadMetrics private[spark] () extends Serializable with KryoSerializable { private[executor] val _remoteBlocksFetched = new LongAccumulator private[executor] val _localBlocksFetched = new LongAccumulator private[executor] val _remoteBytesRead = new LongAccumulator @@ -111,6 +115,28 @@ class ShuffleReadMetrics private[spark] () extends Serializable { _recordsRead.add(metric.recordsRead) } } + + override def write(kryo: Kryo, output: Output): Unit = { + _remoteBlocksFetched.write(kryo, output) + _localBlocksFetched.write(kryo, output) + _remoteBytesRead.write(kryo, output) + _localBytesRead.write(kryo, output) + _fetchWaitTime.write(kryo, output) + _recordsRead.write(kryo, output) + } + + override final def read(kryo: Kryo, input: Input): Unit = { + read(kryo, input, context = null) + } + + def read(kryo: Kryo, input: Input, context: TaskContext): Unit = { + _remoteBlocksFetched.read(kryo, input, context) + _localBlocksFetched.read(kryo, input, context) + _remoteBytesRead.read(kryo, input, context) + _localBytesRead.read(kryo, input, context) + _fetchWaitTime.read(kryo, input, context) + _recordsRead.read(kryo, input, context) + } } /** diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala index ada2e1bc08593..f6aaf90d93b9c 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala @@ -17,6 +17,10 @@ package org.apache.spark.executor +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.LongAccumulator @@ -27,7 +31,7 @@ import org.apache.spark.util.LongAccumulator * Operations are not thread-safe. */ @DeveloperApi -class ShuffleWriteMetrics private[spark] () extends Serializable { +class ShuffleWriteMetrics private[spark] () extends Serializable with KryoSerializable { private[executor] val _bytesWritten = new LongAccumulator private[executor] val _recordsWritten = new LongAccumulator private[executor] val _writeTime = new LongAccumulator @@ -57,6 +61,22 @@ class ShuffleWriteMetrics private[spark] () extends Serializable { _recordsWritten.setValue(recordsWritten - v) } + override def write(kryo: Kryo, output: Output): Unit = { + _bytesWritten.write(kryo, output) + _recordsWritten.write(kryo, output) + _writeTime.write(kryo, output) + } + + override def read(kryo: Kryo, input: Input): Unit = { + read(kryo, input, context = null) + } + + def read(kryo: Kryo, input: Input, context: TaskContext): Unit = { + _bytesWritten.read(kryo, input, context) + _recordsWritten.read(kryo, input, context) + _writeTime.read(kryo, input, context) + } + // Legacy methods for backward compatibility. // TODO: remove these once we make this class private. @deprecated("use bytesWritten instead", "2.0.0") diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 47aec44bac019..d2a8328684f2b 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -20,6 +20,9 @@ package org.apache.spark.executor import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, LinkedHashMap} +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging @@ -42,7 +45,7 @@ import org.apache.spark.util._ * be sent to the driver. */ @DeveloperApi -class TaskMetrics private[spark] () extends Serializable { +class TaskMetrics private[spark] () extends Serializable with KryoSerializable { // Each metric is internally represented as an accumulator private val _executorDeserializeTime = new LongAccumulator private val _executorRunTime = new LongAccumulator @@ -241,6 +244,40 @@ class TaskMetrics private[spark] () extends Serializable { acc.name.isDefined && acc.name.get == name } } + + override def write(kryo: Kryo, output: Output): Unit = { + _executorDeserializeTime.write(kryo, output) + _executorRunTime.write(kryo, output) + _resultSize.write(kryo, output) + _jvmGCTime.write(kryo, output) + _resultSerializationTime.write(kryo, output) + _memoryBytesSpilled.write(kryo, output) + _diskBytesSpilled.write(kryo, output) + _peakExecutionMemory.write(kryo, output) + _updatedBlockStatuses.write(kryo, output) + inputMetrics.write(kryo, output) + outputMetrics.write(kryo, output) + shuffleReadMetrics.write(kryo, output) + shuffleWriteMetrics.write(kryo, output) + } + + override def read(kryo: Kryo, input: Input): Unit = { + // read the TaskContext thread-local once + val taskContext = TaskContext.get() + _executorDeserializeTime.read(kryo, input, taskContext) + _executorRunTime.read(kryo, input, taskContext) + _resultSize.read(kryo, input, taskContext) + _jvmGCTime.read(kryo, input, taskContext) + _resultSerializationTime.read(kryo, input, taskContext) + _memoryBytesSpilled.read(kryo, input, taskContext) + _diskBytesSpilled.read(kryo, input, taskContext) + _peakExecutionMemory.read(kryo, input, taskContext) + _updatedBlockStatuses.read(kryo, input, taskContext) + inputMetrics.read(kryo, input, taskContext) + outputMetrics.read(kryo, input, taskContext) + shuffleReadMetrics.read(kryo, input, taskContext) + shuffleWriteMetrics.read(kryo, input, taskContext) + } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 33a3219607749..0c097c5406c44 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} @@ -32,7 +32,6 @@ import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils -import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -47,7 +46,7 @@ private[spark] class NettyBlockTransferService( extends BlockTransferService { // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. - private val serializer = new JavaSerializer(conf) + private val serializer = SparkEnv.getClosureSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 1f449f8c7a0bc..6795ac69c119c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -146,7 +146,8 @@ abstract class RDD[T: ClassTag]( def sparkContext: SparkContext = sc /** A unique ID for this RDD (within its SparkContext). */ - val id: Int = sc.newRddId() + protected var _id: Int = sc.newRddId() + def id: Int = _id /** A friendly name for this RDD */ @transient var name: String = null @@ -1627,7 +1628,7 @@ abstract class RDD[T: ClassTag]( // Other internal methods and fields // ======================================================================= - private var storageLevel: StorageLevel = StorageLevel.NONE + protected var storageLevel: StorageLevel = StorageLevel.NONE /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ @transient private[spark] val creationSite = sc.getCallSite() diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 3cb1231bd3477..7d4e5595fe860 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -21,16 +21,19 @@ import java.io.{IOException, ObjectOutputStream} import scala.reflect.ClassTag +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext} import org.apache.spark.util.Utils private[spark] class ZippedPartitionsPartition( - idx: Int, + private var idx: Int, @transient private val rdds: Seq[RDD[_]], @transient val preferredLocations: Seq[String]) - extends Partition { + extends Partition with KryoSerializable { - override val index: Int = idx + override def index: Int = idx var partitionValues = rdds.map(rdd => rdd.partitions(idx)) def partitions: Seq[Partition] = partitionValues @@ -40,6 +43,27 @@ private[spark] class ZippedPartitionsPartition( partitionValues = rdds.map(rdd => rdd.partitions(idx)) oos.defaultWriteObject() } + + override def write(kryo: Kryo, output: Output): Unit = { + // Update the reference to parent split at the time of task serialization + partitionValues = rdds.map(rdd => rdd.partitions(idx)) + output.writeVarInt(idx, true) + output.writeVarInt(partitionValues.length, true) + for (p <- partitionValues) { + kryo.writeClassAndObject(output, p) + } + } + + override def read(kryo: Kryo, input: Input): Unit = { + idx = input.readVarInt(true) + var numPartitions = input.readVarInt(true) + val partitionBuilder = Seq.newBuilder[Partition] + while (numPartitions > 0) { + partitionBuilder += kryo.readClassAndObject(input).asInstanceOf[Partition] + numPartitions -= 1 + } + partitionValues = partitionBuilder.result() + } } private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 994e18676ec49..cd2551e77e4ac 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -27,12 +27,12 @@ import org.apache.spark.util.RpcUtils /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. */ -private[spark] abstract class RpcEndpointRef(conf: SparkConf) - extends Serializable with Logging { +private[spark] abstract class RpcEndpointRef(conf: SparkConf, + _env: RpcEnv) extends Serializable with Logging { - private[this] val maxRetries = RpcUtils.numRetries(conf) - private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) - private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf) + @transient protected var maxRetries = _env.maxRetries + @transient protected var retryWaitMs = _env.retryWaitMs + @transient protected var defaultAskTimeout = _env.defaultAskTimeout /** * return the address for the [[RpcEndpointRef]] diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 56683771335a6..e35903c7d9bf6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -59,6 +59,10 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { private[spark] val defaultLookupTimeout = RpcUtils.lookupRpcTimeout(conf) + private[spark] val maxRetries = RpcUtils.numRetries(conf) + private[spark] val retryWaitMs = RpcUtils.retryWaitMs(conf) + private[spark] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf) + /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement * [[RpcEndpoint.self]]. Return `null` if the corresponding [[RpcEndpointRef]] does not exist. diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 89d2fb9b47971..8ba997a3b4525 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -29,7 +29,10 @@ import scala.reflect.ClassTag import scala.util.{DynamicVariable, Failure, Success, Try} import scala.util.control.NonFatal -import org.apache.spark.{SecurityManager, SparkConf} +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ @@ -37,12 +40,12 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.rpc._ -import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance} -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.serializer.{Serializer, SerializerInstance} +import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} private[netty] class NettyRpcEnv( val conf: SparkConf, - javaSerializerInstance: JavaSerializerInstance, + serializer: Serializer, host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging { @@ -51,6 +54,10 @@ private[netty] class NettyRpcEnv( "rpc", conf.getInt("spark.rpc.io.threads", 0)) + private val serializerInstance = new ThreadLocal[SerializerInstance] { + override def initialValue(): SerializerInstance = serializer.newInstance() + } + private val dispatcher: Dispatcher = new Dispatcher(this) private val streamManager = new NettyStreamManager(this) @@ -250,13 +257,13 @@ private[netty] class NettyRpcEnv( } private[netty] def serialize(content: Any): ByteBuffer = { - javaSerializerInstance.serialize(content) + serializerInstance.get().serialize(content) } private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = { NettyRpcEnv.currentClient.withValue(client) { deserialize { () => - javaSerializerInstance.deserialize[T](bytes) + serializerInstance.get().deserialize[T](bytes) } } } @@ -436,12 +443,9 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { def create(config: RpcEnvConfig): RpcEnv = { val sparkConf = config.conf - // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support - // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance - val javaSerializerInstance = - new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance] + val serializer = SparkEnv.getClosureSerializer(sparkConf) val nettyEnv = - new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager) + new NettyRpcEnv(sparkConf, serializer, config.host, config.securityManager) if (!config.clientMode) { val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => nettyEnv.startServer(actualPort) @@ -483,12 +487,12 @@ private[netty] class NettyRpcEndpointRef( @transient private val conf: SparkConf, endpointAddress: RpcEndpointAddress, @transient @volatile private var nettyEnv: NettyRpcEnv) - extends RpcEndpointRef(conf) with Serializable with Logging { + extends RpcEndpointRef(conf, nettyEnv) with Serializable with KryoSerializable with Logging { @transient @volatile var client: TransportClient = _ - private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null - private val _name = endpointAddress.name + private var _address = if (endpointAddress.rpcAddress != null) endpointAddress else null + private var _name = endpointAddress.name override def address: RpcAddress = if (_address != null) _address.rpcAddress else null @@ -496,12 +500,43 @@ private[netty] class NettyRpcEndpointRef( in.defaultReadObject() nettyEnv = NettyRpcEnv.currentEnv.value client = NettyRpcEnv.currentClient.value + + maxRetries = nettyEnv.maxRetries + retryWaitMs = nettyEnv.retryWaitMs + defaultAskTimeout = nettyEnv.defaultAskTimeout } private def writeObject(out: ObjectOutputStream): Unit = { out.defaultWriteObject() } + override def write(kryo: Kryo, output: Output): Unit = { + val addr = address + output.writeString(_name) + if (addr != null && addr.host != null) { + output.writeString(addr.host) + output.writeInt(addr.port) + } else { + output.writeString(null) + } + } + + override def read(kryo: Kryo, input: Input): Unit = { + _name = input.readString() + _address = null + val host = input.readString() + if (host != null) { + val port = input.readInt() + _address = RpcEndpointAddress(host, port, _name) + } + nettyEnv = NettyRpcEnv.currentEnv.value + client = NettyRpcEnv.currentClient.value + + maxRetries = nettyEnv.maxRetries + retryWaitMs = nettyEnv.retryWaitMs + defaultAskTimeout = nettyEnv.defaultAskTimeout + } + override def name: String = _name override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { @@ -529,7 +564,44 @@ private[netty] class NettyRpcEndpointRef( * The message that is sent from the sender to the receiver. */ private[netty] case class RequestMessage( - senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any) + private var _senderAddress: RpcAddress, + private var _receiver: NettyRpcEndpointRef, + private var _content: Any) extends KryoSerializable { + + final def senderAddress: RpcAddress = _senderAddress + + final def receiver: NettyRpcEndpointRef = _receiver + + final def content: Any = _content + + override def write(kryo: Kryo, output: Output): Unit = { + if (_senderAddress != null) { + output.writeString(_senderAddress.host) + output.writeInt(_senderAddress.port) + } else { + output.writeString(null) + } + if (_receiver != null) { + output.writeBoolean(true) + _receiver.write(kryo, output) + } else { + output.writeBoolean(false) + } + kryo.writeClassAndObject(output, _content) + } + + override def read(kryo: Kryo, input: Input): Unit = { + val host = input.readString() + _senderAddress = if (host != null) RpcAddress(host, input.readInt()) else null + if (input.readBoolean()) { + _receiver = kryo.newInstance(classOf[NettyRpcEndpointRef]) + _receiver.read(kryo, input) + } else { + _receiver = null + } + _content = kryo.readClassAndObject(input) + } +} /** * A response that indicates some failure happens in the receiver side. diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 75c6018e214d8..57595fdce2c09 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -21,6 +21,9 @@ import java.io._ import java.nio.ByteBuffer import java.util.Properties +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics @@ -38,7 +41,7 @@ import org.apache.spark.rdd.RDD * (RDD[T], (TaskContext, Iterator[T]) => U). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling - * @param outputId index of the task in this job (a job can launch tasks on only a subset of the + * @param _outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). * @param localProperties copy of thread-local properties set by the user on the driver side. * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. @@ -46,15 +49,17 @@ import org.apache.spark.rdd.RDD private[spark] class ResultTask[T, U]( stageId: Int, stageAttemptId: Int, - taskBinary: Broadcast[Array[Byte]], - partition: Partition, + private var taskBinary: Broadcast[Array[Byte]], + private var partition: Partition, locs: Seq[TaskLocation], - val outputId: Int, + private var _outputId: Int, localProperties: Properties, metrics: TaskMetrics) extends Task[U](stageId, stageAttemptId, partition.index, metrics, localProperties) with Serializable { + final def outputId: Int = _outputId + @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } @@ -74,4 +79,18 @@ private[spark] class ResultTask[T, U]( override def preferredLocations: Seq[TaskLocation] = preferredLocs override def toString: String = "ResultTask(" + stageId + ", " + partitionId + ")" + + override def write(kryo: Kryo, output: Output): Unit = { + super.write(kryo, output) + kryo.writeClassAndObject(output, taskBinary) + kryo.writeClassAndObject(output, partition) + output.writeInt(_outputId) + } + + override def read(kryo: Kryo, input: Input): Unit = { + super.read(kryo, input) + taskBinary = kryo.readClassAndObject(input).asInstanceOf[Broadcast[Array[Byte]]] + partition = kryo.readClassAndObject(input).asInstanceOf[Partition] + _outputId = input.readInt() + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 84b3e5ba6c1f3..add8aaf3c9cc4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -22,6 +22,9 @@ import java.util.Properties import scala.language.existentials +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics @@ -47,8 +50,8 @@ import org.apache.spark.shuffle.ShuffleWriter private[spark] class ShuffleMapTask( stageId: Int, stageAttemptId: Int, - taskBinary: Broadcast[Array[Byte]], - partition: Partition, + private var taskBinary: Broadcast[Array[Byte]], + private var partition: Partition, @transient private var locs: Seq[TaskLocation], metrics: TaskMetrics, localProperties: Properties) @@ -95,4 +98,16 @@ private[spark] class ShuffleMapTask( override def preferredLocations: Seq[TaskLocation] = preferredLocs override def toString: String = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) + + override def write(kryo: Kryo, output: Output): Unit = { + super.write(kryo, output) + kryo.writeClassAndObject(output, taskBinary) + kryo.writeClassAndObject(output, partition) + } + + override def read(kryo: Kryo, input: Input): Unit = { + super.read(kryo, input) + taskBinary = kryo.readClassAndObject(input).asInstanceOf[Broadcast[Array[Byte]]] + partition = kryo.readClassAndObject(input).asInstanceOf[Partition] + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 1ed36bf0692f8..859506f04680d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -24,6 +24,9 @@ import java.util.Properties import scala.collection.mutable import scala.collection.mutable.HashMap +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} @@ -42,19 +45,28 @@ import org.apache.spark.util.{AccumulatorV2, ByteBufferInputStream, ByteBufferOu * and sends the task output back to the driver application. A ShuffleMapTask executes the task * and divides the task output to multiple buckets (based on the task's partitioner). * - * @param stageId id of the stage this task belongs to - * @param stageAttemptId attempt id of the stage this task belongs to - * @param partitionId index of the number in the RDD - * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. + * @param _stageId id of the stage this task belongs to + * @param _stageAttemptId attempt id of the stage this task belongs to + * @param _partitionId index of the number in the RDD + * @param _metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. */ private[spark] abstract class Task[T]( - val stageId: Int, - val stageAttemptId: Int, - val partitionId: Int, + private var _stageId: Int, + private var _stageAttemptId: Int, + private var _partitionId: Int, // The default value is only used in tests. - val metrics: TaskMetrics = TaskMetrics.registered, - @transient var localProperties: Properties = new Properties) extends Serializable { + private var _metrics: TaskMetrics = TaskMetrics.registered, + @transient var localProperties: Properties = new Properties) extends Serializable + with KryoSerializable { + + final def stageId: Int = _stageId + + final def stageAttemptId: Int = _stageAttemptId + + final def partitionId: Int = _partitionId + + final def metrics: TaskMetrics = _metrics /** * Called by [[org.apache.spark.executor.Executor]] to run this task. @@ -115,7 +127,7 @@ private[spark] abstract class Task[T]( } } - private var taskMemoryManager: TaskMemoryManager = _ + @transient private var taskMemoryManager: TaskMemoryManager = _ def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = { this.taskMemoryManager = taskMemoryManager @@ -184,6 +196,25 @@ private[spark] abstract class Task[T]( taskThread.interrupt() } } + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeInt(_stageId) + output.writeVarInt(_stageAttemptId, true) + output.writeVarInt(_partitionId, true) + output.writeLong(epoch) + output.writeLong(_executorDeserializeTime) + _metrics.write(kryo, output) + } + + override def read(kryo: Kryo, input: Input): Unit = { + _stageId = input.readInt() + _stageAttemptId = input.readVarInt(true) + _partitionId = input.readVarInt(true) + epoch = input.readLong() + _executorDeserializeTime = input.readLong() + _metrics = new TaskMetrics + _metrics.read(kryo, input) + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 1c7c81c488c3a..b974ce30d4ccd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -19,25 +19,56 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import org.apache.spark.util.SerializableBuffer +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.util.{SerializableBuffer, Utils} /** * Description of a task that gets passed onto executors to be executed, usually created by * [[TaskSetManager.resourceOffer]]. */ private[spark] class TaskDescription( - val taskId: Long, - val attemptNumber: Int, - val executorId: String, - val name: String, - val index: Int, // Index within this task's TaskSet - _serializedTask: ByteBuffer) - extends Serializable { + private var _taskId: Long, + private var _attemptNumber: Int, + private var _executorId: String, + private var _name: String, + private var _index: Int, // Index within this task's TaskSet + @transient private var _serializedTask: ByteBuffer) + extends Serializable with KryoSerializable { + + def taskId: Long = _taskId + def attemptNumber: Int = _attemptNumber + def executorId: String = _executorId + def name: String = _name + def index: Int = _index // Because ByteBuffers are not serializable, wrap the task in a SerializableBuffer - private val buffer = new SerializableBuffer(_serializedTask) + private val buffer = + if (_serializedTask ne null) new SerializableBuffer(_serializedTask) else null + + def serializedTask: ByteBuffer = + if (_serializedTask ne null) _serializedTask else buffer.value + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeLong(_taskId) + output.writeVarInt(_attemptNumber, true) + output.writeString(_executorId) + output.writeString(_name) + output.writeInt(_index) + output.writeInt(_serializedTask.remaining()) + Utils.writeByteBuffer(_serializedTask, output) + } - def serializedTask: ByteBuffer = buffer.value + override def read(kryo: Kryo, input: Input): Unit = { + _taskId = input.readLong() + _attemptNumber = input.readVarInt(true) + _executorId = input.readString() + _name = input.readString() + _index = input.readInt() + val len = input.readInt() + _serializedTask = ByteBuffer.wrap(input.readBytes(len)) + } override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 24baaffbe0ce5..7afb026764da5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -523,7 +523,7 @@ private[spark] class TaskSetManager( s" $taskLocality, ${serializedTask.limit} bytes)") sched.dagScheduler.taskStarted(task, info) - return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, + return Some(new TaskDescription(_taskId = taskId, _attemptNumber = attemptNum, execId, taskName, index, serializedTask)) case _ => } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index edc8aac5d1515..38ddc94ad3639 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -19,10 +19,13 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.TaskState.TaskState import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.ExecutorLossReason -import org.apache.spark.util.SerializableBuffer +import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -54,8 +57,27 @@ private[spark] object CoarseGrainedClusterMessages { logUrls: Map[String, String]) extends CoarseGrainedClusterMessage - case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, - data: SerializableBuffer) extends CoarseGrainedClusterMessage + case class StatusUpdate(var executorId: String, var taskId: Long, + var state: TaskState, var data: SerializableBuffer) + extends CoarseGrainedClusterMessage with KryoSerializable { + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeString(executorId) + output.writeLong(taskId) + output.writeVarInt(state.id, true) + val buffer = data.buffer + output.writeInt(buffer.remaining()) + Utils.writeByteBuffer(buffer, output) + } + + override def read(kryo: Kryo, input: Input): Unit = { + executorId = input.readString() + taskId = input.readLong() + state = org.apache.spark.TaskState(input.readVarInt(true)) + val len = input.readInt() + data = new SerializableBuffer(ByteBuffer.wrap(input.readBytes(len))) + } + } object StatusUpdate { /** Alternate factory method that takes a ByteBuffer directly for the data field */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 524f6970992a5..2839b766cdf06 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -19,6 +19,9 @@ package org.apache.spark.storage import java.util.UUID +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.annotation.DeveloperApi /** @@ -49,8 +52,19 @@ sealed abstract class BlockId { } @DeveloperApi -case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { - override def name: String = "rdd_" + rddId + "_" + splitIndex +case class RDDBlockId(var rddId: Int, var splitIndex: Int) + extends BlockId with KryoSerializable { + @transient override lazy val name: String = "rdd_" + rddId + "_" + splitIndex + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeInt(rddId) + output.writeVarInt(splitIndex, true) + } + + override def read(kryo: Kryo, input: Input): Unit = { + rddId = input.readInt() + splitIndex = input.readVarInt(true) + } } // Format of the shuffle block ids (including data and index) should be kept in sync with diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 6bded92700504..c84cd576055e1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -19,6 +19,9 @@ package org.apache.spark.storage import java.io.{Externalizable, ObjectInput, ObjectOutput} +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils @@ -30,17 +33,59 @@ private[spark] object BlockManagerMessages { // Remove a block from the slaves that have it. This can only be used to remove // blocks that the master knows about. - case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave + case class RemoveBlock(private var blockId: BlockId) extends ToBlockManagerSlave + with KryoSerializable { + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeString(blockId.name) + } + + override def read(kryo: Kryo, input: Input): Unit = { + blockId = BlockId(input.readString()) + } + } // Remove all blocks belonging to a specific RDD. - case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave + case class RemoveRdd(private var rddId: Int) extends ToBlockManagerSlave + with KryoSerializable { + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeInt(rddId) + } + + override def read(kryo: Kryo, input: Input): Unit = { + rddId = input.readInt() + } + } // Remove all blocks belonging to a specific shuffle. - case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave + case class RemoveShuffle(private var shuffleId: Int) extends ToBlockManagerSlave + with KryoSerializable { + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeInt(shuffleId) + } + + override def read(kryo: Kryo, input: Input): Unit = { + shuffleId = input.readInt() + } + } // Remove all blocks belonging to a specific broadcast. - case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) - extends ToBlockManagerSlave + case class RemoveBroadcast(private var broadcastId: Long, + private var removeFromDriver: Boolean = true) + extends ToBlockManagerSlave with KryoSerializable { + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeLong(broadcastId) + output.writeBoolean(removeFromDriver) + } + + override def read(kryo: Kryo, input: Input): Unit = { + broadcastId = input.readLong() + removeFromDriver = input.readBoolean() + } + } /** * Driver -> Executor message to trigger a thread dump. diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 8990355d254cd..3243b9421cc5b 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -45,15 +45,15 @@ import scala.reflect.ClassTag import com.google.common.io.ByteStreams +import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.serializer.{SerializationStream, SerializerManager} import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel} import org.apache.spark.unsafe.Platform +import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} -import org.apache.spark.util.{SizeEstimator, Utils} -import org.apache.spark.{SparkConf, TaskContext} private sealed trait MemoryEntry[T] { def size: Long diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index d3ddd39131326..dc0b5752e596d 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -25,6 +25,9 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConverters._ +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext} import org.apache.spark.scheduler.AccumulableInfo @@ -44,7 +47,7 @@ private[spark] case class AccumulatorMetadata( */ abstract class AccumulatorV2[IN, OUT] extends Serializable { private[spark] var metadata: AccumulatorMetadata = _ - private[this] var atDriverSide = true + private[spark] var atDriverSide = true private[spark] def register( sc: SparkContext, @@ -194,6 +197,63 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } } +abstract class AccumulatorV2Kryo[IN, OUT] + extends AccumulatorV2[IN, OUT] with KryoSerializable { + + /** + * Child classes cannot override this and must instead implement + * writeKryo/readKryo for consistent writeReplace() behavior. + */ + override final def write(kryo: Kryo, output: Output): Unit = { + var instance = this + if (atDriverSide) { + instance = copyAndReset().asInstanceOf[AccumulatorV2Kryo[IN, OUT]] + assert(instance.isZero, "copyAndReset must return a zero value copy") + instance.metadata = this.metadata + } + val metadata = instance.metadata + output.writeLong(metadata.id) + metadata.name match { + case None => output.writeString(null) + case Some(name) => output.writeString(name) + } + output.writeBoolean(metadata.countFailedValues) + output.writeBoolean(instance.atDriverSide) + + instance.writeKryo(kryo, output) + } + + /** + * Child classes must implement readKryo() and cannot override this. + */ + override final def read(kryo: Kryo, input: Input): Unit = { + read(kryo, input, context = null) + } + + final def read(kryo: Kryo, input: Input, context: TaskContext): Unit = { + val id = input.readLong() + val name = input.readString() + metadata = AccumulatorMetadata(id, Option(name), input.readBoolean()) + atDriverSide = input.readBoolean() + if (atDriverSide) { + atDriverSide = false + // Automatically register the accumulator when it is deserialized with the task closure. + // This is for external accumulators and internal ones that do not represent task level + // metrics, e.g. internal SQL metrics, which are per-operator. + val taskContext = if (context != null) context else TaskContext.get() + if (taskContext != null) { + taskContext.registerAccumulator(this) + } + } else { + atDriverSide = true + } + + readKryo(kryo, input) + } + + def writeKryo(kryo: Kryo, output: Output): Unit + def readKryo(kryo: Kryo, input: Input): Unit +} /** * An internal class used to track accumulators by Spark itself. @@ -282,7 +342,8 @@ private[spark] object AccumulatorContext { * * @since 2.0.0 */ -class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { +class LongAccumulator extends AccumulatorV2Kryo[jl.Long, jl.Long] + with KryoSerializable { private var _sum = 0L private var _count = 0L @@ -352,6 +413,16 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { private[spark] def setValue(newValue: Long): Unit = _sum = newValue override def value: jl.Long = _sum + + override def writeKryo(kryo: Kryo, output: Output): Unit = { + output.writeLong(_sum) + output.writeLong(_count) + } + + override def readKryo(kryo: Kryo, input: Input): Unit = { + _sum = input.readLong() + _count = input.readLong() + } } @@ -361,7 +432,8 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { * * @since 2.0.0 */ -class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { +class DoubleAccumulator extends AccumulatorV2Kryo[jl.Double, jl.Double] + with KryoSerializable { private var _sum = 0.0 private var _count = 0L @@ -427,6 +499,16 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { private[spark] def setValue(newValue: Double): Unit = _sum = newValue override def value: jl.Double = _sum + + override def writeKryo(kryo: Kryo, output: Output): Unit = { + output.writeDouble(_sum) + output.writeVarLong(_count, true) + } + + override def readKryo(kryo: Kryo, input: Input): Unit = { + _sum = input.readDouble() + _count = input.readVarLong(true) + } } @@ -435,7 +517,8 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { * * @since 2.0.0 */ -class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { +class CollectionAccumulator[T] extends AccumulatorV2Kryo[T, java.util.List[T]] + with KryoSerializable { private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]()) override def isZero: Boolean = _list.isEmpty @@ -468,6 +551,23 @@ class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { _list.clear() _list.addAll(newValue) } + + override def writeKryo(kryo: Kryo, output: Output): Unit = { + output.writeVarInt(_list.size(), true) + val iter = _list.iterator() + while (iter.hasNext) { + kryo.writeClassAndObject(output, iter.next()) + } + } + + override def readKryo(kryo: Kryo, input: Input): Unit = { + var len = input.readVarInt(true) + if (!_list.isEmpty) _list.clear() + while (len > 0) { + _list.add(kryo.readClassAndObject(input).asInstanceOf[T]) + len -= 1 + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 489688cb0880f..d4b32cb192776 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -156,7 +156,7 @@ private[spark] object ClosureCleaner extends Logging { accessedFields: Map[Class[_], Set[String]]): Unit = { if (!isClosure(func.getClass)) { - logWarning("Expected a closure; got " + func.getClass.getName) + // logWarning("Expected a closure; got " + func.getClass.getName) return } diff --git a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala index a06b6f84ef11b..5b27fe5cdc6eb 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala @@ -21,12 +21,17 @@ import java.io.{EOFException, IOException, ObjectInputStream, ObjectOutputStream import java.nio.ByteBuffer import java.nio.channels.Channels +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + /** * A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make * it easier to pass ByteBuffers in case class messages. */ private[spark] -class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable { +class SerializableBuffer(@transient var buffer: ByteBuffer) + extends Serializable with KryoSerializable { + def value: ByteBuffer = buffer private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { @@ -51,4 +56,20 @@ class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable } buffer.rewind() // Allow us to write it again later } + + override def write(kryo: Kryo, output: Output) { + if (buffer.position() != 0) { + throw new IOException(s"Unexpected buffer position ${buffer.position()}") + } + output.writeInt(buffer.limit()) + output.writeBytes(buffer.array(), buffer.arrayOffset(), buffer.limit()) + } + + override def read(kryo: Kryo, input: Input) { + val length = input.readInt() + val b = new Array[Byte](length) + input.readBytes(b) + buffer = ByteBuffer.wrap(b) + buffer.rewind() // Allow us to read it later + } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 7ab67fc3a2de9..ac56d621d684f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -17,14 +17,17 @@ package org.apache.spark.util.collection +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + /** * A simple, fixed-size bit set implementation. This implementation is fast because it avoids * safety/bound checking. */ -class BitSet(numBits: Int) extends Serializable { +class BitSet(numBits: Int) extends Serializable with KryoSerializable { - private val words = new Array[Long](bit2words(numBits)) - private val numWords = words.length + private var words = new Array[Long](bit2words(numBits)) + private var numWords = words.length /** * Compute the capacity (number of bits) that can be represented @@ -230,4 +233,27 @@ class BitSet(numBits: Int) extends Serializable { /** Return the number of longs it would take to hold numBits. */ private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 + + override def write(kryo: Kryo, output: Output): Unit = { + val words = this.words + val numWords = this.numWords + output.writeVarInt(numWords, true) + var i = 0 + while (i < numWords) { + output.writeLong(words(i)) + i += 1 + } + } + + override def read(kryo: Kryo, input: Input): Unit = { + val numWords = input.readVarInt(true) + val words = new Array[Long](numWords) + var i = 0 + while (i < numWords) { + words(i) = input.readLong() + i += 1 + } + this.words = words + this.numWords = numWords + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index f4d35d232e691..c34022fc14ba7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -150,7 +150,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR /** * A lazily generated row ordering comparator. */ -class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) +class LazilyGeneratedOrdering(private var ordering: Seq[SortOrder]) extends Ordering[InternalRow] with KryoSerializable { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = @@ -173,7 +173,8 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) } override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { - generatedOrdering = GenerateOrdering.generate(kryo.readObject(in, classOf[Array[SortOrder]])) + ordering = kryo.readObject(in, classOf[Array[SortOrder]]) + generatedOrdering = GenerateOrdering.generate(ordering) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 1890d46a34c70..38ca2bfca63b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -55,16 +55,16 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val isHomogenousStruct = { var i = 1 val ref = ctx.javaType(schema.fields(0).dataType) - var broken = false || !ctx.isPrimitiveType(ref) || schema.length <=1 - while( !broken && i < schema.length) { + var broken = !ctx.isPrimitiveType(ref) || schema.length <= 1 + while (!broken && i < schema.length) { if (ctx.javaType(schema.fields(i).dataType) != ref) { broken = true } - i +=1 + i += 1 } !broken } - val allFields = if (isHomogenousStruct){ + val allFields = if (isHomogenousStruct) { val counter = ctx.freshName("counter") val converter = convertToSafe(ctx, ctx.getValue(tmp, schema.fields(0).dataType, counter), schema.fields(0).dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 6e76d69625ba4..72940d7b8ab0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -112,16 +112,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val isHomogenousStruct = { var i = 1 val ref = ctx.javaType(t.fields(0).dataType) - var broken = false || !ctx.isPrimitiveType(ref) || t.length <=1 - while( !broken && i < t.length) { - if (ctx.javaType(t.fields(i).dataType) != ref){ + var broken = !ctx.isPrimitiveType(ref) || t.length <= 1 + while (!broken && i < t.length) { + if (ctx.javaType(t.fields(i).dataType) != ref) { broken = true } - i +=1 + i += 1 } !broken } - if(isHomogenousStruct) { + if (isHomogenousStruct) { val counter = ctx.freshName("counter") val rowWriterChild = ctx.freshName("rowWriterChild") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index e6e01a4a7479d..8baf0bc6a879c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -20,18 +20,21 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat import java.util.Locale +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo -import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, AccumulatorV2Kryo, Utils} -final class SQLMetric(val metricType: String, initValue: Long = 0L) - extends AccumulatorV2[Long, Long] { +final class SQLMetric(var metricType: String, initValue: Long = 0L) + extends AccumulatorV2Kryo[Long, Long] with KryoSerializable { // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will // update it at the end of task and the value will be at least 0. Then we can filter out the -1 // values before calculate max, min, etc. - private[this] var _value = initValue + private var _value = initValue private var _zeroValue = initValue override def copy(): SQLMetric = { @@ -61,6 +64,18 @@ final class SQLMetric(val metricType: String, initValue: Long = 0L) new AccumulableInfo( id, name, update, value, true, true, Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) } + + override def writeKryo(kryo: Kryo, output: Output): Unit = { + output.writeString(metricType) + output.writeLong(_value) + output.writeLong(_zeroValue) + } + + override def readKryo(kryo: Kryo, input: Input): Unit = { + metricType = input.readString() + _value = input.readLong() + _zeroValue = input.readLong() + } }