Skip to content

Commit

Permalink
removing context and passing activeness directly through the vertex p…
Browse files Browse the repository at this point in the history
…roperty. This addresses the issue where byte-code inspection would always force a 3-way join.
  • Loading branch information
Joseph E. Gonzalez authored and jegonzal committed Oct 30, 2014
1 parent 5bb1396 commit 8f41371
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 46 deletions.
53 changes: 18 additions & 35 deletions graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,6 @@ import org.apache.spark.Logging
*/
object Pregel extends Logging {

class Context(val iteration: Int)

class VertexContext(iteration: Int, val wasActive: Boolean, var isActive: Boolean = true) extends Context(iteration) {
def deactivate() { isActive = false }
}

class EdgeContext(iteration: Int, val srcIsActive: Boolean, val dstIsActive: Boolean) extends Context(iteration)



/**
* Execute a Pregel-like iterative vertex-parallel abstraction. The
Expand Down Expand Up @@ -169,6 +160,7 @@ object Pregel extends Logging {
g
} // end of apply


/**
* Execute a Pregel-like iterative vertex-parallel abstraction. The
* user-defined vertex-program `vprog` is executed in parallel on
Expand Down Expand Up @@ -227,50 +219,41 @@ object Pregel extends Logging {
(graph: Graph[VD, ED],
maxIterations: Int = Int.MaxValue,
activeDirection: EdgeDirection = EdgeDirection.Either)
(vertexProgram: (VertexId, VD, Option[A], VertexContext) => VD,
sendMsg: (EdgeTriplet[VD, ED], EdgeContext) => Iterator[(VertexId, A)],
(vertexProgram: (Int, VertexId, VD, Boolean, Option[A]) => (VD, Boolean),
sendMsg: (Int, EdgeTriplet[(VD, Boolean), ED]) => Iterator[(VertexId, A)],
mergeMsg: (A, A) => A)
: Graph[VD, ED] =
{
// Initialize the graph with all vertices active
var g: Graph[(VD, Boolean), ED] = graph.mapVertices { (vid, vdata) => (vdata, true) }.cache()
var currengGraph: Graph[(VD, Boolean), ED] =
graph.mapVertices { (vid, vdata) => (vdata, true) }.cache()
// Determine the set of vertices that did not vote to halt
var activeVertices = g.vertices
var activeVertices = currengGraph.vertices
var numActive = activeVertices.count()
var i = 0
while (numActive > 0 && i < maxIterations) {
// The send message wrapper removes the active fields from the triplet and places them in the edge context.
def sendMessageWrapper(triplet: EdgeTriplet[(VD, Boolean),ED]): Iterator[(VertexId, A)] = {
val simpleTriplet = new EdgeTriplet[VD, ED]()
simpleTriplet.set(triplet)
simpleTriplet.srcAttr = triplet.srcAttr._1
simpleTriplet.dstAttr = triplet.dstAttr._1
val ctx = new EdgeContext(i, triplet.srcAttr._2, triplet.dstAttr._2)
sendMsg(simpleTriplet, ctx)
}

// get a reference to the current graph so that we can unpersist it once the new graph is created.
val prevG = g
var iteration = 0
while (numActive > 0 && iteration < maxIterations) {
// get a reference to the current graph to enable unprecistance.
val prevG = currengGraph

// Compute the messages for all the active vertices
val messages = g.mapReduceTriplets(sendMessageWrapper, mergeMsg, Some((activeVertices, activeDirection)))
val messages = currengGraph.mapReduceTriplets( t => sendMsg(iteration, t), mergeMsg,
Some((activeVertices, activeDirection)))

// Receive the messages to the subset of active vertices
g = g.outerJoinVertices(messages){ (vid, dataAndActive, msgOpt) =>
currengGraph = currengGraph.outerJoinVertices(messages){ (vid, dataAndActive, msgOpt) =>
val (vdata, active) = dataAndActive
// If the vertex voted to halt and received no message then we can skip the vertex program
if (!active && msgOpt.isEmpty) {
dataAndActive
} else {
val ctx = new VertexContext(i, active)
// The vertex program is either active or received a message (or both).
// A vertex program should vote to halt again even if it has previously voted to halt
(vertexProgram(vid, vdata, msgOpt, ctx), ctx.isActive)
vertexProgram(iteration, vid, vdata, active, msgOpt)
}
}.cache()

// Recompute the active vertices (those that have not voted to halt)
activeVertices = g.vertices.filter(v => v._2._2)
activeVertices = currengGraph.vertices.filter(v => v._2._2)

// Force all computation!
numActive = activeVertices.count()
Expand All @@ -282,11 +265,11 @@ object Pregel extends Logging {
//println("Finished Iteration " + i)
// g.vertices.foreach(println(_))

logInfo("Pregel finished iteration " + i)
logInfo("Pregel finished iteration " + iteration)
// count the iteration
i += 1
iteration += 1
}
g.mapVertices((id, vdata) => vdata._1)
currengGraph.mapVertices((id, vdata) => vdata._1)
} // end of apply


Expand Down
18 changes: 8 additions & 10 deletions graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ object PageRank extends Logging {
// Set the vertex attributes to the initial pagerank values
.mapVertices( (id, attr) => resetProb )


var iteration = 0
var prevRankGraph: Graph[Double, Double] = null
while (iteration < numIter) {
Expand Down Expand Up @@ -152,23 +151,22 @@ object PageRank extends Logging {

// Define the three functions needed to implement PageRank in the GraphX
// version of Pregel
def vertexProgram(id: VertexId, attr: (Double, Double), msgSum: Option[Double], ctx: VertexContext) = {
def vertexProgram(iter: Int, id: VertexId, attr: (Double, Double), wasActive: Boolean,
msgSum: Option[Double]) = {
var (oldPR, pendingDelta) = attr
val newPR = oldPR + msgSum.getOrElse(0.0)
// if we were active then we sent the pending delta on the last iteration
if (ctx.wasActive) {
if (wasActive) {
pendingDelta = 0.0
}
pendingDelta += (1.0 - resetProb) * msgSum.getOrElse(0.0)
if (math.abs(pendingDelta) <= tol) {
ctx.deactivate()
}
(newPR, pendingDelta)
val isActive = math.abs(pendingDelta) >= tol
((newPR, pendingDelta), isActive)
}

def sendMessage(edge: EdgeTriplet[(Double, Double), Double], ctx: EdgeContext) = {
val (srcPr, srcDelta) = edge.srcAttr
assert(ctx.srcIsActive)
def sendMessage(iter: Int, edge: EdgeTriplet[((Double, Double), Boolean), Double]) = {
val ((srcPr, srcDelta), srcIsActive) = edge.srcAttr
assert(srcIsActive)
Iterator((edge.dstId, srcDelta * edge.attr))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext with Matchers {
val nVertices = 10
val starGraph = GraphGenerators.cycleGraph(sc, nVertices)
val resetProb = 0.15
val staticRanks: VertexRDD[Double] = starGraph.staticPageRank(numIter = 10, resetProb).vertices
val staticRanks: VertexRDD[Double] =
starGraph.staticPageRank(numIter = 10, resetProb).vertices
// Check the static pagerank
val pageranks: Map[VertexId, Double] = staticRanks.collect().toMap
for (i <- 0 until nVertices) {
Expand Down

0 comments on commit 8f41371

Please sign in to comment.