diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 9c730704ded05..e57a6ae8e846c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -55,15 +55,6 @@ import org.apache.spark.Logging */ object Pregel extends Logging { - class Context(val iteration: Int) - - class VertexContext(iteration: Int, val wasActive: Boolean, var isActive: Boolean = true) extends Context(iteration) { - def deactivate() { isActive = false } - } - - class EdgeContext(iteration: Int, val srcIsActive: Boolean, val dstIsActive: Boolean) extends Context(iteration) - - /** * Execute a Pregel-like iterative vertex-parallel abstraction. The @@ -169,6 +160,7 @@ object Pregel extends Logging { g } // end of apply + /** * Execute a Pregel-like iterative vertex-parallel abstraction. The * user-defined vertex-program `vprog` is executed in parallel on @@ -227,50 +219,41 @@ object Pregel extends Logging { (graph: Graph[VD, ED], maxIterations: Int = Int.MaxValue, activeDirection: EdgeDirection = EdgeDirection.Either) - (vertexProgram: (VertexId, VD, Option[A], VertexContext) => VD, - sendMsg: (EdgeTriplet[VD, ED], EdgeContext) => Iterator[(VertexId, A)], + (vertexProgram: (Int, VertexId, VD, Boolean, Option[A]) => (VD, Boolean), + sendMsg: (Int, EdgeTriplet[(VD, Boolean), ED]) => Iterator[(VertexId, A)], mergeMsg: (A, A) => A) : Graph[VD, ED] = { // Initialize the graph with all vertices active - var g: Graph[(VD, Boolean), ED] = graph.mapVertices { (vid, vdata) => (vdata, true) }.cache() + var currengGraph: Graph[(VD, Boolean), ED] = + graph.mapVertices { (vid, vdata) => (vdata, true) }.cache() // Determine the set of vertices that did not vote to halt - var activeVertices = g.vertices + var activeVertices = currengGraph.vertices var numActive = activeVertices.count() - var i = 0 - while (numActive > 0 && i < maxIterations) { - // The send message wrapper removes the active fields from the triplet and places them in the edge context. - def sendMessageWrapper(triplet: EdgeTriplet[(VD, Boolean),ED]): Iterator[(VertexId, A)] = { - val simpleTriplet = new EdgeTriplet[VD, ED]() - simpleTriplet.set(triplet) - simpleTriplet.srcAttr = triplet.srcAttr._1 - simpleTriplet.dstAttr = triplet.dstAttr._1 - val ctx = new EdgeContext(i, triplet.srcAttr._2, triplet.dstAttr._2) - sendMsg(simpleTriplet, ctx) - } - - // get a reference to the current graph so that we can unpersist it once the new graph is created. - val prevG = g + var iteration = 0 + while (numActive > 0 && iteration < maxIterations) { + // get a reference to the current graph to enable unprecistance. + val prevG = currengGraph // Compute the messages for all the active vertices - val messages = g.mapReduceTriplets(sendMessageWrapper, mergeMsg, Some((activeVertices, activeDirection))) + val messages = currengGraph.mapReduceTriplets( t => sendMsg(iteration, t), mergeMsg, + Some((activeVertices, activeDirection))) // Receive the messages to the subset of active vertices - g = g.outerJoinVertices(messages){ (vid, dataAndActive, msgOpt) => + currengGraph = currengGraph.outerJoinVertices(messages){ (vid, dataAndActive, msgOpt) => val (vdata, active) = dataAndActive // If the vertex voted to halt and received no message then we can skip the vertex program if (!active && msgOpt.isEmpty) { dataAndActive } else { - val ctx = new VertexContext(i, active) // The vertex program is either active or received a message (or both). // A vertex program should vote to halt again even if it has previously voted to halt - (vertexProgram(vid, vdata, msgOpt, ctx), ctx.isActive) + vertexProgram(iteration, vid, vdata, active, msgOpt) } }.cache() // Recompute the active vertices (those that have not voted to halt) - activeVertices = g.vertices.filter(v => v._2._2) + activeVertices = currengGraph.vertices.filter(v => v._2._2) // Force all computation! numActive = activeVertices.count() @@ -282,11 +265,11 @@ object Pregel extends Logging { //println("Finished Iteration " + i) // g.vertices.foreach(println(_)) - logInfo("Pregel finished iteration " + i) + logInfo("Pregel finished iteration " + iteration) // count the iteration - i += 1 + iteration += 1 } - g.mapVertices((id, vdata) => vdata._1) + currengGraph.mapVertices((id, vdata) => vdata._1) } // end of apply diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index b8fdfb3a1bc20..e0fdaeb9221ce 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -90,7 +90,6 @@ object PageRank extends Logging { // Set the vertex attributes to the initial pagerank values .mapVertices( (id, attr) => resetProb ) - var iteration = 0 var prevRankGraph: Graph[Double, Double] = null while (iteration < numIter) { @@ -152,23 +151,22 @@ object PageRank extends Logging { // Define the three functions needed to implement PageRank in the GraphX // version of Pregel - def vertexProgram(id: VertexId, attr: (Double, Double), msgSum: Option[Double], ctx: VertexContext) = { + def vertexProgram(iter: Int, id: VertexId, attr: (Double, Double), wasActive: Boolean, + msgSum: Option[Double]) = { var (oldPR, pendingDelta) = attr val newPR = oldPR + msgSum.getOrElse(0.0) // if we were active then we sent the pending delta on the last iteration - if (ctx.wasActive) { + if (wasActive) { pendingDelta = 0.0 } pendingDelta += (1.0 - resetProb) * msgSum.getOrElse(0.0) - if (math.abs(pendingDelta) <= tol) { - ctx.deactivate() - } - (newPR, pendingDelta) + val isActive = math.abs(pendingDelta) >= tol + ((newPR, pendingDelta), isActive) } - def sendMessage(edge: EdgeTriplet[(Double, Double), Double], ctx: EdgeContext) = { - val (srcPr, srcDelta) = edge.srcAttr - assert(ctx.srcIsActive) + def sendMessage(iter: Int, edge: EdgeTriplet[((Double, Double), Boolean), Double]) = { + val ((srcPr, srcDelta), srcIsActive) = edge.srcAttr + assert(srcIsActive) Iterator((edge.dstId, srcDelta * edge.attr)) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 14af24c7e04b9..aadc53fe54635 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -170,7 +170,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext with Matchers { val nVertices = 10 val starGraph = GraphGenerators.cycleGraph(sc, nVertices) val resetProb = 0.15 - val staticRanks: VertexRDD[Double] = starGraph.staticPageRank(numIter = 10, resetProb).vertices + val staticRanks: VertexRDD[Double] = + starGraph.staticPageRank(numIter = 10, resetProb).vertices // Check the static pagerank val pageranks: Map[VertexId, Double] = staticRanks.collect().toMap for (i <- 0 until nVertices) {