Skip to content

Commit

Permalink
piskvorky#1380: Make topn a property so setting it to higher values…
Browse files Browse the repository at this point in the history
… will uncache the accumulator and the topics will be shrunk/expanded accordingly.
  • Loading branch information
Sweeney, Mack committed Aug 13, 2017
1 parent 92e5455 commit f8ecab7
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 5 deletions.
30 changes: 25 additions & 5 deletions gensim/models/coherencemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

import numpy as np

from gensim import interfaces
from gensim.matutils import argsort
from gensim import interfaces, matutils
from gensim.topic_coherence import (segmentation, probability_estimation,
direct_confirmation_measure, indirect_confirmation_measure,
aggregation)
Expand Down Expand Up @@ -209,7 +208,7 @@ def __init__(self, model=None, topics=None, texts=None, corpus=None, dictionary=
else:
raise ValueError("%s coherence is not currently supported.", coherence)

self.topn = topn
self._topn = topn
self._model = model
self._accumulator = None
self._topics = None
Expand All @@ -232,13 +231,34 @@ def model(self, model):
self._update_accumulator(new_topics)
self._topics = new_topics

@property
def topn(self):
return self._topn

@topn.setter
def topn(self, topn):
current_topic_length = len(self._topics[0])
requires_expansion = current_topic_length < topn

if self.model is not None:
self._topn = topn
if requires_expansion:
self.model = self._model # trigger topic expansion from model
else:
if requires_expansion:
raise ValueError("Model unavailable and topic sizes are less than topn=%d" % topn)
self._topn = topn # topics will be truncated in getter

@property
def measure(self):
return COHERENCE_MEASURES[self.coherence]

@property
def topics(self):
return self._topics
if len(self._topics[0]) > self._topn:
return [topic[:self._topn] for topic in self._topics]
else:
return self._topics

@topics.setter
def topics(self, topics):
Expand Down Expand Up @@ -279,7 +299,7 @@ def _get_topics(self):
"""Internal helper function to return topics from a trained topic model."""
try:
return [
argsort(topic, topn=self.topn, reverse=True) for topic in
matutils.argsort(topic, topn=self.topn, reverse=True) for topic in
self.model.get_topics()
]
except AttributeError:
Expand Down
39 changes: 39 additions & 0 deletions gensim/test/test_coherencemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,45 @@ def testAccumulatorCachingWithModelSetting(self):
self.assertTrue(np.array_equal(topics, cm1.topics))
self.assertIsNone(cm1._accumulator)

def testAccumulatorCachingWithTopnSettingGivenTopics(self):
kwargs = dict(corpus=self.corpus, dictionary=self.dictionary, topn=5, coherence='u_mass')
cm1 = CoherenceModel(topics=self.topics1, **kwargs)
cm1.estimate_probabilities()
self.assertIsNotNone(cm1._accumulator)

accumulator = cm1._accumulator
topics_before = cm1._topics
cm1.topn = 3
self.assertEqual(accumulator, cm1._accumulator)
self.assertEqual(3, len(cm1.topics[0]))
self.assertEqual(topics_before, cm1._topics)

# Topics should not have been truncated, so topn settings below 5 should work
cm1.topn = 4
self.assertEqual(accumulator, cm1._accumulator)
self.assertEqual(4, len(cm1.topics[0]))
self.assertEqual(topics_before, cm1._topics)

with self.assertRaises(ValueError):
cm1.topn = 6 # can't expand topics any further without model

def testAccumulatorCachingWithTopnSettingGivenModel(self):
kwargs = dict(corpus=self.corpus, dictionary=self.dictionary, topn=5, coherence='u_mass')
cm1 = CoherenceModel(model=self.ldamodel, **kwargs)
cm1.estimate_probabilities()
self.assertIsNotNone(cm1._accumulator)

accumulator = cm1._accumulator
topics_before = cm1._topics
cm1.topn = 3
self.assertEqual(accumulator, cm1._accumulator)
self.assertEqual(3, len(cm1.topics[0]))
self.assertEqual(topics_before, cm1._topics)

cm1.topn = 6 # should be able to expand given the model
self.assertIsNone(cm1._accumulator) # should uncache due to missing terms in accumulator
self.assertEqual(6, len(cm1.topics[0]))


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
Expand Down

0 comments on commit f8ecab7

Please sign in to comment.