Skip to content

Commit

Permalink
piskvorky#1380: Add the keyed_vectors kwarg to the CoherenceModel
Browse files Browse the repository at this point in the history
… to allow passing in pre-trained, pre-loaded word embeddings, and adjust the similarity measure to handle missing terms in the vocabulary. Add a `with_std` option to all confirmation measures that allows the caller to get the standard deviation between the topic segment sets as well as the means.
  • Loading branch information
Sweeney, Mack committed Aug 13, 2017
1 parent a1f9127 commit 345a644
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 48 deletions.
38 changes: 23 additions & 15 deletions gensim/models/coherencemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@
'c_v': 110,
'c_w2v': 5,
'c_uci': 10,
'c_npmi': 10
'c_npmi': 10,
'u_mass': None
}


Expand Down Expand Up @@ -118,7 +119,7 @@ class CoherenceModel(interfaces.TransformationABC):
Model persistency is achieved via its load/save methods.
"""
def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=None,
window_size=None, coherence='c_v', topn=10, processes=-1):
window_size=None, keyed_vectors=None, coherence='c_v', topn=10, processes=-1):
"""
Args:
model : Pre-trained topic model. Should be provided if topics is not provided.
Expand Down Expand Up @@ -168,7 +169,8 @@ def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=
elif topics is not None and dictionary is None:
raise ValueError("dictionary has to be provided if topics are to be used.")

if texts is None and corpus is None:
self.keyed_vectors = keyed_vectors
if keyed_vectors is None and texts is None and corpus is None:
raise ValueError("One of texts or corpus has to be provided.")

# Check if associated dictionary is provided.
Expand All @@ -184,26 +186,28 @@ def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=

# Check for correct inputs for u_mass coherence measure.
self.coherence = coherence
self.window_size = window_size
if self.window_size is None:
self.window_size = SLIDING_WINDOW_SIZES[self.coherence]
self.texts = texts
self.corpus = corpus

if coherence in BOOLEAN_DOCUMENT_BASED:
if is_corpus(corpus)[0]:
self.corpus = corpus
elif texts is not None:
self.texts = texts
elif self.texts is not None:
self.corpus = [self.dictionary.doc2bow(text) for text in self.texts]
else:
raise ValueError(
"Either 'corpus' with 'dictionary' or 'texts' should "
"be provided for %s coherence.", coherence)

# Check for correct inputs for c_v coherence measure.
# Check for correct inputs for sliding window coherence measure.
elif coherence == 'c_w2v' and keyed_vectors is not None:
pass
elif coherence in SLIDING_WINDOW_BASED:
self.window_size = window_size
if self.window_size is None:
self.window_size = SLIDING_WINDOW_SIZES[self.coherence]
if texts is None:
if self.texts is None:
raise ValueError("'texts' should be provided for %s coherence.", coherence)
else:
self.texts = texts
else:
raise ValueError("%s coherence is not currently supported.", coherence)

Expand Down Expand Up @@ -307,14 +311,18 @@ def estimate_probabilities(self, segmented_topics=None):
if self.coherence in BOOLEAN_DOCUMENT_BASED:
self._accumulator = self.measure.prob(self.corpus, segmented_topics)
else:
self._accumulator = self.measure.prob(
kwargs = dict(
texts=self.texts, segmented_topics=segmented_topics,
dictionary=self.dictionary, window_size=self.window_size,
processes=self.processes)
if self.coherence == 'c_w2v':
kwargs['model'] = self.keyed_vectors

self._accumulator = self.measure.prob(**kwargs)

return self._accumulator

def get_coherence_per_topic(self, segmented_topics=None):
def get_coherence_per_topic(self, segmented_topics=None, with_std=False):
"""Return list of coherence values for each topic based on pipeline parameters."""
measure = self.measure
if segmented_topics is None:
Expand All @@ -323,7 +331,7 @@ def get_coherence_per_topic(self, segmented_topics=None):
self.estimate_probabilities(segmented_topics)

if self.coherence in BOOLEAN_DOCUMENT_BASED or self.coherence == 'c_w2v':
kwargs = {}
kwargs = dict(with_std=with_std)
elif self.coherence == 'c_v':
kwargs = dict(topics=self.topics, measure='nlr', gamma=1)
else:
Expand Down
34 changes: 23 additions & 11 deletions gensim/topic_coherence/direct_confirmation_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@
EPSILON = 1e-12 # Should be small. Value as suggested in paper.


def log_conditional_probability(segmented_topics, accumulator):
def log_conditional_probability(segmented_topics, accumulator, with_std=False):
"""
This function calculates the log-conditional-probability measure
which is used by coherence measures such as U_mass.
This is defined as: m_lc(S_i) = log[(P(W', W*) + e) / P(W*)]
Args:
segmented_topics : Output from the segmentation module of the segmented topics.
segmented_topics: Output from the segmentation module of the segmented topics.
Is a list of list of tuples.
accumulator: word occurrence accumulator from probability_estimation.
with_std (bool): True to also include standard deviation across topic segment
sets in addition to the mean coherence for each topic; default is False.
Returns:
m_lc : List of log conditional probability measure for each topic.
Expand All @@ -36,20 +38,24 @@ def log_conditional_probability(segmented_topics, accumulator):
for s_i in segmented_topics:
segment_sims = []
for w_prime, w_star in s_i:
w_star_count = accumulator[w_star]
if w_star_count == 0:
raise ValueError("Topic with id %d not found in corpus used to compute coherence. "
"Try using a larger corpus with a smaller vocobulary and/or setting a smaller value of `topn` for `CoherenceModel`." % (w_star))
co_occur_count = accumulator[w_prime, w_star]
m_lc_i = np.log(((co_occur_count / num_docs) + EPSILON) / (w_star_count / num_docs))
try:
w_star_count = accumulator[w_star]
co_occur_count = accumulator[w_prime, w_star]
m_lc_i = np.log(((co_occur_count / num_docs) + EPSILON) / (w_star_count / num_docs))
except KeyError:
m_lc_i = 0.0

segment_sims.append(m_lc_i)
m_lc.append(np.mean(segment_sims))

if with_std:
m_lc.append((np.mean(segment_sims), np.std(segment_sims)))
else:
m_lc.append(np.mean(segment_sims))

return m_lc


def log_ratio_measure(segmented_topics, accumulator, normalize=False):
def log_ratio_measure(segmented_topics, accumulator, normalize=False, with_std=False):
"""
If normalize=False:
Popularly known as PMI.
Expand All @@ -66,6 +72,8 @@ def log_ratio_measure(segmented_topics, accumulator, normalize=False):
segmented topics : Output from the segmentation module of the segmented topics.
Is a list of list of tuples.
accumulator: word occurrence accumulator from probability_estimation.
with_std (bool): True to also include standard deviation across topic segment
sets in addition to the mean coherence for each topic; default is False.
Returns:
m_lr : List of log ratio measures for each topic.
Expand All @@ -91,6 +99,10 @@ def log_ratio_measure(segmented_topics, accumulator, normalize=False):
m_lr_i = np.log(numerator / denominator)

segment_sims.append(m_lr_i)
m_lr.append(np.mean(segment_sims))

if with_std:
m_lr.append((np.mean(segment_sims), np.std(segment_sims)))
else:
m_lr.append(np.mean(segment_sims))

return m_lr
55 changes: 41 additions & 14 deletions gensim/topic_coherence/indirect_confirmation_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,47 @@
logger = logging.getLogger(__name__)


def word2vec_similarity(segmented_topics, accumulator):
def word2vec_similarity(segmented_topics, accumulator, with_std=False):
"""For each topic segmentation, compute average cosine similarity using a
WordVectorsAccumulator.
Args:
----
segmented_topics : Output from the segmentation module of the segmented topics.
Is a list of list of tuples.
accumulator : Output from the probability_estimation module.
Is an accumulator of word occurrences (see text_analysis module).
with_std : True to also include standard deviation across topic segment sets in addition
to the mean coherence for each topic; default is False.
Returns:
-------
topic_coherences : list of word2vec cosine similarities per topic.
"""
topic_similarities = np.zeros(len(segmented_topics))
topic_coherences = np.zeros(len(segmented_topics))
for i, topic_segments in enumerate(segmented_topics):
segment_similarities = np.zeros(len(topic_segments))
for j, (w_prime, w_star) in enumerate(topic_segments):
segment_similarities = []
for w_prime, w_star in topic_segments:
if not hasattr(w_prime, '__iter__'):
w_prime = [w_prime]
if not hasattr(w_star, '__iter__'):
w_star = [w_star]

segment_similarities[j] = accumulator.ids_similarity(w_prime, w_star)
try:
segment_similarities.append(accumulator.ids_similarity(w_prime, w_star))
except ZeroDivisionError:
logger.warn("at least one topic word not in word2vec model")

topic_similarities[i] = segment_similarities.mean()
if with_std:
topic_coherences[i] = (np.mean(segment_similarities), np.std(segment_similarities))
else:
topic_coherences[i] = np.mean(segment_similarities)

return topic_similarities
return topic_coherences


def cosine_similarity(segmented_topics, accumulator, topics, measure='nlr', gamma=1):
def cosine_similarity(segmented_topics, accumulator, topics, measure='nlr', gamma=1,
with_std=False):
"""
This function calculates the indirect cosine measure. Given context vectors
u = V(W') and w = V(W*) for the word sets of a pair S_i = (W', W*) indirect
Expand All @@ -70,15 +90,18 @@ def cosine_similarity(segmented_topics, accumulator, topics, measure='nlr', gamm
\vec{V}^{\,}_{m,\gamma}(W') = \Bigg \{{\sum_{w_{i} \in W'}^{ } m(w_{i}, w_{j})^{\gamma}}\Bigg \}_{j = 1,...,|W|}
Args:
segmented_topics : Output from the segmentation module of the segmented topics. Is a list of list of tuples.
accumulator : Output from the probability_estimation module. Is an accumulator of word occurrences (see text_analysis module).
segmented_topics : Output from the segmentation module of the segmented topics.
Is a list of list of tuples.
accumulator : Output from the probability_estimation module. Is an accumulator
of word occurrences (see text_analysis module).
topics : Topics obtained from the trained topic model.
measure : String. Direct confirmation measure to be used. Supported values are "nlr" (normalized log ratio).
measure (str): Direct confirmation measure to be used. Supported values are
"nlr" (normalized log ratio).
gamma : Gamma value for computing W', W* vectors; default is 1.
with_std (bool): True to also include standard deviation across topic segment
sets in addition to the mean coherence for each topic; default is False.
Returns:
s_cos_sim : list of indirect cosine similarity measure for each topic.
"""
Expand All @@ -92,7 +115,11 @@ def cosine_similarity(segmented_topics, accumulator, topics, measure='nlr', gamm
w_prime_cv = context_vectors[w_prime, topic_words]
w_star_cv = context_vectors[w_star, topic_words]
segment_sims[i] = _cossim(w_prime_cv, w_star_cv)
s_cos_sim.append(np.mean(segment_sims))

if with_std:
s_cos_sim.append((np.mean(segment_sims), np.std(segment_sims)))
else:
s_cos_sim.append(np.mean(segment_sims))

return s_cos_sim

Expand Down
18 changes: 10 additions & 8 deletions gensim/topic_coherence/text_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def get_co_occurrences(self, word1, word2):
def accumulate(self, texts, window_size):
if self.model is not None:
logger.debug("model is already trained; no accumulation necessary")
return
return self

kwargs = self.model_kwargs.copy()
if window_size is not None:
Expand All @@ -518,12 +518,14 @@ def accumulate(self, texts, window_size):
return self

def ids_similarity(self, ids1, ids2):
if not hasattr(ids1, '__iter__'):
ids1 = [ids1]
if not hasattr(ids2, '__iter__'):
ids2 = [ids2]

words1 = [self.dictionary.id2token[word_id] for word_id in ids1]
words2 = [self.dictionary.id2token[word_id] for word_id in ids2]
words1 = self._words_with_embeddings(ids1)
words2 = self._words_with_embeddings(ids2)
return self.model.n_similarity(words1, words2)

def _words_with_embeddings(self, ids):
if not hasattr(ids, '__iter__'):
ids = [ids]

words = [self.dictionary.id2token[word_id] for word_id in ids]
return [word for word in words if word in self.model.vocab]

0 comments on commit 345a644

Please sign in to comment.