Skip to content

Commit

Permalink
piskvorky#1380: Initial implementation of coherence using word2vec si…
Browse files Browse the repository at this point in the history
…milarity.
  • Loading branch information
Sweeney, Mack committed Aug 13, 2017
1 parent 718b1c6 commit a1f9127
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 14 deletions.
21 changes: 14 additions & 7 deletions gensim/models/coherencemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@

logger = logging.getLogger(__name__)

boolean_document_based = {'u_mass'}
sliding_window_based = {'c_v', 'c_uci', 'c_npmi'}
_make_pipeline = namedtuple('Coherence_Measure', 'seg, prob, conf, aggr')
BOOLEAN_DOCUMENT_BASED = {'u_mass'}
SLIDING_WINDOW_BASED = {'c_v', 'c_uci', 'c_npmi', 'c_w2v'}

_make_pipeline = namedtuple('Coherence_Measure', 'seg, prob, conf, aggr')
COHERENCE_MEASURES = {
'u_mass': _make_pipeline(
segmentation.s_one_pre,
Expand All @@ -53,6 +53,12 @@
indirect_confirmation_measure.cosine_similarity,
aggregation.arithmetic_mean
),
'c_w2v': _make_pipeline(
segmentation.s_one_set,
probability_estimation.p_word2vec,
indirect_confirmation_measure.word2vec_similarity,
aggregation.arithmetic_mean
),
'c_uci': _make_pipeline(
segmentation.s_one_one,
probability_estimation.p_boolean_sliding_window,
Expand All @@ -69,6 +75,7 @@

SLIDING_WINDOW_SIZES = {
'c_v': 110,
'c_w2v': 5,
'c_uci': 10,
'c_npmi': 10
}
Expand Down Expand Up @@ -177,7 +184,7 @@ def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=

# Check for correct inputs for u_mass coherence measure.
self.coherence = coherence
if coherence in boolean_document_based:
if coherence in BOOLEAN_DOCUMENT_BASED:
if is_corpus(corpus)[0]:
self.corpus = corpus
elif texts is not None:
Expand All @@ -189,7 +196,7 @@ def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=
"be provided for %s coherence.", coherence)

# Check for correct inputs for c_v coherence measure.
elif coherence in sliding_window_based:
elif coherence in SLIDING_WINDOW_BASED:
self.window_size = window_size
if self.window_size is None:
self.window_size = SLIDING_WINDOW_SIZES[self.coherence]
Expand Down Expand Up @@ -297,7 +304,7 @@ def estimate_probabilities(self, segmented_topics=None):
if segmented_topics is None:
segmented_topics = self.segment_topics()

if self.coherence in boolean_document_based:
if self.coherence in BOOLEAN_DOCUMENT_BASED:
self._accumulator = self.measure.prob(self.corpus, segmented_topics)
else:
self._accumulator = self.measure.prob(
Expand All @@ -315,7 +322,7 @@ def get_coherence_per_topic(self, segmented_topics=None):
if self._accumulator is None:
self.estimate_probabilities(segmented_topics)

if self.coherence in boolean_document_based:
if self.coherence in BOOLEAN_DOCUMENT_BASED or self.coherence == 'c_w2v':
kwargs = {}
elif self.coherence == 'c_v':
kwargs = dict(topics=self.topics, measure='nlr', gamma=1)
Expand Down
20 changes: 18 additions & 2 deletions gensim/test/test_coherencemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from gensim.corpora.dictionary import Dictionary
from gensim.matutils import argsort
from gensim.models.coherencemodel import CoherenceModel, boolean_document_based
from gensim.models.coherencemodel import CoherenceModel, BOOLEAN_DOCUMENT_BASED
from gensim.models.ldamodel import LdaModel
from gensim.models.wrappers import LdaMallet
from gensim.models.wrappers import LdaVowpalWabbit
Expand Down Expand Up @@ -82,7 +82,7 @@ def setUp(self):

def check_coherence_measure(self, coherence):
"""Check provided topic coherence algorithm on given topics"""
if coherence in boolean_document_based:
if coherence in BOOLEAN_DOCUMENT_BASED:
kwargs = dict(corpus=self.corpus, dictionary=self.dictionary, coherence=coherence)
else:
kwargs = dict(texts=self.texts, dictionary=self.dictionary, coherence=coherence)
Expand Down Expand Up @@ -118,6 +118,10 @@ def testCvLdaModel(self):
"""Perform sanity check to see if c_v coherence works with LDA Model"""
CoherenceModel(model=self.ldamodel, texts=self.texts, coherence='c_v')

def testCw2vLdaModel(self):
"""Perform sanity check to see if c_w2v coherence works with LDAModel."""
CoherenceModel(model=self.ldamodel, texts=self.texts, coherence='c_w2v')

def testCuciLdaModel(self):
"""Perform sanity check to see if c_uci coherence works with LDA Model"""
CoherenceModel(model=self.ldamodel, texts=self.texts, coherence='c_uci')
Expand All @@ -138,6 +142,12 @@ def testCvMalletModel(self):
return
CoherenceModel(model=self.malletmodel, texts=self.texts, coherence='c_v')

def testCw2vMalletModel(self):
"""Perform sanity check to see if c_w2v coherence works with LDA Mallet gensim wrapper"""
if not self.mallet_path:
return
CoherenceModel(model=self.malletmodel, texts=self.texts, coherence='c_w2v')

def testCuciMalletModel(self):
"""Perform sanity check to see if c_uci coherence works with LDA Mallet gensim wrapper"""
if not self.mallet_path:
Expand All @@ -162,6 +172,12 @@ def testCvVWModel(self):
return
CoherenceModel(model=self.vwmodel, texts=self.texts, coherence='c_v')

def testCw2vVWModel(self):
"""Perform sanity check to see if c_w2v coherence works with LDA VW gensim wrapper"""
if not self.vw_path:
return
CoherenceModel(model=self.vwmodel, texts=self.texts, coherence='c_w2v')

def testCuciVWModel(self):
"""Perform sanity check to see if c_uci coherence works with LDA VW gensim wrapper"""
if not self.vw_path:
Expand Down
6 changes: 3 additions & 3 deletions gensim/test/test_text_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import unittest

from gensim.corpora.dictionary import Dictionary
from gensim.topic_coherence.text_analysis import \
InvertedIndexAccumulator, WordOccurrenceAccumulator, ParallelWordOccurrenceAccumulator, \
CorpusAccumulator
from gensim.topic_coherence.text_analysis import (
InvertedIndexAccumulator, WordOccurrenceAccumulator, ParallelWordOccurrenceAccumulator,
CorpusAccumulator)


class BaseTestCases(object):
Expand Down
20 changes: 20 additions & 0 deletions gensim/topic_coherence/indirect_confirmation_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@
logger = logging.getLogger(__name__)


def word2vec_similarity(segmented_topics, accumulator):
"""For each topic segmentation, compute average cosine similarity using a
WordVectorsAccumulator.
"""
topic_similarities = np.zeros(len(segmented_topics))
for i, topic_segments in enumerate(segmented_topics):
segment_similarities = np.zeros(len(topic_segments))
for j, (w_prime, w_star) in enumerate(topic_segments):
if not hasattr(w_prime, '__iter__'):
w_prime = [w_prime]
if not hasattr(w_star, '__iter__'):
w_star = [w_star]

segment_similarities[j] = accumulator.ids_similarity(w_prime, w_star)

topic_similarities[i] = segment_similarities.mean()

return topic_similarities


def cosine_similarity(segmented_topics, accumulator, topics, measure='nlr', gamma=1):
"""
This function calculates the indirect cosine measure. Given context vectors
Expand Down
17 changes: 15 additions & 2 deletions gensim/topic_coherence/probability_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import itertools
import logging

from gensim.topic_coherence.text_analysis import \
CorpusAccumulator, WordOccurrenceAccumulator, ParallelWordOccurrenceAccumulator
from gensim.topic_coherence.text_analysis import (
CorpusAccumulator, WordOccurrenceAccumulator, ParallelWordOccurrenceAccumulator,
WordVectorsAccumulator)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,6 +61,18 @@ def p_boolean_sliding_window(texts, segmented_topics, dictionary, window_size, p
return accumulator.accumulate(texts, window_size)


def p_word2vec(texts, segmented_topics, dictionary, window_size=None, processes=1, model=None):
"""Train word2vec model on `texts` if model is not None.
Returns:
----
accumulator: text accumulator with trained context vectors.
"""
top_ids = unique_ids_from_segments(segmented_topics)
accumulator = WordVectorsAccumulator(
top_ids, dictionary, model, window=window_size, workers=processes)
return accumulator.accumulate(texts, window_size)


def unique_ids_from_segments(segmented_topics):
"""Return the set of all unique ids in a list of segmented topics.
Expand Down
60 changes: 60 additions & 0 deletions gensim/topic_coherence/text_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from six import viewitems, string_types

from gensim import utils
from gensim.models.word2vec import Word2Vec

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -467,3 +468,62 @@ def reply_to_master(self):
logger.info("serializing accumulator to return to master...")
self.output_q.put(self.accumulator, block=False)
logger.info("accumulator serialized")


class WordVectorsAccumulator(UsesDictionary):
"""Accumulate context vectors for words using word vector embeddings."""

def __init__(self, relevant_ids, dictionary, model=None, **model_kwargs):
"""
Args:
----
model: if None, a new Word2Vec model is trained on the given text corpus. If not None,
it should be a pre-trained Word2Vec context
vectors (gensim.models.keyedvectors.KeyedVectors instance).
model_kwargs: if model is None, these keyword arguments will be passed through to the
Word2Vec constructor.
"""
super(WordVectorsAccumulator, self).__init__(relevant_ids, dictionary)
self.model = model
self.model_kwargs = model_kwargs

def get_occurrences(self, word):
"""Return number of docs the word occurs in, once `accumulate` has been called."""
try:
self.token2id[word] # is this a token or an id?
except KeyError:
word = self.dictionary.id2token[word]
return self.model.vocab[word].count

def get_co_occurrences(self, word1, word2):
"""Return number of docs the words co-occur in, once `accumulate` has been called."""
raise NotImplementedError("Word2Vec model does not support co-occurrence counting")

def accumulate(self, texts, window_size):
if self.model is not None:
logger.debug("model is already trained; no accumulation necessary")
return

kwargs = self.model_kwargs.copy()
if window_size is not None:
kwargs['window'] = window_size
kwargs['min_count'] = kwargs.get('min_count', 1)
kwargs['sg'] = kwargs.get('sg', 1)
kwargs['hs'] = kwargs.get('hw', 0)

self.model = Word2Vec(**kwargs)
self.model.build_vocab(texts)
self.model.train(texts, total_examples=self.model.corpus_count, epochs=self.model.iter)
self.model = self.model.wv # retain KeyedVectors
return self

def ids_similarity(self, ids1, ids2):
if not hasattr(ids1, '__iter__'):
ids1 = [ids1]
if not hasattr(ids2, '__iter__'):
ids2 = [ids2]

words1 = [self.dictionary.id2token[word_id] for word_id in ids1]
words2 = [self.dictionary.id2token[word_id] for word_id in ids2]
return self.model.n_similarity(words1, words2)

0 comments on commit a1f9127

Please sign in to comment.