Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Added sklearn wrapper for AuthorTopic model #1403

Merged
merged 16 commits into from
Jun 28, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gensim/sklearn_integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from .sklearn_wrapper_gensim_lsimodel import SklLsiModel # noqa: F401
from .sklearn_wrapper_gensim_rpmodel import SklRpModel # noqa: F401
from .sklearn_wrapper_gensim_ldaseqmodel import SklLdaSeqModel # noqa: F401
from .sklearn_wrapper_gensim_atmodel import SklATModel # noqa: F401
118 changes: 118 additions & 0 deletions gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <[email protected]>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Scikit learn interface for gensim for easy use of gensim with scikit-learn
Follows scikit-learn API conventions
"""
import numpy as np
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.exceptions import NotFittedError

from gensim import models
from gensim.sklearn_integration import BaseSklearnWrapper


class SklATModel(BaseSklearnWrapper, TransformerMixin, BaseEstimator):
"""
Base AuthorTopic module
"""

def __init__(self, num_topics=100, id2word=None, author2doc=None, doc2author=None,
chunksize=2000, passes=1, iterations=50, decay=0.5, offset=1.0,
alpha='symmetric', eta='symmetric', update_every=1, eval_every=10,
gamma_threshold=0.001, serialized=False, serialization_path=None,
minimum_probability=0.01, random_state=None):
"""
Sklearn wrapper for AuthorTopic model. Class derived from gensim.models.AuthorTopicModel
"""
self.gensim_model = None
self.num_topics = num_topics
self.id2word = id2word
self.author2doc = author2doc
self.doc2author = doc2author
self.chunksize = chunksize
self.passes = passes
self.iterations = iterations
self.decay = decay
self.offset = offset
self.alpha = alpha
self.eta = eta
self.update_every = update_every
self.eval_every = eval_every
self.gamma_threshold = gamma_threshold
self.serialized = serialized
self.serialization_path = serialization_path
self.minimum_probability = minimum_probability
self.random_state = random_state

def get_params(self, deep=True):
"""
Returns all parameters as dictionary.
"""
return {"num_topics": self.num_topics, "id2word": self.id2word,
"author2doc": self.author2doc, "doc2author": self.doc2author, "chunksize": self.chunksize,
"passes": self.passes, "iterations": self.iterations, "decay": self.decay,
"offset": self.offset, "alpha": self.alpha, "eta": self.eta, "update_every": self.update_every,
"eval_every": self.eval_every, "gamma_threshold": self.gamma_threshold,
"serialized": self.serialized, "serialization_path": self.serialization_path,
"minimum_probability": self.minimum_probability, "random_state": self.random_state}

def set_params(self, **parameters):
"""
Set all parameters.
"""
super(SklATModel, self).set_params(**parameters)
return self

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
Calls gensim.models.AuthorTopicModel
"""
self.gensim_model = models.AuthorTopicModel(corpus=X, num_topics=self.num_topics, id2word=self.id2word,
author2doc=self.author2doc, doc2author=self.doc2author, chunksize=self.chunksize, passes=self.passes,
iterations=self.iterations, decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta,
update_every=self.update_every, eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized,
serialization_path=self.serialization_path, minimum_probability=self.minimum_probability, random_state=self.random_state)
return self

def transform(self, author_names):
"""
Return topic distribution for input authors as a list of
(topic_id, topic_probabiity) 2-tuples.
"""
# The input as array of array
if self.gensim_model is None:
raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")

check = lambda x: [x] if not isinstance(x, list) else x
author_names = check(author_names)
X = [[] for _ in range(0, len(author_names))]

for k, v in enumerate(author_names):
transformed_author = self.gensim_model[v]
probs_author = list(map(lambda x: x[1], transformed_author))
# Everything should be equal in length
if len(probs_author) != self.num_topics:
probs_author.extend([1e-12] * (self.num_topics - len(probs_author)))
X[k] = probs_author

return np.reshape(np.array(X), (len(author_names), self.num_topics))

def partial_fit(self, X, author2doc=None, doc2author=None):
"""
Train model over X.
"""
if self.gensim_model is None:
self.gensim_model = models.AuthorTopicModel(corpus=X, num_topics=self.num_topics, id2word=self.id2word,
author2doc=self.author2doc, doc2author=self.doc2author, chunksize=self.chunksize, passes=self.passes,
iterations=self.iterations, decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta,
update_every=self.update_every, eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized,
serialization_path=self.serialization_path, minimum_probability=self.minimum_probability, random_state=self.random_state)

self.gensim_model.update(corpus=X, author2doc=author2doc, doc2author=doc2author)
return self
64 changes: 63 additions & 1 deletion gensim/test/test_sklearn_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.datasets import load_files
from sklearn import linear_model
from sklearn import linear_model, cluster
from sklearn.exceptions import NotFittedError
except ImportError:
raise unittest.SkipTest("Test requires scikit-learn to be installed, which is not available")
Expand All @@ -19,6 +19,7 @@
from gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel import SklLdaModel
from gensim.sklearn_integration.sklearn_wrapper_gensim_lsimodel import SklLsiModel
from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaseqmodel import SklLdaSeqModel
from gensim.sklearn_integration.sklearn_wrapper_gensim_atmodel import SklATModel
from gensim.corpora import mmcorpus, Dictionary
from gensim import matutils

Expand All @@ -39,6 +40,12 @@
]
dictionary = Dictionary(texts)
corpus = [dictionary.doc2bow(text) for text in texts]
author2doc = {'john': [0, 1, 2, 3, 4, 5, 6], 'jane': [2, 3, 4, 5, 6, 7, 8], 'jack': [0, 2, 4, 6, 8], 'jill': [1, 3, 5, 7]}

texts_new = texts[0:3]
author2doc_new = {'jill': [0], 'bob': [0, 1], 'sally': [1, 2]}
dictionary_new = Dictionary(texts_new)
corpus_new = [dictionary_new.doc2bow(text) for text in texts_new]

texts_ldaseq = [
[u'senior', u'studios', u'studios', u'studios', u'creators', u'award', u'mobile', u'currently', u'challenges', u'senior', u'summary', u'senior', u'motivated', u'creative', u'senior'],
Expand Down Expand Up @@ -396,5 +403,60 @@ def testModelNotFitted(self):
self.assertRaises(NotFittedError, rpmodel_wrapper.transform, doc)


class TestSklATModelWrapper(unittest.TestCase):
def setUp(self):
self.model = SklATModel(id2word=dictionary, author2doc=author2doc, num_topics=2, passes=100)
self.model.fit(corpus)

def testTransform(self):
# transforming multiple authors
author_list = ['jill', 'jack']
author_topics = self.model.transform(author_list)
self.assertEqual(author_topics.shape[0], 2)
self.assertEqual(author_topics.shape[1], self.model.num_topics)

# transforming one author
jill_topics = self.model.transform('jill')
self.assertEqual(jill_topics.shape[0], 1)
self.assertEqual(jill_topics.shape[1], self.model.num_topics)

def testPartialFit(self):
self.model.partial_fit(corpus_new, author2doc=author2doc_new)

# Did we learn something about Sally?
output_topics = self.model.transform('sally')
sally_topics = output_topics[0] # getting the topics corresponding to 'sally' (from the list of lists)
self.assertTrue(all(sally_topics > 0))

def testSetGetParams(self):
# updating only one param
self.model.set_params(num_topics=3)
model_params = self.model.get_params()
self.assertEqual(model_params["num_topics"], 3)

# updating multiple params
param_dict = {"passes": 5, "iterations": 10}
self.model.set_params(**param_dict)
model_params = self.model.get_params()
for key in param_dict.keys():
self.assertEqual(model_params[key], param_dict[key])

def testPipeline(self):
# train the AuthorTopic model first
model = SklATModel(id2word=dictionary, author2doc=author2doc, num_topics=10, passes=100)
model.fit(corpus)

# create and train clustering model
clstr = cluster.MiniBatchKMeans(n_clusters=2)
authors_full = ['john', 'jane', 'jack', 'jill']
clstr.fit(model.transform(authors_full))

# stack together the two models in a pipeline
text_atm = Pipeline((('features', model,), ('cluster', clstr)))
author_list = ['jane', 'jack', 'jill']
ret_val = text_atm.predict(author_list)
self.assertEqual(len(ret_val), len(author_list))


if __name__ == '__main__':
unittest.main()