Skip to content

Commit

Permalink
Remove msgpack dependency and switch serialization to Pyrolite, plus …
Browse files Browse the repository at this point in the history
…some clean up and refactoring
  • Loading branch information
MLnick committed Apr 21, 2014
1 parent c0ebfb6 commit 44f2857
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,20 @@

package org.apache.spark.api.python

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io._


/**
* Utilities for working with Python objects -> Hadoop-related objects
*/
private[python] object PythonHadoopUtil {

/**
* Convert a Map of properties to a [[org.apache.hadoop.conf.Configuration]]
*/
def mapToConf(map: java.util.Map[String, String]) = {
import collection.JavaConversions._
val conf = new Configuration()
Expand All @@ -42,4 +49,38 @@ private[python] object PythonHadoopUtil {
copy
}

/**
* Converts an RDD of key-value pairs, where key and/or value could be instances of
* [[org.apache.hadoop.io.Writable]], into an RDD[(K, V)]
*/
def convertRDD[K, V](rdd: RDD[(K, V)]) = {
rdd.map{
case (k: Writable, v: Writable) => (convert(k).asInstanceOf[K], convert(v).asInstanceOf[V])
case (k: Writable, v) => (convert(k).asInstanceOf[K], v.asInstanceOf[V])
case (k, v: Writable) => (k.asInstanceOf[K], convert(v).asInstanceOf[V])
case (k, v) => (k.asInstanceOf[K], v.asInstanceOf[V])
}
}

/**
* Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
* object representation
*/
private def convert(writable: Writable): Any = {
import collection.JavaConversions._
writable match {
case iw: IntWritable => SparkContext.intWritableConverter().convert(iw)
case dw: DoubleWritable => SparkContext.doubleWritableConverter().convert(dw)
case lw: LongWritable => SparkContext.longWritableConverter().convert(lw)
case fw: FloatWritable => SparkContext.floatWritableConverter().convert(fw)
case t: Text => SparkContext.stringWritableConverter().convert(t)
case bw: BooleanWritable => SparkContext.booleanWritableConverter().convert(bw)
case byw: BytesWritable => SparkContext.bytesWritableConverter().convert(byw)
case n: NullWritable => null
case aw: ArrayWritable => aw.get().map(convert(_))
case mw: MapWritable => mapAsJavaMap(mw.map{ case (k, v) => (convert(k), convert(v)) }.toMap)
case other => other
}
}

}
37 changes: 19 additions & 18 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ private object SpecialLengths {
val TIMING_DATA = -3
}

private[spark] object PythonRDD {
private[spark] object PythonRDD extends Logging {
val UTF8 = Charset.forName("UTF-8")

def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
Expand Down Expand Up @@ -273,20 +273,21 @@ private[spark] object PythonRDD {
}

/** Create and RDD from a path using [[org.apache.hadoop.mapred.SequenceFileInputFormat]] */
def sequenceFile[K, V](sc: JavaSparkContext,
path: String,
keyClass: String,
valueClass: String,
keyWrapper: String,
valueWrapper: String,
minSplits: Int) = {
def sequenceFile[K, V](
sc: JavaSparkContext,
path: String,
keyClass: String,
valueClass: String,
keyWrapper: String,
valueWrapper: String,
minSplits: Int) = {
implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]]
implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]]
val kc = kcm.runtimeClass.asInstanceOf[Class[K]]
val vc = vcm.runtimeClass.asInstanceOf[Class[V]]
val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits)
val converted = SerDeUtil.convertRDD[K, V](rdd)
JavaRDD.fromRDD(SerDeUtil.serMsgPack[K, V](converted))
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
}

/**
Expand All @@ -308,8 +309,8 @@ private[spark] object PythonRDD {
val rdd =
newAPIHadoopRDDFromClassNames[K, V, F](sc,
Some(path), inputFormatClazz, keyClazz, valueClazz, mergedConf)
val converted = SerDeUtil.convertRDD[K, V](rdd)
JavaRDD.fromRDD(SerDeUtil.serMsgPack[K, V](converted))
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
}

/**
Expand All @@ -329,8 +330,8 @@ private[spark] object PythonRDD {
val rdd =
newAPIHadoopRDDFromClassNames[K, V, F](sc,
None, inputFormatClazz, keyClazz, valueClazz, conf)
val converted = SerDeUtil.convertRDD[K, V](rdd)
JavaRDD.fromRDD(SerDeUtil.serMsgPack[K, V](converted))
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
}

private def newAPIHadoopRDDFromClassNames[K, V, F <: NewInputFormat[K, V]](
Expand Down Expand Up @@ -373,8 +374,8 @@ private[spark] object PythonRDD {
val rdd =
hadoopRDDFromClassNames[K, V, F](sc,
Some(path), inputFormatClazz, keyClazz, valueClazz, mergedConf)
val converted = SerDeUtil.convertRDD[K, V](rdd)
JavaRDD.fromRDD(SerDeUtil.serMsgPack[K, V](converted))
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
}

/**
Expand All @@ -394,8 +395,8 @@ private[spark] object PythonRDD {
val rdd =
hadoopRDDFromClassNames[K, V, F](sc,
None, inputFormatClazz, keyClazz, valueClazz, conf)
val converted = SerDeUtil.convertRDD[K, V](rdd)
JavaRDD.fromRDD(SerDeUtil.serMsgPack[K, V](converted))
val converted = PythonHadoopUtil.convertRDD[K, V](rdd)
JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
}

private def hadoopRDDFromClassNames[K, V, F <: InputFormat[K, V]](
Expand Down
145 changes: 48 additions & 97 deletions core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

package org.apache.spark.api.python

import org.msgpack.ScalaMessagePack
import scala.util.Try
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, Logging}
import org.apache.hadoop.io._
import org.apache.spark.Logging
import scala.util.Success
import scala.util.Failure
import net.razorvine.pickle.Pickler

/**
* Utilities for serialization / deserialization between Python and Java, using MsgPack.
Expand All @@ -33,106 +32,58 @@ import scala.util.Failure
private[python] object SerDeUtil extends Logging {

/**
* Checks whether a Scala object needs to be registered with MsgPack. String, primitives
* and the standard collections don't need to be registered as MsgPack takes care of serializing
* them and registering them throws scary looking errors (but still works).
* Convert an RDD of key-value pairs to an RDD of serialized Python objects, that is usable
* by PySpark. By default, if serialization fails, toString is called and the string
* representation is serialized
*/
def needsToBeRegistered[T](t: T) = {
t match {
case d: Double => false
case f: Float => false
case i: Int => false
case l: Long => false
case b: Byte => false
case c: Char => false
case bool: Boolean => false
case s: String => false
case m: Map[_, _] => false
case a: Seq[_] => false
case o: Option[_] => false
case _ => true
}
}

/** Attempts to register a class with MsgPack */
def register[T](t: T, msgpack: ScalaMessagePack) {
if (!needsToBeRegistered(t)) {
return
}
val clazz = t.getClass
Try {
msgpack.register(clazz)
log.info(s"Registered key/value class with MsgPack: $clazz")
} match {
case Failure(err) =>
log.warn(s"""Failed to register class ($clazz) with MsgPack.
Falling back to default MsgPack serialization, or 'toString' as last resort.
Error: ${err.getMessage}""")
case Success(result) =>
}
}

/** Serializes an RDD[(K, V)] -> RDD[Array[Byte]] using MsgPack */
def serMsgPack[K, V](rdd: RDD[(K, V)]) = {
import org.msgpack.ScalaMessagePack._
rdd.mapPartitions{ pairs =>
val mp = new ScalaMessagePack
var triedReg = false
pairs.map{ pair =>
Try {
if (!triedReg) {
register(pair._1, mp)
register(pair._2, mp)
triedReg = true
def rddToPython[K, V](rdd: RDD[(K, V)]): RDD[Array[Byte]] = {
rdd.mapPartitions{ iter =>
val pickle = new Pickler
var keyFailed = false
var valueFailed = false
var firstRecord = true
iter.map{ case (k, v) =>
if (firstRecord) {
Try {
pickle.dumps(Array(k, v))
} match {
case Success(b) =>
case Failure(err) =>
val kt = Try {
pickle.dumps(k)
}
val vt = Try {
pickle.dumps(v)
}
(kt, vt) match {
case (Failure(kf), Failure(vf)) =>
log.warn(s"""Failed to pickle Java object as key: ${k.getClass.getSimpleName};
Error: ${kf.getMessage}""")
log.warn(s"""Failed to pickle Java object as value: ${v.getClass.getSimpleName};
Error: ${vf.getMessage}""")
keyFailed = true
valueFailed = true
case (Failure(kf), _) =>
log.warn(s"""Failed to pickle Java object as key: ${k.getClass.getSimpleName};
Error: ${kf.getMessage}""")
keyFailed = true
case (_, Failure(vf)) =>
log.warn(s"""Failed to pickle Java object as value: ${v.getClass.getSimpleName};
Error: ${vf.getMessage}""")
valueFailed = true
}
}
mp.write(pair)
} match {
case Failure(err) =>
log.debug("Failed to write", err)
Try {
write((pair._1.toString, pair._2.toString))
} match {
case Success(result) => result
case Failure(e) => throw e
}
case Success(result) => result
firstRecord = false
}
(keyFailed, valueFailed) match {
case (true, true) => pickle.dumps(Array(k.toString, v.toString))
case (true, false) => pickle.dumps(Array(k.toString, v))
case (false, true) => pickle.dumps(Array(k, v.toString))
case (false, false) => pickle.dumps(Array(k, v))
}
}
}
}

/**
* Converts an RDD of key-value pairs, where key and/or value could be instances of
* [[org.apache.hadoop.io.Writable]], into an RDD[(K, V)]
*/
def convertRDD[K, V](rdd: RDD[(K, V)]) = {
rdd.map{
case (k: Writable, v: Writable) => (convert(k).asInstanceOf[K], convert(v).asInstanceOf[V])
case (k: Writable, v) => (convert(k).asInstanceOf[K], v.asInstanceOf[V])
case (k, v: Writable) => (k.asInstanceOf[K], convert(v).asInstanceOf[V])
case (k, v) => (k.asInstanceOf[K], v.asInstanceOf[V])
}
}

/** Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
* object representation
*/
def convert(writable: Writable): Any = {
import collection.JavaConversions._
writable match {
case iw: IntWritable => SparkContext.intWritableConverter().convert(iw)
case dw: DoubleWritable => SparkContext.doubleWritableConverter().convert(dw)
case lw: LongWritable => SparkContext.longWritableConverter().convert(lw)
case fw: FloatWritable => SparkContext.floatWritableConverter().convert(fw)
case t: Text => SparkContext.stringWritableConverter().convert(t)
case bw: BooleanWritable => SparkContext.booleanWritableConverter().convert(bw)
case byw: BytesWritable => SparkContext.bytesWritableConverter().convert(byw)
case n: NullWritable => None
case aw: ArrayWritable => aw.get().map(convert(_))
case mw: MapWritable => mw.map{ case (k, v) => (convert(k), convert(v)) }.toMap
case other => other
}
}

}

3 changes: 1 addition & 2 deletions project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,7 @@ object SparkBuild extends Build {
"com.twitter" % "chill-java" % chillVersion excludeAll(excludeAsm),
"org.tachyonproject" % "tachyon" % "0.4.1-thrift" excludeAll(excludeHadoop, excludeCurator, excludeEclipseJetty, excludePowermock),
"com.clearspring.analytics" % "stream" % "2.5.1" excludeAll(excludeFastutil),
"org.spark-project" % "pyrolite" % "2.0",
"org.msgpack" %% "msgpack-scala" % "0.6.8"
"org.spark-project" % "pyrolite" % "2.0"
),
libraryDependencies ++= maybeAvro
)
Expand Down

0 comments on commit 44f2857

Please sign in to comment.