Skip to content

Commit

Permalink
use syn0Global and syn1Global to represent model
Browse files Browse the repository at this point in the history
  • Loading branch information
Ishiihara committed Aug 17, 2014
1 parent cad2011 commit d5377a9
Showing 1 changed file with 32 additions and 20 deletions.
52 changes: 32 additions & 20 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class Word2Vec extends Serializable with Logging {
b = 0
while (b < i) {
vocab(a).code(i - b - 1) = code(b)
vocab(a).point(i - b) = point(b)
vocab(a).point(i - b) = point(b) - vocabSize
b += 1
}
a += 1
Expand Down Expand Up @@ -285,15 +285,17 @@ class Word2Vec extends Serializable with Logging {

val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
var synGlobal =
Array.fill[Float](2 * vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
var syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
var syn1Global = new Array[Float](vocabSize * vectorSize)
var alpha = startingAlpha
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val synModify = new Array[Int](2 * vocabSize)
val model = iter.foldLeft((synGlobal, 0, 0)) {
case ((syn, lastWordCount, wordCount), sentence) =>
val syn0Modify = new Array[Int](vocabSize)
val syn1Modify = new Array[Int](vocabSize)
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
if (wordCount - lastWordCount > 10000) {
Expand Down Expand Up @@ -324,45 +326,55 @@ class Word2Vec extends Serializable with Logging {
val inner = bcVocab.value(word).point(d)
val l2 = inner * vectorSize
// Propagate hidden -> output
var f = blas.sdot(vectorSize, syn, l1, 1, syn, l2, 1)
var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
blas.saxpy(vectorSize, g, syn, l2, 1, neu1e, 0, 1)
blas.saxpy(vectorSize, g, syn, l1, 1, syn, l2, 1)
synModify(inner) += 1
blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
syn1Modify(inner) += 1
}
d += 1
}
blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn, l1, 1)
synModify(lastWord) += 1
blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
syn0Modify(lastWord) += 1
}
}
a += 1
}
pos += 1
}
(syn, lwc, wc)
(syn0, syn1, lwc, wc)
}
val synLocal = model._1
val syn0Local = model._1
val syn1Local = model._2
val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2)
var index = 0
while(index < 2 * vocabSize) {
if (synModify(index) != 0) {
synOut.update(index, synLocal.slice(index * vectorSize, (index + 1) * vectorSize))
while(index < vocabSize) {
if (syn0Modify(index) != 0) {
synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))
}
if (syn1Modify(index) != 0) {
synOut.update(index + vocabSize,
syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))
}
index += 1
}
Iterator(synOut)
}
val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
v1
}.collect()
var i = 0
while (i < synAgg.length) {
Array.copy(synAgg(i)._2, 0, synGlobal, synAgg(i)._1 * vectorSize, vectorSize)
val index = synAgg(i)._1
if (index < vocabSize) {
Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize)
} else {
Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize)
}
i += 1
}
}
Expand All @@ -373,7 +385,7 @@ class Word2Vec extends Serializable with Logging {
while (i < vocabSize) {
val word = bcVocab.value(i).word
val vector = new Array[Float](vectorSize)
Array.copy(synGlobal, i * vectorSize, vector, 0, vectorSize)
Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize)
word2VecMap += word -> vector
i += 1
}
Expand Down

0 comments on commit d5377a9

Please sign in to comment.