From f4ae251283e2418c3adafb26ef5a9b88a4adbb08 Mon Sep 17 00:00:00 2001 From: Joseph Batchik Date: Thu, 25 Jun 2015 17:01:37 -0700 Subject: [PATCH] fixed serialization error in that SparkConf cannot be serialized --- .../main/scala/org/apache/spark/SparkConf.scala | 10 ++++++++-- .../spark/serializer/GenericAvroSerializer.scala | 15 ++++++--------- .../apache/spark/serializer/KryoSerializer.scala | 7 +++++-- .../serializer/GenericAvroSerializerSuite.scala | 13 +++++++------ 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index d2b0786cf6aba..a08d843dbc8da 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.LinkedHashSet import org.apache.avro.{Schema, SchemaNormalization} -import org.apache.spark.serializer.GenericAvroSerializer.avroSchemaKey +import org.apache.spark.serializer.GenericAvroSerializer.{avroSchemaNamespace, avroSchemaKey} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils @@ -162,7 +162,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { set("spark.serializer", classOf[KryoSerializer].getName) this } - /** * Use Kryo serialization and register the given set of Avro schemas so that the generic * record serializer can decrease network IO @@ -172,6 +171,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { conf.set(avroSchemaKey(SchemaNormalization.parsingFingerprint64(schema)), schema.toString) } + /** Gets all the avro schemas in the configuration used in the generic Avro record serializer */ + def getAvroSchema: Map[Long, String] = { + getAll.filter { case (k, v) => k.startsWith(avroSchemaNamespace) } + .map { case (k, v) => (k.substring(avroSchemaNamespace.length).toLong, v) } + .toMap + } + /** Remove a parameter from the configuration */ def remove(key: String): SparkConf = { settings.remove(key) diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index 7b0bf461ef7ef..3071b65f9c70c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -29,12 +29,9 @@ import org.apache.avro.generic.{GenericData, GenericRecord} import org.apache.avro.io._ import org.apache.avro.{Schema, SchemaNormalization} -import org.apache.spark.SparkConf - -import GenericAvroSerializer._ - object GenericAvroSerializer { - def avroSchemaKey(fingerprint: Long): String = s"avro.schema.$fingerprint" + val avroSchemaNamespace = "avro.schema." + def avroSchemaKey(fingerprint: Long): String = avroSchemaNamespace + fingerprint } /** @@ -44,7 +41,7 @@ object GenericAvroSerializer { * Actions like parsing or compressing schemas are computationally expensive so the serializer * caches all previously seen values as to reduce the amount of work needed to do. */ -class GenericAvroSerializer(conf: SparkConf) extends KSerializer[GenericRecord] { +class GenericAvroSerializer(schemas: Map[Long, String]) extends KSerializer[GenericRecord] { /** Used to reduce the amount of effort to compress the schema */ private val compressCache = new mutable.HashMap[Schema, Array[Byte]]() @@ -58,7 +55,7 @@ class GenericAvroSerializer(conf: SparkConf) extends KSerializer[GenericRecord] private val fingerprintCache = new mutable.HashMap[Schema, Long]() private val schemaCache = new mutable.HashMap[Long, Schema]() - private def confSchema(fingerprint: Long) = conf.getOption(avroSchemaKey(fingerprint)) + private def getSchema(fingerprint: Long): Option[String] = schemas.get(fingerprint) /** * Used to compress Schemas when they are being sent over the wire. @@ -110,7 +107,7 @@ class GenericAvroSerializer(conf: SparkConf) extends KSerializer[GenericRecord] val fingerprint = fingerprintCache.getOrElseUpdate(schema, { SchemaNormalization.parsingFingerprint64(schema) }) - confSchema(fingerprint) match { + getSchema(fingerprint) match { case Some(_) => { output.writeBoolean(true) output.writeLong(fingerprint) @@ -139,7 +136,7 @@ class GenericAvroSerializer(conf: SparkConf) extends KSerializer[GenericRecord] if (input.readBoolean()) { val fingerprint = input.readLong() schemaCache.getOrElseUpdate(fingerprint, { - confSchema(fingerprint) match { + getSchema(fingerprint) match { case Some(s) => new Schema.Parser().parse(s) case None => throw new RuntimeException(s"Unknown fingerprint: $fingerprint") } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 1edba877a52ae..a1b7deece3995 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -74,6 +74,9 @@ class KryoSerializer(conf: SparkConf) private val classesToRegister = conf.get("spark.kryo.classesToRegister", "") .split(',') .filter(!_.isEmpty) + conf.getExecutorEnv + + private val avroSchemas = conf.getAvroSchema def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) @@ -101,8 +104,8 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) - kryo.register(classOf[GenericRecord], new GenericAvroSerializer(conf)) - kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(conf)) + kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) + kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) try { // Use the default classloader when calling the user registrator. diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala index 47dfa0eca1b3c..ffb05a9c79c66 100644 --- a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -36,12 +36,12 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { record.put("data", "test data") test("schema compression and decompression") { - val genericSer = new GenericAvroSerializer(conf) + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema)))) } test("record serialization and deserialization") { - val genericSer = new GenericAvroSerializer(conf) + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) val outputStream = new ByteArrayOutputStream() val output = new Output(outputStream) @@ -54,25 +54,26 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("uses schema fingerprint to decrease message size") { - val genericSer = new GenericAvroSerializer(conf) + val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema) val output = new Output(new ByteArrayOutputStream()) val beginningNormalPosition = output.total() - genericSer.serializeDatum(record, output) + genericSerFull.serializeDatum(record, output) output.flush() val normalLength = output.total - beginningNormalPosition conf.registerAvroSchema(Array(schema)) + val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema) val beginningFingerprintPosition = output.total() - genericSer.serializeDatum(record, output) + genericSerFinger.serializeDatum(record, output) val fingerprintLength = output.total - beginningFingerprintPosition assert(fingerprintLength < normalLength) } test("caches previously seen schemas") { - val genericSer = new GenericAvroSerializer(conf) + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) val compressedSchema = genericSer.compress(schema) val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema))