-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] scikit_learn wrapper for LSI Model in Gensim #1244
Changes from 6 commits
01c3dde
8797cd1
d09442c
73b7e2f
0a9ca6a
b9ef360
a366261
e74d8e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2011 Radim Rehurek <[email protected]> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
# | ||
""" | ||
Scikit learn interface for gensim for easy use of gensim with scikit-learn | ||
Follows scikit-learn API conventions | ||
""" | ||
import numpy as np | ||
|
||
from gensim import models | ||
from gensim import matutils | ||
from scipy import sparse | ||
from sklearn.base import TransformerMixin, BaseEstimator | ||
|
||
# accuracy defaults for the multi-pass stochastic algo | ||
P2_EXTRA_DIMS = 100 # set to `None` for dynamic P2_EXTRA_DIMS=k | ||
P2_EXTRA_ITERS = 2 | ||
|
||
class SklearnWrapperLsiModel(models.LsiModel, TransformerMixin, BaseEstimator): | ||
""" | ||
Base LSI module | ||
""" | ||
|
||
def __init__(self, corpus=None, num_topics=200, id2word=None, chunksize=20000, | ||
decay=1.0, onepass=True, power_iters=P2_EXTRA_ITERS, extra_samples=P2_EXTRA_DIMS): | ||
""" | ||
Sklearn wrapper for LSI model. Class derived from gensim.model.LsiModel. | ||
""" | ||
self.corpus = corpus | ||
self.num_topics = num_topics | ||
self.id2word = id2word | ||
self.chunksize = chunksize | ||
self.decay = decay | ||
self.onepass = onepass | ||
self.extra_samples = extra_samples | ||
self.power_iters = power_iters | ||
|
||
# if 'fit' function is not used, then 'corpus' is given in init | ||
if self.corpus: | ||
models.LsiModel.__init__(self, corpus=self.corpus, num_topics=self.num_topics, id2word=self.id2word, chunksize=self.chunksize, | ||
decay=self.decay, onepass=self.onepass, power_iters=self.power_iters, extra_samples=self.extra_samples) | ||
|
||
def get_params(self, deep=True): | ||
""" | ||
Returns all parameters as dictionary. | ||
""" | ||
return {"corpus": self.corpus, "num_topics": self.num_topics, "id2word": self.id2word, | ||
"chunksize": self.chunksize, "decay": self.decay, "onepass": self.onepass, | ||
"extra_samples": self.extra_samples, "power_iters": self.power_iters} | ||
|
||
def set_params(self, **parameters): | ||
""" | ||
Set all parameters. | ||
""" | ||
for parameter, value in parameters.items(): | ||
self.parameter = value | ||
return self | ||
|
||
def fit(self, X, y=None): | ||
""" | ||
For fitting corpus into the class object. | ||
Calls gensim.model.LsiModel: | ||
>>>gensim.models.LsiModel(corpus=corpus, num_topics=num_topics, id2word=id2word, chunksize=chunksize, decay=decay, onepass=onepass, power_iters=power_iters, extra_samples=extra_samples) | ||
""" | ||
if sparse.issparse(X): | ||
self.corpus = matutils.Sparse2Corpus(X) | ||
else: | ||
self.corpus = X | ||
|
||
models.LsiModel.__init__(self, corpus=self.corpus, num_topics=self.num_topics, id2word=self.id2word, chunksize=self.chunksize, | ||
decay=self.decay, onepass=self.onepass, power_iters=self.power_iters, extra_samples=self.extra_samples) | ||
return self | ||
|
||
def transform(self, docs): | ||
""" | ||
Takes a list of documents as input ('docs'). | ||
Returns a matrix of topic distribution for the given document bow, where a_ij | ||
indicates (topic_i, topic_probability_j). | ||
""" | ||
# The input as array of array | ||
check = lambda x: [x] if isinstance(x[0], tuple) else x | ||
docs = check(docs) | ||
X = [[] for i in range(0,len(docs))]; | ||
for k,v in enumerate(docs): | ||
doc_topics = self[v] | ||
probs_docs = list(map(lambda x: x[1], doc_topics)) | ||
# Everything should be equal in length | ||
if len(probs_docs) != self.num_topics: | ||
probs_docs.extend([1e-12]*(self.num_topics - len(probs_docs))) | ||
X[k] = probs_docs | ||
probs_docs = [] | ||
return np.reshape(np.array(X), (len(docs), self.num_topics)) | ||
|
||
def partial_fit(self, X): | ||
""" | ||
Train model over X. | ||
""" | ||
if sparse.issparse(X): | ||
X = matutils.Sparse2Corpus(X) | ||
self.add_documents(corpus=X) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
from sklearn.datasets import load_files | ||
from sklearn import linear_model | ||
from gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel import SklearnWrapperLdaModel | ||
from gensim.sklearn_integration.sklearn_wrapper_gensim_lsimodel import SklearnWrapperLsiModel | ||
from gensim.corpora import Dictionary | ||
from gensim import matutils | ||
|
||
|
@@ -55,7 +56,7 @@ def testTransform(self): | |
X = self.model.transform(bow) | ||
self.assertTrue(X.shape[0], 3) | ||
self.assertTrue(X.shape[1], self.model.num_topics) | ||
|
||
def testGetTopicDist(self): | ||
texts_new = ['graph','eulerian'] | ||
bow = self.model.id2word.doc2bow(texts_new) | ||
|
@@ -97,7 +98,7 @@ def testPipeline(self): | |
compressed_content = f.read() | ||
uncompressed_content = codecs.decode(compressed_content, 'zlib_codec') | ||
cache = pickle.loads(uncompressed_content) | ||
data = cache | ||
data = cache | ||
id2word=Dictionary(map(lambda x : x.split(), data.data)) | ||
corpus = [id2word.doc2bow(i.split()) for i in data.data] | ||
rand = numpy.random.mtrand.RandomState(1) # set seed for getting same result | ||
|
@@ -107,5 +108,55 @@ def testPipeline(self): | |
score = text_lda.score(corpus, data.target) | ||
self.assertGreater(score, 0.50) | ||
|
||
class TestSklearnLSIWrapper(unittest.TestCase): | ||
def setUp(self): | ||
self.model = SklearnWrapperLsiModel(id2word=dictionary, num_topics=2) | ||
self.model.fit(corpus) | ||
|
||
def testPrintTopic(self): | ||
topic = self.model.print_topics(2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test seems redundant. Please justify including it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tmylk My rationale for having this test was basically to serve as a sanity check after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please rename as testModelSanity then |
||
for k, v in topic: | ||
self.assertTrue(isinstance(v, six.string_types)) | ||
self.assertTrue(isinstance(k, int)) | ||
|
||
def testTransform(self): | ||
texts_new = ['graph','eulerian'] | ||
bow = self.model.id2word.doc2bow(texts_new) | ||
X = self.model.transform(bow) | ||
self.assertTrue(X.shape[0], 1) | ||
self.assertTrue(X.shape[1], self.model.num_topics) | ||
texts_new = [['graph','eulerian'],['server', 'flow'], ['path', 'system']] | ||
bow = [] | ||
for i in texts_new: | ||
bow.append(self.model.id2word.doc2bow(i)) | ||
X = self.model.transform(bow) | ||
self.assertTrue(X.shape[0], 3) | ||
self.assertTrue(X.shape[1], self.model.num_topics) | ||
|
||
def testPartialFit(self): | ||
for i in range(10): | ||
self.model.partial_fit(X=corpus) # fit against the model again | ||
doc=list(corpus)[0] # transform only the first document | ||
transformed = self.model[doc] | ||
transformed_approx = matutils.sparse2full(transformed, 2) # better approximation | ||
expected=[1.39, 0.0] | ||
passed = numpy.allclose(sorted(transformed_approx), sorted(expected), atol=1e-1) | ||
self.assertTrue(passed) | ||
|
||
def testPipeline(self): | ||
model = SklearnWrapperLsiModel(num_topics=2) | ||
with open(datapath('mini_newsgroup'),'rb') as f: | ||
compressed_content = f.read() | ||
uncompressed_content = codecs.decode(compressed_content, 'zlib_codec') | ||
cache = pickle.loads(uncompressed_content) | ||
data = cache | ||
id2word=Dictionary(map(lambda x : x.split(), data.data)) | ||
corpus = [id2word.doc2bow(i.split()) for i in data.data] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The transformation of a text into gensim bow format corpus should have a sklearn-wrapper around so that the pipeline can take in text, not bow. See more about Gensim corpus in this [tutorial(https://github.com/RaRe-Technologies/gensim/blob/develop/docs/notebooks/Corpora_and_Vector_Spaces.ipynb) Please either add it in this PR or create a wishlist issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please respond to this comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tmylk My sincerest apologies for the delayed response. I'd be happy to submit a PR to resolve this issue. I did not add code for this in the current PR because some of the code for the wrapper for LdaModel might also have to be refactored. So maintaining modularity and proposing a solution in a different PR seemed to be a better choice in this case. |
||
clf=linear_model.LogisticRegression(penalty='l2', C=0.1) | ||
text_lda = Pipeline((('features', model,), ('classifier', clf))) | ||
text_lda.fit(corpus, data.target) | ||
score = text_lda.score(corpus, data.target) | ||
self.assertGreater(score, 0.50) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please make defaults explicits. what is the reasons for using constant variables?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tmylk Thanks for pointing this out. I'll make the default values explicit here.