-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tests for word2vec's train(). Continuing #1139 #1237
Changes from 3 commits
1aa3f33
24e6331
f6f571f
5e9529b
5c24a90
c89f285
10ff8a5
a6312ca
be5216a
504bd09
43f9689
49e3d00
c9eab32
8024eb5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -473,7 +473,8 @@ def __init__( | |
if isinstance(sentences, GeneratorType): | ||
raise TypeError("You can't pass a generator as the sentences argument. Try an iterator.") | ||
self.build_vocab(sentences, trim_rule=trim_rule) | ||
self.train(sentences) | ||
self.train(sentences, total_examples=self.corpus_count, epochs=self.iter, | ||
start_alpha=self.alpha, end_alpha=self.min_alpha) | ||
|
||
def initialize_word_vectors(self): | ||
self.wv = KeyedVectors() | ||
|
@@ -754,16 +755,23 @@ def _raw_word_count(self, job): | |
"""Return the number of words in a given job.""" | ||
return sum(len(sentence) for sentence in job) | ||
|
||
def train(self, sentences, total_words=None, word_count=0, | ||
total_examples=None, queue_factor=2, report_delay=1.0): | ||
def train(self, sentences, total_examples=None, total_words=None, | ||
epochs=None, start_alpha=None, end_alpha=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No vertical indent -- use hanging indent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For clarity, what does 'hanging indent' mean in this context? Does the line with (FWIW, this "aligned with opening delimiter" style is the first 'yes' example given in PEP8, and used a lot through gensim already.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, this is function definition, vertical is probably OK here. I misread it as function call when reviewing. Hanging indent would be arguments on separate lines, with extra indentation (see the PEP8 examples). |
||
word_count=0, | ||
queue_factor=2, report_delay=1.0): | ||
""" | ||
Update the model's neural weights from a sequence of sentences (can be a once-only generator stream). | ||
For Word2Vec, each sentence must be a list of unicode strings. (Subclasses may accept other examples.) | ||
|
||
To support linear learning-rate decay from (initial) alpha to min_alpha, either total_examples | ||
(count of sentences) or total_words (count of raw words in sentences) should be provided, unless the | ||
sentences are the same as those that were used to initially build the vocabulary. | ||
To support linear learning-rate decay from (initial) alpha to min_alpha, and accurate | ||
progres-percentage logging, either total_examples (count of sentences) or total_words (count of | ||
raw words in sentences) MUST be provided. (If the corpus is the same as was provided to | ||
`build_vocab()`, the count of examples in that corpus will be available in the model's | ||
`corpus_count` property.) | ||
|
||
To avoid common mistakes around the model's ability to do multiple training passes itself, an | ||
explicit `epochs` argument MUST be provided. In the common and recommended case, where `train()` | ||
is only called once, the model's cached `iter` value should be supplied as `epochs` value. | ||
""" | ||
if (self.model_trimmed_post_training): | ||
raise RuntimeError("Parameters for training were discarded using model_trimmed_post_training method") | ||
|
@@ -795,18 +803,18 @@ def train(self, sentences, total_words=None, word_count=0, | |
"Instead start with a blank model, scan_vocab on the new corpus, intersect_word2vec_format with the old model, then train.") | ||
|
||
if total_words is None and total_examples is None: | ||
if self.corpus_count: | ||
total_examples = self.corpus_count | ||
logger.info("expecting %i sentences, matching count from corpus used for vocabulary survey", total_examples) | ||
else: | ||
raise ValueError("you must provide either total_words or total_examples, to enable alpha and progress calculations") | ||
raise ValueError("you must specify either total_examples or total_words, for proper alpha and progress calculations") | ||
if epochs is None: | ||
raise ValueError("you must specify an explict epochs count") | ||
start_alpha = start_alpha or self.alpha | ||
end_alpha = end_alpha or self.min_alpha | ||
|
||
job_tally = 0 | ||
|
||
if self.iter > 1: | ||
sentences = utils.RepeatCorpusNTimes(sentences, self.iter) | ||
total_words = total_words and total_words * self.iter | ||
total_examples = total_examples and total_examples * self.iter | ||
if epochs > 1: | ||
sentences = utils.RepeatCorpusNTimes(sentences, epochs) | ||
total_words = total_words and total_words * epochs | ||
total_examples = total_examples and total_examples * epochs | ||
|
||
def worker_loop(): | ||
"""Train the model, lifting lists of sentences from the job_queue.""" | ||
|
@@ -828,7 +836,7 @@ def job_producer(): | |
"""Fill jobs queue using the input `sentences` iterator.""" | ||
job_batch, batch_size = [], 0 | ||
pushed_words, pushed_examples = 0, 0 | ||
next_alpha = self.alpha | ||
next_alpha = start_alpha | ||
if next_alpha > self.min_alpha_yet_reached: | ||
logger.warn("Effective 'alpha' higher than previous training cycles") | ||
self.min_alpha_yet_reached = next_alpha | ||
|
@@ -851,7 +859,7 @@ def job_producer(): | |
job_queue.put((job_batch, next_alpha)) | ||
|
||
# update the learning rate for the next job | ||
if self.min_alpha < next_alpha: | ||
if end_alpha < next_alpha: | ||
if total_examples: | ||
# examples-based decay | ||
pushed_examples += len(job_batch) | ||
|
@@ -860,8 +868,8 @@ def job_producer(): | |
# words-based decay | ||
pushed_words += self._raw_word_count(job_batch) | ||
progress = 1.0 * pushed_words / total_words | ||
next_alpha = self.alpha - (self.alpha - self.min_alpha) * progress | ||
next_alpha = max(self.min_alpha, next_alpha) | ||
next_alpha = start_alpha - (start_alpha - end_alpha) * progress | ||
next_alpha = max(end_alpha, next_alpha) | ||
|
||
# add the sentence that didn't fit as the first item of a new job | ||
job_batch, batch_size = [sentence], sentence_length | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -116,12 +116,12 @@ def onlineSanity(self, model): | |
others.append(l) | ||
self.assertTrue(all(['terrorism' not in l for l in others])) | ||
model.build_vocab(others) | ||
model.train(others) | ||
model.train(others, total_examples=model.corpus_count, epochs=model.iter) | ||
self.assertFalse('terrorism' in model.wv.vocab) | ||
model.build_vocab(terro, update=True) | ||
self.assertTrue('terrorism' in model.wv.vocab) | ||
orig0 = np.copy(model.wv.syn0) | ||
model.train(terro) | ||
model.train(terro, total_examples=len(terro), epochs=model.iter) | ||
self.assertFalse(np.allclose(model.wv.syn0, orig0)) | ||
sim = model.n_similarity(['war'], ['terrorism']) | ||
self.assertLess(0., sim) | ||
|
@@ -363,7 +363,7 @@ def testTraining(self): | |
self.assertTrue(model.wv.syn0.shape == (len(model.wv.vocab), 2)) | ||
self.assertTrue(model.syn1.shape == (len(model.wv.vocab), 2)) | ||
|
||
model.train(sentences) | ||
model.train(sentences, total_examples=model.corpus_count, epochs=model.iter) | ||
sims = model.most_similar('graph', topn=10) | ||
# self.assertTrue(sims[0][0] == 'trees', sims) # most similar | ||
|
||
|
@@ -399,7 +399,7 @@ def testLocking(self): | |
# lock the vector in slot 0 against change | ||
model.syn0_lockf[0] = 0.0 | ||
|
||
model.train(corpus) | ||
model.train(corpus, total_examples=model.corpus_count, epochs=model.iter) | ||
self.assertFalse((unlocked1 == model.wv.syn0[1]).all()) # unlocked vector should vary | ||
self.assertTrue((locked0 == model.wv.syn0[0]).all()) # locked vector should not vary | ||
|
||
|
@@ -428,7 +428,7 @@ def model_sanity(self, model, train=True): | |
if train: | ||
model.build_vocab(list_corpus) | ||
orig0 = np.copy(model.wv.syn0[0]) | ||
model.train(list_corpus) | ||
model.train(list_corpus, total_examples=model.corpus_count, epochs=model.iter) | ||
self.assertFalse((orig0 == model.wv.syn0[1]).all()) # vector should vary after training | ||
sims = model.most_similar('war', topn=len(model.wv.index2word)) | ||
t_rank = [word for word, score in sims].index('terrorism') | ||
|
@@ -481,7 +481,7 @@ def testTrainingCbow(self): | |
self.assertTrue(model.wv.syn0.shape == (len(model.wv.vocab), 2)) | ||
self.assertTrue(model.syn1.shape == (len(model.wv.vocab), 2)) | ||
|
||
model.train(sentences) | ||
model.train(sentences, total_examples=model.corpus_count, epochs=model.iter) | ||
sims = model.most_similar('graph', topn=10) | ||
# self.assertTrue(sims[0][0] == 'trees', sims) # most similar | ||
|
||
|
@@ -504,7 +504,7 @@ def testTrainingSgNegative(self): | |
self.assertTrue(model.wv.syn0.shape == (len(model.wv.vocab), 2)) | ||
self.assertTrue(model.syn1neg.shape == (len(model.wv.vocab), 2)) | ||
|
||
model.train(sentences) | ||
model.train(sentences, total_examples=model.corpus_count, epochs=model.iter) | ||
sims = model.most_similar('graph', topn=10) | ||
# self.assertTrue(sims[0][0] == 'trees', sims) # most similar | ||
|
||
|
@@ -527,7 +527,7 @@ def testTrainingCbowNegative(self): | |
self.assertTrue(model.wv.syn0.shape == (len(model.wv.vocab), 2)) | ||
self.assertTrue(model.syn1neg.shape == (len(model.wv.vocab), 2)) | ||
|
||
model.train(sentences) | ||
model.train(sentences, total_examples=model.corpus_count, epochs=model.iter) | ||
sims = model.most_similar('graph', topn=10) | ||
# self.assertTrue(sims[0][0] == 'trees', sims) # most similar | ||
|
||
|
@@ -546,7 +546,7 @@ def testSimilarities(self): | |
# The model is trained using CBOW | ||
model = word2vec.Word2Vec(size=2, min_count=1, sg=0, hs=0, negative=2) | ||
model.build_vocab(sentences) | ||
model.train(sentences) | ||
model.train(sentences, total_examples=model.corpus_count, epochs=model.iter) | ||
|
||
self.assertTrue(model.n_similarity(['graph', 'trees'], ['trees', 'graph'])) | ||
self.assertTrue(model.n_similarity(['graph'], ['trees']) == model.similarity('graph', 'trees')) | ||
|
@@ -637,14 +637,23 @@ def testTrainWarning(self, l): | |
model = word2vec.Word2Vec(min_count=1) | ||
model.build_vocab(sentences) | ||
for epoch in range(10): | ||
model.train(sentences) | ||
model.train(sentences, total_examples=model.corpus_count, epochs=model.iter) | ||
model.alpha -= 0.002 | ||
model.min_alpha = model.alpha | ||
if epoch == 5: | ||
model.alpha += 0.05 | ||
warning = "Effective 'alpha' higher than previous training cycles" | ||
self.assertTrue(warning in str(l)) | ||
|
||
|
||
def test_train_with_explicit_param(self): | ||
model = word2vec.Word2Vec(size=2, min_count=1, hs=1, negative=0) | ||
model.build_vocab(sentences) | ||
with self.assertRaises(ValueError): | ||
model.train(sentences, total_examples=model.corpus_count) | ||
|
||
with self.assertRaises(ValueError): | ||
model.train(sentences, epochs=model.iter) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a simple |
||
def test_sentences_should_not_be_a_generator(self): | ||
""" | ||
Is sentences a generator object? | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No vertical indent -- use hanging indent.