-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JIRA issue: [SPARK-1405] Gibbs sampling based Latent Dirichlet Allocation (LDA) for MLlib #476
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.clustering | ||
|
||
import java.util.Random | ||
|
||
import breeze.linalg.{DenseVector => BDV} | ||
|
||
import org.apache.spark.{AccumulableParam, Logging, SparkContext} | ||
import org.apache.spark.mllib.expectation.GibbsSampling | ||
import org.apache.spark.mllib.linalg.{Vector, Vectors} | ||
import org.apache.spark.mllib.util.MLUtils | ||
import org.apache.spark.rdd.RDD | ||
|
||
case class Document(docId: Int, content: Iterable[Int]) | ||
|
||
case class LDAParams ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it should be a case class. |
||
docCounts: Vector, | ||
topicCounts: Vector, | ||
docTopicCounts: Array[Vector], | ||
topicTermCounts: Array[Vector]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I expect that this will be really big - maybe the last two variables should be RDDs - similar to what we do with ALS. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's make sense. I think the |
||
extends Serializable { | ||
|
||
def update(docId: Int, term: Int, topic: Int, inc: Int) = { | ||
docCounts.toBreeze(docId) += inc | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doing the breeze conversion on every update seems inefficient. These variables should be private and created as breeze variables at initialization, only the user facing APIs need to be Vector, Array[Vector], etc. |
||
topicCounts.toBreeze(topic) += inc | ||
docTopicCounts(docId).toBreeze(topic) += inc | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, I think in this case the model might be really big - e.g. a billion documents in hundreds of topics. Or for the term side, millions of words in a vocabulary and hundreds of topics. |
||
topicTermCounts(topic).toBreeze(term) += inc | ||
this | ||
} | ||
|
||
def merge(other: LDAParams) = { | ||
docCounts.toBreeze += other.docCounts.toBreeze | ||
topicCounts.toBreeze += other.topicCounts.toBreeze | ||
|
||
var i = 0 | ||
while (i < docTopicCounts.length) { | ||
docTopicCounts(i).toBreeze += other.docTopicCounts(i).toBreeze | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. more breeze conversion. |
||
i += 1 | ||
} | ||
|
||
i = 0 | ||
while (i < topicTermCounts.length) { | ||
topicTermCounts(i).toBreeze += other.topicTermCounts(i).toBreeze | ||
i += 1 | ||
} | ||
this | ||
} | ||
|
||
/** | ||
* This function used for computing the new distribution after drop one from current document, | ||
* which is a really essential part of Gibbs sampling for LDA, you can refer to the paper: | ||
* <I>Parameter estimation for text analysis</I> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Link to paper, please. |
||
*/ | ||
def dropOneDistSampler( | ||
docTopicSmoothing: Double, | ||
topicTermSmoothing: Double, | ||
termId: Int, | ||
docId: Int, | ||
rand: Random): Int = { | ||
val (numTopics, numTerms) = (topicCounts.size, topicTermCounts.head.size) | ||
val topicThisTerm = BDV.zeros[Double](numTopics) | ||
var i = 0 | ||
while (i < numTopics) { | ||
topicThisTerm(i) = | ||
((topicTermCounts(i)(termId) + topicTermSmoothing) | ||
/ (topicCounts(i) + (numTerms * topicTermSmoothing)) | ||
) + (docTopicCounts(docId)(i) + docTopicSmoothing) | ||
i += 1 | ||
} | ||
GibbsSampling.multinomialDistSampler(rand, topicThisTerm) | ||
} | ||
} | ||
|
||
object LDAParams { | ||
implicit val ldaParamsAP = new LDAParamsAccumulableParam | ||
|
||
def apply(numDocs: Int, numTopics: Int, numTerms: Int) = new LDAParams( | ||
Vectors.fromBreeze(BDV.zeros[Double](numDocs)), | ||
Vectors.fromBreeze(BDV.zeros[Double](numTopics)), | ||
Array(0 until numDocs: _*).map(_ => Vectors.fromBreeze(BDV.zeros[Double](numTopics))), | ||
Array(0 until numTopics: _*).map(_ => Vectors.fromBreeze(BDV.zeros[Double](numTerms)))) | ||
} | ||
|
||
class LDAParamsAccumulableParam extends AccumulableParam[LDAParams, (Int, Int, Int, Int)] { | ||
def addAccumulator(r: LDAParams, t: (Int, Int, Int, Int)) = { | ||
val (docId, term, topic, inc) = t | ||
r.update(docId, term, topic, inc) | ||
} | ||
|
||
def addInPlace(r1: LDAParams, r2: LDAParams): LDAParams = r1.merge(r2) | ||
|
||
def zero(initialValue: LDAParams): LDAParams = initialValue | ||
} | ||
|
||
class LDA private ( | ||
var numTopics: Int, | ||
var docTopicSmoothing: Double, | ||
var topicTermSmoothing: Double, | ||
var numIteration: Int, | ||
var numDocs: Int, | ||
var numTerms: Int) | ||
extends Serializable with Logging { | ||
def run(input: RDD[Document]): (GibbsSampling, LDAParams) = { | ||
val trainer = new GibbsSampling( | ||
input, | ||
numIteration, | ||
1, | ||
docTopicSmoothing, | ||
topicTermSmoothing) | ||
(trainer, trainer.runGibbsSampling(LDAParams(numDocs, numTopics, numTerms))) | ||
} | ||
} | ||
|
||
object LDA extends Logging { | ||
|
||
def train( | ||
data: RDD[Document], | ||
numTopics: Int, | ||
docTopicSmoothing: Double, | ||
topicTermSmoothing: Double, | ||
numIterations: Int, | ||
numDocs: Int, | ||
numTerms: Int): (Array[Vector], Array[Vector]) = { | ||
val lda = new LDA(numTopics, | ||
docTopicSmoothing, | ||
topicTermSmoothing, | ||
numIterations, | ||
numDocs, | ||
numTerms) | ||
val (trainer, model) = lda.run(data) | ||
trainer.solvePhiAndTheta(model) | ||
} | ||
|
||
def main(args: Array[String]) { | ||
if (args.length != 5) { | ||
println("Usage: LDA <master> <input_dir> <k> <max_iterations> <mini-split>") | ||
System.exit(1) | ||
} | ||
|
||
val (master, inputDir, k, iters, minSplit) = | ||
(args(0), args(1), args(2).toInt, args(3).toInt, args(4).toInt) | ||
val checkPointDir = System.getProperty("spark.gibbsSampling.checkPointDir", "/tmp/lda-cp") | ||
val sc = new SparkContext(master, "LDA") | ||
sc.setCheckpointDir(checkPointDir) | ||
val (data, wordMap, docMap) = MLUtils.loadCorpus(sc, inputDir, minSplit) | ||
val numDocs = docMap.size | ||
val numTerms = wordMap.size | ||
|
||
val (phi, theta) = LDA.train(data, k, 0.01, 0.01, iters, numDocs, numTerms) | ||
val pp = GibbsSampling.perplexity(data, phi, theta) | ||
println(s"final mode perplexity is $pp") | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.expectation | ||
|
||
import java.util.Random | ||
|
||
import breeze.linalg.{DenseVector => BDV, sum} | ||
|
||
import org.apache.spark.Logging | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.mllib.clustering.{Document, LDAParams} | ||
import org.apache.spark.mllib.linalg.{Vector, Vectors} | ||
|
||
/** | ||
* Gibbs sampling from a given dataset and org.apache.spark.mllib.model. | ||
* @param data Dataset, such as corpus. | ||
* @param numOuterIterations Number of outer iteration. | ||
* @param numInnerIterations Number of inner iteration, used in each partition. | ||
* @param docTopicSmoothing Document-topic smoothing. | ||
* @param topicTermSmoothing Topic-term smoothing. | ||
*/ | ||
class GibbsSampling( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gibbs Sampling is a very useful general purpose tool to have. It's interface should be something more generic than RDD[Document], and the parameters should be amenable to domains other than text. |
||
data: RDD[Document], | ||
numOuterIterations: Int, | ||
numInnerIterations: Int, | ||
docTopicSmoothing: Double, | ||
topicTermSmoothing: Double) | ||
extends Logging with Serializable { | ||
|
||
import GibbsSampling._ | ||
|
||
/** | ||
* Main function of running a Gibbs sampling method. It contains two phases of total Gibbs | ||
* sampling: first is initialization, second is real sampling. | ||
*/ | ||
def runGibbsSampling( | ||
initParams: LDAParams, | ||
data: RDD[Document] = data, | ||
numOuterIterations: Int = numOuterIterations, | ||
numInnerIterations: Int = numInnerIterations, | ||
docTopicSmoothing: Double = docTopicSmoothing, | ||
topicTermSmoothing: Double = topicTermSmoothing): LDAParams = { | ||
|
||
val numTerms = initParams.topicTermCounts.head.size | ||
val numDocs = initParams.docCounts.size | ||
val numTopics = initParams.topicCounts.size | ||
|
||
// Construct topic assignment RDD | ||
logInfo("Start initialization") | ||
|
||
val cpInterval = System.getProperty("spark.gibbsSampling.checkPointInterval", "10").toInt | ||
val sc = data.context | ||
val (initialParams, initialChosenTopics) = sampleTermAssignment(initParams, data) | ||
|
||
// Gibbs sampling | ||
val (params, _, _) = Iterator.iterate((sc.accumulable(initialParams), initialChosenTopics, 0)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why an accumulator and not an .aggregate()? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accumulator has the ability to do the fine-grained updating for LDA parameters. For aggregate, I have to use |
||
case (lastParams, lastChosenTopics, i) => | ||
logInfo("Start Gibbs sampling") | ||
|
||
val rand = new Random(42 + i * i) | ||
val params = sc.accumulable(LDAParams(numDocs, numTopics, numTerms)) | ||
val chosenTopics = data.zip(lastChosenTopics).map { | ||
case (Document(docId, content), topics) => | ||
content.zip(topics).map { case (term, topic) => | ||
lastParams += (docId, term, topic, -1) | ||
|
||
val chosenTopic = lastParams.localValue.dropOneDistSampler( | ||
docTopicSmoothing, topicTermSmoothing, term, docId, rand) | ||
|
||
lastParams += (docId, term, chosenTopic, 1) | ||
params += (docId, term, chosenTopic, 1) | ||
|
||
chosenTopic | ||
} | ||
}.cache() | ||
|
||
if (i + 1 % cpInterval == 0) { | ||
chosenTopics.checkpoint() | ||
} | ||
|
||
// Trigger a job to collect accumulable LDA parameters. | ||
chosenTopics.count() | ||
lastChosenTopics.unpersist() | ||
|
||
(params, chosenTopics, i + 1) | ||
}.drop(1 + numOuterIterations).next() | ||
|
||
params.value | ||
} | ||
|
||
/** | ||
* Model matrix Phi and Theta are inferred via LDAParams. | ||
*/ | ||
def solvePhiAndTheta( | ||
params: LDAParams, | ||
docTopicSmoothing: Double = docTopicSmoothing, | ||
topicTermSmoothing: Double = topicTermSmoothing): (Array[Vector], Array[Vector]) = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, Phi and Theta might be too big. |
||
val numTopics = params.topicCounts.size | ||
val numTerms = params.topicTermCounts.head.size | ||
|
||
val docCount = params.docCounts.toBreeze :+ (docTopicSmoothing * numTopics) | ||
val topicCount = params.topicCounts.toBreeze :+ (topicTermSmoothing * numTerms) | ||
val docTopicCount = params.docTopicCounts.map(vec => vec.toBreeze :+ docTopicSmoothing) | ||
val topicTermCount = params.topicTermCounts.map(vec => vec.toBreeze :+ topicTermSmoothing) | ||
|
||
var i = 0 | ||
while (i < numTopics) { | ||
topicTermCount(i) :/= topicCount(i) | ||
i += 1 | ||
} | ||
|
||
i = 0 | ||
while (i < docCount.length) { | ||
docTopicCount(i) :/= docCount(i) | ||
i += 1 | ||
} | ||
|
||
(topicTermCount.map(vec => Vectors.fromBreeze(vec)), | ||
docTopicCount.map(vec => Vectors.fromBreeze(vec))) | ||
} | ||
} | ||
|
||
object GibbsSampling extends Logging { | ||
|
||
/** | ||
* Initial step of Gibbs sampling, which supports incremental LDA. | ||
*/ | ||
private def sampleTermAssignment( | ||
params: LDAParams, | ||
data: RDD[Document]): (LDAParams, RDD[Iterable[Int]]) = { | ||
|
||
val sc = data.context | ||
val initialParams = sc.accumulable(params) | ||
val rand = new Random(42) | ||
val initialChosenTopics = data.map { case Document(docId, content) => | ||
val docTopics = params.docTopicCounts(docId) | ||
if (docTopics.toBreeze.norm(2) == 0) { | ||
content.map { term => | ||
val topic = uniformDistSampler(rand, params.topicCounts.size) | ||
initialParams += (docId, term, topic, 1) | ||
topic | ||
} | ||
} else { | ||
content.map { term => | ||
val topicTerms = Vectors.dense(params.topicTermCounts.map(_(term))).toBreeze | ||
val dist = docTopics.toBreeze :* topicTerms | ||
multinomialDistSampler(rand, dist.asInstanceOf[BDV[Double]]) | ||
} | ||
} | ||
}.cache() | ||
|
||
// Trigger a job to collect accumulable LDA parameters. | ||
initialChosenTopics.count() | ||
|
||
(initialParams.value, initialChosenTopics) | ||
} | ||
|
||
/** | ||
* A uniform distribution sampler, which is only used for initialization. | ||
*/ | ||
private def uniformDistSampler(rand: Random, dimension: Int): Int = rand.nextInt(dimension) | ||
|
||
/** | ||
* A multinomial distribution sampler, using roulette method to sample an Int back. | ||
*/ | ||
def multinomialDistSampler(rand: Random, dist: BDV[Double]): Int = { | ||
val roulette = rand.nextDouble() | ||
|
||
dist :/= sum[BDV[Double], Double](dist) | ||
|
||
def loop(index: Int, accum: Double): Int = { | ||
if(index == dist.length) return dist.length - 1 | ||
val sum = accum + dist(index) | ||
if (sum >= roulette) index else loop(index + 1, sum) | ||
} | ||
|
||
loop(0, 0.0) | ||
} | ||
|
||
/** | ||
* Perplexity is a kind of evaluation method of LDA. Usually it is used on unseen data. But here | ||
* we use it for current documents, which is also OK. If using it on unseen data, you must do an | ||
* iteration of Gibbs sampling before calling this. Small perplexity means good result. | ||
*/ | ||
def perplexity(data: RDD[Document], phi: Array[Vector], theta: Array[Vector]): Double = { | ||
val (termProb, totalNum) = data.flatMap { case Document(docId, content) => | ||
val currentTheta = BDV.zeros[Double](phi.head.size) | ||
var col = 0 | ||
var row = 0 | ||
while (col < phi.head.size) { | ||
row = 0 | ||
while (row < phi.length) { | ||
currentTheta(col) += phi(row)(col) * theta(docId)(row) | ||
row += 1 | ||
} | ||
col += 1 | ||
} | ||
content.map(x => (math.log(currentTheta(x)), 1)) | ||
}.reduce { (lhs, rhs) => | ||
(lhs._1 + rhs._1, lhs._2 + rhs._2) | ||
} | ||
math.exp(-1 * termProb / totalNum) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm... if documents are just an ID and a list of token IDs, maybe something like a SparseVector is a better representation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean "termId: count" ? Yes it is a common way to do that. But I just consider the trade-off between statistical efficiency and hardware efficiency. If we combine the same term together in one document, it seems that the randomness is worse. Anyway, I'll try to modified it using SparseVector.