Skip to content

Commit

Permalink
LDAMallet load_word_topics() returns value (#767)
Browse files Browse the repository at this point in the history
  • Loading branch information
bhargavvader authored and tmylk committed Jul 1, 2016
1 parent bf26bdd commit 003a886
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
12 changes: 7 additions & 5 deletions gensim/models/wrappers/ldamallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __getitem__(self, bow, iterations=100):

def load_word_topics(self):
logger.info("loading assigned topics from %s", self.fstate())
wordtopics = numpy.zeros((self.num_topics, self.num_terms), dtype=numpy.float32)
word_topics = numpy.zeros((self.num_topics, self.num_terms), dtype=numpy.float32)
if hasattr(self.id2word, 'token2id'):
word2id = self.id2word.token2id
else:
Expand All @@ -199,10 +199,10 @@ def load_word_topics(self):
if token not in word2id:
continue
tokenid = word2id[token]
wordtopics[int(topic), tokenid] += 1.0
logger.info("loaded assigned topics for %i tokens", wordtopics.sum())
self.wordtopics = wordtopics
word_topics[int(topic), tokenid] += 1.0
logger.info("loaded assigned topics for %i tokens", word_topics.sum())
self.print_topics(15)
return word_topics

def print_topics(self, num_topics=10, num_words=10):
return self.show_topics(num_topics, num_words, log=True)
Expand Down Expand Up @@ -242,7 +242,9 @@ def show_topics(self, num_topics=10, num_words=10, log=False, formatted=True):
return shown

def show_topic(self, topicid, topn=10):
topic = self.wordtopics[topicid]
if self.word_topics is None:
logger.warn("Run train or load_word_topics before showing topics.")
topic = self.word_topics[topicid]
topic = topic / topic.sum() # normalize to probability dist
bestn = matutils.argsort(topic, topn, reverse=True)
beststr = [(topic[id], self.id2word[id]) for id in bestn]
Expand Down
8 changes: 4 additions & 4 deletions gensim/test/test_ldamallet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def testPersistence(self):
model.save(fname)
model2 = ldamallet.LdaMallet.load(fname)
self.assertEqual(model.num_topics, model2.num_topics)
self.assertTrue(numpy.allclose(model.wordtopics, model2.wordtopics))
self.assertTrue(numpy.allclose(model.word_topics, model2.word_topics))
tstvec = []
self.assertTrue(numpy.allclose(model[tstvec], model2[tstvec])) # try projecting an empty vector

Expand All @@ -116,7 +116,7 @@ def testPersistenceCompressed(self):
model.save(fname)
model2 = ldamallet.LdaMallet.load(fname, mmap=None)
self.assertEqual(model.num_topics, model2.num_topics)
self.assertTrue(numpy.allclose(model.wordtopics, model2.wordtopics))
self.assertTrue(numpy.allclose(model.word_topics, model2.word_topics))
tstvec = []
self.assertTrue(numpy.allclose(model[tstvec], model2[tstvec])) # try projecting an empty vector

Expand All @@ -132,8 +132,8 @@ def testLargeMmap(self):
# test loading the large model arrays with mmap
model2 = ldamodel.LdaModel.load(testfile(), mmap='r')
self.assertEqual(model.num_topics, model2.num_topics)
self.assertTrue(isinstance(model2.wordtopics, numpy.memmap))
self.assertTrue(numpy.allclose(model.wordtopics, model2.wordtopics))
self.assertTrue(isinstance(model2.word_topics, numpy.memmap))
self.assertTrue(numpy.allclose(model.word_topics, model2.word_topics))
tstvec = []
self.assertTrue(numpy.allclose(model[tstvec], model2[tstvec])) # try projecting an empty vector

Expand Down

0 comments on commit 003a886

Please sign in to comment.