Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow use of truncated Dictionary for coherence measures #1349

Merged
merged 34 commits into from
Jun 14, 2017
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
9d06a1f
#1342: Allow use of truncated `Dictionary` for coherence calculation …
May 22, 2017
f69a2ff
#1342: Do not produce sliding windows for texts with no relevant word…
May 22, 2017
26de547
#1342: Remove unused multiprocessing import in `coherencemodel` module.
May 22, 2017
dfe159b
add utility functions for strided windowing of texts (lists of string…
May 24, 2017
2e3852e
handle edge cases with window_size equal to or exceeding document siz…
May 24, 2017
ec7af1b
move code for building inverted index into a new text_analysis module…
May 24, 2017
3f8fb7f
complete migration to using the accumulators in the new text_analysis…
May 24, 2017
b12edef
fix bug in WordOccurrenceAccumulator so that co-occurrences of same w…
May 25, 2017
91b8a05
make wikicorpus parsing handle KeyboardInterrupt gracefully
May 25, 2017
c6224b7
add ParallelWordOccurrenceAccumulator and make default method for p_b…
May 26, 2017
f00d389
clean up, clarify, and optimize the indirect_confirmation_measure.cos…
May 27, 2017
327b739
#1342: Cleanup, documentation improvements, proper caching of accumul…
May 30, 2017
3c0752b
Merge branch 'develop' into develop
macks22 May 30, 2017
e06c7c3
#1342: Do not swallow `KeyboardInterrupt` naively in `WikiCorpus.get_…
May 30, 2017
2ca43f7
#1342: Formatting fixes (hanging indent in `coherencemodel` and non-e…
May 30, 2017
825b0e9
#1342: Improve `CoherenceModel` documentation and minor refactor for …
May 30, 2017
314a400
#1342: Optimize word occurrence accumulation and fix a bug with repea…
May 30, 2017
e785773
#1342: Minor bug fixes and improved logging in text_analysis module; …
May 31, 2017
5f78cdb
#1342: Optimize data structures being used for window set tracking an…
May 31, 2017
bbd2748
#1342: Fix accidental typo.
May 31, 2017
5fb0b95
#1342: Further optimize word co-occurrence accumulation by using a `c…
May 31, 2017
880b8d0
#1342: Clean up logging in `text_analysis` module and remove empty li…
Jun 1, 2017
1d32b8e
#1342: Remove unused traceback module.
Jun 1, 2017
8e04b41
#1342: Fixes for python3 compatibility.
Jun 1, 2017
e3ce402
#1342: Hopefully `six.viewitems` works for python3 compatibility?
Jun 1, 2017
7f7f55d
#1342: Realized the python3 compatibility issue was due to the Dictio…
Jun 1, 2017
343da69
#1342: Fixed a few bugs and added test coverage for the coherencemode…
Jun 2, 2017
1ce8a72
#1342: Further tests for persistence of accumulator.
Jun 2, 2017
96fd343
#1342: Add test case for `CorpusAccumulator`.
Jun 4, 2017
a631ab6
#1342: Formatting fixes for hanging indents and overly long lines.
Jun 5, 2017
5f58bda
#1342: Fix `indirect_confirmation_measure.cosine_similarity` to retur…
Jun 6, 2017
b941f3c
#1342: Fix `direct_confirmation_measure` functions to return individu…
Jun 7, 2017
75fcac8
#1342: Hanging indents and switch out `union` with `update` for uniqu…
Jun 8, 2017
96d1349
#1342: Clarify documentation in the `probability_estimation` module.
Jun 9, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions gensim/corpora/wikicorpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import re
from xml.etree.cElementTree import iterparse # LXML isn't faster, so let's go with the built-in solution
import multiprocessing
import signal

from gensim import utils

Expand Down Expand Up @@ -249,6 +250,10 @@ def process_article(args):
return result, title, pageid


def init_worker():
signal.signal(signal.SIGINT, signal.SIG_IGN)


class WikiCorpus(TextCorpus):
"""
Treat a wikipedia articles dump (\*articles.xml.bz2) as a (read-only) corpus.
Expand Down Expand Up @@ -300,22 +305,26 @@ def get_texts(self):
articles, articles_all = 0, 0
positions, positions_all = 0, 0
texts = ((text, self.lemmatize, title, pageid) for title, text, pageid in extract_pages(bz2.BZ2File(self.fname), self.filter_namespaces))
pool = multiprocessing.Pool(self.processes)
pool = multiprocessing.Pool(self.processes, init_worker)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this change about? Needs a comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piskvorky I was running get_coherence using coherence="c_v" and trying to program it to gracefully handle KeyboardInterrupt. You can see the mechanisms I put in place for this in the text_analysis module in the PR. While doing this, I faced some confusion because some background process was still raising KeyboardInterrupt. After some digging, I noticed that the wikicorpus pool workers were the culprit.

This pool may be active during many other phases of gensim execution if the underlying corpus/texts being iterated come from the wikicorpus. I think it makes things slightly cleaner in this case to handle the KeyboardInterrupt in some manner that does not propagate. Perhaps some sort of logging would improve this?

# process the corpus in smaller chunks of docs, because multiprocessing.Pool
# is dumb and would load the entire input into RAM at once...
for group in utils.chunkize(texts, chunksize=10 * self.processes, maxsize=1):
for tokens, title, pageid in pool.imap(process_article, group): # chunksize=10):
articles_all += 1
positions_all += len(tokens)
# article redirects and short stubs are pruned here
if len(tokens) < ARTICLE_MIN_WORDS or any(title.startswith(ignore + ':') for ignore in IGNORED_NAMESPACES):
continue
articles += 1
positions += len(tokens)
if self.metadata:
yield (tokens, (pageid, title))
else:
yield tokens
try:
for group in utils.chunkize(texts, chunksize=10 * self.processes, maxsize=1):
for tokens, title, pageid in pool.imap(process_article, group): # chunksize=10):
articles_all += 1
positions_all += len(tokens)
# article redirects and short stubs are pruned here
if len(tokens) < ARTICLE_MIN_WORDS or any(title.startswith(ignore + ':') for ignore in IGNORED_NAMESPACES):
continue
articles += 1
positions += len(tokens)
if self.metadata:
yield (tokens, (pageid, title))
else:
yield tokens
except KeyboardInterrupt:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for? Masking user interrupts is an anti-pattern; deserves a detailed comment, at the very least.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the code to handle the interrupt more appropriately. As for why I chose to handle the interrupt at all (at the risk of repeating the above comment): this pool may be active during many other phases of gensim execution if the underlying corpus/texts being iterated come from the wikicorpus. It was confusing for me to see stdout activity when issuing an interrupt during execution of an entirely different code path.

Copy link
Owner

@piskvorky piskvorky May 30, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, not sure I understand.

When iterating over a WikiCorpus, gensim uses a multiprocessing.Pool (forks), yes.

Now when you do CTRL+C, what happened (old behaviour)? And what happens now (after your changes here)? Why is that preferable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the code I am running:

import os
import logging
import gensim
from gensim.models.coherencemodel import CoherenceModel

logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)

home = '/Users/user/workshop/nlp'
id2word = gensim.corpora.Dictionary.load_from_text(os.path.join(home, 'data', 'wikipedia-may-17_wordids.txt.bz2'))
texts = gensim.corpora.WikiCorpus(os.path.join(home, 'data', 'enwiki-latest-pages-articles.xml.bz2'), dictionary=id2word).get_texts()
lda = gensim.models.LdaModel.load(os.path.join(home, 'models', 'lda-k100_en-wiki.model'))
topics = [[t[0] for t in lda.get_topic_terms(i, 20)] for i in range(100)]
topic_words = [[id2word[token_id] for token_id in t] for t in topics]
cm = gensim.models.coherencemodel.CoherenceModel(topics=topic_words, texts=texts, dictionary=id2word, coherence='c_v', topn=20, window_size=110)
topic_coherences = cm.get_coherence_per_topic()
print(topic_coherences)

Before:

2017-05-30 07:34:29,027 : INFO : loading LdaModel object from /Users/user/workshop/nlp/models/lda-k100_en-wiki.model
2017-05-30 07:34:29,028 : INFO : loading expElogbeta from /Users/user/workshop/nlp/models/lda-k100_en-wiki.model.expElogbeta.npy with mmap=None
2017-05-30 07:34:29,080 : INFO : setting ignored attribute id2word to None
2017-05-30 07:34:29,080 : INFO : setting ignored attribute state to None
2017-05-30 07:34:29,080 : INFO : setting ignored attribute dispatcher to None
2017-05-30 07:34:29,080 : INFO : loaded /Users/user/workshop/nlp/models/lda-k100_en-wiki.model
2017-05-30 07:34:29,081 : INFO : loading LdaModel object from /Users/user/workshop/nlp/models/lda-k100_en-wiki.model.state
2017-05-30 07:34:29,249 : INFO : loaded /Users/user/workshop/nlp/models/lda-k100_en-wiki.model.state
2017-05-30 07:34:32,603 : INFO : using ParallelWordOccurrenceAccumulator(processes=7, batch_size=16) to estimate probabilities from sliding windows
2017-05-30 07:34:33,101 : INFO : submitted 0 batches to accumulate stats from 0 documents (90752 virtual)
2017-05-30 07:34:33,298 : INFO : submitted 1 batches to accumulate stats from 16 documents (154816 virtual)
2017-05-30 07:34:33,462 : INFO : submitted 2 batches to accumulate stats from 32 documents (223497 virtual)
2017-05-30 07:34:33,709 : INFO : submitted 3 batches to accumulate stats from 48 documents (285340 virtual)
2017-05-30 07:34:33,790 : INFO : submitted 4 batches to accumulate stats from 64 documents (342337 virtual)
2017-05-30 07:34:34,297 : INFO : submitted 5 batches to accumulate stats from 80 documents (415139 virtual)
2017-05-30 07:34:34,872 : INFO : submitted 6 batches to accumulate stats from 96 documents (484709 virtual)
2017-05-30 07:34:35,093 : INFO : submitted 7 batches to accumulate stats from 112 documents (542834 virtual)
2017-05-30 07:34:35,381 : INFO : submitted 8 batches to accumulate stats from 128 documents (628469 virtual)
2017-05-30 07:34:35,443 : INFO : submitted 9 batches to accumulate stats from 144 documents (691420 virtual)
2017-05-30 07:34:35,764 : INFO : submitted 10 batches to accumulate stats from 160 documents (741122 virtual)
2017-05-30 07:34:35,983 : INFO : submitted 11 batches to accumulate stats from 176 documents (774924 virtual)
2017-05-30 07:34:36,234 : INFO : submitted 12 batches to accumulate stats from 192 documents (829056 virtual)
2017-05-30 07:34:36,682 : INFO : submitted 13 batches to accumulate stats from 208 documents (887935 virtual)
^C2017-05-30 07:34:48,466 : WARNING : stats accumulation interrupted; <= 887935 documents processed
Process PoolWorker-12:
Process PoolWorker-10:
Process PoolWorker-11:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
Process PoolWorker-9:
Traceback (most recent call last):
Process PoolWorker-8:
Process PoolWorker-13:
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
Process PoolWorker-14:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
Traceback (most recent call last):
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
    self.run()
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/process.py", line 114, in run
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/process.py", line 114, in run
    self.run()
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/process.py", line 114, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/pool.py", line 102, in worker
    self._target(*self._args, **self._kwargs)
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/pool.py", line 102, in worker
    self.run()
    self.run()
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/process.py", line 114, in run
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/process.py", line 114, in run
    self._target(*self._args, **self._kwargs)
    self.run()
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/process.py", line 114, in run
    self._target(*self._args, **self._kwargs)
    self._target(*self._args, **self._kwargs)
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/pool.py", line 102, in worker
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/pool.py", line 102, in worker
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/pool.py", line 102, in worker
    self.run()
    self._target(*self._args, **self._kwargs)
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/process.py", line 114, in run
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/pool.py", line 102, in worker
    self._target(*self._args, **self._kwargs)
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/pool.py", line 102, in worker
2017-05-30 07:34:48,467 : INFO : AccumulatingWorker interrupted after processing 8598 documents
2017-05-30 07:34:48,468 : INFO : AccumulatingWorker interrupted after processing 3630 documents
2017-05-30 07:34:48,467 : INFO : AccumulatingWorker interrupted after processing 7670 documents
2017-05-30 07:34:48,468 : INFO : AccumulatingWorker interrupted after processing 5859 documents
2017-05-30 07:34:48,468 : INFO : AccumulatingWorker interrupted after processing 9993 documents
2017-05-30 07:34:48,468 : INFO : AccumulatingWorker interrupted after processing 8987 documents
2017-05-30 07:34:48,469 : INFO : serializing accumulator to return to master...
2017-05-30 07:34:48,469 : INFO : serializing accumulator to return to master...
2017-05-30 07:34:48,468 : INFO : AccumulatingWorker interrupted after processing 5062 documents
2017-05-30 07:34:48,469 : INFO : serializing accumulator to return to master...
2017-05-30 07:34:48,469 : INFO : serializing accumulator to return to master...
2017-05-30 07:34:48,469 : INFO : serializing accumulator to return to master...
2017-05-30 07:34:48,469 : INFO : serializing accumulator to return to master...
    task = get()
    task = get()
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/queues.py", line 376, in get
    task = get()
    task = get()
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/queues.py", line 376, in get
    task = get()
    task = get()
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/queues.py", line 376, in get
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/queues.py", line 376, in get
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/queues.py", line 378, in get
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/queues.py", line 376, in get
    task = get()
  File "/Users/user/anaconda2/lib/python2.7/multiprocessing/queues.py", line 376, in get
    racquire()
    racquire()
    racquire()
    racquire()
    racquire()
    racquire()
    return recv()
KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt
2017-05-30 07:34:48,469 : INFO : serializing accumulator to return to master...
2017-05-30 07:34:48,471 : INFO : accumulator serialized
2017-05-30 07:34:48,473 : INFO : accumulator serialized
2017-05-30 07:34:48,471 : INFO : accumulator serialized
2017-05-30 07:34:48,473 : INFO : accumulator serialized
2017-05-30 07:34:48,471 : INFO : accumulator serialized
2017-05-30 07:34:48,474 : INFO : accumulator serialized
2017-05-30 07:34:48,474 : INFO : accumulator serialized
2017-05-30 07:34:50,727 : INFO : 7 accumulators retrieved from output queue

After:

2017-05-30 07:29:17,246 : WARNING : Slow version of gensim.models.doc2vec is being used
2017-05-30 07:29:17,253 : INFO : 'pattern' package not found; tag filters are not available for English
2017-05-30 07:29:18,188 : INFO : loading LdaModel object from /Users/user/workshop/nlp/models/lda-k100_en-wiki.model
2017-05-30 07:29:18,198 : INFO : loading expElogbeta from /Users/user/workshop/nlp/models/lda-k100_en-wiki.model.expElogbeta.npy with mmap=None
2017-05-30 07:29:18,267 : INFO : setting ignored attribute id2word to None
2017-05-30 07:29:18,267 : INFO : setting ignored attribute state to None
2017-05-30 07:29:18,267 : INFO : setting ignored attribute dispatcher to None
2017-05-30 07:29:18,267 : INFO : loaded /Users/user/workshop/nlp/models/lda-k100_en-wiki.model
2017-05-30 07:29:18,267 : INFO : loading LdaModel object from /Users/user/workshop/nlp/models/lda-k100_en-wiki.model.state
2017-05-30 07:29:18,422 : INFO : loaded /Users/user/workshop/nlp/models/lda-k100_en-wiki.model.state
2017-05-30 07:29:20,667 : INFO : using ParallelWordOccurrenceAccumulator(processes=7, batch_size=16) to estimate probabilities from sliding windows
2017-05-30 07:29:21,117 : INFO : submitted 0 batches to accumulate stats from 0 documents (90752 virtual)
2017-05-30 07:29:21,312 : INFO : submitted 1 batches to accumulate stats from 16 documents (154816 virtual)
2017-05-30 07:29:21,441 : INFO : submitted 2 batches to accumulate stats from 32 documents (223497 virtual)
2017-05-30 07:29:21,715 : INFO : submitted 3 batches to accumulate stats from 48 documents (285340 virtual)
2017-05-30 07:29:21,783 : INFO : submitted 4 batches to accumulate stats from 64 documents (342337 virtual)
2017-05-30 07:29:22,270 : INFO : submitted 5 batches to accumulate stats from 80 documents (415139 virtual)
2017-05-30 07:29:22,634 : INFO : submitted 6 batches to accumulate stats from 96 documents (484709 virtual)
2017-05-30 07:29:22,701 : INFO : submitted 7 batches to accumulate stats from 112 documents (542834 virtual)
2017-05-30 07:29:22,995 : INFO : submitted 8 batches to accumulate stats from 128 documents (628469 virtual)
2017-05-30 07:29:23,118 : INFO : submitted 9 batches to accumulate stats from 144 documents (691420 virtual)
2017-05-30 07:29:23,238 : INFO : submitted 10 batches to accumulate stats from 160 documents (741122 virtual)
2017-05-30 07:29:23,336 : INFO : submitted 11 batches to accumulate stats from 176 documents (774924 virtual)
2017-05-30 07:29:23,471 : INFO : submitted 12 batches to accumulate stats from 192 documents (829056 virtual)
2017-05-30 07:29:23,665 : INFO : submitted 13 batches to accumulate stats from 208 documents (887935 virtual)
^C2017-05-30 07:29:58,875 : WARNING : stats accumulation interrupted; <= 887935 documents processed
2017-05-30 07:29:58,876 : INFO : AccumulatingWorker interrupted after processing 17775 documents
2017-05-30 07:29:58,876 : INFO : AccumulatingWorker interrupted after processing 27228 documents
2017-05-30 07:29:58,876 : INFO : AccumulatingWorker interrupted after processing 20910 documents
2017-05-30 07:29:58,876 : INFO : AccumulatingWorker interrupted after processing 26774 documents
2017-05-30 07:29:58,876 : INFO : AccumulatingWorker interrupted after processing 31392 documents
2017-05-30 07:29:58,876 : INFO : AccumulatingWorker interrupted after processing 32224 documents
2017-05-30 07:29:58,876 : INFO : AccumulatingWorker interrupted after processing 22497 documents
2017-05-30 07:29:58,877 : INFO : serializing accumulator to return to master...
2017-05-30 07:29:58,877 : INFO : serializing accumulator to return to master...
2017-05-30 07:29:58,877 : INFO : serializing accumulator to return to master...
2017-05-30 07:29:58,877 : INFO : serializing accumulator to return to master...
2017-05-30 07:29:58,877 : INFO : serializing accumulator to return to master...
2017-05-30 07:29:58,877 : INFO : serializing accumulator to return to master...
2017-05-30 07:29:58,877 : INFO : serializing accumulator to return to master...
2017-05-30 07:29:58,878 : INFO : accumulator serialized
2017-05-30 07:29:58,878 : INFO : accumulator serialized
2017-05-30 07:29:58,878 : INFO : accumulator serialized
2017-05-30 07:29:58,878 : INFO : accumulator serialized
2017-05-30 07:29:58,878 : INFO : accumulator serialized
2017-05-30 07:29:58,878 : INFO : accumulator serialized
2017-05-30 07:29:58,878 : INFO : accumulator serialized
2017-05-30 07:30:01,106 : INFO : 7 accumulators retrieved from output queue

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the advantage is a cleaner log?
What are the disadvantages? (I am not familiar with this type of functionality, but I assume there must be some disadvantages, otherwise it would be the default behaviour).

pass

pool.terminate()

logger.info(
Expand Down
161 changes: 116 additions & 45 deletions gensim/models/coherencemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,25 @@
"""

import logging
from collections import namedtuple
import multiprocessing as mp

import numpy as np

from gensim import interfaces
from gensim.matutils import argsort
from gensim.models.ldamodel import LdaModel
from gensim.models.wrappers import LdaVowpalWabbit, LdaMallet
from gensim.topic_coherence import (segmentation, probability_estimation,
direct_confirmation_measure, indirect_confirmation_measure,
aggregation)
from gensim.matutils import argsort
from gensim.topic_coherence.probability_estimation import unique_ids_from_segments
from gensim.utils import is_corpus, FakeDict
from gensim.models.ldamodel import LdaModel
from gensim.models.wrappers import LdaVowpalWabbit, LdaMallet

import numpy as np

from collections import namedtuple

logger = logging.getLogger(__name__)

boolean_document_based = ['u_mass']
sliding_window_based = ['c_v', 'c_uci', 'c_npmi']
boolean_document_based = {'u_mass'}
sliding_window_based = {'c_v', 'c_uci', 'c_npmi'}
make_pipeline = namedtuple('Coherence_Measure', 'seg, prob, conf, aggr')

coherence_dict = {
Expand Down Expand Up @@ -64,10 +65,9 @@
'c_npmi': 10
}


class CoherenceModel(interfaces.TransformationABC):
"""
Objects of this class allow for building and maintaining a model for topic
coherence.
"""Objects of this class allow for building and maintaining a model for topic coherence.

The main methods are:

Expand All @@ -89,7 +89,8 @@ class CoherenceModel(interfaces.TransformationABC):

Model persistency is achieved via its load/save methods.
"""
def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=None, window_size=None, coherence='c_v', topn=10):
def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=None, window_size=None,
coherence='c_v', topn=10, processes=-1):
"""
Args:
----
Expand Down Expand Up @@ -128,8 +129,10 @@ def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=
raise ValueError("One of model or topics has to be provided.")
elif topics is not None and dictionary is None:
raise ValueError("dictionary has to be provided if topics are to be used.")

if texts is None and corpus is None:
raise ValueError("One of texts or corpus has to be provided.")

# Check if associated dictionary is provided.
if dictionary is None:
if isinstance(model.id2word, FakeDict):
Expand All @@ -139,7 +142,9 @@ def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=
self.dictionary = model.id2word
else:
self.dictionary = dictionary

# Check for correct inputs for u_mass coherence measure.
self.coherence = coherence
if coherence in boolean_document_based:
if is_corpus(corpus)[0]:
self.corpus = corpus
Expand All @@ -148,30 +153,72 @@ def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=
self.corpus = [self.dictionary.doc2bow(text) for text in self.texts]
else:
raise ValueError("Either 'corpus' with 'dictionary' or 'texts' should be provided for %s coherence." % coherence)

# Check for correct inputs for c_v coherence measure.
elif coherence in sliding_window_based:
self.window_size = window_size
if self.window_size is None:
self.window_size = sliding_windows_dict[self.coherence]
if texts is None:
raise ValueError("'texts' should be provided for %s coherence." % coherence)
else:
self.texts = texts
else:
raise ValueError("%s coherence is not currently supported." % coherence)

self.topn = topn
self.model = model
if model is not None:
self.topics = self._get_topics()

self._accumulator = None
self._topics = None
self.topics = topics

self.processes = processes if processes > 1 else max(1, mp.cpu_count() - 1)

def __str__(self):
return str(self.measure)

@property
def measure(self):
return coherence_dict[self.coherence]

@property
def topics(self):
return self._topics

@topics.setter
def topics(self, topics):
new_topics = None
if self.model is not None:
new_topics = self._get_topics()
if topics is not None:
logger.warn("Ignoring topics you are attempting to set in favor of model's topics: %s" % self.model)
elif topics is not None:
self.topics = []
new_topics = []
for topic in topics:
t_i = []
for n, _ in enumerate(topic):
t_i.append(dictionary.token2id[topic[n]])
self.topics.append(np.array(t_i))
self.coherence = coherence
t_i = np.array([self.dictionary.token2id[topic[n]] for n, _ in enumerate(topic)])
new_topics.append(np.array(t_i))

def __str__(self):
return coherence_dict[self.coherence].__str__()
if self._relevant_ids_will_differ(new_topics):
logger.debug("Wiping cached accumulator since it does not contain all relevant ids.")
self._accumulator = None

self._topics = new_topics

def _relevant_ids_will_differ(self, new_topics):
if not self._topics_differ(new_topics):
return False

measure = self.measure
current_set = unique_ids_from_segments(measure.seg(self.topics))
new_set = unique_ids_from_segments(measure.seg(new_topics))
return not current_set.issuperset(new_set)

def _topics_differ(self, new_topics):
return (new_topics is not None and
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No vertical indent -- please use hanging indent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self._topics is not None and
self._accumulator is not None and
not np.equal(new_topics, self._topics).all())
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what type the arguments are, but np.allclose applicable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list of np.ndarray of ints, so there should be no need for np.allclose.


def _get_topics(self):
"""Internal helper function to return topics from a trained topic model."""
Expand All @@ -193,27 +240,51 @@ def _get_topics(self):
"LdaModel, LdaVowpalWabbit and LdaMallet.")
return topics

def get_coherence(self):
"""
Return coherence value based on pipeline parameters.
def segment_topics(self):
return self.measure.seg(self.topics)

def estimate_probabilities(self, segmented_topics=None):
"""Accumulate word occurrences and co-occurrences from texts or corpus using
the optimal method for the chosen coherence metric. This operation may take
quite some time for the sliding window based coherence methods.
"""
measure = coherence_dict[self.coherence]
segmented_topics = measure.seg(self.topics)
if segmented_topics is None:
segmented_topics = self.segment_topics()

if self.coherence in boolean_document_based:
per_topic_postings, num_docs = measure.prob(self.corpus, segmented_topics)
confirmed_measures = measure.conf(segmented_topics, per_topic_postings, num_docs)
elif self.coherence in sliding_window_based:
if self.window_size is not None:
self.window_size = sliding_windows_dict[self.coherence]
per_topic_postings, num_windows = measure.prob(texts=self.texts, segmented_topics=segmented_topics,
dictionary=self.dictionary, window_size=self.window_size)
if self.coherence == 'c_v':
confirmed_measures = measure.conf(self.topics, segmented_topics, per_topic_postings, 'nlr', 1, num_windows)
else:
if self.coherence == 'c_npmi':
normalize = True
else:
# For c_uci
normalize = False
confirmed_measures = measure.conf(segmented_topics, per_topic_postings, num_windows, normalize=normalize)
return measure.aggr(confirmed_measures)
self._accumulator = self.measure.prob(self.corpus, segmented_topics)
else:
self._accumulator = self.measure.prob(
texts=self.texts, segmented_topics=segmented_topics,
dictionary=self.dictionary, window_size=self.window_size,
processes=self.processes)

return self._accumulator

def get_coherence_per_topic(self, segmented_topics=None):
"""Return list of coherence values for each topic based on pipeline parameters."""
measure = self.measure
if segmented_topics is None:
segmented_topics = measure.seg(self.topics)
if self._accumulator is None:
self.estimate_probabilities(segmented_topics)

if self.coherence in boolean_document_based:
kwargs = {}
elif self.coherence == 'c_v':
kwargs = dict(topics=self.topics, measure='nlr', gamma=1)
else:
kwargs = dict(normalize=(self.coherence == 'c_npmi'))

return measure.conf(segmented_topics, self._accumulator, **kwargs)

def aggregate_measures(self, confirmed_measures):
"""Aggregate the individual topic coherence measures using
the pipeline's aggregation function.
"""
return self.measure.aggr(confirmed_measures)

def get_coherence(self):
"""Return coherence value based on pipeline parameters."""
confirmed_measures = self.get_coherence_per_topic()
return self.aggregate_measures(confirmed_measures)
34 changes: 18 additions & 16 deletions gensim/test/test_coherencemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os.path
import tempfile

from gensim.models.coherencemodel import CoherenceModel
from gensim.models.coherencemodel import CoherenceModel, boolean_document_based
from gensim.models.ldamodel import LdaModel
from gensim.models.wrappers import LdaMallet
from gensim.models.wrappers import LdaVowpalWabbit
Expand All @@ -35,23 +35,12 @@
['graph', 'minors', 'survey']]
dictionary = Dictionary(texts)
corpus = [dictionary.doc2bow(text) for text in texts]
boolean_document_based = ['u_mass']
sliding_window_based = ['c_v', 'c_uci', 'c_npmi']


def testfile():
# temporary data will be stored to this file
return os.path.join(tempfile.gettempdir(), 'gensim_models.tst')

def checkCoherenceMeasure(topics1, topics2, coherence):
"""Check provided topic coherence algorithm on given topics"""
if coherence in boolean_document_based:
cm1 = CoherenceModel(topics=topics1, corpus=corpus, dictionary=dictionary, coherence=coherence)
cm2 = CoherenceModel(topics=topics2, corpus=corpus, dictionary=dictionary, coherence=coherence)
else:
cm1 = CoherenceModel(topics=topics1, texts=texts, dictionary=dictionary, coherence=coherence)
cm2 = CoherenceModel(topics=topics2, texts=texts, dictionary=dictionary, coherence=coherence)
return cm1.get_coherence() > cm2.get_coherence()

class TestCoherenceModel(unittest.TestCase):
def setUp(self):
Expand All @@ -77,21 +66,33 @@ def setUp(self):
self.vw_path = vw_path
self.vwmodel = LdaVowpalWabbit(self.vw_path, corpus=corpus, id2word=dictionary, num_topics=2, passes=0)

def check_coherence_measure(self, coherence):
"""Check provided topic coherence algorithm on given topics"""
if coherence in boolean_document_based:
kwargs = dict(corpus=corpus, dictionary=dictionary, coherence=coherence)
cm1 = CoherenceModel(topics=self.topics1, **kwargs)
cm2 = CoherenceModel(topics=self.topics2, **kwargs)
else:
kwargs = dict(texts=texts, dictionary=dictionary, coherence=coherence)
cm1 = CoherenceModel(topics=self.topics1, **kwargs)
cm2 = CoherenceModel(topics=self.topics2, **kwargs)
self.assertGreater(cm1.get_coherence(), cm2.get_coherence())

def testUMass(self):
"""Test U_Mass topic coherence algorithm on given topics"""
self.assertTrue(checkCoherenceMeasure(self.topics1, self.topics2, 'u_mass'))
self.check_coherence_measure('u_mass')

def testCv(self):
"""Test C_v topic coherence algorithm on given topics"""
self.assertTrue(checkCoherenceMeasure(self.topics1, self.topics2, 'c_v'))
self.check_coherence_measure('c_v')

def testCuci(self):
"""Test C_uci topic coherence algorithm on given topics"""
self.assertTrue(checkCoherenceMeasure(self.topics1, self.topics2, 'c_uci'))
self.check_coherence_measure('c_uci')

def testCnpmi(self):
"""Test C_npmi topic coherence algorithm on given topics"""
self.assertTrue(checkCoherenceMeasure(self.topics1, self.topics2, 'c_npmi'))
self.check_coherence_measure('c_npmi')

def testUMassLdaModel(self):
"""Perform sanity check to see if u_mass coherence works with LDA Model"""
Expand Down Expand Up @@ -219,6 +220,7 @@ def testPersistenceCompressed(self):
model2 = CoherenceModel.load(fname)
self.assertTrue(model.get_coherence() == model2.get_coherence())


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()
Loading