Skip to content

Commit

Permalink
fixed serialization error in that SparkConf cannot be serialized
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Batchik committed Jun 26, 2015
1 parent 2b545cc commit f4ae251
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 19 deletions.
10 changes: 8 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.collection.mutable.LinkedHashSet

import org.apache.avro.{Schema, SchemaNormalization}

import org.apache.spark.serializer.GenericAvroSerializer.avroSchemaKey
import org.apache.spark.serializer.GenericAvroSerializer.{avroSchemaNamespace, avroSchemaKey}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -162,7 +162,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
set("spark.serializer", classOf[KryoSerializer].getName)
this
}

/**
* Use Kryo serialization and register the given set of Avro schemas so that the generic
* record serializer can decrease network IO
Expand All @@ -172,6 +171,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
conf.set(avroSchemaKey(SchemaNormalization.parsingFingerprint64(schema)), schema.toString)
}

/** Gets all the avro schemas in the configuration used in the generic Avro record serializer */
def getAvroSchema: Map[Long, String] = {
getAll.filter { case (k, v) => k.startsWith(avroSchemaNamespace) }
.map { case (k, v) => (k.substring(avroSchemaNamespace.length).toLong, v) }
.toMap
}

/** Remove a parameter from the configuration */
def remove(key: String): SparkConf = {
settings.remove(key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,9 @@ import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.avro.io._
import org.apache.avro.{Schema, SchemaNormalization}

import org.apache.spark.SparkConf

import GenericAvroSerializer._

object GenericAvroSerializer {
def avroSchemaKey(fingerprint: Long): String = s"avro.schema.$fingerprint"
val avroSchemaNamespace = "avro.schema."
def avroSchemaKey(fingerprint: Long): String = avroSchemaNamespace + fingerprint
}

/**
Expand All @@ -44,7 +41,7 @@ object GenericAvroSerializer {
* Actions like parsing or compressing schemas are computationally expensive so the serializer
* caches all previously seen values as to reduce the amount of work needed to do.
*/
class GenericAvroSerializer(conf: SparkConf) extends KSerializer[GenericRecord] {
class GenericAvroSerializer(schemas: Map[Long, String]) extends KSerializer[GenericRecord] {

/** Used to reduce the amount of effort to compress the schema */
private val compressCache = new mutable.HashMap[Schema, Array[Byte]]()
Expand All @@ -58,7 +55,7 @@ class GenericAvroSerializer(conf: SparkConf) extends KSerializer[GenericRecord]
private val fingerprintCache = new mutable.HashMap[Schema, Long]()
private val schemaCache = new mutable.HashMap[Long, Schema]()

private def confSchema(fingerprint: Long) = conf.getOption(avroSchemaKey(fingerprint))
private def getSchema(fingerprint: Long): Option[String] = schemas.get(fingerprint)

/**
* Used to compress Schemas when they are being sent over the wire.
Expand Down Expand Up @@ -110,7 +107,7 @@ class GenericAvroSerializer(conf: SparkConf) extends KSerializer[GenericRecord]
val fingerprint = fingerprintCache.getOrElseUpdate(schema, {
SchemaNormalization.parsingFingerprint64(schema)
})
confSchema(fingerprint) match {
getSchema(fingerprint) match {
case Some(_) => {
output.writeBoolean(true)
output.writeLong(fingerprint)
Expand Down Expand Up @@ -139,7 +136,7 @@ class GenericAvroSerializer(conf: SparkConf) extends KSerializer[GenericRecord]
if (input.readBoolean()) {
val fingerprint = input.readLong()
schemaCache.getOrElseUpdate(fingerprint, {
confSchema(fingerprint) match {
getSchema(fingerprint) match {
case Some(s) => new Schema.Parser().parse(s)
case None => throw new RuntimeException(s"Unknown fingerprint: $fingerprint")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class KryoSerializer(conf: SparkConf)
private val classesToRegister = conf.get("spark.kryo.classesToRegister", "")
.split(',')
.filter(!_.isEmpty)
conf.getExecutorEnv

private val avroSchemas = conf.getAvroSchema

def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize))

Expand Down Expand Up @@ -101,8 +104,8 @@ class KryoSerializer(conf: SparkConf)
kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())

kryo.register(classOf[GenericRecord], new GenericAvroSerializer(conf))
kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(conf))
kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas))
kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas))

try {
// Use the default classloader when calling the user registrator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext {
record.put("data", "test data")

test("schema compression and decompression") {
val genericSer = new GenericAvroSerializer(conf)
val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema))))
}

test("record serialization and deserialization") {
val genericSer = new GenericAvroSerializer(conf)
val genericSer = new GenericAvroSerializer(conf.getAvroSchema)

val outputStream = new ByteArrayOutputStream()
val output = new Output(outputStream)
Expand All @@ -54,25 +54,26 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext {
}

test("uses schema fingerprint to decrease message size") {
val genericSer = new GenericAvroSerializer(conf)
val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema)

val output = new Output(new ByteArrayOutputStream())

val beginningNormalPosition = output.total()
genericSer.serializeDatum(record, output)
genericSerFull.serializeDatum(record, output)
output.flush()
val normalLength = output.total - beginningNormalPosition

conf.registerAvroSchema(Array(schema))
val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema)
val beginningFingerprintPosition = output.total()
genericSer.serializeDatum(record, output)
genericSerFinger.serializeDatum(record, output)
val fingerprintLength = output.total - beginningFingerprintPosition

assert(fingerprintLength < normalLength)
}

test("caches previously seen schemas") {
val genericSer = new GenericAvroSerializer(conf)
val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
val compressedSchema = genericSer.compress(schema)
val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema))

Expand Down

0 comments on commit f4ae251

Please sign in to comment.