Skip to content

Commit

Permalink
update IDF to separate Model from Algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Aug 6, 2014
1 parent e537b33 commit 89f3486
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 78 deletions.
130 changes: 61 additions & 69 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,87 +36,25 @@ class IDF {

// TODO: Allow different IDF formulations.

private var brzIdf: BDV[Double] = _

/**
* Computes the inverse document frequency.
* @param dataset an RDD of term frequency vectors
*/
def fit(dataset: RDD[Vector]): this.type = {
brzIdf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)(
def fit(dataset: RDD[Vector]): IDFModel = {
val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)(
seqOp = (df, v) => df.add(v),
combOp = (df1, df2) => df1.merge(df2)
).idf()
this
new IDFModel(idf)
}

/**
* Computes the inverse document frequency.
* @param dataset a JavaRDD of term frequency vectors
*/
def fit(dataset: JavaRDD[Vector]): this.type = {
def fit(dataset: JavaRDD[Vector]): IDFModel = {
fit(dataset.rdd)
}

/**
* Transforms term frequency (TF) vectors to TF-IDF vectors.
* @param dataset an RDD of term frequency vectors
* @return an RDD of TF-IDF vectors
*/
def transform(dataset: RDD[Vector]): RDD[Vector] = {
if (!initialized) {
throw new IllegalStateException("Haven't learned IDF yet. Call fit first.")
}
val theIdf = brzIdf
val bcIdf = dataset.context.broadcast(theIdf)
dataset.mapPartitions { iter =>
val thisIdf = bcIdf.value
iter.map { v =>
val n = v.size
v match {
case sv: SparseVector =>
val nnz = sv.indices.size
val newValues = new Array[Double](nnz)
var k = 0
while (k < nnz) {
newValues(k) = sv.values(k) * thisIdf(sv.indices(k))
k += 1
}
Vectors.sparse(n, sv.indices, newValues)
case dv: DenseVector =>
val newValues = new Array[Double](n)
var j = 0
while (j < n) {
newValues(j) = dv.values(j) * thisIdf(j)
j += 1
}
Vectors.dense(newValues)
case other =>
throw new UnsupportedOperationException(
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
}
}
}
}

/**
* Transforms term frequency (TF) vectors to TF-IDF vectors (Java version).
* @param dataset a JavaRDD of term frequency vectors
* @return a JavaRDD of TF-IDF vectors
*/
def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = {
transform(dataset.rdd).toJavaRDD()
}

/** Returns the IDF vector. */
def idf(): Vector = {
if (!initialized) {
throw new IllegalStateException("Haven't learned IDF yet. Call fit first.")
}
Vectors.fromBreeze(brzIdf)
}

private def initialized: Boolean = brzIdf != null
}

private object IDF {
Expand Down Expand Up @@ -177,18 +115,72 @@ private object IDF {
private def isEmpty: Boolean = m == 0L

/** Returns the current IDF vector. */
def idf(): BDV[Double] = {
def idf(): Vector = {
if (isEmpty) {
throw new IllegalStateException("Haven't seen any document yet.")
}
val n = df.length
val inv = BDV.zeros[Double](n)
val inv = new Array[Double](n)
var j = 0
while (j < n) {
inv(j) = math.log((m + 1.0)/ (df(j) + 1.0))
j += 1
}
inv
Vectors.dense(inv)
}
}
}

/**
* :: Experimental ::
* Represents an IDF model that can transform term frequency vectors.
*/
@Experimental
class IDFModel private[mllib] (val idf: Vector) extends Serializable {

/**
* Transforms term frequency (TF) vectors to TF-IDF vectors.
* @param dataset an RDD of term frequency vectors
* @return an RDD of TF-IDF vectors
*/
def transform(dataset: RDD[Vector]): RDD[Vector] = {
val bcIdf = dataset.context.broadcast(idf)
dataset.mapPartitions { iter =>
val thisIdf = bcIdf.value
iter.map { v =>
val n = v.size
v match {
case sv: SparseVector =>
val nnz = sv.indices.size
val newValues = new Array[Double](nnz)
var k = 0
while (k < nnz) {
newValues(k) = sv.values(k) * thisIdf(sv.indices(k))
k += 1
}
Vectors.sparse(n, sv.indices, newValues)
case dv: DenseVector =>
val newValues = new Array[Double](n)
var j = 0
while (j < n) {
newValues(j) = dv.values(j) * thisIdf(j)
j += 1
}
Vectors.dense(newValues)
case other =>
throw new UnsupportedOperationException(
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
}
}
}
}

/**
* Transforms term frequency (TF) vectors to TF-IDF vectors (Java version).
* @param dataset a JavaRDD of term frequency vectors
* @return a JavaRDD of TF-IDF vectors
*/
def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = {
transform(dataset.rdd).toJavaRDD()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,12 @@ class IDFSuite extends FunSuite with LocalSparkContext {
val m = localTermFrequencies.size
val termFrequencies = sc.parallelize(localTermFrequencies, 2)
val idf = new IDF
intercept[IllegalStateException] {
idf.idf()
}
intercept[IllegalStateException] {
idf.transform(termFrequencies)
}
idf.fit(termFrequencies)
val model = idf.fit(termFrequencies)
val expected = Vectors.dense(Array(0, 3, 1, 2).map { x =>
math.log((m.toDouble + 1.0) / (x + 1.0))
})
assert(idf.idf() ~== expected absTol 1e-12)
val tfidf = idf.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap()
assert(model.idf ~== expected absTol 1e-12)
val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap()
assert(tfidf.size === 3)
val tfidf0 = tfidf(0L).asInstanceOf[SparseVector]
assert(tfidf0.indices === Array(1, 3))
Expand Down

0 comments on commit 89f3486

Please sign in to comment.