Skip to content

Commit

Permalink
combine syn0Global and syn1Global to synGlobal
Browse files Browse the repository at this point in the history
  • Loading branch information
Ishiihara committed Aug 13, 2014
1 parent aa2ab36 commit 9075e1c
Showing 1 changed file with 23 additions and 55 deletions.
78 changes: 23 additions & 55 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class Word2Vec extends Serializable with Logging {
b = 0
while (b < i) {
vocab(a).code(i - b - 1) = code(b)
vocab(a).point(i - b) = point(b) - vocabSize
vocab(a).point(i - b) = point(b)
b += 1
}
a += 1
Expand Down Expand Up @@ -284,19 +284,15 @@ class Word2Vec extends Serializable with Logging {

val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
var syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
var syn1Global = new Array[Float](vocabSize * vectorSize)

var synGlobal =
Array.fill[Float](2 * vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
var alpha = startingAlpha
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val syn0Modify = new Array[Int](vocabSize)
val syn1Modify = new Array[Int](vocabSize)

val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
val synModify = new Array[Int](2 * vocabSize)
val model = iter.foldLeft((synGlobal, 0, 0)) {
case ((syn, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
if (wordCount - lastWordCount > 10000) {
Expand Down Expand Up @@ -325,73 +321,45 @@ class Word2Vec extends Serializable with Logging {
var d = 0
while (d < bcVocab.value(word).codeLen) {
val ind = bcVocab.value(word).point(d)
val l2 = bcVocab.value(word).point(d) * vectorSize
val l2 = ind * vectorSize
// Propagate hidden -> output
syn1Modify(ind) += 1
var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
synModify(ind) += 1
var f = blas.sdot(vectorSize, syn, l1, 1, syn, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind)
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
blas.saxpy(vectorSize, g, syn, l2, 1, neu1e, 0, 1)
blas.saxpy(vectorSize, g, syn, l1, 1, syn, l2, 1)
}
d += 1
}
blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
syn0Modify(lastWord) += 1
blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn, l1, 1)
synModify(lastWord) += 1
}
}
a += 1
}
pos += 1
}
(syn0, syn1, lwc, wc)
(syn, lwc, wc)
}
val syn0Local = model._1
val syn1Local = model._2

val syn0Out = new PrimitiveKeyOpenHashMap[Int, Array[Float]]
// val syn1Out = new PrimitiveKeyOpenHashMap[Int, Array[Float]]
val synLocal = model._1
val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2)
var index = 0
while(index < vocabSize) {
if (syn0Modify(index) != 0) syn0Out.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))
if (syn1Modify(index) != 0) syn0Out.update(index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))
while(index < 2 * vocabSize) {
if (synModify(index) != 0) {
synOut.update(index, synLocal.slice(index * vectorSize, (index + 1) * vectorSize))
}
index += 1
}
Iterator(syn0Out)
Iterator(synOut)
}
// partial.cache()

val synAgg = partial.flatMap(x => x).reduceByKey {
synGlobal = partial.flatMap(x => x).reduceByKey {
case (v1,v2) =>
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
v1
}.collect().sortBy(_._1).flatMap(x => x._2)

syn0Global = synAgg.slice(0, vocabSize * vectorSize)
syn1Global = synAgg.slice(vocabSize * vectorSize, synAgg.length)
//logInfo("syn0Global length = " + syn0Global.length)
// syn1Global = partial.flatMap(x => x._2).reduceByKey {
// case (v1,v2) =>
// blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
// v1
// }.collect().sortBy(_._1).flatMap(x => x._2)
// logInfo("syn1Global length = " + syn1Global.length)
// logInfo("vocab size = " + vocabSize)
// val (aggSyn0, aggSyn1, _, _) =
// partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
// val n = syn0_1.length
// val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
// val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
// blas.sscal(n, weight1, syn0_1, 1)
// blas.sscal(n, weight1, syn1_1, 1)
// blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
// blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
// (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
// }
//syn0Global = aggSyn0
//syn1Global = aggSyn1
}
newSentences.unpersist()

Expand All @@ -400,7 +368,7 @@ class Word2Vec extends Serializable with Logging {
while (i < vocabSize) {
val word = bcVocab.value(i).word
val vector = new Array[Float](vectorSize)
Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize)
Array.copy(synGlobal, i * vectorSize, vector, 0, vectorSize)
word2VecMap += word -> vector
i += 1
}
Expand Down

0 comments on commit 9075e1c

Please sign in to comment.