diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index 39a7219433..4ca0974a17 100644 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -647,13 +647,19 @@ def build_vocab_from_freq(self, word_freq, keep_raw_vocab=False, corpus_count=No Examples -------- - >>> build_vocab_from_freq({"Word1":15,"Word2":20}, update=True) + >>> from gensim.models.word2vec import Word2Vec + >>> model= Word2Vec() + >>> model.build_vocab_from_freq({"Word1": 15, "Word2": 20}) """ logger.info("Processing provided word frequencies") - vocab = defaultdict(int, word_freq) + raw_vocab = word_freq # Instead of scanning text, this will assign provided word frequencies dictionary(word_freq) to be directly the raw vocab + logger.info( + "collected %i different raw word, with total frequency of %i", + len(raw_vocab), sum(itervalues(raw_vocab)) + ) - self.corpus_count = corpus_count if corpus_count else 0 - self.raw_vocab = vocab + self.corpus_count = corpus_count if corpus_count else 0 # Since no sentences are provided, this is to control the corpus_count + self.raw_vocab = raw_vocab self.scale_vocab(keep_raw_vocab=keep_raw_vocab, trim_rule=trim_rule, update=update) # trim by min_count & precalculate downsampling self.finalize_vocab(update=update) # build tables & arrays @@ -682,7 +688,7 @@ def scan_vocab(self, sentences, progress_per=10000, trim_rule=None): ) for word in sentence: vocab[word] += 1 - total_words += 1 + total_words += len(sentence) if self.max_vocab_size and len(vocab) > self.max_vocab_size: utils.prune_vocab(vocab, min_reduce, trim_rule=trim_rule) @@ -694,6 +700,7 @@ def scan_vocab(self, sentences, progress_per=10000, trim_rule=None): ) self.corpus_count = sentence_no + 1 self.raw_vocab = vocab + return total_words def scale_vocab(self, min_count=None, sample=None, dry_run=False, keep_raw_vocab=False, trim_rule=None, update=False): diff --git a/gensim/test/test_word2vec.py b/gensim/test/test_word2vec.py index 4c642ce5d2..242b6d39bd 100644 --- a/gensim/test/test_word2vec.py +++ b/gensim/test/test_word2vec.py @@ -97,8 +97,8 @@ def testBuildVocabFromFreq(self): model_neg = word2vec.Word2Vec(size=10, min_count=0, seed=42, hs=0, negative=5) model_hs.build_vocab_from_freq(freq_dict) model_neg.build_vocab_from_freq(freq_dict) - self.assertTrue(len(model_hs.wv.vocab), 12) - self.assertTrue(len(model_neg.wv.vocab), 12) + self.assertEqual(len(model_hs.wv.vocab), 12) + self.assertEqual(len(model_neg.wv.vocab), 12) self.assertEqual(model_hs.wv.vocab['minors'].count, 2) self.assertEqual(model_hs.wv.vocab['graph'].count, 3) self.assertEqual(model_hs.wv.vocab['system'].count, 4) @@ -126,11 +126,42 @@ def testBuildVocabFromFreq(self): new_freq_dict = {'computer': 1, 'artificial': 4, 'human': 1, 'graph': 1, 'intelligence': 4, 'system': 1, 'trees': 1} model_hs.build_vocab_from_freq(new_freq_dict, update=True) model_neg.build_vocab_from_freq(new_freq_dict, update=True) - self.assertTrue(model_hs.wv.vocab['graph'].count, 4) - self.assertTrue(model_hs.wv.vocab['artificial'].count, 4) + self.assertEqual(model_hs.wv.vocab['graph'].count, 4) + self.assertEqual(model_hs.wv.vocab['artificial'].count, 4) self.assertEqual(len(model_hs.wv.vocab), 14) self.assertEqual(len(model_neg.wv.vocab), 14) + def testPruneVocab(self): + """Test Prune vocab while scanning sentences""" + sentences = [ + ["graph", "system"], + ["graph", "system"], + ["system", "eps"], + ["graph", "system"] + ] + model = word2vec.Word2Vec(sentences, size=10, min_count=0, max_vocab_size=2, seed=42, hs=1, negative=0) + self.assertEqual(len(model.wv.vocab), 2) + self.assertEqual(model.wv.vocab['graph'].count, 3) + self.assertEqual(model.wv.vocab['system'].count, 4) + + sentences = [ + ["graph", "system"], + ["graph", "system"], + ["system", "eps"], + ["graph", "system"], + ["minors", "survey", "minors", "survey", "minors"] + ] + model = word2vec.Word2Vec(sentences, size=10, min_count=0, max_vocab_size=2, seed=42, hs=1, negative=0) + self.assertEqual(len(model.wv.vocab), 3) + self.assertEqual(model.wv.vocab['graph'].count, 3) + self.assertEqual(model.wv.vocab['minors'].count, 3) + self.assertEqual(model.wv.vocab['system'].count, 4) + + def testTotalWordCount(self): + model = word2vec.Word2Vec(size=10, min_count=0, seed=42) + total_words = model.scan_vocab(sentences) + self.assertEqual(total_words, 29) + def testOnlineLearning(self): """Test that the algorithm is able to add new words to the vocabulary and to a trained model when using a sorted vocabulary"""