Skip to content

Commit

Permalink
Remove GraphX MessageToPartition for compatibility with sort-based sh…
Browse files Browse the repository at this point in the history
…uffle

MessageToPartition was used in `Graph#partitionBy`. Unlike a Tuple2, it marked the key as transient to avoid sending it over the network. However, it was incompatible with sort-based shuffle (SPARK-2045) and represented only a minor optimization: for partitionBy, it improved performance by 6.3% (30.4 s to 28.5 s) and reduced communication by 5.6% (114.2 MB to 107.8 MB).

Author: Ankur Dave <[email protected]>

Closes #1537 from ankurdave/remove-MessageToPartition and squashes the following commits:

f9d0054 [Ankur Dave] Remove MessageToPartition
ab71364 [Ankur Dave] Remove unused VertexBroadcastMsg
  • Loading branch information
ankurdave authored and rxin committed Jul 23, 2014
1 parent 02e4572 commit 6c2be93
Show file tree
Hide file tree
Showing 5 changed files with 2 additions and 228 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ class GraphKryoRegistrator extends KryoRegistrator {

def registerClasses(kryo: Kryo) {
kryo.register(classOf[Edge[Object]])
kryo.register(classOf[MessageToPartition[Object]])
kryo.register(classOf[VertexBroadcastMsg[Object]])
kryo.register(classOf[RoutingTableMessage])
kryo.register(classOf[(VertexId, Object)])
kryo.register(classOf[EdgePartition[Object, Object]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.storage.StorageLevel

import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl._
import org.apache.spark.graphx.impl.MsgRDDFunctions._
import org.apache.spark.graphx.util.BytecodeUtils


Expand Down Expand Up @@ -83,15 +82,13 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
val vdTag = classTag[VD]
val newEdges = edges.withPartitionsRDD(edges.map { e =>
val part: PartitionID = partitionStrategy.getPartition(e.srcId, e.dstId, numPartitions)

// Should we be using 3-tuple or an optimized class
new MessageToPartition(part, (e.srcId, e.dstId, e.attr))
(part, (e.srcId, e.dstId, e.attr))
}
.partitionBy(new HashPartitioner(numPartitions))
.mapPartitionsWithIndex( { (pid, iter) =>
val builder = new EdgePartitionBuilder[ED, VD]()(edTag, vdTag)
iter.foreach { message =>
val data = message.data
val data = message._2
builder.add(data._1, data._2, data._3)
}
val edgePartition = builder.toEdgePartition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,82 +25,6 @@ import org.apache.spark.graphx.{PartitionID, VertexId}
import org.apache.spark.rdd.{ShuffledRDD, RDD}


private[graphx]
class VertexBroadcastMsg[@specialized(Int, Long, Double, Boolean) T](
@transient var partition: PartitionID,
var vid: VertexId,
var data: T)
extends Product2[PartitionID, (VertexId, T)] with Serializable {

override def _1 = partition

override def _2 = (vid, data)

override def canEqual(that: Any): Boolean = that.isInstanceOf[VertexBroadcastMsg[_]]
}


/**
* A message used to send a specific value to a partition.
* @param partition index of the target partition.
* @param data value to send
*/
private[graphx]
class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef */) T](
@transient var partition: PartitionID,
var data: T)
extends Product2[PartitionID, T] with Serializable {

override def _1 = partition

override def _2 = data

override def canEqual(that: Any): Boolean = that.isInstanceOf[MessageToPartition[_]]
}


private[graphx]
class VertexBroadcastMsgRDDFunctions[T: ClassTag](self: RDD[VertexBroadcastMsg[T]]) {
def partitionBy(partitioner: Partitioner): RDD[VertexBroadcastMsg[T]] = {
val rdd = new ShuffledRDD[PartitionID, (VertexId, T), (VertexId, T), VertexBroadcastMsg[T]](
self, partitioner)

// Set a custom serializer if the data is of int or double type.
if (classTag[T] == ClassTag.Int) {
rdd.setSerializer(new IntVertexBroadcastMsgSerializer)
} else if (classTag[T] == ClassTag.Long) {
rdd.setSerializer(new LongVertexBroadcastMsgSerializer)
} else if (classTag[T] == ClassTag.Double) {
rdd.setSerializer(new DoubleVertexBroadcastMsgSerializer)
}
rdd
}
}


private[graphx]
class MsgRDDFunctions[T: ClassTag](self: RDD[MessageToPartition[T]]) {

/**
* Return a copy of the RDD partitioned using the specified partitioner.
*/
def partitionBy(partitioner: Partitioner): RDD[MessageToPartition[T]] = {
new ShuffledRDD[PartitionID, T, T, MessageToPartition[T]](self, partitioner)
}

}

private[graphx]
object MsgRDDFunctions {
implicit def rdd2PartitionRDDFunctions[T: ClassTag](rdd: RDD[MessageToPartition[T]]) = {
new MsgRDDFunctions(rdd)
}

implicit def rdd2vertexMessageRDDFunctions[T: ClassTag](rdd: RDD[VertexBroadcastMsg[T]]) = {
new VertexBroadcastMsgRDDFunctions(rdd)
}
}

private[graphx]
class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) {
def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,78 +76,6 @@ class VertexIdMsgSerializer extends Serializer with Serializable {
}
}

/** A special shuffle serializer for VertexBroadcastMessage[Int]. */
private[graphx]
class IntVertexBroadcastMsgSerializer extends Serializer with Serializable {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {

override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
def writeObject[T: ClassTag](t: T) = {
val msg = t.asInstanceOf[VertexBroadcastMsg[Int]]
writeVarLong(msg.vid, optimizePositive = false)
writeInt(msg.data)
this
}
}

override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
override def readObject[T: ClassTag](): T = {
val a = readVarLong(optimizePositive = false)
val b = readInt()
new VertexBroadcastMsg[Int](0, a, b).asInstanceOf[T]
}
}
}
}

/** A special shuffle serializer for VertexBroadcastMessage[Long]. */
private[graphx]
class LongVertexBroadcastMsgSerializer extends Serializer with Serializable {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {

override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
def writeObject[T: ClassTag](t: T) = {
val msg = t.asInstanceOf[VertexBroadcastMsg[Long]]
writeVarLong(msg.vid, optimizePositive = false)
writeLong(msg.data)
this
}
}

override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
override def readObject[T: ClassTag](): T = {
val a = readVarLong(optimizePositive = false)
val b = readLong()
new VertexBroadcastMsg[Long](0, a, b).asInstanceOf[T]
}
}
}
}

/** A special shuffle serializer for VertexBroadcastMessage[Double]. */
private[graphx]
class DoubleVertexBroadcastMsgSerializer extends Serializer with Serializable {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {

override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
def writeObject[T: ClassTag](t: T) = {
val msg = t.asInstanceOf[VertexBroadcastMsg[Double]]
writeVarLong(msg.vid, optimizePositive = false)
writeDouble(msg.data)
this
}
}

override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
def readObject[T: ClassTag](): T = {
val a = readVarLong(optimizePositive = false)
val b = readDouble()
new VertexBroadcastMsg[Double](0, a, b).asInstanceOf[T]
}
}
}
}

/** A special shuffle serializer for AggregationMessage[Int]. */
private[graphx]
class IntAggMsgSerializer extends Serializer with Serializable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,75 +26,11 @@ import org.scalatest.FunSuite

import org.apache.spark._
import org.apache.spark.graphx.impl._
import org.apache.spark.graphx.impl.MsgRDDFunctions._
import org.apache.spark.serializer.SerializationStream


class SerializerSuite extends FunSuite with LocalSparkContext {

test("IntVertexBroadcastMsgSerializer") {
val outMsg = new VertexBroadcastMsg[Int](3, 4, 5)
val bout = new ByteArrayOutputStream
val outStrm = new IntVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
outStrm.writeObject(outMsg)
outStrm.writeObject(outMsg)
bout.flush()
val bin = new ByteArrayInputStream(bout.toByteArray)
val inStrm = new IntVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
val inMsg1: VertexBroadcastMsg[Int] = inStrm.readObject()
val inMsg2: VertexBroadcastMsg[Int] = inStrm.readObject()
assert(outMsg.vid === inMsg1.vid)
assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data)

intercept[EOFException] {
inStrm.readObject()
}
}

test("LongVertexBroadcastMsgSerializer") {
val outMsg = new VertexBroadcastMsg[Long](3, 4, 5)
val bout = new ByteArrayOutputStream
val outStrm = new LongVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
outStrm.writeObject(outMsg)
outStrm.writeObject(outMsg)
bout.flush()
val bin = new ByteArrayInputStream(bout.toByteArray)
val inStrm = new LongVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
val inMsg1: VertexBroadcastMsg[Long] = inStrm.readObject()
val inMsg2: VertexBroadcastMsg[Long] = inStrm.readObject()
assert(outMsg.vid === inMsg1.vid)
assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data)

intercept[EOFException] {
inStrm.readObject()
}
}

test("DoubleVertexBroadcastMsgSerializer") {
val outMsg = new VertexBroadcastMsg[Double](3, 4, 5.0)
val bout = new ByteArrayOutputStream
val outStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
outStrm.writeObject(outMsg)
outStrm.writeObject(outMsg)
bout.flush()
val bin = new ByteArrayInputStream(bout.toByteArray)
val inStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
val inMsg1: VertexBroadcastMsg[Double] = inStrm.readObject()
val inMsg2: VertexBroadcastMsg[Double] = inStrm.readObject()
assert(outMsg.vid === inMsg1.vid)
assert(outMsg.vid === inMsg2.vid)
assert(outMsg.data === inMsg1.data)
assert(outMsg.data === inMsg2.data)

intercept[EOFException] {
inStrm.readObject()
}
}

test("IntAggMsgSerializer") {
val outMsg = (4: VertexId, 5)
val bout = new ByteArrayOutputStream
Expand Down Expand Up @@ -152,15 +88,6 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
}
}

test("TestShuffleVertexBroadcastMsg") {
withSpark { sc =>
val bmsgs = sc.parallelize(0 until 100, 10).map { pid =>
new VertexBroadcastMsg[Int](pid, pid, pid)
}
bmsgs.partitionBy(new HashPartitioner(3)).collect()
}
}

test("variable long encoding") {
def testVarLongEncoding(v: Long, optimizePositive: Boolean) {
val bout = new ByteArrayOutputStream
Expand Down

0 comments on commit 6c2be93

Please sign in to comment.