Skip to content

Commit

Permalink
update synGlobal in place and reduce synOut size
Browse files Browse the repository at this point in the history
  • Loading branch information
Ishiihara committed Aug 14, 2014
1 parent 9075e1c commit 083aa66
Showing 1 changed file with 9 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.rdd._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap

/**
* Entry in vocabulary
*/
Expand Down Expand Up @@ -323,14 +324,14 @@ class Word2Vec extends Serializable with Logging {
val ind = bcVocab.value(word).point(d)
val l2 = ind * vectorSize
// Propagate hidden -> output
synModify(ind) += 1
var f = blas.sdot(vectorSize, syn, l1, 1, syn, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
blas.saxpy(vectorSize, g, syn, l2, 1, neu1e, 0, 1)
blas.saxpy(vectorSize, g, syn, l1, 1, syn, l2, 1)
synModify(ind) += 1
}
d += 1
}
Expand All @@ -355,11 +356,15 @@ class Word2Vec extends Serializable with Logging {
}
Iterator(synOut)
}
synGlobal = partial.flatMap(x => x).reduceByKey {
case (v1,v2) =>
val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
v1
}.collect().sortBy(_._1).flatMap(x => x._2)
}.collect()
var i = 0
while (i < synAgg.length) {
Array.copy(synAgg(i)._2, 0, synGlobal, synAgg(i)._1 * vectorSize, vectorSize)
i += 1
}
}
newSentences.unpersist()

Expand Down

0 comments on commit 083aa66

Please sign in to comment.