From 9075e1cba5ae64add2986514be99dc51083ff177 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Wed, 13 Aug 2014 16:17:58 -0700 Subject: [PATCH] combine syn0Global and syn1Global to synGlobal --- .../apache/spark/mllib/feature/Word2Vec.scala | 78 ++++++------------- 1 file changed, 23 insertions(+), 55 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 3ec1b3407a87c..8cfd13a837c98 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -235,7 +235,7 @@ class Word2Vec extends Serializable with Logging { b = 0 while (b < i) { vocab(a).code(i - b - 1) = code(b) - vocab(a).point(i - b) = point(b) - vocabSize + vocab(a).point(i - b) = point(b) b += 1 } a += 1 @@ -284,19 +284,15 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) - var syn0Global = - Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) - var syn1Global = new Array[Float](vocabSize * vectorSize) - + var synGlobal = + Array.fill[Float](2 * vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) var alpha = startingAlpha for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) - val syn0Modify = new Array[Int](vocabSize) - val syn1Modify = new Array[Int](vocabSize) - - val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { - case ((syn0, syn1, lastWordCount, wordCount), sentence) => + val synModify = new Array[Int](2 * vocabSize) + val model = iter.foldLeft((synGlobal, 0, 0)) { + case ((syn, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount if (wordCount - lastWordCount > 10000) { @@ -325,73 +321,45 @@ class Word2Vec extends Serializable with Logging { var d = 0 while (d < bcVocab.value(word).codeLen) { val ind = bcVocab.value(word).point(d) - val l2 = bcVocab.value(word).point(d) * vectorSize + val l2 = ind * vectorSize // Propagate hidden -> output - syn1Modify(ind) += 1 - var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) + synModify(ind) += 1 + var f = blas.sdot(vectorSize, syn, l1, 1, syn, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt f = expTable.value(ind) val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat - blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) - blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) + blas.saxpy(vectorSize, g, syn, l2, 1, neu1e, 0, 1) + blas.saxpy(vectorSize, g, syn, l1, 1, syn, l2, 1) } d += 1 } - blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) - syn0Modify(lastWord) += 1 + blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn, l1, 1) + synModify(lastWord) += 1 } } a += 1 } pos += 1 } - (syn0, syn1, lwc, wc) + (syn, lwc, wc) } - val syn0Local = model._1 - val syn1Local = model._2 - - val syn0Out = new PrimitiveKeyOpenHashMap[Int, Array[Float]] - // val syn1Out = new PrimitiveKeyOpenHashMap[Int, Array[Float]] + val synLocal = model._1 + val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2) var index = 0 - while(index < vocabSize) { - if (syn0Modify(index) != 0) syn0Out.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)) - if (syn1Modify(index) != 0) syn0Out.update(index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)) + while(index < 2 * vocabSize) { + if (synModify(index) != 0) { + synOut.update(index, synLocal.slice(index * vectorSize, (index + 1) * vectorSize)) + } index += 1 } - Iterator(syn0Out) + Iterator(synOut) } - // partial.cache() - - val synAgg = partial.flatMap(x => x).reduceByKey { + synGlobal = partial.flatMap(x => x).reduceByKey { case (v1,v2) => blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) v1 }.collect().sortBy(_._1).flatMap(x => x._2) - - syn0Global = synAgg.slice(0, vocabSize * vectorSize) - syn1Global = synAgg.slice(vocabSize * vectorSize, synAgg.length) - //logInfo("syn0Global length = " + syn0Global.length) - // syn1Global = partial.flatMap(x => x._2).reduceByKey { - // case (v1,v2) => - // blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) - // v1 - // }.collect().sortBy(_._1).flatMap(x => x._2) - // logInfo("syn1Global length = " + syn1Global.length) - // logInfo("vocab size = " + vocabSize) - // val (aggSyn0, aggSyn1, _, _) = - // partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - // val n = syn0_1.length - // val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) - // val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) - // blas.sscal(n, weight1, syn0_1, 1) - // blas.sscal(n, weight1, syn1_1, 1) - // blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) - // blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) - // (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) - // } - //syn0Global = aggSyn0 - //syn1Global = aggSyn1 } newSentences.unpersist() @@ -400,7 +368,7 @@ class Word2Vec extends Serializable with Logging { while (i < vocabSize) { val word = bcVocab.value(i).word val vector = new Array[Float](vectorSize) - Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize) + Array.copy(synGlobal, i * vectorSize, vector, 0, vectorSize) word2VecMap += word -> vector i += 1 }