From 9b3459d0f3ae1eeb1c69f43ef85d827dd6336e49 Mon Sep 17 00:00:00 2001 From: "Michael W. Sherman" Date: Tue, 24 Oct 2017 08:22:54 -0400 Subject: [PATCH] Fix scoring function in Phrases. Fix #1533, #1635 (#1573) * initial commit of fixes in comments of #1423 * removed unnecessary space in logger * added support for custom Phrases scorers * fixed Phrases.__getitem__ to support pluggable scoring #1533 * travisCI style fixes * fixed __next__() to next() for python 3 compatibilyt * misc fixes * spacing fixes for style * custom scorer support in sklearn api * Phrases scikit interface tests for pluggable scoring * missing line breaks * style, clarity, and robustness fixes requested by @piskvorky * check in Phrases init to make sure scorer is pickleable * backwards scoring compatibility when loading a Phrases class * removal of pickle testing objects in Phrases init * switched to six for python 2/3 compatibility * fix docstring --- gensim/models/phrases.py | 187 ++++++++++++++++++++++++-------- gensim/models/word2vec.py | 25 +++-- gensim/sklearn_api/phrases.py | 16 ++- gensim/test/test_phrases.py | 117 +++++++++++++++++++- gensim/test/test_sklearn_api.py | 61 +++++++++++ 5 files changed, 338 insertions(+), 68 deletions(-) diff --git a/gensim/models/phrases.py b/gensim/models/phrases.py index a7c0122652..2ec8592bcd 100644 --- a/gensim/models/phrases.py +++ b/gensim/models/phrases.py @@ -64,8 +64,10 @@ import warnings from collections import defaultdict import itertools as it -from functools import partial from math import log +from inspect import getargspec +import pickle +import six from six import iteritems, string_types, next @@ -73,7 +75,6 @@ logger = logging.getLogger(__name__) - def _is_single(obj): """ Check whether `obj` is a single document or an entire corpus. @@ -136,19 +137,36 @@ def __init__(self, sentences=None, min_count=5, threshold=10.0, max_vocab_size=4 should be a byte string (e.g. b'_'). `scoring` specifies how potential phrases are scored for comparison to the `threshold` - setting. two settings are available: + setting. `scoring` can be set with either a string that refers to a built-in scoring function, + or with a function with the expected parameter names. Two built-in scoring functions are available + by setting `scoring` to a string: 'default': from "Efficient Estimaton of Word Representations in Vector Space" by - Mikolov, et. al.: - (count(worda followed by wordb) - min_count) * N / - (count(worda) * count(wordb)) > `threshold`, where `N` is the total vocabulary size. + Mikolov, et. al.: + (count(worda followed by wordb) - min_count) * N / + (count(worda) * count(wordb)) > threshold`, where `N` is the total vocabulary size. 'npmi': normalized pointwise mutual information, from "Normalized (Pointwise) Mutual - Information in Colocation Extraction" by Gerlof Bouma: - ln(prop(worda followed by wordb) / (prop(worda)*prop(wordb))) / - - ln(prop(worda followed by wordb) - where prop(n) is the count of n / the count of everything in the entire corpus - 'npmi' is more robust when dealing with common words that form part of common bigrams, and - ranges from -1 to 1, but is slower to calculate than the default + Information in Colocation Extraction" by Gerlof Bouma: + ln(prop(worda followed by wordb) / (prop(worda)*prop(wordb))) / + - ln(prop(worda followed by wordb) + where prop(n) is the count of n / the count of everything in the entire corpus + + 'npmi' is more robust when dealing with common words that form part of common bigrams, and + ranges from -1 to 1, but is slower to calculate than the default + + To use a custom scoring function, create a function with the following parameters and set the `scoring` + parameter to the custom function. You must use all the parameters in your function call, even if the + function does not require all the parameters. + + worda_count: number of occurrances in `sentences` of the first token in the phrase being scored + wordb_count: number of occurrances in `sentences` of the second token in the phrase being scored + bigram_count: number of occurrances in `sentences` of the phrase being scored + len_vocab: the number of unique tokens in `sentences` + min_count: the `min_count` setting of the Phrases class + corpus_word_count: the total number of (non-unique) tokens in `sentences` + + A scoring function without any of these parameters (even if the parameters are not used) will + raise a ValueError on initialization of the Phrases class. The scoring function must be picklable. """ if min_count <= 0: @@ -159,8 +177,24 @@ def __init__(self, sentences=None, min_count=5, threshold=10.0, max_vocab_size=4 if scoring == 'npmi' and (threshold < -1 or threshold > 1): raise ValueError("threshold should be between -1 and 1 for npmi scoring") - if not (scoring == 'default' or scoring == 'npmi'): - raise ValueError('unknown scoring function "' + scoring + '" specified') + # set scoring based on string + # intentially override the value of the scoring parameter rather than set self.scoring here, + # to still run the check of scoring function parameters in the next code block + + if isinstance(scoring, six.string_types): + if scoring == 'default': + scoring = original_scorer + elif scoring == 'npmi': + scoring = npmi_scorer + else: + raise ValueError('unknown scoring method string %s specified' % (scoring)) + + scoring_parameters = ['worda_count', 'wordb_count', 'bigram_count', 'len_vocab', 'min_count', 'corpus_word_count'] + if callable(scoring): + if all(parameter in getargspec(scoring)[0] for parameter in scoring_parameters): + self.scoring = scoring + else: + raise ValueError('scoring function missing expected parameters') self.min_count = min_count self.threshold = threshold @@ -169,9 +203,18 @@ def __init__(self, sentences=None, min_count=5, threshold=10.0, max_vocab_size=4 self.min_reduce = 1 # ignore any tokens with count smaller than this self.delimiter = delimiter self.progress_per = progress_per - self.scoring = scoring self.corpus_word_count = 0 + # ensure picklability of custom scorer + try: + test_pickle = pickle.dumps(self.scoring) + load_pickle = pickle.loads(test_pickle) + except pickle.PickleError: + raise pickle.PickleError('unable to pickle custom Phrases scoring function') + finally: + del(test_pickle) + del(load_pickle) + if sentences is not None: self.add_vocab(sentences) @@ -227,8 +270,7 @@ def add_vocab(self, sentences): # directly, but gives the new sentences a fighting chance to collect # sufficient counts, before being pruned out by the (large) accummulated # counts collected in previous learn_vocab runs. - min_reduce, vocab, total_words = \ - self.learn_vocab(sentences, self.max_vocab_size, self.delimiter, self.progress_per) + min_reduce, vocab, total_words = self.learn_vocab(sentences, self.max_vocab_size, self.delimiter, self.progress_per) self.corpus_word_count += total_words if len(self.vocab) > 0: @@ -263,14 +305,11 @@ def export_phrases(self, sentences, out_delimiter=b' ', as_tuples=False): threshold = self.threshold delimiter = self.delimiter # delimiter used for lookup min_count = self.min_count - scoring = self.scoring - corpus_word_count = self.corpus_word_count - - if scoring == 'default': - scoring_function = partial(self.original_scorer, len_vocab=float(len(vocab)), min_count=float(min_count)) - elif scoring == 'npmi': - scoring_function = partial(self.npmi_scorer, corpus_word_count=corpus_word_count) - # no else here to catch unknown scoring function, check is done in Phrases.__init__ + scorer = self.scoring + # made floats for scoring function + len_vocab = float(len(vocab)) + scorer_min_count = float(min_count) + corpus_word_count = float(self.corpus_word_count) for sentence in sentences: s = [utils.any2utf8(w) for w in sentence] @@ -284,7 +323,10 @@ def export_phrases(self, sentences, out_delimiter=b' ', as_tuples=False): count_a = float(vocab[word_a]) count_b = float(vocab[word_b]) count_ab = float(vocab[bigram_word]) - score = scoring_function(count_a, count_b, count_ab) + # scoring MUST have all these parameters, even if they are not used + score = scorer(worda_count=count_a, wordb_count=count_b, bigram_count=count_ab, len_vocab=len_vocab, min_count=scorer_min_count, corpus_word_count=corpus_word_count) + # logger.debug("score for %s: (pab=%s - min_count=%s) / pa=%s / pb=%s * vocab_size=%s = %s", + # bigram_word, count_ab, scorer_min_count, count_a, count_ab, len_vocab, score) if score > threshold and count_ab >= min_count: if as_tuples: yield ((word_a, word_b), score) @@ -315,6 +357,16 @@ def __getitem__(self, sentence): """ warnings.warn("For a faster implementation, use the gensim.models.phrases.Phraser class") + vocab = self.vocab + threshold = self.threshold + delimiter = self.delimiter # delimiter used for lookup + min_count = self.min_count + scorer = self.scoring + # made floats for scoring function + len_vocab = float(len(vocab)) + scorer_min_count = float(min_count) + corpus_word_count = float(self.corpus_word_count) + is_single, sentence = _is_single(sentence) if not is_single: # if the input is an entire corpus (rather than a single sentence), @@ -324,18 +376,20 @@ def __getitem__(self, sentence): s, new_s = [utils.any2utf8(w) for w in sentence], [] last_bigram = False vocab = self.vocab - threshold = self.threshold - delimiter = self.delimiter - min_count = self.min_count + for word_a, word_b in zip(s, s[1:]): - if word_a in vocab and word_b in vocab: + # last bigram check was moved here to save a few CPU cycles + if word_a in vocab and word_b in vocab and not last_bigram: bigram_word = delimiter.join((word_a, word_b)) - if bigram_word in vocab and not last_bigram: - pa = float(vocab[word_a]) - pb = float(vocab[word_b]) - pab = float(vocab[bigram_word]) - score = (pab - min_count) / pa / pb * len(vocab) - if score > threshold: + if bigram_word in vocab: + count_a = float(vocab[word_a]) + count_b = float(vocab[word_b]) + count_ab = float(vocab[bigram_word]) + # scoring MUST have all these parameters, even if they are not used + score = scorer(worda_count=count_a, wordb_count=count_b, bigram_count=count_ab, len_vocab=len_vocab, min_count=scorer_min_count, corpus_word_count=corpus_word_count) + # logger.debug("score for %s: (pab=%s - min_count=%s) / pa=%s / pb=%s * vocab_size=%s = %s", + # bigram_word, count_ab, scorer_min_count, count_a, count_ab, len_vocab, score) + if score > threshold and count_ab >= min_count: new_s.append(bigram_word) last_bigram = True continue @@ -351,19 +405,56 @@ def __getitem__(self, sentence): return [utils.to_unicode(w) for w in new_s] - # calculation of score based on original mikolov word2vec paper - # len_vocab and min_count set so functools.partial works - @staticmethod - def original_scorer(worda_count, wordb_count, bigram_count, len_vocab=0.0, min_count=0.0): - return (bigram_count - min_count) / worda_count / wordb_count * len_vocab + @classmethod + def load(cls, *args, **kwargs): + """ + Load a previously saved Phrases class. Handles backwards compatibility from older Phrases versions which did not support + pluggable scoring functions. Otherwise, relies on utils.load + """ - # normalized PMI, requires corpus size - @staticmethod - def npmi_scorer(worda_count, wordb_count, bigram_count, corpus_word_count=0.0): - pa = worda_count / corpus_word_count - pb = wordb_count / corpus_word_count - pab = bigram_count / corpus_word_count - return log(pab / (pa * pb)) / -log(pab) + # for python 2 and 3 compatibility. basestring is used to check if model.scoring is a string + try: + basestring + except NameError: + basestring = str + + model = super(Phrases, cls).load(*args, **kwargs) + # update older models + # if no scoring parameter, use default scoring + if not hasattr(model, 'scoring'): + logger.info('older version of Phrases loaded without scoring function') + logger.info('setting pluggable scoring method to original_scorer for compatibility') + model.scoring = original_scorer + # if there is a scoring parameter, and it's a text value, load the proper scoring function + if hasattr(model, 'scoring'): + if isinstance(model.scoring, basestring): + if model.scoring == 'default': + logger.info('older version of Phrases loaded with "default" scoring parameter') + logger.info('setting scoring method to original_scorer pluggable scoring method for compatibility') + model.scoring = original_scorer + elif model.scoring == 'npmi': + logger.info('older version of Phrases loaded with "npmi" scoring parameter') + logger.info('setting scoring method to npmi_scorer pluggable scoring method for compatibility') + model.scoring = npmi_scorer + else: + raise ValueError('failed to load Phrases model with unknown scoring setting %s' % (model.scoring)) + return model + + +# these two built-in scoring methods don't cast everything to float because the casting is done in the call +# to the scoring method in __getitem__ and export_phrases. + +# calculation of score based on original mikolov word2vec paper +def original_scorer(worda_count, wordb_count, bigram_count, len_vocab, min_count, corpus_word_count): + return (bigram_count - min_count) / worda_count / wordb_count * len_vocab + + +# normalized PMI, requires corpus size +def npmi_scorer(worda_count, wordb_count, bigram_count, len_vocab, min_count, corpus_word_count): + pa = worda_count / corpus_word_count + pb = wordb_count / corpus_word_count + pab = bigram_count / corpus_word_count + return log(pab / (pa * pb)) / -log(pab) def pseudocorpus(source_vocab, sep): diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index 754020a380..f718641d55 100644 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -1666,15 +1666,20 @@ def __iter__(self): class PathLineSentences(object): """ - Simple format: one sentence = one line; words already preprocessed and separated by whitespace. - Like LineSentence, but will process all files in a directory in alphabetical order by filename + + Works like word2vec.LineSentence, but will process all files in a directory in alphabetical order by filename. + The directory can only contain files that can be read by LineSentence: .bz2, .gz, and text files. Any file not ending + with .bz2 or .gz is assumed to be a text file. Does not work with subdirectories. + + The format of files (either text, or compressed text files) in the path is one sentence = one line, with words already + preprocessed and separated by whitespace. + """ def __init__(self, source, max_sentence_length=MAX_WORDS_IN_BATCH, limit=None): """ `source` should be a path to a directory (as a string) where all files can be opened by the - LineSentence class. Each file will be read up to - `limit` lines (or no clipped if limit is None, the default). + LineSentence class. Each file will be read up to `limit` lines (or not clipped if limit is None, the default). Example:: @@ -1688,23 +1693,23 @@ def __init__(self, source, max_sentence_length=MAX_WORDS_IN_BATCH, limit=None): self.limit = limit if os.path.isfile(self.source): - logging.warning('single file read, better to use models.word2vec.LineSentence') + logger.debug('single file given as source, rather than a directory of files') + logger.debug('consider using models.word2vec.LineSentence for a single file') self.input_files = [self.source] # force code compatibility with list of files elif os.path.isdir(self.source): self.source = os.path.join(self.source, '') # ensures os-specific slash at end of path - logging.debug('reading directory %s', self.source) + logger.info('reading directory %s', self.source) self.input_files = os.listdir(self.source) - self.input_files = [self.source + file for file in self.input_files] # make full paths + self.input_files = [self.source + filename for filename in self.input_files] # make full paths self.input_files.sort() # makes sure it happens in filename order else: # not a file or a directory, then we can't do anything with it raise ValueError('input is neither a file nor a path') - - logging.info('files read into PathLineSentences:%s', '\n'.join(self.input_files)) + logger.info('files read into PathLineSentences:%s', '\n'.join(self.input_files)) def __iter__(self): """iterate through the files""" for file_name in self.input_files: - logging.info('reading file %s', file_name) + logger.info('reading file %s', file_name) with utils.smart_open(file_name) as fin: for line in itertools.islice(fin, self.limit): line = utils.to_unicode(line).split() diff --git a/gensim/sklearn_api/phrases.py b/gensim/sklearn_api/phrases.py index ad00c51c0e..2eab84b95e 100644 --- a/gensim/sklearn_api/phrases.py +++ b/gensim/sklearn_api/phrases.py @@ -21,7 +21,8 @@ class PhrasesTransformer(TransformerMixin, BaseEstimator): Base Phrases module """ - def __init__(self, min_count=5, threshold=10.0, max_vocab_size=40000000, delimiter=b'_', progress_per=10000): + def __init__(self, min_count=5, threshold=10.0, max_vocab_size=40000000, + delimiter=b'_', progress_per=10000, scoring='default'): """ Sklearn wrapper for Phrases model. """ @@ -31,15 +32,14 @@ def __init__(self, min_count=5, threshold=10.0, max_vocab_size=40000000, delimit self.max_vocab_size = max_vocab_size self.delimiter = delimiter self.progress_per = progress_per + self.scoring = scoring def fit(self, X, y=None): """ Fit the model according to the given training data. """ - self.gensim_model = models.Phrases( - sentences=X, min_count=self.min_count, threshold=self.threshold, - max_vocab_size=self.max_vocab_size, delimiter=self.delimiter, progress_per=self.progress_per - ) + self.gensim_model = models.Phrases(sentences=X, min_count=self.min_count, threshold=self.threshold, + max_vocab_size=self.max_vocab_size, delimiter=self.delimiter, progress_per=self.progress_per, scoring=self.scoring) return self def transform(self, docs): @@ -62,10 +62,8 @@ def transform(self, docs): def partial_fit(self, X): if self.gensim_model is None: - self.gensim_model = models.Phrases( - sentences=X, min_count=self.min_count, threshold=self.threshold, - max_vocab_size=self.max_vocab_size, delimiter=self.delimiter, progress_per=self.progress_per - ) + self.gensim_model = models.Phrases(sentences=X, min_count=self.min_count, threshold=self.threshold, + max_vocab_size=self.max_vocab_size, delimiter=self.delimiter, progress_per=self.progress_per, scoring=self.scoring) self.gensim_model.add_vocab(X) return self diff --git a/gensim/test/test_phrases.py b/gensim/test/test_phrases.py index 868947defb..cf008f14cc 100644 --- a/gensim/test/test_phrases.py +++ b/gensim/test/test_phrases.py @@ -123,6 +123,14 @@ def testEncoding(self): self.assertTrue(isinstance(transformed, unicode)) +# scorer for testCustomScorer +# function is outside of the scope of the test because for picklability of custom scorer +# Phrases tests for picklability +# all scores will be 1 +def dumb_scorer(worda_count, wordb_count, bigram_count, len_vocab, min_count, corpus_word_count): + return 1 + + class TestPhrasesModel(unittest.TestCase): def testExportPhrases(self): """Test Phrases bigram export_phrases functionality.""" @@ -162,12 +170,20 @@ def testScoringDefault(self): 3.444 # score for human interface } + def test__getitem__(self): + """ test Phrases[sentences] with a single sentence""" + bigram = Phrases(sentences, min_count=1, threshold=1) + # pdb.set_trace() + test_sentences = [['graph', 'minors', 'survey', 'human', 'interface']] + phrased_sentence = next(bigram[test_sentences].__iter__()) + + assert phrased_sentence == ['graph_minors', 'survey', 'human_interface'] + def testScoringNpmi(self): """ test normalized pointwise mutual information scoring """ bigram = Phrases(sentences, min_count=1, threshold=.5, scoring='npmi') seen_scores = set() - test_sentences = [['graph', 'minors', 'survey', 'human', 'interface']] for phrase, score in bigram.export_phrases(test_sentences): seen_scores.add(round(score, 3)) @@ -177,6 +193,19 @@ def testScoringNpmi(self): .714 # score for human interface } + def testCustomScorer(self): + """ test using a custom scoring function """ + + bigram = Phrases(sentences, min_count=1, threshold=.001, scoring=dumb_scorer) + + seen_scores = [] + test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']] + for phrase, score in bigram.export_phrases(test_sentences): + seen_scores.append(score) + + assert all(seen_scores) # all scores 1 + assert len(seen_scores) == 3 # 'graph minors' and 'survey human' and 'interface system' + def testBadParameters(self): """Test the phrases module with bad parameters.""" # should fail with something less or equal than 0 @@ -189,6 +218,92 @@ def testPruning(self): """Test that max_vocab_size parameter is respected.""" bigram = Phrases(sentences, max_vocab_size=5) self.assertTrue(len(bigram.vocab) <= 5) + + def testSaveLoadCustomScorer(self): + """ saving and loading a Phrases object with a custom scorer """ + + try: + bigram = Phrases(sentences, min_count=1, threshold=.001, scoring=dumb_scorer) + bigram.save("test_phrases_testSaveLoadCustomScorer_temp_save.pkl") + bigram_loaded = Phrases.load("test_phrases_testSaveLoadCustomScorer_temp_save.pkl") + seen_scores = [] + test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']] + for phrase, score in bigram_loaded.export_phrases(test_sentences): + seen_scores.append(score) + + assert all(seen_scores) # all scores 1 + assert len(seen_scores) == 3 # 'graph minors' and 'survey human' and 'interface system' + + finally: + if os.path.exists("test_phrases_testSaveLoadCustomScorer_temp_save.pkl"): + os.remove("test_phrases_testSaveLoadCustomScorer_temp_save.pkl") + + def testSaveLoad(self): + """ Saving and loading a Phrases object.""" + + try: + bigram = Phrases(sentences, min_count=1, threshold=1) + bigram.save("test_phrases_testSaveLoad_temp_save.pkl") + bigram_loaded = Phrases.load("test_phrases_testSaveLoad_temp_save.pkl") + seen_scores = set() + test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']] + for phrase, score in bigram_loaded.export_phrases(test_sentences): + seen_scores.add(round(score, 3)) + + assert seen_scores == set([ + 5.167, # score for graph minors + 3.444 # score for human interface + ]) + + finally: + if os.path.exists("test_phrases_testSaveLoad_temp_save.pkl"): + os.remove("test_phrases_testSaveLoad_temp_save.pkl") + + def testSaveLoadStringScoring(self): + """ Saving and loading a Phrases object with a string scoring parameter. + This should ensure backwards compatibility with the previous version of Phrases""" + + try: + bigram = Phrases(sentences, min_count=1, threshold=1) + bigram.scoring = "default" + bigram.save("test_phrases_testSaveLoadStringScoring_temp_save.pkl") + bigram_loaded = Phrases.load("test_phrases_testSaveLoadStringScoring_temp_save.pkl") + seen_scores = set() + test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']] + for phrase, score in bigram_loaded.export_phrases(test_sentences): + seen_scores.add(round(score, 3)) + + assert seen_scores == set([ + 5.167, # score for graph minors + 3.444 # score for human interface + ]) + + finally: + if os.path.exists("test_phrases_testSaveLoadStringScoring_temp_save.pkl"): + os.remove("test_phrases_testSaveLoadStringScoring_temp_save.pkl") + + def testSaveLoadNoScoring(self): + """ Saving and loading a Phrases object with no scoring parameter. + This should ensure backwards compatibility with old versions of Phrases""" + + try: + bigram = Phrases(sentences, min_count=1, threshold=1) + del(bigram.scoring) + bigram.save("test_phrases_testSaveLoadNoScoring_temp_save.pkl") + bigram_loaded = Phrases.load("test_phrases_testSaveLoadNoScoring_temp_save.pkl") + seen_scores = set() + test_sentences = [['graph', 'minors', 'survey', 'human', 'interface', 'system']] + for phrase, score in bigram_loaded.export_phrases(test_sentences): + seen_scores.add(round(score, 3)) + + assert seen_scores == set([ + 5.167, # score for graph minors + 3.444 # score for human interface + ]) + + finally: + if os.path.exists("test_phrases_testSaveLoadNoScoring_temp_save.pkl"): + os.remove("test_phrases_testSaveLoadNoScoring_temp_save.pkl") # endclass TestPhrasesModel diff --git a/gensim/test/test_sklearn_api.py b/gensim/test/test_sklearn_api.py index 07411aa9b9..2dd54073b7 100644 --- a/gensim/test/test_sklearn_api.py +++ b/gensim/test/test_sklearn_api.py @@ -988,5 +988,66 @@ def testModelNotFitted(self): self.assertRaises(NotFittedError, phrases_transformer.transform, phrases_sentences[0]) +# specifically test pluggable scoring in Phrases, because possible pickling issues with function parameter + +# this is intentionally in main rather than a class method to support pickling +# all scores will be 1 +def dumb_scorer(worda_count, wordb_count, bigram_count, len_vocab, min_count, corpus_word_count): + return 1 + + +class TestPhrasesTransformerCustomScorer(unittest.TestCase): + + def setUp(self): + numpy.random.seed(0) + + self.model = PhrasesTransformer(min_count=1, threshold=.9, scoring=dumb_scorer) + self.model.fit(phrases_sentences) + + def testTransform(self): + # tranform one document + doc = phrases_sentences[-1] + phrase_tokens = self.model.transform(doc)[0] + expected_phrase_tokens = [u'graph_minors', u'survey_human', u'interface'] + self.assertEqual(phrase_tokens, expected_phrase_tokens) + + def testPartialFit(self): + new_sentences = [ + ['world', 'peace', 'humans', 'world', 'peace', 'world', 'peace', 'people'], + ['world', 'peace', 'people'], + ['world', 'peace', 'humans'] + ] + self.model.partial_fit(X=new_sentences) # train model with new sentences + + doc = ['graph', 'minors', 'survey', 'human', 'interface', 'world', 'peace'] + phrase_tokens = self.model.transform(doc)[0] + expected_phrase_tokens = [u'graph_minors', u'survey_human', u'interface', u'world_peace'] + self.assertEqual(phrase_tokens, expected_phrase_tokens) + + def testSetGetParams(self): + # updating only one param + self.model.set_params(progress_per=5000) + model_params = self.model.get_params() + self.assertEqual(model_params["progress_per"], 5000) + + # verify that the attributes values are also changed for `gensim_model` after fitting + self.model.fit(phrases_sentences) + self.assertEqual(getattr(self.model.gensim_model, 'progress_per'), 5000) + + def testPersistence(self): + model_dump = pickle.dumps(self.model) + model_load = pickle.loads(model_dump) + + doc = phrases_sentences[-1] + loaded_phrase_tokens = model_load.transform(doc) + + # comparing the original and loaded models + original_phrase_tokens = self.model.transform(doc) + self.assertEqual(original_phrase_tokens, loaded_phrase_tokens) + + def testModelNotFitted(self): + phrases_transformer = PhrasesTransformer() + self.assertRaises(NotFittedError, phrases_transformer.transform, phrases_sentences[0]) + if __name__ == '__main__': unittest.main()