Skip to content

Commit

Permalink
Merge pull request #6 from mengxr/SPARK-1405-lda
Browse files Browse the repository at this point in the history
Spark 1405 lda: fixes for GC issues (mainly avoiding temp instances of Breeze Vectors)
  • Loading branch information
jkbradley committed Jan 16, 2015
2 parents 984c414 + c9c52f8 commit 648f66c
Showing 1 changed file with 65 additions and 33 deletions.
98 changes: 65 additions & 33 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering

import java.util.Random

import breeze.linalg.{DenseVector => BDV, sum => brzSum, normalize}
import breeze.linalg.{DenseVector => BDV, sum => brzSum, normalize, axpy => brzAxpy}

import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
Expand Down Expand Up @@ -279,7 +279,6 @@ object LDA {
vocabSize: Int,
topicSmoothing: Double,
termSmoothing: Double) {

// TODO: Checkpoint periodically?
def next(): LearningState = copy(graph = step(graph))

Expand All @@ -289,25 +288,42 @@ object LDA {
val alpha = topicSmoothing

val N_k = collectTopicTotals()
val sendMsg: EdgeContext[TopicCounts, TokenCount, TopicCounts] => Unit = (edgeContext) => {
// Compute N_{wj} gamma_{wjk}
val N_wj = edgeContext.attr
// E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count N_{wj}.
val scaledTopicDistribution: TopicCounts =
computePTopic(edgeContext, N_k, W, eta, alpha) * N_wj
edgeContext.sendToDst(scaledTopicDistribution)
edgeContext.sendToSrc(scaledTopicDistribution)
}
val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit =
(edgeContext) => {
// Compute N_{wj} gamma_{wjk}
val N_wj = edgeContext.attr
// E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count
// N_{wj}.
val scaledTopicDistribution: TopicCounts =
computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj
edgeContext.sendToDst((false, scaledTopicDistribution))
edgeContext.sendToSrc((false, scaledTopicDistribution))
}
// This is a hack to detect whether we could modify the values in-place.
// TODO: Add zero/seqOp/combOp option to aggregateMessages.
val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) =
(m0, m1) => {
val sum =
if (m0._1) {
m0._2 += m1._2
} else if (m1._1) {
m1._2 += m0._2
} else {
m0._2 + m1._2
}
(true, sum)
}
// M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
val docTopicDistributions: VertexRDD[TopicCounts] =
graph.aggregateMessages[TopicCounts](sendMsg, _ + _)
graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg)
.mapValues(_._2)
// Update the vertex descriptors with the new counts.
graph.outerJoinVertices(docTopicDistributions){ (vid, oldDist, newDist) => newDist.get }
graph.outerJoinVertices(docTopicDistributions) { (vid, oldDist, newDist) => newDist.get }
}

def collectTopicTotals(): TopicCounts = {
val numTopics = k
graph.vertices.filter(isTermVertex).map(_._2).fold(BDV.zeros[Double](numTopics))(_ + _)
graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _)
}

/**
Expand Down Expand Up @@ -373,18 +389,30 @@ object LDA {
* Compute gamma_{wjk}, a distribution over topics k.
*/
private def computePTopic(
edgeContext: EdgeContext[TopicCounts, TokenCount, TopicCounts],
N_k: TopicCounts,
docTopicCounts: TopicCounts,
wordTopicCounts: TopicCounts,
totalTopicCounts: TopicCounts,
vocabSize: Int,
eta: Double,
alpha: Double): TopicCounts = {
val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0)
val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0)
val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0))
// proportional to p(w|z) * p(z|d) / p(z)
val unnormalizedGamma = smoothed_N_wk :* smoothed_N_kj :/ smoothed_N_k
val K = docTopicCounts.length
val N_j = docTopicCounts.data
val N_w = wordTopicCounts.data
val N = totalTopicCounts.data
val eta1 = eta - 1.0
val alpha1 = alpha - 1.0
val Weta1 = vocabSize * eta1
var sum = 0.0
val gamma_wj = new Array[Double](K)
var k = 0
while (k < K) {
val gamma_wjk = (N_w(k) + eta1) * (N_j(k) + alpha1) / (N(k) + Weta1)
gamma_wj(k) = gamma_wjk
sum += gamma_wjk
k += 1
}
// normalize
unnormalizedGamma /= brzSum(unnormalizedGamma)
BDV(gamma_wj) /= sum
}

/**
Expand All @@ -398,12 +426,10 @@ object LDA {
termSmoothing: Double,
randomSeed: Long): LearningState = {
// For each document, create an edge (Document -> Term) for each unique term in the document.
val edges: RDD[Edge[TokenCount]] = docs.mapPartitionsWithIndex { case (partIndex, partDocs) =>
partDocs.flatMap { doc: Document =>
// Add edges for terms with non-zero counts.
doc.counts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) =>
Edge(doc.id, term2index(term), cnt)
}
val edges: RDD[Edge[TokenCount]] = docs.flatMap { doc =>
// Add edges for terms with non-zero counts.
doc.counts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) =>
Edge(doc.id, term2index(term), cnt)
}
}

Expand All @@ -420,12 +446,19 @@ object LDA {
}
}
def createVertices(sendToWhere: Edge[TokenCount] => VertexId): RDD[(VertexId, TopicCounts)] = {
val verticesTMP: RDD[(VertexId, TopicCounts)] =
val verticesTMP: RDD[(VertexId, (TokenCount, TopicCounts))] =
edgesWithGamma.map { case (edge, gamma: TopicCounts) =>
val N_wj = edge.attr
(sendToWhere(edge), gamma * N_wj)
(sendToWhere(edge), (edge.attr, gamma))
}
verticesTMP.foldByKey(BDV.zeros[Double](k))(_ + _)
verticesTMP.aggregateByKey(BDV.zeros[Double](k))(
(sum, t) => {
brzAxpy(t._1, t._2, sum)
sum
},
(sum0, sum1) => {
sum0 += sum1
}
)
}
val docVertices = createVertices(_.srcId)
val termVertices = createVertices(_.dstId)
Expand All @@ -436,5 +469,4 @@ object LDA {

LearningState(graph, k, vocabSize, topicSmoothing, termSmoothing)
}

}

0 comments on commit 648f66c

Please sign in to comment.