-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow use of truncated Dictionary for coherence measures #1349
Changes from 13 commits
9d06a1f
f69a2ff
26de547
dfe159b
2e3852e
ec7af1b
3f8fb7f
b12edef
91b8a05
c6224b7
f00d389
327b739
3c0752b
e06c7c3
2ca43f7
825b0e9
314a400
e785773
5f78cdb
bbd2748
5fb0b95
880b8d0
1d32b8e
8e04b41
e3ce402
7f7f55d
343da69
1ce8a72
96fd343
a631ab6
5f58bda
b941f3c
75fcac8
96d1349
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
import re | ||
from xml.etree.cElementTree import iterparse # LXML isn't faster, so let's go with the built-in solution | ||
import multiprocessing | ||
import signal | ||
|
||
from gensim import utils | ||
|
||
|
@@ -249,6 +250,10 @@ def process_article(args): | |
return result, title, pageid | ||
|
||
|
||
def init_worker(): | ||
signal.signal(signal.SIGINT, signal.SIG_IGN) | ||
|
||
|
||
class WikiCorpus(TextCorpus): | ||
""" | ||
Treat a wikipedia articles dump (\*articles.xml.bz2) as a (read-only) corpus. | ||
|
@@ -300,22 +305,26 @@ def get_texts(self): | |
articles, articles_all = 0, 0 | ||
positions, positions_all = 0, 0 | ||
texts = ((text, self.lemmatize, title, pageid) for title, text, pageid in extract_pages(bz2.BZ2File(self.fname), self.filter_namespaces)) | ||
pool = multiprocessing.Pool(self.processes) | ||
pool = multiprocessing.Pool(self.processes, init_worker) | ||
# process the corpus in smaller chunks of docs, because multiprocessing.Pool | ||
# is dumb and would load the entire input into RAM at once... | ||
for group in utils.chunkize(texts, chunksize=10 * self.processes, maxsize=1): | ||
for tokens, title, pageid in pool.imap(process_article, group): # chunksize=10): | ||
articles_all += 1 | ||
positions_all += len(tokens) | ||
# article redirects and short stubs are pruned here | ||
if len(tokens) < ARTICLE_MIN_WORDS or any(title.startswith(ignore + ':') for ignore in IGNORED_NAMESPACES): | ||
continue | ||
articles += 1 | ||
positions += len(tokens) | ||
if self.metadata: | ||
yield (tokens, (pageid, title)) | ||
else: | ||
yield tokens | ||
try: | ||
for group in utils.chunkize(texts, chunksize=10 * self.processes, maxsize=1): | ||
for tokens, title, pageid in pool.imap(process_article, group): # chunksize=10): | ||
articles_all += 1 | ||
positions_all += len(tokens) | ||
# article redirects and short stubs are pruned here | ||
if len(tokens) < ARTICLE_MIN_WORDS or any(title.startswith(ignore + ':') for ignore in IGNORED_NAMESPACES): | ||
continue | ||
articles += 1 | ||
positions += len(tokens) | ||
if self.metadata: | ||
yield (tokens, (pageid, title)) | ||
else: | ||
yield tokens | ||
except KeyboardInterrupt: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this for? Masking user interrupts is an anti-pattern; deserves a detailed comment, at the very least. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've updated the code to handle the interrupt more appropriately. As for why I chose to handle the interrupt at all (at the risk of repeating the above comment): this pool may be active during many other phases of gensim execution if the underlying corpus/texts being iterated come from the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, not sure I understand. When iterating over a Now when you do CTRL+C, what happened (old behaviour)? And what happens now (after your changes here)? Why is that preferable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is the code I am running:
Before:
After:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the advantage is a cleaner log? |
||
pass | ||
|
||
pool.terminate() | ||
|
||
logger.info( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,24 +19,25 @@ | |
""" | ||
|
||
import logging | ||
from collections import namedtuple | ||
import multiprocessing as mp | ||
|
||
import numpy as np | ||
|
||
from gensim import interfaces | ||
from gensim.matutils import argsort | ||
from gensim.models.ldamodel import LdaModel | ||
from gensim.models.wrappers import LdaVowpalWabbit, LdaMallet | ||
from gensim.topic_coherence import (segmentation, probability_estimation, | ||
direct_confirmation_measure, indirect_confirmation_measure, | ||
aggregation) | ||
from gensim.matutils import argsort | ||
from gensim.topic_coherence.probability_estimation import unique_ids_from_segments | ||
from gensim.utils import is_corpus, FakeDict | ||
from gensim.models.ldamodel import LdaModel | ||
from gensim.models.wrappers import LdaVowpalWabbit, LdaMallet | ||
|
||
import numpy as np | ||
|
||
from collections import namedtuple | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
boolean_document_based = ['u_mass'] | ||
sliding_window_based = ['c_v', 'c_uci', 'c_npmi'] | ||
boolean_document_based = {'u_mass'} | ||
sliding_window_based = {'c_v', 'c_uci', 'c_npmi'} | ||
make_pipeline = namedtuple('Coherence_Measure', 'seg, prob, conf, aggr') | ||
|
||
coherence_dict = { | ||
|
@@ -64,10 +65,9 @@ | |
'c_npmi': 10 | ||
} | ||
|
||
|
||
class CoherenceModel(interfaces.TransformationABC): | ||
""" | ||
Objects of this class allow for building and maintaining a model for topic | ||
coherence. | ||
"""Objects of this class allow for building and maintaining a model for topic coherence. | ||
|
||
The main methods are: | ||
|
||
|
@@ -89,7 +89,8 @@ class CoherenceModel(interfaces.TransformationABC): | |
|
||
Model persistency is achieved via its load/save methods. | ||
""" | ||
def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=None, window_size=None, coherence='c_v', topn=10): | ||
def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=None, window_size=None, | ||
coherence='c_v', topn=10, processes=-1): | ||
""" | ||
Args: | ||
---- | ||
|
@@ -128,8 +129,10 @@ def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary= | |
raise ValueError("One of model or topics has to be provided.") | ||
elif topics is not None and dictionary is None: | ||
raise ValueError("dictionary has to be provided if topics are to be used.") | ||
|
||
if texts is None and corpus is None: | ||
raise ValueError("One of texts or corpus has to be provided.") | ||
|
||
# Check if associated dictionary is provided. | ||
if dictionary is None: | ||
if isinstance(model.id2word, FakeDict): | ||
|
@@ -139,7 +142,9 @@ def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary= | |
self.dictionary = model.id2word | ||
else: | ||
self.dictionary = dictionary | ||
|
||
# Check for correct inputs for u_mass coherence measure. | ||
self.coherence = coherence | ||
if coherence in boolean_document_based: | ||
if is_corpus(corpus)[0]: | ||
self.corpus = corpus | ||
|
@@ -148,30 +153,72 @@ def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary= | |
self.corpus = [self.dictionary.doc2bow(text) for text in self.texts] | ||
else: | ||
raise ValueError("Either 'corpus' with 'dictionary' or 'texts' should be provided for %s coherence." % coherence) | ||
|
||
# Check for correct inputs for c_v coherence measure. | ||
elif coherence in sliding_window_based: | ||
self.window_size = window_size | ||
if self.window_size is None: | ||
self.window_size = sliding_windows_dict[self.coherence] | ||
if texts is None: | ||
raise ValueError("'texts' should be provided for %s coherence." % coherence) | ||
else: | ||
self.texts = texts | ||
else: | ||
raise ValueError("%s coherence is not currently supported." % coherence) | ||
|
||
self.topn = topn | ||
self.model = model | ||
if model is not None: | ||
self.topics = self._get_topics() | ||
|
||
self._accumulator = None | ||
self._topics = None | ||
self.topics = topics | ||
|
||
self.processes = processes if processes > 1 else max(1, mp.cpu_count() - 1) | ||
|
||
def __str__(self): | ||
return str(self.measure) | ||
|
||
@property | ||
def measure(self): | ||
return coherence_dict[self.coherence] | ||
|
||
@property | ||
def topics(self): | ||
return self._topics | ||
|
||
@topics.setter | ||
def topics(self, topics): | ||
new_topics = None | ||
if self.model is not None: | ||
new_topics = self._get_topics() | ||
if topics is not None: | ||
logger.warn("Ignoring topics you are attempting to set in favor of model's topics: %s" % self.model) | ||
elif topics is not None: | ||
self.topics = [] | ||
new_topics = [] | ||
for topic in topics: | ||
t_i = [] | ||
for n, _ in enumerate(topic): | ||
t_i.append(dictionary.token2id[topic[n]]) | ||
self.topics.append(np.array(t_i)) | ||
self.coherence = coherence | ||
t_i = np.array([self.dictionary.token2id[topic[n]] for n, _ in enumerate(topic)]) | ||
new_topics.append(np.array(t_i)) | ||
|
||
def __str__(self): | ||
return coherence_dict[self.coherence].__str__() | ||
if self._relevant_ids_will_differ(new_topics): | ||
logger.debug("Wiping cached accumulator since it does not contain all relevant ids.") | ||
self._accumulator = None | ||
|
||
self._topics = new_topics | ||
|
||
def _relevant_ids_will_differ(self, new_topics): | ||
if not self._topics_differ(new_topics): | ||
return False | ||
|
||
measure = self.measure | ||
current_set = unique_ids_from_segments(measure.seg(self.topics)) | ||
new_set = unique_ids_from_segments(measure.seg(new_topics)) | ||
return not current_set.issuperset(new_set) | ||
|
||
def _topics_differ(self, new_topics): | ||
return (new_topics is not None and | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No vertical indent -- please use hanging indent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
self._topics is not None and | ||
self._accumulator is not None and | ||
not np.equal(new_topics, self._topics).all()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what type the arguments are, but There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
def _get_topics(self): | ||
"""Internal helper function to return topics from a trained topic model.""" | ||
|
@@ -193,27 +240,51 @@ def _get_topics(self): | |
"LdaModel, LdaVowpalWabbit and LdaMallet.") | ||
return topics | ||
|
||
def get_coherence(self): | ||
""" | ||
Return coherence value based on pipeline parameters. | ||
def segment_topics(self): | ||
return self.measure.seg(self.topics) | ||
|
||
def estimate_probabilities(self, segmented_topics=None): | ||
"""Accumulate word occurrences and co-occurrences from texts or corpus using | ||
the optimal method for the chosen coherence metric. This operation may take | ||
quite some time for the sliding window based coherence methods. | ||
""" | ||
measure = coherence_dict[self.coherence] | ||
segmented_topics = measure.seg(self.topics) | ||
if segmented_topics is None: | ||
segmented_topics = self.segment_topics() | ||
|
||
if self.coherence in boolean_document_based: | ||
per_topic_postings, num_docs = measure.prob(self.corpus, segmented_topics) | ||
confirmed_measures = measure.conf(segmented_topics, per_topic_postings, num_docs) | ||
elif self.coherence in sliding_window_based: | ||
if self.window_size is not None: | ||
self.window_size = sliding_windows_dict[self.coherence] | ||
per_topic_postings, num_windows = measure.prob(texts=self.texts, segmented_topics=segmented_topics, | ||
dictionary=self.dictionary, window_size=self.window_size) | ||
if self.coherence == 'c_v': | ||
confirmed_measures = measure.conf(self.topics, segmented_topics, per_topic_postings, 'nlr', 1, num_windows) | ||
else: | ||
if self.coherence == 'c_npmi': | ||
normalize = True | ||
else: | ||
# For c_uci | ||
normalize = False | ||
confirmed_measures = measure.conf(segmented_topics, per_topic_postings, num_windows, normalize=normalize) | ||
return measure.aggr(confirmed_measures) | ||
self._accumulator = self.measure.prob(self.corpus, segmented_topics) | ||
else: | ||
self._accumulator = self.measure.prob( | ||
texts=self.texts, segmented_topics=segmented_topics, | ||
dictionary=self.dictionary, window_size=self.window_size, | ||
processes=self.processes) | ||
|
||
return self._accumulator | ||
|
||
def get_coherence_per_topic(self, segmented_topics=None): | ||
"""Return list of coherence values for each topic based on pipeline parameters.""" | ||
measure = self.measure | ||
if segmented_topics is None: | ||
segmented_topics = measure.seg(self.topics) | ||
if self._accumulator is None: | ||
self.estimate_probabilities(segmented_topics) | ||
|
||
if self.coherence in boolean_document_based: | ||
kwargs = {} | ||
elif self.coherence == 'c_v': | ||
kwargs = dict(topics=self.topics, measure='nlr', gamma=1) | ||
else: | ||
kwargs = dict(normalize=(self.coherence == 'c_npmi')) | ||
|
||
return measure.conf(segmented_topics, self._accumulator, **kwargs) | ||
|
||
def aggregate_measures(self, confirmed_measures): | ||
"""Aggregate the individual topic coherence measures using | ||
the pipeline's aggregation function. | ||
""" | ||
return self.measure.aggr(confirmed_measures) | ||
|
||
def get_coherence(self): | ||
"""Return coherence value based on pipeline parameters.""" | ||
confirmed_measures = self.get_coherence_per_topic() | ||
return self.aggregate_measures(confirmed_measures) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this change about? Needs a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@piskvorky I was running
get_coherence
usingcoherence="c_v"
and trying to program it to gracefully handleKeyboardInterrupt
. You can see the mechanisms I put in place for this in thetext_analysis
module in the PR. While doing this, I faced some confusion because some background process was still raisingKeyboardInterrupt
. After some digging, I noticed that thewikicorpus
pool workers were the culprit.This pool may be active during many other phases of gensim execution if the underlying corpus/texts being iterated come from the
wikicorpus
. I think it makes things slightly cleaner in this case to handle theKeyboardInterrupt
in some manner that does not propagate. Perhaps some sort of logging would improve this?