Skip to content

Commit

Permalink
Refactor initial step of LDA to remove redundant operations.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Feb 10, 2015
1 parent 2d1e916 commit 9af1487
Showing 1 changed file with 15 additions and 24 deletions.
39 changes: 15 additions & 24 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -450,34 +450,25 @@ private[clustering] object LDA {

// Create vertices.
// Initially, we use random soft assignments of tokens to topics (random gamma).
val edgesWithGamma: RDD[(Edge[TokenCount], TopicCounts)] =
edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
val random = new Random(partIndex + randomSeed)
partEdges.map { edge =>
// Create a random gamma_{wjk}
(edge, normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0))
def createVertices(): RDD[(VertexId, TopicCounts)] = {
val verticesTMP: RDD[(VertexId, TopicCounts)] =
edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
val random = new Random(partIndex + randomSeed)
partEdges.flatMap { edge =>
val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)
val sum = BDV.zeros[Double](k)
brzAxpy(edge.attr, gamma, sum)

Seq((edge.srcId, sum), (edge.dstId, sum))
}
}
}
def createVertices(sendToWhere: Edge[TokenCount] => VertexId): RDD[(VertexId, TopicCounts)] = {
val verticesTMP: RDD[(VertexId, (TokenCount, TopicCounts))] =
edgesWithGamma.map { case (edge, gamma: TopicCounts) =>
(sendToWhere(edge), (edge.attr, gamma))
}
verticesTMP.aggregateByKey(BDV.zeros[Double](k))(
(sum, t) => {
brzAxpy(t._1, t._2, sum)
sum
},
(sum0, sum1) => {
sum0 += sum1
}
)
verticesTMP.reduceByKey((sum0, sum1) => { sum0 + sum1 })
}
val docVertices = createVertices(_.srcId)
val termVertices = createVertices(_.dstId)

val docTermVertices = createVertices()

// Partition such that edges are grouped by document
val graph = Graph(docVertices ++ termVertices, edges)
val graph = Graph(docTermVertices, edges)
.partitionBy(PartitionStrategy.EdgePartition1D)

new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval)
Expand Down

0 comments on commit 9af1487

Please sign in to comment.