-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-10299][ML] word2vec should allow users to specify the window size #8513
Changes from 2 commits
f0fd13c
c125c3b
e68f860
2a4739d
27ae763
2846981
76d7b5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,17 @@ private[feature] trait Word2VecBase extends Params | |
/** @group getParam */ | ||
def getVectorSize: Int = $(vectorSize) | ||
|
||
/** | ||
* The window size (context words from [-window, window]). | ||
* @group param | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd use group expertParam, expertSetParam, expertGetParam |
||
*/ | ||
final val windowSize = new IntParam( | ||
this, "windowSize", "the window size (context words from [-window, window])") | ||
setDefault(windowSize -> 5) | ||
|
||
/** @group getParam */ | ||
def getWindowSize: Int = $(windowSize) | ||
|
||
/** | ||
* Number of partitions for sentences of words. | ||
* @group param | ||
|
@@ -102,6 +113,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] | |
/** @group setParam */ | ||
def setVectorSize(value: Int): this.type = set(vectorSize, value) | ||
|
||
/** @group setParam */ | ||
def setWindowSize(value: Int): this.type = set(windowSize, value) | ||
|
||
/** @group setParam */ | ||
def setStepSize(value: Double): this.type = set(stepSize, value) | ||
|
||
|
@@ -127,6 +141,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] | |
.setNumPartitions($(numPartitions)) | ||
.setSeed($(seed)) | ||
.setVectorSize($(vectorSize)) | ||
.setWindowSize($(windowSize)) | ||
.fit(input) | ||
copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -131,7 +131,42 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { | |
expectedSimilarity.zip(similarity).map { | ||
case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) | ||
} | ||
} | ||
|
||
test("window size") { | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why removing this final line? i think this would fail style checker. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure why this show up this way in the github diff viewer, there is a newline after the windowsize test (I'll remerge in master and see if fixes the diff view) |
||
val sqlContext = new SQLContext(sc) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use context from MLlibTestSparkContext There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, all of the other tests in this suite also construct the SQLContext at the start this way I'll factor it out. |
||
import sqlContext.implicits._ | ||
|
||
val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 | ||
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) | ||
val docDF = doc.zip(doc).toDF("text", "alsotext") | ||
|
||
val model = new Word2Vec() | ||
.setVectorSize(3) | ||
.setWindowSize(2) | ||
.setInputCol("text") | ||
.setOutputCol("result") | ||
.setSeed(42L) | ||
.fit(docDF) | ||
|
||
val (synonyms, similarity) = model.findSynonyms("a", 6).map { | ||
case Row(w: String, sim: Double) => (w, sim) | ||
}.collect().unzip | ||
|
||
// Increase the window size | ||
val biggerModel = new Word2Vec() | ||
.setVectorSize(3) | ||
.setInputCol("text") | ||
.setOutputCol("result") | ||
.setSeed(42L) | ||
.setWindowSize(10) | ||
.fit(docDF) | ||
|
||
val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map { | ||
case Row(w: String, sim: Double) => (w, sim) | ||
}.collect().unzip | ||
// The similarity score should be very different with the larger window | ||
assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5) | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
State default