diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 75ee730b7f1d3..452794d8e6379 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -17,13 +17,20 @@ package org.apache.spark.api.python +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io._ + /** * Utilities for working with Python objects -> Hadoop-related objects */ private[python] object PythonHadoopUtil { + /** + * Convert a Map of properties to a [[org.apache.hadoop.conf.Configuration]] + */ def mapToConf(map: java.util.Map[String, String]) = { import collection.JavaConversions._ val conf = new Configuration() @@ -42,4 +49,38 @@ private[python] object PythonHadoopUtil { copy } + /** + * Converts an RDD of key-value pairs, where key and/or value could be instances of + * [[org.apache.hadoop.io.Writable]], into an RDD[(K, V)] + */ + def convertRDD[K, V](rdd: RDD[(K, V)]) = { + rdd.map{ + case (k: Writable, v: Writable) => (convert(k).asInstanceOf[K], convert(v).asInstanceOf[V]) + case (k: Writable, v) => (convert(k).asInstanceOf[K], v.asInstanceOf[V]) + case (k, v: Writable) => (k.asInstanceOf[K], convert(v).asInstanceOf[V]) + case (k, v) => (k.asInstanceOf[K], v.asInstanceOf[V]) + } + } + + /** + * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or + * object representation + */ + private def convert(writable: Writable): Any = { + import collection.JavaConversions._ + writable match { + case iw: IntWritable => SparkContext.intWritableConverter().convert(iw) + case dw: DoubleWritable => SparkContext.doubleWritableConverter().convert(dw) + case lw: LongWritable => SparkContext.longWritableConverter().convert(lw) + case fw: FloatWritable => SparkContext.floatWritableConverter().convert(fw) + case t: Text => SparkContext.stringWritableConverter().convert(t) + case bw: BooleanWritable => SparkContext.booleanWritableConverter().convert(bw) + case byw: BytesWritable => SparkContext.bytesWritableConverter().convert(byw) + case n: NullWritable => null + case aw: ArrayWritable => aw.get().map(convert(_)) + case mw: MapWritable => mapAsJavaMap(mw.map{ case (k, v) => (convert(k), convert(v)) }.toMap) + case other => other + } + } + } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8cabbfe05e1c9..7957d6340ea2f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -211,7 +211,7 @@ private object SpecialLengths { val TIMING_DATA = -3 } -private[spark] object PythonRDD { +private[spark] object PythonRDD extends Logging { val UTF8 = Charset.forName("UTF-8") def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): @@ -273,20 +273,21 @@ private[spark] object PythonRDD { } /** Create and RDD from a path using [[org.apache.hadoop.mapred.SequenceFileInputFormat]] */ - def sequenceFile[K, V](sc: JavaSparkContext, - path: String, - keyClass: String, - valueClass: String, - keyWrapper: String, - valueWrapper: String, - minSplits: Int) = { + def sequenceFile[K, V]( + sc: JavaSparkContext, + path: String, + keyClass: String, + valueClass: String, + keyWrapper: String, + valueWrapper: String, + minSplits: Int) = { implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]] implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]] val kc = kcm.runtimeClass.asInstanceOf[Class[K]] val vc = vcm.runtimeClass.asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) - val converted = SerDeUtil.convertRDD[K, V](rdd) - JavaRDD.fromRDD(SerDeUtil.serMsgPack[K, V](converted)) + val converted = PythonHadoopUtil.convertRDD[K, V](rdd) + JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) } /** @@ -308,8 +309,8 @@ private[spark] object PythonRDD { val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClazz, keyClazz, valueClazz, mergedConf) - val converted = SerDeUtil.convertRDD[K, V](rdd) - JavaRDD.fromRDD(SerDeUtil.serMsgPack[K, V](converted)) + val converted = PythonHadoopUtil.convertRDD[K, V](rdd) + JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) } /** @@ -329,8 +330,8 @@ private[spark] object PythonRDD { val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClazz, keyClazz, valueClazz, conf) - val converted = SerDeUtil.convertRDD[K, V](rdd) - JavaRDD.fromRDD(SerDeUtil.serMsgPack[K, V](converted)) + val converted = PythonHadoopUtil.convertRDD[K, V](rdd) + JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) } private def newAPIHadoopRDDFromClassNames[K, V, F <: NewInputFormat[K, V]]( @@ -373,8 +374,8 @@ private[spark] object PythonRDD { val rdd = hadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClazz, keyClazz, valueClazz, mergedConf) - val converted = SerDeUtil.convertRDD[K, V](rdd) - JavaRDD.fromRDD(SerDeUtil.serMsgPack[K, V](converted)) + val converted = PythonHadoopUtil.convertRDD[K, V](rdd) + JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) } /** @@ -394,8 +395,8 @@ private[spark] object PythonRDD { val rdd = hadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClazz, keyClazz, valueClazz, conf) - val converted = SerDeUtil.convertRDD[K, V](rdd) - JavaRDD.fromRDD(SerDeUtil.serMsgPack[K, V](converted)) + val converted = PythonHadoopUtil.convertRDD[K, V](rdd) + JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) } private def hadoopRDDFromClassNames[K, V, F <: InputFormat[K, V]]( diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 58a67a062dfde..ed265008a8c2b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -17,13 +17,12 @@ package org.apache.spark.api.python -import org.msgpack.ScalaMessagePack import scala.util.Try import org.apache.spark.rdd.RDD -import org.apache.spark.{SparkContext, Logging} -import org.apache.hadoop.io._ +import org.apache.spark.Logging import scala.util.Success import scala.util.Failure +import net.razorvine.pickle.Pickler /** * Utilities for serialization / deserialization between Python and Java, using MsgPack. @@ -33,106 +32,58 @@ import scala.util.Failure private[python] object SerDeUtil extends Logging { /** - * Checks whether a Scala object needs to be registered with MsgPack. String, primitives - * and the standard collections don't need to be registered as MsgPack takes care of serializing - * them and registering them throws scary looking errors (but still works). + * Convert an RDD of key-value pairs to an RDD of serialized Python objects, that is usable + * by PySpark. By default, if serialization fails, toString is called and the string + * representation is serialized */ - def needsToBeRegistered[T](t: T) = { - t match { - case d: Double => false - case f: Float => false - case i: Int => false - case l: Long => false - case b: Byte => false - case c: Char => false - case bool: Boolean => false - case s: String => false - case m: Map[_, _] => false - case a: Seq[_] => false - case o: Option[_] => false - case _ => true - } - } - - /** Attempts to register a class with MsgPack */ - def register[T](t: T, msgpack: ScalaMessagePack) { - if (!needsToBeRegistered(t)) { - return - } - val clazz = t.getClass - Try { - msgpack.register(clazz) - log.info(s"Registered key/value class with MsgPack: $clazz") - } match { - case Failure(err) => - log.warn(s"""Failed to register class ($clazz) with MsgPack. - Falling back to default MsgPack serialization, or 'toString' as last resort. - Error: ${err.getMessage}""") - case Success(result) => - } - } - - /** Serializes an RDD[(K, V)] -> RDD[Array[Byte]] using MsgPack */ - def serMsgPack[K, V](rdd: RDD[(K, V)]) = { - import org.msgpack.ScalaMessagePack._ - rdd.mapPartitions{ pairs => - val mp = new ScalaMessagePack - var triedReg = false - pairs.map{ pair => - Try { - if (!triedReg) { - register(pair._1, mp) - register(pair._2, mp) - triedReg = true + def rddToPython[K, V](rdd: RDD[(K, V)]): RDD[Array[Byte]] = { + rdd.mapPartitions{ iter => + val pickle = new Pickler + var keyFailed = false + var valueFailed = false + var firstRecord = true + iter.map{ case (k, v) => + if (firstRecord) { + Try { + pickle.dumps(Array(k, v)) + } match { + case Success(b) => + case Failure(err) => + val kt = Try { + pickle.dumps(k) + } + val vt = Try { + pickle.dumps(v) + } + (kt, vt) match { + case (Failure(kf), Failure(vf)) => + log.warn(s"""Failed to pickle Java object as key: ${k.getClass.getSimpleName}; + Error: ${kf.getMessage}""") + log.warn(s"""Failed to pickle Java object as value: ${v.getClass.getSimpleName}; + Error: ${vf.getMessage}""") + keyFailed = true + valueFailed = true + case (Failure(kf), _) => + log.warn(s"""Failed to pickle Java object as key: ${k.getClass.getSimpleName}; + Error: ${kf.getMessage}""") + keyFailed = true + case (_, Failure(vf)) => + log.warn(s"""Failed to pickle Java object as value: ${v.getClass.getSimpleName}; + Error: ${vf.getMessage}""") + valueFailed = true + } } - mp.write(pair) - } match { - case Failure(err) => - log.debug("Failed to write", err) - Try { - write((pair._1.toString, pair._2.toString)) - } match { - case Success(result) => result - case Failure(e) => throw e - } - case Success(result) => result + firstRecord = false + } + (keyFailed, valueFailed) match { + case (true, true) => pickle.dumps(Array(k.toString, v.toString)) + case (true, false) => pickle.dumps(Array(k.toString, v)) + case (false, true) => pickle.dumps(Array(k, v.toString)) + case (false, false) => pickle.dumps(Array(k, v)) } } } } - /** - * Converts an RDD of key-value pairs, where key and/or value could be instances of - * [[org.apache.hadoop.io.Writable]], into an RDD[(K, V)] - */ - def convertRDD[K, V](rdd: RDD[(K, V)]) = { - rdd.map{ - case (k: Writable, v: Writable) => (convert(k).asInstanceOf[K], convert(v).asInstanceOf[V]) - case (k: Writable, v) => (convert(k).asInstanceOf[K], v.asInstanceOf[V]) - case (k, v: Writable) => (k.asInstanceOf[K], convert(v).asInstanceOf[V]) - case (k, v) => (k.asInstanceOf[K], v.asInstanceOf[V]) - } - } - - /** Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or - * object representation - */ - def convert(writable: Writable): Any = { - import collection.JavaConversions._ - writable match { - case iw: IntWritable => SparkContext.intWritableConverter().convert(iw) - case dw: DoubleWritable => SparkContext.doubleWritableConverter().convert(dw) - case lw: LongWritable => SparkContext.longWritableConverter().convert(lw) - case fw: FloatWritable => SparkContext.floatWritableConverter().convert(fw) - case t: Text => SparkContext.stringWritableConverter().convert(t) - case bw: BooleanWritable => SparkContext.booleanWritableConverter().convert(bw) - case byw: BytesWritable => SparkContext.bytesWritableConverter().convert(byw) - case n: NullWritable => None - case aw: ArrayWritable => aw.get().map(convert(_)) - case mw: MapWritable => mw.map{ case (k, v) => (convert(k), convert(v)) }.toMap - case other => other - } - } - } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b89b6f760e302..33f9d644ca66d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -344,8 +344,7 @@ object SparkBuild extends Build { "com.twitter" % "chill-java" % chillVersion excludeAll(excludeAsm), "org.tachyonproject" % "tachyon" % "0.4.1-thrift" excludeAll(excludeHadoop, excludeCurator, excludeEclipseJetty, excludePowermock), "com.clearspring.analytics" % "stream" % "2.5.1" excludeAll(excludeFastutil), - "org.spark-project" % "pyrolite" % "2.0", - "org.msgpack" %% "msgpack-scala" % "0.6.8" + "org.spark-project" % "pyrolite" % "2.0" ), libraryDependencies ++= maybeAvro )