diff --git a/test_corpora_2574.py b/test_corpora_2574.py new file mode 100644 index 0000000000..f47f3a27ab --- /dev/null +++ b/test_corpora_2574.py @@ -0,0 +1,933 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2010 Radim Rehurek +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + +""" +Automated tests for checking corpus I/O formats (the corpora package). +""" + +from __future__ import unicode_literals + +import codecs +import itertools +import logging +import os.path +import tempfile +import unittest + +import numpy as np + +from gensim.corpora import (bleicorpus, mmcorpus, lowcorpus, svmlightcorpus, + ucicorpus, malletcorpus, textcorpus, indexedcorpus, wikicorpus) +from gensim.interfaces import TransformedCorpus +from gensim.utils import to_unicode +from gensim.test.utils import datapath, get_tmpfile, common_corpus + + +class DummyTransformer(object): + def __getitem__(self, bow): + if len(next(iter(bow))) == 2: + # single bag of words + transformed = [(termid, count + 1) for termid, count in bow] + else: + # sliced corpus + transformed = [[(termid, count + 1) for termid, count in doc] for doc in bow] + return transformed + + +class CorpusTestCase(unittest.TestCase): + TEST_CORPUS = [[(1, 1.0)], [], [(0, 0.5), (2, 1.0)], []] + + def setUp(self): + self.corpus_class = None + self.file_extension = None + + def run(self, result=None): + if type(self) is not CorpusTestCase: + super(CorpusTestCase, self).run(result) + + def tearDown(self): + # remove all temporary test files + fname = get_tmpfile('gensim_corpus.tst') + extensions = ['', '', '.bz2', '.gz', '.index', '.vocab'] + for ext in itertools.permutations(extensions, 2): + try: + os.remove(fname + ext[0] + ext[1]) + except OSError: + pass + + def test_load(self): + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + corpus = self.corpus_class(fname) + + docs = list(corpus) + # the deerwester corpus always has nine documents + self.assertEqual(len(docs), 9) + + def test_len(self): + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + corpus = self.corpus_class(fname) + + # make sure corpus.index works, too + corpus = self.corpus_class(fname) + self.assertEqual(len(corpus), 9) + + # for subclasses of IndexedCorpus, we need to nuke this so we don't + # test length on the index, but just testcorpus contents + if hasattr(corpus, 'index'): + corpus.index = None + + self.assertEqual(len(corpus), 9) + + def test_empty_input(self): + tmpf = get_tmpfile('gensim_corpus.tst') + with open(tmpf, 'w') as f: + f.write('') + + with open(tmpf + '.vocab', 'w') as f: + f.write('') + + corpus = self.corpus_class(tmpf) + self.assertEqual(len(corpus), 0) + + docs = list(corpus) + self.assertEqual(len(docs), 0) + + def test_save(self): + corpus = self.TEST_CORPUS + tmpf = get_tmpfile('gensim_corpus.tst') + + # make sure the corpus can be saved + self.corpus_class.save_corpus(tmpf, corpus) + + # and loaded back, resulting in exactly the same corpus + corpus2 = list(self.corpus_class(tmpf)) + self.assertEqual(corpus, corpus2) + + def test_serialize(self): + corpus = self.TEST_CORPUS + tmpf = get_tmpfile('gensim_corpus.tst') + + # make sure the corpus can be saved + self.corpus_class.serialize(tmpf, corpus) + + # and loaded back, resulting in exactly the same corpus + corpus2 = self.corpus_class(tmpf) + self.assertEqual(corpus, list(corpus2)) + + # make sure the indexing corpus[i] works + for i in range(len(corpus)): + self.assertEqual(corpus[i], corpus2[i]) + + # make sure that subclasses of IndexedCorpus support fancy indexing + # after deserialisation + if isinstance(corpus, indexedcorpus.IndexedCorpus): + idx = [1, 3, 5, 7] + self.assertEqual(corpus[idx], corpus2[idx]) + + def test_serialize_compressed(self): + corpus = self.TEST_CORPUS + tmpf = get_tmpfile('gensim_corpus.tst') + + for extension in ['.gz', '.bz2']: + fname = tmpf + extension + # make sure the corpus can be saved + self.corpus_class.serialize(fname, corpus) + + # and loaded back, resulting in exactly the same corpus + corpus2 = self.corpus_class(fname) + self.assertEqual(corpus, list(corpus2)) + + # make sure the indexing `corpus[i]` syntax works + for i in range(len(corpus)): + self.assertEqual(corpus[i], corpus2[i]) + + def test_switch_id2word(self): + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + corpus = self.corpus_class(fname) + if hasattr(corpus, 'id2word'): + firstdoc = next(iter(corpus)) + testdoc = set((to_unicode(corpus.id2word[x]), y) for x, y in firstdoc) + + self.assertEqual(testdoc, {('computer', 1), ('human', 1), ('interface', 1)}) + + d = corpus.id2word + d[0], d[1] = d[1], d[0] + corpus.id2word = d + + firstdoc2 = next(iter(corpus)) + testdoc2 = set((to_unicode(corpus.id2word[x]), y) for x, y in firstdoc2) + self.assertEqual(testdoc2, {('computer', 1), ('human', 1), ('interface', 1)}) + + def test_indexing(self): + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + corpus = self.corpus_class(fname) + docs = list(corpus) + + for idx, doc in enumerate(docs): + self.assertEqual(doc, corpus[idx]) + self.assertEqual(doc, corpus[np.int64(idx)]) + + self.assertEqual(docs, list(corpus[:])) + self.assertEqual(docs[0:], list(corpus[0:])) + self.assertEqual(docs[0:-1], list(corpus[0:-1])) + self.assertEqual(docs[2:4], list(corpus[2:4])) + self.assertEqual(docs[::2], list(corpus[::2])) + self.assertEqual(docs[::-1], list(corpus[::-1])) + + # make sure sliced corpora can be iterated over multiple times + c = corpus[:] + self.assertEqual(docs, list(c)) + self.assertEqual(docs, list(c)) + self.assertEqual(len(docs), len(corpus)) + self.assertEqual(len(docs), len(corpus[:])) + self.assertEqual(len(docs[::2]), len(corpus[::2])) + + def _get_slice(corpus, slice_): + # assertRaises for python 2.6 takes a callable + return corpus[slice_] + + # make sure proper input validation for sliced corpora is done + self.assertRaises(ValueError, _get_slice, corpus, {1}) + self.assertRaises(ValueError, _get_slice, corpus, 1.0) + + # check sliced corpora that use fancy indexing + c = corpus[[1, 3, 4]] + self.assertEqual([d for i, d in enumerate(docs) if i in [1, 3, 4]], list(c)) + self.assertEqual([d for i, d in enumerate(docs) if i in [1, 3, 4]], list(c)) + self.assertEqual(len(corpus[[0, 1, -1]]), 3) + self.assertEqual(len(corpus[np.asarray([0, 1, -1])]), 3) + + # check that TransformedCorpus supports indexing when the underlying + # corpus does, and throws an error otherwise + corpus_ = TransformedCorpus(DummyTransformer(), corpus) + if hasattr(corpus, 'index') and corpus.index is not None: + self.assertEqual(corpus_[0][0][1], docs[0][0][1] + 1) + self.assertRaises(ValueError, _get_slice, corpus_, {1}) + transformed_docs = [val + 1 for i, d in enumerate(docs) for _, val in d if i in [1, 3, 4]] + self.assertEqual(transformed_docs, list(v for doc in corpus_[[1, 3, 4]] for _, v in doc)) + self.assertEqual(3, len(corpus_[[1, 3, 4]])) + else: + self.assertRaises(RuntimeError, _get_slice, corpus_, [1, 3, 4]) + self.assertRaises(RuntimeError, _get_slice, corpus_, {1}) + self.assertRaises(RuntimeError, _get_slice, corpus_, 1.0) + + +class TestMmCorpusWithIndex(CorpusTestCase): + + + def test_serialize_compressed(self): + # MmCorpus needs file write with seek => doesn't support compressed output (only input) + pass + + def test_closed_file_object(self): + file_obj = open(datapath('testcorpus.mm')) + f = file_obj.closed + mmcorpus.MmCorpus(file_obj) + s = file_obj.closed + self.assertEqual(f, 0) + self.assertEqual(s, 0) + + def test_load(self): + self.assertEqual(self.corpus.num_docs, 9) + self.assertEqual(self.corpus.num_terms, 12) + self.assertEqual(self.corpus.num_nnz, 28) + + # confirm we can iterate and that document values match expected for first three docs + it = iter(self.corpus) + self.assertEqual(next(it), [(0, 1.0), (1, 1.0), (2, 1.0)]) + self.assertEqual(next(it), [(0, 1.0), (3, 1.0), (4, 1.0), (5, 1.0), (6, 1.0), (7, 1.0)]) + self.assertEqual(next(it), [(2, 1.0), (5, 1.0), (7, 1.0), (8, 1.0)]) + + # confirm that accessing document by index works + self.assertEqual(self.corpus[3], [(1, 1.0), (5, 2.0), (8, 1.0)]) + self.assertEqual(tuple(self.corpus.index), (97, 121, 169, 201, 225, 249, 258, 276, 303)) + + +class TestMmCorpusNoIndex(CorpusTestCase): + def setUp(self): + self.corpus_class = mmcorpus.MmCorpus + self.corpus = self.corpus_class(datapath('test_mmcorpus_no_index.mm')) + self.file_extension = '.mm' + + def test_serialize_compressed(self): + # MmCorpus needs file write with seek => doesn't support compressed output (only input) + pass + + def test_load(self): + self.assertEqual(self.corpus.num_docs, 9) + self.assertEqual(self.corpus.num_terms, 12) + self.assertEqual(self.corpus.num_nnz, 28) + + # confirm we can iterate and that document values match expected for first three docs + it = iter(self.corpus) + self.assertEqual(next(it), [(0, 1.0), (1, 1.0), (2, 1.0)]) + self.assertEqual(next(it), []) + self.assertEqual(next(it), [(2, 0.42371910849), (5, 0.6625174), (7, 1.0), (8, 1.0)]) + + # confirm that accessing document by index fails + self.assertRaises(RuntimeError, lambda: self.corpus[3]) + + +class TestMmCorpusNoIndexGzip(CorpusTestCase): + def setUp(self): + self.corpus_class = mmcorpus.MmCorpus + self.corpus = self.corpus_class(datapath('test_mmcorpus_no_index.mm.gz')) + self.file_extension = '.mm' + + def test_serialize_compressed(self): + # MmCorpus needs file write with seek => doesn't support compressed output (only input) + pass + + def test_load(self): + self.assertEqual(self.corpus.num_docs, 9) + self.assertEqual(self.corpus.num_terms, 12) + self.assertEqual(self.corpus.num_nnz, 28) + + # confirm we can iterate and that document values match expected for first three docs + it = iter(self.corpus) + self.assertEqual(next(it), [(0, 1.0), (1, 1.0), (2, 1.0)]) + self.assertEqual(next(it), []) + self.assertEqual(next(it), [(2, 0.42371910849), (5, 0.6625174), (7, 1.0), (8, 1.0)]) + + # confirm that accessing document by index fails + self.assertRaises(RuntimeError, lambda: self.corpus[3]) + + +class TestMmCorpusNoIndexBzip(CorpusTestCase): + def setUp(self): + self.corpus_class = mmcorpus.MmCorpus + self.corpus = self.corpus_class(datapath('test_mmcorpus_no_index.mm.bz2')) + self.file_extension = '.mm' + + def test_serialize_compressed(self): + # MmCorpus needs file write with seek => doesn't support compressed output (only input) + pass + + def test_load(self): + self.assertEqual(self.corpus.num_docs, 9) + self.assertEqual(self.corpus.num_terms, 12) + self.assertEqual(self.corpus.num_nnz, 28) + + # confirm we can iterate and that document values match expected for first three docs + it = iter(self.corpus) + self.assertEqual(next(it), [(0, 1.0), (1, 1.0), (2, 1.0)]) + self.assertEqual(next(it), []) + self.assertEqual(next(it), [(2, 0.42371910849), (5, 0.6625174), (7, 1.0), (8, 1.0)]) + + # confirm that accessing document by index fails + self.assertRaises(RuntimeError, lambda: self.corpus[3]) + + +class TestMmCorpusCorrupt(CorpusTestCase): + def setUp(self): + self.corpus_class = mmcorpus.MmCorpus + self.corpus = self.corpus_class(datapath('test_mmcorpus_corrupt.mm')) + self.file_extension = '.mm' + + def test_serialize_compressed(self): + # MmCorpus needs file write with seek => doesn't support compressed output (only input) + pass + + def test_load(self): + self.assertRaises(ValueError, lambda: [doc for doc in self.corpus]) + + +class TestMmCorpusOverflow(CorpusTestCase): + """ + Test to make sure cython mmreader doesn't overflow on large number of docs or terms + + """ + def setUp(self): + self.corpus_class = mmcorpus.MmCorpus + self.corpus = self.corpus_class(datapath('test_mmcorpus_overflow.mm')) + self.file_extension = '.mm' + + def test_serialize_compressed(self): + # MmCorpus needs file write with seek => doesn't support compressed output (only input) + pass + + def test_load(self): + self.assertEqual(self.corpus.num_docs, 44270060) + self.assertEqual(self.corpus.num_terms, 500) + self.assertEqual(self.corpus.num_nnz, 22134988630) + + # confirm we can iterate and that document values match expected for first three docs + it = iter(self.corpus) + self.assertEqual(next(it)[:3], [(0, 0.3913027376444812), + (1, -0.07658791716226626), + (2, -0.020870794080588395)]) + self.assertEqual(next(it), []) + self.assertEqual(next(it), []) + + # confirm count of terms + count = 0 + for doc in self.corpus: + for term in doc: + count += 1 + + self.assertEqual(count, 12) + + # confirm that accessing document by index fails + self.assertRaises(RuntimeError, lambda: self.corpus[3]) + + +class TestSvmLightCorpus(CorpusTestCase): + def setUp(self): + self.corpus_class = svmlightcorpus.SvmLightCorpus + self.file_extension = '.svmlight' + + def test_serialization(self): + path = get_tmpfile("svml.corpus") + labels = [1] * len(common_corpus) + second_corpus = [(0, 1.0), (3, 1.0), (4, 1.0), (5, 1.0), (6, 1.0), (7, 1.0)] + self.corpus_class.serialize(path, common_corpus, labels=labels) + serialized_corpus = self.corpus_class(path) + self.assertEqual(serialized_corpus[1], second_corpus) + self.corpus_class.serialize(path, common_corpus, labels=np.array(labels)) + serialized_corpus = self.corpus_class(path) + self.assertEqual(serialized_corpus[1], second_corpus) + + +class TestBleiCorpus(CorpusTestCase): + def setUp(self): + self.corpus_class = bleicorpus.BleiCorpus + self.file_extension = '.blei' + + def test_save_format_for_dtm(self): + corpus = [[(1, 1.0)], [], [(0, 5.0), (2, 1.0)], []] + test_file = get_tmpfile('gensim_corpus.tst') + self.corpus_class.save_corpus(test_file, corpus) + with open(test_file) as f: + for line in f: + # unique_word_count index1:count1 index2:count2 ... indexn:counnt + tokens = line.split() + words_len = int(tokens[0]) + if words_len > 0: + tokens = tokens[1:] + else: + tokens = [] + self.assertEqual(words_len, len(tokens)) + for token in tokens: + word, count = token.split(':') + self.assertEqual(count, str(int(count))) + + +class TestLowCorpus(CorpusTestCase): + TEST_CORPUS = [[(1, 1)], [], [(0, 2), (2, 1)], []] + CORPUS_LINE = 'mom wash window window was washed' + + def setUp(self): + self.corpus_class = lowcorpus.LowCorpus + self.file_extension = '.low' + + def test_line2doc(self): + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + id2word = {1: 'mom', 2: 'window'} + + corpus = self.corpus_class(fname, id2word=id2word) + + # should return all words in doc + corpus.use_wordids = False + self.assertEqual( + sorted(corpus.line2doc(self.CORPUS_LINE)), + [('mom', 1), ('was', 1), ('wash', 1), ('washed', 1), ('window', 2)]) + + # should return words in word2id + corpus.use_wordids = True + self.assertEqual( + sorted(corpus.line2doc(self.CORPUS_LINE)), + [(1, 1), (2, 2)]) + + +class TestUciCorpus(CorpusTestCase): + TEST_CORPUS = [[(1, 1)], [], [(0, 2), (2, 1)], []] + + def setUp(self): + self.corpus_class = ucicorpus.UciCorpus + self.file_extension = '.uci' + + def test_serialize_compressed(self): + # UciCorpus needs file write with seek => doesn't support compressed output (only input) + pass + + +class TestMalletCorpus(TestLowCorpus): + TEST_CORPUS = [[(1, 1)], [], [(0, 2), (2, 1)], []] + CORPUS_LINE = '#3 lang mom wash window window was washed' + + def setUp(self): + self.corpus_class = malletcorpus.MalletCorpus + self.file_extension = '.mallet' + + def test_load_with_metadata(self): + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + corpus = self.corpus_class(fname) + corpus.metadata = True + self.assertEqual(len(corpus), 9) + + docs = list(corpus) + self.assertEqual(len(docs), 9) + + for i, docmeta in enumerate(docs): + doc, metadata = docmeta + self.assertEqual(metadata[0], str(i + 1)) + self.assertEqual(metadata[1], 'en') + + def test_line2doc(self): + # case with metadata=False (by default) + super(TestMalletCorpus, self).test_line2doc() + + # case with metadata=True + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + id2word = {1: 'mom', 2: 'window'} + + corpus = self.corpus_class(fname, id2word=id2word, metadata=True) + + # should return all words in doc + corpus.use_wordids = False + doc, (docid, doclang) = corpus.line2doc(self.CORPUS_LINE) + self.assertEqual(docid, '#3') + self.assertEqual(doclang, 'lang') + self.assertEqual( + sorted(doc), + [('mom', 1), ('was', 1), ('wash', 1), ('washed', 1), ('window', 2)]) + + # should return words in word2id + corpus.use_wordids = True + doc, (docid, doclang) = corpus.line2doc(self.CORPUS_LINE) + + self.assertEqual(docid, '#3') + self.assertEqual(doclang, 'lang') + self.assertEqual( + sorted(doc), + [(1, 1), (2, 2)]) + + +class TestTextCorpus(CorpusTestCase): + + def setUp(self): + self.corpus_class = textcorpus.TextCorpus + self.file_extension = '.txt' + + def test_load_with_metadata(self): + fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + corpus = self.corpus_class(fname) + corpus.metadata = True + self.assertEqual(len(corpus), 9) + + docs = list(corpus) + self.assertEqual(len(docs), 9) + + for i, docmeta in enumerate(docs): + doc, metadata = docmeta + self.assertEqual(metadata[0], i) + + def test_default_preprocessing(self): + lines = [ + "Šéf chomutovských komunistů dostal poštou bílý prášek", + "this is a test for stopwords", + "zf tooth spaces " + ] + expected = [ + ['Sef', 'chomutovskych', 'komunistu', 'dostal', 'postou', 'bily', 'prasek'], + ['test', 'stopwords'], + ['tooth', 'spaces'] + ] + + corpus = self.corpus_from_lines(lines) + texts = list(corpus.get_texts()) + self.assertEqual(expected, texts) + + def corpus_from_lines(self, lines): + fpath = tempfile.mktemp() + with codecs.open(fpath, 'w', encoding='utf8') as f: + f.write('\n'.join(lines)) + + return self.corpus_class(fpath) + + def test_sample_text(self): + lines = ["document%d" % i for i in range(10)] + corpus = self.corpus_from_lines(lines) + corpus.tokenizer = lambda text: text.split() + docs = [doc for doc in corpus.get_texts()] + + sample1 = list(corpus.sample_texts(1)) + self.assertEqual(len(sample1), 1) + self.assertIn(sample1[0], docs) + + sample2 = list(corpus.sample_texts(len(lines))) + self.assertEqual(len(sample2), len(corpus)) + for i in range(len(corpus)): + self.assertEqual(sample2[i], ["document%s" % i]) + + with self.assertRaises(ValueError): + list(corpus.sample_texts(len(corpus) + 1)) + + with self.assertRaises(ValueError): + list(corpus.sample_texts(-1)) + + def test_sample_text_length(self): + lines = ["document%d" % i for i in range(10)] + corpus = self.corpus_from_lines(lines) + corpus.tokenizer = lambda text: text.split() + + sample1 = list(corpus.sample_texts(1, length=1)) + self.assertEqual(sample1[0], ["document0"]) + + sample2 = list(corpus.sample_texts(2, length=2)) + self.assertEqual(sample2[0], ["document0"]) + self.assertEqual(sample2[1], ["document1"]) + + def test_sample_text_seed(self): + lines = ["document%d" % i for i in range(10)] + corpus = self.corpus_from_lines(lines) + + sample1 = list(corpus.sample_texts(5, seed=42)) + sample2 = list(corpus.sample_texts(5, seed=42)) + self.assertEqual(sample1, sample2) + + def test_save(self): + pass + + def test_serialize(self): + pass + + def test_serialize_compressed(self): + pass + + def test_indexing(self): + pass + + +# Needed for the test_custom_tokenizer is the TestWikiCorpus class. +# Cannot be nested due to serializing. +def custom_tokenizer(content, token_min_len=2, token_max_len=15, lower=True): + return [ + to_unicode(token.lower()) if lower else to_unicode(token) for token in content.split() + if token_min_len <= len(token) <= token_max_len and not token.startswith('_') + ] + + +class TestWikiCorpus(TestTextCorpus): + def setUp(self): + self.corpus_class = wikicorpus.WikiCorpus + self.file_extension = '.xml.bz2' + self.fname = datapath('testcorpus.' + self.file_extension.lstrip('.')) + self.enwiki = datapath('enwiki-latest-pages-articles1.xml-p000000010p000030302-shortened.bz2') + + def test_default_preprocessing(self): + expected = ['computer', 'human', 'interface'] + corpus = self.corpus_class(self.fname, article_min_tokens=0) + first_text = next(corpus.get_texts()) + self.assertEqual(expected, first_text) + + def test_len(self): + # When there is no min_token limit all 9 articles must be registered. + corpus = self.corpus_class(self.fname, article_min_tokens=0) + all_articles = corpus.get_texts() + assert (len(list(all_articles)) == 9) + + # With a huge min_token limit, all articles should be filtered out. + corpus = self.corpus_class(self.fname, article_min_tokens=100000) + all_articles = corpus.get_texts() + assert (len(list(all_articles)) == 0) + + def test_load_with_metadata(self): + corpus = self.corpus_class(self.fname, article_min_tokens=0) + corpus.metadata = True + self.assertEqual(len(corpus), 9) + + docs = list(corpus) + self.assertEqual(len(docs), 9) + + for i, docmeta in enumerate(docs): + doc, metadata = docmeta + article_no = i + 1 # Counting IDs from 1 + self.assertEqual(metadata[0], str(article_no)) + self.assertEqual(metadata[1], 'Article%d' % article_no) + + def test_load(self): + corpus = self.corpus_class(self.fname, article_min_tokens=0) + + docs = list(corpus) + # the deerwester corpus always has nine documents + self.assertEqual(len(docs), 9) + + def test_first_element(self): + """ + First two articles in this sample are + 1) anarchism + 2) autism + """ + corpus = self.corpus_class(self.enwiki, processes=1) + + texts = corpus.get_texts() + self.assertTrue(u'anarchism' in next(texts)) + self.assertTrue(u'autism' in next(texts)) + + def test_unicode_element(self): + """ + First unicode article in this sample is + 1) папа + """ + bgwiki = datapath('bgwiki-latest-pages-articles-shortened.xml.bz2') + corpus = self.corpus_class(bgwiki) + texts = corpus.get_texts() + self.assertTrue(u'папа' in next(texts)) + + def test_custom_tokenizer(self): + """ + define a custom tokenizer function and use it + """ + wc = self.corpus_class(self.enwiki, processes=1, lemmatize=False, tokenizer_func=custom_tokenizer, + token_max_len=16, token_min_len=1, lower=False) + row = wc.get_texts() + list_tokens = next(row) + self.assertTrue(u'Anarchism' in list_tokens) + self.assertTrue(u'collectivization' in list_tokens) + self.assertTrue(u'a' in list_tokens) + self.assertTrue(u'i.e.' in list_tokens) + + def test_lower_case_set_true(self): + """ + Set the parameter lower to True and check that upper case 'Anarchism' token doesnt exist + """ + corpus = self.corpus_class(self.enwiki, processes=1, lower=True, lemmatize=False) + row = corpus.get_texts() + list_tokens = next(row) + self.assertTrue(u'Anarchism' not in list_tokens) + self.assertTrue(u'anarchism' in list_tokens) + + def test_lower_case_set_false(self): + """ + Set the parameter lower to False and check that upper case Anarchism' token exists + """ + corpus = self.corpus_class(self.enwiki, processes=1, lower=False, lemmatize=False) + row = corpus.get_texts() + list_tokens = next(row) + self.assertTrue(u'Anarchism' in list_tokens) + self.assertTrue(u'anarchism' in list_tokens) + + def test_min_token_len_not_set(self): + """ + Don't set the parameter token_min_len and check that 'a' as a token doesn't exist + Default token_min_len=2 + """ + corpus = self.corpus_class(self.enwiki, processes=1, lemmatize=False) + self.assertTrue(u'a' not in next(corpus.get_texts())) + + def test_min_token_len_set(self): + """ + Set the parameter token_min_len to 1 and check that 'a' as a token exists + """ + corpus = self.corpus_class(self.enwiki, processes=1, token_min_len=1, lemmatize=False) + self.assertTrue(u'a' in next(corpus.get_texts())) + + def test_max_token_len_not_set(self): + """ + Don't set the parameter token_max_len and check that 'collectivisation' as a token doesn't exist + Default token_max_len=15 + """ + corpus = self.corpus_class(self.enwiki, processes=1, lemmatize=False) + self.assertTrue(u'collectivization' not in next(corpus.get_texts())) + + def test_max_token_len_set(self): + """ + Set the parameter token_max_len to 16 and check that 'collectivisation' as a token exists + """ + corpus = self.corpus_class(self.enwiki, processes=1, token_max_len=16, lemmatize=False) + self.assertTrue(u'collectivization' in next(corpus.get_texts())) + + def test_removed_table_markup(self): + """ + Check if all the table markup has been removed. + """ + enwiki_file = datapath('enwiki-table-markup.xml.bz2') + corpus = self.corpus_class(enwiki_file) + texts = corpus.get_texts() + table_markup = ["style", "class", "border", "cellspacing", "cellpadding", "colspan", "rowspan"] + for text in texts: + for word in table_markup: + self.assertTrue(word not in text) + + # #TODO: sporadic failure to be investigated + # def test_get_texts_returns_generator_of_lists(self): + # corpus = self.corpus_class(self.enwiki) + # l = corpus.get_texts() + # self.assertEqual(type(l), types.GeneratorType) + # first = next(l) + # self.assertEqual(type(first), list) + # self.assertTrue(isinstance(first[0], bytes) or isinstance(first[0], str)) + + def test_sample_text(self): + # Cannot instantiate WikiCorpus from lines + pass + + def test_sample_text_length(self): + # Cannot instantiate WikiCorpus from lines + pass + + def test_sample_text_seed(self): + # Cannot instantiate WikiCorpus from lines + pass + + def test_empty_input(self): + # An empty file is not legit XML + pass + + def test_custom_filterfunction(self): + def reject_all(elem, *args, **kwargs): + return False + corpus = self.corpus_class(self.enwiki, filter_articles=reject_all) + texts = corpus.get_texts() + self.assertFalse(any(texts)) + + def keep_some(elem, title, *args, **kwargs): + return title[0] == 'C' + corpus = self.corpus_class(self.enwiki, filter_articles=reject_all) + corpus.metadata = True + texts = corpus.get_texts() + for text, (pageid, title) in texts: + self.assertEquals(title[0], 'C') + + +class TestTextDirectoryCorpus(unittest.TestCase): + + def write_one_level(self, *args): + if not args: + args = ('doc1', 'doc2') + dirpath = tempfile.mkdtemp() + self.write_docs_to_directory(dirpath, *args) + return dirpath + + def write_docs_to_directory(self, dirpath, *args): + for doc_num, name in enumerate(args): + with open(os.path.join(dirpath, name), 'w') as f: + f.write('document %d content' % doc_num) + + def test_one_level_directory(self): + dirpath = self.write_one_level() + + corpus = textcorpus.TextDirectoryCorpus(dirpath) + self.assertEqual(len(corpus), 2) + docs = list(corpus) + self.assertEqual(len(docs), 2) + + def write_two_levels(self): + dirpath = self.write_one_level() + next_level = os.path.join(dirpath, 'level_two') + os.mkdir(next_level) + self.write_docs_to_directory(next_level, 'doc1', 'doc2') + return dirpath, next_level + + def test_two_level_directory(self): + dirpath, next_level = self.write_two_levels() + + corpus = textcorpus.TextDirectoryCorpus(dirpath) + self.assertEqual(len(corpus), 4) + docs = list(corpus) + self.assertEqual(len(docs), 4) + + corpus = textcorpus.TextDirectoryCorpus(dirpath, min_depth=1) + self.assertEqual(len(corpus), 2) + docs = list(corpus) + self.assertEqual(len(docs), 2) + + corpus = textcorpus.TextDirectoryCorpus(dirpath, max_depth=0) + self.assertEqual(len(corpus), 2) + docs = list(corpus) + self.assertEqual(len(docs), 2) + + def test_filename_filtering(self): + dirpath = self.write_one_level('test1.log', 'test1.txt', 'test2.log', 'other1.log') + corpus = textcorpus.TextDirectoryCorpus(dirpath, pattern=r"test.*\.log") + filenames = list(corpus.iter_filepaths()) + expected = [os.path.join(dirpath, name) for name in ('test1.log', 'test2.log')] + self.assertEqual(sorted(expected), sorted(filenames)) + + corpus.pattern = ".*.txt" + filenames = list(corpus.iter_filepaths()) + expected = [os.path.join(dirpath, 'test1.txt')] + self.assertEqual(expected, filenames) + + corpus.pattern = None + corpus.exclude_pattern = ".*.log" + filenames = list(corpus.iter_filepaths()) + self.assertEqual(expected, filenames) + + def test_lines_are_documents(self): + dirpath = tempfile.mkdtemp() + lines = ['doc%d text' % i for i in range(5)] + fpath = os.path.join(dirpath, 'test_file.txt') + with open(fpath, 'w') as f: + f.write('\n'.join(lines)) + + corpus = textcorpus.TextDirectoryCorpus(dirpath, lines_are_documents=True) + docs = [doc for doc in corpus.getstream()] + self.assertEqual(len(lines), corpus.length) # should have cached + self.assertEqual(lines, docs) + + corpus.lines_are_documents = False + docs = [doc for doc in corpus.getstream()] + self.assertEqual(1, corpus.length) + self.assertEqual('\n'.join(lines), docs[0]) + + def test_non_trivial_structure(self): + """Test with non-trivial directory structure, shown below: + . + ├── 0.txt + ├── a_folder + │ └── 1.txt + └── b_folder + ├── 2.txt + ├── 3.txt + └── c_folder + └── 4.txt + """ + dirpath = tempfile.mkdtemp() + self.write_docs_to_directory(dirpath, '0.txt') + + a_folder = os.path.join(dirpath, 'a_folder') + os.mkdir(a_folder) + self.write_docs_to_directory(a_folder, '1.txt') + + b_folder = os.path.join(dirpath, 'b_folder') + os.mkdir(b_folder) + self.write_docs_to_directory(b_folder, '2.txt', '3.txt') + + c_folder = os.path.join(b_folder, 'c_folder') + os.mkdir(c_folder) + self.write_docs_to_directory(c_folder, '4.txt') + + corpus = textcorpus.TextDirectoryCorpus(dirpath) + filenames = list(corpus.iter_filepaths()) + base_names = sorted(name[len(dirpath) + 1:] for name in filenames) + expected = sorted([ + '0.txt', + 'a_folder/1.txt', + 'b_folder/2.txt', + 'b_folder/3.txt', + 'b_folder/c_folder/4.txt' + ]) + expected = [os.path.normpath(path) for path in expected] + self.assertEqual(expected, base_names) + + corpus.max_depth = 1 + self.assertEqual(expected[:-1], base_names[:-1]) + + corpus.min_depth = 1 + self.assertEqual(expected[2:-1], base_names[2:-1]) + + corpus.max_depth = 0 + self.assertEqual(expected[2:], base_names[2:]) + + corpus.pattern = "4.*" + self.assertEqual(expected[-1], base_names[-1]) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/test_corpora_dictionary_2574.py b/test_corpora_dictionary_2574.py new file mode 100644 index 0000000000..a29311e894 --- /dev/null +++ b/test_corpora_dictionary_2574.py @@ -0,0 +1,375 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + +""" +Unit tests for the `corpora.Dictionary` class. +""" + + +from collections import Mapping +from itertools import chain +import logging +import unittest +import codecs +import os +import os.path + +import scipy +import gensim +from gensim.corpora import Dictionary +from gensim.utils import to_utf8 +from gensim.test.utils import get_tmpfile, common_texts +from six import PY3 +from six.moves import zip + + +class TestDictionary(unittest.TestCase): + def setUp(self): + self.texts = common_texts + + def testDocFreqOneDoc(self): + texts = [['human', 'interface', 'computer']] + d = Dictionary(texts) + expected = {0: 1, 1: 1, 2: 1} + self.assertEqual(d.dfs, expected) + + def testDocFreqAndToken2IdForSeveralDocsWithOneWord(self): + # two docs + texts = [['human'], ['human']] + d = Dictionary(texts) + expected = {0: 2} + self.assertEqual(d.dfs, expected) + # only one token (human) should exist + expected = {'human': 0} + self.assertEqual(d.token2id, expected) + + # three docs + texts = [['human'], ['human'], ['human']] + d = Dictionary(texts) + expected = {0: 3} + self.assertEqual(d.dfs, expected) + # only one token (human) should exist + expected = {'human': 0} + self.assertEqual(d.token2id, expected) + + # four docs + texts = [['human'], ['human'], ['human'], ['human']] + d = Dictionary(texts) + expected = {0: 4} + self.assertEqual(d.dfs, expected) + # only one token (human) should exist + expected = {'human': 0} + self.assertEqual(d.token2id, expected) + + def testDocFreqForOneDocWithSeveralWord(self): + # two words + texts = [['human', 'cat']] + d = Dictionary(texts) + expected = {0: 1, 1: 1} + self.assertEqual(d.dfs, expected) + + # three words + texts = [['human', 'cat', 'minors']] + d = Dictionary(texts) + expected = {0: 1, 1: 1, 2: 1} + self.assertEqual(d.dfs, expected) + + def testDocFreqAndCollectionFreq(self): + # one doc + texts = [['human', 'human', 'human']] + d = Dictionary(texts) + self.assertEqual(d.cfs, {0: 3}) + self.assertEqual(d.dfs, {0: 1}) + + # two docs + texts = [['human', 'human'], ['human']] + d = Dictionary(texts) + self.assertEqual(d.cfs, {0: 3}) + self.assertEqual(d.dfs, {0: 2}) + + # three docs + texts = [['human'], ['human'], ['human']] + d = Dictionary(texts) + self.assertEqual(d.cfs, {0: 3}) + self.assertEqual(d.dfs, {0: 3}) + + def testBuild(self): + d = Dictionary(self.texts) + + # Since we don't specify the order in which dictionaries are built, + # we cannot reliably test for the mapping; only the keys and values. + expected_keys = list(range(12)) + expected_values = [2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3] + self.assertEqual(sorted(d.dfs.keys()), expected_keys) + self.assertEqual(sorted(d.dfs.values()), expected_values) + + expected_keys = sorted([ + 'computer', 'eps', 'graph', 'human', 'interface', + 'minors', 'response', 'survey', 'system', 'time', 'trees', 'user' + ]) + expected_values = list(range(12)) + self.assertEqual(sorted(d.token2id.keys()), expected_keys) + self.assertEqual(sorted(d.token2id.values()), expected_values) + + def testMerge(self): + d = Dictionary(self.texts) + f = Dictionary(self.texts[:3]) + g = Dictionary(self.texts[3:]) + + f.merge_with(g) + self.assertEqual(sorted(d.token2id.keys()), sorted(f.token2id.keys())) + + def testFilter(self): + d = Dictionary(self.texts) + d.filter_extremes(no_below=2, no_above=1.0, keep_n=4) + expected = {0: 3, 1: 3, 2: 3, 3: 3} + self.assertEqual(d.dfs, expected) + + def testFilterKeepTokens_keepTokens(self): + # provide keep_tokens argument, keep the tokens given + d = Dictionary(self.texts) + d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['human', 'survey']) + expected = {'graph', 'trees', 'human', 'system', 'user', 'survey'} + self.assertEqual(set(d.token2id.keys()), expected) + + def testFilterKeepTokens_unchangedFunctionality(self): + # do not provide keep_tokens argument, filter_extremes functionality is unchanged + d = Dictionary(self.texts) + d.filter_extremes(no_below=3, no_above=1.0) + expected = {'graph', 'trees', 'system', 'user'} + self.assertEqual(set(d.token2id.keys()), expected) + + def testFilterKeepTokens_unseenToken(self): + # do provide keep_tokens argument with unseen tokens, filter_extremes functionality is unchanged + d = Dictionary(self.texts) + d.filter_extremes(no_below=3, no_above=1.0, keep_tokens=['unknown_token']) + expected = {'graph', 'trees', 'system', 'user'} + self.assertEqual(set(d.token2id.keys()), expected) + + def testFilterKeepTokens_keepn(self): + # keep_tokens should also work if the keep_n parameter is used, but only + # to keep a maximum of n (so if keep_n < len(keep_n) the tokens to keep are + # still getting removed to reduce the size to keep_n!) + d = Dictionary(self.texts) + # Note: there are four tokens with freq 3, all the others have frequence 2 + # in self.texts. In order to make the test result deterministic, we add + # 2 tokens of frequency one + d.add_documents([['worda'], ['wordb']]) + # this should keep the 3 tokens with freq 3 and the one we want to keep + d.filter_extremes(keep_n=5, no_below=0, no_above=1.0, keep_tokens=['worda']) + expected = {'graph', 'trees', 'system', 'user', 'worda'} + self.assertEqual(set(d.token2id.keys()), expected) + + def testFilterMostFrequent(self): + d = Dictionary(self.texts) + d.filter_n_most_frequent(4) + expected = {0: 2, 1: 2, 2: 2, 3: 2, 4: 2, 5: 2, 6: 2, 7: 2} + self.assertEqual(d.dfs, expected) + + def testFilterTokens(self): + self.maxDiff = 10000 + d = Dictionary(self.texts) + + removed_word = d[0] + d.filter_tokens([0]) + + expected = { + 'computer': 0, 'graph': 10, 'human': 1, + 'interface': 2, 'response': 3, 'survey': 4, + 'system': 5, 'time': 6, 'trees': 9, 'user': 7 + } + del expected[removed_word] + self.assertEqual(sorted(d.token2id.keys()), sorted(expected.keys())) + + expected[removed_word] = len(expected) + d.add_documents([[removed_word]]) + self.assertEqual(sorted(d.token2id.keys()), sorted(expected.keys())) + + def test_doc2bow(self): + d = Dictionary([["žluťoučký"], ["žluťoučký"]]) + + # pass a utf8 string + self.assertEqual(d.doc2bow(["žluťoučký"]), [(0, 1)]) + + # doc2bow must raise a TypeError if passed a string instead of array of strings by accident + self.assertRaises(TypeError, d.doc2bow, "žluťoučký") + + # unicode must be converted to utf8 + self.assertEqual(d.doc2bow([u'\u017elu\u0165ou\u010dk\xfd']), [(0, 1)]) + + def test_saveAsText(self): + """`Dictionary` can be saved as textfile. """ + tmpf = get_tmpfile('save_dict_test.txt') + small_text = [ + ["prvé", "slovo"], + ["slovo", "druhé"], + ["druhé", "slovo"] + ] + + d = Dictionary(small_text) + + d.save_as_text(tmpf) + with codecs.open(tmpf, 'r', encoding='utf-8') as file: + serialized_lines = file.readlines() + self.assertEqual(serialized_lines[0], u"3\n") + self.assertEqual(len(serialized_lines), 4) + # We do not know, which word will have which index + self.assertEqual(serialized_lines[1][1:], u"\tdruhé\t2\n") + self.assertEqual(serialized_lines[2][1:], u"\tprvé\t1\n") + self.assertEqual(serialized_lines[3][1:], u"\tslovo\t3\n") + + d.save_as_text(tmpf, sort_by_word=False) + with codecs.open(tmpf, 'r', encoding='utf-8') as file: + serialized_lines = file.readlines() + self.assertEqual(serialized_lines[0], u"3\n") + self.assertEqual(len(serialized_lines), 4) + self.assertEqual(serialized_lines[1][1:], u"\tslovo\t3\n") + self.assertEqual(serialized_lines[2][1:], u"\tdruhé\t2\n") + self.assertEqual(serialized_lines[3][1:], u"\tprvé\t1\n") + + def test_loadFromText_legacy(self): + """ + `Dictionary` can be loaded from textfile in legacy format. + Legacy format does not have num_docs on the first line. + """ + tmpf = get_tmpfile('load_dict_test_legacy.txt') + no_num_docs_serialization = to_utf8("1\tprvé\t1\n2\tslovo\t2\n") + with open(tmpf, "wb") as file: + file.write(no_num_docs_serialization) + + d = Dictionary.load_from_text(tmpf) + self.assertEqual(d.token2id[u"prvé"], 1) + self.assertEqual(d.token2id[u"slovo"], 2) + self.assertEqual(d.dfs[1], 1) + self.assertEqual(d.dfs[2], 2) + self.assertEqual(d.num_docs, 0) + + def test_loadFromText(self): + """`Dictionary` can be loaded from textfile.""" + tmpf = get_tmpfile('load_dict_test.txt') + no_num_docs_serialization = to_utf8("2\n1\tprvé\t1\n2\tslovo\t2\n") + with open(tmpf, "wb") as file: + file.write(no_num_docs_serialization) + + d = Dictionary.load_from_text(tmpf) + self.assertEqual(d.token2id[u"prvé"], 1) + self.assertEqual(d.token2id[u"slovo"], 2) + self.assertEqual(d.dfs[1], 1) + self.assertEqual(d.dfs[2], 2) + self.assertEqual(d.num_docs, 2) + + def test_saveAsText_and_loadFromText(self): + """`Dictionary` can be saved as textfile and loaded again from textfile. """ + tmpf = get_tmpfile('dict_test.txt') + for sort_by_word in [True, False]: + d = Dictionary(self.texts) + d.save_as_text(tmpf, sort_by_word=sort_by_word) + self.assertTrue(os.path.exists(tmpf)) + + d_loaded = Dictionary.load_from_text(tmpf) + self.assertNotEqual(d_loaded, None) + self.assertEqual(d_loaded.token2id, d.token2id) + + def test_from_corpus(self): + """build `Dictionary` from an existing corpus""" + + documents = [ + "Human machine interface for lab abc computer applications", + "A survey of user opinion of computer system response time", + "The EPS user interface management system", + "System and human system engineering testing of EPS", + "Relation of user perceived response time to error measurement", + "The generation of random binary unordered trees", + "The intersection graph of paths in trees", + "Graph minors IV Widths of trees and well quasi ordering", + "Graph minors A survey" + ] + stoplist = set('for a of the and to in'.split()) + texts = [ + [word for word in document.lower().split() if word not in stoplist] + for document in documents] + + # remove words that appear only once + all_tokens = list(chain.from_iterable(texts)) + tokens_once = set(word for word in set(all_tokens) if all_tokens.count(word) == 1) + texts = [[word for word in text if word not in tokens_once] for text in texts] + + dictionary = Dictionary(texts) + corpus = [dictionary.doc2bow(text) for text in texts] + + # Create dictionary from corpus without a token map + dictionary_from_corpus = Dictionary.from_corpus(corpus) + + dict_token2id_vals = sorted(dictionary.token2id.values()) + dict_from_corpus_vals = sorted(dictionary_from_corpus.token2id.values()) + self.assertEqual(dict_token2id_vals, dict_from_corpus_vals) + self.assertEqual(dictionary.dfs, dictionary_from_corpus.dfs) + self.assertEqual(dictionary.num_docs, dictionary_from_corpus.num_docs) + self.assertEqual(dictionary.num_pos, dictionary_from_corpus.num_pos) + self.assertEqual(dictionary.num_nnz, dictionary_from_corpus.num_nnz) + + # Create dictionary from corpus with an id=>token map + dictionary_from_corpus_2 = Dictionary.from_corpus(corpus, id2word=dictionary) + + self.assertEqual(dictionary.token2id, dictionary_from_corpus_2.token2id) + self.assertEqual(dictionary.dfs, dictionary_from_corpus_2.dfs) + self.assertEqual(dictionary.num_docs, dictionary_from_corpus_2.num_docs) + self.assertEqual(dictionary.num_pos, dictionary_from_corpus_2.num_pos) + self.assertEqual(dictionary.num_nnz, dictionary_from_corpus_2.num_nnz) + + # Ensure Sparse2Corpus is compatible with from_corpus + bow = gensim.matutils.Sparse2Corpus(scipy.sparse.rand(10, 100)) + dictionary = Dictionary.from_corpus(bow) + self.assertEqual(dictionary.num_docs, 100) + + def test_dict_interface(self): + """Test Python 2 dict-like interface in both Python 2 and 3.""" + d = Dictionary(self.texts) + + self.assertTrue(isinstance(d, Mapping)) + + self.assertEqual(list(zip(d.keys(), d.values())), list(d.items())) + + # Even in Py3, we want the iter* members. + self.assertEqual(list(d.items()), list(d.iteritems())) + self.assertEqual(list(d.keys()), list(d.iterkeys())) + self.assertEqual(list(d.values()), list(d.itervalues())) + + # XXX Do we want list results from the dict members in Py3 too? + if not PY3: + self.assertTrue(isinstance(d.items(), list)) + self.assertTrue(isinstance(d.keys(), list)) + self.assertTrue(isinstance(d.values(), list)) + + def test_patch_with_special_tokens(self): + special_tokens = {'pad': 0, 'space': 1, 'quake': 3} + corpus = [["máma", "mele", "maso"], ["ema", "má", "máma"]] + d = Dictionary(corpus) + self.assertEqual(len(d.token2id), 5) + d.patch_with_special_tokens(special_tokens) + self.assertEqual(d.token2id['pad'], 0) + self.assertEqual(d.token2id['space'], 1) + self.assertEqual(d.token2id['quake'], 3) + self.assertEqual(len(d.token2id), 8) + self.assertNotIn((0, 1), d.doc2bow(corpus[0])) + self.assertIn((0, 1), d.doc2bow(['pad'] + corpus[0])) + corpus_with_special_tokens = [["máma", "mele", "maso"], ["ema", "má", "máma", "space"]] + d = Dictionary(corpus_with_special_tokens) + self.assertEqual(len(d.token2id), 6) + self.assertNotEqual(d.token2id['space'], 1) + d.patch_with_special_tokens(special_tokens) + self.assertEqual(len(d.token2id), 8) + self.assertEqual(max(d.token2id.values()), 7) + self.assertEqual(d.token2id['space'], 1) + self.assertNotIn((1, 1), d.doc2bow(corpus_with_special_tokens[0])) + self.assertIn((1, 1), d.doc2bow(corpus_with_special_tokens[1])) + + +# endclass TestDictionary + + +if __name__ == '__main__': + logging.basicConfig(level=logging.WARNING) + unittest.main() diff --git a/test_datatype.py b/test_datatype.py new file mode 100644 index 0000000000..58198fbd35 --- /dev/null +++ b/test_datatype.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + +""" +Automated tests for checking various matutils functions. +""" + +import logging +import unittest + +import numpy as np + +from gensim.test.utils import datapath +from gensim.models.keyedvectors import KeyedVectors + + +class TestDataType(unittest.TestCase): + def load_model(self, datatype): + path = datapath('high_precision.kv.txt') + kv = KeyedVectors.load_word2vec_format(path, binary=False, + datatype=datatype) + return kv + + def test_high_precision(self): + kv = self.load_model(np.float64) + self.assertAlmostEqual(kv['horse.n.01'][0], -0.0008546282343595379) + self.assertEqual(kv['horse.n.01'][0].dtype, np.float64) + + def test_medium_precision(self): + kv = self.load_model(np.float32) + self.assertAlmostEqual(kv['horse.n.01'][0], -0.00085462822) + self.assertEqual(kv['horse.n.01'][0].dtype, np.float32) + + def test_low_precision(self): + kv = self.load_model(np.float16) + + self.assertEqual(kv['horse.n.01'][0].dtype, np.float16) + + def test_type_conversion(self): + path = datapath('high_precision.kv.txt') + binary_path = datapath('high_precision.kv.bin') + model1 = KeyedVectors.load_word2vec_format(path, datatype=np.float16) + model1.save_word2vec_format(binary_path, binary=True) + model2 = KeyedVectors.load_word2vec_format(binary_path, datatype=np.float64, binary=True) + self.assertAlmostEqual(model1["horse.n.01"][0], np.float16(model2["horse.n.01"][0])) + self.assertEqual(model1["horse.n.01"][0].dtype, np.float16) + self.assertEqual(model2["horse.n.01"][0].dtype, np.float64) + + +if __name__ == '__main__': + logging.root.setLevel(logging.WARNING) + unittest.main() diff --git a/test_hdpmodel.py b/test_hdpmodel.py new file mode 100644 index 0000000000..1f73ea7e3c --- /dev/null +++ b/test_hdpmodel.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2010 Radim Rehurek +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + +""" +Automated tests for checking transformation algorithms (the models package). +""" + + +import logging +import unittest + +from gensim.corpora import mmcorpus, Dictionary +from gensim.models import hdpmodel +from gensim.test import basetmtests +from gensim.test.utils import datapath, common_texts + +import numpy as np + +dictionary = Dictionary(common_texts) +corpus = [dictionary.doc2bow(text) for text in common_texts] + + +class TestHdpModel(unittest.TestCase, basetmtests.TestBaseTopicModel): + def setUp(self): + self.corpus = mmcorpus.MmCorpus(datapath('testcorpus.mm')) + self.class_ = hdpmodel.HdpModel + self.model = self.class_(corpus, id2word=dictionary, random_state=np.random.seed(0)) + + def testTopicValues(self): + """ + Check show topics method + """ + results = self.model.show_topics()[0] + expected_prob, expected_word = '0.264', 'trees ' + prob, word = results[1].split('+')[0].split('*') + self.assertEqual(results[0], 0) + self.assertEqual(prob, expected_prob) + self.assertEqual(word, expected_word) + + return + + def testLDAmodel(self): + """ + Create ldamodel object, and check if the corresponding alphas are equal. + """ + + + +if __name__ == '__main__': + logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) + unittest.main()