Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIRA issue: [SPARK-1405] Gibbs sampling based Latent Dirichlet Allocation (LDA) for MLlib #476

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.clustering

import java.util.Random

import breeze.linalg.{DenseVector => BDV}

import org.apache.spark.{AccumulableParam, Logging, SparkContext}
import org.apache.spark.mllib.expectation.GibbsSampling
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD

case class Document(docId: Int, content: Iterable[Int])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... if documents are just an ID and a list of token IDs, maybe something like a SparseVector is a better representation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean "termId: count" ? Yes it is a common way to do that. But I just consider the trade-off between statistical efficiency and hardware efficiency. If we combine the same term together in one document, it seems that the randomness is worse. Anyway, I'll try to modified it using SparseVector.


case class LDAParams (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it should be a case class.

docCounts: Vector,
topicCounts: Vector,
docTopicCounts: Array[Vector],
topicTermCounts: Array[Vector])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect that this will be really big - maybe the last two variables should be RDDs - similar to what we do with ALS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's make sense. I think the docTopicCounts could be sliced easily W.R.T. documents partitions. But for topicTermCounts, it's hard to do slice. I'll find a way to settle it.

extends Serializable {

def update(docId: Int, term: Int, topic: Int, inc: Int) = {
docCounts.toBreeze(docId) += inc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing the breeze conversion on every update seems inefficient. These variables should be private and created as breeze variables at initialization, only the user facing APIs need to be Vector, Array[Vector], etc.

topicCounts.toBreeze(topic) += inc
docTopicCounts(docId).toBreeze(topic) += inc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I think in this case the model might be really big - e.g. a billion documents in hundreds of topics. Or for the term side, millions of words in a vocabulary and hundreds of topics.

topicTermCounts(topic).toBreeze(term) += inc
this
}

def merge(other: LDAParams) = {
docCounts.toBreeze += other.docCounts.toBreeze
topicCounts.toBreeze += other.topicCounts.toBreeze

var i = 0
while (i < docTopicCounts.length) {
docTopicCounts(i).toBreeze += other.docTopicCounts(i).toBreeze
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more breeze conversion.

i += 1
}

i = 0
while (i < topicTermCounts.length) {
topicTermCounts(i).toBreeze += other.topicTermCounts(i).toBreeze
i += 1
}
this
}

/**
* This function used for computing the new distribution after drop one from current document,
* which is a really essential part of Gibbs sampling for LDA, you can refer to the paper:
* <I>Parameter estimation for text analysis</I>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link to paper, please.

*/
def dropOneDistSampler(
docTopicSmoothing: Double,
topicTermSmoothing: Double,
termId: Int,
docId: Int,
rand: Random): Int = {
val (numTopics, numTerms) = (topicCounts.size, topicTermCounts.head.size)
val topicThisTerm = BDV.zeros[Double](numTopics)
var i = 0
while (i < numTopics) {
topicThisTerm(i) =
((topicTermCounts(i)(termId) + topicTermSmoothing)
/ (topicCounts(i) + (numTerms * topicTermSmoothing))
) + (docTopicCounts(docId)(i) + docTopicSmoothing)
i += 1
}
GibbsSampling.multinomialDistSampler(rand, topicThisTerm)
}
}

object LDAParams {
implicit val ldaParamsAP = new LDAParamsAccumulableParam

def apply(numDocs: Int, numTopics: Int, numTerms: Int) = new LDAParams(
Vectors.fromBreeze(BDV.zeros[Double](numDocs)),
Vectors.fromBreeze(BDV.zeros[Double](numTopics)),
Array(0 until numDocs: _*).map(_ => Vectors.fromBreeze(BDV.zeros[Double](numTopics))),
Array(0 until numTopics: _*).map(_ => Vectors.fromBreeze(BDV.zeros[Double](numTerms))))
}

class LDAParamsAccumulableParam extends AccumulableParam[LDAParams, (Int, Int, Int, Int)] {
def addAccumulator(r: LDAParams, t: (Int, Int, Int, Int)) = {
val (docId, term, topic, inc) = t
r.update(docId, term, topic, inc)
}

def addInPlace(r1: LDAParams, r2: LDAParams): LDAParams = r1.merge(r2)

def zero(initialValue: LDAParams): LDAParams = initialValue
}

class LDA private (
var numTopics: Int,
var docTopicSmoothing: Double,
var topicTermSmoothing: Double,
var numIteration: Int,
var numDocs: Int,
var numTerms: Int)
extends Serializable with Logging {
def run(input: RDD[Document]): (GibbsSampling, LDAParams) = {
val trainer = new GibbsSampling(
input,
numIteration,
1,
docTopicSmoothing,
topicTermSmoothing)
(trainer, trainer.runGibbsSampling(LDAParams(numDocs, numTopics, numTerms)))
}
}

object LDA extends Logging {

def train(
data: RDD[Document],
numTopics: Int,
docTopicSmoothing: Double,
topicTermSmoothing: Double,
numIterations: Int,
numDocs: Int,
numTerms: Int): (Array[Vector], Array[Vector]) = {
val lda = new LDA(numTopics,
docTopicSmoothing,
topicTermSmoothing,
numIterations,
numDocs,
numTerms)
val (trainer, model) = lda.run(data)
trainer.solvePhiAndTheta(model)
}

def main(args: Array[String]) {
if (args.length != 5) {
println("Usage: LDA <master> <input_dir> <k> <max_iterations> <mini-split>")
System.exit(1)
}

val (master, inputDir, k, iters, minSplit) =
(args(0), args(1), args(2).toInt, args(3).toInt, args(4).toInt)
val checkPointDir = System.getProperty("spark.gibbsSampling.checkPointDir", "/tmp/lda-cp")
val sc = new SparkContext(master, "LDA")
sc.setCheckpointDir(checkPointDir)
val (data, wordMap, docMap) = MLUtils.loadCorpus(sc, inputDir, minSplit)
val numDocs = docMap.size
val numTerms = wordMap.size

val (phi, theta) = LDA.train(data, k, 0.01, 0.01, iters, numDocs, numTerms)
val pp = GibbsSampling.perplexity(data, phi, theta)
println(s"final mode perplexity is $pp")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.expectation

import java.util.Random

import breeze.linalg.{DenseVector => BDV, sum}

import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.clustering.{Document, LDAParams}
import org.apache.spark.mllib.linalg.{Vector, Vectors}

/**
* Gibbs sampling from a given dataset and org.apache.spark.mllib.model.
* @param data Dataset, such as corpus.
* @param numOuterIterations Number of outer iteration.
* @param numInnerIterations Number of inner iteration, used in each partition.
* @param docTopicSmoothing Document-topic smoothing.
* @param topicTermSmoothing Topic-term smoothing.
*/
class GibbsSampling(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gibbs Sampling is a very useful general purpose tool to have. It's interface should be something more generic than RDD[Document], and the parameters should be amenable to domains other than text.

data: RDD[Document],
numOuterIterations: Int,
numInnerIterations: Int,
docTopicSmoothing: Double,
topicTermSmoothing: Double)
extends Logging with Serializable {

import GibbsSampling._

/**
* Main function of running a Gibbs sampling method. It contains two phases of total Gibbs
* sampling: first is initialization, second is real sampling.
*/
def runGibbsSampling(
initParams: LDAParams,
data: RDD[Document] = data,
numOuterIterations: Int = numOuterIterations,
numInnerIterations: Int = numInnerIterations,
docTopicSmoothing: Double = docTopicSmoothing,
topicTermSmoothing: Double = topicTermSmoothing): LDAParams = {

val numTerms = initParams.topicTermCounts.head.size
val numDocs = initParams.docCounts.size
val numTopics = initParams.topicCounts.size

// Construct topic assignment RDD
logInfo("Start initialization")

val cpInterval = System.getProperty("spark.gibbsSampling.checkPointInterval", "10").toInt
val sc = data.context
val (initialParams, initialChosenTopics) = sampleTermAssignment(initParams, data)

// Gibbs sampling
val (params, _, _) = Iterator.iterate((sc.accumulable(initialParams), initialChosenTopics, 0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why an accumulator and not an .aggregate()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accumulator has the ability to do the fine-grained updating for LDA parameters. For aggregate, I have to use mapPartitions as shown in my previous version of LDA impl here. However, the previous impl is slower than current version, partly because the serialization of the huge parameters.

case (lastParams, lastChosenTopics, i) =>
logInfo("Start Gibbs sampling")

val rand = new Random(42 + i * i)
val params = sc.accumulable(LDAParams(numDocs, numTopics, numTerms))
val chosenTopics = data.zip(lastChosenTopics).map {
case (Document(docId, content), topics) =>
content.zip(topics).map { case (term, topic) =>
lastParams += (docId, term, topic, -1)

val chosenTopic = lastParams.localValue.dropOneDistSampler(
docTopicSmoothing, topicTermSmoothing, term, docId, rand)

lastParams += (docId, term, chosenTopic, 1)
params += (docId, term, chosenTopic, 1)

chosenTopic
}
}.cache()

if (i + 1 % cpInterval == 0) {
chosenTopics.checkpoint()
}

// Trigger a job to collect accumulable LDA parameters.
chosenTopics.count()
lastChosenTopics.unpersist()

(params, chosenTopics, i + 1)
}.drop(1 + numOuterIterations).next()

params.value
}

/**
* Model matrix Phi and Theta are inferred via LDAParams.
*/
def solvePhiAndTheta(
params: LDAParams,
docTopicSmoothing: Double = docTopicSmoothing,
topicTermSmoothing: Double = topicTermSmoothing): (Array[Vector], Array[Vector]) = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, Phi and Theta might be too big.

val numTopics = params.topicCounts.size
val numTerms = params.topicTermCounts.head.size

val docCount = params.docCounts.toBreeze :+ (docTopicSmoothing * numTopics)
val topicCount = params.topicCounts.toBreeze :+ (topicTermSmoothing * numTerms)
val docTopicCount = params.docTopicCounts.map(vec => vec.toBreeze :+ docTopicSmoothing)
val topicTermCount = params.topicTermCounts.map(vec => vec.toBreeze :+ topicTermSmoothing)

var i = 0
while (i < numTopics) {
topicTermCount(i) :/= topicCount(i)
i += 1
}

i = 0
while (i < docCount.length) {
docTopicCount(i) :/= docCount(i)
i += 1
}

(topicTermCount.map(vec => Vectors.fromBreeze(vec)),
docTopicCount.map(vec => Vectors.fromBreeze(vec)))
}
}

object GibbsSampling extends Logging {

/**
* Initial step of Gibbs sampling, which supports incremental LDA.
*/
private def sampleTermAssignment(
params: LDAParams,
data: RDD[Document]): (LDAParams, RDD[Iterable[Int]]) = {

val sc = data.context
val initialParams = sc.accumulable(params)
val rand = new Random(42)
val initialChosenTopics = data.map { case Document(docId, content) =>
val docTopics = params.docTopicCounts(docId)
if (docTopics.toBreeze.norm(2) == 0) {
content.map { term =>
val topic = uniformDistSampler(rand, params.topicCounts.size)
initialParams += (docId, term, topic, 1)
topic
}
} else {
content.map { term =>
val topicTerms = Vectors.dense(params.topicTermCounts.map(_(term))).toBreeze
val dist = docTopics.toBreeze :* topicTerms
multinomialDistSampler(rand, dist.asInstanceOf[BDV[Double]])
}
}
}.cache()

// Trigger a job to collect accumulable LDA parameters.
initialChosenTopics.count()

(initialParams.value, initialChosenTopics)
}

/**
* A uniform distribution sampler, which is only used for initialization.
*/
private def uniformDistSampler(rand: Random, dimension: Int): Int = rand.nextInt(dimension)

/**
* A multinomial distribution sampler, using roulette method to sample an Int back.
*/
def multinomialDistSampler(rand: Random, dist: BDV[Double]): Int = {
val roulette = rand.nextDouble()

dist :/= sum[BDV[Double], Double](dist)

def loop(index: Int, accum: Double): Int = {
if(index == dist.length) return dist.length - 1
val sum = accum + dist(index)
if (sum >= roulette) index else loop(index + 1, sum)
}

loop(0, 0.0)
}

/**
* Perplexity is a kind of evaluation method of LDA. Usually it is used on unseen data. But here
* we use it for current documents, which is also OK. If using it on unseen data, you must do an
* iteration of Gibbs sampling before calling this. Small perplexity means good result.
*/
def perplexity(data: RDD[Document], phi: Array[Vector], theta: Array[Vector]): Double = {
val (termProb, totalNum) = data.flatMap { case Document(docId, content) =>
val currentTheta = BDV.zeros[Double](phi.head.size)
var col = 0
var row = 0
while (col < phi.head.size) {
row = 0
while (row < phi.length) {
currentTheta(col) += phi(row)(col) * theta(docId)(row)
row += 1
}
col += 1
}
content.map(x => (math.log(currentTheta(x)), 1))
}.reduce { (lhs, rhs) =>
(lhs._1 + rhs._1, lhs._2 + rhs._2)
}
math.exp(-1 * termProb / totalNum)
}
}
Loading