From f0fd13c0245bd81829c48e4ab4ef99e978ee50bc Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 26 Aug 2015 17:15:45 -0700 Subject: [PATCH 1/6] Quick addition of window size test --- .../apache/spark/ml/feature/Word2Vec.scala | 15 ++++++++ .../apache/spark/mllib/feature/Word2Vec.scala | 11 +++++- .../spark/ml/feature/Word2VecSuite.scala | 37 ++++++++++++++++++- 3 files changed, 61 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 5af775a4159ad..58704e0e17cf9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -48,6 +48,17 @@ private[feature] trait Word2VecBase extends Params /** @group getParam */ def getVectorSize: Int = $(vectorSize) + /** + * The window size (context words from [-window, window]) + * @group param + */ + final val windowSize = new IntParam( + this, "windowSize", "the window size (context words from [-window, window])") + setDefault(windowSize -> 5) + + /** @group getParam */ + def getWindowSize: Int = $(windowSize) + /** * Number of partitions for sentences of words. * @group param @@ -102,6 +113,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] /** @group setParam */ def setVectorSize(value: Int): this.type = set(vectorSize, value) + /** @group setParam */ + def setWindowSize(value: Int): this.type = set(windowSize, value) + /** @group setParam */ def setStepSize(value: Double): this.type = set(stepSize, value) @@ -127,6 +141,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] .setNumPartitions($(numPartitions)) .setSeed($(seed)) .setVectorSize($(vectorSize)) + .setWindowSize($(windowSize)) .fit(input) copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 36b124c5d2966..99ef8fa18f514 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -128,6 +128,15 @@ class Word2Vec extends Serializable with Logging { this } + /** + * Sets the window of words (default: 5) + */ + @Since("1.6.0") + def setWindowSize(window: Int): this.type = { + this.window = window + this + } + /** * Sets minCount, the minimum number of times a token must appear to be included in the word2vec * model's vocabulary (default: 5). @@ -144,7 +153,7 @@ class Word2Vec extends Serializable with Logging { private val MAX_SENTENCE_LENGTH = 1000 /** context words from [-window, window] */ - private val window = 5 + private var window = 5 private var trainWordsCount = 0 private var vocabSize = 0 diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index a2e46f2029956..9646718785756 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -131,7 +131,42 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { expectedSimilarity.zip(similarity).map { case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) } + } + + test("window size") { + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + val docDF = doc.zip(doc).toDF("text", "alsotext") + val model = new Word2Vec() + .setVectorSize(3) + .setWindowSize(2) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .fit(docDF) + + val (synonyms, similarity) = model.findSynonyms("a", 6).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + + // Increase the window size + val biggerModel = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .setSeed(42L) + .setWindowSize(10) + .fit(docDF) + + val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map { + case Row(w: String, sim: Double) => (w, sim) + }.collect().unzip + // The similarity score should be very different with the larger window + assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5) } } - From c125c3bfd725a6668cca13d43f8c85af9a25ba38 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 31 Aug 2015 17:49:30 -0700 Subject: [PATCH 2/6] CR feedback --- mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 58704e0e17cf9..3ff239d8983ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -49,7 +49,7 @@ private[feature] trait Word2VecBase extends Params def getVectorSize: Int = $(vectorSize) /** - * The window size (context words from [-window, window]) + * The window size (context words from [-window, window]). * @group param */ final val windowSize = new IntParam( From e68f86024208ebfab0daffad521a2e36fad8b153 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 8 Dec 2015 15:19:11 -0800 Subject: [PATCH 3/6] switch to expert param --- .../main/scala/org/apache/spark/ml/feature/Word2Vec.scala | 8 ++++---- .../scala/org/apache/spark/ml/feature/Word2VecSuite.scala | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 3ff239d8983ab..a1ce5dd4d2c3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -49,14 +49,14 @@ private[feature] trait Word2VecBase extends Params def getVectorSize: Int = $(vectorSize) /** - * The window size (context words from [-window, window]). - * @group param + * The window size (context words from [-window, window]) default 5. + * @group expertParam */ final val windowSize = new IntParam( this, "windowSize", "the window size (context words from [-window, window])") setDefault(windowSize -> 5) - /** @group getParam */ + /** @group expertGetParam */ def getWindowSize: Int = $(windowSize) /** @@ -113,7 +113,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] /** @group setParam */ def setVectorSize(value: Int): this.type = set(vectorSize, value) - /** @group setParam */ + /** @group expertSetParam */ def setWindowSize(value: Int): this.type = set(windowSize, value) /** @group setParam */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 9646718785756..418bddc9a3031 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -35,7 +35,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Word2Vec") { - val sqlContext = new SQLContext(sc) import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 From 27ae763df7636302b81ad94720ba786deb855bcd Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 8 Dec 2015 15:20:12 -0800 Subject: [PATCH 4/6] Use sqlContext from MLlibTestSparkContext --- .../test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 0bc80551e4c54..62dadb5e0d49e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -76,7 +76,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("getVectors") { - val sqlContext = new SQLContext(sc) import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -117,7 +116,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("findSynonyms") { - val sqlContext = new SQLContext(sc) import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -144,7 +142,6 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("window size") { - val sqlContext = new SQLContext(sc) import sqlContext.implicits._ val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 From 2846981ae3b0efab3f83b24b8382535daee6b7a5 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 8 Dec 2015 15:24:09 -0800 Subject: [PATCH 5/6] make stable copy of the var for import --- .../scala/org/apache/spark/ml/feature/Word2VecSuite.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 62dadb5e0d49e..63a9f47e93e73 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -35,6 +35,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("Word2Vec") { + + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -76,6 +78,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("getVectors") { + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -116,6 +119,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("findSynonyms") { + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a b " * 100 + "a c " * 10 @@ -142,6 +146,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("window size") { + val sqlContext = this.sqlContext import sqlContext.implicits._ val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 From 76d7b5bf4ec7d8d1aeb0852fc86b8c4476745e3d Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 8 Dec 2015 15:27:02 -0800 Subject: [PATCH 6/6] Add newline that got chomped --- .../test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 63a9f47e93e73..d561bbbb25529 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -207,3 +207,4 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(newInstance.getVectors.collect() === instance.getVectors.collect()) } } +