Skip to content

Commit

Permalink
piskvorky#1380: Add a get_topics method to all topic models, add te…
Browse files Browse the repository at this point in the history
…st coverage for this, and update the `CoherenceModel` to use this for getting topics from models.
  • Loading branch information
Sweeney, Mack committed Aug 13, 2017
1 parent 94fe67b commit 24686ce
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 48 deletions.
28 changes: 9 additions & 19 deletions gensim/models/coherencemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@

from gensim import interfaces
from gensim.matutils import argsort
from gensim.models.ldamodel import LdaModel
from gensim.models.wrappers import LdaVowpalWabbit, LdaMallet
from gensim.topic_coherence import (segmentation, probability_estimation,
direct_confirmation_measure, indirect_confirmation_measure,
aggregation)
Expand Down Expand Up @@ -279,23 +277,15 @@ def _topics_differ(self, new_topics):

def _get_topics(self):
"""Internal helper function to return topics from a trained topic model."""
topics = []
if isinstance(self.model, LdaModel):
for topic in self.model.state.get_lambda():
bestn = argsort(topic, topn=self.topn, reverse=True)
topics.append(bestn)
elif isinstance(self.model, LdaVowpalWabbit):
for topic in self.model._get_topics():
bestn = argsort(topic, topn=self.topn, reverse=True)
topics.append(bestn)
elif isinstance(self.model, LdaMallet):
for topic in self.model.word_topics:
bestn = argsort(topic, topn=self.topn, reverse=True)
topics.append(bestn)
else:
raise ValueError("This topic model is not currently supported. Supported topic models "
" are LdaModel, LdaVowpalWabbit and LdaMallet.")
return topics
try:
return [
argsort(topic, topn=self.topn, reverse=True) for topic in
self.model.get_topics()
]
except AttributeError:
raise ValueError(
"This topic model is not currently supported. Supported topic models"
" should implement the `get_topics` method.")

def segment_topics(self):
return self.measure.seg(self.topics)
Expand Down
10 changes: 9 additions & 1 deletion gensim/models/hdpmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@
import logging
import time
import warnings

import numpy as np
from scipy.special import gammaln, psi # gamma function utils
from six.moves import xrange

from gensim import interfaces, utils, matutils
from gensim.matutils import dirichlet_expectation
from gensim.models import basemodel, ldamodel
from six.moves import xrange

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -456,6 +457,13 @@ def show_topic(self, topic_id, topn=20, log=False, formatted=False, num_words=No
hdp_formatter = HdpTopicFormatter(self.id2word, betas)
return hdp_formatter.show_topic(topic_id, topn, log, formatted)

def get_topics(self):
"""
Return the term topic matrix learned during inference.
This is a `num_topics` x `vocabulary_size` np.ndarray of floats.
"""
return self.m_lambda + self.m_eta

def show_topics(self, num_topics=20, num_words=20, log=False, formatted=True):
"""
Print the `num_words` most probable words for `num_topics` number of topics.
Expand Down
31 changes: 19 additions & 12 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,21 @@


import logging
import numpy as np
import numbers
from random import sample
import os

from gensim import interfaces, utils, matutils
from gensim.matutils import dirichlet_expectation
from gensim.models import basemodel
from gensim.matutils import kullback_leibler, hellinger, jaccard_distance

from itertools import chain
from random import sample

import numpy as np
import six
from scipy.special import gammaln, psi # gamma function utils
from scipy.special import polygamma
from six.moves import xrange
import six

from gensim import interfaces, utils, matutils
from gensim.matutils import dirichlet_expectation
from gensim.matutils import kullback_leibler, hellinger, jaccard_distance
from gensim.models import basemodel

# log(sum(exp(x))) that tries to avoid overflow
try:
Expand Down Expand Up @@ -815,6 +815,13 @@ def show_topic(self, topicid, topn=10):
"""
return [(self.id2word[id], value) for id, value in self.get_topic_terms(topicid, topn)]

def get_topics(self):
"""
Return the term topic matrix learned during inference.
This is a `num_topics` x `vocabulary_size` np.ndarray of floats.
"""
return self.state.get_lambda()

def get_topic_terms(self, topicid, topn=10):
"""
Return a list of `(word_id, probability)` 2-tuples for the most
Expand All @@ -823,7 +830,7 @@ def get_topic_terms(self, topicid, topn=10):
Only return 2-tuples for the topn most probable words (ignore the rest).
"""
topic = self.state.get_lambda()[topicid]
topic = self.get_topics()[topicid]
topic = topic / topic.sum() # normalize to probability distribution
bestn = matutils.argsort(topic, topn, reverse=True)
return [(id, topic[id]) for id in bestn]
Expand All @@ -840,7 +847,7 @@ def top_topics(self, corpus, num_words=20):

topics = []
str_topics = []
for topic in self.state.get_lambda():
for topic in self.get_topics():
topic = topic / topic.sum() # normalize to probability distribution
bestn = matutils.argsort(topic, topn=num_words, reverse=True)
topics.append(bestn)
Expand Down Expand Up @@ -1013,7 +1020,7 @@ def diff(self, other, distance="kullback_leibler", num_words=100, n_ann_terms=10
raise ValueError("The parameter `other` must be of type `{}`".format(self.__name__))

distance_func = distances[distance]
d1, d2 = self.state.get_lambda(), other.state.get_lambda()
d1, d2 = self.get_topics(), other.get_topics()
t1_size, t2_size = d1.shape[0], d2.shape[0]
annotation_terms = None

Expand Down
20 changes: 16 additions & 4 deletions gensim/models/lsimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,11 @@
import scipy.linalg
import scipy.sparse
from scipy.sparse import sparsetools

from gensim import interfaces, matutils, utils
from gensim.models import basemodel

from six import iterkeys
from six.moves import xrange

from gensim import interfaces, matutils, utils
from gensim.models import basemodel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -470,6 +468,20 @@ def __getitem__(self, bow, scaled=False, chunksize=512):
result = matutils.Dense2Corpus(topic_dist)
return result

def get_topics(self):
"""
Return the term topic matrix learned during inference.
This is a `num_topics` x `vocabulary_size` np.ndarray of floats.
NOTE: The number of topics can actually be smaller than `self.num_topics`,
if there were not enough factors (real rank of input matrix smaller than
`self.num_topics`).
"""
projections = self.projection.u.T
num_topics = len(projections)
topics = [np.asarray(projections[i, :]).flatten() for i in range(num_topics)]
return np.array(topics)

def show_topic(self, topicno, topn=10):
"""
Return a specified topic (=left singular vector), 0 <= `topicno` < `self.num_topics`,
Expand Down
18 changes: 11 additions & 7 deletions gensim/models/wrappers/ldamallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,19 @@


import logging
import os
import random
import tempfile
import os

import numpy

import xml.etree.ElementTree as et
import zipfile

from six import iteritems
import numpy
from smart_open import smart_open

from gensim import utils, matutils
from gensim.utils import check_output, revdict
from gensim.models.ldamodel import LdaModel
from gensim.models import basemodel
from gensim.models.ldamodel import LdaModel
from gensim.utils import check_output, revdict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -213,6 +210,13 @@ def load_document_topics(self):
"""
return self.read_doctopics(self.fdoctopics())

def get_topics(self):
"""
Return the term topic matrix learned during inference.
This is a `num_topics` x `vocabulary_size` np.ndarray of floats.
"""
return self.word_topics

def show_topics(self, num_topics=10, num_words=10, log=False, formatted=True):
"""
Print the `num_words` most probable words for `num_topics` number of topics.
Expand Down
15 changes: 11 additions & 4 deletions gensim/models/wrappers/ldavowpalwabbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@
.. [2] http://www.cs.princeton.edu/~mdhoffma/
"""

from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import logging
import tempfile
import os
import shutil
import subprocess
import tempfile

import numpy

Expand Down Expand Up @@ -235,6 +235,13 @@ def log_perplexity(self, chunk):
corpus_words)
return bound

def get_topics(self):
"""
Return the term topic matrix learned during inference.
This is a `num_topics` x `vocabulary_size` np.ndarray of floats.
"""
return self._get_topics()

def print_topics(self, num_topics=10, num_words=10):
return self.show_topics(num_topics, num_words, log=True)

Expand Down
12 changes: 11 additions & 1 deletion gensim/test/basetests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
Automated tests for checking transformation algorithms (the models package).
"""

import six
import numpy as np
import six


class TestBaseTopicModel(object):
def testPrintTopic(self):
Expand Down Expand Up @@ -41,3 +42,12 @@ def testShowTopics(self):
for k, v in topic:
self.assertTrue(isinstance(k, six.string_types))
self.assertTrue(isinstance(v, (np.floating, float)))

def testGetTopics(self):
topics = self.model.get_topics()
vocab_size = len(self.model.id2word)
for topic in topics:
self.assertTrue(isinstance(topic, np.ndarray))
self.assertEqual(topic.dtype, np.float64)
self.assertEqual(vocab_size, topic.shape[0])

0 comments on commit 24686ce

Please sign in to comment.