Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SNAP-1136] Kryo closure serialization support and optimizations #27

Merged
merged 5 commits into from
Nov 28, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ allprojects {
javaxServletVersion = '3.1.0'
guavaVersion = '14.0.1'
hiveVersion = '1.2.1.spark2'
chillVersion = '0.8.0'
chillVersion = '0.8.1'
kryoVersion = '4.0.0'
nettyVersion = '3.8.0.Final'
nettyAllVersion = '4.0.29.Final'
derbyVersion = '10.12.1.1'
Expand Down
5 changes: 4 additions & 1 deletion common/unsafe/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ description = 'Spark Project Unsafe'
dependencies {
compile project(subprojectBase + 'snappy-spark-tags_' + scalaBinaryVersion)

compile group: 'com.twitter', name: 'chill_' + scalaBinaryVersion, version: chillVersion
compile group: 'com.esotericsoftware', name: 'kryo-shaded', version: kryoVersion
compile(group: 'com.twitter', name: 'chill_' + scalaBinaryVersion, version: chillVersion) {
exclude(group: 'com.esotericsoftware', module: 'kryo-shaded')
}
compile group: 'com.google.code.findbugs', name: 'jsr305', version: jsr305Version
compile group: 'com.google.guava', name: 'guava', version: guavaVersion

Expand Down
9 changes: 7 additions & 2 deletions core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,13 @@ dependencies {
exclude(group: 'org.apache.avro', module: 'avro-ipc')
}
compile group: 'com.google.guava', name: 'guava', version: guavaVersion
compile group: 'com.twitter', name: 'chill_' + scalaBinaryVersion, version: chillVersion
compile group: 'com.twitter', name: 'chill-java', version: chillVersion
compile group: 'com.esotericsoftware', name: 'kryo-shaded', version: kryoVersion
compile(group: 'com.twitter', name: 'chill_' + scalaBinaryVersion, version: chillVersion) {
exclude(group: 'com.esotericsoftware', module: 'kryo-shaded')
}
compile(group: 'com.twitter', name: 'chill-java', version: chillVersion) {
exclude(group: 'com.esotericsoftware', module: 'kryo-shaded')
}
compile group: 'org.apache.xbean', name: 'xbean-asm5-shaded', version: '4.4'
// explicitly include netty from akka-remote to not let zookeeper override it
compile group: 'io.netty', name: 'netty', version: nettyVersion
Expand Down
58 changes: 39 additions & 19 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,43 @@ object SparkEnv extends Logging {
env
}

// Create an instance of the class with the given name, possibly initializing it with our conf
def instantiateClass[T](className: String, conf: SparkConf,
isDriver: Boolean): T = {
val cls = Utils.classForName(className)
// Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
// SparkConf, then one taking no arguments
try {
cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
.newInstance(conf, new java.lang.Boolean(isDriver))
.asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
try {
cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
cls.getConstructor().newInstance().asInstanceOf[T]
}
}
}

def getClosureSerializer(conf: SparkConf, doLog: Boolean = false): Serializer = {
val defaultClosureSerializerClass = classOf[JavaSerializer].getName
val closureSerializerClass = conf.get("spark.closure.serializer",
defaultClosureSerializerClass)
val closureSerializer = instantiateClass[Serializer](
closureSerializerClass, conf, isDriver = false)
if (doLog) {
if (closureSerializerClass != defaultClosureSerializerClass) {
logInfo(s"Using non-default closure serializer: $closureSerializerClass")
} else {
logDebug(s"Using closure serializer: $closureSerializerClass")
}
}
closureSerializer
}

/**
* Create a SparkEnv for the driver.
*/
Expand Down Expand Up @@ -251,26 +288,9 @@ object SparkEnv extends Logging {
conf.set("spark.executor.port", rpcEnv.address.port.toString)
}

// Create an instance of the class with the given name, possibly initializing it with our conf
def instantiateClass[T](className: String): T = {
val cls = Utils.classForName(className)
// Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
// SparkConf, then one taking no arguments
try {
cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
.newInstance(conf, new java.lang.Boolean(isDriver))
.asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
try {
cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
cls.getConstructor().newInstance().asInstanceOf[T]
}
}
SparkEnv.instantiateClass(className, conf, isDriver)
}

// Create an instance of the class named by the given SparkConf property, or defaultClassName
// if the property is not set, possibly initializing it with our conf
def instantiateClassFromConf[T](propertyName: String, defaultClassName: String): T = {
Expand All @@ -283,7 +303,7 @@ object SparkEnv extends Logging {

val serializerManager = new SerializerManager(serializer, conf)

val closureSerializer = new JavaSerializer(conf)
val closureSerializer = getClosureSerializer(conf, doLog = true)

def registerOrLookupEndpoint(
name: String, endpointCreator: => RpcEndpoint):
Expand Down
20 changes: 19 additions & 1 deletion core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.executor

import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.esotericsoftware.kryo.io.{Input, Output}

import org.apache.spark.TaskContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.LongAccumulator

Expand All @@ -39,7 +43,7 @@ object DataReadMethod extends Enumeration with Serializable {
* A collection of accumulators that represents metrics about reading data from external systems.
*/
@DeveloperApi
class InputMetrics private[spark] () extends Serializable {
class InputMetrics private[spark] () extends Serializable with KryoSerializable {
private[executor] val _bytesRead = new LongAccumulator
private[executor] val _recordsRead = new LongAccumulator

Expand All @@ -56,4 +60,18 @@ class InputMetrics private[spark] () extends Serializable {
private[spark] def incBytesRead(v: Long): Unit = _bytesRead.add(v)
private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v)
private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v)

override def write(kryo: Kryo, output: Output): Unit = {
_bytesRead.write(kryo, output)
_recordsRead.write(kryo, output)
}

override final def read(kryo: Kryo, input: Input): Unit = {
read(kryo, input, context = null)
}

def read(kryo: Kryo, input: Input, context: TaskContext): Unit = {
_bytesRead.read(kryo, input, context)
_recordsRead.read(kryo, input, context)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.executor

import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.esotericsoftware.kryo.io.{Input, Output}

import org.apache.spark.TaskContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.LongAccumulator

Expand All @@ -38,7 +42,7 @@ object DataWriteMethod extends Enumeration with Serializable {
* A collection of accumulators that represents metrics about writing data to external systems.
*/
@DeveloperApi
class OutputMetrics private[spark] () extends Serializable {
class OutputMetrics private[spark] () extends Serializable with KryoSerializable {
private[executor] val _bytesWritten = new LongAccumulator
private[executor] val _recordsWritten = new LongAccumulator

Expand All @@ -54,4 +58,18 @@ class OutputMetrics private[spark] () extends Serializable {

private[spark] def setBytesWritten(v: Long): Unit = _bytesWritten.setValue(v)
private[spark] def setRecordsWritten(v: Long): Unit = _recordsWritten.setValue(v)

override def write(kryo: Kryo, output: Output): Unit = {
_bytesWritten.write(kryo, output)
_recordsWritten.write(kryo, output)
}

override final def read(kryo: Kryo, input: Input): Unit = {
read(kryo, input, context = null)
}

def read(kryo: Kryo, input: Input, context: TaskContext): Unit = {
_bytesWritten.read(kryo, input, context)
_recordsWritten.read(kryo, input, context)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.executor

import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.esotericsoftware.kryo.io.{Input, Output}

import org.apache.spark.TaskContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.LongAccumulator

Expand All @@ -27,7 +31,7 @@ import org.apache.spark.util.LongAccumulator
* Operations are not thread-safe.
*/
@DeveloperApi
class ShuffleReadMetrics private[spark] () extends Serializable {
class ShuffleReadMetrics private[spark] () extends Serializable with KryoSerializable {
private[executor] val _remoteBlocksFetched = new LongAccumulator
private[executor] val _localBlocksFetched = new LongAccumulator
private[executor] val _remoteBytesRead = new LongAccumulator
Expand Down Expand Up @@ -111,6 +115,28 @@ class ShuffleReadMetrics private[spark] () extends Serializable {
_recordsRead.add(metric.recordsRead)
}
}

override def write(kryo: Kryo, output: Output): Unit = {
_remoteBlocksFetched.write(kryo, output)
_localBlocksFetched.write(kryo, output)
_remoteBytesRead.write(kryo, output)
_localBytesRead.write(kryo, output)
_fetchWaitTime.write(kryo, output)
_recordsRead.write(kryo, output)
}

override final def read(kryo: Kryo, input: Input): Unit = {
read(kryo, input, context = null)
}

def read(kryo: Kryo, input: Input, context: TaskContext): Unit = {
_remoteBlocksFetched.read(kryo, input, context)
_localBlocksFetched.read(kryo, input, context)
_remoteBytesRead.read(kryo, input, context)
_localBytesRead.read(kryo, input, context)
_fetchWaitTime.read(kryo, input, context)
_recordsRead.read(kryo, input, context)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.executor

import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.esotericsoftware.kryo.io.{Input, Output}

import org.apache.spark.TaskContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.LongAccumulator

Expand All @@ -27,7 +31,7 @@ import org.apache.spark.util.LongAccumulator
* Operations are not thread-safe.
*/
@DeveloperApi
class ShuffleWriteMetrics private[spark] () extends Serializable {
class ShuffleWriteMetrics private[spark] () extends Serializable with KryoSerializable {
private[executor] val _bytesWritten = new LongAccumulator
private[executor] val _recordsWritten = new LongAccumulator
private[executor] val _writeTime = new LongAccumulator
Expand Down Expand Up @@ -57,6 +61,22 @@ class ShuffleWriteMetrics private[spark] () extends Serializable {
_recordsWritten.setValue(recordsWritten - v)
}

override def write(kryo: Kryo, output: Output): Unit = {
_bytesWritten.write(kryo, output)
_recordsWritten.write(kryo, output)
_writeTime.write(kryo, output)
}

override def read(kryo: Kryo, input: Input): Unit = {
read(kryo, input, context = null)
}

def read(kryo: Kryo, input: Input, context: TaskContext): Unit = {
_bytesWritten.read(kryo, input, context)
_recordsWritten.read(kryo, input, context)
_writeTime.read(kryo, input, context)
}

// Legacy methods for backward compatibility.
// TODO: remove these once we make this class private.
@deprecated("use bytesWritten instead", "2.0.0")
Expand Down
39 changes: 38 additions & 1 deletion core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ package org.apache.spark.executor
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, LinkedHashMap}

import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.esotericsoftware.kryo.io.{Input, Output}

import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
Expand All @@ -42,7 +45,7 @@ import org.apache.spark.util._
* be sent to the driver.
*/
@DeveloperApi
class TaskMetrics private[spark] () extends Serializable {
class TaskMetrics private[spark] () extends Serializable with KryoSerializable {
// Each metric is internally represented as an accumulator
private val _executorDeserializeTime = new LongAccumulator
private val _executorRunTime = new LongAccumulator
Expand Down Expand Up @@ -241,6 +244,40 @@ class TaskMetrics private[spark] () extends Serializable {
acc.name.isDefined && acc.name.get == name
}
}

override def write(kryo: Kryo, output: Output): Unit = {
_executorDeserializeTime.write(kryo, output)
_executorRunTime.write(kryo, output)
_resultSize.write(kryo, output)
_jvmGCTime.write(kryo, output)
_resultSerializationTime.write(kryo, output)
_memoryBytesSpilled.write(kryo, output)
_diskBytesSpilled.write(kryo, output)
_peakExecutionMemory.write(kryo, output)
_updatedBlockStatuses.write(kryo, output)
inputMetrics.write(kryo, output)
outputMetrics.write(kryo, output)
shuffleReadMetrics.write(kryo, output)
shuffleWriteMetrics.write(kryo, output)
}

override def read(kryo: Kryo, input: Input): Unit = {
// read the TaskContext thread-local once
val taskContext = TaskContext.get()
_executorDeserializeTime.read(kryo, input, taskContext)
_executorRunTime.read(kryo, input, taskContext)
_resultSize.read(kryo, input, taskContext)
_jvmGCTime.read(kryo, input, taskContext)
_resultSerializationTime.read(kryo, input, taskContext)
_memoryBytesSpilled.read(kryo, input, taskContext)
_diskBytesSpilled.read(kryo, input, taskContext)
_peakExecutionMemory.read(kryo, input, taskContext)
_updatedBlockStatuses.read(kryo, input, taskContext)
inputMetrics.read(kryo, input, taskContext)
outputMetrics.read(kryo, input, taskContext)
shuffleReadMetrics.read(kryo, input, taskContext)
shuffleWriteMetrics.read(kryo, input, taskContext)
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.{SecurityManager, SparkConf, SparkEnv}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory}
Expand All @@ -32,7 +32,6 @@ import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher}
import org.apache.spark.network.shuffle.protocol.UploadBlock
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils

Expand All @@ -47,7 +46,7 @@ private[spark] class NettyBlockTransferService(
extends BlockTransferService {

// TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
private val serializer = new JavaSerializer(conf)
private val serializer = SparkEnv.getClosureSerializer(conf)
private val authEnabled = securityManager.isAuthenticationEnabled()
private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores)

Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ abstract class RDD[T: ClassTag](
def sparkContext: SparkContext = sc

/** A unique ID for this RDD (within its SparkContext). */
val id: Int = sc.newRddId()
protected var _id: Int = sc.newRddId()
def id: Int = _id

/** A friendly name for this RDD */
@transient var name: String = null
Expand Down Expand Up @@ -1627,7 +1628,7 @@ abstract class RDD[T: ClassTag](
// Other internal methods and fields
// =======================================================================

private var storageLevel: StorageLevel = StorageLevel.NONE
protected var storageLevel: StorageLevel = StorageLevel.NONE

/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
@transient private[spark] val creationSite = sc.getCallSite()
Expand Down
Loading