Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5062][Graphx] replace mapReduceTriplets with aggregateMessage in Pregel Api #3883

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import scala.reflect.ClassTag
import scala.util.Random

import org.apache.spark.SparkException
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD

import org.apache.spark.graphx.lib._
Expand Down Expand Up @@ -336,6 +335,59 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg)
}

/**
* An additional functionality for [[GraphOps.pregel)]] using `aggregateMessages`
*
* Only the parameter `sendMsg` is different from [[GraphOps.pregel]].
*
* @example for `sendMsg`:
* {{{
* private def sendMessage(ctx: EdgeContext[VD, ED, A): Unit = {
* //logic code defined by yourself.
* ctx.sendToDst(aMsg1)
* ctx.sendToSrc(aMsg2)
* }
* }}}
*
* @tparam A the Pregel message type
* @param graph the input graph.
* @param initialMsg the message each vertex will receive at the on
* the first iteration
* @param maxIterations the maximum number of iterations to run for
*
* @param tripletFields which fields should be included in the [[EdgeContext]]
* passed to the `sendMsg` function. If not all fields are needed,
* specifying this can improve performance.
*
* @param vprog the user-defined vertex program which runs on each
* vertex and receives the inbound message and computes a new vertex
* value. On the first iteration the vertex program is invoked on
* all vertices and is passed the default message. On subsequent
* iterations the vertex program is only invoked on those vertices
* that receive messages.
*
* @param sendMsg a user supplied function that is applied to out
* edges of vertices that received messages in the current
* iteration
*
* @param mergeMsg a user supplied function that takes two incoming
* messages of type A and merges them into a single message of type
* A. ''This function must be commutative and associative and
* ideally the size of A should not increase.''
*
* @return the resulting graph at the end of the computation
*/
def pregel2[A: ClassTag](graph: Graph[VD, ED],
initialMsg: A,
maxIterations: Int = Int.MaxValue,
tripletFields: TripletFields = TripletFields.All)
(vprog: (VertexId, VD, A) => VD,
sendMsg: EdgeContext[VD, ED, A] => Unit,
mergeMsg: (A, A) => A)
: Graph[VD, ED] = {
Pregel.apply2(graph, initialMsg, maxIterations, tripletFields)(vprog, sendMsg, mergeMsg)
}

/**
* Run a dynamic version of PageRank returning a graph with vertex attributes containing the
* PageRank and edge attributes containing the normalized edge weight.
Expand Down
90 changes: 89 additions & 1 deletion graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.graphx

import scala.reflect.ClassTag
import org.apache.spark.Logging

import scala.reflect.ClassTag


/**
* Implements a Pregel-like bulk-synchronous message-passing API.
Expand Down Expand Up @@ -158,4 +159,91 @@ object Pregel extends Logging {
g
} // end of apply


/**
* An additional functionality for [[Pregel.apply()]] using `aggregateMessages`
*
* @tparam VD the vertex data type
* @tparam ED the edge data type
* @tparam A the Pregel message type
* @param graph the input graph.
* @param initialMsg the message each vertex will receive at the on
* the first iteration
* @param maxIterations the maximum number of iterations to run for
*
* @param tripletFields which fields should be included in the [[EdgeContext]]
* passed to the `sendMsg` function. If not all fields are needed,
* specifying this can improve performance.
*
* @param vprog the user-defined vertex program which runs on each
* vertex and receives the inbound message and computes a new vertex
* value. On the first iteration the vertex program is invoked on
* all vertices and is passed the default message. On subsequent
* iterations the vertex program is only invoked on those vertices
* that receive messages.
*
* @param sendMsg a user supplied function that is applied to out
* edges of vertices that received messages in the current
* iteration
*
* @param mergeMsg a user supplied function that takes two incoming
* messages of type A and merges them into a single message of type
* A. ''This function must be commutative and associative and
* ideally the size of A should not increase.''
*
* @return the resulting graph at the end of the computation
*/
def apply2[VD: ClassTag, ED: ClassTag, A: ClassTag]
(graph: Graph[VD, ED],
initialMsg: A,
maxIterations: Int = Int.MaxValue,
tripletFields: TripletFields = TripletFields.All)
(vprog: (VertexId, VD, A) => VD,
sendMsg: EdgeContext[VD, ED, A] => Unit,
mergeMsg: (A, A) => A)
: Graph[VD, ED] = {

var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
// compute the messages
var messages = g.aggregateMessages(sendMsg, mergeMsg)
var activeMessages = messages.count()
// Loop
var prevG: Graph[VD, ED] = null
var i = 0
while (activeMessages > 0 && i < maxIterations) {
// Receive the messages. Vertices that didn't get any messages do not appear in newVerts.
val newVerts = g.vertices.innerJoin(messages)(vprog).cache()
// Update the graph with the new vertices.
prevG = g
g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that newVerts is no longer needed as a separate variable because aggregateMessages does not need a working set as mapReduceTriplets does. If newVerts is removed from the lines 215 and 235 then I think that the following should work

      g = g.outerJoinVertices(messages) { (vid, old, mess) => mess match {
        case Some(mess) => vprog(vid, old, mess)
        case None => old }
      }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that it would not work, because the active set of vertices is needed that will emit the messages during the aggregation stage. It seems that you need to use aggregateMessagesWithActiveSet instead of aggregateMessages.

g.cache()

val oldMessages = messages
// Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't
// get to send messages. We must cache messages so it can be materialized on the next line,
// allowing us to uncache the previous iteration.
messages = g.aggregateMessages(sendMsg, mergeMsg, tripletFields).cache()
// The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This
// hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the
// vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g).
activeMessages = messages.count()

logInfo("Pregel finished iteration " + i)

// Unpersist the RDDs hidden by newly-materialized RDDs
oldMessages.unpersist(blocking=false)
newVerts.unpersist(blocking=false)
prevG.unpersistVertices(blocking=false)
prevG.edges.unpersist(blocking=false)
if (i == 0) {
graph.unpersist(blocking = false)
graph.unpersistVertices(blocking = false)
}
// count the iteration
i += 1
}

g
} // end of apply2

} // end of class Pregel